Full Code of ali-vilab/ACE for AI

main 886bf9510b85 cached
42 files
269.5 KB
62.7k tokens
152 symbols
1 requests
Download .txt
Showing preview only (284K chars total). Download the full file or copy to clipboard to get everything.
Repository: ali-vilab/ACE
Branch: main
Commit: 886bf9510b85
Files: 42
Total size: 269.5 KB

Directory structure:
gitextract_gx2sgdse/

├── .gitignore
├── LICENSE
├── __init__.py
├── chatbot/
│   ├── __init__.py
│   ├── ace_inference.py
│   ├── config/
│   │   ├── chatbot_ui.yaml
│   │   └── models/
│   │       └── ace_0.6b_512.yaml
│   ├── example.py
│   ├── infer.py
│   ├── run_gradio.py
│   └── utils.py
├── config/
│   ├── inference_config/
│   │   ├── chatbot_ui.yaml
│   │   └── models/
│   │       ├── ace_0.6b_1024.yaml
│   │       ├── ace_0.6b_1024_refiner.yaml
│   │       └── ace_0.6b_512.yaml
│   └── train_config/
│       ├── ace_0.6b_1024_train.yaml
│       └── ace_0.6b_512_train.yaml
├── modules/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset/
│   │       ├── __init__.py
│   │       └── dataset.py
│   ├── inference/
│   │   └── __init__.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── backbone/
│   │   │   ├── __init__.py
│   │   │   ├── ace.py
│   │   │   ├── layers.py
│   │   │   └── pos_embed.py
│   │   ├── diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── diffusions.py
│   │   │   ├── samplers.py
│   │   │   └── schedules.py
│   │   ├── embedder/
│   │   │   ├── __init__.py
│   │   │   └── embedder.py
│   │   ├── network/
│   │   │   ├── __init__.py
│   │   │   └── ldm_ace.py
│   │   └── utils/
│   │       └── basic_utils.py
│   └── solver/
│       ├── __init__.py
│       └── ace_solver.py
├── readme.md
├── requirements.txt
└── tools/
    ├── run_inference.py
    └── run_train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
*.pth
*.pt
*.pkl
*.ckpt
*.DS_Store
*__pycache__*
*.cache*
*.bin
*.idea
*.csv
cache
build
dist
dev
scepter.egg-info
.readthedocs.yml
*resources
*.ipynb_checkpoints*
*.vscode


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: __init__.py
================================================
from . import modules
from . import chatbot

================================================
FILE: chatbot/__init__.py
================================================


================================================
FILE: chatbot/ace_inference.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import torchvision.transforms as T
from scepter.modules.model.registry import DIFFUSIONS
from scepter.modules.model.utils.basic_utils import check_list_of_list
from scepter.modules.model.utils.basic_utils import \
    pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
from scepter.modules.model.utils.basic_utils import (
    to_device, unpack_tensor_into_imagelist)
from scepter.modules.utils.distribute import we
from scepter.modules.utils.logger import get_logger

from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model


def process_edit_image(images,
                       masks,
                       tasks,
                       max_seq_len=1024,
                       max_aspect_ratio=4,
                       d=16,
                       **kwargs):

    if not isinstance(images, list):
        images = [images]
    if not isinstance(masks, list):
        masks = [masks]
    if not isinstance(tasks, list):
        tasks = [tasks]

    img_tensors = []
    mask_tensors = []
    for img, mask, task in zip(images, masks, tasks):
        if mask is None or mask == '':
            mask = Image.new('L', img.size, 0)
        W, H = img.size
        if H / W > max_aspect_ratio:
            img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
            mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
        elif W / H > max_aspect_ratio:
            img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
            mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])

        H, W = img.height, img.width
        scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
        rH = int(H * scale) // d * d  # ensure divisible by self.d
        rW = int(W * scale) // d * d

        img = TF.resize(img, (rH, rW),
                        interpolation=TF.InterpolationMode.BICUBIC)
        mask = TF.resize(mask, (rH, rW),
                         interpolation=TF.InterpolationMode.NEAREST_EXACT)

        mask = np.asarray(mask)
        mask = np.where(mask > 128, 1, 0)
        mask = mask.astype(
            np.float32) if np.any(mask) else np.ones_like(mask).astype(
                np.float32)

        img_tensor = TF.to_tensor(img).to(we.device_id)
        img_tensor = TF.normalize(img_tensor,
                                  mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        mask_tensor = TF.to_tensor(mask).to(we.device_id)
        if task in ['inpainting', 'Try On', 'Inpainting']:
            mask_indicator = mask_tensor.repeat(3, 1, 1)
            img_tensor[mask_indicator == 1] = -1.0
        img_tensors.append(img_tensor)
        mask_tensors.append(mask_tensor)
    return img_tensors, mask_tensors


class TextEmbedding(nn.Module):
    def __init__(self, embedding_shape):
        super().__init__()
        self.pos = nn.Parameter(data=torch.zeros(embedding_shape))

class RefinerInference(DiffusionInference):
    def init_from_cfg(self, cfg):
        self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
        super().init_from_cfg(cfg)
        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \
            if cfg.MODEL.have('DIFFUSION') else None
        self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
        assert self.diffusion is not None
        if not self.use_dynamic_model:
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
            self.dynamic_load(self.diffusion_model, 'diffusion_model')
    @torch.no_grad()
    def encode_first_stage(self, x, **kwargs):
        _, dtype = self.get_function_info(self.first_stage_model, 'encode')
        with torch.autocast('cuda',
                            enabled=dtype in ('float16', 'bfloat16'),
                            dtype=getattr(torch, dtype)):
            def run_one_image(u):
                zu = get_model(self.first_stage_model).encode(u)
                if isinstance(zu, (tuple, list)):
                    zu = zu[0]
                return zu
            z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]
            return z
    def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
        c, H, W = image.shape
        scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
        rH = int(H * scale) // 16 * 16  # ensure divisible by self.d
        rW = int(W * scale) // 16 * 16
        image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
        return image
    @torch.no_grad()
    def decode_first_stage(self, z):
        _, dtype = self.get_function_info(self.first_stage_model, 'decode')
        with torch.autocast('cuda',
                            enabled=dtype in ('float16', 'bfloat16'),
                            dtype=getattr(torch, dtype)):
            return [get_model(self.first_stage_model).decode(zu) for zu in z]

    def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
        noise = torch.randn(
            num_samples,
            16,
            # allow for packing
            2 * math.ceil(h / 16),
            2 * math.ceil(w / 16),
            device=device,
            dtype=dtype,
            generator=torch.Generator(device=device).manual_seed(seed),
        )
        return noise
    def refine(self,
               x_samples=None,
               prompt=None,
               reverse_scale=-1.,
               seed = 2024,
               **kwargs
               ):
        print(prompt)
        value_input = copy.deepcopy(self.input)
        x_samples = [self.upscale_resize(x) for x in x_samples]

        noise = []
        for i, x in enumerate(x_samples):
            noise_ = self.noise_sample(1, x.shape[1],
                                       x.shape[2], seed,
                                       device = x.device)
            noise.append(noise_)
        noise, x_shapes = pack_imagelist_into_tensor(noise)
        if reverse_scale > 0:
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            x_samples = [x.unsqueeze(0) for x in x_samples]
            x_start = self.encode_first_stage(x_samples, **kwargs)
            self.dynamic_unload(self.first_stage_model,
                                'first_stage_model',
                                skip_loaded=not self.use_dynamic_model)
            x_start, _ = pack_imagelist_into_tensor(x_start)
        else:
            x_start = None
        # cond stage
        self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
        function_name, dtype = self.get_function_info(self.cond_stage_model)
        with torch.autocast('cuda',
                            enabled=dtype == 'float16',
                            dtype=getattr(torch, dtype)):
            ctx = getattr(get_model(self.cond_stage_model),
                          function_name)(prompt)
            ctx["x_shapes"] = x_shapes
        self.dynamic_unload(self.cond_stage_model,
                            'cond_stage_model',
                            skip_loaded=not self.use_dynamic_model)


        self.dynamic_load(self.diffusion_model, 'diffusion_model')
        # UNet use input n_prompt
        function_name, dtype = self.get_function_info(
            self.diffusion_model)
        with torch.autocast('cuda',
                            enabled=dtype in ('float16', 'bfloat16'),
                            dtype=getattr(torch, dtype)):
            solver_sample = value_input.get('sample', 'flow_euler')
            sample_steps = value_input.get('sample_steps', 20)
            guide_scale = value_input.get('guide_scale', 3.5)
            if guide_scale is not None:
                guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,
                                         dtype=noise.dtype)
            else:
                guide_scale = None
            latent = self.diffusion.sample(
                noise=noise,
                sampler=solver_sample,
                model=get_model(self.diffusion_model),
                model_kwargs={"cond": ctx, "guidance": guide_scale},
                steps=sample_steps,
                show_progress=True,
                guide_scale=guide_scale,
                return_intermediate=None,
                reverse_scale=reverse_scale,
                x=x_start,
                **kwargs).float()
        latent = unpack_tensor_into_imagelist(latent, x_shapes)
        self.dynamic_unload(self.diffusion_model,
                            'diffusion_model',
                            skip_loaded=not self.use_dynamic_model)
        self.dynamic_load(self.first_stage_model, 'first_stage_model')
        x_samples = self.decode_first_stage(latent)
        self.dynamic_unload(self.first_stage_model,
                            'first_stage_model',
                            skip_loaded=not self.use_dynamic_model)
        return x_samples


class ACEInference(DiffusionInference):
    def __init__(self, logger=None):
        if logger is None:
            logger = get_logger(name='scepter')
        self.logger = logger
        self.loaded_model = {}
        self.loaded_model_name = [
            'diffusion_model', 'first_stage_model', 'cond_stage_model'
        ]

    def init_from_cfg(self, cfg):
        self.name = cfg.NAME
        self.is_default = cfg.get('IS_DEFAULT', False)
        self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
        module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
        assert cfg.have('MODEL')

        self.diffusion_model = self.infer_model(
            cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
                'DIFFUSION_MODEL',
                None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
        self.first_stage_model = self.infer_model(
            cfg.MODEL.FIRST_STAGE_MODEL,
            module_paras.get(
                'FIRST_STAGE_MODEL',
                None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
        self.cond_stage_model = self.infer_model(
            cfg.MODEL.COND_STAGE_MODEL,
            module_paras.get(
                'COND_STAGE_MODEL',
                None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None

        self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)
        # self.refiner_scale = cfg.get('REFINER_SCALE', 0.)
        # self.refiner_prompt = cfg.get('REFINER_PROMPT', "")
        self.ace_prompt = cfg.get("ACE_PROMPT", [])
        if self.refiner_model_cfg:
            self.refiner_model_cfg.USE_DYNAMIC_MODEL = self.use_dynamic_model
            self.refiner_module = RefinerInference(self.logger)
            self.refiner_module.init_from_cfg(self.refiner_model_cfg)
        else:
            self.refiner_module = None

        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
                                          logger=self.logger)


        self.interpolate_func = lambda x: (F.interpolate(
            x.unsqueeze(0),
            scale_factor=1 / self.size_factor,
            mode='nearest-exact') if x is not None else None)
        self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
        self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
                                                     False)
        if self.use_text_pos_embeddings:
            self.text_position_embeddings = TextEmbedding(
                (10, 4096)).eval().requires_grad_(False).to(we.device_id)
        else:
            self.text_position_embeddings = None

        self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
        self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
        self.size_factor = cfg.get('SIZE_FACTOR', 8)
        self.decoder_bias = cfg.get('DECODER_BIAS', 0)
        self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
        if not self.use_dynamic_model:
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
            self.dynamic_load(self.diffusion_model, 'diffusion_model')

    @torch.no_grad()
    def encode_first_stage(self, x, **kwargs):
        _, dtype = self.get_function_info(self.first_stage_model, 'encode')
        with torch.autocast('cuda',
                            enabled=(dtype != 'float32'),
                            dtype=getattr(torch, dtype)):
            z = [
                self.scale_factor * get_model(self.first_stage_model)._encode(
                    i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
            ]
        return z

    @torch.no_grad()
    def decode_first_stage(self, z):
        _, dtype = self.get_function_info(self.first_stage_model, 'decode')
        with torch.autocast('cuda',
                            enabled=(dtype != 'float32'),
                            dtype=getattr(torch, dtype)):
            x = [
                get_model(self.first_stage_model)._decode(
                    1. / self.scale_factor * i.to(getattr(torch, dtype)))
                for i in z
            ]
        return x



    @torch.no_grad()
    def __call__(self,
                 image=None,
                 mask=None,
                 prompt='',
                 task=None,
                 negative_prompt='',
                 output_height=512,
                 output_width=512,
                 sampler='ddim',
                 sample_steps=20,
                 guide_scale=4.5,
                 guide_rescale=0.5,
                 seed=-1,
                 history_io=None,
                 tar_index=0,
                 **kwargs):
        input_image, input_mask = image, mask
        g = torch.Generator(device=we.device_id)
        seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
        g.manual_seed(int(seed))
        if input_image is not None:
            # assert isinstance(input_image, list) and isinstance(input_mask, list)
            if task is None:
                task = [''] * len(input_image)
            if not isinstance(prompt, list):
                prompt = [prompt] * len(input_image)
            if history_io is not None and len(history_io) > 0:
                his_image, his_maks, his_prompt, his_task = history_io[
                    'image'], history_io['mask'], history_io[
                        'prompt'], history_io['task']
                assert len(his_image) == len(his_maks) == len(
                    his_prompt) == len(his_task)
                input_image = his_image + input_image
                input_mask = his_maks + input_mask
                task = his_task + task
                prompt = his_prompt + [prompt[-1]]
                prompt = [
                    pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
                    for i, pp in enumerate(prompt)
                ]

            edit_image, edit_image_mask = process_edit_image(
                input_image, input_mask, task, max_seq_len=self.max_seq_len)

            image, image_mask = edit_image[tar_index], edit_image_mask[
                tar_index]
            edit_image, edit_image_mask = [edit_image], [edit_image_mask]

        else:
            edit_image = edit_image_mask = [[]]
            image = torch.zeros(
                size=[3, int(output_height),
                      int(output_width)])
            image_mask = torch.ones(
                size=[1, int(output_height),
                      int(output_width)])
            if not isinstance(prompt, list):
                prompt = [prompt]

        image, image_mask, prompt = [image], [image_mask], [prompt]
        assert check_list_of_list(prompt) and check_list_of_list(
            edit_image) and check_list_of_list(edit_image_mask)
        # Assign Negative Prompt
        if isinstance(negative_prompt, list):
            negative_prompt = negative_prompt[0]
        assert isinstance(negative_prompt, str)

        n_prompt = copy.deepcopy(prompt)
        for nn_p_id, nn_p in enumerate(n_prompt):
            assert isinstance(nn_p, list)
            n_prompt[nn_p_id][-1] = negative_prompt

        is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1
        image = to_device(image)

        refiner_scale = kwargs.pop("refiner_scale", 0.0)
        refiner_prompt = kwargs.pop("refiner_prompt", "")
        use_ace = kwargs.pop("use_ace", True)
        # <= 0 use ace as the txt2img generator.
        if use_ace and (not is_txt_image or refiner_scale <= 0):
            ctx, null_ctx = {}, {}
            # Get Noise Shape
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            x = self.encode_first_stage(image)
            self.dynamic_unload(self.first_stage_model,
                                'first_stage_model',
                                skip_loaded=not self.use_dynamic_model)
            noise = [
                torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
                for i in x
            ]
            noise, x_shapes = pack_imagelist_into_tensor(noise)
            ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes

            image_mask = to_device(image_mask, strict=False)
            cond_mask = [self.interpolate_func(i) for i in image_mask
                         ] if image_mask is not None else [None] * len(image)
            ctx['x_mask'] = null_ctx['x_mask'] = cond_mask

            # Encode Prompt
            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
            function_name, dtype = self.get_function_info(self.cond_stage_model)
            cont, cont_mask = getattr(get_model(self.cond_stage_model),
                                      function_name)(prompt)
            cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
                                                         cont_mask)
            null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
                                                function_name)(n_prompt)
            null_cont, null_cont_mask = self.cond_stage_embeddings(
                prompt, edit_image, null_cont, null_cont_mask)
            self.dynamic_unload(self.cond_stage_model,
                                'cond_stage_model',
                                skip_loaded=not self.use_dynamic_model)
            ctx['crossattn'] = cont
            null_ctx['crossattn'] = null_cont

            # Encode Edit Images
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            edit_image = [to_device(i, strict=False) for i in edit_image]
            edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
            e_img, e_mask = [], []
            for u, m in zip(edit_image, edit_image_mask):
                if u is None:
                    continue
                if m is None:
                    m = [None] * len(u)
                e_img.append(self.encode_first_stage(u, **kwargs))
                e_mask.append([self.interpolate_func(i) for i in m])
            self.dynamic_unload(self.first_stage_model,
                                'first_stage_model',
                                skip_loaded=not self.use_dynamic_model)
            null_ctx['edit'] = ctx['edit'] = e_img
            null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask

            # Diffusion Process
            self.dynamic_load(self.diffusion_model, 'diffusion_model')
            function_name, dtype = self.get_function_info(self.diffusion_model)
            with torch.autocast('cuda',
                                enabled=dtype in ('float16', 'bfloat16'),
                                dtype=getattr(torch, dtype)):
                latent = self.diffusion.sample(
                    noise=noise,
                    sampler=sampler,
                    model=get_model(self.diffusion_model),
                    model_kwargs=[{
                        'cond':
                        ctx,
                        'mask':
                        cont_mask,
                        'text_position_embeddings':
                        self.text_position_embeddings.pos if hasattr(
                            self.text_position_embeddings, 'pos') else None
                    }, {
                        'cond':
                        null_ctx,
                        'mask':
                        null_cont_mask,
                        'text_position_embeddings':
                        self.text_position_embeddings.pos if hasattr(
                            self.text_position_embeddings, 'pos') else None
                    }] if guide_scale is not None and guide_scale > 1 else {
                        'cond':
                        null_ctx,
                        'mask':
                        cont_mask,
                        'text_position_embeddings':
                        self.text_position_embeddings.pos if hasattr(
                            self.text_position_embeddings, 'pos') else None
                    },
                    steps=sample_steps,
                    show_progress=True,
                    seed=seed,
                    guide_scale=guide_scale,
                    guide_rescale=guide_rescale,
                    return_intermediate=None,
                    **kwargs)
            if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
                                'diffusion_model',
                                skip_loaded=not self.use_dynamic_model)

            # Decode to Pixel Space
            self.dynamic_load(self.first_stage_model, 'first_stage_model')
            samples = unpack_tensor_into_imagelist(latent, x_shapes)
            x_samples = self.decode_first_stage(samples)
            self.dynamic_unload(self.first_stage_model,
                                'first_stage_model',
                                skip_loaded=not self.use_dynamic_model)
            x_samples = [x.squeeze(0) for x in x_samples]
        else:
            x_samples = image
        if self.refiner_module and refiner_scale > 0:
            if is_txt_image:
                random.shuffle(self.ace_prompt)
                input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt]
                input_refine_scale = -1.
            else:
                input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt]
                input_refine_scale = refiner_scale
                print(input_refine_prompt)

            x_samples = self.refiner_module.refine(x_samples,
                                                   reverse_scale = input_refine_scale,
                                                   prompt= input_refine_prompt,
                                                   seed=seed,
                                                   use_dynamic_model=self.use_dynamic_model)

        imgs = [
            torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,
                        min=0.0,
                        max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
            for x_i in x_samples
        ]
        imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
        return imgs

    def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
        if self.use_text_pos_embeddings and not torch.sum(
                self.text_position_embeddings.pos) > 0:
            identifier_cont, _ = getattr(get_model(self.cond_stage_model),
                                         'encode')(self.text_indentifers,
                                                   return_mask=True)
            self.text_position_embeddings.load_state_dict(
                {'pos': identifier_cont[:, 0, :]})

        cont_, cont_mask_ = [], []
        for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
            if isinstance(pp, list):
                cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
                cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
            else:
                raise NotImplementedError

        return cont_, cont_mask_


================================================
FILE: chatbot/config/chatbot_ui.yaml
================================================
WORK_DIR: ./cache/chatbot
FILE_SYSTEM:
  - NAME: "HuggingfaceFs"
    TEMP_DIR: ./cache
  - NAME: "ModelscopeFs"
    TEMP_DIR: ./cache
  - NAME: "LocalFs"
    TEMP_DIR: ./cache
  - NAME: "HttpFs"
    TEMP_DIR: ./cache
#
ENABLE_I2V: False
#
MODEL:
  EDIT_MODEL:
    MODEL_CFG_DIR: chatbot/config/models/
    DEFAULT: ace_0.6b_512
  I2V:
    MODEL_NAME: CogVideoX-5b-I2V
    MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
  CAPTIONER:
    MODEL_NAME: InternVL2-2B
    MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
    PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
  ENHANCER:
    MODEL_NAME: Meta-Llama-3.1-8B-Instruct
    MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/


================================================
FILE: chatbot/config/models/ace_0.6b_512.yaml
================================================
NAME: ACE_0.6B_512
IS_DEFAULT: False
DEFAULT_PARAS:
  PARAS:
  #
  INPUT:
    INPUT_IMAGE:
    INPUT_MASK:
    TASK:
    PROMPT: ""
    NEGATIVE_PROMPT: ""
    OUTPUT_HEIGHT: 512
    OUTPUT_WIDTH: 512
    SAMPLER: ddim
    SAMPLE_STEPS: 20
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
    SEED: -1
    TAR_INDEX: 0
  OUTPUT:
    LATENT:
    IMAGES:
    SEED:
  MODULES_PARAS:
    FIRST_STAGE_MODEL:
      FUNCTION:
        - NAME: encode
          DTYPE: float16
          INPUT: ["IMAGE"]
        - NAME: decode
          DTYPE: float16
          INPUT: ["LATENT"]
    #
    DIFFUSION_MODEL:
      FUNCTION:
        - NAME: forward
          DTYPE: float16
          INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
    #
    COND_STAGE_MODEL:
      FUNCTION:
        - NAME: encode_list
          DTYPE: bfloat16
          INPUT: ["PROMPT"]
#
MODEL:
  NAME: LdmACE
  PRETRAINED_MODEL:
  IGNORE_KEYS: [ ]
  SCALE_FACTOR: 0.18215
  SIZE_FACTOR: 8
  DECODER_BIAS: 0.5
  DEFAULT_N_PROMPT: ""
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
  USE_TEXT_POS_EMBEDDINGS: True
  #
  DIFFUSION:
    NAME: ACEDiffusion
    PREDICTION_TYPE: eps
    MIN_SNR_GAMMA:
    NOISE_SCHEDULER:
      NAME: LinearScheduler
      NUM_TIMESTEPS: 1000
      BETA_MIN: 0.0001
      BETA_MAX: 0.02
  #
  DIFFUSION_MODEL:
    NAME: DiTACE
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
    IGNORE_KEYS: [ ]
    PATCH_SIZE: 2
    IN_CHANNELS: 4
    HIDDEN_SIZE: 1152
    DEPTH: 28
    NUM_HEADS: 16
    MLP_RATIO: 4.0
    PRED_SIGMA: True
    DROP_PATH: 0.0
    WINDOW_DIZE: 0
    Y_CHANNELS: 4096
    MAX_SEQ_LEN: 1024
    QK_NORM: True
    USE_GRAD_CHECKPOINT: True
    ATTENTION_BACKEND: flash_attn
  #
  FIRST_STAGE_MODEL:
    NAME: AutoencoderKL
    EMBED_DIM: 4
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
    IGNORE_KEYS: []
    #
    ENCODER:
      NAME: Encoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DOUBLE_Z: True
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
    #
    DECODER:
      NAME: Decoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
      GIVE_PRE_END: False
      TANH_OUT: False
  #
  COND_STAGE_MODEL:
    NAME: ACETextEmbedder
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
    TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
    LENGTH: 120
    T5_DTYPE: bfloat16
    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    CLEAN: whitespace
    USE_GRAD: False


================================================
FILE: chatbot/example.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from scepter.modules.utils.file_system import FS
from PIL import Image


def download_image(image, local_path=None):
    if not FS.exists(local_path):
        local_path = FS.get_from(image, local_path=local_path)
    return local_path

def blank_image():
    return Image.new('RGBA', (128, 128), (0, 0, 0, 0))



def get_examples(cache_dir):
    print('Downloading Examples ...')
    bl_img = blank_image()
    examples = [
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',
                os.path.join(cache_dir, 'examples/e33edc106953.png')), bl_img,
                bl_img, '{image} let the man smile', 6666
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',
                os.path.join(cache_dir, 'examples/5d2bcc91a3e9.png')), bl_img,
                bl_img, 'let the man in {image} wear sunglasses', 9999
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',
                os.path.join(cache_dir, 'examples/3a52eac708bd.png')), bl_img,
                bl_img, '{image} red hair', 9999
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',
                os.path.join(cache_dir, 'examples/3f4dc464a0ea.png')), bl_img,
                bl_img, '{image} let the man serious', 99999
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
                os.path.join(cache_dir,
                             'examples/131ca90fd2a9.png')), bl_img, bl_img,
            '"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
            613725
        ],
        [
            'Render Text',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
                os.path.join(cache_dir, 'examples/33e9f27c2c48.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/33e9f27c2c48_mask.png')), bl_img,
            'Put the text "C A T" at the position marked by mask in the {image}',
            6666
        ],
        [
            'Style Transfer',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
                os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), bl_img,
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
                os.path.join(cache_dir, 'examples/2e02975293d6.png')),
            'edit {image} based on the style of {image1} ', 99999
        ],
        [
            'Outpainting',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
                os.path.join(cache_dir, 'examples/f2b22c08be3f.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/f2b22c08be3f_mask.png')), bl_img,
            'Could the {image} be widened within the space designated by mask, while retaining the original?',
            6666
        ],
        [
            'Image Segmentation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
                os.path.join(cache_dir, 'examples/db3ebaa81899.png')), bl_img,
            bl_img, '{image} Segmentation', 6666
        ],
        [
            'Depth Estimation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
                os.path.join(cache_dir, 'examples/f1927c4692ba.png')), bl_img,
            bl_img, '{image} Depth Estimation', 6666
        ],
        [
            'Pose Estimation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
                os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), bl_img,
            bl_img, '{image} distinguish the poses of the figures', 999999
        ],
        [
            'Scribble Extraction',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
                os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), bl_img,
            bl_img, 'Generate a scribble of {image}, please.', 6666
        ],
        [
            'Mosaic',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
                os.path.join(cache_dir, 'examples/3a2f52361eea.png')), bl_img,
            bl_img, 'Adapt {image} into a mosaic representation.', 6666
        ],
        [
            'Edge map Extraction',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
                os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), bl_img,
            bl_img, 'Get the edge-enhanced result for {image}.', 6666
        ],
        [
            'Grayscale',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
                os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), bl_img,
            bl_img, 'transform {image} into a black and white one', 6666
        ],
        [
            'Contour Extraction',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
                os.path.join(cache_dir,
                             'examples/19652d0f6c4b.png')), bl_img, bl_img,
            'Would you be able to make a contour picture from {image} for me?',
            6666
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
                os.path.join(cache_dir,
                             'examples/249cda2844b7.png')), bl_img, bl_img,
            'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
            6666
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
                os.path.join(cache_dir,
                             'examples/411f6c4b8e6c.png')), bl_img, bl_img,
            'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
            999999
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
                os.path.join(cache_dir,
                             'examples/a35c96ed137a.png')), bl_img, bl_img,
            'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
            3599999
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
                os.path.join(cache_dir,
                             'examples/dcb2fc86f1ce.png')), bl_img, bl_img,
            'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
            6666
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
                os.path.join(cache_dir,
                             'examples/4cd4ee494962.png')), bl_img, bl_img,
            'make this {image} colorful as per the "beautiful sunflowers"',
            6666
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
                os.path.join(cache_dir,
                             'examples/a47e3a9cd166.png')), bl_img, bl_img,
            'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
            613725
        ],
        [
            'Controllable Generation',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
                os.path.join(cache_dir,
                             'examples/d890ed8a3ac2.png')), bl_img, bl_img,
            'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
            6666
        ],
        [
            'Image Denoising',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
                os.path.join(cache_dir,
                             'examples/0844a686a179.png')), bl_img, bl_img,
            'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
            6666
        ],
        [
            'Inpainting',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
                os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/fa91b6b7e59b_mask.png')), bl_img,
            'Ensure to overhaul the parts of the {image} indicated by the mask.',
            6666
        ],
        [
            'Inpainting',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
                os.path.join(cache_dir, 'examples/632899695b26.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/632899695b26_mask.png')), bl_img,
            'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
            6666
        ],
        [
            'General Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
                os.path.join(cache_dir,
                             'examples/354d17594afe.png')), bl_img, bl_img,
            '{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
            6666
        ],
        [
            'General Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
                os.path.join(cache_dir,
                             'examples/38946455752b.png')), bl_img, bl_img,
            '{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
            6669
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
                os.path.join(cache_dir,
                             'examples/3ba5202f0cd8.png')), bl_img, bl_img,
            'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
            99999
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
                os.path.join(cache_dir, 'examples/369365b94725.png')), bl_img,
            bl_img, '{image} Make her looking at the camera', 6666
        ],
        [
            'Facial Editing',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
                os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), bl_img,
            bl_img, '{image} Remove the smile from his face', 9899999
        ],
        [
            'Remove Text',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
                os.path.join(cache_dir, 'examples/8530a6711b2e.png')), bl_img,
            bl_img, 'Aim to remove any textual element in {image}', 6666
        ],
        [
            'Remove Text',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
                os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/c4d7fb28f8f6_mask.png')), bl_img,
            'Rub out any text found in the mask sector of the {image}.', 6666
        ],
        [
            'Remove Object',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
                os.path.join(cache_dir,
                             'examples/e2f318fa5e5b.png')), bl_img, bl_img,
            'Remove the unicorn in this {image}, ensuring a smooth edit.',
            99999
        ],
        [
            'Remove Object',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
                os.path.join(cache_dir, 'examples/1ae96d8aca00.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
                os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')),
            bl_img, 'Discard the contents of the mask area from {image}.', 99999
        ],
        [
            'Add Object',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
                os.path.join(cache_dir, 'examples/80289f48e511.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
                os.path.join(cache_dir,
                             'examples/80289f48e511_mask.png')), bl_img,
            'add a Hot Air Balloon into the {image}, per the mask', 613725
        ],
        [
            'Style Transfer',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
                os.path.join(cache_dir, 'examples/d725cb2009e8.png')), bl_img,
            bl_img, 'Change the style of {image} to colored pencil style', 99999
        ],
        [
            'Style Transfer',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
                os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), bl_img,
            bl_img, 'make {image} to Walt Disney Animation style', 99999
        ],
        [
            'Try On',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
                os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
                os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')),
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
                os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')),
            'Change the cloth in {image} to the one in {image1}', 99999
        ],
        [
            'Workflow',
            download_image(
                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
                os.path.join(cache_dir, 'examples/cb85353c004b.png')), bl_img,
            bl_img, '<workflow> ice cream {image}', 99999
        ],
    ]
    print('Finish. Start building UI ...')
    return examples


================================================
FILE: chatbot/infer.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from scepter.modules.model.registry import DIFFUSIONS
from scepter.modules.utils.distribute import we
from scepter.modules.utils.logger import get_logger
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
from modules.model.utils.basic_utils import (
    check_list_of_list,
    pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
    to_device,
    unpack_tensor_into_imagelist
)


def process_edit_image(images,
                       masks,
                       tasks,
                       max_seq_len=1024,
                       max_aspect_ratio=4,
                       d=16,
                       **kwargs):

    if not isinstance(images, list):
        images = [images]
    if not isinstance(masks, list):
        masks = [masks]
    if not isinstance(tasks, list):
        tasks = [tasks]

    img_tensors = []
    mask_tensors = []
    for img, mask, task in zip(images, masks, tasks):
        if mask is None or mask == '':
            mask = Image.new('L', img.size, 0)
        W, H = img.size
        if H / W > max_aspect_ratio:
            img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
            mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
        elif W / H > max_aspect_ratio:
            img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
            mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])

        H, W = img.height, img.width
        scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
        rH = int(H * scale) // d * d  # ensure divisible by self.d
        rW = int(W * scale) // d * d

        img = TF.resize(img, (rH, rW),
                        interpolation=TF.InterpolationMode.BICUBIC)
        mask = TF.resize(mask, (rH, rW),
                         interpolation=TF.InterpolationMode.NEAREST_EXACT)

        mask = np.asarray(mask)
        mask = np.where(mask > 128, 1, 0)
        mask = mask.astype(
            np.float32) if np.any(mask) else np.ones_like(mask).astype(
                np.float32)

        img_tensor = TF.to_tensor(img).to(we.device_id)
        img_tensor = TF.normalize(img_tensor,
                                  mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        mask_tensor = TF.to_tensor(mask).to(we.device_id)
        if task in ['inpainting', 'Try On', 'Inpainting']:
            mask_indicator = mask_tensor.repeat(3, 1, 1)
            img_tensor[mask_indicator == 1] = -1.0
        img_tensors.append(img_tensor)
        mask_tensors.append(mask_tensor)
    return img_tensors, mask_tensors


class TextEmbedding(nn.Module):
    def __init__(self, embedding_shape):
        super().__init__()
        self.pos = nn.Parameter(data=torch.zeros(embedding_shape))


class ACEInference(DiffusionInference):
    def __init__(self, logger=None):
        if logger is None:
            logger = get_logger(name='scepter')
        self.logger = logger
        self.loaded_model = {}
        self.loaded_model_name = [
            'diffusion_model', 'first_stage_model', 'cond_stage_model'
        ]

    def init_from_cfg(self, cfg):
        self.name = cfg.NAME
        self.is_default = cfg.get('IS_DEFAULT', False)
        module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
        assert cfg.have('MODEL')

        self.diffusion_model = self.infer_model(
            cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
                'DIFFUSION_MODEL',
                None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
        self.first_stage_model = self.infer_model(
            cfg.MODEL.FIRST_STAGE_MODEL,
            module_paras.get(
                'FIRST_STAGE_MODEL',
                None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
        self.cond_stage_model = self.infer_model(
            cfg.MODEL.COND_STAGE_MODEL,
            module_paras.get(
                'COND_STAGE_MODEL',
                None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
                                          logger=self.logger)

        self.interpolate_func = lambda x: (F.interpolate(
            x.unsqueeze(0),
            scale_factor=1 / self.size_factor,
            mode='nearest-exact') if x is not None else None)
        self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
        self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
                                               False)
        if self.use_text_pos_embeddings:
            self.text_position_embeddings = TextEmbedding(
                (10, 4096)).eval().requires_grad_(False).to(we.device_id)
        else:
            self.text_position_embeddings = None

        self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
        self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
        self.size_factor = cfg.get('SIZE_FACTOR', 8)
        self.decoder_bias = cfg.get('DECODER_BIAS', 0)
        self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')

        self.dynamic_load(self.first_stage_model, 'first_stage_model')
        self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
        self.dynamic_load(self.diffusion_model, 'diffusion_model')

    @torch.no_grad()
    def encode_first_stage(self, x, **kwargs):
        _, dtype = self.get_function_info(self.first_stage_model, 'encode')
        with torch.autocast('cuda',
                            enabled=(dtype != 'float32'),
                            dtype=getattr(torch, dtype)):
            z = [
                self.scale_factor * get_model(self.first_stage_model)._encode(
                    i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
            ]
        return z

    @torch.no_grad()
    def decode_first_stage(self, z):
        _, dtype = self.get_function_info(self.first_stage_model, 'decode')
        with torch.autocast('cuda',
                            enabled=(dtype != 'float32'),
                            dtype=getattr(torch, dtype)):
            x = [
                get_model(self.first_stage_model)._decode(
                    1. / self.scale_factor * i.to(getattr(torch, dtype)))
                for i in z
            ]
        return x

    @torch.no_grad()
    def __call__(self,
                 image=None,
                 mask=None,
                 prompt='',
                 task=None,
                 negative_prompt='',
                 output_height=512,
                 output_width=512,
                 sampler='ddim',
                 sample_steps=20,
                 guide_scale=4.5,
                 guide_rescale=0.5,
                 seed=-1,
                 history_io=None,
                 tar_index=0,
                 **kwargs):
        input_image, input_mask = image, mask
        g = torch.Generator(device=we.device_id)
        seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
        g.manual_seed(int(seed))

        if input_image is not None:
            assert isinstance(input_image, list) and isinstance(
                input_mask, list)
            if task is None:
                task = [''] * len(input_image)
            if not isinstance(prompt, list):
                prompt = [prompt] * len(input_image)
            if history_io is not None and len(history_io) > 0:
                his_image, his_maks, his_prompt, his_task = history_io[
                    'image'], history_io['mask'], history_io[
                        'prompt'], history_io['task']
                assert len(his_image) == len(his_maks) == len(
                    his_prompt) == len(his_task)
                input_image = his_image + input_image
                input_mask = his_maks + input_mask
                task = his_task + task
                prompt = his_prompt + [prompt[-1]]
                prompt = [
                    pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
                    for i, pp in enumerate(prompt)
                ]

            edit_image, edit_image_mask = process_edit_image(
                input_image, input_mask, task, max_seq_len=self.max_seq_len)

            image, image_mask = edit_image[tar_index], edit_image_mask[
                tar_index]
            edit_image, edit_image_mask = [edit_image], [edit_image_mask]

        else:
            edit_image = edit_image_mask = [[]]
            image = torch.zeros(
                size=[3, int(output_height),
                      int(output_width)])
            image_mask = torch.ones(
                size=[1, int(output_height),
                      int(output_width)])
            if not isinstance(prompt, list):
                prompt = [prompt]

        image, image_mask, prompt = [image], [image_mask], [prompt]
        assert check_list_of_list(prompt) and check_list_of_list(
            edit_image) and check_list_of_list(edit_image_mask)
        # Assign Negative Prompt
        if isinstance(negative_prompt, list):
            negative_prompt = negative_prompt[0]
        assert isinstance(negative_prompt, str)

        n_prompt = copy.deepcopy(prompt)
        for nn_p_id, nn_p in enumerate(n_prompt):
            assert isinstance(nn_p, list)
            n_prompt[nn_p_id][-1] = negative_prompt

        ctx, null_ctx = {}, {}

        # Get Noise Shape
        image = to_device(image)
        x = self.encode_first_stage(image)
        noise = [
            torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
            for i in x
        ]
        noise, x_shapes = pack_imagelist_into_tensor(noise)
        ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes

        image_mask = to_device(image_mask, strict=False)
        cond_mask = [self.interpolate_func(i) for i in image_mask
                     ] if image_mask is not None else [None] * len(image)
        ctx['x_mask'] = null_ctx['x_mask'] = cond_mask

        # Encode Prompt
        
        function_name, dtype = self.get_function_info(self.cond_stage_model)
        cont, cont_mask = getattr(get_model(self.cond_stage_model),
                                  function_name)(prompt)
        cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
                                                     cont_mask)
        null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
                                            function_name)(n_prompt)
        null_cont, null_cont_mask = self.cond_stage_embeddings(
            prompt, edit_image, null_cont, null_cont_mask)
        ctx['crossattn'] = cont
        null_ctx['crossattn'] = null_cont

        # Encode Edit Images
        edit_image = [to_device(i, strict=False) for i in edit_image]
        edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
        e_img, e_mask = [], []
        for u, m in zip(edit_image, edit_image_mask):
            if u is None:
                continue
            if m is None:
                m = [None] * len(u)
            e_img.append(self.encode_first_stage(u, **kwargs))
            e_mask.append([self.interpolate_func(i) for i in m])

        null_ctx['edit'] = ctx['edit'] = e_img
        null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask

        # Diffusion Process
        function_name, dtype = self.get_function_info(self.diffusion_model)
        with torch.autocast('cuda',
                            enabled=dtype in ('float16', 'bfloat16'),
                            dtype=getattr(torch, dtype)):
            latent = self.diffusion.sample(
                noise=noise,
                sampler=sampler,
                model=get_model(self.diffusion_model),
                model_kwargs=[{
                    'cond':
                    ctx,
                    'mask':
                    cont_mask,
                    'text_position_embeddings':
                    self.text_position_embeddings.pos if hasattr(
                        self.text_position_embeddings, 'pos') else None
                }, {
                    'cond':
                    null_ctx,
                    'mask':
                    null_cont_mask,
                    'text_position_embeddings':
                    self.text_position_embeddings.pos if hasattr(
                        self.text_position_embeddings, 'pos') else None
                }] if guide_scale is not None and guide_scale > 1 else {
                    'cond':
                    null_ctx,
                    'mask':
                    cont_mask,
                    'text_position_embeddings':
                    self.text_position_embeddings.pos if hasattr(
                        self.text_position_embeddings, 'pos') else None
                },
                steps=sample_steps,
                show_progress=True,
                seed=seed,
                guide_scale=guide_scale,
                guide_rescale=guide_rescale,
                return_intermediate=None,
                **kwargs)

        # Decode to Pixel Space
        samples = unpack_tensor_into_imagelist(latent, x_shapes)
        x_samples = self.decode_first_stage(samples)

        imgs = [
            torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
                        min=0.0,
                        max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
            for x_i in x_samples
        ]
        imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
        return imgs

    def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
        if self.use_text_pos_embeddings and not torch.sum(
                self.text_position_embeddings.pos) > 0:
            identifier_cont, _ = getattr(get_model(self.cond_stage_model),
                                         'encode')(self.text_indentifers,
                                                   return_mask=True)
            self.text_position_embeddings.load_state_dict(
                {'pos': identifier_cont[:, 0, :]})

        cont_, cont_mask_ = [], []
        for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
            if isinstance(pp, list):
                cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
                cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
            else:
                raise NotImplementedError

        return cont_, cont_mask_


================================================
FILE: chatbot/run_gradio.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import base64
import copy
import csv
import glob
import io
import os
import random
import re
import string
import sys
import threading
import warnings

import cv2
import gradio as gr
import numpy as np
import torch
import transformers
from PIL import Image
from transformers import AutoModel, AutoTokenizer

from scepter.modules.utils.config import Config
from scepter.modules.utils.directory import get_md5
from scepter.modules.utils.file_system import FS
from scepter.studio.utils.env import init_env
from importlib.metadata import version

from ace_inference import ACEInference
from example import get_examples
from utils import load_image

csv.field_size_limit(sys.maxsize)

refresh_sty = '\U0001f504'  # 🔄
clear_sty = '\U0001f5d1'  # 🗑️
upload_sty = '\U0001f5bc'  # 🖼️
sync_sty = '\U0001f4be'  # 💾
chat_sty = '\U0001F4AC'  # 💬
video_sty = '\U0001f3a5'  # 🎥

lock = threading.Lock()


class ChatBotUI(object):
    def __init__(self,
                 cfg_general_file,
                 is_debug=False,
                 language='en',
                 root_work_dir='./'):
        try:
            from diffusers import CogVideoXImageToVideoPipeline
            from diffusers.utils import export_to_video
        except Exception as e:
            print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
        if isinstance(cfg_general_file, str):
            cfg = Config(cfg_file=cfg_general_file)
        else:
            cfg = cfg_general_file
        cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
        if not FS.exists(cfg.WORK_DIR):
            FS.make_dir(cfg.WORK_DIR)
        cfg = init_env(cfg)
        self.cache_dir = cfg.WORK_DIR
        self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
        self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
        self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
                                                  '*.yaml'))
        self.model_choices = dict()
        self.default_model_name = ''
        for i in self.model_yamls:
            model_cfg = Config(load=True, cfg_file=i)
            model_name = model_cfg.NAME
            if model_cfg.IS_DEFAULT: self.default_model_name = model_name
            self.model_choices[model_name] = model_cfg
        print('Models: ', self.model_choices.keys())
        assert len(self.model_choices) > 0
        if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
        self.model_name = self.default_model_name
        self.pipe = ACEInference()
        self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
        self.max_msgs = 20
        self.enable_i2v = cfg.get('ENABLE_I2V', False)
        self.gradio_version = version('gradio')

        if self.enable_i2v:
            self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
            self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
            if self.i2v_model_name == 'CogVideoX-5b-I2V':
                with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
                    self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
                        local_dir, torch_dtype=torch.bfloat16).cuda()
            else:
                raise NotImplementedError

            with FS.get_dir_to_local_dir(
                    cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
                self.captioner = AutoModel.from_pretrained(
                    local_dir,
                    torch_dtype=torch.bfloat16,
                    low_cpu_mem_usage=True,
                    use_flash_attn=True,
                    trust_remote_code=True).eval().cuda()
                self.llm_tokenizer = AutoTokenizer.from_pretrained(
                    local_dir, trust_remote_code=True, use_fast=False)
                self.llm_generation_config = dict(max_new_tokens=1024,
                                                  do_sample=True)
                self.llm_prompt = cfg.LLM.PROMPT
                self.llm_max_num = 2

            with FS.get_dir_to_local_dir(
                    cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
                self.enhancer = transformers.pipeline(
                    'text-generation',
                    model=local_dir,
                    model_kwargs={'torch_dtype': torch.bfloat16},
                    device_map='auto',
                )

            sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.

            For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
            There are a few rules to follow:

            You will only ever output a single video description per user request.

            When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
            Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.

            Video descriptions must have the same num of words as examples below. Extra words will be ignored.
            """
            self.enhance_ctx = [
                {
                    'role': 'system',
                    'content': sys_prompt
                },
                {
                    'role':
                    'user',
                    'content':
                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
                },
                {
                    'role':
                    'assistant',
                    'content':
                    "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
                },
                {
                    'role':
                    'user',
                    'content':
                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
                },
                {
                    'role':
                    'assistant',
                    'content':
                    "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
                },
                {
                    'role':
                    'user',
                    'content':
                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
                },
                {
                    'role':
                    'assistant',
                    'content':
                    'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
                },
            ]

    def create_ui(self):

        css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
        with gr.Blocks(css=css,
                       title='Chatbot',
                       head='Chatbot',
                       analytics_enabled=False):
            self.history = gr.State(value=[])
            self.images = gr.State(value={})
            self.history_result = gr.State(value={})
            self.retry_msg = gr.State(value='')
            with gr.Group():
                self.ui_mode = gr.State(value='legacy')
                with gr.Row(equal_height=True, visible=False) as self.chat_group:
                    with gr.Column(visible=True) as self.chat_page:
                        self.chatbot = gr.Chatbot(
                            height=600,
                            value=[],
                            bubble_full_width=False,
                            show_copy_button=True,
                            container=False,
                            placeholder='<strong>Chat Box</strong>')
                        with gr.Row():
                            self.clear_btn = gr.Button(clear_sty +
                                                       ' Clear Chat',
                                                       size='sm')

                    with gr.Column(visible=False) as self.editor_page:
                        with gr.Tabs(visible=False) as self.upload_tabs:
                            with gr.Tab(id='ImageUploader',
                                        label='Image Uploader',
                                        visible=True) as self.upload_tab:
                                self.image_uploader = gr.Image(
                                    height=550,
                                    interactive=True,
                                    type='pil',
                                    image_mode='RGB',
                                    sources=['upload'],
                                    elem_id='image_uploader',
                                    format='png')
                                with gr.Row():
                                    self.sub_btn_1 = gr.Button(
                                        value='Submit',
                                        elem_id='upload_submit')
                                    self.ext_btn_1 = gr.Button(value='Exit')
                        with gr.Tabs(visible=False) as self.edit_tabs:
                            with gr.Tab(id='ImageEditor',
                                        label='Image Editor') as self.edit_tab:
                                self.mask_type = gr.Dropdown(
                                    label='Mask Type',
                                    choices=[
                                        'Background', 'Composite',
                                        'Outpainting'
                                    ],
                                    value='Background')
                                self.mask_type_info = gr.HTML(
                                    value=
                                    "<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
                                )
                                with gr.Accordion(
                                        label='Outpainting Setting',
                                        open=True,
                                        visible=False) as self.outpaint_tab:
                                    with gr.Row(variant='panel'):
                                        self.top_ext = gr.Slider(
                                            show_label=True,
                                            label='Top Extend Ratio',
                                            minimum=0.0,
                                            maximum=2.0,
                                            step=0.1,
                                            value=0.25)
                                        self.bottom_ext = gr.Slider(
                                            show_label=True,
                                            label='Bottom Extend Ratio',
                                            minimum=0.0,
                                            maximum=2.0,
                                            step=0.1,
                                            value=0.25)
                                    with gr.Row(variant='panel'):
                                        self.left_ext = gr.Slider(
                                            show_label=True,
                                            label='Left Extend Ratio',
                                            minimum=0.0,
                                            maximum=2.0,
                                            step=0.1,
                                            value=0.25)
                                        self.right_ext = gr.Slider(
                                            show_label=True,
                                            label='Right Extend Ratio',
                                            minimum=0.0,
                                            maximum=2.0,
                                            step=0.1,
                                            value=0.25)
                                    with gr.Row(variant='panel'):
                                        self.img_pad_btn = gr.Button(
                                            value='Pad Image')

                                self.image_editor = gr.ImageMask(
                                    value=None,
                                    sources=[],
                                    layers=False,
                                    label='Edit Image',
                                    elem_id='image_editor',
                                    format='png')
                                with gr.Row():
                                    self.sub_btn_2 = gr.Button(
                                        value='Submit', elem_id='edit_submit')
                                    self.ext_btn_2 = gr.Button(value='Exit')

                            with gr.Tab(id='ImageViewer',
                                        label='Image Viewer') as self.image_view_tab:
                                if self.gradio_version >= '5.0.0':
                                    self.image_viewer = gr.Image(
                                        label='Image',
                                        type='pil',
                                        show_download_button=True,
                                        elem_id='image_viewer')
                                else:
                                    try:
                                        from gradio_imageslider import ImageSlider
                                    except Exception as e:
                                        print(f"Import gradio_imageslider failed, please install.")
                                    self.image_viewer = ImageSlider(
                                        label='Image',
                                        type='pil',
                                        show_download_button=True,
                                        elem_id='image_viewer')

                                self.ext_btn_3 = gr.Button(value='Exit')

                            with gr.Tab(id='VideoViewer',
                                        label='Video Viewer',
                                        visible=False) as self.video_view_tab:
                                self.video_viewer = gr.Video(
                                    label='Video',
                                    interactive=False,
                                    sources=[],
                                    format='mp4',
                                    show_download_button=True,
                                    elem_id='video_viewer',
                                    loop=True,
                                    autoplay=True)

                                self.ext_btn_4 = gr.Button(value='Exit')

                with gr.Row(equal_height=True, visible=True) as self.legacy_group:
                    with gr.Column():
                        self.legacy_image_uploader = gr.Image(
                            height=550,
                            interactive=True,
                            type='pil',
                            image_mode='RGB',
                            elem_id='legacy_image_uploader',
                            format='png')
                    with gr.Column():
                        self.legacy_image_viewer = gr.Image(
                            label='Image',
                            height=550,
                            type='pil',
                            interactive=False,
                            show_download_button=True,
                            elem_id='image_viewer')


                with gr.Accordion(label='Setting', open=False):
                    with gr.Row():
                        self.model_name_dd = gr.Dropdown(
                            choices=self.model_choices,
                            value=self.default_model_name,
                            label='Model Version')

                    with gr.Row():
                        self.negative_prompt = gr.Textbox(
                            value='',
                            placeholder=
                            'Negative prompt used for Classifier-Free Guidance',
                            label='Negative Prompt',
                            container=False)

                    with gr.Row():
                        # REFINER_PROMPT
                        self.refiner_prompt = gr.Textbox(
                            value=self.pipe.input.get("refiner_prompt", ""),
                            visible=self.pipe.input.get("refiner_prompt", None) is not None,
                            placeholder=
                            'Prompt used for refiner',
                            label='Refiner Prompt',
                            container=False)


                    with gr.Row():
                        with gr.Column(scale=8, min_width=500):
                            with gr.Row():
                                self.step = gr.Slider(minimum=1,
                                                      maximum=1000,
                                                      value=self.pipe.input.get("sample_steps", 20),
                                                      visible=self.pipe.input.get("sample_steps", None) is not None,
                                                      label='Sample Step')
                                self.cfg_scale = gr.Slider(
                                    minimum=1.0,
                                    maximum=20.0,
                                    value=self.pipe.input.get("guide_scale", 4.5),
                                    visible=self.pipe.input.get("guide_scale", None) is not None,
                                    label='Guidance Scale')
                                self.rescale = gr.Slider(minimum=0.0,
                                                         maximum=1.0,
                                                         value=self.pipe.input.get("guide_rescale", 0.5),
                                                         visible=self.pipe.input.get("guide_rescale", None) is not None,
                                                         label='Rescale')
                                self.refiner_scale = gr.Slider(minimum=-0.1,
                                                         maximum=1.0,
                                                         value=self.pipe.input.get("refiner_scale", -1),
                                                         visible=self.pipe.input.get("refiner_scale", None) is not None,
                                                         label='Refiner Scale')
                                self.seed = gr.Slider(minimum=-1,
                                                      maximum=10000000,
                                                      value=-1,
                                                      label='Seed')
                                self.output_height = gr.Slider(
                                    minimum=256,
                                    maximum=1440,
                                    value=self.pipe.input.get("output_height", 1024),
                                    visible=self.pipe.input.get("output_height", None) is not None,
                                    label='Output Height')
                                self.output_width = gr.Slider(
                                    minimum=256,
                                    maximum=1440,
                                    value=self.pipe.input.get("output_width", 1024),
                                    visible=self.pipe.input.get("output_width", None) is not None,
                                    label='Output Width')
                        with gr.Column(scale=1, min_width=50):
                            self.use_history = gr.Checkbox(value=False,
                                                           label='Use History')
                            self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
                                                       visible=self.pipe.input.get("use_ace", None) is not None,
                                                       label='Use ACE')
                            self.video_auto = gr.Checkbox(
                                value=False,
                                label='Auto Gen Video',
                                visible=self.enable_i2v)

                    with gr.Row(variant='panel',
                                equal_height=True,
                                visible=self.enable_i2v):
                        self.video_fps = gr.Slider(minimum=1,
                                                   maximum=16,
                                                   value=8,
                                                   label='Video FPS',
                                                   visible=True)
                        self.video_frames = gr.Slider(minimum=8,
                                                      maximum=49,
                                                      value=49,
                                                      label='Video Frame Num',
                                                      visible=True)
                        self.video_step = gr.Slider(minimum=1,
                                                    maximum=1000,
                                                    value=50,
                                                    label='Video Sample Step',
                                                    visible=True)
                        self.video_cfg_scale = gr.Slider(
                            minimum=1.0,
                            maximum=20.0,
                            value=6.0,
                            label='Video Guidance Scale',
                            visible=True)
                        self.video_seed = gr.Slider(minimum=-1,
                                                    maximum=10000000,
                                                    value=-1,
                                                    label='Video Seed',
                                                    visible=True)

                with gr.Row():
                    self.chatbot_inst = """
                       **Instruction**:

                       1. Click 'Upload' button to upload one or more images as input images.
                       2. Enter '@' in the text box will exhibit all images in the gallery.
                       3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
                       4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction.
                       For example, you might say, "Change the girl's skirt in @123456 to blue." The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input.
                       5. Once your instructions are prepared, please click the "Chat" button to view the edited result in the chat window.
                       6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
                       7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
                       8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.

                    """

                    self.legacy_inst = """
                       **Instruction**:

                       1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
                       2. Enter '@' in the text box will exhibit all images in the gallery.
                       3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
                       4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
                       5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
                       6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.

                    """

                    self.instruction = gr.Markdown(value=self.legacy_inst)

                with gr.Row(variant='panel',
                            equal_height=True,
                            show_progress=False):
                    with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
                        self.upload_btn = gr.Button(value=upload_sty +
                                                    ' Upload',
                                                    variant='secondary')
                    with gr.Column(scale=5, min_width=500):
                        self.text = gr.Textbox(
                            placeholder='Input "@" find history of image',
                            label='Instruction',
                            container=False)
                    with gr.Column(scale=1, min_width=100):
                        self.chat_btn = gr.Button(value='Generate',
                                                  variant='primary')
                    with gr.Column(scale=1, min_width=100):
                        self.retry_btn = gr.Button(value=refresh_sty +
                                                   ' Retry',
                                                   variant='secondary')
                    with gr.Column(scale=1, min_width=100):
                        self.mode_checkbox = gr.Checkbox(
                            value=False,
                            label='ChatBot')
                    with gr.Column(scale=(1 if self.enable_i2v else 0),
                                   min_width=0):
                        self.video_gen_btn = gr.Button(value=video_sty +
                                                       ' Gen Video',
                                                       variant='secondary',
                                                       visible=self.enable_i2v)
                    with gr.Column(scale=(1 if self.enable_i2v else 0),
                                   min_width=0):
                        self.extend_prompt = gr.Checkbox(
                            value=True,
                            label='Extend Prompt',
                            visible=self.enable_i2v)

                with gr.Row():
                    self.gallery = gr.Gallery(visible=False,
                                              label='History',
                                              columns=10,
                                              allow_preview=False,
                                              interactive=False)

                self.eg = gr.Column(visible=True)

    def set_callbacks(self, *args, **kwargs):

        ########################################
        def change_model(model_name):
            if model_name not in self.model_choices:
                gr.Info('The provided model name is not a valid choice!')
                return model_name, gr.update(), gr.update()

            if model_name != self.model_name:
                lock.acquire()
                del self.pipe
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()
                self.pipe = ACEInference()
                self.pipe.init_from_cfg(self.model_choices[model_name])
                self.model_name = model_name
                lock.release()

            return (model_name, gr.update(), gr.update(),
                    gr.Slider(
                              value=self.pipe.input.get("sample_steps", 20),
                              visible=self.pipe.input.get("sample_steps", None) is not None),
                    gr.Slider(
                        value=self.pipe.input.get("guide_scale", 4.5),
                        visible=self.pipe.input.get("guide_scale", None) is not None),
                    gr.Slider(
                              value=self.pipe.input.get("guide_rescale", 0.5),
                              visible=self.pipe.input.get("guide_rescale", None) is not None),
                    gr.Slider(
                        value=self.pipe.input.get("output_height", 1024),
                        visible=self.pipe.input.get("output_height", None) is not None),
                    gr.Slider(
                        value=self.pipe.input.get("output_width", 1024),
                        visible=self.pipe.input.get("output_width", None) is not None),
                    gr.Textbox(
                        value=self.pipe.input.get("refiner_prompt", ""),
                        visible=self.pipe.input.get("refiner_prompt", None) is not None),
                    gr.Slider(
                              value=self.pipe.input.get("refiner_scale", -1),
                              visible=self.pipe.input.get("refiner_scale", None) is not None
                        ),
                    gr.Checkbox(
                        value=self.pipe.input.get("use_ace", True),
                        visible=self.pipe.input.get("use_ace", None) is not None
                    )
                    )

        self.model_name_dd.change(
            change_model,
            inputs=[self.model_name_dd],
            outputs=[
                self.model_name_dd, self.chatbot, self.text,
                self.step,
                self.cfg_scale, self.rescale, self.output_height,
                self.output_width, self.refiner_prompt, self.refiner_scale,
                self.use_ace])


        def mode_change(mode_check):
            if mode_check:
                # ChatBot
                return (
                    gr.Row(visible=False),
                    gr.Row(visible=True),
                    gr.Button(value='Generate'),
                    gr.State(value='chatbot'),
                    gr.Column(visible=True),
                    gr.Markdown(value=self.chatbot_inst)
                )
            else:
                # Legacy
                return (
                    gr.Row(visible=True),
                    gr.Row(visible=False),
                    gr.Button(value=chat_sty + ' Chat'),
                    gr.State(value='legacy'),
                    gr.Column(visible=False),
                    gr.Markdown(value=self.legacy_inst)
                )
        self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
                                  outputs=[self.legacy_group, self.chat_group,
                                           self.chat_btn, self.ui_mode,
                                           self.upload_panel, self.instruction])


        ########################################
        def generate_gallery(text, images):
            if text.endswith(' '):
                return gr.update(), gr.update(visible=False)
            elif text.endswith('@'):
                gallery_info = []
                for image_id, image_meta in images.items():
                    thumbnail_path = image_meta['thumbnail']
                    gallery_info.append((thumbnail_path, image_id))
                return gr.update(), gr.update(visible=True, value=gallery_info)
            else:
                gallery_info = []
                match = re.search('@([^@ ]+)$', text)
                if match:
                    prefix = match.group(1)
                    for image_id, image_meta in images.items():
                        if not image_id.startswith(prefix):
                            continue
                        thumbnail_path = image_meta['thumbnail']
                        gallery_info.append((thumbnail_path, image_id))

                    if len(gallery_info) > 0:
                        return gr.update(), gr.update(visible=True,
                                                      value=gallery_info)
                    else:
                        return gr.update(), gr.update(visible=False)
                else:
                    return gr.update(), gr.update(visible=False)

        self.text.input(generate_gallery,
                        inputs=[self.text, self.images],
                        outputs=[self.text, self.gallery],
                        show_progress='hidden')

        ########################################
        def select_image(text, evt: gr.SelectData):
            image_id = evt.value['caption']
            text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
            return gr.update(value=text), gr.update(visible=False, value=None)

        self.gallery.select(select_image,
                            inputs=self.text,
                            outputs=[self.text, self.gallery])

        ########################################
        def generate_video(message,
                           extend_prompt,
                           history,
                           images,
                           num_steps,
                           num_frames,
                           cfg_scale,
                           fps,
                           seed,
                           progress=gr.Progress(track_tqdm=True)):

            from diffusers.utils import export_to_video

            generator = torch.Generator(device='cuda').manual_seed(seed)
            img_ids = re.findall('@(.*?)[ ,;.?$]', message)
            if len(img_ids) == 0:
                history.append((
                    message,
                    'Sorry, no images were found in the prompt to be used as the first frame of the video.'
                ))
                while len(history) >= self.max_msgs:
                    history.pop(0)
                return history, self.get_history(
                    history), gr.update(), gr.update(visible=False)

            img_id = img_ids[0]
            prompt = re.sub(f'@{img_id}\s+', '', message)

            if extend_prompt:
                messages = copy.deepcopy(self.enhance_ctx)
                messages.append({
                    'role':
                    'user',
                    'content':
                    f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
                })
                lock.acquire()
                outputs = self.enhancer(
                    messages,
                    max_new_tokens=200,
                )

                prompt = outputs[0]['generated_text'][-1]['content']
                print(prompt)
                lock.release()

            img_meta = images[img_id]
            img_path = img_meta['image']
            image = Image.open(img_path).convert('RGB')

            lock.acquire()
            video = self.i2v_pipe(
                prompt=prompt,
                image=image,
                num_videos_per_prompt=1,
                num_inference_steps=num_steps,
                num_frames=num_frames,
                guidance_scale=cfg_scale,
                generator=generator,
            ).frames[0]
            lock.release()

            out_video_path = export_to_video(video, fps=fps)
            history.append((
                f"Based on first frame @{img_id} and description '{prompt}', generate a video",
                'This is generated video:'))
            history.append((None, out_video_path))
            while len(history) >= self.max_msgs:
                history.pop(0)

            return history, self.get_history(history), gr.update(
                value=''), gr.update(visible=False)

        self.video_gen_btn.click(
            generate_video,
            inputs=[
                self.text, self.extend_prompt, self.history, self.images,
                self.video_step, self.video_frames, self.video_cfg_scale,
                self.video_fps, self.video_seed
            ],
            outputs=[self.history, self.chatbot, self.text, self.gallery])

        ########################################
        def run_chat(
                     message,
                     legacy_image,
                     ui_mode,
                     use_ace,
                     extend_prompt,
                     history,
                     images,
                     use_history,
                     history_result,
                     negative_prompt,
                     cfg_scale,
                     rescale,
                     refiner_prompt,
                     refiner_scale,
                     step,
                     seed,
                     output_h,
                     output_w,
                     video_auto,
                     video_steps,
                     video_frames,
                     video_cfg_scale,
                     video_fps,
                     video_seed,
                     progress=gr.Progress(track_tqdm=True)):
            legacy_img_ids = []
            if ui_mode == 'legacy':
                if legacy_image is not None:
                    history, images, img_id = self.add_uploaded_image_to_history(
                        legacy_image, history, images)
                    legacy_img_ids.append(img_id)
            retry_msg = message
            gen_id = get_md5(message)[:12]
            save_path = os.path.join(self.cache_dir, f'{gen_id}.png')

            img_ids = re.findall('@(.*?)[ ,;.?$]', message)
            history_io = None

            if len(img_ids) < 1:
                img_ids = legacy_img_ids
                for img_id in img_ids:
                    if f'@{img_id}' not in message:
                        message = f'@{img_id} ' + message

            new_message = message

            if len(img_ids) > 0:
                edit_image, edit_image_mask, edit_task = [], [], []
                for i, img_id in enumerate(img_ids):
                    if img_id not in images:
                        gr.Info(
                            f'The input image ID {img_id} is not exist... Skip loading image.'
                        )
                        continue
                    placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
                    new_message = re.sub(f'@{img_id}', placeholder,
                                         new_message)
                    img_meta = images[img_id]
                    img_path = img_meta['image']
                    img_mask = img_meta['mask']
                    img_mask_type = img_meta['mask_type']
                    if img_mask_type is not None and img_mask_type == 'Composite':
                        task = 'inpainting'
                    else:
                        task = ''
                    edit_image.append(Image.open(img_path).convert('RGB'))
                    edit_image_mask.append(
                        Image.open(img_mask).
                        convert('L') if img_mask is not None else None)
                    edit_task.append(task)

                    if use_history and (img_id in history_result):
                        history_io = history_result[img_id]

                buffered = io.BytesIO()
                edit_image[0].save(buffered, format='PNG')
                img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
                img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
                pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
            else:
                pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
                edit_image = None
                edit_image_mask = None
                edit_task = ''

            print(new_message)
            imgs = self.pipe(
                image=edit_image,
                mask=edit_image_mask,
                task=edit_task,
                prompt=[new_message] *
                len(edit_image) if edit_image is not None else [new_message],
                negative_prompt=[negative_prompt] * len(edit_image)
                if edit_image is not None else [negative_prompt],
                history_io=history_io,
                output_height=output_h,
                output_width=output_w,
                sampler='ddim',
                sample_steps=step,
                guide_scale=cfg_scale,
                guide_rescale=rescale,
                seed=seed,
                refiner_prompt=refiner_prompt,
                refiner_scale=refiner_scale,
                use_ace=use_ace
            )

            img = imgs[0]
            img.save(save_path, format='PNG')

            if history_io:
                history_io_new = copy.deepcopy(history_io)
                history_io_new['image'] += edit_image[:1]
                history_io_new['mask'] += edit_image_mask[:1]
                history_io_new['task'] += edit_task[:1]
                history_io_new['prompt'] += [new_message]
                history_io_new['image'] = history_io_new['image'][-5:]
                history_io_new['mask'] = history_io_new['mask'][-5:]
                history_io_new['task'] = history_io_new['task'][-5:]
                history_io_new['prompt'] = history_io_new['prompt'][-5:]
                history_result[gen_id] = history_io_new
            elif edit_image is not None and len(edit_image) > 0:
                history_io_new = {
                    'image': edit_image[:1],
                    'mask': edit_image_mask[:1],
                    'task': edit_task[:1],
                    'prompt': [new_message]
                }
                history_result[gen_id] = history_io_new

            w, h = img.size
            if w > h:
                tb_w = 128
                tb_h = int(h * tb_w / w)
            else:
                tb_h = 128
                tb_w = int(w * tb_h / h)

            thumbnail_path = os.path.join(self.cache_dir,
                                          f'{gen_id}_thumbnail.jpg')
            thumbnail = img.resize((tb_w, tb_h))
            thumbnail.save(thumbnail_path, format='JPEG')

            images[gen_id] = {
                'image': save_path,
                'mask': None,
                'mask_type': None,
                'thumbnail': thumbnail_path
            }

            buffered = io.BytesIO()
            img.convert('RGB').save(buffered, format='PNG')
            img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'

            history.append(
                (message,
                 f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))

            if video_auto:
                if video_seed is None or video_seed == -1:
                    video_seed = random.randint(0, 10000000)

                lock.acquire()
                generator = torch.Generator(
                    device='cuda').manual_seed(video_seed)
                pixel_values = load_image(img.convert('RGB'),
                                          max_num=self.llm_max_num).to(
                                              torch.bfloat16).cuda()
                prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
                                             self.llm_prompt,
                                             self.llm_generation_config)
                print(prompt)
                lock.release()

                if extend_prompt:
                    messages = copy.deepcopy(self.enhance_ctx)
                    messages.append({
                        'role':
                        'user',
                        'content':
                        f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
                    })
                    lock.acquire()
                    outputs = self.enhancer(
                        messages,
                        max_new_tokens=200,
                    )
                    prompt = outputs[0]['generated_text'][-1]['content']
                    print(prompt)
                    lock.release()

                lock.acquire()
                video = self.i2v_pipe(
                    prompt=prompt,
                    image=img,
                    num_videos_per_prompt=1,
                    num_inference_steps=video_steps,
                    num_frames=video_frames,
                    guidance_scale=video_cfg_scale,
                    generator=generator,
                ).frames[0]
                lock.release()

                out_video_path = export_to_video(video, fps=video_fps)
                history.append((
                    f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
                    'This is generated video:'))
                history.append((None, out_video_path))

            while len(history) >= self.max_msgs:
                history.pop(0)

            return (history, images, gr.Image(value=save_path),
                    history_result, self.get_history(
                history), gr.update(), gr.update(
                    visible=False), retry_msg)

        chat_inputs = [
            self.legacy_image_uploader, self.ui_mode, self.use_ace,
            self.extend_prompt, self.history, self.images, self.use_history,
            self.history_result, self.negative_prompt, self.cfg_scale,
            self.rescale, self.refiner_prompt, self.refiner_scale,
            self.step, self.seed, self.output_height,
            self.output_width, self.video_auto, self.video_step,
            self.video_frames, self.video_cfg_scale, self.video_fps,
            self.video_seed
        ]

        chat_outputs = [
            self.history, self.images, self.legacy_image_viewer,
            self.history_result, self.chatbot,
            self.text, self.gallery, self.retry_msg
        ]

        self.chat_btn.click(run_chat,
                            inputs=[self.text] + chat_inputs,
                            outputs=chat_outputs)

        self.text.submit(run_chat,
                         inputs=[self.text] + chat_inputs,
                         outputs=chat_outputs)

        def retry_fn(*args):
            return run_chat(*args)

        self.retry_btn.click(retry_fn,
                             inputs=[self.retry_msg] + chat_inputs,
                             outputs=chat_outputs)

        ########################################
        def run_example(task, img, img_mask, ref1, prompt, seed):
            edit_image, edit_image_mask, edit_task = [], [], []
            if img is not None:
                w, h = img.size
                if w > 2048:
                    ratio = w / 2048.
                    w = 2048
                    h = int(h / ratio)
                if h > 2048:
                    ratio = h / 2048.
                    h = 2048
                    w = int(w / ratio)
                img = img.resize((w, h))
                edit_image.append(img)
                if img_mask is not None:
                    img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None
                edit_image_mask.append(
                    img_mask if img_mask is not None else None)
                edit_task.append(task)
                if ref1 is not None:
                    ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None
                if ref1 is not None:
                    edit_image.append(ref1)
                    edit_image_mask.append(None)
                    edit_task.append('')

                buffered = io.BytesIO()
                img.save(buffered, format='PNG')
                img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
                img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
                pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
            else:
                pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
                edit_image = None
                edit_image_mask = None
                edit_task = ''

            img_num = len(edit_image) if edit_image is not None else 1
            imgs = self.pipe(
                image=edit_image,
                mask=edit_image_mask,
                task=edit_task,
                prompt=[prompt] * img_num,
                negative_prompt=[''] * img_num,
                seed=seed,
                refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
                refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
            )

            img = imgs[0]
            buffered = io.BytesIO()
            img.convert('RGB').save(buffered, format='PNG')
            img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
            history = [(prompt,
                        f'{pre_info} The generated image is:\n {img_str}')]

            img_id = get_md5(img_b64)[:12]
            save_path = os.path.join(self.cache_dir, f'{img_id}.png')
            img.convert('RGB').save(save_path)
            return self.get_history(history), gr.update(value=''), gr.update(
                visible=False),  gr.update(value=save_path), gr.update(value=-1)

        with self.eg:
            self.example_task = gr.Text(label='Task Name',
                                        value='',
                                        visible=False)
            self.example_image = gr.Image(label='Edit Image',
                                          type='pil',
                                          image_mode='RGB',
                                          visible=False)
            self.example_mask = gr.Image(label='Edit Image Mask',
                                         type='pil',
                                         image_mode='L',
                                         visible=False)
            self.example_ref_im1 = gr.Image(label='Ref Image',
                                            type='pil',
                                            image_mode='RGB',
                                            visible=False)
            self.examples = gr.Examples(
                fn=run_example,
                examples=self.chatbot_examples,
                inputs=[
                    self.example_task, self.example_image, self.example_mask,
                    self.example_ref_im1, self.text, self.seed
                ],
                outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed],
                examples_per_page=4,
                cache_examples=False,
                run_on_click=True)

        ########################################
        def upload_image():
            return (gr.update(visible=True,
                              scale=1), gr.update(visible=True, scale=1),
                    gr.update(visible=True), gr.update(visible=False),
                    gr.update(visible=False), gr.update(visible=False),
                    gr.update(visible=True))

        self.upload_btn.click(upload_image,
                              inputs=[],
                              outputs=[
                                  self.chat_page, self.editor_page,
                                  self.upload_tab, self.edit_tab,
                                  self.image_view_tab, self.video_view_tab,
                                  self.upload_tabs
                              ])

        ########################################
        def edit_image(evt: gr.SelectData):
            if isinstance(evt.value, str):
                img_b64s = re.findall(
                    '<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
                    evt.value)
                imgs = [
                    Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
                    for i in img_b64s
                ]
                if len(imgs) > 0:
                    if len(imgs) == 2:
                        if self.gradio_version >= '5.0.0':
                            view_img = copy.deepcopy(imgs[-1])
                        else:
                            view_img = copy.deepcopy(imgs)
                        edit_img = copy.deepcopy(imgs[-1])
                    else:
                        if self.gradio_version >= '5.0.0':
                            view_img = copy.deepcopy(imgs[-1])
                        else:
                            view_img = [
                                copy.deepcopy(imgs[-1]),
                                copy.deepcopy(imgs[-1])
                            ]
                        edit_img = copy.deepcopy(imgs[-1])

                    return (gr.update(visible=True,
                                      scale=1), gr.update(visible=True,
                                                          scale=1),
                            gr.update(visible=False), gr.update(visible=True),
                            gr.update(visible=True), gr.update(visible=False),
                            gr.update(value=edit_img),
                            gr.update(value=view_img), gr.update(value=None),
                            gr.update(visible=True))
                else:
                    return (gr.update(), gr.update(), gr.update(), gr.update(),
                            gr.update(), gr.update(), gr.update(), gr.update(),
                            gr.update(), gr.update())
            elif isinstance(evt.value, dict) and evt.value.get(
                    'component', '') == 'video':
                value = evt.value['value']['video']['path']
                return (gr.update(visible=True,
                                  scale=1), gr.update(visible=True, scale=1),
                        gr.update(visible=False), gr.update(visible=False),
                        gr.update(visible=False), gr.update(visible=True),
                        gr.update(), gr.update(), gr.update(value=value),
                        gr.update())
            else:
                return (gr.update(), gr.update(), gr.update(), gr.update(),
                        gr.update(), gr.update(), gr.update(), gr.update(),
                        gr.update(), gr.update())

        self.chatbot.select(edit_image,
                            outputs=[
                                self.chat_page, self.editor_page,
                                self.upload_tab, self.edit_tab,
                                self.image_view_tab, self.video_view_tab,
                                self.image_editor, self.image_viewer,
                                self.video_viewer, self.edit_tabs
                            ])

        if self.gradio_version < '5.0.0':
            self.image_viewer.change(lambda x: x,
                                     inputs=self.image_viewer,
                                     outputs=self.image_viewer)

        ########################################
        def submit_upload_image(image, history, images):
            history, images, _ = self.add_uploaded_image_to_history(
                image, history, images)
            return gr.update(visible=False), gr.update(
                visible=True), gr.update(
                    value=self.get_history(history)), history, images

        self.sub_btn_1.click(
            submit_upload_image,
            inputs=[self.image_uploader, self.history, self.images],
            outputs=[
                self.editor_page, self.chat_page, self.chatbot, self.history,
                self.images
            ])

        ########################################
        def submit_edit_image(imagemask, mask_type, history, images):
            history, images = self.add_edited_image_to_history(
                imagemask, mask_type, history, images)
            return gr.update(visible=False), gr.update(
                visible=True), gr.update(
                    value=self.get_history(history)), history, images

        self.sub_btn_2.click(submit_edit_image,
                             inputs=[
                                 self.image_editor, self.mask_type,
                                 self.history, self.images
                             ],
                             outputs=[
                                 self.editor_page, self.chat_page,
                                 self.chatbot, self.history, self.images
                             ])

        ########################################
        def exit_edit():
            return gr.update(visible=False), gr.update(visible=True, scale=3)

        self.ext_btn_1.click(exit_edit,
                             outputs=[self.editor_page, self.chat_page])
        self.ext_btn_2.click(exit_edit,
                             outputs=[self.editor_page, self.chat_page])
        self.ext_btn_3.click(exit_edit,
                             outputs=[self.editor_page, self.chat_page])
        self.ext_btn_4.click(exit_edit,
                             outputs=[self.editor_page, self.chat_page])

        ########################################
        def update_mask_type_info(mask_type):
            if mask_type == 'Background':
                info = 'Background mode will not erase the visual content in the mask area'
                visible = False
            elif mask_type == 'Composite':
                info = 'Composite mode will erase the visual content in the mask area'
                visible = False
            elif mask_type == 'Outpainting':
                info = 'Outpaint mode is used for preparing input image for outpainting task'
                visible = True
            return (gr.update(
                visible=True,
                value=
                f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
            ), gr.update(visible=visible))

        self.mask_type.change(update_mask_type_info,
                              inputs=self.mask_type,
                              outputs=[self.mask_type_info, self.outpaint_tab])

        ########################################
        def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
                         image):
            img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
            h, w = img.shape[:2]
            new_h = int(h * (top_ratio + bottom_ratio + 1))
            new_w = int(w * (left_ratio + right_ratio + 1))
            start_h = int(h * top_ratio)
            start_w = int(w * left_ratio)
            new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
            new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
            new_img[start_h:start_h + h, start_w:start_w + w, :] = img
            new_mask[start_h:start_h + h, start_w:start_w + w] = 0
            layer = np.concatenate([new_img, new_mask], axis=2)
            value = {
                'background': new_img,
                'composite': new_img,
                'layers': [layer]
            }
            return gr.update(value=value)

        self.img_pad_btn.click(extend_image,
                               inputs=[
                                   self.top_ext, self.bottom_ext,
                                   self.left_ext, self.right_ext,
                                   self.image_editor
                               ],
                               outputs=self.image_editor)

        ########################################
        def clear_chat(history, images, history_result):
            history.clear()
            images.clear()
            history_result.clear()
            return history, images, history_result, self.get_history(history)

        self.clear_btn.click(
            clear_chat,
            inputs=[self.history, self.images, self.history_result],
            outputs=[
                self.history, self.images, self.history_result, self.chatbot
            ])

    def get_history(self, history):
        info = []
        for item in history:
            new_item = [None, None]
            if isinstance(item[0], str) and item[0].endswith('.mp4'):
                new_item[0] = gr.Video(item[0], format='mp4')
            else:
                new_item[0] = item[0]
            if isinstance(item[1], str) and item[1].endswith('.mp4'):
                new_item[1] = gr.Video(item[1], format='mp4')
            else:
                new_item[1] = item[1]
            info.append(new_item)
        return info

    def generate_random_string(self, length=20):
        letters_and_digits = string.ascii_letters + string.digits
        random_string = ''.join(
            random.choice(letters_and_digits) for i in range(length))
        return random_string

    def add_edited_image_to_history(self, image, mask_type, history, images):
        if mask_type == 'Composite':
            img = Image.fromarray(image['composite'])
        else:
            img = Image.fromarray(image['background'])

        img_id = get_md5(self.generate_random_string())[:12]
        save_path = os.path.join(self.cache_dir, f'{img_id}.png')
        img.convert('RGB').save(save_path)

        mask = image['layers'][0][:, :, 3]
        mask = Image.fromarray(mask).convert('RGB')
        mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
        mask.save(mask_path)

        w, h = img.size
        if w > h:
            tb_w = 128
            tb_h = int(h * tb_w / w)
        else:
            tb_h = 128
            tb_w = int(w * tb_h / h)

        if mask_type == 'Background':
            comp_mask = np.array(mask, dtype=np.uint8)
            mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
                          0.6).astype(np.uint8)
            comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
            thumbnail = Image.alpha_composite(
                img.convert('RGBA'),
                Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
        else:
            thumbnail = img.convert('RGB')

        thumbnail_path = os.path.join(self.cache_dir,
                                      f'{img_id}_thumbnail.jpg')
        thumbnail = thumbnail.resize((tb_w, tb_h))
        thumbnail.save(thumbnail_path, format='JPEG')

        buffered = io.BytesIO()
        img.convert('RGB').save(buffered, format='PNG')
        img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'

        buffered = io.BytesIO()
        mask.convert('RGB').save(buffered, format='PNG')
        mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'

        images[img_id] = {
            'image': save_path,
            'mask': mask_path,
            'mask_type': mask_type,
            'thumbnail': thumbnail_path
        }
        history.append((
            None,
            f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
        ))
        return history, images

    def add_uploaded_image_to_history(self, img, history, images):
        img_id = get_md5(self.generate_random_string())[:12]
        save_path = os.path.join(self.cache_dir, f'{img_id}.png')
        w, h = img.size
        if w > 2048:
            ratio = w / 2048.
            w = 2048
            h = int(h / ratio)
        if h > 2048:
            ratio = h / 2048.
            h = 2048
            w = int(w / ratio)
        img = img.resize((w, h))
        img.save(save_path)

        w, h = img.size
        if w > h:
            tb_w = 128
            tb_h = int(h * tb_w / w)
        else:
            tb_h = 128
            tb_w = int(w * tb_h / h)
        thumbnail_path = os.path.join(self.cache_dir,
                                      f'{img_id}_thumbnail.jpg')
        thumbnail = img.resize((tb_w, tb_h))
        thumbnail.save(thumbnail_path, format='JPEG')

        images[img_id] = {
            'image': save_path,
            'mask': None,
            'mask_type': None,
            'thumbnail': thumbnail_path
        }

        buffered = io.BytesIO()
        img.convert('RGB').save(buffered, format='PNG')
        img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'

        history.append(
            (None,
             f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
        return history, images, img_id


def run_gr(cfg):
    with gr.Blocks() as demo:
        chatbot = ChatBotUI(cfg)
        chatbot.create_ui()
        chatbot.set_callbacks()
        demo.launch(server_name='0.0.0.0',
                    server_port=cfg.args.server_port,
                    root_path=cfg.args.root_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
    parser.add_argument('--server_port',
                        dest='server_port',
                        help='',
                        type=int,
                        default=2345)
    parser.add_argument('--root_path', dest='root_path', help='', default='')
    cfg = Config(load=True, parser_ins=parser)
    run_gr(cfg)


================================================
FILE: chatbot/utils.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image,
                       min_num=1,
                       max_num=12,
                       image_size=448,
                       use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=12):
    if isinstance(image_file, str):
        image = Image.open(image_file).convert('RGB')
    else:
        image = image_file
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
                                image_size=input_size,
                                use_thumbnail=True,
                                max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


================================================
FILE: config/inference_config/chatbot_ui.yaml
================================================
WORK_DIR: ./cache/chatbot
FILE_SYSTEM:
  - NAME: "HuggingfaceFs"
    TEMP_DIR: ./cache
  - NAME: "ModelscopeFs"
    TEMP_DIR: ./cache
  - NAME: "LocalFs"
    TEMP_DIR: ./cache
  - NAME: "HttpFs"
    TEMP_DIR: ./cache
#
ENABLE_I2V: False
SKIP_EXAMPLES: False
#
MODEL:
  EDIT_MODEL:
    MODEL_CFG_DIR: config/inference_config/models/
    DEFAULT: ace_0.6b_512
  I2V:
    MODEL_NAME: CogVideoX-5b-I2V
    MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
  CAPTIONER:
    MODEL_NAME: InternVL2-2B
    MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
    PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
  ENHANCER:
    MODEL_NAME: Meta-Llama-3.1-8B-Instruct
    MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/


================================================
FILE: config/inference_config/models/ace_0.6b_1024.yaml
================================================
NAME: ACE_0.6B_1024
IS_DEFAULT: False
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
  PARAS:
  #
  INPUT:
    INPUT_IMAGE:
    INPUT_MASK:
    TASK:
    PROMPT: ""
    NEGATIVE_PROMPT: ""
    OUTPUT_HEIGHT: 1024
    OUTPUT_WIDTH: 1024
    SAMPLER: ddim
    SAMPLE_STEPS: 50
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
    SEED: -1
    TAR_INDEX: 0
  OUTPUT:
    LATENT:
    IMAGES:
    SEED:
  MODULES_PARAS:
    FIRST_STAGE_MODEL:
      FUNCTION:
        - NAME: encode
          DTYPE: float16
          INPUT: ["IMAGE"]
        - NAME: decode
          DTYPE: float16
          INPUT: ["LATENT"]
    #
    DIFFUSION_MODEL:
      FUNCTION:
        - NAME: forward
          DTYPE: float16
          INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
    #
    COND_STAGE_MODEL:
      FUNCTION:
        - NAME: encode_list
          DTYPE: bfloat16
          INPUT: ["PROMPT"]
#
MODEL:
  NAME: LdmACE
  PRETRAINED_MODEL:
  IGNORE_KEYS: [ ]
  SCALE_FACTOR: 0.18215
  SIZE_FACTOR: 8
  DECODER_BIAS: 0.5
  DEFAULT_N_PROMPT: ""
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
  USE_TEXT_POS_EMBEDDINGS: True
  #
  DIFFUSION:
    NAME: ACEDiffusion
    PREDICTION_TYPE: eps
    MIN_SNR_GAMMA:
    NOISE_SCHEDULER:
      NAME: LinearScheduler
      NUM_TIMESTEPS: 1000
      BETA_MIN: 0.0001
      BETA_MAX: 0.02
  #
  DIFFUSION_MODEL:
    NAME: DiTACE
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth
    IGNORE_KEYS: [ ]
    PATCH_SIZE: 2
    IN_CHANNELS: 4
    HIDDEN_SIZE: 1152
    DEPTH: 28
    NUM_HEADS: 16
    MLP_RATIO: 4.0
    PRED_SIGMA: True
    DROP_PATH: 0.0
    WINDOW_DIZE: 0
    Y_CHANNELS: 4096
    MAX_SEQ_LEN: 4096
    QK_NORM: True
    USE_GRAD_CHECKPOINT: True
    ATTENTION_BACKEND: flash_attn
  #
  FIRST_STAGE_MODEL:
    NAME: AutoencoderKL
    EMBED_DIM: 4
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
    IGNORE_KEYS: []
    #
    ENCODER:
      NAME: Encoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DOUBLE_Z: True
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
    #
    DECODER:
      NAME: Decoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
      GIVE_PRE_END: False
      TANH_OUT: False
  #
  COND_STAGE_MODEL:
    NAME: ACETextEmbedder
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
    TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
    LENGTH: 120
    T5_DTYPE: bfloat16
    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    CLEAN: whitespace
    USE_GRAD: False


================================================
FILE: config/inference_config/models/ace_0.6b_1024_refiner.yaml
================================================
NAME: ACE_0.6B_1024_REFINER
IS_DEFAULT: False
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
  PARAS:
  #
  INPUT:
    INPUT_IMAGE:
    INPUT_MASK:
    TASK:
    PROMPT: ""
    NEGATIVE_PROMPT: ""
    OUTPUT_HEIGHT: 1024
    OUTPUT_WIDTH: 1024
    SAMPLER: ddim
    SAMPLE_STEPS: 50
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
    SEED: -1
    TAR_INDEX: 0
    REFINER_SCALE: 0.2
    USE_ACE: True
    #REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR"
    REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR"
  OUTPUT:
    LATENT:
    IMAGES:
    SEED:
  MODULES_PARAS:
    FIRST_STAGE_MODEL:
      FUNCTION:
        - NAME: encode
          DTYPE: float16
          INPUT: ["IMAGE"]
        - NAME: decode
          DTYPE: float16
          INPUT: ["LATENT"]
    #
    DIFFUSION_MODEL:
      FUNCTION:
        - NAME: forward
          DTYPE: float16
          INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
    #
    COND_STAGE_MODEL:
      FUNCTION:
        - NAME: encode_list
          DTYPE: bfloat16
          INPUT: ["PROMPT"]
#
MODEL:
  NAME: LdmACE
  PRETRAINED_MODEL:
  IGNORE_KEYS: [ ]
  SCALE_FACTOR: 0.18215
  SIZE_FACTOR: 8
  DECODER_BIAS: 0.5
  DEFAULT_N_PROMPT: ""
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
  USE_TEXT_POS_EMBEDDINGS: True
  #
  DIFFUSION:
    NAME: ACEDiffusion
    PREDICTION_TYPE: eps
    MIN_SNR_GAMMA:
    NOISE_SCHEDULER:
      NAME: LinearScheduler
      NUM_TIMESTEPS: 1000
      BETA_MIN: 0.0001
      BETA_MAX: 0.02
  #
  DIFFUSION_MODEL:
    NAME: DiTACE
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth
    IGNORE_KEYS: [ ]
    PATCH_SIZE: 2
    IN_CHANNELS: 4
    HIDDEN_SIZE: 1152
    DEPTH: 28
    NUM_HEADS: 16
    MLP_RATIO: 4.0
    PRED_SIGMA: True
    DROP_PATH: 0.0
    WINDOW_DIZE: 0
    Y_CHANNELS: 4096
    MAX_SEQ_LEN: 4096
    QK_NORM: True
    USE_GRAD_CHECKPOINT: True
    ATTENTION_BACKEND: flash_attn
  #
  FIRST_STAGE_MODEL:
    NAME: AutoencoderKL
    EMBED_DIM: 4
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
    IGNORE_KEYS: []
    #
    ENCODER:
      NAME: Encoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DOUBLE_Z: True
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
    #
    DECODER:
      NAME: Decoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
      GIVE_PRE_END: False
      TANH_OUT: False
  #
  COND_STAGE_MODEL:
    NAME: ACETextEmbedder
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
    TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
    LENGTH: 120
    T5_DTYPE: bfloat16
    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    CLEAN: whitespace
    USE_GRAD: False

ACE_PROMPT: [
  "A cute cartoon rabbit holding a whiteboard that says 'ACE Refiner', standing in a sunny meadow filled with flowers, with a big smile and bright colors.",
  "A beautiful young woman with long flowing hair, wearing a summer dress, holding a whiteboard that reads 'ACE Refiner' while sitting on a park bench surrounded by cherry blossoms.",
  "An adorable cartoon cat wearing oversized glasses, holding a whiteboard that says 'ACE Refiner', perched on a stack of colorful books in a cozy library setting.",
  "A charming girl with pigtails, wearing a cute school uniform, enthusiastically holding a whiteboard that has 'ACE Refiner' written on it, in a bright and cheerful classroom full of educational posters.",
  "A friendly cartoon dog with floppy ears, sitting in front of a doghouse, proudly holding a whiteboard that says 'ACE Refiner', with a playful expression and a blue sky in the background.",
  "A cute anime girl with big expressive eyes, dressed in a colorful outfit, holding a whiteboard that reads 'ACE Refiner' in a fantastical landscape filled with mythical creatures.",
  "A vibrant cartoon fox holding a whiteboard that says 'ACE Refiner', standing on a rock by a sparkling stream, surrounded by lush greenery and butterflies.",
  "A stylish young woman in a business outfit, smiling as she holds a whiteboard written with 'ACE Refiner', in a modern office filled with plants and natural light.",
  "A cute cartoon unicorn holding a sparkling whiteboard that says 'ACE Refiner', frolicking in a magical forest, with rainbows and stars in the background.",
  "A happy family, consisting of a cute little girl and her playful puppy, holding a whiteboard that says 'ACE Refiner', together in their backyard on a sunny day."
]
REFINER_MODEL:
  NAME: ""
  IS_DEFAULT: False
  DEFAULT_PARAS:
    PARAS:
      RESOLUTIONS: [ [ 1024, 1024 ] ]
    INPUT:
      INPUT_IMAGE:
      INPUT_MASK:
      TASK:
      PROMPT: ""
      NEGATIVE_PROMPT: ""
      OUTPUT_HEIGHT: 1024
      OUTPUT_WIDTH: 1024
      SAMPLER: flow_euler
      SAMPLE_STEPS: 30
      GUIDE_SCALE: 3.5
      GUIDE_RESCALE:
    OUTPUT:
      LATENT:
      IMAGES:
      SEED:
    MODULES_PARAS:
      FIRST_STAGE_MODEL:
        FUNCTION:
          - NAME: encode
            DTYPE: bfloat16
            INPUT: [ "IMAGE" ]
          - NAME: decode
            DTYPE: bfloat16
            INPUT: [ "LATENT" ]
        PARAS:
          SCALE_FACTOR: 1.5305
          SHIFT_FACTOR: 0.0609
          SIZE_FACTOR: 8
      DIFFUSION_MODEL:
        FUNCTION:
          - NAME: forward
            DTYPE: bfloat16
            INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
      COND_STAGE_MODEL:
        FUNCTION:
          - NAME: encode
            DTYPE: bfloat16
            INPUT: [ "PROMPT" ]

  MODEL:
    DIFFUSION:
      NAME: DiffusionFluxRF
      PREDICTION_TYPE: raw
      NOISE_SCHEDULER:
        NAME: FlowMatchSigmaScheduler
        WEIGHTING_SCHEME: logit_normal
        SHIFT: 3.0
        LOGIT_MEAN: 0.0
        LOGIT_STD: 1.0
        MODE_SCALE: 1.29
    DIFFUSION_MODEL:
      NAME: FluxMR
      PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors
      IN_CHANNELS: 64
      OUT_CHANNELS: 64
      HIDDEN_SIZE: 3072
      NUM_HEADS: 24
      AXES_DIM: [ 16, 56, 56 ]
      THETA: 10000
      VEC_IN_DIM: 768
      GUIDANCE_EMBED: True
      CONTEXT_IN_DIM: 4096
      MLP_RATIO: 4.0
      QKV_BIAS: True
      DEPTH: 19
      DEPTH_SINGLE_BLOCKS: 38
      USE_GRAD_CHECKPOINT: True
      ATTN_BACKEND: flash_attn
    #
    FIRST_STAGE_MODEL:
      NAME: AutoencoderKLFlux
      EMBED_DIM: 16
      PRETRAINED_MODEL:  ms://AI-ModelScope/FLUX.1-dev@ae.safetensors
      IGNORE_KEYS: [ ]
      BATCH_SIZE: 8
      USE_CONV: False
      SCALE_FACTOR: 0.3611
      SHIFT_FACTOR: 0.1159
      #
      ENCODER:
        NAME: Encoder
        USE_CHECKPOINT: False
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 16
        DOUBLE_Z: True
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
      #
      DECODER:
        NAME: Decoder
        USE_CHECKPOINT: False
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 16
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
        GIVE_PRE_END: False
        TANH_OUT: False
    #
    COND_STAGE_MODEL:
      NAME: T5PlusClipFluxEmbedder
      T5_MODEL:
        NAME: HFEmbedder
        HF_MODEL_CLS: T5EncoderModel
        MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/
        HF_TOKENIZER_CLS: T5Tokenizer
        TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/
        MAX_LENGTH: 512
        OUTPUT_KEY: last_hidden_state
        D_TYPE: bfloat16
        BATCH_INFER: False
        CLEAN: whitespace
      CLIP_MODEL:
        NAME: HFEmbedder
        HF_MODEL_CLS: CLIPTextModel
        MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder/
        HF_TOKENIZER_CLS: CLIPTokenizer
        TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer/
        MAX_LENGTH: 77
        OUTPUT_KEY: pooler_output
        D_TYPE: bfloat16
        BATCH_INFER: True
        CLEAN: whitespace


================================================
FILE: config/inference_config/models/ace_0.6b_512.yaml
================================================
NAME: ACE_0.6B_512
IS_DEFAULT: True
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
  PARAS:
  #
  INPUT:
    INPUT_IMAGE:
    INPUT_MASK:
    TASK:
    PROMPT: ""
    NEGATIVE_PROMPT: ""
    OUTPUT_HEIGHT: 512
    OUTPUT_WIDTH: 512
    SAMPLER: ddim
    SAMPLE_STEPS: 20
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
    SEED: -1
    TAR_INDEX: 0
  OUTPUT:
    LATENT:
    IMAGES:
    SEED:
  MODULES_PARAS:
    FIRST_STAGE_MODEL:
      FUNCTION:
        - NAME: encode
          DTYPE: float16
          INPUT: ["IMAGE"]
        - NAME: decode
          DTYPE: float16
          INPUT: ["LATENT"]
    #
    DIFFUSION_MODEL:
      FUNCTION:
        - NAME: forward
          DTYPE: float16
          INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
    #
    COND_STAGE_MODEL:
      FUNCTION:
        - NAME: encode_list
          DTYPE: bfloat16
          INPUT: ["PROMPT"]
#
MODEL:
  NAME: LdmACE
  PRETRAINED_MODEL:
  IGNORE_KEYS: [ ]
  SCALE_FACTOR: 0.18215
  SIZE_FACTOR: 8
  DECODER_BIAS: 0.5
  DEFAULT_N_PROMPT: ""
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
  USE_TEXT_POS_EMBEDDINGS: True
  #
  DIFFUSION:
    NAME: ACEDiffusion
    PREDICTION_TYPE: eps
    MIN_SNR_GAMMA:
    NOISE_SCHEDULER:
      NAME: LinearScheduler
      NUM_TIMESTEPS: 1000
      BETA_MIN: 0.0001
      BETA_MAX: 0.02
  #
  DIFFUSION_MODEL:
    NAME: DiTACE
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
    IGNORE_KEYS: [ ]
    PATCH_SIZE: 2
    IN_CHANNELS: 4
    HIDDEN_SIZE: 1152
    DEPTH: 28
    NUM_HEADS: 16
    MLP_RATIO: 4.0
    PRED_SIGMA: True
    DROP_PATH: 0.0
    WINDOW_DIZE: 0
    Y_CHANNELS: 4096
    MAX_SEQ_LEN: 1024
    QK_NORM: True
    USE_GRAD_CHECKPOINT: True
    ATTENTION_BACKEND: flash_attn
  #
  FIRST_STAGE_MODEL:
    NAME: AutoencoderKL
    EMBED_DIM: 4
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
    IGNORE_KEYS: []
    #
    ENCODER:
      NAME: Encoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DOUBLE_Z: True
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
    #
    DECODER:
      NAME: Decoder
      CH: 128
      OUT_CH: 3
      NUM_RES_BLOCKS: 2
      IN_CHANNELS: 3
      ATTN_RESOLUTIONS: [ ]
      CH_MULT: [ 1, 2, 4, 4 ]
      Z_CHANNELS: 4
      DROPOUT: 0.0
      RESAMP_WITH_CONV: True
      GIVE_PRE_END: False
      TANH_OUT: False
  #
  COND_STAGE_MODEL:
    NAME: ACETextEmbedder
    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
    TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
    LENGTH: 120
    T5_DTYPE: bfloat16
    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    CLEAN: whitespace
    USE_GRAD: False


================================================
FILE: config/train_config/ace_0.6b_1024_train.yaml
================================================
ENV:
  BACKEND: nccl
  SEED: 2024
#
SOLVER:
  NAME: ACESolverV1
  RESUME_FROM:
  LOAD_MODEL_ONLY: True
  USE_FSDP: False
  SHARDING_STRATEGY:
  USE_AMP: True
  DTYPE: float16
  CHANNELS_LAST: True
  MAX_STEPS: 500
  MAX_EPOCHS: -1
  NUM_FOLDS: 1
  ACCU_STEP: 1
  EVAL_INTERVAL: 50
  RESCALE_LR: False
  #
  WORK_DIR: ./cache/exp/exp1
  LOG_FILE: std_log.txt
  #
  FILE_SYSTEM:
    - NAME: "HuggingfaceFs"
      TEMP_DIR: ./cache
    - NAME: "ModelscopeFs"
      TEMP_DIR: ./cache
    - NAME: "LocalFs"
      TEMP_DIR: ./cache
    - NAME: "HttpFs"
      TEMP_DIR: ./cache
  #
  MODEL:
    NAME: LdmACE
    PRETRAINED_MODEL:
    IGNORE_KEYS: [ ]
    SCALE_FACTOR: 0.18215
    SIZE_FACTOR: 8
    DECODER_BIAS: 0.5
    DEFAULT_N_PROMPT:
    USE_EMA: True
    EVAL_EMA: False
    TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    USE_TEXT_POS_EMBEDDINGS: True
    #
    DIFFUSION:
      NAME: ACEDiffusion
      PREDICTION_TYPE: eps
      MIN_SNR_GAMMA:
      NOISE_SCHEDULER:
        NAME: LinearScheduler
        NUM_TIMESTEPS: 1000
        BETA_MIN: 0.0001
        BETA_MAX: 0.02
    #
    DIFFUSION_MODEL:
      NAME: DiTACE
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
      IGNORE_KEYS: [ ]
      PATCH_SIZE: 2
      IN_CHANNELS: 4
      HIDDEN_SIZE: 1152
      DEPTH: 28
      NUM_HEADS: 16
      MLP_RATIO: 4.0
      PRED_SIGMA: True
      DROP_PATH: 0.0
      WINDOW_DIZE: 0
      Y_CHANNELS: 4096
      MAX_SEQ_LEN: 4096
      QK_NORM: True
      USE_GRAD_CHECKPOINT: True
      ATTENTION_BACKEND: flash_attn
    #
    FIRST_STAGE_MODEL:
      NAME: AutoencoderKL
      EMBED_DIM: 4
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
      IGNORE_KEYS: []
      #
      ENCODER:
        NAME: Encoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DOUBLE_Z: True
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
      #
      DECODER:
        NAME: Decoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
        GIVE_PRE_END: False
        TANH_OUT: False
    #
    COND_STAGE_MODEL:
      NAME: T5EmbedderHF
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
      TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
      LENGTH: 120
      T5_DTYPE: bfloat16
      ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
      CLEAN: whitespace
      USE_GRAD: False
    LOSS:
      NAME: ReconstructLoss
      LOSS_TYPE: l2
  #
  SAMPLE_ARGS:
    SAMPLER: ddim
    SAMPLE_STEPS: 20
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
  #
  OPTIMIZER:
    NAME: AdamW
    LEARNING_RATE: 1e-5
    EPS: 1e-10
    WEIGHT_DECAY: 5e-4
  #
  TRAIN_DATA:
    NAME: ACEDemoDataset
    MODE: train
    MS_DATASET_NAME: cache/datasets/hed_pair
    MS_DATASET_NAMESPACE: ""
    MS_DATASET_SPLIT: "train"
    MS_DATASET_SUBNAME: ""
    PROMPT_PREFIX: ""
    REPLACE_STYLE: False
    MAX_SEQ_LEN: 4096
    PIN_MEMORY: True
    BATCH_SIZE: 1
    NUM_WORKERS: 1
    SAMPLER:
      NAME: LoopSampler
  #
  TRAIN_HOOKS:
    -
      NAME: BackwardHook
      PRIORITY: 0
    -
      NAME: LogHook
      LOG_INTERVAL: 50
    -
      NAME: CheckpointHook
      INTERVAL: 100
    -
      NAME: ProbeDataHook
      PROB_INTERVAL: 100


================================================
FILE: config/train_config/ace_0.6b_512_train.yaml
================================================
ENV:
  BACKEND: nccl
  SEED: 2024
#
SOLVER:
  NAME: ACESolverV1
  RESUME_FROM:
  LOAD_MODEL_ONLY: True
  USE_FSDP: False
  SHARDING_STRATEGY:
  USE_AMP: True
  DTYPE: float16
  CHANNELS_LAST: True
  MAX_STEPS: 500
  MAX_EPOCHS: -1
  NUM_FOLDS: 1
  ACCU_STEP: 1
  EVAL_INTERVAL: 50
  RESCALE_LR: False
  #
  WORK_DIR: ./cache/exp/exp1
  LOG_FILE: std_log.txt
  #
  FILE_SYSTEM:
    - NAME: "HuggingfaceFs"
      TEMP_DIR: ./cache
    - NAME: "ModelscopeFs"
      TEMP_DIR: ./cache
    - NAME: "LocalFs"
      TEMP_DIR: ./cache
    - NAME: "HttpFs"
      TEMP_DIR: ./cache
  #
  MODEL:
    NAME: LdmACE
    PRETRAINED_MODEL:
    IGNORE_KEYS: [ ]
    SCALE_FACTOR: 0.18215
    SIZE_FACTOR: 8
    DECODER_BIAS: 0.5
    DEFAULT_N_PROMPT:
    USE_EMA: True
    EVAL_EMA: False
    TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
    USE_TEXT_POS_EMBEDDINGS: True
    #
    DIFFUSION:
      NAME: ACEDiffusion
      PREDICTION_TYPE: eps
      MIN_SNR_GAMMA:
      NOISE_SCHEDULER:
        NAME: LinearScheduler
        NUM_TIMESTEPS: 1000
        BETA_MIN: 0.0001
        BETA_MAX: 0.02
    #
    DIFFUSION_MODEL:
      NAME: DiTACE
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
      IGNORE_KEYS: [ ]
      PATCH_SIZE: 2
      IN_CHANNELS: 4
      HIDDEN_SIZE: 1152
      DEPTH: 28
      NUM_HEADS: 16
      MLP_RATIO: 4.0
      PRED_SIGMA: True
      DROP_PATH: 0.0
      WINDOW_DIZE: 0
      Y_CHANNELS: 4096
      MAX_SEQ_LEN: 1024
      QK_NORM: True
      USE_GRAD_CHECKPOINT: True
      ATTENTION_BACKEND: flash_attn
    #
    FIRST_STAGE_MODEL:
      NAME: AutoencoderKL
      EMBED_DIM: 4
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
      IGNORE_KEYS: []
      #
      ENCODER:
        NAME: Encoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DOUBLE_Z: True
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
      #
      DECODER:
        NAME: Decoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
        GIVE_PRE_END: False
        TANH_OUT: False
    #
    COND_STAGE_MODEL:
      NAME: ACETextEmbedder
      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
      TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
      LENGTH: 120
      T5_DTYPE: bfloat16
      ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
      CLEAN: whitespace
      USE_GRAD: False
    LOSS:
      NAME: ReconstructLoss
      LOSS_TYPE: l2
  #
  SAMPLE_ARGS:
    SAMPLER: ddim
    SAMPLE_STEPS: 20
    GUIDE_SCALE: 4.5
    GUIDE_RESCALE: 0.5
  #
  OPTIMIZER:
    NAME: AdamW
    LEARNING_RATE: 1e-5
    EPS: 1e-10
    WEIGHT_DECAY: 5e-4
  #
  TRAIN_DATA:
    NAME: ACEDemoDataset
    MODE: train
    MS_DATASET_NAME: cache/datasets/hed_pair
    MS_DATASET_NAMESPACE: ""
    MS_DATASET_SPLIT: "train"
    MS_DATASET_SUBNAME: ""
    PROMPT_PREFIX: ""
    REPLACE_STYLE: False
    MAX_SEQ_LEN: 1024
    PIN_MEMORY: True
    BATCH_SIZE: 1
    NUM_WORKERS: 1
    SAMPLER:
      NAME: LoopSampler
  #
  TRAIN_HOOKS:
    -
      NAME: BackwardHook
      PRIORITY: 0
    -
      NAME: LogHook
      LOG_INTERVAL: 50
    -
      NAME: CheckpointHook
      INTERVAL: 100
    -
      NAME: ProbeDataHook
      PROB_INTERVAL: 100


================================================
FILE: modules/__init__.py
================================================
from . import data, model, solver

================================================
FILE: modules/data/__init__.py
================================================
from . import dataset

================================================
FILE: modules/data/dataset/__init__.py
================================================
from .dataset import ACEDemoDataset

================================================
FILE: modules/data/dataset/dataset.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import io
import math
import os
import sys
from collections import defaultdict

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

from scepter.modules.data.dataset.base_dataset import BaseDataset
from scepter.modules.data.dataset.registry import DATASETS
from scepter.modules.transform.io import pillow_convert
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS

Image.MAX_IMAGE_PIXELS = None

@DATASETS.register_class()
class ACEDemoDataset(BaseDataset):
    para_dict = {
        'MS_DATASET_NAME': {
            'value': '',
            'description': 'Modelscope dataset name.'
        },
        'MS_DATASET_NAMESPACE': {
            'value': '',
            'description': 'Modelscope dataset namespace.'
        },
        'MS_DATASET_SUBNAME': {
            'value': '',
            'description': 'Modelscope dataset subname.'
        },
        'MS_DATASET_SPLIT': {
            'value': '',
            'description':
            'Modelscope dataset split set name, default is train.'
        },
        'MS_REMAP_KEYS': {
            'value':
            None,
            'description':
            'Modelscope dataset header of list file, the default is Target:FILE; '
            'If your file is not this header, please set this field, which is a map dict.'
            "For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
        },
        'MS_REMAP_PATH': {
            'value':
            None,
            'description':
            'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
            ' default is None. But if you want to use the datalist from modelscope and the file from '
            'local device, you can use this field to set the root path of your images. '
        },
        'TRIGGER_WORDS': {
            'value':
            '',
            'description':
            'The words used to describe the common features of your data, especially when you customize a '
            'tuner. Use these words you can get what you want.'
        },
        'HIGHLIGHT_KEYWORDS': {
            'value':
            '',
            'description':
            'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
        },
        'KEYWORDS_SIGN': {
            'value':
            '',
            'description':
            'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
        },
    }

    def __init__(self, cfg, logger=None):
        super().__init__(cfg=cfg, logger=logger)
        from modelscope import MsDataset
        from modelscope.utils.constant import DownloadMode
        ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
        ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
        ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
        ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
        ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
        ms_remap_path = cfg.get('MS_REMAP_PATH', None)

        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
        self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
        self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
        self.replace_style = cfg.get('REPLACE_STYLE', False)
        self.trigger_words = cfg.get('TRIGGER_WORDS', '')
        self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
        self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
        self.add_indicator = cfg.get('ADD_INDICATOR', False)
        # Use modelscope dataset
        if not ms_dataset_name:
            raise ValueError(
                'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
                'as modelscope dataset.')
        if FS.exists(ms_dataset_name):
            ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
            self.ms_dataset_name = ms_dataset_name
            # ms_remap_path = ms_dataset_name
        try:
            self.data = MsDataset.load(str(ms_dataset_name),
                                       namespace=ms_dataset_namespace,
                                       subset_name=ms_dataset_subname,
                                       split=ms_dataset_split)
        except Exception:
            self.logger.info(
                "Load Modelscope dataset failed, retry with download_mode='force_redownload'."
            )
            try:
                self.data = MsDataset.load(
                    str(ms_dataset_name),
                    namespace=ms_dataset_namespace,
                    subset_name=ms_dataset_subname,
                    split=ms_dataset_split,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD)
            except Exception as sec_e:
                raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
        if ms_remap_keys:
            self.data = self.data.remap_columns(ms_remap_keys.get_dict())

        if ms_remap_path:

            def map_func(example):
                return {
                    k: os.path.join(ms_remap_path, v)
                    if k.endswith(':FILE') else v
                    for k, v in example.items()
                }

            self.data = self.data.ds_instance.map(map_func)

        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        if self.mode == 'train':
            return sys.maxsize
        else:
            return len(self.data)

    def _get(self, index: int):
        current_data = self.data[index % len(self.data)]

        tar_image_path = current_data.get('Target:FILE', '')
        src_image_path = current_data.get('Source:FILE', '')

        style = current_data.get('Style', '')
        prompt = current_data.get('Prompt', current_data.get('prompt', ''))
        if self.replace_style and not style == '':
            prompt = prompt.replace(style, f'<{self.keywords_sign}>')

        elif not self.replace_keywords.strip() == '':
            prompt = prompt.replace(
                self.replace_keywords,
                '<' + self.replace_keywords + f'{self.keywords_sign}>')

        if not self.trigger_words == '':
            prompt = self.trigger_words.strip() + ' ' + prompt

        src_image = self.load_image(self.ms_dataset_name,
                                    src_image_path,
                                    cvt_type='RGB')
        tar_image = self.load_image(self.ms_dataset_name,
                                    tar_image_path,
                                    cvt_type='RGB')
        src_image = self.image_preprocess(src_image)
        tar_image = self.image_preprocess(tar_image)

        tar_image = self.transforms(tar_image)
        src_image = self.transforms(src_image)
        src_mask = torch.ones_like(src_image[[0]])
        tar_mask = torch.ones_like(tar_image[[0]])
        if self.add_indicator:
            if '{image}' not in prompt:
                prompt = '{image}, ' + prompt

        return {
            'edit_image': [src_image],
            'edit_image_mask': [src_mask],
            'image': tar_image,
            'image_mask': tar_mask,
            'prompt': [prompt],
        }

    def load_image(self, prefix, img_path, cvt_type=None):
        if img_path is None or img_path == '':
            return None
        img_path = os.path.join(prefix, img_path)
        with FS.get_object(img_path) as image_bytes:
            image = Image.open(io.BytesIO(image_bytes))
            if cvt_type is not None:
                image = pillow_convert(image, cvt_type)
        return image

    def image_preprocess(self,
                         img,
                         size=None,
                         interpolation=InterpolationMode.BILINEAR):
        H, W = img.height, img.width
        if H / W > self.max_aspect_ratio:
            img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
        elif W / H > self.max_aspect_ratio:
            img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)

        if size is None:
            # resize image for max_seq_len, while keep the aspect ratio
            H, W = img.height, img.width
            scale = min(
                1.0,
                math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
            rH = int(
                H * scale) // self.d * self.d  # ensure divisible by self.d
            rW = int(W * scale) // self.d * self.d
        else:
            rH, rW = size
        img = T.Resize((rH, rW), interpolation=interpolation,
                       antialias=True)(img)
        return np.array(img, dtype=np.uint8)

    @staticmethod
    def get_config_template():
        return dict_to_yaml('DATASet',
                            __class__.__name__,
                            ACEDemoDataset.para_dict,
                            set_name=True)

    @staticmethod
    def collate_fn(batch):
        collect = defaultdict(list)
        for sample in batch:
            for k, v in sample.items():
                collect[k].append(v)

        new_batch = dict()
        for k, v in collect.items():
            if all([i is None for i in v]):
                new_batch[k] = None
            else:
                new_batch[k] = v

        return new_batch


================================================
FILE: modules/inference/__init__.py
================================================


================================================
FILE: modules/model/__init__.py
================================================
from . import backbone, embedder, diffusion, network

================================================
FILE: modules/model/backbone/__init__.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from .ace import DiTACE


================================================
FILE: modules/model/backbone/ace.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint_sequential

from scepter.modules.model.base_model import BaseModel
from scepter.modules.model.registry import BACKBONES
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS

from .layers import (
    Mlp,
    TimestepEmbedder,
    PatchEmbed,
    DiTACEBlock,
    T2IFinalLayer
)
from .pos_embed import rope_params


@BACKBONES.register_class()
class DiTACE(BaseModel):

    para_dict = {
        'PATCH_SIZE': {
            'value': 2,
            'description': ''
        },
        'IN_CHANNELS': {
            'value': 4,
            'description': ''
        },
        'HIDDEN_SIZE': {
            'value': 1152,
            'description': ''
        },
        'DEPTH': {
            'value': 28,
            'description': ''
        },
        'NUM_HEADS': {
            'value': 16,
            'description': ''
        },
        'MLP_RATIO': {
            'value': 4.0,
            'description': ''
        },
        'PRED_SIGMA': {
            'value': True,
            'description': ''
        },
        'DROP_PATH': {
            'value': 0.,
            'description': ''
        },
        'WINDOW_SIZE': {
            'value': 0,
            'description': ''
        },
        'WINDOW_BLOCK_INDEXES': {
            'value': None,
            'description': ''
        },
        'Y_CHANNELS': {
            'value': 4096,
            'description': ''
        },
        'ATTENTION_BACKEND': {
            'value': None,
            'description': ''
        },
        'QK_NORM': {
            'value': True,
            'description': 'Whether to use RMSNorm for query and key.',
        },
    }
    para_dict.update(BaseModel.para_dict)

    def __init__(self, cfg, logger):
        super().__init__(cfg, logger=logger)
        self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
        if self.window_block_indexes is None:
            self.window_block_indexes = []
        self.pred_sigma = cfg.get('PRED_SIGMA', True)
        self.in_channels = cfg.get('IN_CHANNELS', 4)
        self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
        self.patch_size = cfg.get('PATCH_SIZE', 2)
        self.num_heads = cfg.get('NUM_HEADS', 16)
        self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
        self.y_channels = cfg.get('Y_CHANNELS', 4096)
        self.drop_path = cfg.get('DROP_PATH', 0.)
        self.depth = cfg.get('DEPTH', 28)
        self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
        self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
        self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
        self.qk_norm = cfg.get('QK_NORM', False)
        self.ignore_keys = cfg.get('IGNORE_KEYS', [])
        assert (self.hidden_size % self.num_heads
                ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
        d = self.hidden_size // self.num_heads
        self.freqs = torch.cat(
            [
                rope_params(self.max_seq_len, d - 4 * (d // 6)),  # T (~1/3)
                rope_params(self.max_seq_len, 2 * (d // 6)),  # H (~1/3)
                rope_params(self.max_seq_len, 2 * (d // 6))  # W (~1/3)
            ],
            dim=1)

        # init embedder
        self.x_embedder = PatchEmbed(self.patch_size,
                                     self.in_channels + 1,
                                     self.hidden_size,
                                     bias=True,
                                     flatten=False)
        self.t_embedder = TimestepEmbedder(self.hidden_size)
        self.y_embedder = Mlp(in_features=self.y_channels,
                              hidden_features=self.hidden_size,
                              out_features=self.hidden_size,
                              act_layer=lambda: nn.GELU(approximate='tanh'),
                              drop=0)
        self.t_block = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
        # init blocks
        drop_path = [
            x.item() for x in torch.linspace(0, self.drop_path, self.depth)
        ]
        self.blocks = nn.ModuleList([
            DiTACEBlock(self.hidden_size,
                        self.num_heads,
                        mlp_ratio=self.mlp_ratio,
                        drop_path=drop_path[i],
                        window_size=self.window_size
                        if i in self.window_block_indexes else 0,
                        backend=self.attention_backend,
                        use_condition=True,
                        qk_norm=self.qk_norm) for i in range(self.depth)
        ])
        self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
                                         self.out_channels)
        self.initialize_weights()

    def load_pretrained_model(self, pretrained_model):
        if pretrained_model:
            with FS.get_from(pretrained_model, wait_finish=True) as local_path:
                model = torch.load(local_path, map_location='cpu')
                if 'state_dict' in model:
                    model = model['state_dict']
                new_ckpt = OrderedDict()
                for k, v in model.items():
                    if self.ignore_keys is not None:
                        if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
                                (isinstance(self.ignore_keys, list) and k in self.ignore_keys):
                            continue
                    k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
                    k = k.replace('.cross_attn.proj.',
                                  '.cross_attn.o.').replace(
                                      '.attn.proj.', '.attn.o.')
                    if '.cross_attn.kv_linear.' in k:
                        k_p, v_p = torch.split(v, v.shape[0] // 2)
                        new_ckpt[k.replace('.cross_attn.kv_linear.',
                                           '.cross_attn.k.')] = k_p
                        new_ckpt[k.replace('.cross_attn.kv_linear.',
                                           '.cross_attn.v.')] = v_p
                    elif '.attn.qkv.' in k:
                        q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
                        new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
                        new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
                        new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
                    elif 'y_embedder.y_proj.' in k:
                        new_ckpt[k.replace('y_embedder.y_proj.',
                                           'y_embedder.')] = v
                    elif k in ('x_embedder.proj.weight'):
                        model_p = self.state_dict()[k]
                        if v.shape != model_p.shape:
                            model_p.zero_()
                            model_p[:, :4, :, :].copy_(v)
                            new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
                        else:
                            new_ckpt[k] = v
                    elif k in ('x_embedder.proj.bias'):
                        new_ckpt[k] = v
                    else:
                        new_ckpt[k] = v
                missing, unexpected = self.load_state_dict(new_ckpt,
                                                           strict=False)
                print(
                    f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
                )
                if len(missing) > 0:
                    print(f'Missing Keys:\n {missing}')
                if len(unexpected) > 0:
                    print(f'\nUnexpected Keys:\n {unexpected}')

    def forward(self,
                x,
                t=None,
                cond=dict(),
                mask=None,
                text_position_embeddings=None,
                gc_seg=-1,
                **kwargs):
        if self.freqs.device != x.device:
            self.freqs = self.freqs.to(x.device)
        if isinstance(cond, dict):
            context = cond.get('crossattn', None)
        else:
            context = cond
        if text_position_embeddings is not None:
            # default use the text_position_embeddings in state_dict
            # if state_dict doesn't including this key, use the arg: text_position_embeddings
            proj_position_embeddings = self.y_embedder(
                text_position_embeddings)
        else:
            proj_position_embeddings = None

        ctx_batch, txt_lens = [], []
        if mask is not None and isinstance(mask, list):
            for ctx, ctx_mask in zip(context, mask):
                for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
                    u, m = one_ctx
                    t_len = m.flatten().sum()  # l
                    u = u[:t_len]
                    u = self.y_embedder(u)
                    if frame_id == 0:
                        u = u + proj_position_embeddings[
                            len(ctx) -
                            1] if proj_position_embeddings is not None else u
                    else:
                        u = u + proj_position_embeddings[
                            frame_id -
                            1] if proj_position_embeddings is not None else u
                    ctx_batch.append(u)
                    txt_lens.append(t_len)
        else:
            raise TypeError
        y = torch.cat(ctx_batch, dim=0)
        txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)

        batch_frames = []
        for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
            u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
            m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
            batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
        if 'edit' in cond:
            for i, (edit, edit_mask) in enumerate(
                    zip(cond['edit'], cond['edit_mask'])):
                if edit is None:
                    continue
                for u, m in zip(edit, edit_mask):
                    u = u.squeeze(0)
                    m = torch.ones_like(
                        u[[0], :, :]) if m is None else m.squeeze(0)
                    batch_frames[i].append(
                        torch.cat([u, m], dim=0).unsqueeze(0))

        patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
        for frames in batch_frames:
            patches, patch_shapes = [], []
            self_x_len.append(0)
            for frame_id, u in enumerate(frames):
                u = self.x_embedder(u)
                h, w = u.size(2), u.size(3)
                u = rearrange(u, '1 c h w -> (h w) c')
                if frame_id == 0:
                    u = u + proj_position_embeddings[
                        len(frames) -
                        1] if proj_position_embeddings is not None else u
                else:
                    u = u + proj_position_embeddings[
                        frame_id -
                        1] if proj_position_embeddings is not None else u
                patches.append(u)
                patch_shapes.append([h, w])
                cross_x_len.append(h * w)  # b*s, 1
                self_x_len[-1] += h * w  # b, 1
            # u = torch.cat(patches, dim=0)
            patch_batch.extend(patches)
            shape_batch.append(
                torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
        # repeat t to align with x
        t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
        self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
            x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
                x.device, non_blocking=True))
        # x = pad_sequence(tuple(patch_batch), batch_first=True)  # b, s*max(cl), c
        x = torch.cat(patch_batch, dim=0)
        x_shapes = pad_sequence(tuple(shape_batch),
                                batch_first=True)  # b, max(len(frames)), 2
        t = self.t_embedder(t)  # (N, D)
        t0 = self.t_block(t)
        # y = self.y_embedder(context)

        kwargs = dict(y=y,
                      t=t0,
                      x_shapes=x_shapes,
                      self_x_len=self_x_len,
                      cross_x_len=cross_x_len,
                      freqs=self.freqs,
                      txt_lens=txt_lens)
        if self.use_grad_checkpoint and gc_seg >= 0:
            x = checkpoint_sequential(
                functions=[partial(block, **kwargs) for block in self.blocks],
                segments=gc_seg if gc_seg > 0 else len(self.blocks),
                input=x,
                use_reentrant=False)
        else:
            for block in self.blocks:
                x = block(x, **kwargs)
        x = self.final_layer(x, t)  # b*s*n, d
        outs, cur_length = [], 0
        p = self.patch_size
        for seq_length, shape in zip(self_x_len, shape_batch):
            x_i = x[cur_length:cur_length + seq_length]
            h, w = shape[0].tolist()
            u = x_i[:h * w].view(h, w, p, p, -1)
            u = rearrange(u, 'h w p q c -> (h p w q) c'
                          )  # dump into sequence for following tensor ops
            cur_length = cur_length + seq_length
            outs.append(u)
        x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
        if self.pred_sigma:
            return x.chunk(2, dim=1)[0]
        else:
            return x

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)
        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
        nn.init.normal_(self.t_block[1].weight, std=0.02)
        # Initialize caption embedding MLP:
        if hasattr(self, 'y_embedder'):
            nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
            nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
        # Zero-out adaLN modulation layers
        for block in self.blocks:
            nn.init.constant_(block.cross_attn.o.weight, 0)
            nn.init.constant_(block.cross_attn.o.bias, 0)
        # Zero-out output layers:
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @staticmethod
    def get_config_template():
        return dict_to_yaml('BACKBONE',
                            __class__.__name__,
                            DiTACE.para_dict,
                            set_name=True)


================================================
FILE: modules/model/backbone/layers.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import warnings
import torch
import torch.nn as nn
from .pos_embed import rope_apply_multires as rope_apply

try:
    from flash_attn import (flash_attn_varlen_func)
    FLASHATTN_IS_AVAILABLE = True
except ImportError as e:
    FLASHATTN_IS_AVAILABLE = False
    flash_attn_varlen_func = None
    warnings.warn(f'{e}')

__all__ = [
    "drop_path",
    "modulate",
    "PatchEmbed",
    "DropPath",
    "RMSNorm",
    "Mlp",
    "TimestepEmbedder",
    "DiTEditBlock",
    "MultiHeadAttentionDiTEdit",
    "T2IFinalLayer",
]

def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape 
Download .txt
gitextract_gx2sgdse/

├── .gitignore
├── LICENSE
├── __init__.py
├── chatbot/
│   ├── __init__.py
│   ├── ace_inference.py
│   ├── config/
│   │   ├── chatbot_ui.yaml
│   │   └── models/
│   │       └── ace_0.6b_512.yaml
│   ├── example.py
│   ├── infer.py
│   ├── run_gradio.py
│   └── utils.py
├── config/
│   ├── inference_config/
│   │   ├── chatbot_ui.yaml
│   │   └── models/
│   │       ├── ace_0.6b_1024.yaml
│   │       ├── ace_0.6b_1024_refiner.yaml
│   │       └── ace_0.6b_512.yaml
│   └── train_config/
│       ├── ace_0.6b_1024_train.yaml
│       └── ace_0.6b_512_train.yaml
├── modules/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset/
│   │       ├── __init__.py
│   │       └── dataset.py
│   ├── inference/
│   │   └── __init__.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── backbone/
│   │   │   ├── __init__.py
│   │   │   ├── ace.py
│   │   │   ├── layers.py
│   │   │   └── pos_embed.py
│   │   ├── diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── diffusions.py
│   │   │   ├── samplers.py
│   │   │   └── schedules.py
│   │   ├── embedder/
│   │   │   ├── __init__.py
│   │   │   └── embedder.py
│   │   ├── network/
│   │   │   ├── __init__.py
│   │   │   └── ldm_ace.py
│   │   └── utils/
│   │       └── basic_utils.py
│   └── solver/
│       ├── __init__.py
│       └── ace_solver.py
├── readme.md
├── requirements.txt
└── tools/
    ├── run_inference.py
    └── run_train.py
Download .txt
SYMBOL INDEX (152 symbols across 18 files)

FILE: chatbot/ace_inference.py
  function process_edit_image (line 26) | def process_edit_image(images,
  class TextEmbedding (line 83) | class TextEmbedding(nn.Module):
    method __init__ (line 84) | def __init__(self, embedding_shape):
  class RefinerInference (line 88) | class RefinerInference(DiffusionInference):
    method init_from_cfg (line 89) | def init_from_cfg(self, cfg):
    method encode_first_stage (line 101) | def encode_first_stage(self, x, **kwargs):
    method upscale_resize (line 113) | def upscale_resize(self, image, interpolation=T.InterpolationMode.BILI...
    method decode_first_stage (line 121) | def decode_first_stage(self, z):
    method noise_sample (line 128) | def noise_sample(self, num_samples, h, w, seed, device = None, dtype =...
    method refine (line 140) | def refine(self,
  class ACEInference (line 221) | class ACEInference(DiffusionInference):
    method __init__ (line 222) | def __init__(self, logger=None):
    method init_from_cfg (line 231) | def init_from_cfg(self, cfg):
    method encode_first_stage (line 292) | def encode_first_stage(self, x, **kwargs):
    method decode_first_stage (line 304) | def decode_first_stage(self, z):
    method __call__ (line 319) | def __call__(self,
    method cond_stage_embeddings (line 534) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):

FILE: chatbot/example.py
  function download_image (line 9) | def download_image(image, local_path=None):
  function blank_image (line 14) | def blank_image():
  function get_examples (line 19) | def get_examples(cache_dir):

FILE: chatbot/infer.py
  function process_edit_image (line 26) | def process_edit_image(images,
  class TextEmbedding (line 83) | class TextEmbedding(nn.Module):
    method __init__ (line 84) | def __init__(self, embedding_shape):
  class ACEInference (line 89) | class ACEInference(DiffusionInference):
    method __init__ (line 90) | def __init__(self, logger=None):
    method init_from_cfg (line 99) | def init_from_cfg(self, cfg):
    method encode_first_stage (line 146) | def encode_first_stage(self, x, **kwargs):
    method decode_first_stage (line 158) | def decode_first_stage(self, z):
    method __call__ (line 171) | def __call__(self,
    method cond_stage_embeddings (line 346) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):

FILE: chatbot/run_gradio.py
  class ChatBotUI (line 47) | class ChatBotUI(object):
    method __init__ (line 48) | def __init__(self,
    method create_ui (line 177) | def create_ui(self):
    method set_callbacks (line 523) | def set_callbacks(self, *args, **kwargs):
    method get_history (line 1274) | def get_history(self, history):
    method generate_random_string (line 1289) | def generate_random_string(self, length=20):
    method add_edited_image_to_history (line 1295) | def add_edited_image_to_history(self, image, mask_type, history, images):
    method add_uploaded_image_to_history (line 1356) | def add_uploaded_image_to_history(self, img, history, images):
  function run_gr (line 1401) | def run_gr(cfg):

FILE: chatbot/utils.py
  function build_transform (line 12) | def build_transform(input_size):
  function find_closest_aspect_ratio (line 24) | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  function dynamic_preprocess (line 41) | def dynamic_preprocess(image,
  function load_image (line 83) | def load_image(image_file, input_size=448, max_num=12):

FILE: modules/data/dataset/dataset.py
  class ACEDemoDataset (line 25) | class ACEDemoDataset(BaseDataset):
    method __init__ (line 81) | def __init__(self, cfg, logger=None):
    method __len__ (line 146) | def __len__(self):
    method _get (line 152) | def _get(self, index: int):
    method load_image (line 196) | def load_image(self, prefix, img_path, cvt_type=None):
    method image_preprocess (line 206) | def image_preprocess(self,
    method get_config_template (line 232) | def get_config_template():
    method collate_fn (line 239) | def collate_fn(batch):

FILE: modules/model/backbone/ace.py
  class DiTACE (line 29) | class DiTACE(BaseModel):
    method __init__ (line 87) | def __init__(self, cfg, logger):
    method load_pretrained_model (line 152) | def load_pretrained_model(self, pretrained_model):
    method forward (line 204) | def forward(self,
    method initialize_weights (line 336) | def initialize_weights(self):
    method dtype (line 365) | def dtype(self):
    method get_config_template (line 369) | def get_config_template():

FILE: modules/model/backbone/layers.py
  function drop_path (line 30) | def drop_path(x, drop_prob: float = 0., training: bool = False):
  function modulate (line 50) | def modulate(x, shift, scale, unsqueeze=False):
  class PatchEmbed (line 57) | class PatchEmbed(nn.Module):
    method __init__ (line 60) | def __init__(
    method forward (line 78) | def forward(self, x):
  class DropPath (line 86) | class DropPath(nn.Module):
    method __init__ (line 89) | def __init__(self, drop_prob=None):
    method forward (line 93) | def forward(self, x):
  class RMSNorm (line 97) | class RMSNorm(nn.Module):
    method __init__ (line 98) | def __init__(self, dim, eps=1e-6):
    method forward (line 104) | def forward(self, x):
    method _norm (line 107) | def _norm(self, x):
  class Mlp (line 111) | class Mlp(nn.Module):
    method __init__ (line 114) | def __init__(self,
    method forward (line 128) | def forward(self, x):
  class TimestepEmbedder (line 137) | class TimestepEmbedder(nn.Module):
    method __init__ (line 141) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method timestep_embedding (line 151) | def timestep_embedding(t, dim, max_period=10000):
    method forward (line 173) | def forward(self, t):
  class DiTACEBlock (line 179) | class DiTACEBlock(nn.Module):
    method __init__ (line 180) | def __init__(self,
    method forward (line 226) | def forward(self, x, y, t, **kwargs):
  class MultiHeadAttention (line 244) | class MultiHeadAttention(nn.Module):
    method __init__ (line 245) | def __init__(self,
    method flash_attn (line 291) | def flash_attn(self, x, context=None, **kwargs):
    method forward (line 359) | def forward(self, x, context=None, **kwargs):
  class T2IFinalLayer (line 364) | class T2IFinalLayer(nn.Module):
    method __init__ (line 368) | def __init__(self, hidden_size, patch_size, out_channels):
    method forward (line 380) | def forward(self, x, t):

FILE: modules/model/backbone/pos_embed.py
  function frame_pad (line 9) | def frame_pad(x, seq_len, shapes):
  function frame_unpad (line 26) | def frame_unpad(x, shapes):
  function rope_apply_multires (line 38) | def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
  function rope_params (line 75) | def rope_params(max_seq_len, dim, theta=10000):

FILE: modules/model/diffusion/diffusions.py
  class ACEDiffusion (line 18) | class ACEDiffusion(object):
    method __init__ (line 29) | def __init__(self, cfg, logger=None):
    method init_params (line 35) | def init_params(self):
    method sample (line 53) | def sample(self,
    method loss (line 133) | def loss(self,
    method get_sampler (line 159) | def get_sampler(self, sampler):
    method __repr__ (line 185) | def __repr__(self) -> str:
    method get_config_template (line 189) | def get_config_template():

FILE: modules/model/diffusion/samplers.py
  function _i (line 9) | def _i(tensor, t, x):
  class DDIMSampler (line 20) | class DDIMSampler(BaseDiffusionSampler):
    method init_params (line 21) | def init_params(self):
    method preprare_sampler (line 27) | def preprare_sampler(self,
    method step (line 60) | def step(self, sampler_output):

FILE: modules/model/diffusion/schedules.py
  class ScheduleOutput (line 10) | class ScheduleOutput(object):
    method add_custom_field (line 18) | def add_custom_field(self, key: str, value) -> None:
  class LinearScheduler (line 23) | class LinearScheduler(BaseNoiseScheduler):
    method init_params (line 26) | def init_params(self):
    method betas_to_sigmas (line 31) | def betas_to_sigmas(self, betas):
    method get_schedule (line 34) | def get_schedule(self):
    method add_noise (line 46) | def add_noise(self, x_0, noise=None, t=None, **kwargs):

FILE: modules/model/embedder/embedder.py
  class ACETextEmbedder (line 25) | class ACETextEmbedder(BaseEmbedder):
    method __init__ (line 70) | def __init__(self, cfg, logger=None):
    method freeze (line 112) | def freeze(self):
    method forward (line 118) | def forward(self, tokens, return_mask=False, use_mask=True):
    method _clean (line 134) | def _clean(self, text):
    method encode (line 145) | def encode(self, text, return_mask=False, use_mask=True):
    method encode_list (line 167) | def encode_list(self, text_list, return_mask=True):
    method get_config_template (line 180) | def get_config_template():

FILE: modules/model/network/ldm_ace.py
  class TextEmbedding (line 23) | class TextEmbedding(nn.Module):
    method __init__ (line 24) | def __init__(self, embedding_shape):
  class LdmACE (line 30) | class LdmACE(LatentDiffusion):
    method __init__ (line 34) | def __init__(self, cfg, logger=None):
    method encode_first_stage (line 53) | def encode_first_stage(self, x, **kwargs):
    method decode_first_stage (line 61) | def decode_first_stage(self, z):
    method cond_stage_embeddings (line 67) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
    method limit_batch_data (line 85) | def limit_batch_data(self, batch_data_list, log_num):
    method forward_train (line 96) | def forward_train(self,
    method forward_test (line 186) | def forward_test(self,
    method get_config_template (line 348) | def get_config_template():

FILE: modules/model/utils/basic_utils.py
  function exists (line 11) | def exists(x):
  function default (line 15) | def default(val, d):
  function disabled_train (line 21) | def disabled_train(self, mode=True):
  function transfer_size (line 27) | def transfer_size(para_num):
  function count_params (line 44) | def count_params(model):
  function expand_dims_like (line 49) | def expand_dims_like(x, y):
  function unpack_tensor_into_imagelist (line 55) | def unpack_tensor_into_imagelist(image_tensor, shapes):
  function find_example (line 64) | def find_example(tensor_list, image_list):
  function pack_imagelist_into_tensor_v2 (line 75) | def pack_imagelist_into_tensor_v2(image_list):
  function to_device (line 95) | def to_device(inputs, strict=True):
  function check_list_of_list (line 103) | def check_list_of_list(ll):

FILE: modules/solver/ace_solver.py
  class ACESolverV1 (line 16) | class ACESolverV1(LatentDiffusionSolver):
    method __init__ (line 17) | def __init__(self, cfg, logger=None):
    method save_results (line 21) | def save_results(self, results):
    method run_eval (line 65) | def run_eval(self):
    method run_test (line 95) | def run_test(self):
    method probe_data (line 126) | def probe_data(self):

FILE: tools/run_inference.py
  function run_one_case (line 31) | def run_one_case(pipe, input_image, input_mask, edit_k,
  function run (line 63) | def run():

FILE: tools/run_train.py
  function run_task (line 21) | def run_task(cfg):
  function update_config (line 29) | def update_config(cfg):
  function run (line 45) | def run():
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (286K chars).
[
  {
    "path": ".gitignore",
    "chars": 179,
    "preview": "*.pyc\n*.pth\n*.pt\n*.pkl\n*.ckpt\n*.DS_Store\n*__pycache__*\n*.cache*\n*.bin\n*.idea\n*.csv\ncache\nbuild\ndist\ndev\nscepter.egg-info"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "__init__.py",
    "chars": 43,
    "preview": "from . import modules\nfrom . import chatbot"
  },
  {
    "path": "chatbot/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "chatbot/ace_inference.py",
    "chars": 24539,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\n\nimport "
  },
  {
    "path": "chatbot/config/chatbot_ui.yaml",
    "chars": 887,
    "preview": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n  - NAME: \"HuggingfaceFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"ModelscopeFs\"\n    TEMP"
  },
  {
    "path": "chatbot/config/models/ace_0.6b_512.yaml",
    "chars": 2961,
    "preview": "NAME: ACE_0.6B_512\nIS_DEFAULT: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    INPUT_MASK:\n    TASK:\n   "
  },
  {
    "path": "chatbot/example.py",
    "chars": 19638,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport os\n\nfrom scepter.modules.utils.file_sys"
  },
  {
    "path": "chatbot/infer.py",
    "chars": 14660,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\nimport n"
  },
  {
    "path": "chatbot/run_gradio.py",
    "chars": 68798,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport base64\nimport copy\nimpo"
  },
  {
    "path": "chatbot/utils.py",
    "chars": 3702,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nimport torchvision.transforms as "
  },
  {
    "path": "config/inference_config/chatbot_ui.yaml",
    "chars": 917,
    "preview": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n  - NAME: \"HuggingfaceFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"ModelscopeFs\"\n    TEMP"
  },
  {
    "path": "config/inference_config/models/ace_0.6b_1024.yaml",
    "chars": 2994,
    "preview": "NAME: ACE_0.6B_1024\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n   "
  },
  {
    "path": "config/inference_config/models/ace_0.6b_1024_refiner.yaml",
    "chars": 8606,
    "preview": "NAME: ACE_0.6B_1024_REFINER\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IM"
  },
  {
    "path": "config/inference_config/models/ace_0.6b_512.yaml",
    "chars": 2985,
    "preview": "NAME: ACE_0.6B_512\nIS_DEFAULT: True\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    I"
  },
  {
    "path": "config/train_config/ace_0.6b_1024_train.yaml",
    "chars": 3727,
    "preview": "ENV:\n  BACKEND: nccl\n  SEED: 2024\n#\nSOLVER:\n  NAME: ACESolverV1\n  RESUME_FROM:\n  LOAD_MODEL_ONLY: True\n  USE_FSDP: False"
  },
  {
    "path": "config/train_config/ace_0.6b_512_train.yaml",
    "chars": 3727,
    "preview": "ENV:\n  BACKEND: nccl\n  SEED: 2024\n#\nSOLVER:\n  NAME: ACESolverV1\n  RESUME_FROM:\n  LOAD_MODEL_ONLY: True\n  USE_FSDP: False"
  },
  {
    "path": "modules/__init__.py",
    "chars": 33,
    "preview": "from . import data, model, solver"
  },
  {
    "path": "modules/data/__init__.py",
    "chars": 21,
    "preview": "from . import dataset"
  },
  {
    "path": "modules/data/dataset/__init__.py",
    "chars": 35,
    "preview": "from .dataset import ACEDemoDataset"
  },
  {
    "path": "modules/data/dataset/dataset.py",
    "chars": 9569,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nimport io\nimport math\nimport os\nimport sys\nfr"
  },
  {
    "path": "modules/inference/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "modules/model/__init__.py",
    "chars": 52,
    "preview": "from . import backbone, embedder, diffusion, network"
  },
  {
    "path": "modules/model/backbone/__init__.py",
    "chars": 98,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom .ace import DiTACE\n"
  },
  {
    "path": "modules/model/backbone/ace.py",
    "chars": 15610,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport re\nfrom collections import OrderedDict\n"
  },
  {
    "path": "modules/model/backbone/layers.py",
    "chars": 14218,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport warnings\nimport torch\nimpor"
  },
  {
    "path": "modules/model/backbone/pos_embed.py",
    "chars": 2724,
    "preview": "import numpy as np\nfrom einops import rearrange\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn.functional as"
  },
  {
    "path": "modules/model/diffusion/__init__.py",
    "chars": 184,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nfrom .diffusions import ACEDiffusion\nfrom .sa"
  },
  {
    "path": "modules/model/diffusion/diffusions.py",
    "chars": 7872,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport os\nfrom collections import "
  },
  {
    "path": "modules/model/diffusion/samplers.py",
    "chars": 3523,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\n\nfrom scepter.modules.model.regis"
  },
  {
    "path": "modules/model/diffusion/schedules.py",
    "chars": 1939,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nfrom dataclasses import dataclass"
  },
  {
    "path": "modules/model/embedder/__init__.py",
    "chars": 37,
    "preview": "from .embedder import ACETextEmbedder"
  },
  {
    "path": "modules/model/embedder/embedder.py",
    "chars": 6754,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport warnings\nfrom contextlib import nullcon"
  },
  {
    "path": "modules/model/network/__init__.py",
    "chars": 27,
    "preview": "from .ldm_ace import LdmACE"
  },
  {
    "path": "modules/model/network/ldm_ace.py",
    "chars": 14723,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport random\nfrom contextlib impo"
  },
  {
    "path": "modules/model/utils/basic_utils.py",
    "chars": 2885,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom inspect import isfunction\n\nimport torch\nf"
  },
  {
    "path": "modules/solver/__init__.py",
    "chars": 35,
    "preview": "from .ace_solver import ACESolverV1"
  },
  {
    "path": "modules/solver/ace_solver.py",
    "chars": 6257,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport numpy as np\nimport torch\nfrom tqdm impo"
  },
  {
    "path": "readme.md",
    "chars": 10710,
    "preview": "<p align=\"center\">\n\n  <h2 align=\"center\"><img src=\"assets/figures/icon.png\" height=16> : All-round Creator and Editor Fo"
  },
  {
    "path": "requirements.txt",
    "chars": 330,
    "preview": "git+https://github.com/modelscope/scepter.git@v1.3.0_dev#egg=scepter\npycocotools\npyyaml>=5.3.1\nscikit-image\ntorchsde\ntra"
  },
  {
    "path": "tools/run_inference.py",
    "chars": 6495,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport io\nimp"
  },
  {
    "path": "tools/run_train.py",
    "chars": 2114,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport os\nimp"
  }
]

About this extraction

This page contains the full source code of the ali-vilab/ACE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (269.5 KB), approximately 62.7k tokens, and a symbol index with 152 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!