[
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.pth\n*.pt\n*.pkl\n*.ckpt\n*.DS_Store\n*__pycache__*\n*.cache*\n*.bin\n*.idea\n*.csv\ncache\nbuild\ndist\ndev\nscepter.egg-info\n.readthedocs.yml\n*resources\n*.ipynb_checkpoints*\n*.vscode\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "__init__.py",
    "content": "from . import modules\nfrom . import chatbot"
  },
  {
    "path": "chatbot/__init__.py",
    "content": ""
  },
  {
    "path": "chatbot/ace_inference.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as TF\nfrom PIL import Image\nimport torchvision.transforms as T\nfrom scepter.modules.model.registry import DIFFUSIONS\nfrom scepter.modules.model.utils.basic_utils import check_list_of_list\nfrom scepter.modules.model.utils.basic_utils import \\\n    pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor\nfrom scepter.modules.model.utils.basic_utils import (\n    to_device, unpack_tensor_into_imagelist)\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.logger import get_logger\n\nfrom scepter.modules.inference.diffusion_inference import DiffusionInference, get_model\n\n\ndef process_edit_image(images,\n                       masks,\n                       tasks,\n                       max_seq_len=1024,\n                       max_aspect_ratio=4,\n                       d=16,\n                       **kwargs):\n\n    if not isinstance(images, list):\n        images = [images]\n    if not isinstance(masks, list):\n        masks = [masks]\n    if not isinstance(tasks, list):\n        tasks = [tasks]\n\n    img_tensors = []\n    mask_tensors = []\n    for img, mask, task in zip(images, masks, tasks):\n        if mask is None or mask == '':\n            mask = Image.new('L', img.size, 0)\n        W, H = img.size\n        if H / W > max_aspect_ratio:\n            img = TF.center_crop(img, [int(max_aspect_ratio * W), W])\n            mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])\n        elif W / H > max_aspect_ratio:\n            img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])\n            mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])\n\n        H, W = img.height, img.width\n        scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))\n        rH = int(H * scale) // d * d  # ensure divisible by self.d\n        rW = int(W * scale) // d * d\n\n        img = TF.resize(img, (rH, rW),\n                        interpolation=TF.InterpolationMode.BICUBIC)\n        mask = TF.resize(mask, (rH, rW),\n                         interpolation=TF.InterpolationMode.NEAREST_EXACT)\n\n        mask = np.asarray(mask)\n        mask = np.where(mask > 128, 1, 0)\n        mask = mask.astype(\n            np.float32) if np.any(mask) else np.ones_like(mask).astype(\n                np.float32)\n\n        img_tensor = TF.to_tensor(img).to(we.device_id)\n        img_tensor = TF.normalize(img_tensor,\n                                  mean=[0.5, 0.5, 0.5],\n                                  std=[0.5, 0.5, 0.5])\n        mask_tensor = TF.to_tensor(mask).to(we.device_id)\n        if task in ['inpainting', 'Try On', 'Inpainting']:\n            mask_indicator = mask_tensor.repeat(3, 1, 1)\n            img_tensor[mask_indicator == 1] = -1.0\n        img_tensors.append(img_tensor)\n        mask_tensors.append(mask_tensor)\n    return img_tensors, mask_tensors\n\n\nclass TextEmbedding(nn.Module):\n    def __init__(self, embedding_shape):\n        super().__init__()\n        self.pos = nn.Parameter(data=torch.zeros(embedding_shape))\n\nclass RefinerInference(DiffusionInference):\n    def init_from_cfg(self, cfg):\n        self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)\n        super().init_from_cfg(cfg)\n        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \\\n            if cfg.MODEL.have('DIFFUSION') else None\n        self.max_seq_length = cfg.MODEL.get(\"MAX_SEQ_LENGTH\", 4096)\n        assert self.diffusion is not None\n        if not self.use_dynamic_model:\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')\n            self.dynamic_load(self.diffusion_model, 'diffusion_model')\n    @torch.no_grad()\n    def encode_first_stage(self, x, **kwargs):\n        _, dtype = self.get_function_info(self.first_stage_model, 'encode')\n        with torch.autocast('cuda',\n                            enabled=dtype in ('float16', 'bfloat16'),\n                            dtype=getattr(torch, dtype)):\n            def run_one_image(u):\n                zu = get_model(self.first_stage_model).encode(u)\n                if isinstance(zu, (tuple, list)):\n                    zu = zu[0]\n                return zu\n            z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]\n            return z\n    def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):\n        c, H, W = image.shape\n        scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))\n        rH = int(H * scale) // 16 * 16  # ensure divisible by self.d\n        rW = int(W * scale) // 16 * 16\n        image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)\n        return image\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        _, dtype = self.get_function_info(self.first_stage_model, 'decode')\n        with torch.autocast('cuda',\n                            enabled=dtype in ('float16', 'bfloat16'),\n                            dtype=getattr(torch, dtype)):\n            return [get_model(self.first_stage_model).decode(zu) for zu in z]\n\n    def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):\n        noise = torch.randn(\n            num_samples,\n            16,\n            # allow for packing\n            2 * math.ceil(h / 16),\n            2 * math.ceil(w / 16),\n            device=device,\n            dtype=dtype,\n            generator=torch.Generator(device=device).manual_seed(seed),\n        )\n        return noise\n    def refine(self,\n               x_samples=None,\n               prompt=None,\n               reverse_scale=-1.,\n               seed = 2024,\n               **kwargs\n               ):\n        print(prompt)\n        value_input = copy.deepcopy(self.input)\n        x_samples = [self.upscale_resize(x) for x in x_samples]\n\n        noise = []\n        for i, x in enumerate(x_samples):\n            noise_ = self.noise_sample(1, x.shape[1],\n                                       x.shape[2], seed,\n                                       device = x.device)\n            noise.append(noise_)\n        noise, x_shapes = pack_imagelist_into_tensor(noise)\n        if reverse_scale > 0:\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            x_samples = [x.unsqueeze(0) for x in x_samples]\n            x_start = self.encode_first_stage(x_samples, **kwargs)\n            self.dynamic_unload(self.first_stage_model,\n                                'first_stage_model',\n                                skip_loaded=not self.use_dynamic_model)\n            x_start, _ = pack_imagelist_into_tensor(x_start)\n        else:\n            x_start = None\n        # cond stage\n        self.dynamic_load(self.cond_stage_model, 'cond_stage_model')\n        function_name, dtype = self.get_function_info(self.cond_stage_model)\n        with torch.autocast('cuda',\n                            enabled=dtype == 'float16',\n                            dtype=getattr(torch, dtype)):\n            ctx = getattr(get_model(self.cond_stage_model),\n                          function_name)(prompt)\n            ctx[\"x_shapes\"] = x_shapes\n        self.dynamic_unload(self.cond_stage_model,\n                            'cond_stage_model',\n                            skip_loaded=not self.use_dynamic_model)\n\n\n        self.dynamic_load(self.diffusion_model, 'diffusion_model')\n        # UNet use input n_prompt\n        function_name, dtype = self.get_function_info(\n            self.diffusion_model)\n        with torch.autocast('cuda',\n                            enabled=dtype in ('float16', 'bfloat16'),\n                            dtype=getattr(torch, dtype)):\n            solver_sample = value_input.get('sample', 'flow_euler')\n            sample_steps = value_input.get('sample_steps', 20)\n            guide_scale = value_input.get('guide_scale', 3.5)\n            if guide_scale is not None:\n                guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,\n                                         dtype=noise.dtype)\n            else:\n                guide_scale = None\n            latent = self.diffusion.sample(\n                noise=noise,\n                sampler=solver_sample,\n                model=get_model(self.diffusion_model),\n                model_kwargs={\"cond\": ctx, \"guidance\": guide_scale},\n                steps=sample_steps,\n                show_progress=True,\n                guide_scale=guide_scale,\n                return_intermediate=None,\n                reverse_scale=reverse_scale,\n                x=x_start,\n                **kwargs).float()\n        latent = unpack_tensor_into_imagelist(latent, x_shapes)\n        self.dynamic_unload(self.diffusion_model,\n                            'diffusion_model',\n                            skip_loaded=not self.use_dynamic_model)\n        self.dynamic_load(self.first_stage_model, 'first_stage_model')\n        x_samples = self.decode_first_stage(latent)\n        self.dynamic_unload(self.first_stage_model,\n                            'first_stage_model',\n                            skip_loaded=not self.use_dynamic_model)\n        return x_samples\n\n\nclass ACEInference(DiffusionInference):\n    def __init__(self, logger=None):\n        if logger is None:\n            logger = get_logger(name='scepter')\n        self.logger = logger\n        self.loaded_model = {}\n        self.loaded_model_name = [\n            'diffusion_model', 'first_stage_model', 'cond_stage_model'\n        ]\n\n    def init_from_cfg(self, cfg):\n        self.name = cfg.NAME\n        self.is_default = cfg.get('IS_DEFAULT', False)\n        self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)\n        module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))\n        assert cfg.have('MODEL')\n\n        self.diffusion_model = self.infer_model(\n            cfg.MODEL.DIFFUSION_MODEL, module_paras.get(\n                'DIFFUSION_MODEL',\n                None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None\n        self.first_stage_model = self.infer_model(\n            cfg.MODEL.FIRST_STAGE_MODEL,\n            module_paras.get(\n                'FIRST_STAGE_MODEL',\n                None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None\n        self.cond_stage_model = self.infer_model(\n            cfg.MODEL.COND_STAGE_MODEL,\n            module_paras.get(\n                'COND_STAGE_MODEL',\n                None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None\n\n        self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)\n        # self.refiner_scale = cfg.get('REFINER_SCALE', 0.)\n        # self.refiner_prompt = cfg.get('REFINER_PROMPT', \"\")\n        self.ace_prompt = cfg.get(\"ACE_PROMPT\", [])\n        if self.refiner_model_cfg:\n            self.refiner_model_cfg.USE_DYNAMIC_MODEL = self.use_dynamic_model\n            self.refiner_module = RefinerInference(self.logger)\n            self.refiner_module.init_from_cfg(self.refiner_model_cfg)\n        else:\n            self.refiner_module = None\n\n        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,\n                                          logger=self.logger)\n\n\n        self.interpolate_func = lambda x: (F.interpolate(\n            x.unsqueeze(0),\n            scale_factor=1 / self.size_factor,\n            mode='nearest-exact') if x is not None else None)\n        self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])\n        self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',\n                                                     False)\n        if self.use_text_pos_embeddings:\n            self.text_position_embeddings = TextEmbedding(\n                (10, 4096)).eval().requires_grad_(False).to(we.device_id)\n        else:\n            self.text_position_embeddings = None\n\n        self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN\n        self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)\n        self.size_factor = cfg.get('SIZE_FACTOR', 8)\n        self.decoder_bias = cfg.get('DECODER_BIAS', 0)\n        self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')\n        if not self.use_dynamic_model:\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')\n            self.dynamic_load(self.diffusion_model, 'diffusion_model')\n\n    @torch.no_grad()\n    def encode_first_stage(self, x, **kwargs):\n        _, dtype = self.get_function_info(self.first_stage_model, 'encode')\n        with torch.autocast('cuda',\n                            enabled=(dtype != 'float32'),\n                            dtype=getattr(torch, dtype)):\n            z = [\n                self.scale_factor * get_model(self.first_stage_model)._encode(\n                    i.unsqueeze(0).to(getattr(torch, dtype))) for i in x\n            ]\n        return z\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        _, dtype = self.get_function_info(self.first_stage_model, 'decode')\n        with torch.autocast('cuda',\n                            enabled=(dtype != 'float32'),\n                            dtype=getattr(torch, dtype)):\n            x = [\n                get_model(self.first_stage_model)._decode(\n                    1. / self.scale_factor * i.to(getattr(torch, dtype)))\n                for i in z\n            ]\n        return x\n\n\n\n    @torch.no_grad()\n    def __call__(self,\n                 image=None,\n                 mask=None,\n                 prompt='',\n                 task=None,\n                 negative_prompt='',\n                 output_height=512,\n                 output_width=512,\n                 sampler='ddim',\n                 sample_steps=20,\n                 guide_scale=4.5,\n                 guide_rescale=0.5,\n                 seed=-1,\n                 history_io=None,\n                 tar_index=0,\n                 **kwargs):\n        input_image, input_mask = image, mask\n        g = torch.Generator(device=we.device_id)\n        seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)\n        g.manual_seed(int(seed))\n        if input_image is not None:\n            # assert isinstance(input_image, list) and isinstance(input_mask, list)\n            if task is None:\n                task = [''] * len(input_image)\n            if not isinstance(prompt, list):\n                prompt = [prompt] * len(input_image)\n            if history_io is not None and len(history_io) > 0:\n                his_image, his_maks, his_prompt, his_task = history_io[\n                    'image'], history_io['mask'], history_io[\n                        'prompt'], history_io['task']\n                assert len(his_image) == len(his_maks) == len(\n                    his_prompt) == len(his_task)\n                input_image = his_image + input_image\n                input_mask = his_maks + input_mask\n                task = his_task + task\n                prompt = his_prompt + [prompt[-1]]\n                prompt = [\n                    pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp\n                    for i, pp in enumerate(prompt)\n                ]\n\n            edit_image, edit_image_mask = process_edit_image(\n                input_image, input_mask, task, max_seq_len=self.max_seq_len)\n\n            image, image_mask = edit_image[tar_index], edit_image_mask[\n                tar_index]\n            edit_image, edit_image_mask = [edit_image], [edit_image_mask]\n\n        else:\n            edit_image = edit_image_mask = [[]]\n            image = torch.zeros(\n                size=[3, int(output_height),\n                      int(output_width)])\n            image_mask = torch.ones(\n                size=[1, int(output_height),\n                      int(output_width)])\n            if not isinstance(prompt, list):\n                prompt = [prompt]\n\n        image, image_mask, prompt = [image], [image_mask], [prompt]\n        assert check_list_of_list(prompt) and check_list_of_list(\n            edit_image) and check_list_of_list(edit_image_mask)\n        # Assign Negative Prompt\n        if isinstance(negative_prompt, list):\n            negative_prompt = negative_prompt[0]\n        assert isinstance(negative_prompt, str)\n\n        n_prompt = copy.deepcopy(prompt)\n        for nn_p_id, nn_p in enumerate(n_prompt):\n            assert isinstance(nn_p, list)\n            n_prompt[nn_p_id][-1] = negative_prompt\n\n        is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1\n        image = to_device(image)\n\n        refiner_scale = kwargs.pop(\"refiner_scale\", 0.0)\n        refiner_prompt = kwargs.pop(\"refiner_prompt\", \"\")\n        use_ace = kwargs.pop(\"use_ace\", True)\n        # <= 0 use ace as the txt2img generator.\n        if use_ace and (not is_txt_image or refiner_scale <= 0):\n            ctx, null_ctx = {}, {}\n            # Get Noise Shape\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            x = self.encode_first_stage(image)\n            self.dynamic_unload(self.first_stage_model,\n                                'first_stage_model',\n                                skip_loaded=not self.use_dynamic_model)\n            noise = [\n                torch.empty(*i.shape, device=we.device_id).normal_(generator=g)\n                for i in x\n            ]\n            noise, x_shapes = pack_imagelist_into_tensor(noise)\n            ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes\n\n            image_mask = to_device(image_mask, strict=False)\n            cond_mask = [self.interpolate_func(i) for i in image_mask\n                         ] if image_mask is not None else [None] * len(image)\n            ctx['x_mask'] = null_ctx['x_mask'] = cond_mask\n\n            # Encode Prompt\n            self.dynamic_load(self.cond_stage_model, 'cond_stage_model')\n            function_name, dtype = self.get_function_info(self.cond_stage_model)\n            cont, cont_mask = getattr(get_model(self.cond_stage_model),\n                                      function_name)(prompt)\n            cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,\n                                                         cont_mask)\n            null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),\n                                                function_name)(n_prompt)\n            null_cont, null_cont_mask = self.cond_stage_embeddings(\n                prompt, edit_image, null_cont, null_cont_mask)\n            self.dynamic_unload(self.cond_stage_model,\n                                'cond_stage_model',\n                                skip_loaded=not self.use_dynamic_model)\n            ctx['crossattn'] = cont\n            null_ctx['crossattn'] = null_cont\n\n            # Encode Edit Images\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            edit_image = [to_device(i, strict=False) for i in edit_image]\n            edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]\n            e_img, e_mask = [], []\n            for u, m in zip(edit_image, edit_image_mask):\n                if u is None:\n                    continue\n                if m is None:\n                    m = [None] * len(u)\n                e_img.append(self.encode_first_stage(u, **kwargs))\n                e_mask.append([self.interpolate_func(i) for i in m])\n            self.dynamic_unload(self.first_stage_model,\n                                'first_stage_model',\n                                skip_loaded=not self.use_dynamic_model)\n            null_ctx['edit'] = ctx['edit'] = e_img\n            null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask\n\n            # Diffusion Process\n            self.dynamic_load(self.diffusion_model, 'diffusion_model')\n            function_name, dtype = self.get_function_info(self.diffusion_model)\n            with torch.autocast('cuda',\n                                enabled=dtype in ('float16', 'bfloat16'),\n                                dtype=getattr(torch, dtype)):\n                latent = self.diffusion.sample(\n                    noise=noise,\n                    sampler=sampler,\n                    model=get_model(self.diffusion_model),\n                    model_kwargs=[{\n                        'cond':\n                        ctx,\n                        'mask':\n                        cont_mask,\n                        'text_position_embeddings':\n                        self.text_position_embeddings.pos if hasattr(\n                            self.text_position_embeddings, 'pos') else None\n                    }, {\n                        'cond':\n                        null_ctx,\n                        'mask':\n                        null_cont_mask,\n                        'text_position_embeddings':\n                        self.text_position_embeddings.pos if hasattr(\n                            self.text_position_embeddings, 'pos') else None\n                    }] if guide_scale is not None and guide_scale > 1 else {\n                        'cond':\n                        null_ctx,\n                        'mask':\n                        cont_mask,\n                        'text_position_embeddings':\n                        self.text_position_embeddings.pos if hasattr(\n                            self.text_position_embeddings, 'pos') else None\n                    },\n                    steps=sample_steps,\n                    show_progress=True,\n                    seed=seed,\n                    guide_scale=guide_scale,\n                    guide_rescale=guide_rescale,\n                    return_intermediate=None,\n                    **kwargs)\n            if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,\n                                'diffusion_model',\n                                skip_loaded=not self.use_dynamic_model)\n\n            # Decode to Pixel Space\n            self.dynamic_load(self.first_stage_model, 'first_stage_model')\n            samples = unpack_tensor_into_imagelist(latent, x_shapes)\n            x_samples = self.decode_first_stage(samples)\n            self.dynamic_unload(self.first_stage_model,\n                                'first_stage_model',\n                                skip_loaded=not self.use_dynamic_model)\n            x_samples = [x.squeeze(0) for x in x_samples]\n        else:\n            x_samples = image\n        if self.refiner_module and refiner_scale > 0:\n            if is_txt_image:\n                random.shuffle(self.ace_prompt)\n                input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == \"\" else p[0] for p in prompt]\n                input_refine_scale = -1.\n            else:\n                input_refine_prompt = [p[0].replace(\"{image}\", \"\") + \" \" + refiner_prompt for p in prompt]\n                input_refine_scale = refiner_scale\n                print(input_refine_prompt)\n\n            x_samples = self.refiner_module.refine(x_samples,\n                                                   reverse_scale = input_refine_scale,\n                                                   prompt= input_refine_prompt,\n                                                   seed=seed,\n                                                   use_dynamic_model=self.use_dynamic_model)\n\n        imgs = [\n            torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,\n                        min=0.0,\n                        max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()\n            for x_i in x_samples\n        ]\n        imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]\n        return imgs\n\n    def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):\n        if self.use_text_pos_embeddings and not torch.sum(\n                self.text_position_embeddings.pos) > 0:\n            identifier_cont, _ = getattr(get_model(self.cond_stage_model),\n                                         'encode')(self.text_indentifers,\n                                                   return_mask=True)\n            self.text_position_embeddings.load_state_dict(\n                {'pos': identifier_cont[:, 0, :]})\n\n        cont_, cont_mask_ = [], []\n        for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):\n            if isinstance(pp, list):\n                cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])\n                cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])\n            else:\n                raise NotImplementedError\n\n        return cont_, cont_mask_\n"
  },
  {
    "path": "chatbot/config/chatbot_ui.yaml",
    "content": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n  - NAME: \"HuggingfaceFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"ModelscopeFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"LocalFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"HttpFs\"\n    TEMP_DIR: ./cache\n#\nENABLE_I2V: False\n#\nMODEL:\n  EDIT_MODEL:\n    MODEL_CFG_DIR: chatbot/config/models/\n    DEFAULT: ace_0.6b_512\n  I2V:\n    MODEL_NAME: CogVideoX-5b-I2V\n    MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/\n  CAPTIONER:\n    MODEL_NAME: InternVL2-2B\n    MODEL_DIR: ms://OpenGVLab/InternVL2-2B/\n    PROMPT: '<image>\\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as \"a dog running\" or \"a person turns to left\". No more than 30 words.'\n  ENHANCER:\n    MODEL_NAME: Meta-Llama-3.1-8B-Instruct\n    MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/\n"
  },
  {
    "path": "chatbot/config/models/ace_0.6b_512.yaml",
    "content": "NAME: ACE_0.6B_512\nIS_DEFAULT: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    INPUT_MASK:\n    TASK:\n    PROMPT: \"\"\n    NEGATIVE_PROMPT: \"\"\n    OUTPUT_HEIGHT: 512\n    OUTPUT_WIDTH: 512\n    SAMPLER: ddim\n    SAMPLE_STEPS: 20\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n    SEED: -1\n    TAR_INDEX: 0\n  OUTPUT:\n    LATENT:\n    IMAGES:\n    SEED:\n  MODULES_PARAS:\n    FIRST_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode\n          DTYPE: float16\n          INPUT: [\"IMAGE\"]\n        - NAME: decode\n          DTYPE: float16\n          INPUT: [\"LATENT\"]\n    #\n    DIFFUSION_MODEL:\n      FUNCTION:\n        - NAME: forward\n          DTYPE: float16\n          INPUT: [\"SAMPLE_STEPS\", \"SAMPLE\", \"GUIDE_SCALE\"]\n    #\n    COND_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode_list\n          DTYPE: bfloat16\n          INPUT: [\"PROMPT\"]\n#\nMODEL:\n  NAME: LdmACE\n  PRETRAINED_MODEL:\n  IGNORE_KEYS: [ ]\n  SCALE_FACTOR: 0.18215\n  SIZE_FACTOR: 8\n  DECODER_BIAS: 0.5\n  DEFAULT_N_PROMPT: \"\"\n  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n  USE_TEXT_POS_EMBEDDINGS: True\n  #\n  DIFFUSION:\n    NAME: ACEDiffusion\n    PREDICTION_TYPE: eps\n    MIN_SNR_GAMMA:\n    NOISE_SCHEDULER:\n      NAME: LinearScheduler\n      NUM_TIMESTEPS: 1000\n      BETA_MIN: 0.0001\n      BETA_MAX: 0.02\n  #\n  DIFFUSION_MODEL:\n    NAME: DiTACE\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth\n    IGNORE_KEYS: [ ]\n    PATCH_SIZE: 2\n    IN_CHANNELS: 4\n    HIDDEN_SIZE: 1152\n    DEPTH: 28\n    NUM_HEADS: 16\n    MLP_RATIO: 4.0\n    PRED_SIGMA: True\n    DROP_PATH: 0.0\n    WINDOW_DIZE: 0\n    Y_CHANNELS: 4096\n    MAX_SEQ_LEN: 1024\n    QK_NORM: True\n    USE_GRAD_CHECKPOINT: True\n    ATTENTION_BACKEND: flash_attn\n  #\n  FIRST_STAGE_MODEL:\n    NAME: AutoencoderKL\n    EMBED_DIM: 4\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin\n    IGNORE_KEYS: []\n    #\n    ENCODER:\n      NAME: Encoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DOUBLE_Z: True\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n    #\n    DECODER:\n      NAME: Decoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n      GIVE_PRE_END: False\n      TANH_OUT: False\n  #\n  COND_STAGE_MODEL:\n    NAME: ACETextEmbedder\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/\n    TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl\n    LENGTH: 120\n    T5_DTYPE: bfloat16\n    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    CLEAN: whitespace\n    USE_GRAD: False\n"
  },
  {
    "path": "chatbot/example.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport os\n\nfrom scepter.modules.utils.file_system import FS\nfrom PIL import Image\n\n\ndef download_image(image, local_path=None):\n    if not FS.exists(local_path):\n        local_path = FS.get_from(image, local_path=local_path)\n    return local_path\n\ndef blank_image():\n    return Image.new('RGBA', (128, 128), (0, 0, 0, 0))\n\n\n\ndef get_examples(cache_dir):\n    print('Downloading Examples ...')\n    bl_img = blank_image()\n    examples = [\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',\n                os.path.join(cache_dir, 'examples/e33edc106953.png')), bl_img,\n                bl_img, '{image} let the man smile', 6666\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',\n                os.path.join(cache_dir, 'examples/5d2bcc91a3e9.png')), bl_img,\n                bl_img, 'let the man in {image} wear sunglasses', 9999\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',\n                os.path.join(cache_dir, 'examples/3a52eac708bd.png')), bl_img,\n                bl_img, '{image} red hair', 9999\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',\n                os.path.join(cache_dir, 'examples/3f4dc464a0ea.png')), bl_img,\n                bl_img, '{image} let the man serious', 99999\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/131ca90fd2a9.png')), bl_img, bl_img,\n            '\"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene.\" , generate the image that corresponds to the given scribble {image}.',\n            613725\n        ],\n        [\n            'Render Text',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',\n                os.path.join(cache_dir, 'examples/33e9f27c2c48.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/33e9f27c2c48_mask.png')), bl_img,\n            'Put the text \"C A T\" at the position marked by mask in the {image}',\n            6666\n        ],\n        [\n            'Style Transfer',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',\n                os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), bl_img,\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',\n                os.path.join(cache_dir, 'examples/2e02975293d6.png')),\n            'edit {image} based on the style of {image1} ', 99999\n        ],\n        [\n            'Outpainting',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',\n                os.path.join(cache_dir, 'examples/f2b22c08be3f.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/f2b22c08be3f_mask.png')), bl_img,\n            'Could the {image} be widened within the space designated by mask, while retaining the original?',\n            6666\n        ],\n        [\n            'Image Segmentation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',\n                os.path.join(cache_dir, 'examples/db3ebaa81899.png')), bl_img,\n            bl_img, '{image} Segmentation', 6666\n        ],\n        [\n            'Depth Estimation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',\n                os.path.join(cache_dir, 'examples/f1927c4692ba.png')), bl_img,\n            bl_img, '{image} Depth Estimation', 6666\n        ],\n        [\n            'Pose Estimation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',\n                os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), bl_img,\n            bl_img, '{image} distinguish the poses of the figures', 999999\n        ],\n        [\n            'Scribble Extraction',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',\n                os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), bl_img,\n            bl_img, 'Generate a scribble of {image}, please.', 6666\n        ],\n        [\n            'Mosaic',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',\n                os.path.join(cache_dir, 'examples/3a2f52361eea.png')), bl_img,\n            bl_img, 'Adapt {image} into a mosaic representation.', 6666\n        ],\n        [\n            'Edge map Extraction',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',\n                os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), bl_img,\n            bl_img, 'Get the edge-enhanced result for {image}.', 6666\n        ],\n        [\n            'Grayscale',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',\n                os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), bl_img,\n            bl_img, 'transform {image} into a black and white one', 6666\n        ],\n        [\n            'Contour Extraction',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/19652d0f6c4b.png')), bl_img, bl_img,\n            'Would you be able to make a contour picture from {image} for me?',\n            6666\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/249cda2844b7.png')), bl_img, bl_img,\n            'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in \"a mighty cat lying on the bed”.',\n            6666\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/411f6c4b8e6c.png')), bl_img, bl_img,\n            'use the depth map {image} and the text caption \"a cut white cat\" to create a corresponding graphic image',\n            999999\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/a35c96ed137a.png')), bl_img, bl_img,\n            'help translate this posture schema {image} into a colored image based on the context I provided \"A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope.\"',\n            3599999\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/dcb2fc86f1ce.png')), bl_img, bl_img,\n            'Transform and generate an image using mosaic {image} and \"Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting.\" description',\n            6666\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/4cd4ee494962.png')), bl_img, bl_img,\n            'make this {image} colorful as per the \"beautiful sunflowers\"',\n            6666\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/a47e3a9cd166.png')), bl_img, bl_img,\n            'Take the edge conscious {image} and the written guideline \"A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic.\" and produce a realistic image.',\n            613725\n        ],\n        [\n            'Controllable Generation',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/d890ed8a3ac2.png')), bl_img, bl_img,\n            'creating a vivid image based on {image} and description \"This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish.\"',\n            6666\n        ],\n        [\n            'Image Denoising',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/0844a686a179.png')), bl_img, bl_img,\n            'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',\n            6666\n        ],\n        [\n            'Inpainting',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',\n                os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/fa91b6b7e59b_mask.png')), bl_img,\n            'Ensure to overhaul the parts of the {image} indicated by the mask.',\n            6666\n        ],\n        [\n            'Inpainting',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',\n                os.path.join(cache_dir, 'examples/632899695b26.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/632899695b26_mask.png')), bl_img,\n            'Refashion the mask portion of {image} in accordance with \"A yellow egg with a smiling face painted on it\"',\n            6666\n        ],\n        [\n            'General Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/354d17594afe.png')), bl_img, bl_img,\n            '{image} change the dog\\'s posture to walking in the water, and change the background to green plants and a pond.',\n            6666\n        ],\n        [\n            'General Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/38946455752b.png')), bl_img, bl_img,\n            '{image} change the color of the dress from white to red and the model\\'s hair color red brown to blonde.Other parts remain unchanged',\n            6669\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/3ba5202f0cd8.png')), bl_img, bl_img,\n            'Keep the same facial feature in @3ba5202f0cd8, change the woman\\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',\n            99999\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',\n                os.path.join(cache_dir, 'examples/369365b94725.png')), bl_img,\n            bl_img, '{image} Make her looking at the camera', 6666\n        ],\n        [\n            'Facial Editing',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',\n                os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), bl_img,\n            bl_img, '{image} Remove the smile from his face', 9899999\n        ],\n        [\n            'Remove Text',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',\n                os.path.join(cache_dir, 'examples/8530a6711b2e.png')), bl_img,\n            bl_img, 'Aim to remove any textual element in {image}', 6666\n        ],\n        [\n            'Remove Text',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',\n                os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/c4d7fb28f8f6_mask.png')), bl_img,\n            'Rub out any text found in the mask sector of the {image}.', 6666\n        ],\n        [\n            'Remove Object',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/e2f318fa5e5b.png')), bl_img, bl_img,\n            'Remove the unicorn in this {image}, ensuring a smooth edit.',\n            99999\n        ],\n        [\n            'Remove Object',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',\n                os.path.join(cache_dir, 'examples/1ae96d8aca00.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',\n                os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')),\n            bl_img, 'Discard the contents of the mask area from {image}.', 99999\n        ],\n        [\n            'Add Object',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',\n                os.path.join(cache_dir, 'examples/80289f48e511.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',\n                os.path.join(cache_dir,\n                             'examples/80289f48e511_mask.png')), bl_img,\n            'add a Hot Air Balloon into the {image}, per the mask', 613725\n        ],\n        [\n            'Style Transfer',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',\n                os.path.join(cache_dir, 'examples/d725cb2009e8.png')), bl_img,\n            bl_img, 'Change the style of {image} to colored pencil style', 99999\n        ],\n        [\n            'Style Transfer',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',\n                os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), bl_img,\n            bl_img, 'make {image} to Walt Disney Animation style', 99999\n        ],\n        [\n            'Try On',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',\n                os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',\n                os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')),\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',\n                os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')),\n            'Change the cloth in {image} to the one in {image1}', 99999\n        ],\n        [\n            'Workflow',\n            download_image(\n                'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',\n                os.path.join(cache_dir, 'examples/cb85353c004b.png')), bl_img,\n            bl_img, '<workflow> ice cream {image}', 99999\n        ],\n    ]\n    print('Finish. Start building UI ...')\n    return examples\n"
  },
  {
    "path": "chatbot/infer.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as TF\n\nfrom scepter.modules.model.registry import DIFFUSIONS\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.logger import get_logger\nfrom scepter.modules.inference.diffusion_inference import DiffusionInference, get_model\nfrom modules.model.utils.basic_utils import (\n    check_list_of_list,\n    pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,\n    to_device,\n    unpack_tensor_into_imagelist\n)\n\n\ndef process_edit_image(images,\n                       masks,\n                       tasks,\n                       max_seq_len=1024,\n                       max_aspect_ratio=4,\n                       d=16,\n                       **kwargs):\n\n    if not isinstance(images, list):\n        images = [images]\n    if not isinstance(masks, list):\n        masks = [masks]\n    if not isinstance(tasks, list):\n        tasks = [tasks]\n\n    img_tensors = []\n    mask_tensors = []\n    for img, mask, task in zip(images, masks, tasks):\n        if mask is None or mask == '':\n            mask = Image.new('L', img.size, 0)\n        W, H = img.size\n        if H / W > max_aspect_ratio:\n            img = TF.center_crop(img, [int(max_aspect_ratio * W), W])\n            mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])\n        elif W / H > max_aspect_ratio:\n            img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])\n            mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])\n\n        H, W = img.height, img.width\n        scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))\n        rH = int(H * scale) // d * d  # ensure divisible by self.d\n        rW = int(W * scale) // d * d\n\n        img = TF.resize(img, (rH, rW),\n                        interpolation=TF.InterpolationMode.BICUBIC)\n        mask = TF.resize(mask, (rH, rW),\n                         interpolation=TF.InterpolationMode.NEAREST_EXACT)\n\n        mask = np.asarray(mask)\n        mask = np.where(mask > 128, 1, 0)\n        mask = mask.astype(\n            np.float32) if np.any(mask) else np.ones_like(mask).astype(\n                np.float32)\n\n        img_tensor = TF.to_tensor(img).to(we.device_id)\n        img_tensor = TF.normalize(img_tensor,\n                                  mean=[0.5, 0.5, 0.5],\n                                  std=[0.5, 0.5, 0.5])\n        mask_tensor = TF.to_tensor(mask).to(we.device_id)\n        if task in ['inpainting', 'Try On', 'Inpainting']:\n            mask_indicator = mask_tensor.repeat(3, 1, 1)\n            img_tensor[mask_indicator == 1] = -1.0\n        img_tensors.append(img_tensor)\n        mask_tensors.append(mask_tensor)\n    return img_tensors, mask_tensors\n\n\nclass TextEmbedding(nn.Module):\n    def __init__(self, embedding_shape):\n        super().__init__()\n        self.pos = nn.Parameter(data=torch.zeros(embedding_shape))\n\n\nclass ACEInference(DiffusionInference):\n    def __init__(self, logger=None):\n        if logger is None:\n            logger = get_logger(name='scepter')\n        self.logger = logger\n        self.loaded_model = {}\n        self.loaded_model_name = [\n            'diffusion_model', 'first_stage_model', 'cond_stage_model'\n        ]\n\n    def init_from_cfg(self, cfg):\n        self.name = cfg.NAME\n        self.is_default = cfg.get('IS_DEFAULT', False)\n        module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))\n        assert cfg.have('MODEL')\n\n        self.diffusion_model = self.infer_model(\n            cfg.MODEL.DIFFUSION_MODEL, module_paras.get(\n                'DIFFUSION_MODEL',\n                None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None\n        self.first_stage_model = self.infer_model(\n            cfg.MODEL.FIRST_STAGE_MODEL,\n            module_paras.get(\n                'FIRST_STAGE_MODEL',\n                None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None\n        self.cond_stage_model = self.infer_model(\n            cfg.MODEL.COND_STAGE_MODEL,\n            module_paras.get(\n                'COND_STAGE_MODEL',\n                None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None\n        self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,\n                                          logger=self.logger)\n\n        self.interpolate_func = lambda x: (F.interpolate(\n            x.unsqueeze(0),\n            scale_factor=1 / self.size_factor,\n            mode='nearest-exact') if x is not None else None)\n        self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])\n        self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',\n                                               False)\n        if self.use_text_pos_embeddings:\n            self.text_position_embeddings = TextEmbedding(\n                (10, 4096)).eval().requires_grad_(False).to(we.device_id)\n        else:\n            self.text_position_embeddings = None\n\n        self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN\n        self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)\n        self.size_factor = cfg.get('SIZE_FACTOR', 8)\n        self.decoder_bias = cfg.get('DECODER_BIAS', 0)\n        self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')\n\n        self.dynamic_load(self.first_stage_model, 'first_stage_model')\n        self.dynamic_load(self.cond_stage_model, 'cond_stage_model')\n        self.dynamic_load(self.diffusion_model, 'diffusion_model')\n\n    @torch.no_grad()\n    def encode_first_stage(self, x, **kwargs):\n        _, dtype = self.get_function_info(self.first_stage_model, 'encode')\n        with torch.autocast('cuda',\n                            enabled=(dtype != 'float32'),\n                            dtype=getattr(torch, dtype)):\n            z = [\n                self.scale_factor * get_model(self.first_stage_model)._encode(\n                    i.unsqueeze(0).to(getattr(torch, dtype))) for i in x\n            ]\n        return z\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        _, dtype = self.get_function_info(self.first_stage_model, 'decode')\n        with torch.autocast('cuda',\n                            enabled=(dtype != 'float32'),\n                            dtype=getattr(torch, dtype)):\n            x = [\n                get_model(self.first_stage_model)._decode(\n                    1. / self.scale_factor * i.to(getattr(torch, dtype)))\n                for i in z\n            ]\n        return x\n\n    @torch.no_grad()\n    def __call__(self,\n                 image=None,\n                 mask=None,\n                 prompt='',\n                 task=None,\n                 negative_prompt='',\n                 output_height=512,\n                 output_width=512,\n                 sampler='ddim',\n                 sample_steps=20,\n                 guide_scale=4.5,\n                 guide_rescale=0.5,\n                 seed=-1,\n                 history_io=None,\n                 tar_index=0,\n                 **kwargs):\n        input_image, input_mask = image, mask\n        g = torch.Generator(device=we.device_id)\n        seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)\n        g.manual_seed(int(seed))\n\n        if input_image is not None:\n            assert isinstance(input_image, list) and isinstance(\n                input_mask, list)\n            if task is None:\n                task = [''] * len(input_image)\n            if not isinstance(prompt, list):\n                prompt = [prompt] * len(input_image)\n            if history_io is not None and len(history_io) > 0:\n                his_image, his_maks, his_prompt, his_task = history_io[\n                    'image'], history_io['mask'], history_io[\n                        'prompt'], history_io['task']\n                assert len(his_image) == len(his_maks) == len(\n                    his_prompt) == len(his_task)\n                input_image = his_image + input_image\n                input_mask = his_maks + input_mask\n                task = his_task + task\n                prompt = his_prompt + [prompt[-1]]\n                prompt = [\n                    pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp\n                    for i, pp in enumerate(prompt)\n                ]\n\n            edit_image, edit_image_mask = process_edit_image(\n                input_image, input_mask, task, max_seq_len=self.max_seq_len)\n\n            image, image_mask = edit_image[tar_index], edit_image_mask[\n                tar_index]\n            edit_image, edit_image_mask = [edit_image], [edit_image_mask]\n\n        else:\n            edit_image = edit_image_mask = [[]]\n            image = torch.zeros(\n                size=[3, int(output_height),\n                      int(output_width)])\n            image_mask = torch.ones(\n                size=[1, int(output_height),\n                      int(output_width)])\n            if not isinstance(prompt, list):\n                prompt = [prompt]\n\n        image, image_mask, prompt = [image], [image_mask], [prompt]\n        assert check_list_of_list(prompt) and check_list_of_list(\n            edit_image) and check_list_of_list(edit_image_mask)\n        # Assign Negative Prompt\n        if isinstance(negative_prompt, list):\n            negative_prompt = negative_prompt[0]\n        assert isinstance(negative_prompt, str)\n\n        n_prompt = copy.deepcopy(prompt)\n        for nn_p_id, nn_p in enumerate(n_prompt):\n            assert isinstance(nn_p, list)\n            n_prompt[nn_p_id][-1] = negative_prompt\n\n        ctx, null_ctx = {}, {}\n\n        # Get Noise Shape\n        image = to_device(image)\n        x = self.encode_first_stage(image)\n        noise = [\n            torch.empty(*i.shape, device=we.device_id).normal_(generator=g)\n            for i in x\n        ]\n        noise, x_shapes = pack_imagelist_into_tensor(noise)\n        ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes\n\n        image_mask = to_device(image_mask, strict=False)\n        cond_mask = [self.interpolate_func(i) for i in image_mask\n                     ] if image_mask is not None else [None] * len(image)\n        ctx['x_mask'] = null_ctx['x_mask'] = cond_mask\n\n        # Encode Prompt\n        \n        function_name, dtype = self.get_function_info(self.cond_stage_model)\n        cont, cont_mask = getattr(get_model(self.cond_stage_model),\n                                  function_name)(prompt)\n        cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,\n                                                     cont_mask)\n        null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),\n                                            function_name)(n_prompt)\n        null_cont, null_cont_mask = self.cond_stage_embeddings(\n            prompt, edit_image, null_cont, null_cont_mask)\n        ctx['crossattn'] = cont\n        null_ctx['crossattn'] = null_cont\n\n        # Encode Edit Images\n        edit_image = [to_device(i, strict=False) for i in edit_image]\n        edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]\n        e_img, e_mask = [], []\n        for u, m in zip(edit_image, edit_image_mask):\n            if u is None:\n                continue\n            if m is None:\n                m = [None] * len(u)\n            e_img.append(self.encode_first_stage(u, **kwargs))\n            e_mask.append([self.interpolate_func(i) for i in m])\n\n        null_ctx['edit'] = ctx['edit'] = e_img\n        null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask\n\n        # Diffusion Process\n        function_name, dtype = self.get_function_info(self.diffusion_model)\n        with torch.autocast('cuda',\n                            enabled=dtype in ('float16', 'bfloat16'),\n                            dtype=getattr(torch, dtype)):\n            latent = self.diffusion.sample(\n                noise=noise,\n                sampler=sampler,\n                model=get_model(self.diffusion_model),\n                model_kwargs=[{\n                    'cond':\n                    ctx,\n                    'mask':\n                    cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                }, {\n                    'cond':\n                    null_ctx,\n                    'mask':\n                    null_cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                }] if guide_scale is not None and guide_scale > 1 else {\n                    'cond':\n                    null_ctx,\n                    'mask':\n                    cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                },\n                steps=sample_steps,\n                show_progress=True,\n                seed=seed,\n                guide_scale=guide_scale,\n                guide_rescale=guide_rescale,\n                return_intermediate=None,\n                **kwargs)\n\n        # Decode to Pixel Space\n        samples = unpack_tensor_into_imagelist(latent, x_shapes)\n        x_samples = self.decode_first_stage(samples)\n\n        imgs = [\n            torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,\n                        min=0.0,\n                        max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()\n            for x_i in x_samples\n        ]\n        imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]\n        return imgs\n\n    def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):\n        if self.use_text_pos_embeddings and not torch.sum(\n                self.text_position_embeddings.pos) > 0:\n            identifier_cont, _ = getattr(get_model(self.cond_stage_model),\n                                         'encode')(self.text_indentifers,\n                                                   return_mask=True)\n            self.text_position_embeddings.load_state_dict(\n                {'pos': identifier_cont[:, 0, :]})\n\n        cont_, cont_mask_ = [], []\n        for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):\n            if isinstance(pp, list):\n                cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])\n                cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])\n            else:\n                raise NotImplementedError\n\n        return cont_, cont_mask_\n"
  },
  {
    "path": "chatbot/run_gradio.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport base64\nimport copy\nimport csv\nimport glob\nimport io\nimport os\nimport random\nimport re\nimport string\nimport sys\nimport threading\nimport warnings\n\nimport cv2\nimport gradio as gr\nimport numpy as np\nimport torch\nimport transformers\nfrom PIL import Image\nfrom transformers import AutoModel, AutoTokenizer\n\nfrom scepter.modules.utils.config import Config\nfrom scepter.modules.utils.directory import get_md5\nfrom scepter.modules.utils.file_system import FS\nfrom scepter.studio.utils.env import init_env\nfrom importlib.metadata import version\n\nfrom ace_inference import ACEInference\nfrom example import get_examples\nfrom utils import load_image\n\ncsv.field_size_limit(sys.maxsize)\n\nrefresh_sty = '\\U0001f504'  # 🔄\nclear_sty = '\\U0001f5d1'  # 🗑️\nupload_sty = '\\U0001f5bc'  # 🖼️\nsync_sty = '\\U0001f4be'  # 💾\nchat_sty = '\\U0001F4AC'  # 💬\nvideo_sty = '\\U0001f3a5'  # 🎥\n\nlock = threading.Lock()\n\n\nclass ChatBotUI(object):\n    def __init__(self,\n                 cfg_general_file,\n                 is_debug=False,\n                 language='en',\n                 root_work_dir='./'):\n        try:\n            from diffusers import CogVideoXImageToVideoPipeline\n            from diffusers.utils import export_to_video\n        except Exception as e:\n            print(f\"Import diffusers failed, please install or upgrade diffusers. Error information: {e}\")\n        if isinstance(cfg_general_file, str):\n            cfg = Config(cfg_file=cfg_general_file)\n        else:\n            cfg = cfg_general_file\n        cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)\n        if not FS.exists(cfg.WORK_DIR):\n            FS.make_dir(cfg.WORK_DIR)\n        cfg = init_env(cfg)\n        self.cache_dir = cfg.WORK_DIR\n        self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []\n        self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR\n        self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,\n                                                  '*.yaml'))\n        self.model_choices = dict()\n        self.default_model_name = ''\n        for i in self.model_yamls:\n            model_cfg = Config(load=True, cfg_file=i)\n            model_name = model_cfg.NAME\n            if model_cfg.IS_DEFAULT: self.default_model_name = model_name\n            self.model_choices[model_name] = model_cfg\n        print('Models: ', self.model_choices.keys())\n        assert len(self.model_choices) > 0\n        if self.default_model_name == \"\": self.default_model_name = list(self.model_choices.keys())[0]\n        self.model_name = self.default_model_name\n        self.pipe = ACEInference()\n        self.pipe.init_from_cfg(self.model_choices[self.default_model_name])\n        self.max_msgs = 20\n        self.enable_i2v = cfg.get('ENABLE_I2V', False)\n        self.gradio_version = version('gradio')\n\n        if self.enable_i2v:\n            self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR\n            self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME\n            if self.i2v_model_name == 'CogVideoX-5b-I2V':\n                with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:\n                    self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(\n                        local_dir, torch_dtype=torch.bfloat16).cuda()\n            else:\n                raise NotImplementedError\n\n            with FS.get_dir_to_local_dir(\n                    cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:\n                self.captioner = AutoModel.from_pretrained(\n                    local_dir,\n                    torch_dtype=torch.bfloat16,\n                    low_cpu_mem_usage=True,\n                    use_flash_attn=True,\n                    trust_remote_code=True).eval().cuda()\n                self.llm_tokenizer = AutoTokenizer.from_pretrained(\n                    local_dir, trust_remote_code=True, use_fast=False)\n                self.llm_generation_config = dict(max_new_tokens=1024,\n                                                  do_sample=True)\n                self.llm_prompt = cfg.LLM.PROMPT\n                self.llm_max_num = 2\n\n            with FS.get_dir_to_local_dir(\n                    cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:\n                self.enhancer = transformers.pipeline(\n                    'text-generation',\n                    model=local_dir,\n                    model_kwargs={'torch_dtype': torch.bfloat16},\n                    device_map='auto',\n                )\n\n            sys_prompt = \"\"\"You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.\n\n            For example , outputting \" a beautiful morning in the woods with the sun peaking through the trees \" will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.\n            There are a few rules to follow:\n\n            You will only ever output a single video description per user request.\n\n            When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.\n            Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.\n\n            Video descriptions must have the same num of words as examples below. Extra words will be ignored.\n            \"\"\"\n            self.enhance_ctx = [\n                {\n                    'role': 'system',\n                    'content': sys_prompt\n                },\n                {\n                    'role':\n                    'user',\n                    'content':\n                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : \"a girl is on the beach\"',\n                },\n                {\n                    'role':\n                    'assistant',\n                    'content':\n                    \"A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.\",\n                },\n                {\n                    'role':\n                    'user',\n                    'content':\n                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : \"A man jogging on a football field\"',\n                },\n                {\n                    'role':\n                    'assistant',\n                    'content':\n                    \"A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.\",\n                },\n                {\n                    'role':\n                    'user',\n                    'content':\n                    'Create an imaginative video descriptive caption or modify an earlier caption for the user input : \" A woman is dancing, HD footage, close-up\"',\n                },\n                {\n                    'role':\n                    'assistant',\n                    'content':\n                    'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',\n                },\n            ]\n\n    def create_ui(self):\n\n        css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'\n        with gr.Blocks(css=css,\n                       title='Chatbot',\n                       head='Chatbot',\n                       analytics_enabled=False):\n            self.history = gr.State(value=[])\n            self.images = gr.State(value={})\n            self.history_result = gr.State(value={})\n            self.retry_msg = gr.State(value='')\n            with gr.Group():\n                self.ui_mode = gr.State(value='legacy')\n                with gr.Row(equal_height=True, visible=False) as self.chat_group:\n                    with gr.Column(visible=True) as self.chat_page:\n                        self.chatbot = gr.Chatbot(\n                            height=600,\n                            value=[],\n                            bubble_full_width=False,\n                            show_copy_button=True,\n                            container=False,\n                            placeholder='<strong>Chat Box</strong>')\n                        with gr.Row():\n                            self.clear_btn = gr.Button(clear_sty +\n                                                       ' Clear Chat',\n                                                       size='sm')\n\n                    with gr.Column(visible=False) as self.editor_page:\n                        with gr.Tabs(visible=False) as self.upload_tabs:\n                            with gr.Tab(id='ImageUploader',\n                                        label='Image Uploader',\n                                        visible=True) as self.upload_tab:\n                                self.image_uploader = gr.Image(\n                                    height=550,\n                                    interactive=True,\n                                    type='pil',\n                                    image_mode='RGB',\n                                    sources=['upload'],\n                                    elem_id='image_uploader',\n                                    format='png')\n                                with gr.Row():\n                                    self.sub_btn_1 = gr.Button(\n                                        value='Submit',\n                                        elem_id='upload_submit')\n                                    self.ext_btn_1 = gr.Button(value='Exit')\n                        with gr.Tabs(visible=False) as self.edit_tabs:\n                            with gr.Tab(id='ImageEditor',\n                                        label='Image Editor') as self.edit_tab:\n                                self.mask_type = gr.Dropdown(\n                                    label='Mask Type',\n                                    choices=[\n                                        'Background', 'Composite',\n                                        'Outpainting'\n                                    ],\n                                    value='Background')\n                                self.mask_type_info = gr.HTML(\n                                    value=\n                                    \"<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>\"\n                                )\n                                with gr.Accordion(\n                                        label='Outpainting Setting',\n                                        open=True,\n                                        visible=False) as self.outpaint_tab:\n                                    with gr.Row(variant='panel'):\n                                        self.top_ext = gr.Slider(\n                                            show_label=True,\n                                            label='Top Extend Ratio',\n                                            minimum=0.0,\n                                            maximum=2.0,\n                                            step=0.1,\n                                            value=0.25)\n                                        self.bottom_ext = gr.Slider(\n                                            show_label=True,\n                                            label='Bottom Extend Ratio',\n                                            minimum=0.0,\n                                            maximum=2.0,\n                                            step=0.1,\n                                            value=0.25)\n                                    with gr.Row(variant='panel'):\n                                        self.left_ext = gr.Slider(\n                                            show_label=True,\n                                            label='Left Extend Ratio',\n                                            minimum=0.0,\n                                            maximum=2.0,\n                                            step=0.1,\n                                            value=0.25)\n                                        self.right_ext = gr.Slider(\n                                            show_label=True,\n                                            label='Right Extend Ratio',\n                                            minimum=0.0,\n                                            maximum=2.0,\n                                            step=0.1,\n                                            value=0.25)\n                                    with gr.Row(variant='panel'):\n                                        self.img_pad_btn = gr.Button(\n                                            value='Pad Image')\n\n                                self.image_editor = gr.ImageMask(\n                                    value=None,\n                                    sources=[],\n                                    layers=False,\n                                    label='Edit Image',\n                                    elem_id='image_editor',\n                                    format='png')\n                                with gr.Row():\n                                    self.sub_btn_2 = gr.Button(\n                                        value='Submit', elem_id='edit_submit')\n                                    self.ext_btn_2 = gr.Button(value='Exit')\n\n                            with gr.Tab(id='ImageViewer',\n                                        label='Image Viewer') as self.image_view_tab:\n                                if self.gradio_version >= '5.0.0':\n                                    self.image_viewer = gr.Image(\n                                        label='Image',\n                                        type='pil',\n                                        show_download_button=True,\n                                        elem_id='image_viewer')\n                                else:\n                                    try:\n                                        from gradio_imageslider import ImageSlider\n                                    except Exception as e:\n                                        print(f\"Import gradio_imageslider failed, please install.\")\n                                    self.image_viewer = ImageSlider(\n                                        label='Image',\n                                        type='pil',\n                                        show_download_button=True,\n                                        elem_id='image_viewer')\n\n                                self.ext_btn_3 = gr.Button(value='Exit')\n\n                            with gr.Tab(id='VideoViewer',\n                                        label='Video Viewer',\n                                        visible=False) as self.video_view_tab:\n                                self.video_viewer = gr.Video(\n                                    label='Video',\n                                    interactive=False,\n                                    sources=[],\n                                    format='mp4',\n                                    show_download_button=True,\n                                    elem_id='video_viewer',\n                                    loop=True,\n                                    autoplay=True)\n\n                                self.ext_btn_4 = gr.Button(value='Exit')\n\n                with gr.Row(equal_height=True, visible=True) as self.legacy_group:\n                    with gr.Column():\n                        self.legacy_image_uploader = gr.Image(\n                            height=550,\n                            interactive=True,\n                            type='pil',\n                            image_mode='RGB',\n                            elem_id='legacy_image_uploader',\n                            format='png')\n                    with gr.Column():\n                        self.legacy_image_viewer = gr.Image(\n                            label='Image',\n                            height=550,\n                            type='pil',\n                            interactive=False,\n                            show_download_button=True,\n                            elem_id='image_viewer')\n\n\n                with gr.Accordion(label='Setting', open=False):\n                    with gr.Row():\n                        self.model_name_dd = gr.Dropdown(\n                            choices=self.model_choices,\n                            value=self.default_model_name,\n                            label='Model Version')\n\n                    with gr.Row():\n                        self.negative_prompt = gr.Textbox(\n                            value='',\n                            placeholder=\n                            'Negative prompt used for Classifier-Free Guidance',\n                            label='Negative Prompt',\n                            container=False)\n\n                    with gr.Row():\n                        # REFINER_PROMPT\n                        self.refiner_prompt = gr.Textbox(\n                            value=self.pipe.input.get(\"refiner_prompt\", \"\"),\n                            visible=self.pipe.input.get(\"refiner_prompt\", None) is not None,\n                            placeholder=\n                            'Prompt used for refiner',\n                            label='Refiner Prompt',\n                            container=False)\n\n\n                    with gr.Row():\n                        with gr.Column(scale=8, min_width=500):\n                            with gr.Row():\n                                self.step = gr.Slider(minimum=1,\n                                                      maximum=1000,\n                                                      value=self.pipe.input.get(\"sample_steps\", 20),\n                                                      visible=self.pipe.input.get(\"sample_steps\", None) is not None,\n                                                      label='Sample Step')\n                                self.cfg_scale = gr.Slider(\n                                    minimum=1.0,\n                                    maximum=20.0,\n                                    value=self.pipe.input.get(\"guide_scale\", 4.5),\n                                    visible=self.pipe.input.get(\"guide_scale\", None) is not None,\n                                    label='Guidance Scale')\n                                self.rescale = gr.Slider(minimum=0.0,\n                                                         maximum=1.0,\n                                                         value=self.pipe.input.get(\"guide_rescale\", 0.5),\n                                                         visible=self.pipe.input.get(\"guide_rescale\", None) is not None,\n                                                         label='Rescale')\n                                self.refiner_scale = gr.Slider(minimum=-0.1,\n                                                         maximum=1.0,\n                                                         value=self.pipe.input.get(\"refiner_scale\", -1),\n                                                         visible=self.pipe.input.get(\"refiner_scale\", None) is not None,\n                                                         label='Refiner Scale')\n                                self.seed = gr.Slider(minimum=-1,\n                                                      maximum=10000000,\n                                                      value=-1,\n                                                      label='Seed')\n                                self.output_height = gr.Slider(\n                                    minimum=256,\n                                    maximum=1440,\n                                    value=self.pipe.input.get(\"output_height\", 1024),\n                                    visible=self.pipe.input.get(\"output_height\", None) is not None,\n                                    label='Output Height')\n                                self.output_width = gr.Slider(\n                                    minimum=256,\n                                    maximum=1440,\n                                    value=self.pipe.input.get(\"output_width\", 1024),\n                                    visible=self.pipe.input.get(\"output_width\", None) is not None,\n                                    label='Output Width')\n                        with gr.Column(scale=1, min_width=50):\n                            self.use_history = gr.Checkbox(value=False,\n                                                           label='Use History')\n                            self.use_ace = gr.Checkbox(value=self.pipe.input.get(\"use_ace\", True),\n                                                       visible=self.pipe.input.get(\"use_ace\", None) is not None,\n                                                       label='Use ACE')\n                            self.video_auto = gr.Checkbox(\n                                value=False,\n                                label='Auto Gen Video',\n                                visible=self.enable_i2v)\n\n                    with gr.Row(variant='panel',\n                                equal_height=True,\n                                visible=self.enable_i2v):\n                        self.video_fps = gr.Slider(minimum=1,\n                                                   maximum=16,\n                                                   value=8,\n                                                   label='Video FPS',\n                                                   visible=True)\n                        self.video_frames = gr.Slider(minimum=8,\n                                                      maximum=49,\n                                                      value=49,\n                                                      label='Video Frame Num',\n                                                      visible=True)\n                        self.video_step = gr.Slider(minimum=1,\n                                                    maximum=1000,\n                                                    value=50,\n                                                    label='Video Sample Step',\n                                                    visible=True)\n                        self.video_cfg_scale = gr.Slider(\n                            minimum=1.0,\n                            maximum=20.0,\n                            value=6.0,\n                            label='Video Guidance Scale',\n                            visible=True)\n                        self.video_seed = gr.Slider(minimum=-1,\n                                                    maximum=10000000,\n                                                    value=-1,\n                                                    label='Video Seed',\n                                                    visible=True)\n\n                with gr.Row():\n                    self.chatbot_inst = \"\"\"\n                       **Instruction**:\n\n                       1. Click 'Upload' button to upload one or more images as input images.\n                       2. Enter '@' in the text box will exhibit all images in the gallery.\n                       3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.\n                       4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction.\n                       For example, you might say, \"Change the girl's skirt in @123456 to blue.\" The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input.\n                       5. Once your instructions are prepared, please click the \"Chat\" button to view the edited result in the chat window.\n                       6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, \"add text 'g i r l' on the mask area of @xxxxx\".\n                       7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.\n                       8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.\n\n                    \"\"\"\n\n                    self.legacy_inst = \"\"\"\n                       **Instruction**:\n\n                       1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..\n                       2. Enter '@' in the text box will exhibit all images in the gallery.\n                       3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.\n                       4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, \"add text 'g i r l' on the mask area of @xxxxx\".\n                       5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..\n                       6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.\n\n                    \"\"\"\n\n                    self.instruction = gr.Markdown(value=self.legacy_inst)\n\n                with gr.Row(variant='panel',\n                            equal_height=True,\n                            show_progress=False):\n                    with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:\n                        self.upload_btn = gr.Button(value=upload_sty +\n                                                    ' Upload',\n                                                    variant='secondary')\n                    with gr.Column(scale=5, min_width=500):\n                        self.text = gr.Textbox(\n                            placeholder='Input \"@\" find history of image',\n                            label='Instruction',\n                            container=False)\n                    with gr.Column(scale=1, min_width=100):\n                        self.chat_btn = gr.Button(value='Generate',\n                                                  variant='primary')\n                    with gr.Column(scale=1, min_width=100):\n                        self.retry_btn = gr.Button(value=refresh_sty +\n                                                   ' Retry',\n                                                   variant='secondary')\n                    with gr.Column(scale=1, min_width=100):\n                        self.mode_checkbox = gr.Checkbox(\n                            value=False,\n                            label='ChatBot')\n                    with gr.Column(scale=(1 if self.enable_i2v else 0),\n                                   min_width=0):\n                        self.video_gen_btn = gr.Button(value=video_sty +\n                                                       ' Gen Video',\n                                                       variant='secondary',\n                                                       visible=self.enable_i2v)\n                    with gr.Column(scale=(1 if self.enable_i2v else 0),\n                                   min_width=0):\n                        self.extend_prompt = gr.Checkbox(\n                            value=True,\n                            label='Extend Prompt',\n                            visible=self.enable_i2v)\n\n                with gr.Row():\n                    self.gallery = gr.Gallery(visible=False,\n                                              label='History',\n                                              columns=10,\n                                              allow_preview=False,\n                                              interactive=False)\n\n                self.eg = gr.Column(visible=True)\n\n    def set_callbacks(self, *args, **kwargs):\n\n        ########################################\n        def change_model(model_name):\n            if model_name not in self.model_choices:\n                gr.Info('The provided model name is not a valid choice!')\n                return model_name, gr.update(), gr.update()\n\n            if model_name != self.model_name:\n                lock.acquire()\n                del self.pipe\n                torch.cuda.empty_cache()\n                torch.cuda.ipc_collect()\n                self.pipe = ACEInference()\n                self.pipe.init_from_cfg(self.model_choices[model_name])\n                self.model_name = model_name\n                lock.release()\n\n            return (model_name, gr.update(), gr.update(),\n                    gr.Slider(\n                              value=self.pipe.input.get(\"sample_steps\", 20),\n                              visible=self.pipe.input.get(\"sample_steps\", None) is not None),\n                    gr.Slider(\n                        value=self.pipe.input.get(\"guide_scale\", 4.5),\n                        visible=self.pipe.input.get(\"guide_scale\", None) is not None),\n                    gr.Slider(\n                              value=self.pipe.input.get(\"guide_rescale\", 0.5),\n                              visible=self.pipe.input.get(\"guide_rescale\", None) is not None),\n                    gr.Slider(\n                        value=self.pipe.input.get(\"output_height\", 1024),\n                        visible=self.pipe.input.get(\"output_height\", None) is not None),\n                    gr.Slider(\n                        value=self.pipe.input.get(\"output_width\", 1024),\n                        visible=self.pipe.input.get(\"output_width\", None) is not None),\n                    gr.Textbox(\n                        value=self.pipe.input.get(\"refiner_prompt\", \"\"),\n                        visible=self.pipe.input.get(\"refiner_prompt\", None) is not None),\n                    gr.Slider(\n                              value=self.pipe.input.get(\"refiner_scale\", -1),\n                              visible=self.pipe.input.get(\"refiner_scale\", None) is not None\n                        ),\n                    gr.Checkbox(\n                        value=self.pipe.input.get(\"use_ace\", True),\n                        visible=self.pipe.input.get(\"use_ace\", None) is not None\n                    )\n                    )\n\n        self.model_name_dd.change(\n            change_model,\n            inputs=[self.model_name_dd],\n            outputs=[\n                self.model_name_dd, self.chatbot, self.text,\n                self.step,\n                self.cfg_scale, self.rescale, self.output_height,\n                self.output_width, self.refiner_prompt, self.refiner_scale,\n                self.use_ace])\n\n\n        def mode_change(mode_check):\n            if mode_check:\n                # ChatBot\n                return (\n                    gr.Row(visible=False),\n                    gr.Row(visible=True),\n                    gr.Button(value='Generate'),\n                    gr.State(value='chatbot'),\n                    gr.Column(visible=True),\n                    gr.Markdown(value=self.chatbot_inst)\n                )\n            else:\n                # Legacy\n                return (\n                    gr.Row(visible=True),\n                    gr.Row(visible=False),\n                    gr.Button(value=chat_sty + ' Chat'),\n                    gr.State(value='legacy'),\n                    gr.Column(visible=False),\n                    gr.Markdown(value=self.legacy_inst)\n                )\n        self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],\n                                  outputs=[self.legacy_group, self.chat_group,\n                                           self.chat_btn, self.ui_mode,\n                                           self.upload_panel, self.instruction])\n\n\n        ########################################\n        def generate_gallery(text, images):\n            if text.endswith(' '):\n                return gr.update(), gr.update(visible=False)\n            elif text.endswith('@'):\n                gallery_info = []\n                for image_id, image_meta in images.items():\n                    thumbnail_path = image_meta['thumbnail']\n                    gallery_info.append((thumbnail_path, image_id))\n                return gr.update(), gr.update(visible=True, value=gallery_info)\n            else:\n                gallery_info = []\n                match = re.search('@([^@ ]+)$', text)\n                if match:\n                    prefix = match.group(1)\n                    for image_id, image_meta in images.items():\n                        if not image_id.startswith(prefix):\n                            continue\n                        thumbnail_path = image_meta['thumbnail']\n                        gallery_info.append((thumbnail_path, image_id))\n\n                    if len(gallery_info) > 0:\n                        return gr.update(), gr.update(visible=True,\n                                                      value=gallery_info)\n                    else:\n                        return gr.update(), gr.update(visible=False)\n                else:\n                    return gr.update(), gr.update(visible=False)\n\n        self.text.input(generate_gallery,\n                        inputs=[self.text, self.images],\n                        outputs=[self.text, self.gallery],\n                        show_progress='hidden')\n\n        ########################################\n        def select_image(text, evt: gr.SelectData):\n            image_id = evt.value['caption']\n            text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '\n            return gr.update(value=text), gr.update(visible=False, value=None)\n\n        self.gallery.select(select_image,\n                            inputs=self.text,\n                            outputs=[self.text, self.gallery])\n\n        ########################################\n        def generate_video(message,\n                           extend_prompt,\n                           history,\n                           images,\n                           num_steps,\n                           num_frames,\n                           cfg_scale,\n                           fps,\n                           seed,\n                           progress=gr.Progress(track_tqdm=True)):\n\n            from diffusers.utils import export_to_video\n\n            generator = torch.Generator(device='cuda').manual_seed(seed)\n            img_ids = re.findall('@(.*?)[ ,;.?$]', message)\n            if len(img_ids) == 0:\n                history.append((\n                    message,\n                    'Sorry, no images were found in the prompt to be used as the first frame of the video.'\n                ))\n                while len(history) >= self.max_msgs:\n                    history.pop(0)\n                return history, self.get_history(\n                    history), gr.update(), gr.update(visible=False)\n\n            img_id = img_ids[0]\n            prompt = re.sub(f'@{img_id}\\s+', '', message)\n\n            if extend_prompt:\n                messages = copy.deepcopy(self.enhance_ctx)\n                messages.append({\n                    'role':\n                    'user',\n                    'content':\n                    f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: \"{prompt}\"',\n                })\n                lock.acquire()\n                outputs = self.enhancer(\n                    messages,\n                    max_new_tokens=200,\n                )\n\n                prompt = outputs[0]['generated_text'][-1]['content']\n                print(prompt)\n                lock.release()\n\n            img_meta = images[img_id]\n            img_path = img_meta['image']\n            image = Image.open(img_path).convert('RGB')\n\n            lock.acquire()\n            video = self.i2v_pipe(\n                prompt=prompt,\n                image=image,\n                num_videos_per_prompt=1,\n                num_inference_steps=num_steps,\n                num_frames=num_frames,\n                guidance_scale=cfg_scale,\n                generator=generator,\n            ).frames[0]\n            lock.release()\n\n            out_video_path = export_to_video(video, fps=fps)\n            history.append((\n                f\"Based on first frame @{img_id} and description '{prompt}', generate a video\",\n                'This is generated video:'))\n            history.append((None, out_video_path))\n            while len(history) >= self.max_msgs:\n                history.pop(0)\n\n            return history, self.get_history(history), gr.update(\n                value=''), gr.update(visible=False)\n\n        self.video_gen_btn.click(\n            generate_video,\n            inputs=[\n                self.text, self.extend_prompt, self.history, self.images,\n                self.video_step, self.video_frames, self.video_cfg_scale,\n                self.video_fps, self.video_seed\n            ],\n            outputs=[self.history, self.chatbot, self.text, self.gallery])\n\n        ########################################\n        def run_chat(\n                     message,\n                     legacy_image,\n                     ui_mode,\n                     use_ace,\n                     extend_prompt,\n                     history,\n                     images,\n                     use_history,\n                     history_result,\n                     negative_prompt,\n                     cfg_scale,\n                     rescale,\n                     refiner_prompt,\n                     refiner_scale,\n                     step,\n                     seed,\n                     output_h,\n                     output_w,\n                     video_auto,\n                     video_steps,\n                     video_frames,\n                     video_cfg_scale,\n                     video_fps,\n                     video_seed,\n                     progress=gr.Progress(track_tqdm=True)):\n            legacy_img_ids = []\n            if ui_mode == 'legacy':\n                if legacy_image is not None:\n                    history, images, img_id = self.add_uploaded_image_to_history(\n                        legacy_image, history, images)\n                    legacy_img_ids.append(img_id)\n            retry_msg = message\n            gen_id = get_md5(message)[:12]\n            save_path = os.path.join(self.cache_dir, f'{gen_id}.png')\n\n            img_ids = re.findall('@(.*?)[ ,;.?$]', message)\n            history_io = None\n\n            if len(img_ids) < 1:\n                img_ids = legacy_img_ids\n                for img_id in img_ids:\n                    if f'@{img_id}' not in message:\n                        message = f'@{img_id} ' + message\n\n            new_message = message\n\n            if len(img_ids) > 0:\n                edit_image, edit_image_mask, edit_task = [], [], []\n                for i, img_id in enumerate(img_ids):\n                    if img_id not in images:\n                        gr.Info(\n                            f'The input image ID {img_id} is not exist... Skip loading image.'\n                        )\n                        continue\n                    placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'\n                    new_message = re.sub(f'@{img_id}', placeholder,\n                                         new_message)\n                    img_meta = images[img_id]\n                    img_path = img_meta['image']\n                    img_mask = img_meta['mask']\n                    img_mask_type = img_meta['mask_type']\n                    if img_mask_type is not None and img_mask_type == 'Composite':\n                        task = 'inpainting'\n                    else:\n                        task = ''\n                    edit_image.append(Image.open(img_path).convert('RGB'))\n                    edit_image_mask.append(\n                        Image.open(img_mask).\n                        convert('L') if img_mask is not None else None)\n                    edit_task.append(task)\n\n                    if use_history and (img_id in history_result):\n                        history_io = history_result[img_id]\n\n                buffered = io.BytesIO()\n                edit_image[0].save(buffered, format='PNG')\n                img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n                img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n                pre_info = f'Received one or more images, so image editing is conducted.\\n The first input image @{img_ids[0]} is:\\n {img_str}'\n            else:\n                pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \\n'\n                edit_image = None\n                edit_image_mask = None\n                edit_task = ''\n\n            print(new_message)\n            imgs = self.pipe(\n                image=edit_image,\n                mask=edit_image_mask,\n                task=edit_task,\n                prompt=[new_message] *\n                len(edit_image) if edit_image is not None else [new_message],\n                negative_prompt=[negative_prompt] * len(edit_image)\n                if edit_image is not None else [negative_prompt],\n                history_io=history_io,\n                output_height=output_h,\n                output_width=output_w,\n                sampler='ddim',\n                sample_steps=step,\n                guide_scale=cfg_scale,\n                guide_rescale=rescale,\n                seed=seed,\n                refiner_prompt=refiner_prompt,\n                refiner_scale=refiner_scale,\n                use_ace=use_ace\n            )\n\n            img = imgs[0]\n            img.save(save_path, format='PNG')\n\n            if history_io:\n                history_io_new = copy.deepcopy(history_io)\n                history_io_new['image'] += edit_image[:1]\n                history_io_new['mask'] += edit_image_mask[:1]\n                history_io_new['task'] += edit_task[:1]\n                history_io_new['prompt'] += [new_message]\n                history_io_new['image'] = history_io_new['image'][-5:]\n                history_io_new['mask'] = history_io_new['mask'][-5:]\n                history_io_new['task'] = history_io_new['task'][-5:]\n                history_io_new['prompt'] = history_io_new['prompt'][-5:]\n                history_result[gen_id] = history_io_new\n            elif edit_image is not None and len(edit_image) > 0:\n                history_io_new = {\n                    'image': edit_image[:1],\n                    'mask': edit_image_mask[:1],\n                    'task': edit_task[:1],\n                    'prompt': [new_message]\n                }\n                history_result[gen_id] = history_io_new\n\n            w, h = img.size\n            if w > h:\n                tb_w = 128\n                tb_h = int(h * tb_w / w)\n            else:\n                tb_h = 128\n                tb_w = int(w * tb_h / h)\n\n            thumbnail_path = os.path.join(self.cache_dir,\n                                          f'{gen_id}_thumbnail.jpg')\n            thumbnail = img.resize((tb_w, tb_h))\n            thumbnail.save(thumbnail_path, format='JPEG')\n\n            images[gen_id] = {\n                'image': save_path,\n                'mask': None,\n                'mask_type': None,\n                'thumbnail': thumbnail_path\n            }\n\n            buffered = io.BytesIO()\n            img.convert('RGB').save(buffered, format='PNG')\n            img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n            img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n\n            history.append(\n                (message,\n                 f'{pre_info} The generated image @{gen_id} is:\\n {img_str}'))\n\n            if video_auto:\n                if video_seed is None or video_seed == -1:\n                    video_seed = random.randint(0, 10000000)\n\n                lock.acquire()\n                generator = torch.Generator(\n                    device='cuda').manual_seed(video_seed)\n                pixel_values = load_image(img.convert('RGB'),\n                                          max_num=self.llm_max_num).to(\n                                              torch.bfloat16).cuda()\n                prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,\n                                             self.llm_prompt,\n                                             self.llm_generation_config)\n                print(prompt)\n                lock.release()\n\n                if extend_prompt:\n                    messages = copy.deepcopy(self.enhance_ctx)\n                    messages.append({\n                        'role':\n                        'user',\n                        'content':\n                        f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: \"{prompt}\"',\n                    })\n                    lock.acquire()\n                    outputs = self.enhancer(\n                        messages,\n                        max_new_tokens=200,\n                    )\n                    prompt = outputs[0]['generated_text'][-1]['content']\n                    print(prompt)\n                    lock.release()\n\n                lock.acquire()\n                video = self.i2v_pipe(\n                    prompt=prompt,\n                    image=img,\n                    num_videos_per_prompt=1,\n                    num_inference_steps=video_steps,\n                    num_frames=video_frames,\n                    guidance_scale=video_cfg_scale,\n                    generator=generator,\n                ).frames[0]\n                lock.release()\n\n                out_video_path = export_to_video(video, fps=video_fps)\n                history.append((\n                    f\"Based on first frame @{gen_id} and description '{prompt}', generate a video\",\n                    'This is generated video:'))\n                history.append((None, out_video_path))\n\n            while len(history) >= self.max_msgs:\n                history.pop(0)\n\n            return (history, images, gr.Image(value=save_path),\n                    history_result, self.get_history(\n                history), gr.update(), gr.update(\n                    visible=False), retry_msg)\n\n        chat_inputs = [\n            self.legacy_image_uploader, self.ui_mode, self.use_ace,\n            self.extend_prompt, self.history, self.images, self.use_history,\n            self.history_result, self.negative_prompt, self.cfg_scale,\n            self.rescale, self.refiner_prompt, self.refiner_scale,\n            self.step, self.seed, self.output_height,\n            self.output_width, self.video_auto, self.video_step,\n            self.video_frames, self.video_cfg_scale, self.video_fps,\n            self.video_seed\n        ]\n\n        chat_outputs = [\n            self.history, self.images, self.legacy_image_viewer,\n            self.history_result, self.chatbot,\n            self.text, self.gallery, self.retry_msg\n        ]\n\n        self.chat_btn.click(run_chat,\n                            inputs=[self.text] + chat_inputs,\n                            outputs=chat_outputs)\n\n        self.text.submit(run_chat,\n                         inputs=[self.text] + chat_inputs,\n                         outputs=chat_outputs)\n\n        def retry_fn(*args):\n            return run_chat(*args)\n\n        self.retry_btn.click(retry_fn,\n                             inputs=[self.retry_msg] + chat_inputs,\n                             outputs=chat_outputs)\n\n        ########################################\n        def run_example(task, img, img_mask, ref1, prompt, seed):\n            edit_image, edit_image_mask, edit_task = [], [], []\n            if img is not None:\n                w, h = img.size\n                if w > 2048:\n                    ratio = w / 2048.\n                    w = 2048\n                    h = int(h / ratio)\n                if h > 2048:\n                    ratio = h / 2048.\n                    h = 2048\n                    w = int(w / ratio)\n                img = img.resize((w, h))\n                edit_image.append(img)\n                if img_mask is not None:\n                    img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None\n                edit_image_mask.append(\n                    img_mask if img_mask is not None else None)\n                edit_task.append(task)\n                if ref1 is not None:\n                    ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None\n                if ref1 is not None:\n                    edit_image.append(ref1)\n                    edit_image_mask.append(None)\n                    edit_task.append('')\n\n                buffered = io.BytesIO()\n                img.save(buffered, format='PNG')\n                img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n                img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n                pre_info = f'Received one or more images, so image editing is conducted.\\n The first input image is:\\n {img_str}'\n            else:\n                pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \\n'\n                edit_image = None\n                edit_image_mask = None\n                edit_task = ''\n\n            img_num = len(edit_image) if edit_image is not None else 1\n            imgs = self.pipe(\n                image=edit_image,\n                mask=edit_image_mask,\n                task=edit_task,\n                prompt=[prompt] * img_num,\n                negative_prompt=[''] * img_num,\n                seed=seed,\n                refiner_prompt=self.pipe.input.get(\"refiner_prompt\", \"\"),\n                refiner_scale=self.pipe.input.get(\"refiner_scale\", 0.0),\n            )\n\n            img = imgs[0]\n            buffered = io.BytesIO()\n            img.convert('RGB').save(buffered, format='PNG')\n            img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n            img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n            history = [(prompt,\n                        f'{pre_info} The generated image is:\\n {img_str}')]\n\n            img_id = get_md5(img_b64)[:12]\n            save_path = os.path.join(self.cache_dir, f'{img_id}.png')\n            img.convert('RGB').save(save_path)\n            return self.get_history(history), gr.update(value=''), gr.update(\n                visible=False),  gr.update(value=save_path), gr.update(value=-1)\n\n        with self.eg:\n            self.example_task = gr.Text(label='Task Name',\n                                        value='',\n                                        visible=False)\n            self.example_image = gr.Image(label='Edit Image',\n                                          type='pil',\n                                          image_mode='RGB',\n                                          visible=False)\n            self.example_mask = gr.Image(label='Edit Image Mask',\n                                         type='pil',\n                                         image_mode='L',\n                                         visible=False)\n            self.example_ref_im1 = gr.Image(label='Ref Image',\n                                            type='pil',\n                                            image_mode='RGB',\n                                            visible=False)\n            self.examples = gr.Examples(\n                fn=run_example,\n                examples=self.chatbot_examples,\n                inputs=[\n                    self.example_task, self.example_image, self.example_mask,\n                    self.example_ref_im1, self.text, self.seed\n                ],\n                outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed],\n                examples_per_page=4,\n                cache_examples=False,\n                run_on_click=True)\n\n        ########################################\n        def upload_image():\n            return (gr.update(visible=True,\n                              scale=1), gr.update(visible=True, scale=1),\n                    gr.update(visible=True), gr.update(visible=False),\n                    gr.update(visible=False), gr.update(visible=False),\n                    gr.update(visible=True))\n\n        self.upload_btn.click(upload_image,\n                              inputs=[],\n                              outputs=[\n                                  self.chat_page, self.editor_page,\n                                  self.upload_tab, self.edit_tab,\n                                  self.image_view_tab, self.video_view_tab,\n                                  self.upload_tabs\n                              ])\n\n        ########################################\n        def edit_image(evt: gr.SelectData):\n            if isinstance(evt.value, str):\n                img_b64s = re.findall(\n                    '<img src=\"data:image/png;base64,(.*?)\" style=\"pointer-events: none;\">',\n                    evt.value)\n                imgs = [\n                    Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))\n                    for i in img_b64s\n                ]\n                if len(imgs) > 0:\n                    if len(imgs) == 2:\n                        if self.gradio_version >= '5.0.0':\n                            view_img = copy.deepcopy(imgs[-1])\n                        else:\n                            view_img = copy.deepcopy(imgs)\n                        edit_img = copy.deepcopy(imgs[-1])\n                    else:\n                        if self.gradio_version >= '5.0.0':\n                            view_img = copy.deepcopy(imgs[-1])\n                        else:\n                            view_img = [\n                                copy.deepcopy(imgs[-1]),\n                                copy.deepcopy(imgs[-1])\n                            ]\n                        edit_img = copy.deepcopy(imgs[-1])\n\n                    return (gr.update(visible=True,\n                                      scale=1), gr.update(visible=True,\n                                                          scale=1),\n                            gr.update(visible=False), gr.update(visible=True),\n                            gr.update(visible=True), gr.update(visible=False),\n                            gr.update(value=edit_img),\n                            gr.update(value=view_img), gr.update(value=None),\n                            gr.update(visible=True))\n                else:\n                    return (gr.update(), gr.update(), gr.update(), gr.update(),\n                            gr.update(), gr.update(), gr.update(), gr.update(),\n                            gr.update(), gr.update())\n            elif isinstance(evt.value, dict) and evt.value.get(\n                    'component', '') == 'video':\n                value = evt.value['value']['video']['path']\n                return (gr.update(visible=True,\n                                  scale=1), gr.update(visible=True, scale=1),\n                        gr.update(visible=False), gr.update(visible=False),\n                        gr.update(visible=False), gr.update(visible=True),\n                        gr.update(), gr.update(), gr.update(value=value),\n                        gr.update())\n            else:\n                return (gr.update(), gr.update(), gr.update(), gr.update(),\n                        gr.update(), gr.update(), gr.update(), gr.update(),\n                        gr.update(), gr.update())\n\n        self.chatbot.select(edit_image,\n                            outputs=[\n                                self.chat_page, self.editor_page,\n                                self.upload_tab, self.edit_tab,\n                                self.image_view_tab, self.video_view_tab,\n                                self.image_editor, self.image_viewer,\n                                self.video_viewer, self.edit_tabs\n                            ])\n\n        if self.gradio_version < '5.0.0':\n            self.image_viewer.change(lambda x: x,\n                                     inputs=self.image_viewer,\n                                     outputs=self.image_viewer)\n\n        ########################################\n        def submit_upload_image(image, history, images):\n            history, images, _ = self.add_uploaded_image_to_history(\n                image, history, images)\n            return gr.update(visible=False), gr.update(\n                visible=True), gr.update(\n                    value=self.get_history(history)), history, images\n\n        self.sub_btn_1.click(\n            submit_upload_image,\n            inputs=[self.image_uploader, self.history, self.images],\n            outputs=[\n                self.editor_page, self.chat_page, self.chatbot, self.history,\n                self.images\n            ])\n\n        ########################################\n        def submit_edit_image(imagemask, mask_type, history, images):\n            history, images = self.add_edited_image_to_history(\n                imagemask, mask_type, history, images)\n            return gr.update(visible=False), gr.update(\n                visible=True), gr.update(\n                    value=self.get_history(history)), history, images\n\n        self.sub_btn_2.click(submit_edit_image,\n                             inputs=[\n                                 self.image_editor, self.mask_type,\n                                 self.history, self.images\n                             ],\n                             outputs=[\n                                 self.editor_page, self.chat_page,\n                                 self.chatbot, self.history, self.images\n                             ])\n\n        ########################################\n        def exit_edit():\n            return gr.update(visible=False), gr.update(visible=True, scale=3)\n\n        self.ext_btn_1.click(exit_edit,\n                             outputs=[self.editor_page, self.chat_page])\n        self.ext_btn_2.click(exit_edit,\n                             outputs=[self.editor_page, self.chat_page])\n        self.ext_btn_3.click(exit_edit,\n                             outputs=[self.editor_page, self.chat_page])\n        self.ext_btn_4.click(exit_edit,\n                             outputs=[self.editor_page, self.chat_page])\n\n        ########################################\n        def update_mask_type_info(mask_type):\n            if mask_type == 'Background':\n                info = 'Background mode will not erase the visual content in the mask area'\n                visible = False\n            elif mask_type == 'Composite':\n                info = 'Composite mode will erase the visual content in the mask area'\n                visible = False\n            elif mask_type == 'Outpainting':\n                info = 'Outpaint mode is used for preparing input image for outpainting task'\n                visible = True\n            return (gr.update(\n                visible=True,\n                value=\n                f\"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>\"\n            ), gr.update(visible=visible))\n\n        self.mask_type.change(update_mask_type_info,\n                              inputs=self.mask_type,\n                              outputs=[self.mask_type_info, self.outpaint_tab])\n\n        ########################################\n        def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,\n                         image):\n            img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)\n            h, w = img.shape[:2]\n            new_h = int(h * (top_ratio + bottom_ratio + 1))\n            new_w = int(w * (left_ratio + right_ratio + 1))\n            start_h = int(h * top_ratio)\n            start_w = int(w * left_ratio)\n            new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)\n            new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255\n            new_img[start_h:start_h + h, start_w:start_w + w, :] = img\n            new_mask[start_h:start_h + h, start_w:start_w + w] = 0\n            layer = np.concatenate([new_img, new_mask], axis=2)\n            value = {\n                'background': new_img,\n                'composite': new_img,\n                'layers': [layer]\n            }\n            return gr.update(value=value)\n\n        self.img_pad_btn.click(extend_image,\n                               inputs=[\n                                   self.top_ext, self.bottom_ext,\n                                   self.left_ext, self.right_ext,\n                                   self.image_editor\n                               ],\n                               outputs=self.image_editor)\n\n        ########################################\n        def clear_chat(history, images, history_result):\n            history.clear()\n            images.clear()\n            history_result.clear()\n            return history, images, history_result, self.get_history(history)\n\n        self.clear_btn.click(\n            clear_chat,\n            inputs=[self.history, self.images, self.history_result],\n            outputs=[\n                self.history, self.images, self.history_result, self.chatbot\n            ])\n\n    def get_history(self, history):\n        info = []\n        for item in history:\n            new_item = [None, None]\n            if isinstance(item[0], str) and item[0].endswith('.mp4'):\n                new_item[0] = gr.Video(item[0], format='mp4')\n            else:\n                new_item[0] = item[0]\n            if isinstance(item[1], str) and item[1].endswith('.mp4'):\n                new_item[1] = gr.Video(item[1], format='mp4')\n            else:\n                new_item[1] = item[1]\n            info.append(new_item)\n        return info\n\n    def generate_random_string(self, length=20):\n        letters_and_digits = string.ascii_letters + string.digits\n        random_string = ''.join(\n            random.choice(letters_and_digits) for i in range(length))\n        return random_string\n\n    def add_edited_image_to_history(self, image, mask_type, history, images):\n        if mask_type == 'Composite':\n            img = Image.fromarray(image['composite'])\n        else:\n            img = Image.fromarray(image['background'])\n\n        img_id = get_md5(self.generate_random_string())[:12]\n        save_path = os.path.join(self.cache_dir, f'{img_id}.png')\n        img.convert('RGB').save(save_path)\n\n        mask = image['layers'][0][:, :, 3]\n        mask = Image.fromarray(mask).convert('RGB')\n        mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')\n        mask.save(mask_path)\n\n        w, h = img.size\n        if w > h:\n            tb_w = 128\n            tb_h = int(h * tb_w / w)\n        else:\n            tb_h = 128\n            tb_w = int(w * tb_h / h)\n\n        if mask_type == 'Background':\n            comp_mask = np.array(mask, dtype=np.uint8)\n            mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *\n                          0.6).astype(np.uint8)\n            comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)\n            thumbnail = Image.alpha_composite(\n                img.convert('RGBA'),\n                Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')\n        else:\n            thumbnail = img.convert('RGB')\n\n        thumbnail_path = os.path.join(self.cache_dir,\n                                      f'{img_id}_thumbnail.jpg')\n        thumbnail = thumbnail.resize((tb_w, tb_h))\n        thumbnail.save(thumbnail_path, format='JPEG')\n\n        buffered = io.BytesIO()\n        img.convert('RGB').save(buffered, format='PNG')\n        img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n        img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n\n        buffered = io.BytesIO()\n        mask.convert('RGB').save(buffered, format='PNG')\n        mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n        mask_str = f'<img src=\"data:image/png;base64,{mask_b64}\" style=\"pointer-events: none;\">'\n\n        images[img_id] = {\n            'image': save_path,\n            'mask': mask_path,\n            'mask_type': mask_type,\n            'thumbnail': thumbnail_path\n        }\n        history.append((\n            None,\n            f'This is edited image and mask:\\n {img_str} {mask_str} image ID is: {img_id}'\n        ))\n        return history, images\n\n    def add_uploaded_image_to_history(self, img, history, images):\n        img_id = get_md5(self.generate_random_string())[:12]\n        save_path = os.path.join(self.cache_dir, f'{img_id}.png')\n        w, h = img.size\n        if w > 2048:\n            ratio = w / 2048.\n            w = 2048\n            h = int(h / ratio)\n        if h > 2048:\n            ratio = h / 2048.\n            h = 2048\n            w = int(w / ratio)\n        img = img.resize((w, h))\n        img.save(save_path)\n\n        w, h = img.size\n        if w > h:\n            tb_w = 128\n            tb_h = int(h * tb_w / w)\n        else:\n            tb_h = 128\n            tb_w = int(w * tb_h / h)\n        thumbnail_path = os.path.join(self.cache_dir,\n                                      f'{img_id}_thumbnail.jpg')\n        thumbnail = img.resize((tb_w, tb_h))\n        thumbnail.save(thumbnail_path, format='JPEG')\n\n        images[img_id] = {\n            'image': save_path,\n            'mask': None,\n            'mask_type': None,\n            'thumbnail': thumbnail_path\n        }\n\n        buffered = io.BytesIO()\n        img.convert('RGB').save(buffered, format='PNG')\n        img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')\n        img_str = f'<img src=\"data:image/png;base64,{img_b64}\" style=\"pointer-events: none;\">'\n\n        history.append(\n            (None,\n             f'This is uploaded image:\\n {img_str} image ID is: {img_id}'))\n        return history, images, img_id\n\n\ndef run_gr(cfg):\n    with gr.Blocks() as demo:\n        chatbot = ChatBotUI(cfg)\n        chatbot.create_ui()\n        chatbot.set_callbacks()\n        demo.launch(server_name='0.0.0.0',\n                    server_port=cfg.args.server_port,\n                    root_path=cfg.args.root_path)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Argparser for Scepter:\\n')\n    parser.add_argument('--server_port',\n                        dest='server_port',\n                        help='',\n                        type=int,\n                        default=2345)\n    parser.add_argument('--root_path', dest='root_path', help='', default='')\n    cfg = Config(load=True, parser_ins=parser)\n    run_gr(cfg)\n"
  },
  {
    "path": "chatbot/utils.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nimport torchvision.transforms as T\nfrom PIL import Image\nfrom torchvision.transforms.functional import InterpolationMode\n\nIMAGENET_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_STD = (0.229, 0.224, 0.225)\n\n\ndef build_transform(input_size):\n    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n    transform = T.Compose([\n        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n        T.Resize((input_size, input_size),\n                 interpolation=InterpolationMode.BICUBIC),\n        T.ToTensor(),\n        T.Normalize(mean=MEAN, std=STD)\n    ])\n    return transform\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,\n                              image_size):\n    best_ratio_diff = float('inf')\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    return best_ratio\n\n\ndef dynamic_preprocess(image,\n                       min_num=1,\n                       max_num=12,\n                       image_size=448,\n                       use_thumbnail=False):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set((i, j) for n in range(min_num, max_num + 1)\n                        for i in range(1, n + 1) for j in range(1, n + 1)\n                        if i * j <= max_num and i * j >= min_num)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,\n                                                    target_ratios, orig_width,\n                                                    orig_height, image_size)\n\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = ((i % (target_width // image_size)) * image_size,\n               (i // (target_width // image_size)) * image_size,\n               ((i % (target_width // image_size)) + 1) * image_size,\n               ((i // (target_width // image_size)) + 1) * image_size)\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n\n\ndef load_image(image_file, input_size=448, max_num=12):\n    if isinstance(image_file, str):\n        image = Image.open(image_file).convert('RGB')\n    else:\n        image = image_file\n    transform = build_transform(input_size=input_size)\n    images = dynamic_preprocess(image,\n                                image_size=input_size,\n                                use_thumbnail=True,\n                                max_num=max_num)\n    pixel_values = [transform(image) for image in images]\n    pixel_values = torch.stack(pixel_values)\n    return pixel_values\n"
  },
  {
    "path": "config/inference_config/chatbot_ui.yaml",
    "content": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n  - NAME: \"HuggingfaceFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"ModelscopeFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"LocalFs\"\n    TEMP_DIR: ./cache\n  - NAME: \"HttpFs\"\n    TEMP_DIR: ./cache\n#\nENABLE_I2V: False\nSKIP_EXAMPLES: False\n#\nMODEL:\n  EDIT_MODEL:\n    MODEL_CFG_DIR: config/inference_config/models/\n    DEFAULT: ace_0.6b_512\n  I2V:\n    MODEL_NAME: CogVideoX-5b-I2V\n    MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/\n  CAPTIONER:\n    MODEL_NAME: InternVL2-2B\n    MODEL_DIR: ms://OpenGVLab/InternVL2-2B/\n    PROMPT: '<image>\\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as \"a dog running\" or \"a person turns to left\". No more than 30 words.'\n  ENHANCER:\n    MODEL_NAME: Meta-Llama-3.1-8B-Instruct\n    MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/\n"
  },
  {
    "path": "config/inference_config/models/ace_0.6b_1024.yaml",
    "content": "NAME: ACE_0.6B_1024\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    INPUT_MASK:\n    TASK:\n    PROMPT: \"\"\n    NEGATIVE_PROMPT: \"\"\n    OUTPUT_HEIGHT: 1024\n    OUTPUT_WIDTH: 1024\n    SAMPLER: ddim\n    SAMPLE_STEPS: 50\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n    SEED: -1\n    TAR_INDEX: 0\n  OUTPUT:\n    LATENT:\n    IMAGES:\n    SEED:\n  MODULES_PARAS:\n    FIRST_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode\n          DTYPE: float16\n          INPUT: [\"IMAGE\"]\n        - NAME: decode\n          DTYPE: float16\n          INPUT: [\"LATENT\"]\n    #\n    DIFFUSION_MODEL:\n      FUNCTION:\n        - NAME: forward\n          DTYPE: float16\n          INPUT: [\"SAMPLE_STEPS\", \"SAMPLE\", \"GUIDE_SCALE\"]\n    #\n    COND_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode_list\n          DTYPE: bfloat16\n          INPUT: [\"PROMPT\"]\n#\nMODEL:\n  NAME: LdmACE\n  PRETRAINED_MODEL:\n  IGNORE_KEYS: [ ]\n  SCALE_FACTOR: 0.18215\n  SIZE_FACTOR: 8\n  DECODER_BIAS: 0.5\n  DEFAULT_N_PROMPT: \"\"\n  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n  USE_TEXT_POS_EMBEDDINGS: True\n  #\n  DIFFUSION:\n    NAME: ACEDiffusion\n    PREDICTION_TYPE: eps\n    MIN_SNR_GAMMA:\n    NOISE_SCHEDULER:\n      NAME: LinearScheduler\n      NUM_TIMESTEPS: 1000\n      BETA_MIN: 0.0001\n      BETA_MAX: 0.02\n  #\n  DIFFUSION_MODEL:\n    NAME: DiTACE\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth\n    IGNORE_KEYS: [ ]\n    PATCH_SIZE: 2\n    IN_CHANNELS: 4\n    HIDDEN_SIZE: 1152\n    DEPTH: 28\n    NUM_HEADS: 16\n    MLP_RATIO: 4.0\n    PRED_SIGMA: True\n    DROP_PATH: 0.0\n    WINDOW_DIZE: 0\n    Y_CHANNELS: 4096\n    MAX_SEQ_LEN: 4096\n    QK_NORM: True\n    USE_GRAD_CHECKPOINT: True\n    ATTENTION_BACKEND: flash_attn\n  #\n  FIRST_STAGE_MODEL:\n    NAME: AutoencoderKL\n    EMBED_DIM: 4\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin\n    IGNORE_KEYS: []\n    #\n    ENCODER:\n      NAME: Encoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DOUBLE_Z: True\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n    #\n    DECODER:\n      NAME: Decoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n      GIVE_PRE_END: False\n      TANH_OUT: False\n  #\n  COND_STAGE_MODEL:\n    NAME: ACETextEmbedder\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/\n    TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl\n    LENGTH: 120\n    T5_DTYPE: bfloat16\n    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    CLEAN: whitespace\n    USE_GRAD: False\n"
  },
  {
    "path": "config/inference_config/models/ace_0.6b_1024_refiner.yaml",
    "content": "NAME: ACE_0.6B_1024_REFINER\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    INPUT_MASK:\n    TASK:\n    PROMPT: \"\"\n    NEGATIVE_PROMPT: \"\"\n    OUTPUT_HEIGHT: 1024\n    OUTPUT_WIDTH: 1024\n    SAMPLER: ddim\n    SAMPLE_STEPS: 50\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n    SEED: -1\n    TAR_INDEX: 0\n    REFINER_SCALE: 0.2\n    USE_ACE: True\n    #REFINER_PROMPT: \"High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR\"\n    REFINER_PROMPT: \"High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR\"\n  OUTPUT:\n    LATENT:\n    IMAGES:\n    SEED:\n  MODULES_PARAS:\n    FIRST_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode\n          DTYPE: float16\n          INPUT: [\"IMAGE\"]\n        - NAME: decode\n          DTYPE: float16\n          INPUT: [\"LATENT\"]\n    #\n    DIFFUSION_MODEL:\n      FUNCTION:\n        - NAME: forward\n          DTYPE: float16\n          INPUT: [\"SAMPLE_STEPS\", \"SAMPLE\", \"GUIDE_SCALE\"]\n    #\n    COND_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode_list\n          DTYPE: bfloat16\n          INPUT: [\"PROMPT\"]\n#\nMODEL:\n  NAME: LdmACE\n  PRETRAINED_MODEL:\n  IGNORE_KEYS: [ ]\n  SCALE_FACTOR: 0.18215\n  SIZE_FACTOR: 8\n  DECODER_BIAS: 0.5\n  DEFAULT_N_PROMPT: \"\"\n  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n  USE_TEXT_POS_EMBEDDINGS: True\n  #\n  DIFFUSION:\n    NAME: ACEDiffusion\n    PREDICTION_TYPE: eps\n    MIN_SNR_GAMMA:\n    NOISE_SCHEDULER:\n      NAME: LinearScheduler\n      NUM_TIMESTEPS: 1000\n      BETA_MIN: 0.0001\n      BETA_MAX: 0.02\n  #\n  DIFFUSION_MODEL:\n    NAME: DiTACE\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth\n    IGNORE_KEYS: [ ]\n    PATCH_SIZE: 2\n    IN_CHANNELS: 4\n    HIDDEN_SIZE: 1152\n    DEPTH: 28\n    NUM_HEADS: 16\n    MLP_RATIO: 4.0\n    PRED_SIGMA: True\n    DROP_PATH: 0.0\n    WINDOW_DIZE: 0\n    Y_CHANNELS: 4096\n    MAX_SEQ_LEN: 4096\n    QK_NORM: True\n    USE_GRAD_CHECKPOINT: True\n    ATTENTION_BACKEND: flash_attn\n  #\n  FIRST_STAGE_MODEL:\n    NAME: AutoencoderKL\n    EMBED_DIM: 4\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin\n    IGNORE_KEYS: []\n    #\n    ENCODER:\n      NAME: Encoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DOUBLE_Z: True\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n    #\n    DECODER:\n      NAME: Decoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n      GIVE_PRE_END: False\n      TANH_OUT: False\n  #\n  COND_STAGE_MODEL:\n    NAME: ACETextEmbedder\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/\n    TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl\n    LENGTH: 120\n    T5_DTYPE: bfloat16\n    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    CLEAN: whitespace\n    USE_GRAD: False\n\nACE_PROMPT: [\n  \"A cute cartoon rabbit holding a whiteboard that says 'ACE Refiner', standing in a sunny meadow filled with flowers, with a big smile and bright colors.\",\n  \"A beautiful young woman with long flowing hair, wearing a summer dress, holding a whiteboard that reads 'ACE Refiner' while sitting on a park bench surrounded by cherry blossoms.\",\n  \"An adorable cartoon cat wearing oversized glasses, holding a whiteboard that says 'ACE Refiner', perched on a stack of colorful books in a cozy library setting.\",\n  \"A charming girl with pigtails, wearing a cute school uniform, enthusiastically holding a whiteboard that has 'ACE Refiner' written on it, in a bright and cheerful classroom full of educational posters.\",\n  \"A friendly cartoon dog with floppy ears, sitting in front of a doghouse, proudly holding a whiteboard that says 'ACE Refiner', with a playful expression and a blue sky in the background.\",\n  \"A cute anime girl with big expressive eyes, dressed in a colorful outfit, holding a whiteboard that reads 'ACE Refiner' in a fantastical landscape filled with mythical creatures.\",\n  \"A vibrant cartoon fox holding a whiteboard that says 'ACE Refiner', standing on a rock by a sparkling stream, surrounded by lush greenery and butterflies.\",\n  \"A stylish young woman in a business outfit, smiling as she holds a whiteboard written with 'ACE Refiner', in a modern office filled with plants and natural light.\",\n  \"A cute cartoon unicorn holding a sparkling whiteboard that says 'ACE Refiner', frolicking in a magical forest, with rainbows and stars in the background.\",\n  \"A happy family, consisting of a cute little girl and her playful puppy, holding a whiteboard that says 'ACE Refiner', together in their backyard on a sunny day.\"\n]\nREFINER_MODEL:\n  NAME: \"\"\n  IS_DEFAULT: False\n  DEFAULT_PARAS:\n    PARAS:\n      RESOLUTIONS: [ [ 1024, 1024 ] ]\n    INPUT:\n      INPUT_IMAGE:\n      INPUT_MASK:\n      TASK:\n      PROMPT: \"\"\n      NEGATIVE_PROMPT: \"\"\n      OUTPUT_HEIGHT: 1024\n      OUTPUT_WIDTH: 1024\n      SAMPLER: flow_euler\n      SAMPLE_STEPS: 30\n      GUIDE_SCALE: 3.5\n      GUIDE_RESCALE:\n    OUTPUT:\n      LATENT:\n      IMAGES:\n      SEED:\n    MODULES_PARAS:\n      FIRST_STAGE_MODEL:\n        FUNCTION:\n          - NAME: encode\n            DTYPE: bfloat16\n            INPUT: [ \"IMAGE\" ]\n          - NAME: decode\n            DTYPE: bfloat16\n            INPUT: [ \"LATENT\" ]\n        PARAS:\n          SCALE_FACTOR: 1.5305\n          SHIFT_FACTOR: 0.0609\n          SIZE_FACTOR: 8\n      DIFFUSION_MODEL:\n        FUNCTION:\n          - NAME: forward\n            DTYPE: bfloat16\n            INPUT: [ \"SAMPLE_STEPS\", \"SAMPLE\", \"GUIDE_SCALE\" ]\n      COND_STAGE_MODEL:\n        FUNCTION:\n          - NAME: encode\n            DTYPE: bfloat16\n            INPUT: [ \"PROMPT\" ]\n\n  MODEL:\n    DIFFUSION:\n      NAME: DiffusionFluxRF\n      PREDICTION_TYPE: raw\n      NOISE_SCHEDULER:\n        NAME: FlowMatchSigmaScheduler\n        WEIGHTING_SCHEME: logit_normal\n        SHIFT: 3.0\n        LOGIT_MEAN: 0.0\n        LOGIT_STD: 1.0\n        MODE_SCALE: 1.29\n    DIFFUSION_MODEL:\n      NAME: FluxMR\n      PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors\n      IN_CHANNELS: 64\n      OUT_CHANNELS: 64\n      HIDDEN_SIZE: 3072\n      NUM_HEADS: 24\n      AXES_DIM: [ 16, 56, 56 ]\n      THETA: 10000\n      VEC_IN_DIM: 768\n      GUIDANCE_EMBED: True\n      CONTEXT_IN_DIM: 4096\n      MLP_RATIO: 4.0\n      QKV_BIAS: True\n      DEPTH: 19\n      DEPTH_SINGLE_BLOCKS: 38\n      USE_GRAD_CHECKPOINT: True\n      ATTN_BACKEND: flash_attn\n    #\n    FIRST_STAGE_MODEL:\n      NAME: AutoencoderKLFlux\n      EMBED_DIM: 16\n      PRETRAINED_MODEL:  ms://AI-ModelScope/FLUX.1-dev@ae.safetensors\n      IGNORE_KEYS: [ ]\n      BATCH_SIZE: 8\n      USE_CONV: False\n      SCALE_FACTOR: 0.3611\n      SHIFT_FACTOR: 0.1159\n      #\n      ENCODER:\n        NAME: Encoder\n        USE_CHECKPOINT: False\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 16\n        DOUBLE_Z: True\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n      #\n      DECODER:\n        NAME: Decoder\n        USE_CHECKPOINT: False\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 16\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n        GIVE_PRE_END: False\n        TANH_OUT: False\n    #\n    COND_STAGE_MODEL:\n      NAME: T5PlusClipFluxEmbedder\n      T5_MODEL:\n        NAME: HFEmbedder\n        HF_MODEL_CLS: T5EncoderModel\n        MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/\n        HF_TOKENIZER_CLS: T5Tokenizer\n        TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/\n        MAX_LENGTH: 512\n        OUTPUT_KEY: last_hidden_state\n        D_TYPE: bfloat16\n        BATCH_INFER: False\n        CLEAN: whitespace\n      CLIP_MODEL:\n        NAME: HFEmbedder\n        HF_MODEL_CLS: CLIPTextModel\n        MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder/\n        HF_TOKENIZER_CLS: CLIPTokenizer\n        TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer/\n        MAX_LENGTH: 77\n        OUTPUT_KEY: pooler_output\n        D_TYPE: bfloat16\n        BATCH_INFER: True\n        CLEAN: whitespace\n"
  },
  {
    "path": "config/inference_config/models/ace_0.6b_512.yaml",
    "content": "NAME: ACE_0.6B_512\nIS_DEFAULT: True\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n  PARAS:\n  #\n  INPUT:\n    INPUT_IMAGE:\n    INPUT_MASK:\n    TASK:\n    PROMPT: \"\"\n    NEGATIVE_PROMPT: \"\"\n    OUTPUT_HEIGHT: 512\n    OUTPUT_WIDTH: 512\n    SAMPLER: ddim\n    SAMPLE_STEPS: 20\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n    SEED: -1\n    TAR_INDEX: 0\n  OUTPUT:\n    LATENT:\n    IMAGES:\n    SEED:\n  MODULES_PARAS:\n    FIRST_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode\n          DTYPE: float16\n          INPUT: [\"IMAGE\"]\n        - NAME: decode\n          DTYPE: float16\n          INPUT: [\"LATENT\"]\n    #\n    DIFFUSION_MODEL:\n      FUNCTION:\n        - NAME: forward\n          DTYPE: float16\n          INPUT: [\"SAMPLE_STEPS\", \"SAMPLE\", \"GUIDE_SCALE\"]\n    #\n    COND_STAGE_MODEL:\n      FUNCTION:\n        - NAME: encode_list\n          DTYPE: bfloat16\n          INPUT: [\"PROMPT\"]\n#\nMODEL:\n  NAME: LdmACE\n  PRETRAINED_MODEL:\n  IGNORE_KEYS: [ ]\n  SCALE_FACTOR: 0.18215\n  SIZE_FACTOR: 8\n  DECODER_BIAS: 0.5\n  DEFAULT_N_PROMPT: \"\"\n  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n  USE_TEXT_POS_EMBEDDINGS: True\n  #\n  DIFFUSION:\n    NAME: ACEDiffusion\n    PREDICTION_TYPE: eps\n    MIN_SNR_GAMMA:\n    NOISE_SCHEDULER:\n      NAME: LinearScheduler\n      NUM_TIMESTEPS: 1000\n      BETA_MIN: 0.0001\n      BETA_MAX: 0.02\n  #\n  DIFFUSION_MODEL:\n    NAME: DiTACE\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth\n    IGNORE_KEYS: [ ]\n    PATCH_SIZE: 2\n    IN_CHANNELS: 4\n    HIDDEN_SIZE: 1152\n    DEPTH: 28\n    NUM_HEADS: 16\n    MLP_RATIO: 4.0\n    PRED_SIGMA: True\n    DROP_PATH: 0.0\n    WINDOW_DIZE: 0\n    Y_CHANNELS: 4096\n    MAX_SEQ_LEN: 1024\n    QK_NORM: True\n    USE_GRAD_CHECKPOINT: True\n    ATTENTION_BACKEND: flash_attn\n  #\n  FIRST_STAGE_MODEL:\n    NAME: AutoencoderKL\n    EMBED_DIM: 4\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin\n    IGNORE_KEYS: []\n    #\n    ENCODER:\n      NAME: Encoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DOUBLE_Z: True\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n    #\n    DECODER:\n      NAME: Decoder\n      CH: 128\n      OUT_CH: 3\n      NUM_RES_BLOCKS: 2\n      IN_CHANNELS: 3\n      ATTN_RESOLUTIONS: [ ]\n      CH_MULT: [ 1, 2, 4, 4 ]\n      Z_CHANNELS: 4\n      DROPOUT: 0.0\n      RESAMP_WITH_CONV: True\n      GIVE_PRE_END: False\n      TANH_OUT: False\n  #\n  COND_STAGE_MODEL:\n    NAME: ACETextEmbedder\n    PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/\n    TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl\n    LENGTH: 120\n    T5_DTYPE: bfloat16\n    ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    CLEAN: whitespace\n    USE_GRAD: False\n"
  },
  {
    "path": "config/train_config/ace_0.6b_1024_train.yaml",
    "content": "ENV:\n  BACKEND: nccl\n  SEED: 2024\n#\nSOLVER:\n  NAME: ACESolverV1\n  RESUME_FROM:\n  LOAD_MODEL_ONLY: True\n  USE_FSDP: False\n  SHARDING_STRATEGY:\n  USE_AMP: True\n  DTYPE: float16\n  CHANNELS_LAST: True\n  MAX_STEPS: 500\n  MAX_EPOCHS: -1\n  NUM_FOLDS: 1\n  ACCU_STEP: 1\n  EVAL_INTERVAL: 50\n  RESCALE_LR: False\n  #\n  WORK_DIR: ./cache/exp/exp1\n  LOG_FILE: std_log.txt\n  #\n  FILE_SYSTEM:\n    - NAME: \"HuggingfaceFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"ModelscopeFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"LocalFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"HttpFs\"\n      TEMP_DIR: ./cache\n  #\n  MODEL:\n    NAME: LdmACE\n    PRETRAINED_MODEL:\n    IGNORE_KEYS: [ ]\n    SCALE_FACTOR: 0.18215\n    SIZE_FACTOR: 8\n    DECODER_BIAS: 0.5\n    DEFAULT_N_PROMPT:\n    USE_EMA: True\n    EVAL_EMA: False\n    TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    USE_TEXT_POS_EMBEDDINGS: True\n    #\n    DIFFUSION:\n      NAME: ACEDiffusion\n      PREDICTION_TYPE: eps\n      MIN_SNR_GAMMA:\n      NOISE_SCHEDULER:\n        NAME: LinearScheduler\n        NUM_TIMESTEPS: 1000\n        BETA_MIN: 0.0001\n        BETA_MAX: 0.02\n    #\n    DIFFUSION_MODEL:\n      NAME: DiTACE\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth\n      IGNORE_KEYS: [ ]\n      PATCH_SIZE: 2\n      IN_CHANNELS: 4\n      HIDDEN_SIZE: 1152\n      DEPTH: 28\n      NUM_HEADS: 16\n      MLP_RATIO: 4.0\n      PRED_SIGMA: True\n      DROP_PATH: 0.0\n      WINDOW_DIZE: 0\n      Y_CHANNELS: 4096\n      MAX_SEQ_LEN: 4096\n      QK_NORM: True\n      USE_GRAD_CHECKPOINT: True\n      ATTENTION_BACKEND: flash_attn\n    #\n    FIRST_STAGE_MODEL:\n      NAME: AutoencoderKL\n      EMBED_DIM: 4\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin\n      IGNORE_KEYS: []\n      #\n      ENCODER:\n        NAME: Encoder\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 4\n        DOUBLE_Z: True\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n      #\n      DECODER:\n        NAME: Decoder\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 4\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n        GIVE_PRE_END: False\n        TANH_OUT: False\n    #\n    COND_STAGE_MODEL:\n      NAME: T5EmbedderHF\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/\n      TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl\n      LENGTH: 120\n      T5_DTYPE: bfloat16\n      ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n      CLEAN: whitespace\n      USE_GRAD: False\n    LOSS:\n      NAME: ReconstructLoss\n      LOSS_TYPE: l2\n  #\n  SAMPLE_ARGS:\n    SAMPLER: ddim\n    SAMPLE_STEPS: 20\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n  #\n  OPTIMIZER:\n    NAME: AdamW\n    LEARNING_RATE: 1e-5\n    EPS: 1e-10\n    WEIGHT_DECAY: 5e-4\n  #\n  TRAIN_DATA:\n    NAME: ACEDemoDataset\n    MODE: train\n    MS_DATASET_NAME: cache/datasets/hed_pair\n    MS_DATASET_NAMESPACE: \"\"\n    MS_DATASET_SPLIT: \"train\"\n    MS_DATASET_SUBNAME: \"\"\n    PROMPT_PREFIX: \"\"\n    REPLACE_STYLE: False\n    MAX_SEQ_LEN: 4096\n    PIN_MEMORY: True\n    BATCH_SIZE: 1\n    NUM_WORKERS: 1\n    SAMPLER:\n      NAME: LoopSampler\n  #\n  TRAIN_HOOKS:\n    -\n      NAME: BackwardHook\n      PRIORITY: 0\n    -\n      NAME: LogHook\n      LOG_INTERVAL: 50\n    -\n      NAME: CheckpointHook\n      INTERVAL: 100\n    -\n      NAME: ProbeDataHook\n      PROB_INTERVAL: 100\n"
  },
  {
    "path": "config/train_config/ace_0.6b_512_train.yaml",
    "content": "ENV:\n  BACKEND: nccl\n  SEED: 2024\n#\nSOLVER:\n  NAME: ACESolverV1\n  RESUME_FROM:\n  LOAD_MODEL_ONLY: True\n  USE_FSDP: False\n  SHARDING_STRATEGY:\n  USE_AMP: True\n  DTYPE: float16\n  CHANNELS_LAST: True\n  MAX_STEPS: 500\n  MAX_EPOCHS: -1\n  NUM_FOLDS: 1\n  ACCU_STEP: 1\n  EVAL_INTERVAL: 50\n  RESCALE_LR: False\n  #\n  WORK_DIR: ./cache/exp/exp1\n  LOG_FILE: std_log.txt\n  #\n  FILE_SYSTEM:\n    - NAME: \"HuggingfaceFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"ModelscopeFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"LocalFs\"\n      TEMP_DIR: ./cache\n    - NAME: \"HttpFs\"\n      TEMP_DIR: ./cache\n  #\n  MODEL:\n    NAME: LdmACE\n    PRETRAINED_MODEL:\n    IGNORE_KEYS: [ ]\n    SCALE_FACTOR: 0.18215\n    SIZE_FACTOR: 8\n    DECODER_BIAS: 0.5\n    DEFAULT_N_PROMPT:\n    USE_EMA: True\n    EVAL_EMA: False\n    TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n    USE_TEXT_POS_EMBEDDINGS: True\n    #\n    DIFFUSION:\n      NAME: ACEDiffusion\n      PREDICTION_TYPE: eps\n      MIN_SNR_GAMMA:\n      NOISE_SCHEDULER:\n        NAME: LinearScheduler\n        NUM_TIMESTEPS: 1000\n        BETA_MIN: 0.0001\n        BETA_MAX: 0.02\n    #\n    DIFFUSION_MODEL:\n      NAME: DiTACE\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth\n      IGNORE_KEYS: [ ]\n      PATCH_SIZE: 2\n      IN_CHANNELS: 4\n      HIDDEN_SIZE: 1152\n      DEPTH: 28\n      NUM_HEADS: 16\n      MLP_RATIO: 4.0\n      PRED_SIGMA: True\n      DROP_PATH: 0.0\n      WINDOW_DIZE: 0\n      Y_CHANNELS: 4096\n      MAX_SEQ_LEN: 1024\n      QK_NORM: True\n      USE_GRAD_CHECKPOINT: True\n      ATTENTION_BACKEND: flash_attn\n    #\n    FIRST_STAGE_MODEL:\n      NAME: AutoencoderKL\n      EMBED_DIM: 4\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin\n      IGNORE_KEYS: []\n      #\n      ENCODER:\n        NAME: Encoder\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 4\n        DOUBLE_Z: True\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n      #\n      DECODER:\n        NAME: Decoder\n        CH: 128\n        OUT_CH: 3\n        NUM_RES_BLOCKS: 2\n        IN_CHANNELS: 3\n        ATTN_RESOLUTIONS: [ ]\n        CH_MULT: [ 1, 2, 4, 4 ]\n        Z_CHANNELS: 4\n        DROPOUT: 0.0\n        RESAMP_WITH_CONV: True\n        GIVE_PRE_END: False\n        TANH_OUT: False\n    #\n    COND_STAGE_MODEL:\n      NAME: ACETextEmbedder\n      PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/\n      TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl\n      LENGTH: 120\n      T5_DTYPE: bfloat16\n      ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]\n      CLEAN: whitespace\n      USE_GRAD: False\n    LOSS:\n      NAME: ReconstructLoss\n      LOSS_TYPE: l2\n  #\n  SAMPLE_ARGS:\n    SAMPLER: ddim\n    SAMPLE_STEPS: 20\n    GUIDE_SCALE: 4.5\n    GUIDE_RESCALE: 0.5\n  #\n  OPTIMIZER:\n    NAME: AdamW\n    LEARNING_RATE: 1e-5\n    EPS: 1e-10\n    WEIGHT_DECAY: 5e-4\n  #\n  TRAIN_DATA:\n    NAME: ACEDemoDataset\n    MODE: train\n    MS_DATASET_NAME: cache/datasets/hed_pair\n    MS_DATASET_NAMESPACE: \"\"\n    MS_DATASET_SPLIT: \"train\"\n    MS_DATASET_SUBNAME: \"\"\n    PROMPT_PREFIX: \"\"\n    REPLACE_STYLE: False\n    MAX_SEQ_LEN: 1024\n    PIN_MEMORY: True\n    BATCH_SIZE: 1\n    NUM_WORKERS: 1\n    SAMPLER:\n      NAME: LoopSampler\n  #\n  TRAIN_HOOKS:\n    -\n      NAME: BackwardHook\n      PRIORITY: 0\n    -\n      NAME: LogHook\n      LOG_INTERVAL: 50\n    -\n      NAME: CheckpointHook\n      INTERVAL: 100\n    -\n      NAME: ProbeDataHook\n      PROB_INTERVAL: 100\n"
  },
  {
    "path": "modules/__init__.py",
    "content": "from . import data, model, solver"
  },
  {
    "path": "modules/data/__init__.py",
    "content": "from . import dataset"
  },
  {
    "path": "modules/data/dataset/__init__.py",
    "content": "from .dataset import ACEDemoDataset"
  },
  {
    "path": "modules/data/dataset/dataset.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nimport io\nimport math\nimport os\nimport sys\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\nimport torchvision.transforms as T\nfrom PIL import Image\nfrom torchvision.transforms.functional import InterpolationMode\n\nfrom scepter.modules.data.dataset.base_dataset import BaseDataset\nfrom scepter.modules.data.dataset.registry import DATASETS\nfrom scepter.modules.transform.io import pillow_convert\nfrom scepter.modules.utils.config import dict_to_yaml\nfrom scepter.modules.utils.file_system import FS\n\nImage.MAX_IMAGE_PIXELS = None\n\n@DATASETS.register_class()\nclass ACEDemoDataset(BaseDataset):\n    para_dict = {\n        'MS_DATASET_NAME': {\n            'value': '',\n            'description': 'Modelscope dataset name.'\n        },\n        'MS_DATASET_NAMESPACE': {\n            'value': '',\n            'description': 'Modelscope dataset namespace.'\n        },\n        'MS_DATASET_SUBNAME': {\n            'value': '',\n            'description': 'Modelscope dataset subname.'\n        },\n        'MS_DATASET_SPLIT': {\n            'value': '',\n            'description':\n            'Modelscope dataset split set name, default is train.'\n        },\n        'MS_REMAP_KEYS': {\n            'value':\n            None,\n            'description':\n            'Modelscope dataset header of list file, the default is Target:FILE; '\n            'If your file is not this header, please set this field, which is a map dict.'\n            \"For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE\"\n        },\n        'MS_REMAP_PATH': {\n            'value':\n            None,\n            'description':\n            'When modelscope dataset name is not None, that means you use the dataset from modelscope,'\n            ' default is None. But if you want to use the datalist from modelscope and the file from '\n            'local device, you can use this field to set the root path of your images. '\n        },\n        'TRIGGER_WORDS': {\n            'value':\n            '',\n            'description':\n            'The words used to describe the common features of your data, especially when you customize a '\n            'tuner. Use these words you can get what you want.'\n        },\n        'HIGHLIGHT_KEYWORDS': {\n            'value':\n            '',\n            'description':\n            'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'\n        },\n        'KEYWORDS_SIGN': {\n            'value':\n            '',\n            'description':\n            'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'\n        },\n    }\n\n    def __init__(self, cfg, logger=None):\n        super().__init__(cfg=cfg, logger=logger)\n        from modelscope import MsDataset\n        from modelscope.utils.constant import DownloadMode\n        ms_dataset_name = cfg.get('MS_DATASET_NAME', None)\n        ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)\n        ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)\n        ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')\n        ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)\n        ms_remap_path = cfg.get('MS_REMAP_PATH', None)\n\n        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)\n        self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)\n        self.d = cfg.get('DOWNSAMPLE_RATIO', 16)\n        self.replace_style = cfg.get('REPLACE_STYLE', False)\n        self.trigger_words = cfg.get('TRIGGER_WORDS', '')\n        self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')\n        self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')\n        self.add_indicator = cfg.get('ADD_INDICATOR', False)\n        # Use modelscope dataset\n        if not ms_dataset_name:\n            raise ValueError(\n                'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '\n                'as modelscope dataset.')\n        if FS.exists(ms_dataset_name):\n            ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)\n            self.ms_dataset_name = ms_dataset_name\n            # ms_remap_path = ms_dataset_name\n        try:\n            self.data = MsDataset.load(str(ms_dataset_name),\n                                       namespace=ms_dataset_namespace,\n                                       subset_name=ms_dataset_subname,\n                                       split=ms_dataset_split)\n        except Exception:\n            self.logger.info(\n                \"Load Modelscope dataset failed, retry with download_mode='force_redownload'.\"\n            )\n            try:\n                self.data = MsDataset.load(\n                    str(ms_dataset_name),\n                    namespace=ms_dataset_namespace,\n                    subset_name=ms_dataset_subname,\n                    split=ms_dataset_split,\n                    download_mode=DownloadMode.FORCE_REDOWNLOAD)\n            except Exception as sec_e:\n                raise ValueError(f'Load Modelscope dataset failed {sec_e}.')\n        if ms_remap_keys:\n            self.data = self.data.remap_columns(ms_remap_keys.get_dict())\n\n        if ms_remap_path:\n\n            def map_func(example):\n                return {\n                    k: os.path.join(ms_remap_path, v)\n                    if k.endswith(':FILE') else v\n                    for k, v in example.items()\n                }\n\n            self.data = self.data.ds_instance.map(map_func)\n\n        self.transforms = T.Compose([\n            T.ToTensor(),\n            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n        ])\n\n    def __len__(self):\n        if self.mode == 'train':\n            return sys.maxsize\n        else:\n            return len(self.data)\n\n    def _get(self, index: int):\n        current_data = self.data[index % len(self.data)]\n\n        tar_image_path = current_data.get('Target:FILE', '')\n        src_image_path = current_data.get('Source:FILE', '')\n\n        style = current_data.get('Style', '')\n        prompt = current_data.get('Prompt', current_data.get('prompt', ''))\n        if self.replace_style and not style == '':\n            prompt = prompt.replace(style, f'<{self.keywords_sign}>')\n\n        elif not self.replace_keywords.strip() == '':\n            prompt = prompt.replace(\n                self.replace_keywords,\n                '<' + self.replace_keywords + f'{self.keywords_sign}>')\n\n        if not self.trigger_words == '':\n            prompt = self.trigger_words.strip() + ' ' + prompt\n\n        src_image = self.load_image(self.ms_dataset_name,\n                                    src_image_path,\n                                    cvt_type='RGB')\n        tar_image = self.load_image(self.ms_dataset_name,\n                                    tar_image_path,\n                                    cvt_type='RGB')\n        src_image = self.image_preprocess(src_image)\n        tar_image = self.image_preprocess(tar_image)\n\n        tar_image = self.transforms(tar_image)\n        src_image = self.transforms(src_image)\n        src_mask = torch.ones_like(src_image[[0]])\n        tar_mask = torch.ones_like(tar_image[[0]])\n        if self.add_indicator:\n            if '{image}' not in prompt:\n                prompt = '{image}, ' + prompt\n\n        return {\n            'edit_image': [src_image],\n            'edit_image_mask': [src_mask],\n            'image': tar_image,\n            'image_mask': tar_mask,\n            'prompt': [prompt],\n        }\n\n    def load_image(self, prefix, img_path, cvt_type=None):\n        if img_path is None or img_path == '':\n            return None\n        img_path = os.path.join(prefix, img_path)\n        with FS.get_object(img_path) as image_bytes:\n            image = Image.open(io.BytesIO(image_bytes))\n            if cvt_type is not None:\n                image = pillow_convert(image, cvt_type)\n        return image\n\n    def image_preprocess(self,\n                         img,\n                         size=None,\n                         interpolation=InterpolationMode.BILINEAR):\n        H, W = img.height, img.width\n        if H / W > self.max_aspect_ratio:\n            img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)\n        elif W / H > self.max_aspect_ratio:\n            img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)\n\n        if size is None:\n            # resize image for max_seq_len, while keep the aspect ratio\n            H, W = img.height, img.width\n            scale = min(\n                1.0,\n                math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))\n            rH = int(\n                H * scale) // self.d * self.d  # ensure divisible by self.d\n            rW = int(W * scale) // self.d * self.d\n        else:\n            rH, rW = size\n        img = T.Resize((rH, rW), interpolation=interpolation,\n                       antialias=True)(img)\n        return np.array(img, dtype=np.uint8)\n\n    @staticmethod\n    def get_config_template():\n        return dict_to_yaml('DATASet',\n                            __class__.__name__,\n                            ACEDemoDataset.para_dict,\n                            set_name=True)\n\n    @staticmethod\n    def collate_fn(batch):\n        collect = defaultdict(list)\n        for sample in batch:\n            for k, v in sample.items():\n                collect[k].append(v)\n\n        new_batch = dict()\n        for k, v in collect.items():\n            if all([i is None for i in v]):\n                new_batch[k] = None\n            else:\n                new_batch[k] = v\n\n        return new_batch\n"
  },
  {
    "path": "modules/inference/__init__.py",
    "content": ""
  },
  {
    "path": "modules/model/__init__.py",
    "content": "from . import backbone, embedder, diffusion, network"
  },
  {
    "path": "modules/model/backbone/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom .ace import DiTACE\n"
  },
  {
    "path": "modules/model/backbone/ace.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport re\nfrom collections import OrderedDict\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch.nn.utils.rnn import pad_sequence\nfrom torch.utils.checkpoint import checkpoint_sequential\n\nfrom scepter.modules.model.base_model import BaseModel\nfrom scepter.modules.model.registry import BACKBONES\nfrom scepter.modules.utils.config import dict_to_yaml\nfrom scepter.modules.utils.file_system import FS\n\nfrom .layers import (\n    Mlp,\n    TimestepEmbedder,\n    PatchEmbed,\n    DiTACEBlock,\n    T2IFinalLayer\n)\nfrom .pos_embed import rope_params\n\n\n@BACKBONES.register_class()\nclass DiTACE(BaseModel):\n\n    para_dict = {\n        'PATCH_SIZE': {\n            'value': 2,\n            'description': ''\n        },\n        'IN_CHANNELS': {\n            'value': 4,\n            'description': ''\n        },\n        'HIDDEN_SIZE': {\n            'value': 1152,\n            'description': ''\n        },\n        'DEPTH': {\n            'value': 28,\n            'description': ''\n        },\n        'NUM_HEADS': {\n            'value': 16,\n            'description': ''\n        },\n        'MLP_RATIO': {\n            'value': 4.0,\n            'description': ''\n        },\n        'PRED_SIGMA': {\n            'value': True,\n            'description': ''\n        },\n        'DROP_PATH': {\n            'value': 0.,\n            'description': ''\n        },\n        'WINDOW_SIZE': {\n            'value': 0,\n            'description': ''\n        },\n        'WINDOW_BLOCK_INDEXES': {\n            'value': None,\n            'description': ''\n        },\n        'Y_CHANNELS': {\n            'value': 4096,\n            'description': ''\n        },\n        'ATTENTION_BACKEND': {\n            'value': None,\n            'description': ''\n        },\n        'QK_NORM': {\n            'value': True,\n            'description': 'Whether to use RMSNorm for query and key.',\n        },\n    }\n    para_dict.update(BaseModel.para_dict)\n\n    def __init__(self, cfg, logger):\n        super().__init__(cfg, logger=logger)\n        self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)\n        if self.window_block_indexes is None:\n            self.window_block_indexes = []\n        self.pred_sigma = cfg.get('PRED_SIGMA', True)\n        self.in_channels = cfg.get('IN_CHANNELS', 4)\n        self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels\n        self.patch_size = cfg.get('PATCH_SIZE', 2)\n        self.num_heads = cfg.get('NUM_HEADS', 16)\n        self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)\n        self.y_channels = cfg.get('Y_CHANNELS', 4096)\n        self.drop_path = cfg.get('DROP_PATH', 0.)\n        self.depth = cfg.get('DEPTH', 28)\n        self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)\n        self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)\n        self.attention_backend = cfg.get('ATTENTION_BACKEND', None)\n        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)\n        self.qk_norm = cfg.get('QK_NORM', False)\n        self.ignore_keys = cfg.get('IGNORE_KEYS', [])\n        assert (self.hidden_size % self.num_heads\n                ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0\n        d = self.hidden_size // self.num_heads\n        self.freqs = torch.cat(\n            [\n                rope_params(self.max_seq_len, d - 4 * (d // 6)),  # T (~1/3)\n                rope_params(self.max_seq_len, 2 * (d // 6)),  # H (~1/3)\n                rope_params(self.max_seq_len, 2 * (d // 6))  # W (~1/3)\n            ],\n            dim=1)\n\n        # init embedder\n        self.x_embedder = PatchEmbed(self.patch_size,\n                                     self.in_channels + 1,\n                                     self.hidden_size,\n                                     bias=True,\n                                     flatten=False)\n        self.t_embedder = TimestepEmbedder(self.hidden_size)\n        self.y_embedder = Mlp(in_features=self.y_channels,\n                              hidden_features=self.hidden_size,\n                              out_features=self.hidden_size,\n                              act_layer=lambda: nn.GELU(approximate='tanh'),\n                              drop=0)\n        self.t_block = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))\n        # init blocks\n        drop_path = [\n            x.item() for x in torch.linspace(0, self.drop_path, self.depth)\n        ]\n        self.blocks = nn.ModuleList([\n            DiTACEBlock(self.hidden_size,\n                        self.num_heads,\n                        mlp_ratio=self.mlp_ratio,\n                        drop_path=drop_path[i],\n                        window_size=self.window_size\n                        if i in self.window_block_indexes else 0,\n                        backend=self.attention_backend,\n                        use_condition=True,\n                        qk_norm=self.qk_norm) for i in range(self.depth)\n        ])\n        self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,\n                                         self.out_channels)\n        self.initialize_weights()\n\n    def load_pretrained_model(self, pretrained_model):\n        if pretrained_model:\n            with FS.get_from(pretrained_model, wait_finish=True) as local_path:\n                model = torch.load(local_path, map_location='cpu')\n                if 'state_dict' in model:\n                    model = model['state_dict']\n                new_ckpt = OrderedDict()\n                for k, v in model.items():\n                    if self.ignore_keys is not None:\n                        if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \\\n                                (isinstance(self.ignore_keys, list) and k in self.ignore_keys):\n                            continue\n                    k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')\n                    k = k.replace('.cross_attn.proj.',\n                                  '.cross_attn.o.').replace(\n                                      '.attn.proj.', '.attn.o.')\n                    if '.cross_attn.kv_linear.' in k:\n                        k_p, v_p = torch.split(v, v.shape[0] // 2)\n                        new_ckpt[k.replace('.cross_attn.kv_linear.',\n                                           '.cross_attn.k.')] = k_p\n                        new_ckpt[k.replace('.cross_attn.kv_linear.',\n                                           '.cross_attn.v.')] = v_p\n                    elif '.attn.qkv.' in k:\n                        q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)\n                        new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p\n                        new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p\n                        new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p\n                    elif 'y_embedder.y_proj.' in k:\n                        new_ckpt[k.replace('y_embedder.y_proj.',\n                                           'y_embedder.')] = v\n                    elif k in ('x_embedder.proj.weight'):\n                        model_p = self.state_dict()[k]\n                        if v.shape != model_p.shape:\n                            model_p.zero_()\n                            model_p[:, :4, :, :].copy_(v)\n                            new_ckpt[k] = torch.nn.parameter.Parameter(model_p)\n                        else:\n                            new_ckpt[k] = v\n                    elif k in ('x_embedder.proj.bias'):\n                        new_ckpt[k] = v\n                    else:\n                        new_ckpt[k] = v\n                missing, unexpected = self.load_state_dict(new_ckpt,\n                                                           strict=False)\n                print(\n                    f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'\n                )\n                if len(missing) > 0:\n                    print(f'Missing Keys:\\n {missing}')\n                if len(unexpected) > 0:\n                    print(f'\\nUnexpected Keys:\\n {unexpected}')\n\n    def forward(self,\n                x,\n                t=None,\n                cond=dict(),\n                mask=None,\n                text_position_embeddings=None,\n                gc_seg=-1,\n                **kwargs):\n        if self.freqs.device != x.device:\n            self.freqs = self.freqs.to(x.device)\n        if isinstance(cond, dict):\n            context = cond.get('crossattn', None)\n        else:\n            context = cond\n        if text_position_embeddings is not None:\n            # default use the text_position_embeddings in state_dict\n            # if state_dict doesn't including this key, use the arg: text_position_embeddings\n            proj_position_embeddings = self.y_embedder(\n                text_position_embeddings)\n        else:\n            proj_position_embeddings = None\n\n        ctx_batch, txt_lens = [], []\n        if mask is not None and isinstance(mask, list):\n            for ctx, ctx_mask in zip(context, mask):\n                for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):\n                    u, m = one_ctx\n                    t_len = m.flatten().sum()  # l\n                    u = u[:t_len]\n                    u = self.y_embedder(u)\n                    if frame_id == 0:\n                        u = u + proj_position_embeddings[\n                            len(ctx) -\n                            1] if proj_position_embeddings is not None else u\n                    else:\n                        u = u + proj_position_embeddings[\n                            frame_id -\n                            1] if proj_position_embeddings is not None else u\n                    ctx_batch.append(u)\n                    txt_lens.append(t_len)\n        else:\n            raise TypeError\n        y = torch.cat(ctx_batch, dim=0)\n        txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)\n\n        batch_frames = []\n        for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):\n            u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])\n            m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)\n            batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])\n        if 'edit' in cond:\n            for i, (edit, edit_mask) in enumerate(\n                    zip(cond['edit'], cond['edit_mask'])):\n                if edit is None:\n                    continue\n                for u, m in zip(edit, edit_mask):\n                    u = u.squeeze(0)\n                    m = torch.ones_like(\n                        u[[0], :, :]) if m is None else m.squeeze(0)\n                    batch_frames[i].append(\n                        torch.cat([u, m], dim=0).unsqueeze(0))\n\n        patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []\n        for frames in batch_frames:\n            patches, patch_shapes = [], []\n            self_x_len.append(0)\n            for frame_id, u in enumerate(frames):\n                u = self.x_embedder(u)\n                h, w = u.size(2), u.size(3)\n                u = rearrange(u, '1 c h w -> (h w) c')\n                if frame_id == 0:\n                    u = u + proj_position_embeddings[\n                        len(frames) -\n                        1] if proj_position_embeddings is not None else u\n                else:\n                    u = u + proj_position_embeddings[\n                        frame_id -\n                        1] if proj_position_embeddings is not None else u\n                patches.append(u)\n                patch_shapes.append([h, w])\n                cross_x_len.append(h * w)  # b*s, 1\n                self_x_len[-1] += h * w  # b, 1\n            # u = torch.cat(patches, dim=0)\n            patch_batch.extend(patches)\n            shape_batch.append(\n                torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))\n        # repeat t to align with x\n        t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])\n        self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(\n            x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(\n                x.device, non_blocking=True))\n        # x = pad_sequence(tuple(patch_batch), batch_first=True)  # b, s*max(cl), c\n        x = torch.cat(patch_batch, dim=0)\n        x_shapes = pad_sequence(tuple(shape_batch),\n                                batch_first=True)  # b, max(len(frames)), 2\n        t = self.t_embedder(t)  # (N, D)\n        t0 = self.t_block(t)\n        # y = self.y_embedder(context)\n\n        kwargs = dict(y=y,\n                      t=t0,\n                      x_shapes=x_shapes,\n                      self_x_len=self_x_len,\n                      cross_x_len=cross_x_len,\n                      freqs=self.freqs,\n                      txt_lens=txt_lens)\n        if self.use_grad_checkpoint and gc_seg >= 0:\n            x = checkpoint_sequential(\n                functions=[partial(block, **kwargs) for block in self.blocks],\n                segments=gc_seg if gc_seg > 0 else len(self.blocks),\n                input=x,\n                use_reentrant=False)\n        else:\n            for block in self.blocks:\n                x = block(x, **kwargs)\n        x = self.final_layer(x, t)  # b*s*n, d\n        outs, cur_length = [], 0\n        p = self.patch_size\n        for seq_length, shape in zip(self_x_len, shape_batch):\n            x_i = x[cur_length:cur_length + seq_length]\n            h, w = shape[0].tolist()\n            u = x_i[:h * w].view(h, w, p, p, -1)\n            u = rearrange(u, 'h w p q c -> (h p w q) c'\n                          )  # dump into sequence for following tensor ops\n            cur_length = cur_length + seq_length\n            outs.append(u)\n        x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)\n        if self.pred_sigma:\n            return x.chunk(2, dim=1)[0]\n        else:\n            return x\n\n    def initialize_weights(self):\n        # Initialize transformer layers:\n        def _basic_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n\n        self.apply(_basic_init)\n        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):\n        w = self.x_embedder.proj.weight.data\n        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))\n        # Initialize timestep embedding MLP:\n        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)\n        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)\n        nn.init.normal_(self.t_block[1].weight, std=0.02)\n        # Initialize caption embedding MLP:\n        if hasattr(self, 'y_embedder'):\n            nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)\n            nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)\n        # Zero-out adaLN modulation layers\n        for block in self.blocks:\n            nn.init.constant_(block.cross_attn.o.weight, 0)\n            nn.init.constant_(block.cross_attn.o.bias, 0)\n        # Zero-out output layers:\n        nn.init.constant_(self.final_layer.linear.weight, 0)\n        nn.init.constant_(self.final_layer.linear.bias, 0)\n\n    @property\n    def dtype(self):\n        return next(self.parameters()).dtype\n\n    @staticmethod\n    def get_config_template():\n        return dict_to_yaml('BACKBONE',\n                            __class__.__name__,\n                            DiTACE.para_dict,\n                            set_name=True)\n"
  },
  {
    "path": "modules/model/backbone/layers.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport warnings\nimport torch\nimport torch.nn as nn\nfrom .pos_embed import rope_apply_multires as rope_apply\n\ntry:\n    from flash_attn import (flash_attn_varlen_func)\n    FLASHATTN_IS_AVAILABLE = True\nexcept ImportError as e:\n    FLASHATTN_IS_AVAILABLE = False\n    flash_attn_varlen_func = None\n    warnings.warn(f'{e}')\n\n__all__ = [\n    \"drop_path\",\n    \"modulate\",\n    \"PatchEmbed\",\n    \"DropPath\",\n    \"RMSNorm\",\n    \"Mlp\",\n    \"TimestepEmbedder\",\n    \"DiTEditBlock\",\n    \"MultiHeadAttentionDiTEdit\",\n    \"T2IFinalLayer\",\n]\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0], ) + (1, ) * (\n        x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(\n        shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\ndef modulate(x, shift, scale, unsqueeze=False):\n    if unsqueeze:\n        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n    else:\n        return x * (1 + scale) + shift\n    \n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n    def __init__(\n        self,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n        bias=True,\n    ):\n        super().__init__()\n        self.flatten = flatten\n        self.proj = nn.Conv2d(in_chans,\n                              embed_dim,\n                              kernel_size=patch_size,\n                              stride=patch_size,\n                              bias=bias)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n    \n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-6):\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        return self._norm(x.float()).type_as(x) * self.weight\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self,\n                 in_features,\n                 hidden_features=None,\n                 out_features=None,\n                 act_layer=nn.GELU,\n                 drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n    def __init__(self, hidden_size, frequency_embedding_size=256):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, hidden_size, bias=True),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) *\n            torch.arange(start=0, end=half, dtype=torch.float32) /\n            half).to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n    \n\nclass DiTACEBlock(nn.Module):\n    def __init__(self,\n                 hidden_size,\n                 num_heads,\n                 mlp_ratio=4.0,\n                 drop_path=0.,\n                 window_size=0,\n                 backend=None,\n                 use_condition=True,\n                 qk_norm=False,\n                 **block_kwargs):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.use_condition = use_condition\n        self.norm1 = nn.LayerNorm(hidden_size,\n                                  elementwise_affine=False,\n                                  eps=1e-6)\n        self.attn = MultiHeadAttention(hidden_size,\n                                        num_heads=num_heads,\n                                        qkv_bias=True,\n                                        backend=backend,\n                                        qk_norm=qk_norm,\n                                        **block_kwargs)\n        if self.use_condition:\n            self.cross_attn = MultiHeadAttention(\n                hidden_size,\n                context_dim=hidden_size,\n                num_heads=num_heads,\n                qkv_bias=True,\n                backend=backend,\n                qk_norm=qk_norm,\n                **block_kwargs)\n        self.norm2 = nn.LayerNorm(hidden_size,\n                                  elementwise_affine=False,\n                                  eps=1e-6)\n        # to be compatible with lower version pytorch\n        approx_gelu = lambda: nn.GELU(approximate='tanh')\n        self.mlp = Mlp(in_features=hidden_size,\n                       hidden_features=int(hidden_size * mlp_ratio),\n                       act_layer=approx_gelu,\n                       drop=0)\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.window_size = window_size\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(6, hidden_size) / hidden_size**0.5)\n\n    def forward(self, x, y, t, **kwargs):\n        B = x.size(0)\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1),\n            shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1))\n        x = x + self.drop_path(gate_msa * self.attn(\n            modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), **\n            kwargs))\n        if self.use_condition:\n            x = x + self.cross_attn(x, context=y, **kwargs)\n\n        x = x + self.drop_path(gate_mlp * self.mlp(\n            modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False)))\n        return x\n\n\nclass MultiHeadAttention(nn.Module):\n    def __init__(self,\n                 dim,\n                 context_dim=None,\n                 num_heads=None,\n                 head_dim=None,\n                 attn_drop=0.0,\n                 qkv_bias=False,\n                 dropout=0.0,\n                 backend=None,\n                 qk_norm=False,\n                 eps=1e-6,\n                 **block_kwargs):\n        super().__init__()\n        # consider head_dim first, then num_heads\n        num_heads = dim // head_dim if head_dim else num_heads\n        head_dim = dim // num_heads\n        assert num_heads * head_dim == dim\n        context_dim = context_dim or dim\n        self.dim = dim\n        self.context_dim = context_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.scale = math.pow(head_dim, -0.25)\n        # layers\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.k = nn.Linear(context_dim, dim, bias=qkv_bias)\n        self.v = nn.Linear(context_dim, dim, bias=qkv_bias)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n        self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n\n        self.dropout = nn.Dropout(dropout)\n        self.attention_op = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.backend = backend\n        assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn',\n                                None)\n        if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None):\n            self.backend = 'flash_attn'\n            self.softmax_scale = block_kwargs.get('softmax_scale', None)\n            self.causal = block_kwargs.get('causal', False)\n            self.window_size = block_kwargs.get('window_size', (-1, -1))\n            self.deterministic = block_kwargs.get('deterministic', False)\n        else:\n            raise NotImplementedError\n\n    def flash_attn(self, x, context=None, **kwargs):\n        '''\n         The implementation will be very slow when mask is not None,\n         because we need rearange the x/context features according to mask.\n        Args:\n            x:\n            context:\n            mask:\n            **kwargs:\n        Returns: x\n        '''\n        dtype = kwargs.get('dtype', torch.float16)\n\n        def half(x):\n            return x if x.dtype in [torch.float16, torch.bfloat16\n                                    ] else x.to(dtype)\n\n        x_shapes = kwargs['x_shapes']\n        freqs = kwargs['freqs']\n        self_x_len = kwargs['self_x_len']\n        cross_x_len = kwargs['cross_x_len']\n        txt_lens = kwargs['txt_lens']\n        n, d = self.num_heads, self.head_dim\n\n        if context is None:\n            # self-attn\n            q = self.norm_q(self.q(x)).view(-1, n, d)\n            k = self.norm_q(self.k(x)).view(-1, n, d)\n            v = self.v(x).view(-1, n, d)\n            q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False)\n            k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False)\n            q_lens = k_lens = self_x_len\n        else:\n            # cross-attn\n            q = self.norm_q(self.q(x)).view(-1, n, d)\n            k = self.norm_q(self.k(context)).view(-1, n, d)\n            v = self.v(context).view(-1, n, d)\n            q_lens = cross_x_len\n            k_lens = txt_lens\n\n        cu_seqlens_q = torch.cat([q_lens.new_zeros([1]),\n                                  q_lens]).cumsum(0, dtype=torch.int32)\n        cu_seqlens_k = torch.cat([k_lens.new_zeros([1]),\n                                  k_lens]).cumsum(0, dtype=torch.int32)\n        max_seqlen_q = q_lens.max()\n        max_seqlen_k = k_lens.max()\n\n        out_dtype = q.dtype\n        q, k, v = half(q), half(k), half(v)\n        x = flash_attn_varlen_func(q,\n                                   k,\n                                   v,\n                                   cu_seqlens_q=cu_seqlens_q,\n                                   cu_seqlens_k=cu_seqlens_k,\n                                   max_seqlen_q=max_seqlen_q,\n                                   max_seqlen_k=max_seqlen_k,\n                                   dropout_p=self.attn_drop.p,\n                                   softmax_scale=self.softmax_scale,\n                                   causal=self.causal,\n                                   window_size=self.window_size,\n                                   deterministic=self.deterministic)\n\n        x = x.type(out_dtype)\n        x = x.reshape(-1, n * d)\n        x = self.o(x)\n        x = self.dropout(x)\n        return x\n\n    def forward(self, x, context=None, **kwargs):\n        x = getattr(self, self.backend)(x, context=context, **kwargs)\n        return x\n\n\nclass T2IFinalLayer(nn.Module):\n    \"\"\"\n    The final layer of PixArt.\n    \"\"\"\n    def __init__(self, hidden_size, patch_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size,\n                                       elementwise_affine=False,\n                                       eps=1e-6)\n        self.linear = nn.Linear(hidden_size,\n                                patch_size * patch_size * out_channels,\n                                bias=True)\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(2, hidden_size) / hidden_size**0.5)\n        self.out_channels = out_channels\n\n    def forward(self, x, t):\n        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2,\n                                                                         dim=1)\n        shift, scale = shift.squeeze(1), scale.squeeze(1)\n        x = modulate(self.norm_final(x), shift, scale)\n        x = self.linear(x)\n        return x"
  },
  {
    "path": "modules/model/backbone/pos_embed.py",
    "content": "import numpy as np\nfrom einops import rearrange\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_sequence\n\ndef frame_pad(x, seq_len, shapes):\n    max_h, max_w = np.max(shapes, 0)\n    frames = []\n    cur_len = 0\n    for h, w in shapes:\n        frame_len = h * w\n        frames.append(\n            F.pad(\n                x[cur_len:cur_len + frame_len].view(h, w, -1),\n                (0, 0, 0, max_w - w, 0, max_h - h))  # .view(max_h * max_w, -1)\n        )\n        cur_len += frame_len\n        if cur_len >= seq_len:\n            break\n    return torch.stack(frames)\n\n\ndef frame_unpad(x, shapes):\n    max_h, max_w = np.max(shapes, 0)\n    x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)\n    frames = []\n    for i, (h, w) in enumerate(shapes):\n        if i >= len(x):\n            break\n        frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))\n    return torch.concat(frames)\n\n\n@amp.autocast(enabled=False)\ndef rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):\n    \"\"\"\n    x:          [B*L, N, C].\n    x_lens:     [B].\n    x_shapes:   [B, F, 2].\n    freqs:      [M, C // 2].\n    \"\"\"\n    n, c = x.size(1), x.size(2) // 2\n    # split freqs\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n    # loop over samples\n    output = []\n    st = 0\n    for i, (seq_len,\n            shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):\n        x_i = frame_pad(x[st:st + seq_len], seq_len, shapes)  # f, h, w, c\n        f, h, w = x_i.shape[:3]\n        pad_seq_len = f * h * w\n        # precompute multipliers\n        x_i = torch.view_as_complex(\n            x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))\n        freqs_i = torch.cat([\n            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n        ],\n                            dim=-1).reshape(pad_seq_len, 1, -1)\n        # apply rotary embedding\n        x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)\n        x_i = frame_unpad(x_i, shapes)\n        # append to collection\n        output.append(x_i)\n        st += seq_len\n    return pad_sequence(output) if pad else torch.concat(output)\n\n\n@amp.autocast(enabled=False)\ndef rope_params(max_seq_len, dim, theta=10000):\n    \"\"\"\n    Precompute the frequency tensor for complex exponentials.\n    \"\"\"\n    assert dim % 2 == 0\n    freqs = torch.outer(\n        torch.arange(max_seq_len),\n        1.0 / torch.pow(theta,\n                        torch.arange(0, dim, 2).to(torch.float64).div(dim)))\n    freqs = torch.polar(torch.ones_like(freqs), freqs)\n    return freqs"
  },
  {
    "path": "modules/model/diffusion/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nfrom .diffusions import ACEDiffusion\nfrom .samplers import DDIMSampler\nfrom .schedules import LinearScheduler"
  },
  {
    "path": "modules/model/diffusion/diffusions.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport os\nfrom collections import OrderedDict\n\nimport torch\nfrom tqdm import trange\n\nfrom scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,\n                                            NOISE_SCHEDULERS)\nfrom scepter.modules.utils.config import Config, dict_to_yaml\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.file_system import FS\n\n\n@DIFFUSIONS.register_class()\nclass ACEDiffusion(object):\n    para_dict = {\n        'NOISE_SCHEDULER': {},\n        'SAMPLER_SCHEDULER': {},\n        'PREDICTION_TYPE': {\n            'value': 'eps',\n            'description':\n            'The type of prediction to use for the loss function.'\n        }\n    }\n\n    def __init__(self, cfg, logger=None):\n        super(ACEDiffusion, self).__init__()\n        self.logger = logger\n        self.cfg = cfg\n        self.init_params()\n\n    def init_params(self):\n        self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')\n        self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,\n                                                      logger=self.logger)\n        self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(\n            'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),\n                                                        logger=self.logger)\n        self.num_timesteps = self.noise_scheduler.num_timesteps\n        if self.cfg.have('WORK_DIR') and we.rank == 0:\n            schedule_visualization = os.path.join(self.cfg.WORK_DIR,\n                                                  'noise_schedule.png')\n            with FS.put_to(schedule_visualization) as local_path:\n                self.noise_scheduler.plot_noise_sampling_map(local_path)\n            schedule_visualization = os.path.join(self.cfg.WORK_DIR,\n                                                  'sampler_schedule.png')\n            with FS.put_to(schedule_visualization) as local_path:\n                self.sampler_scheduler.plot_noise_sampling_map(local_path)\n\n    def sample(self,\n               noise,\n               model,\n               model_kwargs={},\n               steps=20,\n               sampler=None,\n               use_dynamic_cfg=False,\n               guide_scale=None,\n               guide_rescale=None,\n               show_progress=False,\n               return_intermediate=None,\n               intermediate_callback=None,\n               reverse_scale = -1.,\n               x = None,\n               **kwargs):\n        assert isinstance(steps, (int, torch.LongTensor))\n        assert return_intermediate in (None, 'x0', 'xt')\n        assert isinstance(sampler, (str, dict, Config))\n        intermediates = []\n\n        def callback_fn(x_t, t, sigma=None, alpha_bar=None):\n            timestamp = t\n            t = t.repeat(len(x_t)).round().long().to(x_t.device)\n            sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))\n            alpha_bar = alpha_bar.repeat(len(x_t), *([1] * (len(alpha_bar.shape) - 1)))\n\n            if guide_scale is None or guide_scale == 1.0:\n                out = model(x=x_t, t=t, **model_kwargs)\n            else:\n                if use_dynamic_cfg:\n                    guidance_scale = 1 + guide_scale * (\n                        (1 - math.cos(math.pi * (\n                            (steps - timestamp.item()) / steps)**5.0)) / 2)\n                else:\n                    guidance_scale = guide_scale\n                y_out = model(x=x_t, t=t, **model_kwargs[0])\n                u_out = model(x=x_t, t=t, **model_kwargs[1])\n                out = u_out + guidance_scale * (y_out - u_out)\n            if guide_rescale is not None and guide_rescale > 0.0:\n                ratio = (\n                    y_out.flatten(1).std(dim=1) /\n                    (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *\n                                                              (y_out.ndim - 1))\n                out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0\n\n            if self.prediction_type == 'x0':\n                x0 = out\n            elif self.prediction_type == 'eps':\n                x0 = (x_t - sigma * out) / alpha_bar\n            elif self.prediction_type == 'v':\n                x0 = alpha_bar * x_t - sigma * out\n            else:\n                raise NotImplementedError(\n                    f'prediction_type {self.prediction_type} not implemented')\n            return x0\n\n        sampler_ins = self.get_sampler(sampler)\n\n        # this is ignored for schnell\n        sampler_output = sampler_ins.preprare_sampler(\n            noise,\n            x = x,\n            steps=steps,\n            reverse_scale= reverse_scale,\n            prediction_type=self.prediction_type,\n            scheduler_ins=self.sampler_scheduler,\n            callback_fn=callback_fn)\n\n        for _ in trange(sampler_output.steps, disable=not show_progress):\n            trange.desc = sampler_output.msg\n            sampler_output = sampler_ins.step(sampler_output)\n            if return_intermediate == 'x_0':\n                intermediates.append(sampler_output.x_0)\n            elif return_intermediate == 'x_t':\n                intermediates.append(sampler_output.x_t)\n            if intermediate_callback is not None:\n                intermediate_callback(intermediates[-1])\n        return (sampler_output.x_0, intermediates\n                ) if return_intermediate is not None else sampler_output.x_0\n\n    def loss(self,\n             x_0,\n             model,\n             model_kwargs={},\n             reduction='mean',\n             noise=None,\n             **kwargs):\n        # use noise scheduler to add noise\n        if noise is None:\n            noise = torch.randn_like(x_0)\n        schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)\n        x_t, t, sigma, alpha_bar = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha_bar\n        out = model(x=x_t, t=t, **model_kwargs)\n\n        # mse loss\n        target = {\n            'eps': noise,\n            'x0': x_0,\n            'v': alpha_bar * noise - sigma * x_0\n        }[self.prediction_type]\n\n        loss = (out - target).pow(2)\n        if reduction == 'mean':\n            loss = loss.flatten(1).mean(dim=1)\n        return loss\n\n    def get_sampler(self, sampler):\n        if isinstance(sampler, str):\n            if sampler not in DIFFUSION_SAMPLERS.class_map:\n                if self.logger is not None:\n                    self.logger.info(\n                        f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'\n                    )\n                else:\n                    print(\n                        f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'\n                    )\n                return None\n            sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)\n            sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,\n                                                   logger=self.logger)\n        elif isinstance(sampler, (Config, dict, OrderedDict)):\n            if isinstance(sampler, (dict, OrderedDict)):\n                sampler = Config(\n                    cfg_dict={k.upper(): v\n                              for k, v in dict(sampler).items()},\n                    load=False)\n            sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)\n        else:\n            raise NotImplementedError\n        return sampler_ins\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}' + ' ' + super().__repr__()\n\n    @staticmethod\n    def get_config_template():\n        return dict_to_yaml('DIFFUSIONS',\n                            __class__.__name__,\n                            ACEDiffusion.para_dict,\n                            set_name=True)"
  },
  {
    "path": "modules/model/diffusion/samplers.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\n\nfrom scepter.modules.model.registry import DIFFUSION_SAMPLERS\nfrom scepter.modules.model.diffusion.samplers import BaseDiffusionSampler\nfrom scepter.modules.model.diffusion.util import _i\n\ndef _i(tensor, t, x):\n    \"\"\"\n    Index tensor using t and format the output according to x.\n    \"\"\"\n    shape = (x.size(0), ) + (1, ) * (x.ndim - 1)\n    if isinstance(t, torch.Tensor):\n        t = t.to(tensor.device)\n    return tensor[t].view(shape).to(x.device)\n\n\n@DIFFUSION_SAMPLERS.register_class('ddim')\nclass DDIMSampler(BaseDiffusionSampler):\n    def init_params(self):\n        super().init_params()\n        self.eta = self.cfg.get('ETA', 0.)\n        self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',\n                                                'trailing')\n\n    def preprare_sampler(self,\n                         noise,\n                         x=None,\n                         steps=20,\n                         reverse_scale = -1.,\n                         scheduler_ins=None,\n                         prediction_type='',\n                         sigmas=None,\n                         betas=None,\n                         alphas=None,\n                         alphas_bar=None,\n                         callback_fn=None,\n                         **kwargs):\n        output = super().preprare_sampler(noise,\n                                          x = x,\n                                          steps = steps,\n                                          reverse_scale = reverse_scale,\n                                          scheduler_ins = scheduler_ins,\n                                          prediction_type = prediction_type,\n                                          sigmas = sigmas,\n                                          betas = betas,\n                                          alphas = alphas,\n                                          alphas_bar = alphas_bar,\n                                          callback_fn = callback_fn,\n                                          **kwargs)\n        sigmas = output.sigmas\n        sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])\n        sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5\n        sigmas_vp[sigmas == float('inf')] = 1.\n        output.add_custom_field('sigmas_vp', sigmas_vp)\n        output.steps += 1\n        return output\n\n    def step(self, sampler_output):\n        x_t = sampler_output.x_t\n        step = sampler_output.step\n        t = sampler_output.ts[step]\n        sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)\n        alpha_bar_init = _i(sampler_output.alphas_bar_init, step, x_t[:1])\n        sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])\n\n        x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_bar_init)\n        noise_factor = self.eta * (sigmas_vp[step + 1]**2 /\n                                   sigmas_vp[step]**2 *\n                                   (1 - (1 - sigmas_vp[step]**2) /\n                                    (1 - sigmas_vp[step + 1]**2)))\n        d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]\n        x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \\\n            (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d\n        sampler_output.x_0 = x\n        if sigmas_vp[step + 1] > 0:\n            x += noise_factor * torch.randn_like(x)\n        sampler_output.x_t = x\n        sampler_output.step += 1\n        sampler_output.msg = f'step {step}'\n        return sampler_output"
  },
  {
    "path": "modules/model/diffusion/schedules.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nfrom dataclasses import dataclass, field\nfrom scepter.modules.model.registry import NOISE_SCHEDULERS\nfrom scepter.modules.model.diffusion.schedules import BaseNoiseScheduler\nfrom scepter.modules.model.diffusion.util import _i\n\n@dataclass\nclass ScheduleOutput(object):\n    x_t: torch.Tensor\n    x_0: torch.Tensor\n    t: torch.Tensor\n    sigma: torch.Tensor\n    alpha_bar: torch.Tensor\n    custom_fields: dict = field(default_factory=dict)\n\n    def add_custom_field(self, key: str, value) -> None:\n        self.__setattr__(key, value)\n\n\n@NOISE_SCHEDULERS.register_class()\nclass LinearScheduler(BaseNoiseScheduler):\n    para_dict = {}\n\n    def init_params(self):\n        super().init_params()\n        self.beta_min = self.cfg.get('BETA_MIN', 0.00085)\n        self.beta_max = self.cfg.get('BETA_MAX', 0.012)\n\n    def betas_to_sigmas(self, betas):\n        return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))\n\n    def get_schedule(self):\n        betas = torch.linspace(self.beta_min,\n                               self.beta_max,\n                               self.num_timesteps,\n                               dtype=torch.float32)\n        sigmas = self.betas_to_sigmas(betas)\n        self._sigmas = sigmas\n        self._betas = betas\n        self._alphas = torch.sqrt(1 - betas**2)\n        self._alphas_bar = torch.sqrt(1 - sigmas**2)\n        self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)\n\n    def add_noise(self, x_0, noise=None, t=None, **kwargs):\n        if t is None:\n            t = torch.randint(0,\n                              self.num_timesteps, (x_0.shape[0], ),\n                              device=x_0.device).long()\n        alpha = _i(self.alphas, t, x_0)\n        sigma = _i(self.sigmas, t, x_0)\n        x_t = alpha * x_0 + sigma * noise\n\n        return ScheduleOutput(x_0=x_0, x_t=x_t, t=t, alpha_bar=alpha, sigma=sigma)"
  },
  {
    "path": "modules/model/embedder/__init__.py",
    "content": "from .embedder import ACETextEmbedder"
  },
  {
    "path": "modules/model/embedder/embedder.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport warnings\nfrom contextlib import nullcontext\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.dlpack\nfrom scepter.modules.model.embedder.base_embedder import BaseEmbedder\nfrom scepter.modules.model.registry import EMBEDDERS\nfrom scepter.modules.model.tokenizer.tokenizer_component import (\n    basic_clean, canonicalize, heavy_clean, whitespace_clean)\nfrom scepter.modules.utils.config import dict_to_yaml\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.file_system import FS\n\ntry:\n    from transformers import AutoTokenizer, T5EncoderModel\nexcept Exception as e:\n    warnings.warn(\n        f'Import transformers error, please deal with this problem: {e}')\n\n\n@EMBEDDERS.register_class()\nclass ACETextEmbedder(BaseEmbedder):\n    \"\"\"\n    Uses the OpenCLIP transformer encoder for text\n    \"\"\"\n    \"\"\"\n        Uses the OpenCLIP transformer encoder for text\n        \"\"\"\n    para_dict = {\n        'PRETRAINED_MODEL': {\n            'value':\n            'google/umt5-small',\n            'description':\n            'Pretrained Model for umt5, modelcard path or local path.'\n        },\n        'TOKENIZER_PATH': {\n            'value': 'google/umt5-small',\n            'description':\n            'Tokenizer Path for umt5, modelcard path or local path.'\n        },\n        'FREEZE': {\n            'value': True,\n            'description': ''\n        },\n        'USE_GRAD': {\n            'value': False,\n            'description': 'Compute grad or not.'\n        },\n        'CLEAN': {\n            'value':\n            'whitespace',\n            'description':\n            'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'\n        },\n        'LAYER': {\n            'value': 'last',\n            'description': ''\n        },\n        'LEGACY': {\n            'value':\n            True,\n            'description':\n            'Whether use legacy returnd feature or not ,default True.'\n        }\n    }\n\n    def __init__(self, cfg, logger=None):\n        super().__init__(cfg, logger=logger)\n        pretrained_path = cfg.get('PRETRAINED_MODEL', None)\n        self.t5_dtype = cfg.get('T5_DTYPE', 'float32')\n        assert pretrained_path\n        with FS.get_dir_to_local_dir(pretrained_path,\n                                     wait_finish=True) as local_path:\n            self.model = T5EncoderModel.from_pretrained(\n                local_path,\n                torch_dtype=getattr(\n                    torch,\n                    'float' if self.t5_dtype == 'float32' else self.t5_dtype))\n        tokenizer_path = cfg.get('TOKENIZER_PATH', None)\n        self.length = cfg.get('LENGTH', 77)\n\n        self.use_grad = cfg.get('USE_GRAD', False)\n        self.clean = cfg.get('CLEAN', 'whitespace')\n        self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)\n        if tokenizer_path:\n            self.tokenize_kargs = {'return_tensors': 'pt'}\n            with FS.get_dir_to_local_dir(tokenizer_path,\n                                         wait_finish=True) as local_path:\n                if self.added_identifier is not None and isinstance(\n                        self.added_identifier, list):\n                    self.tokenizer = AutoTokenizer.from_pretrained(local_path)\n                else:\n                    self.tokenizer = AutoTokenizer.from_pretrained(local_path)\n            if self.length is not None:\n                self.tokenize_kargs.update({\n                    'padding': 'max_length',\n                    'truncation': True,\n                    'max_length': self.length\n                })\n            self.eos_token = self.tokenizer(\n                self.tokenizer.eos_token)['input_ids'][0]\n        else:\n            self.tokenizer = None\n            self.tokenize_kargs = {}\n\n        self.use_grad = cfg.get('USE_GRAD', False)\n        self.clean = cfg.get('CLEAN', 'whitespace')\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    # encode && encode_text\n    def forward(self, tokens, return_mask=False, use_mask=True):\n        # tokenization\n        embedding_context = nullcontext if self.use_grad else torch.no_grad\n        with embedding_context():\n            if use_mask:\n                x = self.model(tokens.input_ids.to(we.device_id),\n                               tokens.attention_mask.to(we.device_id))\n            else:\n                x = self.model(tokens.input_ids.to(we.device_id))\n            x = x.last_hidden_state\n\n            if return_mask:\n                return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)\n            else:\n                return x.detach() + 0.0, None\n\n    def _clean(self, text):\n        if self.clean == 'whitespace':\n            text = whitespace_clean(basic_clean(text))\n        elif self.clean == 'lower':\n            text = whitespace_clean(basic_clean(text)).lower()\n        elif self.clean == 'canonicalize':\n            text = canonicalize(basic_clean(text))\n        elif self.clean == 'heavy':\n            text = heavy_clean(basic_clean(text))\n        return text\n\n    def encode(self, text, return_mask=False, use_mask=True):\n        if isinstance(text, str):\n            text = [text]\n        if self.clean:\n            text = [self._clean(u) for u in text]\n        assert self.tokenizer is not None\n        cont, mask = [], []\n        with torch.autocast(device_type='cuda',\n                            enabled=self.t5_dtype in ('float16', 'bfloat16'),\n                            dtype=getattr(torch, self.t5_dtype)):\n            for tt in text:\n                tokens = self.tokenizer([tt], **self.tokenize_kargs)\n                one_cont, one_mask = self(tokens,\n                                          return_mask=return_mask,\n                                          use_mask=use_mask)\n                cont.append(one_cont)\n                mask.append(one_mask)\n        if return_mask:\n            return torch.cat(cont, dim=0), torch.cat(mask, dim=0)\n        else:\n            return torch.cat(cont, dim=0)\n\n    def encode_list(self, text_list, return_mask=True):\n        cont_list = []\n        mask_list = []\n        for pp in text_list:\n            cont, cont_mask = self.encode(pp, return_mask=return_mask)\n            cont_list.append(cont)\n            mask_list.append(cont_mask)\n        if return_mask:\n            return cont_list, mask_list\n        else:\n            return cont_list\n\n    @staticmethod\n    def get_config_template():\n        return dict_to_yaml('MODELS',\n                            __class__.__name__,\n                            ACETextEmbedder.para_dict,\n                            set_name=True)"
  },
  {
    "path": "modules/model/network/__init__.py",
    "content": "from .ldm_ace import LdmACE"
  },
  {
    "path": "modules/model/network/ldm_ace.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport random\nfrom contextlib import nullcontext\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom scepter.modules.model.network.ldm import LatentDiffusion\nfrom scepter.modules.model.registry import MODELS\nimport torchvision.transforms as T\nfrom scepter.modules.model.utils.basic_utils import check_list_of_list\nfrom scepter.modules.model.utils.basic_utils import \\\n    pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor\nfrom scepter.modules.model.utils.basic_utils import (\n    to_device, unpack_tensor_into_imagelist)\nfrom scepter.modules.utils.config import dict_to_yaml\nfrom scepter.modules.utils.distribute import we\n\n\nclass TextEmbedding(nn.Module):\n    def __init__(self, embedding_shape):\n        super().__init__()\n        self.pos = nn.Parameter(data=torch.zeros(embedding_shape))\n\n\n@MODELS.register_class()\nclass LdmACE(LatentDiffusion):\n    para_dict = LatentDiffusion.para_dict\n    para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''}\n\n    def __init__(self, cfg, logger=None):\n        super().__init__(cfg, logger=logger)\n        self.interpolate_func = lambda x: (F.interpolate(\n            x.unsqueeze(0),\n            scale_factor=1 / self.size_factor,\n            mode='nearest-exact') if x is not None else None)\n\n        self.text_indentifers = cfg.get('TEXT_IDENTIFIER', [])\n        self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS',\n                                               False)\n        if self.use_text_pos_embeddings:\n            self.text_position_embeddings = TextEmbedding(\n                (10, 4096)).eval().requires_grad_(False)\n        else:\n            self.text_position_embeddings = None\n\n        self.logger.info(self.model)\n\n    @torch.no_grad()\n    def encode_first_stage(self, x, **kwargs):\n        return [\n            self.scale_factor *\n            self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16))\n            for i in x\n        ]\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        return [\n            self.first_stage_model._decode(1. / self.scale_factor *\n                                           i.to(torch.float16)) for i in z\n        ]\n\n    def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):\n        if self.use_text_pos_embeddings and not torch.sum(\n                self.text_position_embeddings.pos) > 0:\n            identifier_cont, identifier_cont_mask = getattr(\n                self.cond_stage_model, 'encode_list_of_list')(self.text_indentifers,\n                                                 return_mask=True)\n            self.text_position_embeddings.load_state_dict(\n                {'pos': torch.cat( [one_id[0][0, :].unsqueeze(0) for one_id in identifier_cont], dim=0)})\n        cont_, cont_mask_ = [], []\n        for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):\n            if isinstance(pp, list):\n                cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])\n                cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])\n            else:\n                raise NotImplementedError\n\n        return cont_, cont_mask_\n\n    def limit_batch_data(self, batch_data_list, log_num):\n        if log_num and log_num > 0:\n            batch_data_list_limited = []\n            for sub_data in batch_data_list:\n                if sub_data is not None:\n                    sub_data = sub_data[:log_num]\n                batch_data_list_limited.append(sub_data)\n            return batch_data_list_limited\n        else:\n            return batch_data_list\n\n    def forward_train(self,\n                      edit_image=[],\n                      edit_image_mask=[],\n                      image=None,\n                      image_mask=None,\n                      noise=None,\n                      prompt=[],\n                      **kwargs):\n        '''\n        Args:\n            edit_image: list of list of edit_image\n            edit_image_mask: list of list of edit_image_mask\n            image: target image\n            image_mask: target image mask\n            noise: default is None, generate automaticly\n            prompt: list of list of text\n            **kwargs:\n        Returns:\n        '''\n        assert check_list_of_list(prompt) and check_list_of_list(\n            edit_image) and check_list_of_list(edit_image_mask)\n        assert len(edit_image) == len(edit_image_mask) == len(prompt)\n        assert self.cond_stage_model is not None\n        gc_seg = kwargs.pop('gc_seg', [])\n        gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0\n        context = {}\n\n        # process image\n        image = to_device(image)\n        x_start = self.encode_first_stage(image, **kwargs)\n        x_start, x_shapes = pack_imagelist_into_tensor(x_start)  # B, C, L\n        n, _, _ = x_start.shape\n        t = torch.randint(0, self.num_timesteps, (n, ),\n                          device=x_start.device).long()\n        context['x_shapes'] = x_shapes\n\n        # process image mask\n        image_mask = to_device(image_mask, strict=False)\n        context['x_mask'] = [self.interpolate_func(i) for i in image_mask\n                             ] if image_mask is not None else [None] * n\n\n        # process text\n        # with torch.autocast(device_type=\"cuda\", enabled=True, dtype=torch.bfloat16):\n        prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]\n        try:\n            cont, cont_mask = getattr(self.cond_stage_model,\n                                      'encode_list_of_list')(prompt_, return_mask=True)\n        except Exception as e:\n            print(e, prompt_)\n        cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,\n                                                     cont_mask)\n        context['crossattn'] = cont\n\n        # process edit image & edit image mask\n        edit_image = [to_device(i, strict=False) for i in edit_image]\n        edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]\n        e_img, e_mask = [], []\n        for u, m in zip(edit_image, edit_image_mask):\n            if m is None:\n                m = [None] * len(u) if u is not None else [None]\n            e_img.append(\n                self.encode_first_stage(u, **kwargs) if u is not None else u)\n            e_mask.append([\n                self.interpolate_func(i) if i is not None else None for i in m\n            ])\n        context['edit'], context['edit_mask'] = e_img, e_mask\n\n        # process loss\n        loss = self.diffusion.loss(\n            x_0=x_start,\n            t=t,\n            noise=noise,\n            model=self.model,\n            model_kwargs={\n                'cond':\n                context,\n                'mask':\n                cont_mask,\n                'gc_seg':\n                gc_seg,\n                'text_position_embeddings':\n                self.text_position_embeddings.pos if hasattr(\n                    self.text_position_embeddings, 'pos') else None\n            },\n            **kwargs)\n        loss = loss.mean()\n        ret = {'loss': loss, 'probe_data': {'prompt': prompt}}\n        return ret\n\n    @torch.no_grad()\n    def forward_test(self,\n                     edit_image=[],\n                     edit_image_mask=[],\n                     image=None,\n                     image_mask=None,\n                     prompt=[],\n                     n_prompt=[],\n                     sampler='ddim',\n                     sample_steps=20,\n                     guide_scale=4.5,\n                     guide_rescale=0.5,\n                     log_num=-1,\n                     seed=2024,\n                     **kwargs):\n\n        assert check_list_of_list(prompt) and check_list_of_list(\n            edit_image) and check_list_of_list(edit_image_mask)\n        assert len(edit_image) == len(edit_image_mask) == len(prompt)\n        assert self.cond_stage_model is not None\n        # gc_seg is unused\n        kwargs.pop('gc_seg', -1)\n        # prepare data\n        context, null_context = {}, {}\n\n        prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data(\n            [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask],\n            log_num)\n        g = torch.Generator(device=we.device_id)\n        seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)\n        g.manual_seed(seed)\n        n_prompt = copy.deepcopy(prompt)\n        # only modify the last prompt to be zero\n        for nn_p_id, nn_p in enumerate(n_prompt):\n            if isinstance(nn_p, str):\n                n_prompt[nn_p_id] = ['']\n            elif isinstance(nn_p, list):\n                n_prompt[nn_p_id][-1] = ''\n            else:\n                raise NotImplementedError\n        # process image\n        image = to_device(image)\n        x = self.encode_first_stage(image, **kwargs)\n        noise = [\n            torch.empty(*i.shape, device=we.device_id).normal_(generator=g)\n            for i in x\n        ]\n        noise, x_shapes = pack_imagelist_into_tensor(noise)\n        context['x_shapes'] = null_context['x_shapes'] = x_shapes\n\n        # process image mask\n        image_mask = to_device(image_mask, strict=False)\n        cond_mask = [self.interpolate_func(i) for i in image_mask\n                     ] if image_mask is not None else [None] * len(image)\n        context['x_mask'] = null_context['x_mask'] = cond_mask\n        # process text\n        # with torch.autocast(device_type=\"cuda\", enabled=True, dtype=torch.bfloat16):\n        prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]\n        cont, cont_mask = getattr(self.cond_stage_model,\n                                  'encode_list_of_list')(prompt_, return_mask=True)\n        cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,\n                                                     cont_mask)\n        null_cont, null_cont_mask = getattr(self.cond_stage_model,\n                                            'encode_list_of_list')(n_prompt,\n                                                           return_mask=True)\n        null_cont, null_cont_mask = self.cond_stage_embeddings(\n            prompt, edit_image, null_cont, null_cont_mask)\n        context['crossattn'] = cont\n        null_context['crossattn'] = null_cont\n\n        # processe edit image & edit image mask\n        edit_image = [to_device(i, strict=False) for i in edit_image]\n        edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]\n        e_img, e_mask = [], []\n        for u, m in zip(edit_image, edit_image_mask):\n            if u is None:\n                continue\n            if m is None:\n                m = [None] * len(u)\n            e_img.append(self.encode_first_stage(u, **kwargs))\n            e_mask.append([self.interpolate_func(i) for i in m])\n        null_context['edit'] = context['edit'] = e_img\n        null_context['edit_mask'] = context['edit_mask'] = e_mask\n\n        # process sample\n        model = self.model_ema if self.use_ema and self.eval_ema else self.model\n        embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \\\n            else nullcontext\n        with embedding_context():\n            samples = self.diffusion.sample(\n                sampler=sampler,\n                noise=noise,\n                model=model,\n                model_kwargs=[{\n                    'cond':\n                    context,\n                    'mask':\n                    cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                }, {\n                    'cond':\n                    null_context,\n                    'mask':\n                    null_cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                }] if guide_scale is not None and guide_scale > 1 else {\n                    'cond':\n                    context,\n                    'mask':\n                    cont_mask,\n                    'text_position_embeddings':\n                    self.text_position_embeddings.pos if hasattr(\n                        self.text_position_embeddings, 'pos') else None\n                },\n                steps=sample_steps,\n                guide_scale=guide_scale,\n                guide_rescale=guide_rescale,\n                show_progress=True,\n                **kwargs)\n\n        samples = unpack_tensor_into_imagelist(samples, x_shapes)\n        x_samples = self.decode_first_stage(samples)\n        outputs = list()\n        for i in range(len(prompt)):\n            rec_img = torch.clamp(\n                (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255,\n                min=0.0,\n                max=1.0)\n            rec_img = rec_img.squeeze(0)\n            edit_imgs, edit_img_masks = [], []\n            if edit_image is not None and edit_image[i] is not None:\n                if edit_image_mask[i] is None:\n                    edit_image_mask[i] = [None] * len(edit_image[i])\n                for edit_img, edit_mask in zip(edit_image[i],\n                                               edit_image_mask[i]):\n                    edit_img = torch.clamp((edit_img + 1.0) / 2.0,\n                                           min=0.0,\n                                           max=1.0)\n                    edit_imgs.append(edit_img.squeeze(0))\n                    if edit_mask is None:\n                        edit_mask = torch.ones_like(edit_img[[0], :, :])\n                    edit_img_masks.append(edit_mask)\n            one_tup = {\n                'reconstruct_image': rec_img,\n                'instruction': prompt[i],\n                'edit_image': edit_imgs if len(edit_imgs) > 0 else None,\n                'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None\n            }\n            if image is not None:\n                if image_mask is None:\n                    image_mask = [None] * len(image)\n                ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)\n                one_tup['target_image'] = ori_img.squeeze(0)\n                one_tup['target_mask'] = image_mask[i] if image_mask[\n                    i] is not None else torch.ones_like(ori_img[[0], :, :])\n            outputs.append(one_tup)\n        return outputs\n\n    @staticmethod\n    def get_config_template():\n        return dict_to_yaml('MODEL',\n                            __class__.__name__,\n                            LdmACE.para_dict,\n                            set_name=True)\n"
  },
  {
    "path": "modules/model/utils/basic_utils.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom inspect import isfunction\n\nimport torch\nfrom torch.nn.utils.rnn import pad_sequence\n\nfrom scepter.modules.utils.distribute import we\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef transfer_size(para_num):\n    if para_num > 1000 * 1000 * 1000 * 1000:\n        bill = para_num / (1000 * 1000 * 1000 * 1000)\n        return '{:.2f}T'.format(bill)\n    elif para_num > 1000 * 1000 * 1000:\n        gyte = para_num / (1000 * 1000 * 1000)\n        return '{:.2f}B'.format(gyte)\n    elif para_num > (1000 * 1000):\n        meta = para_num / (1000 * 1000)\n        return '{:.2f}M'.format(meta)\n    elif para_num > 1000:\n        kelo = para_num / 1000\n        return '{:.2f}K'.format(kelo)\n    else:\n        return para_num\n\n\ndef count_params(model):\n    total_params = sum(p.numel() for p in model.parameters())\n    return transfer_size(total_params)\n\n\ndef expand_dims_like(x, y):\n    while x.dim() != y.dim():\n        x = x.unsqueeze(-1)\n    return x\n\n\ndef unpack_tensor_into_imagelist(image_tensor, shapes):\n    image_list = []\n    for img, shape in zip(image_tensor, shapes):\n        h, w = shape[0], shape[1]\n        image_list.append(img[:, :h * w].view(1, -1, h, w))\n\n    return image_list\n\n\ndef find_example(tensor_list, image_list):\n    for i in tensor_list:\n        if isinstance(i, torch.Tensor):\n            return torch.zeros_like(i)\n    for i in image_list:\n        if isinstance(i, torch.Tensor):\n            _, c, h, w = i.size()\n            return torch.zeros_like(i.view(c, h * w).transpose(1, 0))\n    return None\n\n\ndef pack_imagelist_into_tensor_v2(image_list):\n    # allow None\n    example = None\n    image_tensor, shapes = [], []\n    for img in image_list:\n        if img is None:\n            example = find_example(image_tensor,\n                                   image_list) if example is None else example\n            image_tensor.append(example)\n            shapes.append(None)\n            continue\n        _, c, h, w = img.size()\n        image_tensor.append(img.view(c, h * w).transpose(1, 0))  # h*w, c\n        shapes.append((h, w))\n\n    image_tensor = pad_sequence(image_tensor,\n                                batch_first=True).permute(0, 2, 1)  # b, c, l\n    return image_tensor, shapes\n\n\ndef to_device(inputs, strict=True):\n    if inputs is None:\n        return None\n    if strict:\n        assert all(isinstance(i, torch.Tensor) for i in inputs)\n    return [i.to(we.device_id) if i is not None else None for i in inputs]\n\n\ndef check_list_of_list(ll):\n    return isinstance(ll, list) and all(isinstance(i, list) for i in ll)\n"
  },
  {
    "path": "modules/solver/__init__.py",
    "content": "from .ace_solver import ACESolverV1"
  },
  {
    "path": "modules/solver/ace_solver.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom scepter.modules.utils.data import transfer_data_to_cuda\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.probe import ProbeData\nfrom scepter.modules.solver.registry import SOLVERS\nfrom scepter.modules.solver.diffusion_solver import LatentDiffusionSolver\n\n\n\n@SOLVERS.register_class()\nclass ACESolverV1(LatentDiffusionSolver):\n    def __init__(self, cfg, logger=None):\n        super().__init__(cfg, logger=logger)\n        self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)\n\n    def save_results(self, results):\n        log_data, log_label = [], []\n        for result in results:\n            ret_images, ret_labels = [], []\n            edit_image = result.get('edit_image', None)\n            edit_mask = result.get('edit_mask', None)\n            if edit_image is not None:\n                for i, edit_img in enumerate(result['edit_image']):\n                    if edit_img is None:\n                        continue\n                    ret_images.append(\n                        (edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(\n                            np.uint8))\n                    ret_labels.append(f'edit_image{i}; ')\n                    if edit_mask is not None:\n                        ret_images.append(\n                            (edit_mask[i].permute(1, 2, 0).cpu().numpy() *\n                             255).astype(np.uint8))\n                        ret_labels.append(f'edit_mask{i}; ')\n\n            target_image = result.get('target_image', None)\n            target_mask = result.get('target_mask', None)\n            if target_image is not None:\n                ret_images.append(\n                    (target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(\n                        np.uint8))\n                ret_labels.append('target_image; ')\n                if target_mask is not None:\n                    ret_images.append(\n                        (target_mask.permute(1, 2, 0).cpu().numpy() *\n                         255).astype(np.uint8))\n                    ret_labels.append('target_mask; ')\n\n            reconstruct_image = result.get('reconstruct_image', None)\n            if reconstruct_image is not None:\n                ret_images.append(\n                    (reconstruct_image.permute(1, 2, 0).cpu().numpy() *\n                     255).astype(np.uint8))\n                ret_labels.append(f\"{result['instruction']}\")\n            log_data.append(ret_images)\n            log_label.append(ret_labels)\n        return log_data, log_label\n\n    @torch.no_grad()\n    def run_eval(self):\n        self.eval_mode()\n        self.before_all_iter(self.hooks_dict[self._mode])\n        all_results = []\n        for batch_idx, batch_data in tqdm(\n                enumerate(self.datas[self._mode].dataloader)):\n            self.before_iter(self.hooks_dict[self._mode])\n            if self.sample_args:\n                batch_data.update(self.sample_args.get_lowercase_dict())\n            with torch.autocast(device_type='cuda',\n                                enabled=self.use_amp,\n                                dtype=self.dtype):\n                results = self.run_step_eval(transfer_data_to_cuda(batch_data),\n                                             batch_idx,\n                                             step=self.total_iter,\n                                             rank=we.rank)\n                all_results.extend(results)\n            self.after_iter(self.hooks_dict[self._mode])\n        log_data, log_label = self.save_results(all_results)\n        self.register_probe({'eval_label': log_label})\n        self.register_probe({\n            'eval_image':\n            ProbeData(log_data,\n                      is_image=True,\n                      build_html=True,\n                      build_label=log_label)\n        })\n        self.after_all_iter(self.hooks_dict[self._mode])\n\n    @torch.no_grad()\n    def run_test(self):\n        self.test_mode()\n        self.before_all_iter(self.hooks_dict[self._mode])\n        all_results = []\n        for batch_idx, batch_data in tqdm(\n                enumerate(self.datas[self._mode].dataloader)):\n            self.before_iter(self.hooks_dict[self._mode])\n            if self.sample_args:\n                batch_data.update(self.sample_args.get_lowercase_dict())\n            with torch.autocast(device_type='cuda',\n                                enabled=self.use_amp,\n                                dtype=self.dtype):\n                results = self.run_step_eval(transfer_data_to_cuda(batch_data),\n                                             batch_idx,\n                                             step=self.total_iter,\n                                             rank=we.rank)\n                all_results.extend(results)\n            self.after_iter(self.hooks_dict[self._mode])\n        log_data, log_label = self.save_results(all_results)\n        self.register_probe({'test_label': log_label})\n        self.register_probe({\n            'test_image':\n            ProbeData(log_data,\n                      is_image=True,\n                      build_html=True,\n                      build_label=log_label)\n        })\n\n        self.after_all_iter(self.hooks_dict[self._mode])\n\n    @property\n    def probe_data(self):\n        if not we.debug and self.mode == 'train':\n            batch_data = transfer_data_to_cuda(\n                self.current_batch_data[self.mode])\n            self.eval_mode()\n            with torch.autocast(device_type='cuda',\n                                enabled=self.use_amp,\n                                dtype=self.dtype):\n                batch_data['log_num'] = self.log_train_num\n                results = self.run_step_eval(batch_data)\n            self.train_mode()\n            log_data, log_label = self.save_results(results)\n            self.register_probe({\n                'train_image':\n                ProbeData(log_data,\n                          is_image=True,\n                          build_html=True,\n                          build_label=log_label)\n            })\n            self.register_probe({'train_label': log_label})\n        return super(LatentDiffusionSolver, self).probe_data\n"
  },
  {
    "path": "readme.md",
    "content": "<p align=\"center\">\n\n  <h2 align=\"center\"><img src=\"assets/figures/icon.png\" height=16> : All-round Creator and Editor Following <br> Instructions via Diffusion Transformer</h2>\n\n  <p align=\"center\">\n    <a href=\"https://arxiv.org/abs/2410.00086\"><img src='https://img.shields.io/badge/arXiv-ACE-red' alt='Paper PDF'></a>\n    <a href='https://ali-vilab.github.io/ace-page'><img src='https://img.shields.io/badge/Project_Page-ACE-blue' alt='Project Page'></a>\n    <a href='https://github.com/modelscope/scepter'><img src='https://img.shields.io/badge/Scepter-ACE-green'></a>\n    <a href='https://huggingface.co/spaces/scepter-studio/ACE-Chat'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-orange'></a>\n    <a href='https://huggingface.co/scepter-studio/ACE-0.6B-512px'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-orange'></a>\n    <a href='https://www.modelscope.cn/models/iic/ACE-0.6B-512px'><img src='https://img.shields.io/badge/ModelScope-Model-purple'></a>\n    <br>\n    <strong>Zhen Han*</strong>\n    ·\n    <strong>Zeyinzi Jiang*</strong>\n    ·\n    <strong>Yulin Pan*</strong>\n    ·\n    <strong>Jingfeng Zhang*</strong>\n    ·\n    <strong>Chaojie Mao*</strong>\n    <br>\n    <strong>Chenwei Xie</strong>\n    ·\n    <strong>Yu Liu</strong>\n    ·\n    <strong>Jingren Zhou</strong>\n    <br>\n    Tongyi Lab, Alibaba Group\n  </p>\n  <table align=\"center\">\n    <tr>\n    <td>\n      <img src=\"assets/figures/teaser.png\">\n    </td>\n    </tr>\n  </table>\n\n## 📢 News\n* **[2024.9.30]** Release the paper of ACE on arxiv.\n* **[2024.10.31]** Release the ACE checkpoint on [ModelScope](https://www.modelscope.cn/models/iic/ACE-0.6B-512px) and [HuggingFace](https://huggingface.co/scepter-studio/ACE-0.6B-512px).\n* **[2024.11.1]** Support online demo on [HuggingFace](https://huggingface.co/spaces/scepter-studio/ACE-Chat).\n* **[2024.11.20]** Release the [ACE-0.6b-1024px](https://huggingface.co/scepter-studio/ACE-0.6B-1024px) model, \nwhich significantly enhances image generation quality compared with [ACE-0.6b-512px](https://huggingface.co/scepter-studio/ACE-0.6B-512px).\n* **[2025.01.06]** Release the [ACE++](https://ali-vilab.github.io/ACE_plus_page/).\n\n\n## 🚀 Installation\nInstall the necessary packages with `pip`: \n```bash\npip install -r requirements.txt\n```\n\n##  🔥 ACE Models\n|    **Model**     |                                                                                                                                                                                                            **Status**                                                                                                                                                                                                             | \n|:----------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|\n|  ACE-0.6B-512px  |          [![Demo link](https://img.shields.io/badge/Demo-ACE_Chat-purple)](https://huggingface.co/spaces/scepter-studio/ACE-Chat)<br>[![ModelScope link](https://img.shields.io/badge/ModelScope-Model-blue)](https://www.modelscope.cn/models/iic/ACE-0.6B-512px)  [![HuggingFace link](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/scepter-studio/ACE-0.6B-512px)          |\n| ACE-0.6B-1024px  | [![ModelScope link](https://img.shields.io/badge/ModelScope-Model-blue)](https://www.modelscope.cn/models/iic/ACE-0.6B-1024px)  [![HuggingFace link](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/scepter-studio/ACE-0.6B-1024px) |             |\n| ACE-12B-FLUX-dev |                                                                                                                                        The ACE model based on the FLUX.1-dev base model has adopted a new adaptation method. We have organized a new project called [ACE++](https://ali-vilab.github.io/ACE_plus_page/). The relevant models have been open-sourced. Please visit to learn more.                                                                                                                                               |             |\n\n## 🖼 Model Performance Visualization\n\nThe current model's parameters scale of ACE is 0.6B, which imposes certain limitations on the quality of image generation. [FLUX.1-Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev), on the other hand, \nhas a significant advantage in text-to-image generation quality. By using SDEdit, we can effectively leverage the generative capabilities of FLUX to further enhance the image results generated by ACE. Based on the above considerations, we have designed the ACE-Refiner pipeline, as shown in the diagram below.\n\n![ACE_REFINER](assets/ace_method/ace_refiner_process.webp)\n\nAs shown in the figure below, when the strength \nσ of the generated image is high, the generated image will suffer from fidelity loss compared to the original image. Conversely, lower \nσ does not significantly improve the image quality. Therefore, users can make a trade-off between fidelity to the generated result and the image quality based on their own needs. \nUsers can set the value of \"REFINER_SCALE\" in the configuration file `config/inference_config/models/ace_0.6b_1024_refiner.yaml`. \nWe recommend that users use the advance options in the [webui-demo](#-chat-bot-) for effect verification.\n\n![ACE_REFINER_EXAMPLE](assets/ace_method/ace_refiner.webp)\n\n\nWe compared the generation and editing performance of different models on several tasks, as shown as following. \n![Samples](assets/ace_method/samples_compare.webp)\n\n\n## 🔥 Training\n\nWe offer a demonstration training YAML that enables the end-to-end training of ACE using a toy dataset. For a comprehensive overview of the hyperparameter configurations, please consult `config/ace_0.6b_512_train.yaml`.\n\n### Prepare datasets\n\nPlease find the dataset class located in `modules/data/dataset/dataset.py`, \ndesigned to facilitate end-to-end training using an open-source toy dataset. \nDownload a dataset zip file from [modelscope](https://www.modelscope.cn/models/iic/scepter/resolve/master/datasets/hed_pair.zip), and then extract its contents into the `cache/datasets/` directory.\n\nShould you wish to prepare your own datasets, we recommend consulting `modules/data/dataset/dataset.py` for detailed guidance on the required data format.\n\n### Prepare initial weight\nThe ACE checkpoint has been uploaded to both ModelScope and HuggingFace platforms:\n* [ModelScope](https://www.modelscope.cn/models/iic/ACE-0.6B-512px)\n* [HuggingFace](https://huggingface.co/scepter-studio/ACE-0.6B-512px)\n\nIn the provided training YAML configuration, we have designated the Modelscope URL as the default checkpoint URL. Should you wish to transition to Hugging Face, you can effortlessly achieve this by modifying the PRETRAINED_MODEL value within the YAML file (replace the prefix \"ms://iic\" to \"hf://scepter-studio\").\n\n\n### Start training\n\nYou can easily start training procedure by executing the following command:\n```bash\n# ACE-0.6B-512px\nPYTHONPATH=. python tools/run_train.py --cfg config/ace_0.6b_512_train.yaml\n# ACE-0.6B-1024px\nPYTHONPATH=. python tools/run_train.py --cfg config/ace_0.6b_1024_train.yaml\n```\n\n## 🚀 Inference\n\nWe provide a simple inference demo that allows users to generate images from text descriptions.\n```bash\n PYTHONPATH=. python tools/run_inference.py --cfg config/inference_config/models/ace_0.6b_512.yaml --instruction \"make the boy cry, his eyes filled with tears\" --seed 199999 --input_image examples/input_images/example0.webp\n```\nWe recommend runing the examples for quick testing. Running the following command will run the example inference and the results will be saved in `examples/output_images/`.\n```bash\nPYTHONPATH=. python tools/run_inference.py --cfg config/inference_config/models/ace_0.6b_512.yaml\n```\n\n## 💬 Chat Bot \nWe have developed an chatbot UI utilizing Gradio, designed to transform user input in natural language into visually stunning images that align semantically with the provided instructions. Users can effortlessly initiate the chatbot app by executing the following command:\n```bash\npython chatbot/run_gradio.py --cfg chatbot/config/chatbot_ui.yaml --server_port 2024\n```\n\n<table align=\"center\">\n  <tr>\n  <td>\n    <img src=\"assets/videos/demo_chat.gif\">\n  </td>\n  </tr>\n</table>\n\n## ⚙️️ ComfyUI Workflow\n\n![Workflow](assets/comfyui/ace_example.jpg)\n\nWe support the use of ACE in the ComfyUI Workflow through the following methods:\n\n1) Automatic installation directly via the ComfyUI Manager by searching for the **ComfyUI-Scepter** node.\n2) Manually install by moving custom_nodes from Scepter to ComfyUI.\n```shell\ngit clone https://github.com/modelscope/scepter.git\ncd path/to/scepter\npip install -e .\ncp -r path/to/scepter/workflow/ path/to/ComfyUI/custom_nodes/ComfyUI-Scepter\ncd path/to/ComfyUI\npython main.py\n```\n\n**Note**: You can use the nodes by dragging the sample images below into ComfyUI. Additionally, our nodes can automatically pull models from ModelScope or HuggingFace by selecting the *model_source* field, or you can place the already downloaded models in a local path.\n\n<table><tbody>\n  <tr>\n    <th align=\"center\" colspan=\"4\">ACE Workflow Examples</th>\n  </tr>\n  <tr>\n    <th align=\"center\" colspan=\"1\">Control</th>\n    <th align=\"center\" colspan=\"1\">Semantic</th>\n    <th align=\"center\" colspan=\"1\">Element</th>\n  </tr>\n  <tr>\n    <td>\n      <a href=\"assets/comfyui/ace_control.png\" target=\"_blank\">\n        <img src=\"assets/comfyui/ace_control.png\" width=\"200\">\n      </a>\n    </td>\n    <td>\n      <a href=\"assets/comfyui/ace_semantic.png\" target=\"_blank\">\n        <img src=\"assets/comfyui/ace_semantic.png\" width=\"200\">\n      </a>\n    </td>\n    <td>\n      <a href=\"assets/comfyui/ace_element.png\" target=\"_blank\">\n        <img src=\"assets/comfyui/ace_element.png\" width=\"200\">\n      </a>\n    </td>\n  </tr>\n</tbody>\n</table>\n\n\n## 📝 Citation\n\n```bibtex\n@inproceedings{ICLR2025_ACE,\n title = {ACE: All-round Creator and Editor Following Instructions via Diffusion Transformer},\n author = {Han, Zhen and Jiang, Zeyinzi and Pan, Yulin and Zhang, Jingfeng and Mao, Chaojie and Xie, Chen-Wei and Liu, Yu and Zhou, Jingren},\n booktitle = {International Conference on Representation Learning},\n pages = {57096--57111},\n year = {2025}\n}\n```\n"
  },
  {
    "path": "requirements.txt",
    "content": "git+https://github.com/modelscope/scepter.git@v1.3.0_dev#egg=scepter\npycocotools\npyyaml>=5.3.1\nscikit-image\ntorchsde\ntransformers\nscikit-learn\nnumpy\nopencv-python\nopencv_transforms>=0.0.6\noss2>=2.15.0\neinops\ntorch==2.4.0\ntorchvision\nflash-attn==2.5.8\nbitsandbytes\ngradio==4.44.1\ngradio_imageslider\ndiffusers\naddict\ndatasets==3.0.1"
  },
  {
    "path": "tools/run_inference.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport io\nimport os\nimport sys\nfrom PIL import Image\nfrom scepter.modules.utils.config import Config\nfrom scepter.modules.utils.file_system import FS\nif os.path.exists('__init__.py'):\n    package_name = 'scepter_ext'\n    spec = importlib.util.spec_from_file_location(package_name, '__init__.py')\n    package = importlib.util.module_from_spec(spec)\n    sys.modules[package_name] = package\n    spec.loader.exec_module(package)\n\nfrom chatbot.ace_inference import ACEInference\n\nfs_list = [\n    Config(cfg_dict={\"NAME\": \"HuggingfaceFs\", \"TEMP_DIR\": \"./cache\"}, load=False),\n    Config(cfg_dict={\"NAME\": \"ModelscopeFs\", \"TEMP_DIR\": \"./cache\"}, load=False),\n    Config(cfg_dict={\"NAME\": \"HttpFs\", \"TEMP_DIR\": \"./cache\"}, load=False),\n    Config(cfg_dict={\"NAME\": \"LocalFs\", \"TEMP_DIR\": \"./cache\"}, load=False),\n]\n\nfor one_fs in fs_list:\n    FS.init_fs_client(one_fs)\n\n\ndef run_one_case(pipe, input_image, input_mask, edit_k,\n                 instruction, negative_prompt, seed,\n                 output_h, output_w, save_path):\n    edit_image, edit_image_mask, edit_task = [], [], []\n    if input_image is not None:\n        image = Image.open(io.BytesIO(FS.get_object(input_image)))\n        edit_image.append(image.convert('RGB'))\n        edit_image_mask.append(\n            Image.open(Image.open(io.BytesIO(FS.get_object(input_mask)))).\n            convert('L') if input_mask is not None else None)\n        edit_task.append(edit_k)\n    imgs = pipe(\n        image=edit_image,\n        mask=edit_image_mask,\n        task=edit_task,\n        prompt=[instruction] *\n               len(edit_image) if edit_image is not None else [instruction],\n        negative_prompt=[negative_prompt] * len(edit_image)\n        if edit_image is not None else [negative_prompt],\n        output_height=output_h,\n        output_width=output_w,\n        sampler=pipe.input.get(\"sampler\", \"ddim\"),\n        sample_steps=pipe.input.get(\"sample_steps\", 20),\n        guide_scale=pipe.input.get(\"guide_scale\", 4.5),\n        guide_rescale=pipe.input.get(\"guide_rescale\", 0.5),\n        seed=seed,\n    )\n    with FS.put_to(save_path) as local_path:\n        imgs[0].save(local_path)\n    return\n\n\ndef run():\n    parser = argparse.ArgumentParser(description='Argparser for Scepter:\\n')\n    parser.add_argument('--instruction',\n                        dest='instruction',\n                        help='The instruction for editing or generating!',\n                        default=\"\")\n    parser.add_argument('--negative_prompt',\n                        dest='negative_prompt',\n                        help='The negative prompt for editing or generating!',\n                        default=\"\")\n    parser.add_argument('--output_h',\n                        dest='output_h',\n                        help='The height of output image for generation tasks!',\n                        type=int,\n                        default=None)\n    parser.add_argument('--output_w',\n                        dest='output_w',\n                        help='The width of output image for generation tasks!',\n                        type=int,\n                        default=None)\n    parser.add_argument('--input_image',\n                        dest='input_image',\n                        help='The input image!',\n                        default=None\n                        )\n    parser.add_argument('--input_mask',\n                        dest='input_mask',\n                        help='The input mask!',\n                        default=None\n                        )\n    parser.add_argument('--save_path',\n                        dest='save_path',\n                        help='The save path for output image!',\n                        default='examples/output_images/output.png'\n                        )\n    parser.add_argument('--seed',\n                        dest='seed',\n                        help='The seed for generation!',\n                        type=int,\n                        default=-1)\n    cfg = Config(load=True, parser_ins=parser)\n    pipe = ACEInference()\n    pipe.init_from_cfg(cfg)\n\n\n    output_h = cfg.args.output_h or pipe.input.get(\"output_height\", 1024)\n    output_w = cfg.args.output_w or pipe.input.get(\"output_width\", 1024)\n    negative_prompt = cfg.args.negative_prompt\n\n    if cfg.args.instruction == \"\" and cfg.args.input_image is None:\n        # run examples\n        all_examples = [\n            [\"examples/input_images/example0.webp\", None, \"\",\n             \"{image} make the boy cry, his eyes filled with tears\",\n             \"\", 199999, output_h, output_w, \"examples/output_images/example0.png\"],\n            [\"examples/input_images/example1.webp\", None, \"\",\n             \"{image}use the depth map @cb638863a0e9 and the text caption  \\\"Vincent van Gogh with expressive, \"\n             \"soulful eyes and a gentle smile, wearing traditional 19th-century artist's attire, including a \"\n             \"paint-streaked smock, a straw hat with sunflowers, and an artist's easel slung over his shoulder.\"\n             \"Subtle elements of \\\"Starry Night\\\" swirling around, with hints of sunflowers and wheat fields \"\n             \"from his famous paintings. Include a palette and paintbrushes, a small sun painted in the top \"\n             \"corner, and subtle curling patterns reminiscent of his brush strokes\\\" to create a image\",\n             \"\", 899999, output_h, output_w, \"examples/output_images/example1.png\"],\n            [\"examples/input_images/example2.webp\", None, \"\",\n             \"make this {image} colorful\",\n             \"\", 199999, output_h, output_w, \"examples/output_images/example2.png\"],\n            [\"examples/input_images/example3.webp\", None, \"\",\n             \"change the style to 3D cartoon style\",\n             \"\", 2023, output_h, output_w, \"examples/output_images/example3.png\"],\n\n        ]\n        for example in all_examples:\n            run_one_case(pipe, example[0], example[1], example[2], example[3],\n                         example[4], example[5], example[6], example[7], example[8])\n    else:\n        if \"{image}\" not in cfg.args.instruction:\n            instruction = \"{image} \" + cfg.args.instruction\n        else:\n            instruction = cfg.args.instruction\n\n        run_one_case(pipe, cfg.args.input_image, cfg.args.input_mask, \"\",\n                 instruction, negative_prompt, cfg.args.seed,\n                 output_h, output_w, cfg.args.save_path)\n\nif __name__ == '__main__':\n    run()\n\n"
  },
  {
    "path": "tools/run_train.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport os\nimport sys\n\nfrom scepter.modules.solver.registry import SOLVERS\nfrom scepter.modules.utils.config import Config\nfrom scepter.modules.utils.distribute import we\nfrom scepter.modules.utils.logger import get_logger\n\nif os.path.exists('__init__.py'):\n    package_name = 'scepter_ext'\n    spec = importlib.util.spec_from_file_location(package_name, '__init__.py')\n    package = importlib.util.module_from_spec(spec)\n    sys.modules[package_name] = package\n    spec.loader.exec_module(package)\n\n\ndef run_task(cfg):\n    std_logger = get_logger(name='scepter')\n    solver = SOLVERS.build(cfg.SOLVER, logger=std_logger)\n    solver.set_up_pre()\n    solver.set_up()\n    solver.solve()\n\n\ndef update_config(cfg):\n    if hasattr(cfg.args, 'learning_rate') and cfg.args.learning_rate:\n        if cfg.SOLVER.OPTIMIZER.get('LEARNING_RATE', None) is not None:\n            print(\n                f'learning_rate change from {cfg.SOLVER.OPTIMIZER.LEARNING_RATE} to {cfg.args.learning_rate}'\n            )\n        cfg.SOLVER.OPTIMIZER.LEARNING_RATE = float(cfg.args.learning_rate)\n    if hasattr(cfg.args, 'max_steps') and cfg.args.max_steps:\n        if cfg.SOLVER.get('MAX_STEPS', None) is not None:\n            print(\n                f'max_steps change from {cfg.SOLVER.MAX_STEPS} to {cfg.args.max_steps}'\n            )\n        cfg.SOLVER.MAX_STEPS = int(cfg.args.max_steps)\n    return cfg\n\n\ndef run():\n    parser = argparse.ArgumentParser(description='Argparser for Scepter:\\n')\n    parser.add_argument('--learning_rate',\n                        dest='learning_rate',\n                        help='The learning rate for our network!',\n                        default=None)\n    parser.add_argument('--max_steps',\n                        dest='max_steps',\n                        help='The max steps for training!',\n                        default=None)\n\n    cfg = Config(load=True, parser_ins=parser)\n    cfg = update_config(cfg)\n    we.init_env(cfg, logger=None, fn=run_task)\n\n\nif __name__ == '__main__':\n    run()\n"
  }
]