Repository: ali-vilab/ACE Branch: main Commit: 886bf9510b85 Files: 42 Total size: 269.5 KB Directory structure: gitextract_gx2sgdse/ ├── .gitignore ├── LICENSE ├── __init__.py ├── chatbot/ │ ├── __init__.py │ ├── ace_inference.py │ ├── config/ │ │ ├── chatbot_ui.yaml │ │ └── models/ │ │ └── ace_0.6b_512.yaml │ ├── example.py │ ├── infer.py │ ├── run_gradio.py │ └── utils.py ├── config/ │ ├── inference_config/ │ │ ├── chatbot_ui.yaml │ │ └── models/ │ │ ├── ace_0.6b_1024.yaml │ │ ├── ace_0.6b_1024_refiner.yaml │ │ └── ace_0.6b_512.yaml │ └── train_config/ │ ├── ace_0.6b_1024_train.yaml │ └── ace_0.6b_512_train.yaml ├── modules/ │ ├── __init__.py │ ├── data/ │ │ ├── __init__.py │ │ └── dataset/ │ │ ├── __init__.py │ │ └── dataset.py │ ├── inference/ │ │ └── __init__.py │ ├── model/ │ │ ├── __init__.py │ │ ├── backbone/ │ │ │ ├── __init__.py │ │ │ ├── ace.py │ │ │ ├── layers.py │ │ │ └── pos_embed.py │ │ ├── diffusion/ │ │ │ ├── __init__.py │ │ │ ├── diffusions.py │ │ │ ├── samplers.py │ │ │ └── schedules.py │ │ ├── embedder/ │ │ │ ├── __init__.py │ │ │ └── embedder.py │ │ ├── network/ │ │ │ ├── __init__.py │ │ │ └── ldm_ace.py │ │ └── utils/ │ │ └── basic_utils.py │ └── solver/ │ ├── __init__.py │ └── ace_solver.py ├── readme.md ├── requirements.txt └── tools/ ├── run_inference.py └── run_train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc *.pth *.pt *.pkl *.ckpt *.DS_Store *__pycache__* *.cache* *.bin *.idea *.csv cache build dist dev scepter.egg-info .readthedocs.yml *resources *.ipynb_checkpoints* *.vscode ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: __init__.py ================================================ from . import modules from . import chatbot ================================================ FILE: chatbot/__init__.py ================================================ ================================================ FILE: chatbot/ace_inference.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy import math import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from PIL import Image import torchvision.transforms as T from scepter.modules.model.registry import DIFFUSIONS from scepter.modules.model.utils.basic_utils import check_list_of_list from scepter.modules.model.utils.basic_utils import \ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor from scepter.modules.model.utils.basic_utils import ( to_device, unpack_tensor_into_imagelist) from scepter.modules.utils.distribute import we from scepter.modules.utils.logger import get_logger from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model def process_edit_image(images, masks, tasks, max_seq_len=1024, max_aspect_ratio=4, d=16, **kwargs): if not isinstance(images, list): images = [images] if not isinstance(masks, list): masks = [masks] if not isinstance(tasks, list): tasks = [tasks] img_tensors = [] mask_tensors = [] for img, mask, task in zip(images, masks, tasks): if mask is None or mask == '': mask = Image.new('L', img.size, 0) W, H = img.size if H / W > max_aspect_ratio: img = TF.center_crop(img, [int(max_aspect_ratio * W), W]) mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W]) elif W / H > max_aspect_ratio: img = TF.center_crop(img, [H, int(max_aspect_ratio * H)]) mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)]) H, W = img.height, img.width scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d)))) rH = int(H * scale) // d * d # ensure divisible by self.d rW = int(W * scale) // d * d img = TF.resize(img, (rH, rW), interpolation=TF.InterpolationMode.BICUBIC) mask = TF.resize(mask, (rH, rW), interpolation=TF.InterpolationMode.NEAREST_EXACT) mask = np.asarray(mask) mask = np.where(mask > 128, 1, 0) mask = mask.astype( np.float32) if np.any(mask) else np.ones_like(mask).astype( np.float32) img_tensor = TF.to_tensor(img).to(we.device_id) img_tensor = TF.normalize(img_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) mask_tensor = TF.to_tensor(mask).to(we.device_id) if task in ['inpainting', 'Try On', 'Inpainting']: mask_indicator = mask_tensor.repeat(3, 1, 1) img_tensor[mask_indicator == 1] = -1.0 img_tensors.append(img_tensor) mask_tensors.append(mask_tensor) return img_tensors, mask_tensors class TextEmbedding(nn.Module): def __init__(self, embedding_shape): super().__init__() self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) class RefinerInference(DiffusionInference): def init_from_cfg(self, cfg): self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True) super().init_from_cfg(cfg) self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \ if cfg.MODEL.have('DIFFUSION') else None self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096) assert self.diffusion is not None if not self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model') self.dynamic_load(self.cond_stage_model, 'cond_stage_model') self.dynamic_load(self.diffusion_model, 'diffusion_model') @torch.no_grad() def encode_first_stage(self, x, **kwargs): _, dtype = self.get_function_info(self.first_stage_model, 'encode') with torch.autocast('cuda', enabled=dtype in ('float16', 'bfloat16'), dtype=getattr(torch, dtype)): def run_one_image(u): zu = get_model(self.first_stage_model).encode(u) if isinstance(zu, (tuple, list)): zu = zu[0] return zu z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x] return z def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR): c, H, W = image.shape scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16)))) rH = int(H * scale) // 16 * 16 # ensure divisible by self.d rW = int(W * scale) // 16 * 16 image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image) return image @torch.no_grad() def decode_first_stage(self, z): _, dtype = self.get_function_info(self.first_stage_model, 'decode') with torch.autocast('cuda', enabled=dtype in ('float16', 'bfloat16'), dtype=getattr(torch, dtype)): return [get_model(self.first_stage_model).decode(zu) for zu in z] def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16): noise = torch.randn( num_samples, 16, # allow for packing 2 * math.ceil(h / 16), 2 * math.ceil(w / 16), device=device, dtype=dtype, generator=torch.Generator(device=device).manual_seed(seed), ) return noise def refine(self, x_samples=None, prompt=None, reverse_scale=-1., seed = 2024, **kwargs ): print(prompt) value_input = copy.deepcopy(self.input) x_samples = [self.upscale_resize(x) for x in x_samples] noise = [] for i, x in enumerate(x_samples): noise_ = self.noise_sample(1, x.shape[1], x.shape[2], seed, device = x.device) noise.append(noise_) noise, x_shapes = pack_imagelist_into_tensor(noise) if reverse_scale > 0: self.dynamic_load(self.first_stage_model, 'first_stage_model') x_samples = [x.unsqueeze(0) for x in x_samples] x_start = self.encode_first_stage(x_samples, **kwargs) self.dynamic_unload(self.first_stage_model, 'first_stage_model', skip_loaded=not self.use_dynamic_model) x_start, _ = pack_imagelist_into_tensor(x_start) else: x_start = None # cond stage self.dynamic_load(self.cond_stage_model, 'cond_stage_model') function_name, dtype = self.get_function_info(self.cond_stage_model) with torch.autocast('cuda', enabled=dtype == 'float16', dtype=getattr(torch, dtype)): ctx = getattr(get_model(self.cond_stage_model), function_name)(prompt) ctx["x_shapes"] = x_shapes self.dynamic_unload(self.cond_stage_model, 'cond_stage_model', skip_loaded=not self.use_dynamic_model) self.dynamic_load(self.diffusion_model, 'diffusion_model') # UNet use input n_prompt function_name, dtype = self.get_function_info( self.diffusion_model) with torch.autocast('cuda', enabled=dtype in ('float16', 'bfloat16'), dtype=getattr(torch, dtype)): solver_sample = value_input.get('sample', 'flow_euler') sample_steps = value_input.get('sample_steps', 20) guide_scale = value_input.get('guide_scale', 3.5) if guide_scale is not None: guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype) else: guide_scale = None latent = self.diffusion.sample( noise=noise, sampler=solver_sample, model=get_model(self.diffusion_model), model_kwargs={"cond": ctx, "guidance": guide_scale}, steps=sample_steps, show_progress=True, guide_scale=guide_scale, return_intermediate=None, reverse_scale=reverse_scale, x=x_start, **kwargs).float() latent = unpack_tensor_into_imagelist(latent, x_shapes) self.dynamic_unload(self.diffusion_model, 'diffusion_model', skip_loaded=not self.use_dynamic_model) self.dynamic_load(self.first_stage_model, 'first_stage_model') x_samples = self.decode_first_stage(latent) self.dynamic_unload(self.first_stage_model, 'first_stage_model', skip_loaded=not self.use_dynamic_model) return x_samples class ACEInference(DiffusionInference): def __init__(self, logger=None): if logger is None: logger = get_logger(name='scepter') self.logger = logger self.loaded_model = {} self.loaded_model_name = [ 'diffusion_model', 'first_stage_model', 'cond_stage_model' ] def init_from_cfg(self, cfg): self.name = cfg.NAME self.is_default = cfg.get('IS_DEFAULT', False) self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True) module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None)) assert cfg.have('MODEL') self.diffusion_model = self.infer_model( cfg.MODEL.DIFFUSION_MODEL, module_paras.get( 'DIFFUSION_MODEL', None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None self.first_stage_model = self.infer_model( cfg.MODEL.FIRST_STAGE_MODEL, module_paras.get( 'FIRST_STAGE_MODEL', None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None self.cond_stage_model = self.infer_model( cfg.MODEL.COND_STAGE_MODEL, module_paras.get( 'COND_STAGE_MODEL', None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None self.refiner_model_cfg = cfg.get('REFINER_MODEL', None) # self.refiner_scale = cfg.get('REFINER_SCALE', 0.) # self.refiner_prompt = cfg.get('REFINER_PROMPT', "") self.ace_prompt = cfg.get("ACE_PROMPT", []) if self.refiner_model_cfg: self.refiner_model_cfg.USE_DYNAMIC_MODEL = self.use_dynamic_model self.refiner_module = RefinerInference(self.logger) self.refiner_module.init_from_cfg(self.refiner_model_cfg) else: self.refiner_module = None self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) self.interpolate_func = lambda x: (F.interpolate( x.unsqueeze(0), scale_factor=1 / self.size_factor, mode='nearest-exact') if x is not None else None) self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', []) self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS', False) if self.use_text_pos_embeddings: self.text_position_embeddings = TextEmbedding( (10, 4096)).eval().requires_grad_(False).to(we.device_id) else: self.text_position_embeddings = None self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215) self.size_factor = cfg.get('SIZE_FACTOR', 8) self.decoder_bias = cfg.get('DECODER_BIAS', 0) self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '') if not self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model') self.dynamic_load(self.cond_stage_model, 'cond_stage_model') self.dynamic_load(self.diffusion_model, 'diffusion_model') @torch.no_grad() def encode_first_stage(self, x, **kwargs): _, dtype = self.get_function_info(self.first_stage_model, 'encode') with torch.autocast('cuda', enabled=(dtype != 'float32'), dtype=getattr(torch, dtype)): z = [ self.scale_factor * get_model(self.first_stage_model)._encode( i.unsqueeze(0).to(getattr(torch, dtype))) for i in x ] return z @torch.no_grad() def decode_first_stage(self, z): _, dtype = self.get_function_info(self.first_stage_model, 'decode') with torch.autocast('cuda', enabled=(dtype != 'float32'), dtype=getattr(torch, dtype)): x = [ get_model(self.first_stage_model)._decode( 1. / self.scale_factor * i.to(getattr(torch, dtype))) for i in z ] return x @torch.no_grad() def __call__(self, image=None, mask=None, prompt='', task=None, negative_prompt='', output_height=512, output_width=512, sampler='ddim', sample_steps=20, guide_scale=4.5, guide_rescale=0.5, seed=-1, history_io=None, tar_index=0, **kwargs): input_image, input_mask = image, mask g = torch.Generator(device=we.device_id) seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) g.manual_seed(int(seed)) if input_image is not None: # assert isinstance(input_image, list) and isinstance(input_mask, list) if task is None: task = [''] * len(input_image) if not isinstance(prompt, list): prompt = [prompt] * len(input_image) if history_io is not None and len(history_io) > 0: his_image, his_maks, his_prompt, his_task = history_io[ 'image'], history_io['mask'], history_io[ 'prompt'], history_io['task'] assert len(his_image) == len(his_maks) == len( his_prompt) == len(his_task) input_image = his_image + input_image input_mask = his_maks + input_mask task = his_task + task prompt = his_prompt + [prompt[-1]] prompt = [ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp for i, pp in enumerate(prompt) ] edit_image, edit_image_mask = process_edit_image( input_image, input_mask, task, max_seq_len=self.max_seq_len) image, image_mask = edit_image[tar_index], edit_image_mask[ tar_index] edit_image, edit_image_mask = [edit_image], [edit_image_mask] else: edit_image = edit_image_mask = [[]] image = torch.zeros( size=[3, int(output_height), int(output_width)]) image_mask = torch.ones( size=[1, int(output_height), int(output_width)]) if not isinstance(prompt, list): prompt = [prompt] image, image_mask, prompt = [image], [image_mask], [prompt] assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) # Assign Negative Prompt if isinstance(negative_prompt, list): negative_prompt = negative_prompt[0] assert isinstance(negative_prompt, str) n_prompt = copy.deepcopy(prompt) for nn_p_id, nn_p in enumerate(n_prompt): assert isinstance(nn_p, list) n_prompt[nn_p_id][-1] = negative_prompt is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1 image = to_device(image) refiner_scale = kwargs.pop("refiner_scale", 0.0) refiner_prompt = kwargs.pop("refiner_prompt", "") use_ace = kwargs.pop("use_ace", True) # <= 0 use ace as the txt2img generator. if use_ace and (not is_txt_image or refiner_scale <= 0): ctx, null_ctx = {}, {} # Get Noise Shape self.dynamic_load(self.first_stage_model, 'first_stage_model') x = self.encode_first_stage(image) self.dynamic_unload(self.first_stage_model, 'first_stage_model', skip_loaded=not self.use_dynamic_model) noise = [ torch.empty(*i.shape, device=we.device_id).normal_(generator=g) for i in x ] noise, x_shapes = pack_imagelist_into_tensor(noise) ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes image_mask = to_device(image_mask, strict=False) cond_mask = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * len(image) ctx['x_mask'] = null_ctx['x_mask'] = cond_mask # Encode Prompt self.dynamic_load(self.cond_stage_model, 'cond_stage_model') function_name, dtype = self.get_function_info(self.cond_stage_model) cont, cont_mask = getattr(get_model(self.cond_stage_model), function_name)(prompt) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model), function_name)(n_prompt) null_cont, null_cont_mask = self.cond_stage_embeddings( prompt, edit_image, null_cont, null_cont_mask) self.dynamic_unload(self.cond_stage_model, 'cond_stage_model', skip_loaded=not self.use_dynamic_model) ctx['crossattn'] = cont null_ctx['crossattn'] = null_cont # Encode Edit Images self.dynamic_load(self.first_stage_model, 'first_stage_model') edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if u is None: continue if m is None: m = [None] * len(u) e_img.append(self.encode_first_stage(u, **kwargs)) e_mask.append([self.interpolate_func(i) for i in m]) self.dynamic_unload(self.first_stage_model, 'first_stage_model', skip_loaded=not self.use_dynamic_model) null_ctx['edit'] = ctx['edit'] = e_img null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask # Diffusion Process self.dynamic_load(self.diffusion_model, 'diffusion_model') function_name, dtype = self.get_function_info(self.diffusion_model) with torch.autocast('cuda', enabled=dtype in ('float16', 'bfloat16'), dtype=getattr(torch, dtype)): latent = self.diffusion.sample( noise=noise, sampler=sampler, model=get_model(self.diffusion_model), model_kwargs=[{ 'cond': ctx, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, { 'cond': null_ctx, 'mask': null_cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }] if guide_scale is not None and guide_scale > 1 else { 'cond': null_ctx, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, steps=sample_steps, show_progress=True, seed=seed, guide_scale=guide_scale, guide_rescale=guide_rescale, return_intermediate=None, **kwargs) if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model, 'diffusion_model', skip_loaded=not self.use_dynamic_model) # Decode to Pixel Space self.dynamic_load(self.first_stage_model, 'first_stage_model') samples = unpack_tensor_into_imagelist(latent, x_shapes) x_samples = self.decode_first_stage(samples) self.dynamic_unload(self.first_stage_model, 'first_stage_model', skip_loaded=not self.use_dynamic_model) x_samples = [x.squeeze(0) for x in x_samples] else: x_samples = image if self.refiner_module and refiner_scale > 0: if is_txt_image: random.shuffle(self.ace_prompt) input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt] input_refine_scale = -1. else: input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt] input_refine_scale = refiner_scale print(input_refine_prompt) x_samples = self.refiner_module.refine(x_samples, reverse_scale = input_refine_scale, prompt= input_refine_prompt, seed=seed, use_dynamic_model=self.use_dynamic_model) imgs = [ torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255, min=0.0, max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy() for x_i in x_samples ] imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] return imgs def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): if self.use_text_pos_embeddings and not torch.sum( self.text_position_embeddings.pos) > 0: identifier_cont, _ = getattr(get_model(self.cond_stage_model), 'encode')(self.text_indentifers, return_mask=True) self.text_position_embeddings.load_state_dict( {'pos': identifier_cont[:, 0, :]}) cont_, cont_mask_ = [], [] for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): if isinstance(pp, list): cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) else: raise NotImplementedError return cont_, cont_mask_ ================================================ FILE: chatbot/config/chatbot_ui.yaml ================================================ WORK_DIR: ./cache/chatbot FILE_SYSTEM: - NAME: "HuggingfaceFs" TEMP_DIR: ./cache - NAME: "ModelscopeFs" TEMP_DIR: ./cache - NAME: "LocalFs" TEMP_DIR: ./cache - NAME: "HttpFs" TEMP_DIR: ./cache # ENABLE_I2V: False # MODEL: EDIT_MODEL: MODEL_CFG_DIR: chatbot/config/models/ DEFAULT: ace_0.6b_512 I2V: MODEL_NAME: CogVideoX-5b-I2V MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/ CAPTIONER: MODEL_NAME: InternVL2-2B MODEL_DIR: ms://OpenGVLab/InternVL2-2B/ PROMPT: '\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.' ENHANCER: MODEL_NAME: Meta-Llama-3.1-8B-Instruct MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/ ================================================ FILE: chatbot/config/models/ace_0.6b_512.yaml ================================================ NAME: ACE_0.6B_512 IS_DEFAULT: False DEFAULT_PARAS: PARAS: # INPUT: INPUT_IMAGE: INPUT_MASK: TASK: PROMPT: "" NEGATIVE_PROMPT: "" OUTPUT_HEIGHT: 512 OUTPUT_WIDTH: 512 SAMPLER: ddim SAMPLE_STEPS: 20 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 SEED: -1 TAR_INDEX: 0 OUTPUT: LATENT: IMAGES: SEED: MODULES_PARAS: FIRST_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: float16 INPUT: ["IMAGE"] - NAME: decode DTYPE: float16 INPUT: ["LATENT"] # DIFFUSION_MODEL: FUNCTION: - NAME: forward DTYPE: float16 INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"] # COND_STAGE_MODEL: FUNCTION: - NAME: encode_list DTYPE: bfloat16 INPUT: ["PROMPT"] # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: "" TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 1024 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: ACETextEmbedder PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False ================================================ FILE: chatbot/example.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os from scepter.modules.utils.file_system import FS from PIL import Image def download_image(image, local_path=None): if not FS.exists(local_path): local_path = FS.get_from(image, local_path=local_path) return local_path def blank_image(): return Image.new('RGBA', (128, 128), (0, 0, 0, 0)) def get_examples(cache_dir): print('Downloading Examples ...') bl_img = blank_image() examples = [ [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true', os.path.join(cache_dir, 'examples/e33edc106953.png')), bl_img, bl_img, '{image} let the man smile', 6666 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true', os.path.join(cache_dir, 'examples/5d2bcc91a3e9.png')), bl_img, bl_img, 'let the man in {image} wear sunglasses', 9999 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true', os.path.join(cache_dir, 'examples/3a52eac708bd.png')), bl_img, bl_img, '{image} red hair', 9999 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true', os.path.join(cache_dir, 'examples/3f4dc464a0ea.png')), bl_img, bl_img, '{image} let the man serious', 99999 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true', os.path.join(cache_dir, 'examples/131ca90fd2a9.png')), bl_img, bl_img, '"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.', 613725 ], [ 'Render Text', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true', os.path.join(cache_dir, 'examples/33e9f27c2c48.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true', os.path.join(cache_dir, 'examples/33e9f27c2c48_mask.png')), bl_img, 'Put the text "C A T" at the position marked by mask in the {image}', 6666 ], [ 'Style Transfer', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true', os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), bl_img, download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true', os.path.join(cache_dir, 'examples/2e02975293d6.png')), 'edit {image} based on the style of {image1} ', 99999 ], [ 'Outpainting', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true', os.path.join(cache_dir, 'examples/f2b22c08be3f.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true', os.path.join(cache_dir, 'examples/f2b22c08be3f_mask.png')), bl_img, 'Could the {image} be widened within the space designated by mask, while retaining the original?', 6666 ], [ 'Image Segmentation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true', os.path.join(cache_dir, 'examples/db3ebaa81899.png')), bl_img, bl_img, '{image} Segmentation', 6666 ], [ 'Depth Estimation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true', os.path.join(cache_dir, 'examples/f1927c4692ba.png')), bl_img, bl_img, '{image} Depth Estimation', 6666 ], [ 'Pose Estimation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true', os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), bl_img, bl_img, '{image} distinguish the poses of the figures', 999999 ], [ 'Scribble Extraction', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true', os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), bl_img, bl_img, 'Generate a scribble of {image}, please.', 6666 ], [ 'Mosaic', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true', os.path.join(cache_dir, 'examples/3a2f52361eea.png')), bl_img, bl_img, 'Adapt {image} into a mosaic representation.', 6666 ], [ 'Edge map Extraction', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true', os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), bl_img, bl_img, 'Get the edge-enhanced result for {image}.', 6666 ], [ 'Grayscale', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true', os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), bl_img, bl_img, 'transform {image} into a black and white one', 6666 ], [ 'Contour Extraction', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true', os.path.join(cache_dir, 'examples/19652d0f6c4b.png')), bl_img, bl_img, 'Would you be able to make a contour picture from {image} for me?', 6666 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true', os.path.join(cache_dir, 'examples/249cda2844b7.png')), bl_img, bl_img, 'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.', 6666 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true', os.path.join(cache_dir, 'examples/411f6c4b8e6c.png')), bl_img, bl_img, 'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image', 999999 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true', os.path.join(cache_dir, 'examples/a35c96ed137a.png')), bl_img, bl_img, 'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."', 3599999 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true', os.path.join(cache_dir, 'examples/dcb2fc86f1ce.png')), bl_img, bl_img, 'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description', 6666 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true', os.path.join(cache_dir, 'examples/4cd4ee494962.png')), bl_img, bl_img, 'make this {image} colorful as per the "beautiful sunflowers"', 6666 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true', os.path.join(cache_dir, 'examples/a47e3a9cd166.png')), bl_img, bl_img, 'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.', 613725 ], [ 'Controllable Generation', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true', os.path.join(cache_dir, 'examples/d890ed8a3ac2.png')), bl_img, bl_img, 'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."', 6666 ], [ 'Image Denoising', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true', os.path.join(cache_dir, 'examples/0844a686a179.png')), bl_img, bl_img, 'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality', 6666 ], [ 'Inpainting', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true', os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true', os.path.join(cache_dir, 'examples/fa91b6b7e59b_mask.png')), bl_img, 'Ensure to overhaul the parts of the {image} indicated by the mask.', 6666 ], [ 'Inpainting', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true', os.path.join(cache_dir, 'examples/632899695b26.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true', os.path.join(cache_dir, 'examples/632899695b26_mask.png')), bl_img, 'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"', 6666 ], [ 'General Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true', os.path.join(cache_dir, 'examples/354d17594afe.png')), bl_img, bl_img, '{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.', 6666 ], [ 'General Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true', os.path.join(cache_dir, 'examples/38946455752b.png')), bl_img, bl_img, '{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged', 6669 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true', os.path.join(cache_dir, 'examples/3ba5202f0cd8.png')), bl_img, bl_img, 'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.', 99999 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true', os.path.join(cache_dir, 'examples/369365b94725.png')), bl_img, bl_img, '{image} Make her looking at the camera', 6666 ], [ 'Facial Editing', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true', os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), bl_img, bl_img, '{image} Remove the smile from his face', 9899999 ], [ 'Remove Text', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true', os.path.join(cache_dir, 'examples/8530a6711b2e.png')), bl_img, bl_img, 'Aim to remove any textual element in {image}', 6666 ], [ 'Remove Text', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true', os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true', os.path.join(cache_dir, 'examples/c4d7fb28f8f6_mask.png')), bl_img, 'Rub out any text found in the mask sector of the {image}.', 6666 ], [ 'Remove Object', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true', os.path.join(cache_dir, 'examples/e2f318fa5e5b.png')), bl_img, bl_img, 'Remove the unicorn in this {image}, ensuring a smooth edit.', 99999 ], [ 'Remove Object', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true', os.path.join(cache_dir, 'examples/1ae96d8aca00.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true', os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')), bl_img, 'Discard the contents of the mask area from {image}.', 99999 ], [ 'Add Object', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true', os.path.join(cache_dir, 'examples/80289f48e511.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true', os.path.join(cache_dir, 'examples/80289f48e511_mask.png')), bl_img, 'add a Hot Air Balloon into the {image}, per the mask', 613725 ], [ 'Style Transfer', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true', os.path.join(cache_dir, 'examples/d725cb2009e8.png')), bl_img, bl_img, 'Change the style of {image} to colored pencil style', 99999 ], [ 'Style Transfer', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true', os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), bl_img, bl_img, 'make {image} to Walt Disney Animation style', 99999 ], [ 'Try On', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true', os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true', os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')), download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true', os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')), 'Change the cloth in {image} to the one in {image1}', 99999 ], [ 'Workflow', download_image( 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true', os.path.join(cache_dir, 'examples/cb85353c004b.png')), bl_img, bl_img, ' ice cream {image}', 99999 ], ] print('Finish. Start building UI ...') return examples ================================================ FILE: chatbot/infer.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy import math import random import numpy as np from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from scepter.modules.model.registry import DIFFUSIONS from scepter.modules.utils.distribute import we from scepter.modules.utils.logger import get_logger from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model from modules.model.utils.basic_utils import ( check_list_of_list, pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor, to_device, unpack_tensor_into_imagelist ) def process_edit_image(images, masks, tasks, max_seq_len=1024, max_aspect_ratio=4, d=16, **kwargs): if not isinstance(images, list): images = [images] if not isinstance(masks, list): masks = [masks] if not isinstance(tasks, list): tasks = [tasks] img_tensors = [] mask_tensors = [] for img, mask, task in zip(images, masks, tasks): if mask is None or mask == '': mask = Image.new('L', img.size, 0) W, H = img.size if H / W > max_aspect_ratio: img = TF.center_crop(img, [int(max_aspect_ratio * W), W]) mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W]) elif W / H > max_aspect_ratio: img = TF.center_crop(img, [H, int(max_aspect_ratio * H)]) mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)]) H, W = img.height, img.width scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d)))) rH = int(H * scale) // d * d # ensure divisible by self.d rW = int(W * scale) // d * d img = TF.resize(img, (rH, rW), interpolation=TF.InterpolationMode.BICUBIC) mask = TF.resize(mask, (rH, rW), interpolation=TF.InterpolationMode.NEAREST_EXACT) mask = np.asarray(mask) mask = np.where(mask > 128, 1, 0) mask = mask.astype( np.float32) if np.any(mask) else np.ones_like(mask).astype( np.float32) img_tensor = TF.to_tensor(img).to(we.device_id) img_tensor = TF.normalize(img_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) mask_tensor = TF.to_tensor(mask).to(we.device_id) if task in ['inpainting', 'Try On', 'Inpainting']: mask_indicator = mask_tensor.repeat(3, 1, 1) img_tensor[mask_indicator == 1] = -1.0 img_tensors.append(img_tensor) mask_tensors.append(mask_tensor) return img_tensors, mask_tensors class TextEmbedding(nn.Module): def __init__(self, embedding_shape): super().__init__() self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) class ACEInference(DiffusionInference): def __init__(self, logger=None): if logger is None: logger = get_logger(name='scepter') self.logger = logger self.loaded_model = {} self.loaded_model_name = [ 'diffusion_model', 'first_stage_model', 'cond_stage_model' ] def init_from_cfg(self, cfg): self.name = cfg.NAME self.is_default = cfg.get('IS_DEFAULT', False) module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None)) assert cfg.have('MODEL') self.diffusion_model = self.infer_model( cfg.MODEL.DIFFUSION_MODEL, module_paras.get( 'DIFFUSION_MODEL', None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None self.first_stage_model = self.infer_model( cfg.MODEL.FIRST_STAGE_MODEL, module_paras.get( 'FIRST_STAGE_MODEL', None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None self.cond_stage_model = self.infer_model( cfg.MODEL.COND_STAGE_MODEL, module_paras.get( 'COND_STAGE_MODEL', None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) self.interpolate_func = lambda x: (F.interpolate( x.unsqueeze(0), scale_factor=1 / self.size_factor, mode='nearest-exact') if x is not None else None) self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', []) self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS', False) if self.use_text_pos_embeddings: self.text_position_embeddings = TextEmbedding( (10, 4096)).eval().requires_grad_(False).to(we.device_id) else: self.text_position_embeddings = None self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215) self.size_factor = cfg.get('SIZE_FACTOR', 8) self.decoder_bias = cfg.get('DECODER_BIAS', 0) self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '') self.dynamic_load(self.first_stage_model, 'first_stage_model') self.dynamic_load(self.cond_stage_model, 'cond_stage_model') self.dynamic_load(self.diffusion_model, 'diffusion_model') @torch.no_grad() def encode_first_stage(self, x, **kwargs): _, dtype = self.get_function_info(self.first_stage_model, 'encode') with torch.autocast('cuda', enabled=(dtype != 'float32'), dtype=getattr(torch, dtype)): z = [ self.scale_factor * get_model(self.first_stage_model)._encode( i.unsqueeze(0).to(getattr(torch, dtype))) for i in x ] return z @torch.no_grad() def decode_first_stage(self, z): _, dtype = self.get_function_info(self.first_stage_model, 'decode') with torch.autocast('cuda', enabled=(dtype != 'float32'), dtype=getattr(torch, dtype)): x = [ get_model(self.first_stage_model)._decode( 1. / self.scale_factor * i.to(getattr(torch, dtype))) for i in z ] return x @torch.no_grad() def __call__(self, image=None, mask=None, prompt='', task=None, negative_prompt='', output_height=512, output_width=512, sampler='ddim', sample_steps=20, guide_scale=4.5, guide_rescale=0.5, seed=-1, history_io=None, tar_index=0, **kwargs): input_image, input_mask = image, mask g = torch.Generator(device=we.device_id) seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) g.manual_seed(int(seed)) if input_image is not None: assert isinstance(input_image, list) and isinstance( input_mask, list) if task is None: task = [''] * len(input_image) if not isinstance(prompt, list): prompt = [prompt] * len(input_image) if history_io is not None and len(history_io) > 0: his_image, his_maks, his_prompt, his_task = history_io[ 'image'], history_io['mask'], history_io[ 'prompt'], history_io['task'] assert len(his_image) == len(his_maks) == len( his_prompt) == len(his_task) input_image = his_image + input_image input_mask = his_maks + input_mask task = his_task + task prompt = his_prompt + [prompt[-1]] prompt = [ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp for i, pp in enumerate(prompt) ] edit_image, edit_image_mask = process_edit_image( input_image, input_mask, task, max_seq_len=self.max_seq_len) image, image_mask = edit_image[tar_index], edit_image_mask[ tar_index] edit_image, edit_image_mask = [edit_image], [edit_image_mask] else: edit_image = edit_image_mask = [[]] image = torch.zeros( size=[3, int(output_height), int(output_width)]) image_mask = torch.ones( size=[1, int(output_height), int(output_width)]) if not isinstance(prompt, list): prompt = [prompt] image, image_mask, prompt = [image], [image_mask], [prompt] assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) # Assign Negative Prompt if isinstance(negative_prompt, list): negative_prompt = negative_prompt[0] assert isinstance(negative_prompt, str) n_prompt = copy.deepcopy(prompt) for nn_p_id, nn_p in enumerate(n_prompt): assert isinstance(nn_p, list) n_prompt[nn_p_id][-1] = negative_prompt ctx, null_ctx = {}, {} # Get Noise Shape image = to_device(image) x = self.encode_first_stage(image) noise = [ torch.empty(*i.shape, device=we.device_id).normal_(generator=g) for i in x ] noise, x_shapes = pack_imagelist_into_tensor(noise) ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes image_mask = to_device(image_mask, strict=False) cond_mask = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * len(image) ctx['x_mask'] = null_ctx['x_mask'] = cond_mask # Encode Prompt function_name, dtype = self.get_function_info(self.cond_stage_model) cont, cont_mask = getattr(get_model(self.cond_stage_model), function_name)(prompt) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model), function_name)(n_prompt) null_cont, null_cont_mask = self.cond_stage_embeddings( prompt, edit_image, null_cont, null_cont_mask) ctx['crossattn'] = cont null_ctx['crossattn'] = null_cont # Encode Edit Images edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if u is None: continue if m is None: m = [None] * len(u) e_img.append(self.encode_first_stage(u, **kwargs)) e_mask.append([self.interpolate_func(i) for i in m]) null_ctx['edit'] = ctx['edit'] = e_img null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask # Diffusion Process function_name, dtype = self.get_function_info(self.diffusion_model) with torch.autocast('cuda', enabled=dtype in ('float16', 'bfloat16'), dtype=getattr(torch, dtype)): latent = self.diffusion.sample( noise=noise, sampler=sampler, model=get_model(self.diffusion_model), model_kwargs=[{ 'cond': ctx, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, { 'cond': null_ctx, 'mask': null_cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }] if guide_scale is not None and guide_scale > 1 else { 'cond': null_ctx, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, steps=sample_steps, show_progress=True, seed=seed, guide_scale=guide_scale, guide_rescale=guide_rescale, return_intermediate=None, **kwargs) # Decode to Pixel Space samples = unpack_tensor_into_imagelist(latent, x_shapes) x_samples = self.decode_first_stage(samples) imgs = [ torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255, min=0.0, max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy() for x_i in x_samples ] imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] return imgs def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): if self.use_text_pos_embeddings and not torch.sum( self.text_position_embeddings.pos) > 0: identifier_cont, _ = getattr(get_model(self.cond_stage_model), 'encode')(self.text_indentifers, return_mask=True) self.text_position_embeddings.load_state_dict( {'pos': identifier_cont[:, 0, :]}) cont_, cont_mask_ = [], [] for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): if isinstance(pp, list): cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) else: raise NotImplementedError return cont_, cont_mask_ ================================================ FILE: chatbot/run_gradio.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse import base64 import copy import csv import glob import io import os import random import re import string import sys import threading import warnings import cv2 import gradio as gr import numpy as np import torch import transformers from PIL import Image from transformers import AutoModel, AutoTokenizer from scepter.modules.utils.config import Config from scepter.modules.utils.directory import get_md5 from scepter.modules.utils.file_system import FS from scepter.studio.utils.env import init_env from importlib.metadata import version from ace_inference import ACEInference from example import get_examples from utils import load_image csv.field_size_limit(sys.maxsize) refresh_sty = '\U0001f504' # 🔄 clear_sty = '\U0001f5d1' # 🗑️ upload_sty = '\U0001f5bc' # 🖼️ sync_sty = '\U0001f4be' # 💾 chat_sty = '\U0001F4AC' # 💬 video_sty = '\U0001f3a5' # 🎥 lock = threading.Lock() class ChatBotUI(object): def __init__(self, cfg_general_file, is_debug=False, language='en', root_work_dir='./'): try: from diffusers import CogVideoXImageToVideoPipeline from diffusers.utils import export_to_video except Exception as e: print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}") if isinstance(cfg_general_file, str): cfg = Config(cfg_file=cfg_general_file) else: cfg = cfg_general_file cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR) if not FS.exists(cfg.WORK_DIR): FS.make_dir(cfg.WORK_DIR) cfg = init_env(cfg) self.cache_dir = cfg.WORK_DIR self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else [] self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir, '*.yaml')) self.model_choices = dict() self.default_model_name = '' for i in self.model_yamls: model_cfg = Config(load=True, cfg_file=i) model_name = model_cfg.NAME if model_cfg.IS_DEFAULT: self.default_model_name = model_name self.model_choices[model_name] = model_cfg print('Models: ', self.model_choices.keys()) assert len(self.model_choices) > 0 if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0] self.model_name = self.default_model_name self.pipe = ACEInference() self.pipe.init_from_cfg(self.model_choices[self.default_model_name]) self.max_msgs = 20 self.enable_i2v = cfg.get('ENABLE_I2V', False) self.gradio_version = version('gradio') if self.enable_i2v: self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME if self.i2v_model_name == 'CogVideoX-5b-I2V': with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir: self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained( local_dir, torch_dtype=torch.bfloat16).cuda() else: raise NotImplementedError with FS.get_dir_to_local_dir( cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir: self.captioner = AutoModel.from_pretrained( local_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True).eval().cuda() self.llm_tokenizer = AutoTokenizer.from_pretrained( local_dir, trust_remote_code=True, use_fast=False) self.llm_generation_config = dict(max_new_tokens=1024, do_sample=True) self.llm_prompt = cfg.LLM.PROMPT self.llm_max_num = 2 with FS.get_dir_to_local_dir( cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir: self.enhancer = transformers.pipeline( 'text-generation', model=local_dir, model_kwargs={'torch_dtype': torch.bfloat16}, device_map='auto', ) sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. There are a few rules to follow: You will only ever output a single video description per user request. When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions. Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user. Video descriptions must have the same num of words as examples below. Extra words will be ignored. """ self.enhance_ctx = [ { 'role': 'system', 'content': sys_prompt }, { 'role': 'user', 'content': 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"', }, { 'role': 'assistant', 'content': "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.", }, { 'role': 'user', 'content': 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"', }, { 'role': 'assistant', 'content': "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.", }, { 'role': 'user', 'content': 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"', }, { 'role': 'assistant', 'content': 'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.', }, ] def create_ui(self): css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}' with gr.Blocks(css=css, title='Chatbot', head='Chatbot', analytics_enabled=False): self.history = gr.State(value=[]) self.images = gr.State(value={}) self.history_result = gr.State(value={}) self.retry_msg = gr.State(value='') with gr.Group(): self.ui_mode = gr.State(value='legacy') with gr.Row(equal_height=True, visible=False) as self.chat_group: with gr.Column(visible=True) as self.chat_page: self.chatbot = gr.Chatbot( height=600, value=[], bubble_full_width=False, show_copy_button=True, container=False, placeholder='Chat Box') with gr.Row(): self.clear_btn = gr.Button(clear_sty + ' Clear Chat', size='sm') with gr.Column(visible=False) as self.editor_page: with gr.Tabs(visible=False) as self.upload_tabs: with gr.Tab(id='ImageUploader', label='Image Uploader', visible=True) as self.upload_tab: self.image_uploader = gr.Image( height=550, interactive=True, type='pil', image_mode='RGB', sources=['upload'], elem_id='image_uploader', format='png') with gr.Row(): self.sub_btn_1 = gr.Button( value='Submit', elem_id='upload_submit') self.ext_btn_1 = gr.Button(value='Exit') with gr.Tabs(visible=False) as self.edit_tabs: with gr.Tab(id='ImageEditor', label='Image Editor') as self.edit_tab: self.mask_type = gr.Dropdown( label='Mask Type', choices=[ 'Background', 'Composite', 'Outpainting' ], value='Background') self.mask_type_info = gr.HTML( value= "
Background mode will not erase the visual content in the mask area
" ) with gr.Accordion( label='Outpainting Setting', open=True, visible=False) as self.outpaint_tab: with gr.Row(variant='panel'): self.top_ext = gr.Slider( show_label=True, label='Top Extend Ratio', minimum=0.0, maximum=2.0, step=0.1, value=0.25) self.bottom_ext = gr.Slider( show_label=True, label='Bottom Extend Ratio', minimum=0.0, maximum=2.0, step=0.1, value=0.25) with gr.Row(variant='panel'): self.left_ext = gr.Slider( show_label=True, label='Left Extend Ratio', minimum=0.0, maximum=2.0, step=0.1, value=0.25) self.right_ext = gr.Slider( show_label=True, label='Right Extend Ratio', minimum=0.0, maximum=2.0, step=0.1, value=0.25) with gr.Row(variant='panel'): self.img_pad_btn = gr.Button( value='Pad Image') self.image_editor = gr.ImageMask( value=None, sources=[], layers=False, label='Edit Image', elem_id='image_editor', format='png') with gr.Row(): self.sub_btn_2 = gr.Button( value='Submit', elem_id='edit_submit') self.ext_btn_2 = gr.Button(value='Exit') with gr.Tab(id='ImageViewer', label='Image Viewer') as self.image_view_tab: if self.gradio_version >= '5.0.0': self.image_viewer = gr.Image( label='Image', type='pil', show_download_button=True, elem_id='image_viewer') else: try: from gradio_imageslider import ImageSlider except Exception as e: print(f"Import gradio_imageslider failed, please install.") self.image_viewer = ImageSlider( label='Image', type='pil', show_download_button=True, elem_id='image_viewer') self.ext_btn_3 = gr.Button(value='Exit') with gr.Tab(id='VideoViewer', label='Video Viewer', visible=False) as self.video_view_tab: self.video_viewer = gr.Video( label='Video', interactive=False, sources=[], format='mp4', show_download_button=True, elem_id='video_viewer', loop=True, autoplay=True) self.ext_btn_4 = gr.Button(value='Exit') with gr.Row(equal_height=True, visible=True) as self.legacy_group: with gr.Column(): self.legacy_image_uploader = gr.Image( height=550, interactive=True, type='pil', image_mode='RGB', elem_id='legacy_image_uploader', format='png') with gr.Column(): self.legacy_image_viewer = gr.Image( label='Image', height=550, type='pil', interactive=False, show_download_button=True, elem_id='image_viewer') with gr.Accordion(label='Setting', open=False): with gr.Row(): self.model_name_dd = gr.Dropdown( choices=self.model_choices, value=self.default_model_name, label='Model Version') with gr.Row(): self.negative_prompt = gr.Textbox( value='', placeholder= 'Negative prompt used for Classifier-Free Guidance', label='Negative Prompt', container=False) with gr.Row(): # REFINER_PROMPT self.refiner_prompt = gr.Textbox( value=self.pipe.input.get("refiner_prompt", ""), visible=self.pipe.input.get("refiner_prompt", None) is not None, placeholder= 'Prompt used for refiner', label='Refiner Prompt', container=False) with gr.Row(): with gr.Column(scale=8, min_width=500): with gr.Row(): self.step = gr.Slider(minimum=1, maximum=1000, value=self.pipe.input.get("sample_steps", 20), visible=self.pipe.input.get("sample_steps", None) is not None, label='Sample Step') self.cfg_scale = gr.Slider( minimum=1.0, maximum=20.0, value=self.pipe.input.get("guide_scale", 4.5), visible=self.pipe.input.get("guide_scale", None) is not None, label='Guidance Scale') self.rescale = gr.Slider(minimum=0.0, maximum=1.0, value=self.pipe.input.get("guide_rescale", 0.5), visible=self.pipe.input.get("guide_rescale", None) is not None, label='Rescale') self.refiner_scale = gr.Slider(minimum=-0.1, maximum=1.0, value=self.pipe.input.get("refiner_scale", -1), visible=self.pipe.input.get("refiner_scale", None) is not None, label='Refiner Scale') self.seed = gr.Slider(minimum=-1, maximum=10000000, value=-1, label='Seed') self.output_height = gr.Slider( minimum=256, maximum=1440, value=self.pipe.input.get("output_height", 1024), visible=self.pipe.input.get("output_height", None) is not None, label='Output Height') self.output_width = gr.Slider( minimum=256, maximum=1440, value=self.pipe.input.get("output_width", 1024), visible=self.pipe.input.get("output_width", None) is not None, label='Output Width') with gr.Column(scale=1, min_width=50): self.use_history = gr.Checkbox(value=False, label='Use History') self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True), visible=self.pipe.input.get("use_ace", None) is not None, label='Use ACE') self.video_auto = gr.Checkbox( value=False, label='Auto Gen Video', visible=self.enable_i2v) with gr.Row(variant='panel', equal_height=True, visible=self.enable_i2v): self.video_fps = gr.Slider(minimum=1, maximum=16, value=8, label='Video FPS', visible=True) self.video_frames = gr.Slider(minimum=8, maximum=49, value=49, label='Video Frame Num', visible=True) self.video_step = gr.Slider(minimum=1, maximum=1000, value=50, label='Video Sample Step', visible=True) self.video_cfg_scale = gr.Slider( minimum=1.0, maximum=20.0, value=6.0, label='Video Guidance Scale', visible=True) self.video_seed = gr.Slider(minimum=-1, maximum=10000000, value=-1, label='Video Seed', visible=True) with gr.Row(): self.chatbot_inst = """ **Instruction**: 1. Click 'Upload' button to upload one or more images as input images. 2. Enter '@' in the text box will exhibit all images in the gallery. 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box. 4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction. For example, you might say, "Change the girl's skirt in @123456 to blue." The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input. 5. Once your instructions are prepared, please click the "Chat" button to view the edited result in the chat window. 6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx". 7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type. 8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information. """ self.legacy_inst = """ **Instruction**: 1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text.. 2. Enter '@' in the text box will exhibit all images in the gallery. 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box. 4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx". 5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions.. 6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information. """ self.instruction = gr.Markdown(value=self.legacy_inst) with gr.Row(variant='panel', equal_height=True, show_progress=False): with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel: self.upload_btn = gr.Button(value=upload_sty + ' Upload', variant='secondary') with gr.Column(scale=5, min_width=500): self.text = gr.Textbox( placeholder='Input "@" find history of image', label='Instruction', container=False) with gr.Column(scale=1, min_width=100): self.chat_btn = gr.Button(value='Generate', variant='primary') with gr.Column(scale=1, min_width=100): self.retry_btn = gr.Button(value=refresh_sty + ' Retry', variant='secondary') with gr.Column(scale=1, min_width=100): self.mode_checkbox = gr.Checkbox( value=False, label='ChatBot') with gr.Column(scale=(1 if self.enable_i2v else 0), min_width=0): self.video_gen_btn = gr.Button(value=video_sty + ' Gen Video', variant='secondary', visible=self.enable_i2v) with gr.Column(scale=(1 if self.enable_i2v else 0), min_width=0): self.extend_prompt = gr.Checkbox( value=True, label='Extend Prompt', visible=self.enable_i2v) with gr.Row(): self.gallery = gr.Gallery(visible=False, label='History', columns=10, allow_preview=False, interactive=False) self.eg = gr.Column(visible=True) def set_callbacks(self, *args, **kwargs): ######################################## def change_model(model_name): if model_name not in self.model_choices: gr.Info('The provided model name is not a valid choice!') return model_name, gr.update(), gr.update() if model_name != self.model_name: lock.acquire() del self.pipe torch.cuda.empty_cache() torch.cuda.ipc_collect() self.pipe = ACEInference() self.pipe.init_from_cfg(self.model_choices[model_name]) self.model_name = model_name lock.release() return (model_name, gr.update(), gr.update(), gr.Slider( value=self.pipe.input.get("sample_steps", 20), visible=self.pipe.input.get("sample_steps", None) is not None), gr.Slider( value=self.pipe.input.get("guide_scale", 4.5), visible=self.pipe.input.get("guide_scale", None) is not None), gr.Slider( value=self.pipe.input.get("guide_rescale", 0.5), visible=self.pipe.input.get("guide_rescale", None) is not None), gr.Slider( value=self.pipe.input.get("output_height", 1024), visible=self.pipe.input.get("output_height", None) is not None), gr.Slider( value=self.pipe.input.get("output_width", 1024), visible=self.pipe.input.get("output_width", None) is not None), gr.Textbox( value=self.pipe.input.get("refiner_prompt", ""), visible=self.pipe.input.get("refiner_prompt", None) is not None), gr.Slider( value=self.pipe.input.get("refiner_scale", -1), visible=self.pipe.input.get("refiner_scale", None) is not None ), gr.Checkbox( value=self.pipe.input.get("use_ace", True), visible=self.pipe.input.get("use_ace", None) is not None ) ) self.model_name_dd.change( change_model, inputs=[self.model_name_dd], outputs=[ self.model_name_dd, self.chatbot, self.text, self.step, self.cfg_scale, self.rescale, self.output_height, self.output_width, self.refiner_prompt, self.refiner_scale, self.use_ace]) def mode_change(mode_check): if mode_check: # ChatBot return ( gr.Row(visible=False), gr.Row(visible=True), gr.Button(value='Generate'), gr.State(value='chatbot'), gr.Column(visible=True), gr.Markdown(value=self.chatbot_inst) ) else: # Legacy return ( gr.Row(visible=True), gr.Row(visible=False), gr.Button(value=chat_sty + ' Chat'), gr.State(value='legacy'), gr.Column(visible=False), gr.Markdown(value=self.legacy_inst) ) self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox], outputs=[self.legacy_group, self.chat_group, self.chat_btn, self.ui_mode, self.upload_panel, self.instruction]) ######################################## def generate_gallery(text, images): if text.endswith(' '): return gr.update(), gr.update(visible=False) elif text.endswith('@'): gallery_info = [] for image_id, image_meta in images.items(): thumbnail_path = image_meta['thumbnail'] gallery_info.append((thumbnail_path, image_id)) return gr.update(), gr.update(visible=True, value=gallery_info) else: gallery_info = [] match = re.search('@([^@ ]+)$', text) if match: prefix = match.group(1) for image_id, image_meta in images.items(): if not image_id.startswith(prefix): continue thumbnail_path = image_meta['thumbnail'] gallery_info.append((thumbnail_path, image_id)) if len(gallery_info) > 0: return gr.update(), gr.update(visible=True, value=gallery_info) else: return gr.update(), gr.update(visible=False) else: return gr.update(), gr.update(visible=False) self.text.input(generate_gallery, inputs=[self.text, self.images], outputs=[self.text, self.gallery], show_progress='hidden') ######################################## def select_image(text, evt: gr.SelectData): image_id = evt.value['caption'] text = '@'.join(text.split('@')[:-1]) + f'@{image_id} ' return gr.update(value=text), gr.update(visible=False, value=None) self.gallery.select(select_image, inputs=self.text, outputs=[self.text, self.gallery]) ######################################## def generate_video(message, extend_prompt, history, images, num_steps, num_frames, cfg_scale, fps, seed, progress=gr.Progress(track_tqdm=True)): from diffusers.utils import export_to_video generator = torch.Generator(device='cuda').manual_seed(seed) img_ids = re.findall('@(.*?)[ ,;.?$]', message) if len(img_ids) == 0: history.append(( message, 'Sorry, no images were found in the prompt to be used as the first frame of the video.' )) while len(history) >= self.max_msgs: history.pop(0) return history, self.get_history( history), gr.update(), gr.update(visible=False) img_id = img_ids[0] prompt = re.sub(f'@{img_id}\s+', '', message) if extend_prompt: messages = copy.deepcopy(self.enhance_ctx) messages.append({ 'role': 'user', 'content': f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"', }) lock.acquire() outputs = self.enhancer( messages, max_new_tokens=200, ) prompt = outputs[0]['generated_text'][-1]['content'] print(prompt) lock.release() img_meta = images[img_id] img_path = img_meta['image'] image = Image.open(img_path).convert('RGB') lock.acquire() video = self.i2v_pipe( prompt=prompt, image=image, num_videos_per_prompt=1, num_inference_steps=num_steps, num_frames=num_frames, guidance_scale=cfg_scale, generator=generator, ).frames[0] lock.release() out_video_path = export_to_video(video, fps=fps) history.append(( f"Based on first frame @{img_id} and description '{prompt}', generate a video", 'This is generated video:')) history.append((None, out_video_path)) while len(history) >= self.max_msgs: history.pop(0) return history, self.get_history(history), gr.update( value=''), gr.update(visible=False) self.video_gen_btn.click( generate_video, inputs=[ self.text, self.extend_prompt, self.history, self.images, self.video_step, self.video_frames, self.video_cfg_scale, self.video_fps, self.video_seed ], outputs=[self.history, self.chatbot, self.text, self.gallery]) ######################################## def run_chat( message, legacy_image, ui_mode, use_ace, extend_prompt, history, images, use_history, history_result, negative_prompt, cfg_scale, rescale, refiner_prompt, refiner_scale, step, seed, output_h, output_w, video_auto, video_steps, video_frames, video_cfg_scale, video_fps, video_seed, progress=gr.Progress(track_tqdm=True)): legacy_img_ids = [] if ui_mode == 'legacy': if legacy_image is not None: history, images, img_id = self.add_uploaded_image_to_history( legacy_image, history, images) legacy_img_ids.append(img_id) retry_msg = message gen_id = get_md5(message)[:12] save_path = os.path.join(self.cache_dir, f'{gen_id}.png') img_ids = re.findall('@(.*?)[ ,;.?$]', message) history_io = None if len(img_ids) < 1: img_ids = legacy_img_ids for img_id in img_ids: if f'@{img_id}' not in message: message = f'@{img_id} ' + message new_message = message if len(img_ids) > 0: edit_image, edit_image_mask, edit_task = [], [], [] for i, img_id in enumerate(img_ids): if img_id not in images: gr.Info( f'The input image ID {img_id} is not exist... Skip loading image.' ) continue placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}' new_message = re.sub(f'@{img_id}', placeholder, new_message) img_meta = images[img_id] img_path = img_meta['image'] img_mask = img_meta['mask'] img_mask_type = img_meta['mask_type'] if img_mask_type is not None and img_mask_type == 'Composite': task = 'inpainting' else: task = '' edit_image.append(Image.open(img_path).convert('RGB')) edit_image_mask.append( Image.open(img_mask). convert('L') if img_mask is not None else None) edit_task.append(task) if use_history and (img_id in history_result): history_io = history_result[img_id] buffered = io.BytesIO() edit_image[0].save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}' else: pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n' edit_image = None edit_image_mask = None edit_task = '' print(new_message) imgs = self.pipe( image=edit_image, mask=edit_image_mask, task=edit_task, prompt=[new_message] * len(edit_image) if edit_image is not None else [new_message], negative_prompt=[negative_prompt] * len(edit_image) if edit_image is not None else [negative_prompt], history_io=history_io, output_height=output_h, output_width=output_w, sampler='ddim', sample_steps=step, guide_scale=cfg_scale, guide_rescale=rescale, seed=seed, refiner_prompt=refiner_prompt, refiner_scale=refiner_scale, use_ace=use_ace ) img = imgs[0] img.save(save_path, format='PNG') if history_io: history_io_new = copy.deepcopy(history_io) history_io_new['image'] += edit_image[:1] history_io_new['mask'] += edit_image_mask[:1] history_io_new['task'] += edit_task[:1] history_io_new['prompt'] += [new_message] history_io_new['image'] = history_io_new['image'][-5:] history_io_new['mask'] = history_io_new['mask'][-5:] history_io_new['task'] = history_io_new['task'][-5:] history_io_new['prompt'] = history_io_new['prompt'][-5:] history_result[gen_id] = history_io_new elif edit_image is not None and len(edit_image) > 0: history_io_new = { 'image': edit_image[:1], 'mask': edit_image_mask[:1], 'task': edit_task[:1], 'prompt': [new_message] } history_result[gen_id] = history_io_new w, h = img.size if w > h: tb_w = 128 tb_h = int(h * tb_w / w) else: tb_h = 128 tb_w = int(w * tb_h / h) thumbnail_path = os.path.join(self.cache_dir, f'{gen_id}_thumbnail.jpg') thumbnail = img.resize((tb_w, tb_h)) thumbnail.save(thumbnail_path, format='JPEG') images[gen_id] = { 'image': save_path, 'mask': None, 'mask_type': None, 'thumbnail': thumbnail_path } buffered = io.BytesIO() img.convert('RGB').save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' history.append( (message, f'{pre_info} The generated image @{gen_id} is:\n {img_str}')) if video_auto: if video_seed is None or video_seed == -1: video_seed = random.randint(0, 10000000) lock.acquire() generator = torch.Generator( device='cuda').manual_seed(video_seed) pixel_values = load_image(img.convert('RGB'), max_num=self.llm_max_num).to( torch.bfloat16).cuda() prompt = self.captioner.chat(self.llm_tokenizer, pixel_values, self.llm_prompt, self.llm_generation_config) print(prompt) lock.release() if extend_prompt: messages = copy.deepcopy(self.enhance_ctx) messages.append({ 'role': 'user', 'content': f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"', }) lock.acquire() outputs = self.enhancer( messages, max_new_tokens=200, ) prompt = outputs[0]['generated_text'][-1]['content'] print(prompt) lock.release() lock.acquire() video = self.i2v_pipe( prompt=prompt, image=img, num_videos_per_prompt=1, num_inference_steps=video_steps, num_frames=video_frames, guidance_scale=video_cfg_scale, generator=generator, ).frames[0] lock.release() out_video_path = export_to_video(video, fps=video_fps) history.append(( f"Based on first frame @{gen_id} and description '{prompt}', generate a video", 'This is generated video:')) history.append((None, out_video_path)) while len(history) >= self.max_msgs: history.pop(0) return (history, images, gr.Image(value=save_path), history_result, self.get_history( history), gr.update(), gr.update( visible=False), retry_msg) chat_inputs = [ self.legacy_image_uploader, self.ui_mode, self.use_ace, self.extend_prompt, self.history, self.images, self.use_history, self.history_result, self.negative_prompt, self.cfg_scale, self.rescale, self.refiner_prompt, self.refiner_scale, self.step, self.seed, self.output_height, self.output_width, self.video_auto, self.video_step, self.video_frames, self.video_cfg_scale, self.video_fps, self.video_seed ] chat_outputs = [ self.history, self.images, self.legacy_image_viewer, self.history_result, self.chatbot, self.text, self.gallery, self.retry_msg ] self.chat_btn.click(run_chat, inputs=[self.text] + chat_inputs, outputs=chat_outputs) self.text.submit(run_chat, inputs=[self.text] + chat_inputs, outputs=chat_outputs) def retry_fn(*args): return run_chat(*args) self.retry_btn.click(retry_fn, inputs=[self.retry_msg] + chat_inputs, outputs=chat_outputs) ######################################## def run_example(task, img, img_mask, ref1, prompt, seed): edit_image, edit_image_mask, edit_task = [], [], [] if img is not None: w, h = img.size if w > 2048: ratio = w / 2048. w = 2048 h = int(h / ratio) if h > 2048: ratio = h / 2048. h = 2048 w = int(w / ratio) img = img.resize((w, h)) edit_image.append(img) if img_mask is not None: img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None edit_image_mask.append( img_mask if img_mask is not None else None) edit_task.append(task) if ref1 is not None: ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None if ref1 is not None: edit_image.append(ref1) edit_image_mask.append(None) edit_task.append('') buffered = io.BytesIO() img.save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}' else: pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n' edit_image = None edit_image_mask = None edit_task = '' img_num = len(edit_image) if edit_image is not None else 1 imgs = self.pipe( image=edit_image, mask=edit_image_mask, task=edit_task, prompt=[prompt] * img_num, negative_prompt=[''] * img_num, seed=seed, refiner_prompt=self.pipe.input.get("refiner_prompt", ""), refiner_scale=self.pipe.input.get("refiner_scale", 0.0), ) img = imgs[0] buffered = io.BytesIO() img.convert('RGB').save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' history = [(prompt, f'{pre_info} The generated image is:\n {img_str}')] img_id = get_md5(img_b64)[:12] save_path = os.path.join(self.cache_dir, f'{img_id}.png') img.convert('RGB').save(save_path) return self.get_history(history), gr.update(value=''), gr.update( visible=False), gr.update(value=save_path), gr.update(value=-1) with self.eg: self.example_task = gr.Text(label='Task Name', value='', visible=False) self.example_image = gr.Image(label='Edit Image', type='pil', image_mode='RGB', visible=False) self.example_mask = gr.Image(label='Edit Image Mask', type='pil', image_mode='L', visible=False) self.example_ref_im1 = gr.Image(label='Ref Image', type='pil', image_mode='RGB', visible=False) self.examples = gr.Examples( fn=run_example, examples=self.chatbot_examples, inputs=[ self.example_task, self.example_image, self.example_mask, self.example_ref_im1, self.text, self.seed ], outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed], examples_per_page=4, cache_examples=False, run_on_click=True) ######################################## def upload_image(): return (gr.update(visible=True, scale=1), gr.update(visible=True, scale=1), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)) self.upload_btn.click(upload_image, inputs=[], outputs=[ self.chat_page, self.editor_page, self.upload_tab, self.edit_tab, self.image_view_tab, self.video_view_tab, self.upload_tabs ]) ######################################## def edit_image(evt: gr.SelectData): if isinstance(evt.value, str): img_b64s = re.findall( '', evt.value) imgs = [ Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i)))) for i in img_b64s ] if len(imgs) > 0: if len(imgs) == 2: if self.gradio_version >= '5.0.0': view_img = copy.deepcopy(imgs[-1]) else: view_img = copy.deepcopy(imgs) edit_img = copy.deepcopy(imgs[-1]) else: if self.gradio_version >= '5.0.0': view_img = copy.deepcopy(imgs[-1]) else: view_img = [ copy.deepcopy(imgs[-1]), copy.deepcopy(imgs[-1]) ] edit_img = copy.deepcopy(imgs[-1]) return (gr.update(visible=True, scale=1), gr.update(visible=True, scale=1), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(value=edit_img), gr.update(value=view_img), gr.update(value=None), gr.update(visible=True)) else: return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) elif isinstance(evt.value, dict) and evt.value.get( 'component', '') == 'video': value = evt.value['value']['video']['path'] return (gr.update(visible=True, scale=1), gr.update(visible=True, scale=1), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(), gr.update(), gr.update(value=value), gr.update()) else: return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) self.chatbot.select(edit_image, outputs=[ self.chat_page, self.editor_page, self.upload_tab, self.edit_tab, self.image_view_tab, self.video_view_tab, self.image_editor, self.image_viewer, self.video_viewer, self.edit_tabs ]) if self.gradio_version < '5.0.0': self.image_viewer.change(lambda x: x, inputs=self.image_viewer, outputs=self.image_viewer) ######################################## def submit_upload_image(image, history, images): history, images, _ = self.add_uploaded_image_to_history( image, history, images) return gr.update(visible=False), gr.update( visible=True), gr.update( value=self.get_history(history)), history, images self.sub_btn_1.click( submit_upload_image, inputs=[self.image_uploader, self.history, self.images], outputs=[ self.editor_page, self.chat_page, self.chatbot, self.history, self.images ]) ######################################## def submit_edit_image(imagemask, mask_type, history, images): history, images = self.add_edited_image_to_history( imagemask, mask_type, history, images) return gr.update(visible=False), gr.update( visible=True), gr.update( value=self.get_history(history)), history, images self.sub_btn_2.click(submit_edit_image, inputs=[ self.image_editor, self.mask_type, self.history, self.images ], outputs=[ self.editor_page, self.chat_page, self.chatbot, self.history, self.images ]) ######################################## def exit_edit(): return gr.update(visible=False), gr.update(visible=True, scale=3) self.ext_btn_1.click(exit_edit, outputs=[self.editor_page, self.chat_page]) self.ext_btn_2.click(exit_edit, outputs=[self.editor_page, self.chat_page]) self.ext_btn_3.click(exit_edit, outputs=[self.editor_page, self.chat_page]) self.ext_btn_4.click(exit_edit, outputs=[self.editor_page, self.chat_page]) ######################################## def update_mask_type_info(mask_type): if mask_type == 'Background': info = 'Background mode will not erase the visual content in the mask area' visible = False elif mask_type == 'Composite': info = 'Composite mode will erase the visual content in the mask area' visible = False elif mask_type == 'Outpainting': info = 'Outpaint mode is used for preparing input image for outpainting task' visible = True return (gr.update( visible=True, value= f"
{info}
" ), gr.update(visible=visible)) self.mask_type.change(update_mask_type_info, inputs=self.mask_type, outputs=[self.mask_type_info, self.outpaint_tab]) ######################################## def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio, image): img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB) h, w = img.shape[:2] new_h = int(h * (top_ratio + bottom_ratio + 1)) new_w = int(w * (left_ratio + right_ratio + 1)) start_h = int(h * top_ratio) start_w = int(w * left_ratio) new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8) new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255 new_img[start_h:start_h + h, start_w:start_w + w, :] = img new_mask[start_h:start_h + h, start_w:start_w + w] = 0 layer = np.concatenate([new_img, new_mask], axis=2) value = { 'background': new_img, 'composite': new_img, 'layers': [layer] } return gr.update(value=value) self.img_pad_btn.click(extend_image, inputs=[ self.top_ext, self.bottom_ext, self.left_ext, self.right_ext, self.image_editor ], outputs=self.image_editor) ######################################## def clear_chat(history, images, history_result): history.clear() images.clear() history_result.clear() return history, images, history_result, self.get_history(history) self.clear_btn.click( clear_chat, inputs=[self.history, self.images, self.history_result], outputs=[ self.history, self.images, self.history_result, self.chatbot ]) def get_history(self, history): info = [] for item in history: new_item = [None, None] if isinstance(item[0], str) and item[0].endswith('.mp4'): new_item[0] = gr.Video(item[0], format='mp4') else: new_item[0] = item[0] if isinstance(item[1], str) and item[1].endswith('.mp4'): new_item[1] = gr.Video(item[1], format='mp4') else: new_item[1] = item[1] info.append(new_item) return info def generate_random_string(self, length=20): letters_and_digits = string.ascii_letters + string.digits random_string = ''.join( random.choice(letters_and_digits) for i in range(length)) return random_string def add_edited_image_to_history(self, image, mask_type, history, images): if mask_type == 'Composite': img = Image.fromarray(image['composite']) else: img = Image.fromarray(image['background']) img_id = get_md5(self.generate_random_string())[:12] save_path = os.path.join(self.cache_dir, f'{img_id}.png') img.convert('RGB').save(save_path) mask = image['layers'][0][:, :, 3] mask = Image.fromarray(mask).convert('RGB') mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png') mask.save(mask_path) w, h = img.size if w > h: tb_w = 128 tb_h = int(h * tb_w / w) else: tb_h = 128 tb_w = int(w * tb_h / h) if mask_type == 'Background': comp_mask = np.array(mask, dtype=np.uint8) mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) * 0.6).astype(np.uint8) comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2) thumbnail = Image.alpha_composite( img.convert('RGBA'), Image.fromarray(comp_mask).convert('RGBA')).convert('RGB') else: thumbnail = img.convert('RGB') thumbnail_path = os.path.join(self.cache_dir, f'{img_id}_thumbnail.jpg') thumbnail = thumbnail.resize((tb_w, tb_h)) thumbnail.save(thumbnail_path, format='JPEG') buffered = io.BytesIO() img.convert('RGB').save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' buffered = io.BytesIO() mask.convert('RGB').save(buffered, format='PNG') mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') mask_str = f'' images[img_id] = { 'image': save_path, 'mask': mask_path, 'mask_type': mask_type, 'thumbnail': thumbnail_path } history.append(( None, f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}' )) return history, images def add_uploaded_image_to_history(self, img, history, images): img_id = get_md5(self.generate_random_string())[:12] save_path = os.path.join(self.cache_dir, f'{img_id}.png') w, h = img.size if w > 2048: ratio = w / 2048. w = 2048 h = int(h / ratio) if h > 2048: ratio = h / 2048. h = 2048 w = int(w / ratio) img = img.resize((w, h)) img.save(save_path) w, h = img.size if w > h: tb_w = 128 tb_h = int(h * tb_w / w) else: tb_h = 128 tb_w = int(w * tb_h / h) thumbnail_path = os.path.join(self.cache_dir, f'{img_id}_thumbnail.jpg') thumbnail = img.resize((tb_w, tb_h)) thumbnail.save(thumbnail_path, format='JPEG') images[img_id] = { 'image': save_path, 'mask': None, 'mask_type': None, 'thumbnail': thumbnail_path } buffered = io.BytesIO() img.convert('RGB').save(buffered, format='PNG') img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_str = f'' history.append( (None, f'This is uploaded image:\n {img_str} image ID is: {img_id}')) return history, images, img_id def run_gr(cfg): with gr.Blocks() as demo: chatbot = ChatBotUI(cfg) chatbot.create_ui() chatbot.set_callbacks() demo.launch(server_name='0.0.0.0', server_port=cfg.args.server_port, root_path=cfg.args.root_path) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Argparser for Scepter:\n') parser.add_argument('--server_port', dest='server_port', help='', type=int, default=2345) parser.add_argument('--root_path', dest='root_path', help='', default='') cfg = Config(load=True, parser_ins=parser) run_gr(cfg) ================================================ FILE: chatbot/utils.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(image_file, input_size=448, max_num=12): if isinstance(image_file, str): image = Image.open(image_file).convert('RGB') else: image = image_file transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values ================================================ FILE: config/inference_config/chatbot_ui.yaml ================================================ WORK_DIR: ./cache/chatbot FILE_SYSTEM: - NAME: "HuggingfaceFs" TEMP_DIR: ./cache - NAME: "ModelscopeFs" TEMP_DIR: ./cache - NAME: "LocalFs" TEMP_DIR: ./cache - NAME: "HttpFs" TEMP_DIR: ./cache # ENABLE_I2V: False SKIP_EXAMPLES: False # MODEL: EDIT_MODEL: MODEL_CFG_DIR: config/inference_config/models/ DEFAULT: ace_0.6b_512 I2V: MODEL_NAME: CogVideoX-5b-I2V MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/ CAPTIONER: MODEL_NAME: InternVL2-2B MODEL_DIR: ms://OpenGVLab/InternVL2-2B/ PROMPT: '\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.' ENHANCER: MODEL_NAME: Meta-Llama-3.1-8B-Instruct MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/ ================================================ FILE: config/inference_config/models/ace_0.6b_1024.yaml ================================================ NAME: ACE_0.6B_1024 IS_DEFAULT: False USE_DYNAMIC_MODEL: False DEFAULT_PARAS: PARAS: # INPUT: INPUT_IMAGE: INPUT_MASK: TASK: PROMPT: "" NEGATIVE_PROMPT: "" OUTPUT_HEIGHT: 1024 OUTPUT_WIDTH: 1024 SAMPLER: ddim SAMPLE_STEPS: 50 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 SEED: -1 TAR_INDEX: 0 OUTPUT: LATENT: IMAGES: SEED: MODULES_PARAS: FIRST_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: float16 INPUT: ["IMAGE"] - NAME: decode DTYPE: float16 INPUT: ["LATENT"] # DIFFUSION_MODEL: FUNCTION: - NAME: forward DTYPE: float16 INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"] # COND_STAGE_MODEL: FUNCTION: - NAME: encode_list DTYPE: bfloat16 INPUT: ["PROMPT"] # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: "" TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 4096 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: ACETextEmbedder PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False ================================================ FILE: config/inference_config/models/ace_0.6b_1024_refiner.yaml ================================================ NAME: ACE_0.6B_1024_REFINER IS_DEFAULT: False USE_DYNAMIC_MODEL: False DEFAULT_PARAS: PARAS: # INPUT: INPUT_IMAGE: INPUT_MASK: TASK: PROMPT: "" NEGATIVE_PROMPT: "" OUTPUT_HEIGHT: 1024 OUTPUT_WIDTH: 1024 SAMPLER: ddim SAMPLE_STEPS: 50 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 SEED: -1 TAR_INDEX: 0 REFINER_SCALE: 0.2 USE_ACE: True #REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR" REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR" OUTPUT: LATENT: IMAGES: SEED: MODULES_PARAS: FIRST_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: float16 INPUT: ["IMAGE"] - NAME: decode DTYPE: float16 INPUT: ["LATENT"] # DIFFUSION_MODEL: FUNCTION: - NAME: forward DTYPE: float16 INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"] # COND_STAGE_MODEL: FUNCTION: - NAME: encode_list DTYPE: bfloat16 INPUT: ["PROMPT"] # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: "" TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 4096 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: ACETextEmbedder PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False ACE_PROMPT: [ "A cute cartoon rabbit holding a whiteboard that says 'ACE Refiner', standing in a sunny meadow filled with flowers, with a big smile and bright colors.", "A beautiful young woman with long flowing hair, wearing a summer dress, holding a whiteboard that reads 'ACE Refiner' while sitting on a park bench surrounded by cherry blossoms.", "An adorable cartoon cat wearing oversized glasses, holding a whiteboard that says 'ACE Refiner', perched on a stack of colorful books in a cozy library setting.", "A charming girl with pigtails, wearing a cute school uniform, enthusiastically holding a whiteboard that has 'ACE Refiner' written on it, in a bright and cheerful classroom full of educational posters.", "A friendly cartoon dog with floppy ears, sitting in front of a doghouse, proudly holding a whiteboard that says 'ACE Refiner', with a playful expression and a blue sky in the background.", "A cute anime girl with big expressive eyes, dressed in a colorful outfit, holding a whiteboard that reads 'ACE Refiner' in a fantastical landscape filled with mythical creatures.", "A vibrant cartoon fox holding a whiteboard that says 'ACE Refiner', standing on a rock by a sparkling stream, surrounded by lush greenery and butterflies.", "A stylish young woman in a business outfit, smiling as she holds a whiteboard written with 'ACE Refiner', in a modern office filled with plants and natural light.", "A cute cartoon unicorn holding a sparkling whiteboard that says 'ACE Refiner', frolicking in a magical forest, with rainbows and stars in the background.", "A happy family, consisting of a cute little girl and her playful puppy, holding a whiteboard that says 'ACE Refiner', together in their backyard on a sunny day." ] REFINER_MODEL: NAME: "" IS_DEFAULT: False DEFAULT_PARAS: PARAS: RESOLUTIONS: [ [ 1024, 1024 ] ] INPUT: INPUT_IMAGE: INPUT_MASK: TASK: PROMPT: "" NEGATIVE_PROMPT: "" OUTPUT_HEIGHT: 1024 OUTPUT_WIDTH: 1024 SAMPLER: flow_euler SAMPLE_STEPS: 30 GUIDE_SCALE: 3.5 GUIDE_RESCALE: OUTPUT: LATENT: IMAGES: SEED: MODULES_PARAS: FIRST_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: bfloat16 INPUT: [ "IMAGE" ] - NAME: decode DTYPE: bfloat16 INPUT: [ "LATENT" ] PARAS: SCALE_FACTOR: 1.5305 SHIFT_FACTOR: 0.0609 SIZE_FACTOR: 8 DIFFUSION_MODEL: FUNCTION: - NAME: forward DTYPE: bfloat16 INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ] COND_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: bfloat16 INPUT: [ "PROMPT" ] MODEL: DIFFUSION: NAME: DiffusionFluxRF PREDICTION_TYPE: raw NOISE_SCHEDULER: NAME: FlowMatchSigmaScheduler WEIGHTING_SCHEME: logit_normal SHIFT: 3.0 LOGIT_MEAN: 0.0 LOGIT_STD: 1.0 MODE_SCALE: 1.29 DIFFUSION_MODEL: NAME: FluxMR PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors IN_CHANNELS: 64 OUT_CHANNELS: 64 HIDDEN_SIZE: 3072 NUM_HEADS: 24 AXES_DIM: [ 16, 56, 56 ] THETA: 10000 VEC_IN_DIM: 768 GUIDANCE_EMBED: True CONTEXT_IN_DIM: 4096 MLP_RATIO: 4.0 QKV_BIAS: True DEPTH: 19 DEPTH_SINGLE_BLOCKS: 38 USE_GRAD_CHECKPOINT: True ATTN_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKLFlux EMBED_DIM: 16 PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@ae.safetensors IGNORE_KEYS: [ ] BATCH_SIZE: 8 USE_CONV: False SCALE_FACTOR: 0.3611 SHIFT_FACTOR: 0.1159 # ENCODER: NAME: Encoder USE_CHECKPOINT: False CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 16 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder USE_CHECKPOINT: False CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 16 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: T5PlusClipFluxEmbedder T5_MODEL: NAME: HFEmbedder HF_MODEL_CLS: T5EncoderModel MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/ HF_TOKENIZER_CLS: T5Tokenizer TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/ MAX_LENGTH: 512 OUTPUT_KEY: last_hidden_state D_TYPE: bfloat16 BATCH_INFER: False CLEAN: whitespace CLIP_MODEL: NAME: HFEmbedder HF_MODEL_CLS: CLIPTextModel MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder/ HF_TOKENIZER_CLS: CLIPTokenizer TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer/ MAX_LENGTH: 77 OUTPUT_KEY: pooler_output D_TYPE: bfloat16 BATCH_INFER: True CLEAN: whitespace ================================================ FILE: config/inference_config/models/ace_0.6b_512.yaml ================================================ NAME: ACE_0.6B_512 IS_DEFAULT: True USE_DYNAMIC_MODEL: False DEFAULT_PARAS: PARAS: # INPUT: INPUT_IMAGE: INPUT_MASK: TASK: PROMPT: "" NEGATIVE_PROMPT: "" OUTPUT_HEIGHT: 512 OUTPUT_WIDTH: 512 SAMPLER: ddim SAMPLE_STEPS: 20 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 SEED: -1 TAR_INDEX: 0 OUTPUT: LATENT: IMAGES: SEED: MODULES_PARAS: FIRST_STAGE_MODEL: FUNCTION: - NAME: encode DTYPE: float16 INPUT: ["IMAGE"] - NAME: decode DTYPE: float16 INPUT: ["LATENT"] # DIFFUSION_MODEL: FUNCTION: - NAME: forward DTYPE: float16 INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"] # COND_STAGE_MODEL: FUNCTION: - NAME: encode_list DTYPE: bfloat16 INPUT: ["PROMPT"] # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: "" TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 1024 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: ACETextEmbedder PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False ================================================ FILE: config/train_config/ace_0.6b_1024_train.yaml ================================================ ENV: BACKEND: nccl SEED: 2024 # SOLVER: NAME: ACESolverV1 RESUME_FROM: LOAD_MODEL_ONLY: True USE_FSDP: False SHARDING_STRATEGY: USE_AMP: True DTYPE: float16 CHANNELS_LAST: True MAX_STEPS: 500 MAX_EPOCHS: -1 NUM_FOLDS: 1 ACCU_STEP: 1 EVAL_INTERVAL: 50 RESCALE_LR: False # WORK_DIR: ./cache/exp/exp1 LOG_FILE: std_log.txt # FILE_SYSTEM: - NAME: "HuggingfaceFs" TEMP_DIR: ./cache - NAME: "ModelscopeFs" TEMP_DIR: ./cache - NAME: "LocalFs" TEMP_DIR: ./cache - NAME: "HttpFs" TEMP_DIR: ./cache # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: USE_EMA: True EVAL_EMA: False TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 4096 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: T5EmbedderHF PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False LOSS: NAME: ReconstructLoss LOSS_TYPE: l2 # SAMPLE_ARGS: SAMPLER: ddim SAMPLE_STEPS: 20 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 # OPTIMIZER: NAME: AdamW LEARNING_RATE: 1e-5 EPS: 1e-10 WEIGHT_DECAY: 5e-4 # TRAIN_DATA: NAME: ACEDemoDataset MODE: train MS_DATASET_NAME: cache/datasets/hed_pair MS_DATASET_NAMESPACE: "" MS_DATASET_SPLIT: "train" MS_DATASET_SUBNAME: "" PROMPT_PREFIX: "" REPLACE_STYLE: False MAX_SEQ_LEN: 4096 PIN_MEMORY: True BATCH_SIZE: 1 NUM_WORKERS: 1 SAMPLER: NAME: LoopSampler # TRAIN_HOOKS: - NAME: BackwardHook PRIORITY: 0 - NAME: LogHook LOG_INTERVAL: 50 - NAME: CheckpointHook INTERVAL: 100 - NAME: ProbeDataHook PROB_INTERVAL: 100 ================================================ FILE: config/train_config/ace_0.6b_512_train.yaml ================================================ ENV: BACKEND: nccl SEED: 2024 # SOLVER: NAME: ACESolverV1 RESUME_FROM: LOAD_MODEL_ONLY: True USE_FSDP: False SHARDING_STRATEGY: USE_AMP: True DTYPE: float16 CHANNELS_LAST: True MAX_STEPS: 500 MAX_EPOCHS: -1 NUM_FOLDS: 1 ACCU_STEP: 1 EVAL_INTERVAL: 50 RESCALE_LR: False # WORK_DIR: ./cache/exp/exp1 LOG_FILE: std_log.txt # FILE_SYSTEM: - NAME: "HuggingfaceFs" TEMP_DIR: ./cache - NAME: "ModelscopeFs" TEMP_DIR: ./cache - NAME: "LocalFs" TEMP_DIR: ./cache - NAME: "HttpFs" TEMP_DIR: ./cache # MODEL: NAME: LdmACE PRETRAINED_MODEL: IGNORE_KEYS: [ ] SCALE_FACTOR: 0.18215 SIZE_FACTOR: 8 DECODER_BIAS: 0.5 DEFAULT_N_PROMPT: USE_EMA: True EVAL_EMA: False TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] USE_TEXT_POS_EMBEDDINGS: True # DIFFUSION: NAME: ACEDiffusion PREDICTION_TYPE: eps MIN_SNR_GAMMA: NOISE_SCHEDULER: NAME: LinearScheduler NUM_TIMESTEPS: 1000 BETA_MIN: 0.0001 BETA_MAX: 0.02 # DIFFUSION_MODEL: NAME: DiTACE PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth IGNORE_KEYS: [ ] PATCH_SIZE: 2 IN_CHANNELS: 4 HIDDEN_SIZE: 1152 DEPTH: 28 NUM_HEADS: 16 MLP_RATIO: 4.0 PRED_SIGMA: True DROP_PATH: 0.0 WINDOW_DIZE: 0 Y_CHANNELS: 4096 MAX_SEQ_LEN: 1024 QK_NORM: True USE_GRAD_CHECKPOINT: True ATTENTION_BACKEND: flash_attn # FIRST_STAGE_MODEL: NAME: AutoencoderKL EMBED_DIM: 4 PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin IGNORE_KEYS: [] # ENCODER: NAME: Encoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DOUBLE_Z: True DROPOUT: 0.0 RESAMP_WITH_CONV: True # DECODER: NAME: Decoder CH: 128 OUT_CH: 3 NUM_RES_BLOCKS: 2 IN_CHANNELS: 3 ATTN_RESOLUTIONS: [ ] CH_MULT: [ 1, 2, 4, 4 ] Z_CHANNELS: 4 DROPOUT: 0.0 RESAMP_WITH_CONV: True GIVE_PRE_END: False TANH_OUT: False # COND_STAGE_MODEL: NAME: ACETextEmbedder PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/ TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl LENGTH: 120 T5_DTYPE: bfloat16 ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ] CLEAN: whitespace USE_GRAD: False LOSS: NAME: ReconstructLoss LOSS_TYPE: l2 # SAMPLE_ARGS: SAMPLER: ddim SAMPLE_STEPS: 20 GUIDE_SCALE: 4.5 GUIDE_RESCALE: 0.5 # OPTIMIZER: NAME: AdamW LEARNING_RATE: 1e-5 EPS: 1e-10 WEIGHT_DECAY: 5e-4 # TRAIN_DATA: NAME: ACEDemoDataset MODE: train MS_DATASET_NAME: cache/datasets/hed_pair MS_DATASET_NAMESPACE: "" MS_DATASET_SPLIT: "train" MS_DATASET_SUBNAME: "" PROMPT_PREFIX: "" REPLACE_STYLE: False MAX_SEQ_LEN: 1024 PIN_MEMORY: True BATCH_SIZE: 1 NUM_WORKERS: 1 SAMPLER: NAME: LoopSampler # TRAIN_HOOKS: - NAME: BackwardHook PRIORITY: 0 - NAME: LogHook LOG_INTERVAL: 50 - NAME: CheckpointHook INTERVAL: 100 - NAME: ProbeDataHook PROB_INTERVAL: 100 ================================================ FILE: modules/__init__.py ================================================ from . import data, model, solver ================================================ FILE: modules/data/__init__.py ================================================ from . import dataset ================================================ FILE: modules/data/dataset/__init__.py ================================================ from .dataset import ACEDemoDataset ================================================ FILE: modules/data/dataset/dataset.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import io import math import os import sys from collections import defaultdict import numpy as np import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from scepter.modules.data.dataset.base_dataset import BaseDataset from scepter.modules.data.dataset.registry import DATASETS from scepter.modules.transform.io import pillow_convert from scepter.modules.utils.config import dict_to_yaml from scepter.modules.utils.file_system import FS Image.MAX_IMAGE_PIXELS = None @DATASETS.register_class() class ACEDemoDataset(BaseDataset): para_dict = { 'MS_DATASET_NAME': { 'value': '', 'description': 'Modelscope dataset name.' }, 'MS_DATASET_NAMESPACE': { 'value': '', 'description': 'Modelscope dataset namespace.' }, 'MS_DATASET_SUBNAME': { 'value': '', 'description': 'Modelscope dataset subname.' }, 'MS_DATASET_SPLIT': { 'value': '', 'description': 'Modelscope dataset split set name, default is train.' }, 'MS_REMAP_KEYS': { 'value': None, 'description': 'Modelscope dataset header of list file, the default is Target:FILE; ' 'If your file is not this header, please set this field, which is a map dict.' "For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE" }, 'MS_REMAP_PATH': { 'value': None, 'description': 'When modelscope dataset name is not None, that means you use the dataset from modelscope,' ' default is None. But if you want to use the datalist from modelscope and the file from ' 'local device, you can use this field to set the root path of your images. ' }, 'TRIGGER_WORDS': { 'value': '', 'description': 'The words used to describe the common features of your data, especially when you customize a ' 'tuner. Use these words you can get what you want.' }, 'HIGHLIGHT_KEYWORDS': { 'value': '', 'description': 'The keywords you want to highlight in prompt, which will be replace by .' }, 'KEYWORDS_SIGN': { 'value': '', 'description': 'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>' }, } def __init__(self, cfg, logger=None): super().__init__(cfg=cfg, logger=logger) from modelscope import MsDataset from modelscope.utils.constant import DownloadMode ms_dataset_name = cfg.get('MS_DATASET_NAME', None) ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None) ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None) ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train') ms_remap_keys = cfg.get('MS_REMAP_KEYS', None) ms_remap_path = cfg.get('MS_REMAP_PATH', None) self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024) self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4) self.d = cfg.get('DOWNSAMPLE_RATIO', 16) self.replace_style = cfg.get('REPLACE_STYLE', False) self.trigger_words = cfg.get('TRIGGER_WORDS', '') self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '') self.keywords_sign = cfg.get('KEYWORDS_SIGN', '') self.add_indicator = cfg.get('ADD_INDICATOR', False) # Use modelscope dataset if not ms_dataset_name: raise ValueError( 'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized ' 'as modelscope dataset.') if FS.exists(ms_dataset_name): ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name) self.ms_dataset_name = ms_dataset_name # ms_remap_path = ms_dataset_name try: self.data = MsDataset.load(str(ms_dataset_name), namespace=ms_dataset_namespace, subset_name=ms_dataset_subname, split=ms_dataset_split) except Exception: self.logger.info( "Load Modelscope dataset failed, retry with download_mode='force_redownload'." ) try: self.data = MsDataset.load( str(ms_dataset_name), namespace=ms_dataset_namespace, subset_name=ms_dataset_subname, split=ms_dataset_split, download_mode=DownloadMode.FORCE_REDOWNLOAD) except Exception as sec_e: raise ValueError(f'Load Modelscope dataset failed {sec_e}.') if ms_remap_keys: self.data = self.data.remap_columns(ms_remap_keys.get_dict()) if ms_remap_path: def map_func(example): return { k: os.path.join(ms_remap_path, v) if k.endswith(':FILE') else v for k, v in example.items() } self.data = self.data.ds_instance.map(map_func) self.transforms = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def __len__(self): if self.mode == 'train': return sys.maxsize else: return len(self.data) def _get(self, index: int): current_data = self.data[index % len(self.data)] tar_image_path = current_data.get('Target:FILE', '') src_image_path = current_data.get('Source:FILE', '') style = current_data.get('Style', '') prompt = current_data.get('Prompt', current_data.get('prompt', '')) if self.replace_style and not style == '': prompt = prompt.replace(style, f'<{self.keywords_sign}>') elif not self.replace_keywords.strip() == '': prompt = prompt.replace( self.replace_keywords, '<' + self.replace_keywords + f'{self.keywords_sign}>') if not self.trigger_words == '': prompt = self.trigger_words.strip() + ' ' + prompt src_image = self.load_image(self.ms_dataset_name, src_image_path, cvt_type='RGB') tar_image = self.load_image(self.ms_dataset_name, tar_image_path, cvt_type='RGB') src_image = self.image_preprocess(src_image) tar_image = self.image_preprocess(tar_image) tar_image = self.transforms(tar_image) src_image = self.transforms(src_image) src_mask = torch.ones_like(src_image[[0]]) tar_mask = torch.ones_like(tar_image[[0]]) if self.add_indicator: if '{image}' not in prompt: prompt = '{image}, ' + prompt return { 'edit_image': [src_image], 'edit_image_mask': [src_mask], 'image': tar_image, 'image_mask': tar_mask, 'prompt': [prompt], } def load_image(self, prefix, img_path, cvt_type=None): if img_path is None or img_path == '': return None img_path = os.path.join(prefix, img_path) with FS.get_object(img_path) as image_bytes: image = Image.open(io.BytesIO(image_bytes)) if cvt_type is not None: image = pillow_convert(image, cvt_type) return image def image_preprocess(self, img, size=None, interpolation=InterpolationMode.BILINEAR): H, W = img.height, img.width if H / W > self.max_aspect_ratio: img = T.CenterCrop((self.max_aspect_ratio * W, W))(img) elif W / H > self.max_aspect_ratio: img = T.CenterCrop((H, self.max_aspect_ratio * H))(img) if size is None: # resize image for max_seq_len, while keep the aspect ratio H, W = img.height, img.width scale = min( 1.0, math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d)))) rH = int( H * scale) // self.d * self.d # ensure divisible by self.d rW = int(W * scale) // self.d * self.d else: rH, rW = size img = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(img) return np.array(img, dtype=np.uint8) @staticmethod def get_config_template(): return dict_to_yaml('DATASet', __class__.__name__, ACEDemoDataset.para_dict, set_name=True) @staticmethod def collate_fn(batch): collect = defaultdict(list) for sample in batch: for k, v in sample.items(): collect[k].append(v) new_batch = dict() for k, v in collect.items(): if all([i is None for i in v]): new_batch[k] = None else: new_batch[k] = v return new_batch ================================================ FILE: modules/inference/__init__.py ================================================ ================================================ FILE: modules/model/__init__.py ================================================ from . import backbone, embedder, diffusion, network ================================================ FILE: modules/model/backbone/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .ace import DiTACE ================================================ FILE: modules/model/backbone/ace.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import re from collections import OrderedDict from functools import partial import torch import torch.nn as nn from einops import rearrange from torch.nn.utils.rnn import pad_sequence from torch.utils.checkpoint import checkpoint_sequential from scepter.modules.model.base_model import BaseModel from scepter.modules.model.registry import BACKBONES from scepter.modules.utils.config import dict_to_yaml from scepter.modules.utils.file_system import FS from .layers import ( Mlp, TimestepEmbedder, PatchEmbed, DiTACEBlock, T2IFinalLayer ) from .pos_embed import rope_params @BACKBONES.register_class() class DiTACE(BaseModel): para_dict = { 'PATCH_SIZE': { 'value': 2, 'description': '' }, 'IN_CHANNELS': { 'value': 4, 'description': '' }, 'HIDDEN_SIZE': { 'value': 1152, 'description': '' }, 'DEPTH': { 'value': 28, 'description': '' }, 'NUM_HEADS': { 'value': 16, 'description': '' }, 'MLP_RATIO': { 'value': 4.0, 'description': '' }, 'PRED_SIGMA': { 'value': True, 'description': '' }, 'DROP_PATH': { 'value': 0., 'description': '' }, 'WINDOW_SIZE': { 'value': 0, 'description': '' }, 'WINDOW_BLOCK_INDEXES': { 'value': None, 'description': '' }, 'Y_CHANNELS': { 'value': 4096, 'description': '' }, 'ATTENTION_BACKEND': { 'value': None, 'description': '' }, 'QK_NORM': { 'value': True, 'description': 'Whether to use RMSNorm for query and key.', }, } para_dict.update(BaseModel.para_dict) def __init__(self, cfg, logger): super().__init__(cfg, logger=logger) self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None) if self.window_block_indexes is None: self.window_block_indexes = [] self.pred_sigma = cfg.get('PRED_SIGMA', True) self.in_channels = cfg.get('IN_CHANNELS', 4) self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels self.patch_size = cfg.get('PATCH_SIZE', 2) self.num_heads = cfg.get('NUM_HEADS', 16) self.hidden_size = cfg.get('HIDDEN_SIZE', 1152) self.y_channels = cfg.get('Y_CHANNELS', 4096) self.drop_path = cfg.get('DROP_PATH', 0.) self.depth = cfg.get('DEPTH', 28) self.mlp_ratio = cfg.get('MLP_RATIO', 4.0) self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False) self.attention_backend = cfg.get('ATTENTION_BACKEND', None) self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024) self.qk_norm = cfg.get('QK_NORM', False) self.ignore_keys = cfg.get('IGNORE_KEYS', []) assert (self.hidden_size % self.num_heads ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0 d = self.hidden_size // self.num_heads self.freqs = torch.cat( [ rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3) rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3) rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3) ], dim=1) # init embedder self.x_embedder = PatchEmbed(self.patch_size, self.in_channels + 1, self.hidden_size, bias=True, flatten=False) self.t_embedder = TimestepEmbedder(self.hidden_size) self.y_embedder = Mlp(in_features=self.y_channels, hidden_features=self.hidden_size, out_features=self.hidden_size, act_layer=lambda: nn.GELU(approximate='tanh'), drop=0) self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)) # init blocks drop_path = [ x.item() for x in torch.linspace(0, self.drop_path, self.depth) ] self.blocks = nn.ModuleList([ DiTACEBlock(self.hidden_size, self.num_heads, mlp_ratio=self.mlp_ratio, drop_path=drop_path[i], window_size=self.window_size if i in self.window_block_indexes else 0, backend=self.attention_backend, use_condition=True, qk_norm=self.qk_norm) for i in range(self.depth) ]) self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels) self.initialize_weights() def load_pretrained_model(self, pretrained_model): if pretrained_model: with FS.get_from(pretrained_model, wait_finish=True) as local_path: model = torch.load(local_path, map_location='cpu') if 'state_dict' in model: model = model['state_dict'] new_ckpt = OrderedDict() for k, v in model.items(): if self.ignore_keys is not None: if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \ (isinstance(self.ignore_keys, list) and k in self.ignore_keys): continue k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.') k = k.replace('.cross_attn.proj.', '.cross_attn.o.').replace( '.attn.proj.', '.attn.o.') if '.cross_attn.kv_linear.' in k: k_p, v_p = torch.split(v, v.shape[0] // 2) new_ckpt[k.replace('.cross_attn.kv_linear.', '.cross_attn.k.')] = k_p new_ckpt[k.replace('.cross_attn.kv_linear.', '.cross_attn.v.')] = v_p elif '.attn.qkv.' in k: q_p, k_p, v_p = torch.split(v, v.shape[0] // 3) new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p elif 'y_embedder.y_proj.' in k: new_ckpt[k.replace('y_embedder.y_proj.', 'y_embedder.')] = v elif k in ('x_embedder.proj.weight'): model_p = self.state_dict()[k] if v.shape != model_p.shape: model_p.zero_() model_p[:, :4, :, :].copy_(v) new_ckpt[k] = torch.nn.parameter.Parameter(model_p) else: new_ckpt[k] = v elif k in ('x_embedder.proj.bias'): new_ckpt[k] = v else: new_ckpt[k] = v missing, unexpected = self.load_state_dict(new_ckpt, strict=False) print( f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) if len(missing) > 0: print(f'Missing Keys:\n {missing}') if len(unexpected) > 0: print(f'\nUnexpected Keys:\n {unexpected}') def forward(self, x, t=None, cond=dict(), mask=None, text_position_embeddings=None, gc_seg=-1, **kwargs): if self.freqs.device != x.device: self.freqs = self.freqs.to(x.device) if isinstance(cond, dict): context = cond.get('crossattn', None) else: context = cond if text_position_embeddings is not None: # default use the text_position_embeddings in state_dict # if state_dict doesn't including this key, use the arg: text_position_embeddings proj_position_embeddings = self.y_embedder( text_position_embeddings) else: proj_position_embeddings = None ctx_batch, txt_lens = [], [] if mask is not None and isinstance(mask, list): for ctx, ctx_mask in zip(context, mask): for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)): u, m = one_ctx t_len = m.flatten().sum() # l u = u[:t_len] u = self.y_embedder(u) if frame_id == 0: u = u + proj_position_embeddings[ len(ctx) - 1] if proj_position_embeddings is not None else u else: u = u + proj_position_embeddings[ frame_id - 1] if proj_position_embeddings is not None else u ctx_batch.append(u) txt_lens.append(t_len) else: raise TypeError y = torch.cat(ctx_batch, dim=0) txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True) batch_frames = [] for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']): u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1]) m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0) batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)]) if 'edit' in cond: for i, (edit, edit_mask) in enumerate( zip(cond['edit'], cond['edit_mask'])): if edit is None: continue for u, m in zip(edit, edit_mask): u = u.squeeze(0) m = torch.ones_like( u[[0], :, :]) if m is None else m.squeeze(0) batch_frames[i].append( torch.cat([u, m], dim=0).unsqueeze(0)) patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], [] for frames in batch_frames: patches, patch_shapes = [], [] self_x_len.append(0) for frame_id, u in enumerate(frames): u = self.x_embedder(u) h, w = u.size(2), u.size(3) u = rearrange(u, '1 c h w -> (h w) c') if frame_id == 0: u = u + proj_position_embeddings[ len(frames) - 1] if proj_position_embeddings is not None else u else: u = u + proj_position_embeddings[ frame_id - 1] if proj_position_embeddings is not None else u patches.append(u) patch_shapes.append([h, w]) cross_x_len.append(h * w) # b*s, 1 self_x_len[-1] += h * w # b, 1 # u = torch.cat(patches, dim=0) patch_batch.extend(patches) shape_batch.append( torch.LongTensor(patch_shapes).to(x.device, non_blocking=True)) # repeat t to align with x t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)]) self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to( x.device, non_blocking=True), torch.LongTensor(cross_x_len).to( x.device, non_blocking=True)) # x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c x = torch.cat(patch_batch, dim=0) x_shapes = pad_sequence(tuple(shape_batch), batch_first=True) # b, max(len(frames)), 2 t = self.t_embedder(t) # (N, D) t0 = self.t_block(t) # y = self.y_embedder(context) kwargs = dict(y=y, t=t0, x_shapes=x_shapes, self_x_len=self_x_len, cross_x_len=cross_x_len, freqs=self.freqs, txt_lens=txt_lens) if self.use_grad_checkpoint and gc_seg >= 0: x = checkpoint_sequential( functions=[partial(block, **kwargs) for block in self.blocks], segments=gc_seg if gc_seg > 0 else len(self.blocks), input=x, use_reentrant=False) else: for block in self.blocks: x = block(x, **kwargs) x = self.final_layer(x, t) # b*s*n, d outs, cur_length = [], 0 p = self.patch_size for seq_length, shape in zip(self_x_len, shape_batch): x_i = x[cur_length:cur_length + seq_length] h, w = shape[0].tolist() u = x_i[:h * w].view(h, w, p, p, -1) u = rearrange(u, 'h w p q c -> (h p w q) c' ) # dump into sequence for following tensor ops cur_length = cur_length + seq_length outs.append(u) x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1) if self.pred_sigma: return x.chunk(2, dim=1)[0] else: return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: if hasattr(self, 'y_embedder'): nn.init.normal_(self.y_embedder.fc1.weight, std=0.02) nn.init.normal_(self.y_embedder.fc2.weight, std=0.02) # Zero-out adaLN modulation layers for block in self.blocks: nn.init.constant_(block.cross_attn.o.weight, 0) nn.init.constant_(block.cross_attn.o.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype @staticmethod def get_config_template(): return dict_to_yaml('BACKBONE', __class__.__name__, DiTACE.para_dict, set_name=True) ================================================ FILE: modules/model/backbone/layers.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import math import warnings import torch import torch.nn as nn from .pos_embed import rope_apply_multires as rope_apply try: from flash_attn import (flash_attn_varlen_func) FLASHATTN_IS_AVAILABLE = True except ImportError as e: FLASHATTN_IS_AVAILABLE = False flash_attn_varlen_func = None warnings.warn(f'{e}') __all__ = [ "drop_path", "modulate", "PatchEmbed", "DropPath", "RMSNorm", "Mlp", "TimestepEmbedder", "DiTEditBlock", "MultiHeadAttentionDiTEdit", "T2IFinalLayer", ] def drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0], ) + (1, ) * ( x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand( shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output def modulate(x, shift, scale, unsqueeze=False): if unsqueeze: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) else: return x * (1 + scale) + shift class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( self, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True, ): super().__init__() self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return self._norm(x.float()).type_as(x) * self.weight def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class DiTACEBlock(nn.Module): def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, backend=None, use_condition=True, qk_norm=False, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.use_condition = use_condition self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = MultiHeadAttention(hidden_size, num_heads=num_heads, qkv_bias=True, backend=backend, qk_norm=qk_norm, **block_kwargs) if self.use_condition: self.cross_attn = MultiHeadAttention( hidden_size, context_dim=hidden_size, num_heads=num_heads, qkv_bias=True, backend=backend, qk_norm=qk_norm, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate='tanh') self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter( torch.randn(6, hidden_size) / hidden_size**0.5) def forward(self, x, y, t, **kwargs): B = x.size(0) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1), shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1)) x = x + self.drop_path(gate_msa * self.attn( modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), ** kwargs)) if self.use_condition: x = x + self.cross_attn(x, context=y, **kwargs) x = x + self.drop_path(gate_mlp * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False))) return x class MultiHeadAttention(nn.Module): def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, attn_drop=0.0, qkv_bias=False, dropout=0.0, backend=None, qk_norm=False, eps=1e-6, **block_kwargs): super().__init__() # consider head_dim first, then num_heads num_heads = dim // head_dim if head_dim else num_heads head_dim = dim // num_heads assert num_heads * head_dim == dim context_dim = context_dim or dim self.dim = dim self.context_dim = context_dim self.num_heads = num_heads self.head_dim = head_dim self.scale = math.pow(head_dim, -0.25) # layers self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(context_dim, dim, bias=qkv_bias) self.v = nn.Linear(context_dim, dim, bias=qkv_bias) self.o = nn.Linear(dim, dim) self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.dropout = nn.Dropout(dropout) self.attention_op = None self.attn_drop = nn.Dropout(attn_drop) self.backend = backend assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn', None) if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None): self.backend = 'flash_attn' self.softmax_scale = block_kwargs.get('softmax_scale', None) self.causal = block_kwargs.get('causal', False) self.window_size = block_kwargs.get('window_size', (-1, -1)) self.deterministic = block_kwargs.get('deterministic', False) else: raise NotImplementedError def flash_attn(self, x, context=None, **kwargs): ''' The implementation will be very slow when mask is not None, because we need rearange the x/context features according to mask. Args: x: context: mask: **kwargs: Returns: x ''' dtype = kwargs.get('dtype', torch.float16) def half(x): return x if x.dtype in [torch.float16, torch.bfloat16 ] else x.to(dtype) x_shapes = kwargs['x_shapes'] freqs = kwargs['freqs'] self_x_len = kwargs['self_x_len'] cross_x_len = kwargs['cross_x_len'] txt_lens = kwargs['txt_lens'] n, d = self.num_heads, self.head_dim if context is None: # self-attn q = self.norm_q(self.q(x)).view(-1, n, d) k = self.norm_q(self.k(x)).view(-1, n, d) v = self.v(x).view(-1, n, d) q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False) k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False) q_lens = k_lens = self_x_len else: # cross-attn q = self.norm_q(self.q(x)).view(-1, n, d) k = self.norm_q(self.k(context)).view(-1, n, d) v = self.v(context).view(-1, n, d) q_lens = cross_x_len k_lens = txt_lens cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) max_seqlen_q = q_lens.max() max_seqlen_k = k_lens.max() out_dtype = q.dtype q, k, v = half(q), half(k), half(v) x = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=self.attn_drop.p, softmax_scale=self.softmax_scale, causal=self.causal, window_size=self.window_size, deterministic=self.deterministic) x = x.type(out_dtype) x = x.reshape(-1, n * d) x = self.o(x) x = self.dropout(x) return x def forward(self, x, context=None, **kwargs): x = getattr(self, self.backend)(x, context=context, **kwargs) return x class T2IFinalLayer(nn.Module): """ The final layer of PixArt. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.scale_shift_table = nn.Parameter( torch.randn(2, hidden_size) / hidden_size**0.5) self.out_channels = out_channels def forward(self, x, t): shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) shift, scale = shift.squeeze(1), scale.squeeze(1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x ================================================ FILE: modules/model/backbone/pos_embed.py ================================================ import numpy as np from einops import rearrange import torch import torch.cuda.amp as amp import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence def frame_pad(x, seq_len, shapes): max_h, max_w = np.max(shapes, 0) frames = [] cur_len = 0 for h, w in shapes: frame_len = h * w frames.append( F.pad( x[cur_len:cur_len + frame_len].view(h, w, -1), (0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1) ) cur_len += frame_len if cur_len >= seq_len: break return torch.stack(frames) def frame_unpad(x, shapes): max_h, max_w = np.max(shapes, 0) x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w) frames = [] for i, (h, w) in enumerate(shapes): if i >= len(x): break frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c')) return torch.concat(frames) @amp.autocast(enabled=False) def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True): """ x: [B*L, N, C]. x_lens: [B]. x_shapes: [B, F, 2]. freqs: [M, C // 2]. """ n, c = x.size(1), x.size(2) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] st = 0 for i, (seq_len, shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())): x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c f, h, w = x_i.shape[:3] pad_seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex( x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(pad_seq_len, 1, -1) # apply rotary embedding x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x) x_i = frame_unpad(x_i, shapes) # append to collection output.append(x_i) st += seq_len return pad_sequence(output) if pad else torch.concat(output) @amp.autocast(enabled=False) def rope_params(max_seq_len, dim, theta=10000): """ Precompute the frequency tensor for complex exponentials. """ assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs ================================================ FILE: modules/model/diffusion/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .diffusions import ACEDiffusion from .samplers import DDIMSampler from .schedules import LinearScheduler ================================================ FILE: modules/model/diffusion/diffusions.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import math import os from collections import OrderedDict import torch from tqdm import trange from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS, NOISE_SCHEDULERS) from scepter.modules.utils.config import Config, dict_to_yaml from scepter.modules.utils.distribute import we from scepter.modules.utils.file_system import FS @DIFFUSIONS.register_class() class ACEDiffusion(object): para_dict = { 'NOISE_SCHEDULER': {}, 'SAMPLER_SCHEDULER': {}, 'PREDICTION_TYPE': { 'value': 'eps', 'description': 'The type of prediction to use for the loss function.' } } def __init__(self, cfg, logger=None): super(ACEDiffusion, self).__init__() self.logger = logger self.cfg = cfg self.init_params() def init_params(self): self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps') self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER, logger=self.logger) self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get( 'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER), logger=self.logger) self.num_timesteps = self.noise_scheduler.num_timesteps if self.cfg.have('WORK_DIR') and we.rank == 0: schedule_visualization = os.path.join(self.cfg.WORK_DIR, 'noise_schedule.png') with FS.put_to(schedule_visualization) as local_path: self.noise_scheduler.plot_noise_sampling_map(local_path) schedule_visualization = os.path.join(self.cfg.WORK_DIR, 'sampler_schedule.png') with FS.put_to(schedule_visualization) as local_path: self.sampler_scheduler.plot_noise_sampling_map(local_path) def sample(self, noise, model, model_kwargs={}, steps=20, sampler=None, use_dynamic_cfg=False, guide_scale=None, guide_rescale=None, show_progress=False, return_intermediate=None, intermediate_callback=None, reverse_scale = -1., x = None, **kwargs): assert isinstance(steps, (int, torch.LongTensor)) assert return_intermediate in (None, 'x0', 'xt') assert isinstance(sampler, (str, dict, Config)) intermediates = [] def callback_fn(x_t, t, sigma=None, alpha_bar=None): timestamp = t t = t.repeat(len(x_t)).round().long().to(x_t.device) sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1))) alpha_bar = alpha_bar.repeat(len(x_t), *([1] * (len(alpha_bar.shape) - 1))) if guide_scale is None or guide_scale == 1.0: out = model(x=x_t, t=t, **model_kwargs) else: if use_dynamic_cfg: guidance_scale = 1 + guide_scale * ( (1 - math.cos(math.pi * ( (steps - timestamp.item()) / steps)**5.0)) / 2) else: guidance_scale = guide_scale y_out = model(x=x_t, t=t, **model_kwargs[0]) u_out = model(x=x_t, t=t, **model_kwargs[1]) out = u_out + guidance_scale * (y_out - u_out) if guide_rescale is not None and guide_rescale > 0.0: ratio = ( y_out.flatten(1).std(dim=1) / (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) * (y_out.ndim - 1)) out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 if self.prediction_type == 'x0': x0 = out elif self.prediction_type == 'eps': x0 = (x_t - sigma * out) / alpha_bar elif self.prediction_type == 'v': x0 = alpha_bar * x_t - sigma * out else: raise NotImplementedError( f'prediction_type {self.prediction_type} not implemented') return x0 sampler_ins = self.get_sampler(sampler) # this is ignored for schnell sampler_output = sampler_ins.preprare_sampler( noise, x = x, steps=steps, reverse_scale= reverse_scale, prediction_type=self.prediction_type, scheduler_ins=self.sampler_scheduler, callback_fn=callback_fn) for _ in trange(sampler_output.steps, disable=not show_progress): trange.desc = sampler_output.msg sampler_output = sampler_ins.step(sampler_output) if return_intermediate == 'x_0': intermediates.append(sampler_output.x_0) elif return_intermediate == 'x_t': intermediates.append(sampler_output.x_t) if intermediate_callback is not None: intermediate_callback(intermediates[-1]) return (sampler_output.x_0, intermediates ) if return_intermediate is not None else sampler_output.x_0 def loss(self, x_0, model, model_kwargs={}, reduction='mean', noise=None, **kwargs): # use noise scheduler to add noise if noise is None: noise = torch.randn_like(x_0) schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs) x_t, t, sigma, alpha_bar = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha_bar out = model(x=x_t, t=t, **model_kwargs) # mse loss target = { 'eps': noise, 'x0': x_0, 'v': alpha_bar * noise - sigma * x_0 }[self.prediction_type] loss = (out - target).pow(2) if reduction == 'mean': loss = loss.flatten(1).mean(dim=1) return loss def get_sampler(self, sampler): if isinstance(sampler, str): if sampler not in DIFFUSION_SAMPLERS.class_map: if self.logger is not None: self.logger.info( f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' ) else: print( f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' ) return None sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False) sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg, logger=self.logger) elif isinstance(sampler, (Config, dict, OrderedDict)): if isinstance(sampler, (dict, OrderedDict)): sampler = Config( cfg_dict={k.upper(): v for k, v in dict(sampler).items()}, load=False) sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger) else: raise NotImplementedError return sampler_ins def __repr__(self) -> str: return f'{self.__class__.__name__}' + ' ' + super().__repr__() @staticmethod def get_config_template(): return dict_to_yaml('DIFFUSIONS', __class__.__name__, ACEDiffusion.para_dict, set_name=True) ================================================ FILE: modules/model/diffusion/samplers.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import torch from scepter.modules.model.registry import DIFFUSION_SAMPLERS from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler from scepter.modules.model.diffusion.util import _i def _i(tensor, t, x): """ Index tensor using t and format the output according to x. """ shape = (x.size(0), ) + (1, ) * (x.ndim - 1) if isinstance(t, torch.Tensor): t = t.to(tensor.device) return tensor[t].view(shape).to(x.device) @DIFFUSION_SAMPLERS.register_class('ddim') class DDIMSampler(BaseDiffusionSampler): def init_params(self): super().init_params() self.eta = self.cfg.get('ETA', 0.) self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE', 'trailing') def preprare_sampler(self, noise, x=None, steps=20, reverse_scale = -1., scheduler_ins=None, prediction_type='', sigmas=None, betas=None, alphas=None, alphas_bar=None, callback_fn=None, **kwargs): output = super().preprare_sampler(noise, x = x, steps = steps, reverse_scale = reverse_scale, scheduler_ins = scheduler_ins, prediction_type = prediction_type, sigmas = sigmas, betas = betas, alphas = alphas, alphas_bar = alphas_bar, callback_fn = callback_fn, **kwargs) sigmas = output.sigmas sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5 sigmas_vp[sigmas == float('inf')] = 1. output.add_custom_field('sigmas_vp', sigmas_vp) output.steps += 1 return output def step(self, sampler_output): x_t = sampler_output.x_t step = sampler_output.step t = sampler_output.ts[step] sigmas_vp = sampler_output.sigmas_vp.to(x_t.device) alpha_bar_init = _i(sampler_output.alphas_bar_init, step, x_t[:1]) sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1]) x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_bar_init) noise_factor = self.eta * (sigmas_vp[step + 1]**2 / sigmas_vp[step]**2 * (1 - (1 - sigmas_vp[step]**2) / (1 - sigmas_vp[step + 1]**2))) d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step] x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \ (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d sampler_output.x_0 = x if sigmas_vp[step + 1] > 0: x += noise_factor * torch.randn_like(x) sampler_output.x_t = x sampler_output.step += 1 sampler_output.msg = f'step {step}' return sampler_output ================================================ FILE: modules/model/diffusion/schedules.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import torch from dataclasses import dataclass, field from scepter.modules.model.registry import NOISE_SCHEDULERS from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler from scepter.modules.model.diffusion.util import _i @dataclass class ScheduleOutput(object): x_t: torch.Tensor x_0: torch.Tensor t: torch.Tensor sigma: torch.Tensor alpha_bar: torch.Tensor custom_fields: dict = field(default_factory=dict) def add_custom_field(self, key: str, value) -> None: self.__setattr__(key, value) @NOISE_SCHEDULERS.register_class() class LinearScheduler(BaseNoiseScheduler): para_dict = {} def init_params(self): super().init_params() self.beta_min = self.cfg.get('BETA_MIN', 0.00085) self.beta_max = self.cfg.get('BETA_MAX', 0.012) def betas_to_sigmas(self, betas): return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) def get_schedule(self): betas = torch.linspace(self.beta_min, self.beta_max, self.num_timesteps, dtype=torch.float32) sigmas = self.betas_to_sigmas(betas) self._sigmas = sigmas self._betas = betas self._alphas = torch.sqrt(1 - betas**2) self._alphas_bar = torch.sqrt(1 - sigmas**2) self._timesteps = torch.arange(len(sigmas), dtype=torch.float32) def add_noise(self, x_0, noise=None, t=None, **kwargs): if t is None: t = torch.randint(0, self.num_timesteps, (x_0.shape[0], ), device=x_0.device).long() alpha = _i(self.alphas, t, x_0) sigma = _i(self.sigmas, t, x_0) x_t = alpha * x_0 + sigma * noise return ScheduleOutput(x_0=x_0, x_t=x_t, t=t, alpha_bar=alpha, sigma=sigma) ================================================ FILE: modules/model/embedder/__init__.py ================================================ from .embedder import ACETextEmbedder ================================================ FILE: modules/model/embedder/embedder.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import warnings from contextlib import nullcontext import torch import torch.nn.functional as F import torch.utils.dlpack from scepter.modules.model.embedder.base_embedder import BaseEmbedder from scepter.modules.model.registry import EMBEDDERS from scepter.modules.model.tokenizer.tokenizer_component import ( basic_clean, canonicalize, heavy_clean, whitespace_clean) from scepter.modules.utils.config import dict_to_yaml from scepter.modules.utils.distribute import we from scepter.modules.utils.file_system import FS try: from transformers import AutoTokenizer, T5EncoderModel except Exception as e: warnings.warn( f'Import transformers error, please deal with this problem: {e}') @EMBEDDERS.register_class() class ACETextEmbedder(BaseEmbedder): """ Uses the OpenCLIP transformer encoder for text """ """ Uses the OpenCLIP transformer encoder for text """ para_dict = { 'PRETRAINED_MODEL': { 'value': 'google/umt5-small', 'description': 'Pretrained Model for umt5, modelcard path or local path.' }, 'TOKENIZER_PATH': { 'value': 'google/umt5-small', 'description': 'Tokenizer Path for umt5, modelcard path or local path.' }, 'FREEZE': { 'value': True, 'description': '' }, 'USE_GRAD': { 'value': False, 'description': 'Compute grad or not.' }, 'CLEAN': { 'value': 'whitespace', 'description': 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.' }, 'LAYER': { 'value': 'last', 'description': '' }, 'LEGACY': { 'value': True, 'description': 'Whether use legacy returnd feature or not ,default True.' } } def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) pretrained_path = cfg.get('PRETRAINED_MODEL', None) self.t5_dtype = cfg.get('T5_DTYPE', 'float32') assert pretrained_path with FS.get_dir_to_local_dir(pretrained_path, wait_finish=True) as local_path: self.model = T5EncoderModel.from_pretrained( local_path, torch_dtype=getattr( torch, 'float' if self.t5_dtype == 'float32' else self.t5_dtype)) tokenizer_path = cfg.get('TOKENIZER_PATH', None) self.length = cfg.get('LENGTH', 77) self.use_grad = cfg.get('USE_GRAD', False) self.clean = cfg.get('CLEAN', 'whitespace') self.added_identifier = cfg.get('ADDED_IDENTIFIER', None) if tokenizer_path: self.tokenize_kargs = {'return_tensors': 'pt'} with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path: if self.added_identifier is not None and isinstance( self.added_identifier, list): self.tokenizer = AutoTokenizer.from_pretrained(local_path) else: self.tokenizer = AutoTokenizer.from_pretrained(local_path) if self.length is not None: self.tokenize_kargs.update({ 'padding': 'max_length', 'truncation': True, 'max_length': self.length }) self.eos_token = self.tokenizer( self.tokenizer.eos_token)['input_ids'][0] else: self.tokenizer = None self.tokenize_kargs = {} self.use_grad = cfg.get('USE_GRAD', False) self.clean = cfg.get('CLEAN', 'whitespace') def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False # encode && encode_text def forward(self, tokens, return_mask=False, use_mask=True): # tokenization embedding_context = nullcontext if self.use_grad else torch.no_grad with embedding_context(): if use_mask: x = self.model(tokens.input_ids.to(we.device_id), tokens.attention_mask.to(we.device_id)) else: x = self.model(tokens.input_ids.to(we.device_id)) x = x.last_hidden_state if return_mask: return x.detach() + 0.0, tokens.attention_mask.to(we.device_id) else: return x.detach() + 0.0, None def _clean(self, text): if self.clean == 'whitespace': text = whitespace_clean(basic_clean(text)) elif self.clean == 'lower': text = whitespace_clean(basic_clean(text)).lower() elif self.clean == 'canonicalize': text = canonicalize(basic_clean(text)) elif self.clean == 'heavy': text = heavy_clean(basic_clean(text)) return text def encode(self, text, return_mask=False, use_mask=True): if isinstance(text, str): text = [text] if self.clean: text = [self._clean(u) for u in text] assert self.tokenizer is not None cont, mask = [], [] with torch.autocast(device_type='cuda', enabled=self.t5_dtype in ('float16', 'bfloat16'), dtype=getattr(torch, self.t5_dtype)): for tt in text: tokens = self.tokenizer([tt], **self.tokenize_kargs) one_cont, one_mask = self(tokens, return_mask=return_mask, use_mask=use_mask) cont.append(one_cont) mask.append(one_mask) if return_mask: return torch.cat(cont, dim=0), torch.cat(mask, dim=0) else: return torch.cat(cont, dim=0) def encode_list(self, text_list, return_mask=True): cont_list = [] mask_list = [] for pp in text_list: cont, cont_mask = self.encode(pp, return_mask=return_mask) cont_list.append(cont) mask_list.append(cont_mask) if return_mask: return cont_list, mask_list else: return cont_list @staticmethod def get_config_template(): return dict_to_yaml('MODELS', __class__.__name__, ACETextEmbedder.para_dict, set_name=True) ================================================ FILE: modules/model/network/__init__.py ================================================ from .ldm_ace import LdmACE ================================================ FILE: modules/model/network/ldm_ace.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy import random from contextlib import nullcontext import torch import torch.nn.functional as F from torch import nn from scepter.modules.model.network.ldm import LatentDiffusion from scepter.modules.model.registry import MODELS import torchvision.transforms as T from scepter.modules.model.utils.basic_utils import check_list_of_list from scepter.modules.model.utils.basic_utils import \ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor from scepter.modules.model.utils.basic_utils import ( to_device, unpack_tensor_into_imagelist) from scepter.modules.utils.config import dict_to_yaml from scepter.modules.utils.distribute import we class TextEmbedding(nn.Module): def __init__(self, embedding_shape): super().__init__() self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) @MODELS.register_class() class LdmACE(LatentDiffusion): para_dict = LatentDiffusion.para_dict para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''} def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) self.interpolate_func = lambda x: (F.interpolate( x.unsqueeze(0), scale_factor=1 / self.size_factor, mode='nearest-exact') if x is not None else None) self.text_indentifers = cfg.get('TEXT_IDENTIFIER', []) self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS', False) if self.use_text_pos_embeddings: self.text_position_embeddings = TextEmbedding( (10, 4096)).eval().requires_grad_(False) else: self.text_position_embeddings = None self.logger.info(self.model) @torch.no_grad() def encode_first_stage(self, x, **kwargs): return [ self.scale_factor * self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16)) for i in x ] @torch.no_grad() def decode_first_stage(self, z): return [ self.first_stage_model._decode(1. / self.scale_factor * i.to(torch.float16)) for i in z ] def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): if self.use_text_pos_embeddings and not torch.sum( self.text_position_embeddings.pos) > 0: identifier_cont, identifier_cont_mask = getattr( self.cond_stage_model, 'encode_list_of_list')(self.text_indentifers, return_mask=True) self.text_position_embeddings.load_state_dict( {'pos': torch.cat( [one_id[0][0, :].unsqueeze(0) for one_id in identifier_cont], dim=0)}) cont_, cont_mask_ = [], [] for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): if isinstance(pp, list): cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) else: raise NotImplementedError return cont_, cont_mask_ def limit_batch_data(self, batch_data_list, log_num): if log_num and log_num > 0: batch_data_list_limited = [] for sub_data in batch_data_list: if sub_data is not None: sub_data = sub_data[:log_num] batch_data_list_limited.append(sub_data) return batch_data_list_limited else: return batch_data_list def forward_train(self, edit_image=[], edit_image_mask=[], image=None, image_mask=None, noise=None, prompt=[], **kwargs): ''' Args: edit_image: list of list of edit_image edit_image_mask: list of list of edit_image_mask image: target image image_mask: target image mask noise: default is None, generate automaticly prompt: list of list of text **kwargs: Returns: ''' assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) assert len(edit_image) == len(edit_image_mask) == len(prompt) assert self.cond_stage_model is not None gc_seg = kwargs.pop('gc_seg', []) gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0 context = {} # process image image = to_device(image) x_start = self.encode_first_stage(image, **kwargs) x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L n, _, _ = x_start.shape t = torch.randint(0, self.num_timesteps, (n, ), device=x_start.device).long() context['x_shapes'] = x_shapes # process image mask image_mask = to_device(image_mask, strict=False) context['x_mask'] = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * n # process text # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] try: cont, cont_mask = getattr(self.cond_stage_model, 'encode_list_of_list')(prompt_, return_mask=True) except Exception as e: print(e, prompt_) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) context['crossattn'] = cont # process edit image & edit image mask edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if m is None: m = [None] * len(u) if u is not None else [None] e_img.append( self.encode_first_stage(u, **kwargs) if u is not None else u) e_mask.append([ self.interpolate_func(i) if i is not None else None for i in m ]) context['edit'], context['edit_mask'] = e_img, e_mask # process loss loss = self.diffusion.loss( x_0=x_start, t=t, noise=noise, model=self.model, model_kwargs={ 'cond': context, 'mask': cont_mask, 'gc_seg': gc_seg, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, **kwargs) loss = loss.mean() ret = {'loss': loss, 'probe_data': {'prompt': prompt}} return ret @torch.no_grad() def forward_test(self, edit_image=[], edit_image_mask=[], image=None, image_mask=None, prompt=[], n_prompt=[], sampler='ddim', sample_steps=20, guide_scale=4.5, guide_rescale=0.5, log_num=-1, seed=2024, **kwargs): assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) assert len(edit_image) == len(edit_image_mask) == len(prompt) assert self.cond_stage_model is not None # gc_seg is unused kwargs.pop('gc_seg', -1) # prepare data context, null_context = {}, {} prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data( [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask], log_num) g = torch.Generator(device=we.device_id) seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) g.manual_seed(seed) n_prompt = copy.deepcopy(prompt) # only modify the last prompt to be zero for nn_p_id, nn_p in enumerate(n_prompt): if isinstance(nn_p, str): n_prompt[nn_p_id] = [''] elif isinstance(nn_p, list): n_prompt[nn_p_id][-1] = '' else: raise NotImplementedError # process image image = to_device(image) x = self.encode_first_stage(image, **kwargs) noise = [ torch.empty(*i.shape, device=we.device_id).normal_(generator=g) for i in x ] noise, x_shapes = pack_imagelist_into_tensor(noise) context['x_shapes'] = null_context['x_shapes'] = x_shapes # process image mask image_mask = to_device(image_mask, strict=False) cond_mask = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * len(image) context['x_mask'] = null_context['x_mask'] = cond_mask # process text # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] cont, cont_mask = getattr(self.cond_stage_model, 'encode_list_of_list')(prompt_, return_mask=True) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) null_cont, null_cont_mask = getattr(self.cond_stage_model, 'encode_list_of_list')(n_prompt, return_mask=True) null_cont, null_cont_mask = self.cond_stage_embeddings( prompt, edit_image, null_cont, null_cont_mask) context['crossattn'] = cont null_context['crossattn'] = null_cont # processe edit image & edit image mask edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if u is None: continue if m is None: m = [None] * len(u) e_img.append(self.encode_first_stage(u, **kwargs)) e_mask.append([self.interpolate_func(i) for i in m]) null_context['edit'] = context['edit'] = e_img null_context['edit_mask'] = context['edit_mask'] = e_mask # process sample model = self.model_ema if self.use_ema and self.eval_ema else self.model embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \ else nullcontext with embedding_context(): samples = self.diffusion.sample( sampler=sampler, noise=noise, model=model, model_kwargs=[{ 'cond': context, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, { 'cond': null_context, 'mask': null_cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }] if guide_scale is not None and guide_scale > 1 else { 'cond': context, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, steps=sample_steps, guide_scale=guide_scale, guide_rescale=guide_rescale, show_progress=True, **kwargs) samples = unpack_tensor_into_imagelist(samples, x_shapes) x_samples = self.decode_first_stage(samples) outputs = list() for i in range(len(prompt)): rec_img = torch.clamp( (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255, min=0.0, max=1.0) rec_img = rec_img.squeeze(0) edit_imgs, edit_img_masks = [], [] if edit_image is not None and edit_image[i] is not None: if edit_image_mask[i] is None: edit_image_mask[i] = [None] * len(edit_image[i]) for edit_img, edit_mask in zip(edit_image[i], edit_image_mask[i]): edit_img = torch.clamp((edit_img + 1.0) / 2.0, min=0.0, max=1.0) edit_imgs.append(edit_img.squeeze(0)) if edit_mask is None: edit_mask = torch.ones_like(edit_img[[0], :, :]) edit_img_masks.append(edit_mask) one_tup = { 'reconstruct_image': rec_img, 'instruction': prompt[i], 'edit_image': edit_imgs if len(edit_imgs) > 0 else None, 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None } if image is not None: if image_mask is None: image_mask = [None] * len(image) ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0) one_tup['target_image'] = ori_img.squeeze(0) one_tup['target_mask'] = image_mask[i] if image_mask[ i] is not None else torch.ones_like(ori_img[[0], :, :]) outputs.append(one_tup) return outputs @staticmethod def get_config_template(): return dict_to_yaml('MODEL', __class__.__name__, LdmACE.para_dict, set_name=True) ================================================ FILE: modules/model/utils/basic_utils.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from inspect import isfunction import torch from torch.nn.utils.rnn import pad_sequence from scepter.modules.utils.distribute import we def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def transfer_size(para_num): if para_num > 1000 * 1000 * 1000 * 1000: bill = para_num / (1000 * 1000 * 1000 * 1000) return '{:.2f}T'.format(bill) elif para_num > 1000 * 1000 * 1000: gyte = para_num / (1000 * 1000 * 1000) return '{:.2f}B'.format(gyte) elif para_num > (1000 * 1000): meta = para_num / (1000 * 1000) return '{:.2f}M'.format(meta) elif para_num > 1000: kelo = para_num / 1000 return '{:.2f}K'.format(kelo) else: return para_num def count_params(model): total_params = sum(p.numel() for p in model.parameters()) return transfer_size(total_params) def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def unpack_tensor_into_imagelist(image_tensor, shapes): image_list = [] for img, shape in zip(image_tensor, shapes): h, w = shape[0], shape[1] image_list.append(img[:, :h * w].view(1, -1, h, w)) return image_list def find_example(tensor_list, image_list): for i in tensor_list: if isinstance(i, torch.Tensor): return torch.zeros_like(i) for i in image_list: if isinstance(i, torch.Tensor): _, c, h, w = i.size() return torch.zeros_like(i.view(c, h * w).transpose(1, 0)) return None def pack_imagelist_into_tensor_v2(image_list): # allow None example = None image_tensor, shapes = [], [] for img in image_list: if img is None: example = find_example(image_tensor, image_list) if example is None else example image_tensor.append(example) shapes.append(None) continue _, c, h, w = img.size() image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c shapes.append((h, w)) image_tensor = pad_sequence(image_tensor, batch_first=True).permute(0, 2, 1) # b, c, l return image_tensor, shapes def to_device(inputs, strict=True): if inputs is None: return None if strict: assert all(isinstance(i, torch.Tensor) for i in inputs) return [i.to(we.device_id) if i is not None else None for i in inputs] def check_list_of_list(ll): return isinstance(ll, list) and all(isinstance(i, list) for i in ll) ================================================ FILE: modules/solver/__init__.py ================================================ from .ace_solver import ACESolverV1 ================================================ FILE: modules/solver/ace_solver.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch from tqdm import tqdm from scepter.modules.utils.data import transfer_data_to_cuda from scepter.modules.utils.distribute import we from scepter.modules.utils.probe import ProbeData from scepter.modules.solver.registry import SOLVERS from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver @SOLVERS.register_class() class ACESolverV1(LatentDiffusionSolver): def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1) def save_results(self, results): log_data, log_label = [], [] for result in results: ret_images, ret_labels = [], [] edit_image = result.get('edit_image', None) edit_mask = result.get('edit_mask', None) if edit_image is not None: for i, edit_img in enumerate(result['edit_image']): if edit_img is None: continue ret_images.append( (edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype( np.uint8)) ret_labels.append(f'edit_image{i}; ') if edit_mask is not None: ret_images.append( (edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'edit_mask{i}; ') target_image = result.get('target_image', None) target_mask = result.get('target_mask', None) if target_image is not None: ret_images.append( (target_image.permute(1, 2, 0).cpu().numpy() * 255).astype( np.uint8)) ret_labels.append('target_image; ') if target_mask is not None: ret_images.append( (target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append('target_mask; ') reconstruct_image = result.get('reconstruct_image', None) if reconstruct_image is not None: ret_images.append( (reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f"{result['instruction']}") log_data.append(ret_images) log_label.append(ret_labels) return log_data, log_label @torch.no_grad() def run_eval(self): self.eval_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'eval_label': log_label}) self.register_probe({ 'eval_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) @torch.no_grad() def run_test(self): self.test_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'test_label': log_label}) self.register_probe({ 'test_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) @property def probe_data(self): if not we.debug and self.mode == 'train': batch_data = transfer_data_to_cuda( self.current_batch_data[self.mode]) self.eval_mode() with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): batch_data['log_num'] = self.log_train_num results = self.run_step_eval(batch_data) self.train_mode() log_data, log_label = self.save_results(results) self.register_probe({ 'train_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.register_probe({'train_label': log_label}) return super(LatentDiffusionSolver, self).probe_data ================================================ FILE: readme.md ================================================

: All-round Creator and Editor Following
Instructions via Diffusion Transformer

Paper PDF Project Page
Zhen Han* · Zeyinzi Jiang* · Yulin Pan* · Jingfeng Zhang* · Chaojie Mao*
Chenwei Xie · Yu Liu · Jingren Zhou
Tongyi Lab, Alibaba Group

## 📢 News * **[2024.9.30]** Release the paper of ACE on arxiv. * **[2024.10.31]** Release the ACE checkpoint on [ModelScope](https://www.modelscope.cn/models/iic/ACE-0.6B-512px) and [HuggingFace](https://huggingface.co/scepter-studio/ACE-0.6B-512px). * **[2024.11.1]** Support online demo on [HuggingFace](https://huggingface.co/spaces/scepter-studio/ACE-Chat). * **[2024.11.20]** Release the [ACE-0.6b-1024px](https://huggingface.co/scepter-studio/ACE-0.6B-1024px) model, which significantly enhances image generation quality compared with [ACE-0.6b-512px](https://huggingface.co/scepter-studio/ACE-0.6B-512px). * **[2025.01.06]** Release the [ACE++](https://ali-vilab.github.io/ACE_plus_page/). ## 🚀 Installation Install the necessary packages with `pip`: ```bash pip install -r requirements.txt ``` ## 🔥 ACE Models | **Model** | **Status** | |:----------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| | ACE-0.6B-512px | [![Demo link](https://img.shields.io/badge/Demo-ACE_Chat-purple)](https://huggingface.co/spaces/scepter-studio/ACE-Chat)
[![ModelScope link](https://img.shields.io/badge/ModelScope-Model-blue)](https://www.modelscope.cn/models/iic/ACE-0.6B-512px) [![HuggingFace link](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/scepter-studio/ACE-0.6B-512px) | | ACE-0.6B-1024px | [![ModelScope link](https://img.shields.io/badge/ModelScope-Model-blue)](https://www.modelscope.cn/models/iic/ACE-0.6B-1024px) [![HuggingFace link](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/scepter-studio/ACE-0.6B-1024px) | | | ACE-12B-FLUX-dev | The ACE model based on the FLUX.1-dev base model has adopted a new adaptation method. We have organized a new project called [ACE++](https://ali-vilab.github.io/ACE_plus_page/). The relevant models have been open-sourced. Please visit to learn more. | | ## 🖼 Model Performance Visualization The current model's parameters scale of ACE is 0.6B, which imposes certain limitations on the quality of image generation. [FLUX.1-Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev), on the other hand, has a significant advantage in text-to-image generation quality. By using SDEdit, we can effectively leverage the generative capabilities of FLUX to further enhance the image results generated by ACE. Based on the above considerations, we have designed the ACE-Refiner pipeline, as shown in the diagram below. ![ACE_REFINER](assets/ace_method/ace_refiner_process.webp) As shown in the figure below, when the strength σ of the generated image is high, the generated image will suffer from fidelity loss compared to the original image. Conversely, lower σ does not significantly improve the image quality. Therefore, users can make a trade-off between fidelity to the generated result and the image quality based on their own needs. Users can set the value of "REFINER_SCALE" in the configuration file `config/inference_config/models/ace_0.6b_1024_refiner.yaml`. We recommend that users use the advance options in the [webui-demo](#-chat-bot-) for effect verification. ![ACE_REFINER_EXAMPLE](assets/ace_method/ace_refiner.webp) We compared the generation and editing performance of different models on several tasks, as shown as following. ![Samples](assets/ace_method/samples_compare.webp) ## 🔥 Training We offer a demonstration training YAML that enables the end-to-end training of ACE using a toy dataset. For a comprehensive overview of the hyperparameter configurations, please consult `config/ace_0.6b_512_train.yaml`. ### Prepare datasets Please find the dataset class located in `modules/data/dataset/dataset.py`, designed to facilitate end-to-end training using an open-source toy dataset. Download a dataset zip file from [modelscope](https://www.modelscope.cn/models/iic/scepter/resolve/master/datasets/hed_pair.zip), and then extract its contents into the `cache/datasets/` directory. Should you wish to prepare your own datasets, we recommend consulting `modules/data/dataset/dataset.py` for detailed guidance on the required data format. ### Prepare initial weight The ACE checkpoint has been uploaded to both ModelScope and HuggingFace platforms: * [ModelScope](https://www.modelscope.cn/models/iic/ACE-0.6B-512px) * [HuggingFace](https://huggingface.co/scepter-studio/ACE-0.6B-512px) In the provided training YAML configuration, we have designated the Modelscope URL as the default checkpoint URL. Should you wish to transition to Hugging Face, you can effortlessly achieve this by modifying the PRETRAINED_MODEL value within the YAML file (replace the prefix "ms://iic" to "hf://scepter-studio"). ### Start training You can easily start training procedure by executing the following command: ```bash # ACE-0.6B-512px PYTHONPATH=. python tools/run_train.py --cfg config/ace_0.6b_512_train.yaml # ACE-0.6B-1024px PYTHONPATH=. python tools/run_train.py --cfg config/ace_0.6b_1024_train.yaml ``` ## 🚀 Inference We provide a simple inference demo that allows users to generate images from text descriptions. ```bash PYTHONPATH=. python tools/run_inference.py --cfg config/inference_config/models/ace_0.6b_512.yaml --instruction "make the boy cry, his eyes filled with tears" --seed 199999 --input_image examples/input_images/example0.webp ``` We recommend runing the examples for quick testing. Running the following command will run the example inference and the results will be saved in `examples/output_images/`. ```bash PYTHONPATH=. python tools/run_inference.py --cfg config/inference_config/models/ace_0.6b_512.yaml ``` ## 💬 Chat Bot We have developed an chatbot UI utilizing Gradio, designed to transform user input in natural language into visually stunning images that align semantically with the provided instructions. Users can effortlessly initiate the chatbot app by executing the following command: ```bash python chatbot/run_gradio.py --cfg chatbot/config/chatbot_ui.yaml --server_port 2024 ```
## ⚙️️ ComfyUI Workflow ![Workflow](assets/comfyui/ace_example.jpg) We support the use of ACE in the ComfyUI Workflow through the following methods: 1) Automatic installation directly via the ComfyUI Manager by searching for the **ComfyUI-Scepter** node. 2) Manually install by moving custom_nodes from Scepter to ComfyUI. ```shell git clone https://github.com/modelscope/scepter.git cd path/to/scepter pip install -e . cp -r path/to/scepter/workflow/ path/to/ComfyUI/custom_nodes/ComfyUI-Scepter cd path/to/ComfyUI python main.py ``` **Note**: You can use the nodes by dragging the sample images below into ComfyUI. Additionally, our nodes can automatically pull models from ModelScope or HuggingFace by selecting the *model_source* field, or you can place the already downloaded models in a local path.
ACE Workflow Examples
Control Semantic Element
## 📝 Citation ```bibtex @inproceedings{ICLR2025_ACE, title = {ACE: All-round Creator and Editor Following Instructions via Diffusion Transformer}, author = {Han, Zhen and Jiang, Zeyinzi and Pan, Yulin and Zhang, Jingfeng and Mao, Chaojie and Xie, Chen-Wei and Liu, Yu and Zhou, Jingren}, booktitle = {International Conference on Representation Learning}, pages = {57096--57111}, year = {2025} } ``` ================================================ FILE: requirements.txt ================================================ git+https://github.com/modelscope/scepter.git@v1.3.0_dev#egg=scepter pycocotools pyyaml>=5.3.1 scikit-image torchsde transformers scikit-learn numpy opencv-python opencv_transforms>=0.0.6 oss2>=2.15.0 einops torch==2.4.0 torchvision flash-attn==2.5.8 bitsandbytes gradio==4.44.1 gradio_imageslider diffusers addict datasets==3.0.1 ================================================ FILE: tools/run_inference.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse import importlib import io import os import sys from PIL import Image from scepter.modules.utils.config import Config from scepter.modules.utils.file_system import FS if os.path.exists('__init__.py'): package_name = 'scepter_ext' spec = importlib.util.spec_from_file_location(package_name, '__init__.py') package = importlib.util.module_from_spec(spec) sys.modules[package_name] = package spec.loader.exec_module(package) from chatbot.ace_inference import ACEInference fs_list = [ Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False), Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False), Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False), Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False), ] for one_fs in fs_list: FS.init_fs_client(one_fs) def run_one_case(pipe, input_image, input_mask, edit_k, instruction, negative_prompt, seed, output_h, output_w, save_path): edit_image, edit_image_mask, edit_task = [], [], [] if input_image is not None: image = Image.open(io.BytesIO(FS.get_object(input_image))) edit_image.append(image.convert('RGB')) edit_image_mask.append( Image.open(Image.open(io.BytesIO(FS.get_object(input_mask)))). convert('L') if input_mask is not None else None) edit_task.append(edit_k) imgs = pipe( image=edit_image, mask=edit_image_mask, task=edit_task, prompt=[instruction] * len(edit_image) if edit_image is not None else [instruction], negative_prompt=[negative_prompt] * len(edit_image) if edit_image is not None else [negative_prompt], output_height=output_h, output_width=output_w, sampler=pipe.input.get("sampler", "ddim"), sample_steps=pipe.input.get("sample_steps", 20), guide_scale=pipe.input.get("guide_scale", 4.5), guide_rescale=pipe.input.get("guide_rescale", 0.5), seed=seed, ) with FS.put_to(save_path) as local_path: imgs[0].save(local_path) return def run(): parser = argparse.ArgumentParser(description='Argparser for Scepter:\n') parser.add_argument('--instruction', dest='instruction', help='The instruction for editing or generating!', default="") parser.add_argument('--negative_prompt', dest='negative_prompt', help='The negative prompt for editing or generating!', default="") parser.add_argument('--output_h', dest='output_h', help='The height of output image for generation tasks!', type=int, default=None) parser.add_argument('--output_w', dest='output_w', help='The width of output image for generation tasks!', type=int, default=None) parser.add_argument('--input_image', dest='input_image', help='The input image!', default=None ) parser.add_argument('--input_mask', dest='input_mask', help='The input mask!', default=None ) parser.add_argument('--save_path', dest='save_path', help='The save path for output image!', default='examples/output_images/output.png' ) parser.add_argument('--seed', dest='seed', help='The seed for generation!', type=int, default=-1) cfg = Config(load=True, parser_ins=parser) pipe = ACEInference() pipe.init_from_cfg(cfg) output_h = cfg.args.output_h or pipe.input.get("output_height", 1024) output_w = cfg.args.output_w or pipe.input.get("output_width", 1024) negative_prompt = cfg.args.negative_prompt if cfg.args.instruction == "" and cfg.args.input_image is None: # run examples all_examples = [ ["examples/input_images/example0.webp", None, "", "{image} make the boy cry, his eyes filled with tears", "", 199999, output_h, output_w, "examples/output_images/example0.png"], ["examples/input_images/example1.webp", None, "", "{image}use the depth map @cb638863a0e9 and the text caption \"Vincent van Gogh with expressive, " "soulful eyes and a gentle smile, wearing traditional 19th-century artist's attire, including a " "paint-streaked smock, a straw hat with sunflowers, and an artist's easel slung over his shoulder." "Subtle elements of \"Starry Night\" swirling around, with hints of sunflowers and wheat fields " "from his famous paintings. Include a palette and paintbrushes, a small sun painted in the top " "corner, and subtle curling patterns reminiscent of his brush strokes\" to create a image", "", 899999, output_h, output_w, "examples/output_images/example1.png"], ["examples/input_images/example2.webp", None, "", "make this {image} colorful", "", 199999, output_h, output_w, "examples/output_images/example2.png"], ["examples/input_images/example3.webp", None, "", "change the style to 3D cartoon style", "", 2023, output_h, output_w, "examples/output_images/example3.png"], ] for example in all_examples: run_one_case(pipe, example[0], example[1], example[2], example[3], example[4], example[5], example[6], example[7], example[8]) else: if "{image}" not in cfg.args.instruction: instruction = "{image} " + cfg.args.instruction else: instruction = cfg.args.instruction run_one_case(pipe, cfg.args.input_image, cfg.args.input_mask, "", instruction, negative_prompt, cfg.args.seed, output_h, output_w, cfg.args.save_path) if __name__ == '__main__': run() ================================================ FILE: tools/run_train.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse import importlib import os import sys from scepter.modules.solver.registry import SOLVERS from scepter.modules.utils.config import Config from scepter.modules.utils.distribute import we from scepter.modules.utils.logger import get_logger if os.path.exists('__init__.py'): package_name = 'scepter_ext' spec = importlib.util.spec_from_file_location(package_name, '__init__.py') package = importlib.util.module_from_spec(spec) sys.modules[package_name] = package spec.loader.exec_module(package) def run_task(cfg): std_logger = get_logger(name='scepter') solver = SOLVERS.build(cfg.SOLVER, logger=std_logger) solver.set_up_pre() solver.set_up() solver.solve() def update_config(cfg): if hasattr(cfg.args, 'learning_rate') and cfg.args.learning_rate: if cfg.SOLVER.OPTIMIZER.get('LEARNING_RATE', None) is not None: print( f'learning_rate change from {cfg.SOLVER.OPTIMIZER.LEARNING_RATE} to {cfg.args.learning_rate}' ) cfg.SOLVER.OPTIMIZER.LEARNING_RATE = float(cfg.args.learning_rate) if hasattr(cfg.args, 'max_steps') and cfg.args.max_steps: if cfg.SOLVER.get('MAX_STEPS', None) is not None: print( f'max_steps change from {cfg.SOLVER.MAX_STEPS} to {cfg.args.max_steps}' ) cfg.SOLVER.MAX_STEPS = int(cfg.args.max_steps) return cfg def run(): parser = argparse.ArgumentParser(description='Argparser for Scepter:\n') parser.add_argument('--learning_rate', dest='learning_rate', help='The learning rate for our network!', default=None) parser.add_argument('--max_steps', dest='max_steps', help='The max steps for training!', default=None) cfg = Config(load=True, parser_ins=parser) cfg = update_config(cfg) we.init_env(cfg, logger=None, fn=run_task) if __name__ == '__main__': run()