Showing preview only (284K chars total). Download the full file or copy to clipboard to get everything.
Repository: ali-vilab/ACE
Branch: main
Commit: 886bf9510b85
Files: 42
Total size: 269.5 KB
Directory structure:
gitextract_gx2sgdse/
├── .gitignore
├── LICENSE
├── __init__.py
├── chatbot/
│ ├── __init__.py
│ ├── ace_inference.py
│ ├── config/
│ │ ├── chatbot_ui.yaml
│ │ └── models/
│ │ └── ace_0.6b_512.yaml
│ ├── example.py
│ ├── infer.py
│ ├── run_gradio.py
│ └── utils.py
├── config/
│ ├── inference_config/
│ │ ├── chatbot_ui.yaml
│ │ └── models/
│ │ ├── ace_0.6b_1024.yaml
│ │ ├── ace_0.6b_1024_refiner.yaml
│ │ └── ace_0.6b_512.yaml
│ └── train_config/
│ ├── ace_0.6b_1024_train.yaml
│ └── ace_0.6b_512_train.yaml
├── modules/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset/
│ │ ├── __init__.py
│ │ └── dataset.py
│ ├── inference/
│ │ └── __init__.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── backbone/
│ │ │ ├── __init__.py
│ │ │ ├── ace.py
│ │ │ ├── layers.py
│ │ │ └── pos_embed.py
│ │ ├── diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── diffusions.py
│ │ │ ├── samplers.py
│ │ │ └── schedules.py
│ │ ├── embedder/
│ │ │ ├── __init__.py
│ │ │ └── embedder.py
│ │ ├── network/
│ │ │ ├── __init__.py
│ │ │ └── ldm_ace.py
│ │ └── utils/
│ │ └── basic_utils.py
│ └── solver/
│ ├── __init__.py
│ └── ace_solver.py
├── readme.md
├── requirements.txt
└── tools/
├── run_inference.py
└── run_train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
*.pth
*.pt
*.pkl
*.ckpt
*.DS_Store
*__pycache__*
*.cache*
*.bin
*.idea
*.csv
cache
build
dist
dev
scepter.egg-info
.readthedocs.yml
*resources
*.ipynb_checkpoints*
*.vscode
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: __init__.py
================================================
from . import modules
from . import chatbot
================================================
FILE: chatbot/__init__.py
================================================
================================================
FILE: chatbot/ace_inference.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import torchvision.transforms as T
from scepter.modules.model.registry import DIFFUSIONS
from scepter.modules.model.utils.basic_utils import check_list_of_list
from scepter.modules.model.utils.basic_utils import \
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
from scepter.modules.model.utils.basic_utils import (
to_device, unpack_tensor_into_imagelist)
from scepter.modules.utils.distribute import we
from scepter.modules.utils.logger import get_logger
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
def process_edit_image(images,
masks,
tasks,
max_seq_len=1024,
max_aspect_ratio=4,
d=16,
**kwargs):
if not isinstance(images, list):
images = [images]
if not isinstance(masks, list):
masks = [masks]
if not isinstance(tasks, list):
tasks = [tasks]
img_tensors = []
mask_tensors = []
for img, mask, task in zip(images, masks, tasks):
if mask is None or mask == '':
mask = Image.new('L', img.size, 0)
W, H = img.size
if H / W > max_aspect_ratio:
img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
elif W / H > max_aspect_ratio:
img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
H, W = img.height, img.width
scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
rH = int(H * scale) // d * d # ensure divisible by self.d
rW = int(W * scale) // d * d
img = TF.resize(img, (rH, rW),
interpolation=TF.InterpolationMode.BICUBIC)
mask = TF.resize(mask, (rH, rW),
interpolation=TF.InterpolationMode.NEAREST_EXACT)
mask = np.asarray(mask)
mask = np.where(mask > 128, 1, 0)
mask = mask.astype(
np.float32) if np.any(mask) else np.ones_like(mask).astype(
np.float32)
img_tensor = TF.to_tensor(img).to(we.device_id)
img_tensor = TF.normalize(img_tensor,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
mask_tensor = TF.to_tensor(mask).to(we.device_id)
if task in ['inpainting', 'Try On', 'Inpainting']:
mask_indicator = mask_tensor.repeat(3, 1, 1)
img_tensor[mask_indicator == 1] = -1.0
img_tensors.append(img_tensor)
mask_tensors.append(mask_tensor)
return img_tensors, mask_tensors
class TextEmbedding(nn.Module):
def __init__(self, embedding_shape):
super().__init__()
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
class RefinerInference(DiffusionInference):
def init_from_cfg(self, cfg):
self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
super().init_from_cfg(cfg)
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \
if cfg.MODEL.have('DIFFUSION') else None
self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
assert self.diffusion is not None
if not self.use_dynamic_model:
self.dynamic_load(self.first_stage_model, 'first_stage_model')
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
self.dynamic_load(self.diffusion_model, 'diffusion_model')
@torch.no_grad()
def encode_first_stage(self, x, **kwargs):
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
with torch.autocast('cuda',
enabled=dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, dtype)):
def run_one_image(u):
zu = get_model(self.first_stage_model).encode(u)
if isinstance(zu, (tuple, list)):
zu = zu[0]
return zu
z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]
return z
def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
c, H, W = image.shape
scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
rW = int(W * scale) // 16 * 16
image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
return image
@torch.no_grad()
def decode_first_stage(self, z):
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
with torch.autocast('cuda',
enabled=dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, dtype)):
return [get_model(self.first_stage_model).decode(zu) for zu in z]
def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
noise = torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(h / 16),
2 * math.ceil(w / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
return noise
def refine(self,
x_samples=None,
prompt=None,
reverse_scale=-1.,
seed = 2024,
**kwargs
):
print(prompt)
value_input = copy.deepcopy(self.input)
x_samples = [self.upscale_resize(x) for x in x_samples]
noise = []
for i, x in enumerate(x_samples):
noise_ = self.noise_sample(1, x.shape[1],
x.shape[2], seed,
device = x.device)
noise.append(noise_)
noise, x_shapes = pack_imagelist_into_tensor(noise)
if reverse_scale > 0:
self.dynamic_load(self.first_stage_model, 'first_stage_model')
x_samples = [x.unsqueeze(0) for x in x_samples]
x_start = self.encode_first_stage(x_samples, **kwargs)
self.dynamic_unload(self.first_stage_model,
'first_stage_model',
skip_loaded=not self.use_dynamic_model)
x_start, _ = pack_imagelist_into_tensor(x_start)
else:
x_start = None
# cond stage
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
function_name, dtype = self.get_function_info(self.cond_stage_model)
with torch.autocast('cuda',
enabled=dtype == 'float16',
dtype=getattr(torch, dtype)):
ctx = getattr(get_model(self.cond_stage_model),
function_name)(prompt)
ctx["x_shapes"] = x_shapes
self.dynamic_unload(self.cond_stage_model,
'cond_stage_model',
skip_loaded=not self.use_dynamic_model)
self.dynamic_load(self.diffusion_model, 'diffusion_model')
# UNet use input n_prompt
function_name, dtype = self.get_function_info(
self.diffusion_model)
with torch.autocast('cuda',
enabled=dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, dtype)):
solver_sample = value_input.get('sample', 'flow_euler')
sample_steps = value_input.get('sample_steps', 20)
guide_scale = value_input.get('guide_scale', 3.5)
if guide_scale is not None:
guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,
dtype=noise.dtype)
else:
guide_scale = None
latent = self.diffusion.sample(
noise=noise,
sampler=solver_sample,
model=get_model(self.diffusion_model),
model_kwargs={"cond": ctx, "guidance": guide_scale},
steps=sample_steps,
show_progress=True,
guide_scale=guide_scale,
return_intermediate=None,
reverse_scale=reverse_scale,
x=x_start,
**kwargs).float()
latent = unpack_tensor_into_imagelist(latent, x_shapes)
self.dynamic_unload(self.diffusion_model,
'diffusion_model',
skip_loaded=not self.use_dynamic_model)
self.dynamic_load(self.first_stage_model, 'first_stage_model')
x_samples = self.decode_first_stage(latent)
self.dynamic_unload(self.first_stage_model,
'first_stage_model',
skip_loaded=not self.use_dynamic_model)
return x_samples
class ACEInference(DiffusionInference):
def __init__(self, logger=None):
if logger is None:
logger = get_logger(name='scepter')
self.logger = logger
self.loaded_model = {}
self.loaded_model_name = [
'diffusion_model', 'first_stage_model', 'cond_stage_model'
]
def init_from_cfg(self, cfg):
self.name = cfg.NAME
self.is_default = cfg.get('IS_DEFAULT', False)
self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
assert cfg.have('MODEL')
self.diffusion_model = self.infer_model(
cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
'DIFFUSION_MODEL',
None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
self.first_stage_model = self.infer_model(
cfg.MODEL.FIRST_STAGE_MODEL,
module_paras.get(
'FIRST_STAGE_MODEL',
None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
self.cond_stage_model = self.infer_model(
cfg.MODEL.COND_STAGE_MODEL,
module_paras.get(
'COND_STAGE_MODEL',
None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)
# self.refiner_scale = cfg.get('REFINER_SCALE', 0.)
# self.refiner_prompt = cfg.get('REFINER_PROMPT', "")
self.ace_prompt = cfg.get("ACE_PROMPT", [])
if self.refiner_model_cfg:
self.refiner_model_cfg.USE_DYNAMIC_MODEL = self.use_dynamic_model
self.refiner_module = RefinerInference(self.logger)
self.refiner_module.init_from_cfg(self.refiner_model_cfg)
else:
self.refiner_module = None
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
logger=self.logger)
self.interpolate_func = lambda x: (F.interpolate(
x.unsqueeze(0),
scale_factor=1 / self.size_factor,
mode='nearest-exact') if x is not None else None)
self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
False)
if self.use_text_pos_embeddings:
self.text_position_embeddings = TextEmbedding(
(10, 4096)).eval().requires_grad_(False).to(we.device_id)
else:
self.text_position_embeddings = None
self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
self.size_factor = cfg.get('SIZE_FACTOR', 8)
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
if not self.use_dynamic_model:
self.dynamic_load(self.first_stage_model, 'first_stage_model')
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
self.dynamic_load(self.diffusion_model, 'diffusion_model')
@torch.no_grad()
def encode_first_stage(self, x, **kwargs):
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
with torch.autocast('cuda',
enabled=(dtype != 'float32'),
dtype=getattr(torch, dtype)):
z = [
self.scale_factor * get_model(self.first_stage_model)._encode(
i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
]
return z
@torch.no_grad()
def decode_first_stage(self, z):
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
with torch.autocast('cuda',
enabled=(dtype != 'float32'),
dtype=getattr(torch, dtype)):
x = [
get_model(self.first_stage_model)._decode(
1. / self.scale_factor * i.to(getattr(torch, dtype)))
for i in z
]
return x
@torch.no_grad()
def __call__(self,
image=None,
mask=None,
prompt='',
task=None,
negative_prompt='',
output_height=512,
output_width=512,
sampler='ddim',
sample_steps=20,
guide_scale=4.5,
guide_rescale=0.5,
seed=-1,
history_io=None,
tar_index=0,
**kwargs):
input_image, input_mask = image, mask
g = torch.Generator(device=we.device_id)
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
g.manual_seed(int(seed))
if input_image is not None:
# assert isinstance(input_image, list) and isinstance(input_mask, list)
if task is None:
task = [''] * len(input_image)
if not isinstance(prompt, list):
prompt = [prompt] * len(input_image)
if history_io is not None and len(history_io) > 0:
his_image, his_maks, his_prompt, his_task = history_io[
'image'], history_io['mask'], history_io[
'prompt'], history_io['task']
assert len(his_image) == len(his_maks) == len(
his_prompt) == len(his_task)
input_image = his_image + input_image
input_mask = his_maks + input_mask
task = his_task + task
prompt = his_prompt + [prompt[-1]]
prompt = [
pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
for i, pp in enumerate(prompt)
]
edit_image, edit_image_mask = process_edit_image(
input_image, input_mask, task, max_seq_len=self.max_seq_len)
image, image_mask = edit_image[tar_index], edit_image_mask[
tar_index]
edit_image, edit_image_mask = [edit_image], [edit_image_mask]
else:
edit_image = edit_image_mask = [[]]
image = torch.zeros(
size=[3, int(output_height),
int(output_width)])
image_mask = torch.ones(
size=[1, int(output_height),
int(output_width)])
if not isinstance(prompt, list):
prompt = [prompt]
image, image_mask, prompt = [image], [image_mask], [prompt]
assert check_list_of_list(prompt) and check_list_of_list(
edit_image) and check_list_of_list(edit_image_mask)
# Assign Negative Prompt
if isinstance(negative_prompt, list):
negative_prompt = negative_prompt[0]
assert isinstance(negative_prompt, str)
n_prompt = copy.deepcopy(prompt)
for nn_p_id, nn_p in enumerate(n_prompt):
assert isinstance(nn_p, list)
n_prompt[nn_p_id][-1] = negative_prompt
is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1
image = to_device(image)
refiner_scale = kwargs.pop("refiner_scale", 0.0)
refiner_prompt = kwargs.pop("refiner_prompt", "")
use_ace = kwargs.pop("use_ace", True)
# <= 0 use ace as the txt2img generator.
if use_ace and (not is_txt_image or refiner_scale <= 0):
ctx, null_ctx = {}, {}
# Get Noise Shape
self.dynamic_load(self.first_stage_model, 'first_stage_model')
x = self.encode_first_stage(image)
self.dynamic_unload(self.first_stage_model,
'first_stage_model',
skip_loaded=not self.use_dynamic_model)
noise = [
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
for i in x
]
noise, x_shapes = pack_imagelist_into_tensor(noise)
ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
image_mask = to_device(image_mask, strict=False)
cond_mask = [self.interpolate_func(i) for i in image_mask
] if image_mask is not None else [None] * len(image)
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
# Encode Prompt
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
function_name, dtype = self.get_function_info(self.cond_stage_model)
cont, cont_mask = getattr(get_model(self.cond_stage_model),
function_name)(prompt)
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
cont_mask)
null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
function_name)(n_prompt)
null_cont, null_cont_mask = self.cond_stage_embeddings(
prompt, edit_image, null_cont, null_cont_mask)
self.dynamic_unload(self.cond_stage_model,
'cond_stage_model',
skip_loaded=not self.use_dynamic_model)
ctx['crossattn'] = cont
null_ctx['crossattn'] = null_cont
# Encode Edit Images
self.dynamic_load(self.first_stage_model, 'first_stage_model')
edit_image = [to_device(i, strict=False) for i in edit_image]
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
e_img, e_mask = [], []
for u, m in zip(edit_image, edit_image_mask):
if u is None:
continue
if m is None:
m = [None] * len(u)
e_img.append(self.encode_first_stage(u, **kwargs))
e_mask.append([self.interpolate_func(i) for i in m])
self.dynamic_unload(self.first_stage_model,
'first_stage_model',
skip_loaded=not self.use_dynamic_model)
null_ctx['edit'] = ctx['edit'] = e_img
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
# Diffusion Process
self.dynamic_load(self.diffusion_model, 'diffusion_model')
function_name, dtype = self.get_function_info(self.diffusion_model)
with torch.autocast('cuda',
enabled=dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, dtype)):
latent = self.diffusion.sample(
noise=noise,
sampler=sampler,
model=get_model(self.diffusion_model),
model_kwargs=[{
'cond':
ctx,
'mask':
cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
}, {
'cond':
null_ctx,
'mask':
null_cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
}] if guide_scale is not None and guide_scale > 1 else {
'cond':
null_ctx,
'mask':
cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
},
steps=sample_steps,
show_progress=True,
seed=seed,
guide_scale=guide_scale,
guide_rescale=guide_rescale,
return_intermediate=None,
**kwargs)
if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
'diffusion_model',
skip_loaded=not self.use_dynamic_model)
# Decode to Pixel Space
self.dynamic_load(self.first_stage_model, 'first_stage_model')
samples = unpack_tensor_into_imagelist(latent, x_shapes)
x_samples = self.decode_first_stage(samples)
self.dynamic_unload(self.first_stage_model,
'first_stage_model',
skip_loaded=not self.use_dynamic_model)
x_samples = [x.squeeze(0) for x in x_samples]
else:
x_samples = image
if self.refiner_module and refiner_scale > 0:
if is_txt_image:
random.shuffle(self.ace_prompt)
input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt]
input_refine_scale = -1.
else:
input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt]
input_refine_scale = refiner_scale
print(input_refine_prompt)
x_samples = self.refiner_module.refine(x_samples,
reverse_scale = input_refine_scale,
prompt= input_refine_prompt,
seed=seed,
use_dynamic_model=self.use_dynamic_model)
imgs = [
torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,
min=0.0,
max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
for x_i in x_samples
]
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
return imgs
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
if self.use_text_pos_embeddings and not torch.sum(
self.text_position_embeddings.pos) > 0:
identifier_cont, _ = getattr(get_model(self.cond_stage_model),
'encode')(self.text_indentifers,
return_mask=True)
self.text_position_embeddings.load_state_dict(
{'pos': identifier_cont[:, 0, :]})
cont_, cont_mask_ = [], []
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
if isinstance(pp, list):
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
else:
raise NotImplementedError
return cont_, cont_mask_
================================================
FILE: chatbot/config/chatbot_ui.yaml
================================================
WORK_DIR: ./cache/chatbot
FILE_SYSTEM:
- NAME: "HuggingfaceFs"
TEMP_DIR: ./cache
- NAME: "ModelscopeFs"
TEMP_DIR: ./cache
- NAME: "LocalFs"
TEMP_DIR: ./cache
- NAME: "HttpFs"
TEMP_DIR: ./cache
#
ENABLE_I2V: False
#
MODEL:
EDIT_MODEL:
MODEL_CFG_DIR: chatbot/config/models/
DEFAULT: ace_0.6b_512
I2V:
MODEL_NAME: CogVideoX-5b-I2V
MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
CAPTIONER:
MODEL_NAME: InternVL2-2B
MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
ENHANCER:
MODEL_NAME: Meta-Llama-3.1-8B-Instruct
MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
================================================
FILE: chatbot/config/models/ace_0.6b_512.yaml
================================================
NAME: ACE_0.6B_512
IS_DEFAULT: False
DEFAULT_PARAS:
PARAS:
#
INPUT:
INPUT_IMAGE:
INPUT_MASK:
TASK:
PROMPT: ""
NEGATIVE_PROMPT: ""
OUTPUT_HEIGHT: 512
OUTPUT_WIDTH: 512
SAMPLER: ddim
SAMPLE_STEPS: 20
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
SEED: -1
TAR_INDEX: 0
OUTPUT:
LATENT:
IMAGES:
SEED:
MODULES_PARAS:
FIRST_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: float16
INPUT: ["IMAGE"]
- NAME: decode
DTYPE: float16
INPUT: ["LATENT"]
#
DIFFUSION_MODEL:
FUNCTION:
- NAME: forward
DTYPE: float16
INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
#
COND_STAGE_MODEL:
FUNCTION:
- NAME: encode_list
DTYPE: bfloat16
INPUT: ["PROMPT"]
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT: ""
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 1024
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: ACETextEmbedder
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
================================================
FILE: chatbot/example.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from scepter.modules.utils.file_system import FS
from PIL import Image
def download_image(image, local_path=None):
if not FS.exists(local_path):
local_path = FS.get_from(image, local_path=local_path)
return local_path
def blank_image():
return Image.new('RGBA', (128, 128), (0, 0, 0, 0))
def get_examples(cache_dir):
print('Downloading Examples ...')
bl_img = blank_image()
examples = [
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',
os.path.join(cache_dir, 'examples/e33edc106953.png')), bl_img,
bl_img, '{image} let the man smile', 6666
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',
os.path.join(cache_dir, 'examples/5d2bcc91a3e9.png')), bl_img,
bl_img, 'let the man in {image} wear sunglasses', 9999
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',
os.path.join(cache_dir, 'examples/3a52eac708bd.png')), bl_img,
bl_img, '{image} red hair', 9999
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',
os.path.join(cache_dir, 'examples/3f4dc464a0ea.png')), bl_img,
bl_img, '{image} let the man serious', 99999
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
os.path.join(cache_dir,
'examples/131ca90fd2a9.png')), bl_img, bl_img,
'"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
613725
],
[
'Render Text',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
os.path.join(cache_dir, 'examples/33e9f27c2c48.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
os.path.join(cache_dir,
'examples/33e9f27c2c48_mask.png')), bl_img,
'Put the text "C A T" at the position marked by mask in the {image}',
6666
],
[
'Style Transfer',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
os.path.join(cache_dir, 'examples/9e73e7eeef55.png')), bl_img,
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
os.path.join(cache_dir, 'examples/2e02975293d6.png')),
'edit {image} based on the style of {image1} ', 99999
],
[
'Outpainting',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
os.path.join(cache_dir, 'examples/f2b22c08be3f.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
os.path.join(cache_dir,
'examples/f2b22c08be3f_mask.png')), bl_img,
'Could the {image} be widened within the space designated by mask, while retaining the original?',
6666
],
[
'Image Segmentation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
os.path.join(cache_dir, 'examples/db3ebaa81899.png')), bl_img,
bl_img, '{image} Segmentation', 6666
],
[
'Depth Estimation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
os.path.join(cache_dir, 'examples/f1927c4692ba.png')), bl_img,
bl_img, '{image} Depth Estimation', 6666
],
[
'Pose Estimation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
os.path.join(cache_dir, 'examples/014e5bf3b4d1.png')), bl_img,
bl_img, '{image} distinguish the poses of the figures', 999999
],
[
'Scribble Extraction',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
os.path.join(cache_dir, 'examples/5f59a202f8ac.png')), bl_img,
bl_img, 'Generate a scribble of {image}, please.', 6666
],
[
'Mosaic',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
os.path.join(cache_dir, 'examples/3a2f52361eea.png')), bl_img,
bl_img, 'Adapt {image} into a mosaic representation.', 6666
],
[
'Edge map Extraction',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
os.path.join(cache_dir, 'examples/b9d1e519d6e5.png')), bl_img,
bl_img, 'Get the edge-enhanced result for {image}.', 6666
],
[
'Grayscale',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
os.path.join(cache_dir, 'examples/c4ebbe2ba29b.png')), bl_img,
bl_img, 'transform {image} into a black and white one', 6666
],
[
'Contour Extraction',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
os.path.join(cache_dir,
'examples/19652d0f6c4b.png')), bl_img, bl_img,
'Would you be able to make a contour picture from {image} for me?',
6666
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
os.path.join(cache_dir,
'examples/249cda2844b7.png')), bl_img, bl_img,
'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
6666
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
os.path.join(cache_dir,
'examples/411f6c4b8e6c.png')), bl_img, bl_img,
'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
999999
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
os.path.join(cache_dir,
'examples/a35c96ed137a.png')), bl_img, bl_img,
'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
3599999
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
os.path.join(cache_dir,
'examples/dcb2fc86f1ce.png')), bl_img, bl_img,
'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
6666
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
os.path.join(cache_dir,
'examples/4cd4ee494962.png')), bl_img, bl_img,
'make this {image} colorful as per the "beautiful sunflowers"',
6666
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
os.path.join(cache_dir,
'examples/a47e3a9cd166.png')), bl_img, bl_img,
'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
613725
],
[
'Controllable Generation',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
os.path.join(cache_dir,
'examples/d890ed8a3ac2.png')), bl_img, bl_img,
'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
6666
],
[
'Image Denoising',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
os.path.join(cache_dir,
'examples/0844a686a179.png')), bl_img, bl_img,
'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
6666
],
[
'Inpainting',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
os.path.join(cache_dir, 'examples/fa91b6b7e59b.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
os.path.join(cache_dir,
'examples/fa91b6b7e59b_mask.png')), bl_img,
'Ensure to overhaul the parts of the {image} indicated by the mask.',
6666
],
[
'Inpainting',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
os.path.join(cache_dir, 'examples/632899695b26.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
os.path.join(cache_dir,
'examples/632899695b26_mask.png')), bl_img,
'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
6666
],
[
'General Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
os.path.join(cache_dir,
'examples/354d17594afe.png')), bl_img, bl_img,
'{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
6666
],
[
'General Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
os.path.join(cache_dir,
'examples/38946455752b.png')), bl_img, bl_img,
'{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
6669
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
os.path.join(cache_dir,
'examples/3ba5202f0cd8.png')), bl_img, bl_img,
'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
99999
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
os.path.join(cache_dir, 'examples/369365b94725.png')), bl_img,
bl_img, '{image} Make her looking at the camera', 6666
],
[
'Facial Editing',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
os.path.join(cache_dir, 'examples/92751f2e4a0e.png')), bl_img,
bl_img, '{image} Remove the smile from his face', 9899999
],
[
'Remove Text',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
os.path.join(cache_dir, 'examples/8530a6711b2e.png')), bl_img,
bl_img, 'Aim to remove any textual element in {image}', 6666
],
[
'Remove Text',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
os.path.join(cache_dir, 'examples/c4d7fb28f8f6.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
os.path.join(cache_dir,
'examples/c4d7fb28f8f6_mask.png')), bl_img,
'Rub out any text found in the mask sector of the {image}.', 6666
],
[
'Remove Object',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
os.path.join(cache_dir,
'examples/e2f318fa5e5b.png')), bl_img, bl_img,
'Remove the unicorn in this {image}, ensuring a smooth edit.',
99999
],
[
'Remove Object',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
os.path.join(cache_dir, 'examples/1ae96d8aca00.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.png')),
bl_img, 'Discard the contents of the mask area from {image}.', 99999
],
[
'Add Object',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
os.path.join(cache_dir, 'examples/80289f48e511.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
os.path.join(cache_dir,
'examples/80289f48e511_mask.png')), bl_img,
'add a Hot Air Balloon into the {image}, per the mask', 613725
],
[
'Style Transfer',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
os.path.join(cache_dir, 'examples/d725cb2009e8.png')), bl_img,
bl_img, 'Change the style of {image} to colored pencil style', 99999
],
[
'Style Transfer',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
os.path.join(cache_dir, 'examples/e0f48b3fd010.png')), bl_img,
bl_img, 'make {image} to Walt Disney Animation style', 99999
],
[
'Try On',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
os.path.join(cache_dir, 'examples/ee4ca60b8c96.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.png')),
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
os.path.join(cache_dir, 'examples/ebe825bbfe3c.png')),
'Change the cloth in {image} to the one in {image1}', 99999
],
[
'Workflow',
download_image(
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
os.path.join(cache_dir, 'examples/cb85353c004b.png')), bl_img,
bl_img, '<workflow> ice cream {image}', 99999
],
]
print('Finish. Start building UI ...')
return examples
================================================
FILE: chatbot/infer.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from scepter.modules.model.registry import DIFFUSIONS
from scepter.modules.utils.distribute import we
from scepter.modules.utils.logger import get_logger
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
from modules.model.utils.basic_utils import (
check_list_of_list,
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
to_device,
unpack_tensor_into_imagelist
)
def process_edit_image(images,
masks,
tasks,
max_seq_len=1024,
max_aspect_ratio=4,
d=16,
**kwargs):
if not isinstance(images, list):
images = [images]
if not isinstance(masks, list):
masks = [masks]
if not isinstance(tasks, list):
tasks = [tasks]
img_tensors = []
mask_tensors = []
for img, mask, task in zip(images, masks, tasks):
if mask is None or mask == '':
mask = Image.new('L', img.size, 0)
W, H = img.size
if H / W > max_aspect_ratio:
img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
elif W / H > max_aspect_ratio:
img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
H, W = img.height, img.width
scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
rH = int(H * scale) // d * d # ensure divisible by self.d
rW = int(W * scale) // d * d
img = TF.resize(img, (rH, rW),
interpolation=TF.InterpolationMode.BICUBIC)
mask = TF.resize(mask, (rH, rW),
interpolation=TF.InterpolationMode.NEAREST_EXACT)
mask = np.asarray(mask)
mask = np.where(mask > 128, 1, 0)
mask = mask.astype(
np.float32) if np.any(mask) else np.ones_like(mask).astype(
np.float32)
img_tensor = TF.to_tensor(img).to(we.device_id)
img_tensor = TF.normalize(img_tensor,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
mask_tensor = TF.to_tensor(mask).to(we.device_id)
if task in ['inpainting', 'Try On', 'Inpainting']:
mask_indicator = mask_tensor.repeat(3, 1, 1)
img_tensor[mask_indicator == 1] = -1.0
img_tensors.append(img_tensor)
mask_tensors.append(mask_tensor)
return img_tensors, mask_tensors
class TextEmbedding(nn.Module):
def __init__(self, embedding_shape):
super().__init__()
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
class ACEInference(DiffusionInference):
def __init__(self, logger=None):
if logger is None:
logger = get_logger(name='scepter')
self.logger = logger
self.loaded_model = {}
self.loaded_model_name = [
'diffusion_model', 'first_stage_model', 'cond_stage_model'
]
def init_from_cfg(self, cfg):
self.name = cfg.NAME
self.is_default = cfg.get('IS_DEFAULT', False)
module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
assert cfg.have('MODEL')
self.diffusion_model = self.infer_model(
cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
'DIFFUSION_MODEL',
None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
self.first_stage_model = self.infer_model(
cfg.MODEL.FIRST_STAGE_MODEL,
module_paras.get(
'FIRST_STAGE_MODEL',
None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
self.cond_stage_model = self.infer_model(
cfg.MODEL.COND_STAGE_MODEL,
module_paras.get(
'COND_STAGE_MODEL',
None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
logger=self.logger)
self.interpolate_func = lambda x: (F.interpolate(
x.unsqueeze(0),
scale_factor=1 / self.size_factor,
mode='nearest-exact') if x is not None else None)
self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
False)
if self.use_text_pos_embeddings:
self.text_position_embeddings = TextEmbedding(
(10, 4096)).eval().requires_grad_(False).to(we.device_id)
else:
self.text_position_embeddings = None
self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
self.size_factor = cfg.get('SIZE_FACTOR', 8)
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
self.dynamic_load(self.first_stage_model, 'first_stage_model')
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
self.dynamic_load(self.diffusion_model, 'diffusion_model')
@torch.no_grad()
def encode_first_stage(self, x, **kwargs):
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
with torch.autocast('cuda',
enabled=(dtype != 'float32'),
dtype=getattr(torch, dtype)):
z = [
self.scale_factor * get_model(self.first_stage_model)._encode(
i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
]
return z
@torch.no_grad()
def decode_first_stage(self, z):
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
with torch.autocast('cuda',
enabled=(dtype != 'float32'),
dtype=getattr(torch, dtype)):
x = [
get_model(self.first_stage_model)._decode(
1. / self.scale_factor * i.to(getattr(torch, dtype)))
for i in z
]
return x
@torch.no_grad()
def __call__(self,
image=None,
mask=None,
prompt='',
task=None,
negative_prompt='',
output_height=512,
output_width=512,
sampler='ddim',
sample_steps=20,
guide_scale=4.5,
guide_rescale=0.5,
seed=-1,
history_io=None,
tar_index=0,
**kwargs):
input_image, input_mask = image, mask
g = torch.Generator(device=we.device_id)
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
g.manual_seed(int(seed))
if input_image is not None:
assert isinstance(input_image, list) and isinstance(
input_mask, list)
if task is None:
task = [''] * len(input_image)
if not isinstance(prompt, list):
prompt = [prompt] * len(input_image)
if history_io is not None and len(history_io) > 0:
his_image, his_maks, his_prompt, his_task = history_io[
'image'], history_io['mask'], history_io[
'prompt'], history_io['task']
assert len(his_image) == len(his_maks) == len(
his_prompt) == len(his_task)
input_image = his_image + input_image
input_mask = his_maks + input_mask
task = his_task + task
prompt = his_prompt + [prompt[-1]]
prompt = [
pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
for i, pp in enumerate(prompt)
]
edit_image, edit_image_mask = process_edit_image(
input_image, input_mask, task, max_seq_len=self.max_seq_len)
image, image_mask = edit_image[tar_index], edit_image_mask[
tar_index]
edit_image, edit_image_mask = [edit_image], [edit_image_mask]
else:
edit_image = edit_image_mask = [[]]
image = torch.zeros(
size=[3, int(output_height),
int(output_width)])
image_mask = torch.ones(
size=[1, int(output_height),
int(output_width)])
if not isinstance(prompt, list):
prompt = [prompt]
image, image_mask, prompt = [image], [image_mask], [prompt]
assert check_list_of_list(prompt) and check_list_of_list(
edit_image) and check_list_of_list(edit_image_mask)
# Assign Negative Prompt
if isinstance(negative_prompt, list):
negative_prompt = negative_prompt[0]
assert isinstance(negative_prompt, str)
n_prompt = copy.deepcopy(prompt)
for nn_p_id, nn_p in enumerate(n_prompt):
assert isinstance(nn_p, list)
n_prompt[nn_p_id][-1] = negative_prompt
ctx, null_ctx = {}, {}
# Get Noise Shape
image = to_device(image)
x = self.encode_first_stage(image)
noise = [
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
for i in x
]
noise, x_shapes = pack_imagelist_into_tensor(noise)
ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
image_mask = to_device(image_mask, strict=False)
cond_mask = [self.interpolate_func(i) for i in image_mask
] if image_mask is not None else [None] * len(image)
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
# Encode Prompt
function_name, dtype = self.get_function_info(self.cond_stage_model)
cont, cont_mask = getattr(get_model(self.cond_stage_model),
function_name)(prompt)
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
cont_mask)
null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
function_name)(n_prompt)
null_cont, null_cont_mask = self.cond_stage_embeddings(
prompt, edit_image, null_cont, null_cont_mask)
ctx['crossattn'] = cont
null_ctx['crossattn'] = null_cont
# Encode Edit Images
edit_image = [to_device(i, strict=False) for i in edit_image]
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
e_img, e_mask = [], []
for u, m in zip(edit_image, edit_image_mask):
if u is None:
continue
if m is None:
m = [None] * len(u)
e_img.append(self.encode_first_stage(u, **kwargs))
e_mask.append([self.interpolate_func(i) for i in m])
null_ctx['edit'] = ctx['edit'] = e_img
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
# Diffusion Process
function_name, dtype = self.get_function_info(self.diffusion_model)
with torch.autocast('cuda',
enabled=dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, dtype)):
latent = self.diffusion.sample(
noise=noise,
sampler=sampler,
model=get_model(self.diffusion_model),
model_kwargs=[{
'cond':
ctx,
'mask':
cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
}, {
'cond':
null_ctx,
'mask':
null_cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
}] if guide_scale is not None and guide_scale > 1 else {
'cond':
null_ctx,
'mask':
cont_mask,
'text_position_embeddings':
self.text_position_embeddings.pos if hasattr(
self.text_position_embeddings, 'pos') else None
},
steps=sample_steps,
show_progress=True,
seed=seed,
guide_scale=guide_scale,
guide_rescale=guide_rescale,
return_intermediate=None,
**kwargs)
# Decode to Pixel Space
samples = unpack_tensor_into_imagelist(latent, x_shapes)
x_samples = self.decode_first_stage(samples)
imgs = [
torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
min=0.0,
max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
for x_i in x_samples
]
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
return imgs
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
if self.use_text_pos_embeddings and not torch.sum(
self.text_position_embeddings.pos) > 0:
identifier_cont, _ = getattr(get_model(self.cond_stage_model),
'encode')(self.text_indentifers,
return_mask=True)
self.text_position_embeddings.load_state_dict(
{'pos': identifier_cont[:, 0, :]})
cont_, cont_mask_ = [], []
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
if isinstance(pp, list):
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
else:
raise NotImplementedError
return cont_, cont_mask_
================================================
FILE: chatbot/run_gradio.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import base64
import copy
import csv
import glob
import io
import os
import random
import re
import string
import sys
import threading
import warnings
import cv2
import gradio as gr
import numpy as np
import torch
import transformers
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from scepter.modules.utils.config import Config
from scepter.modules.utils.directory import get_md5
from scepter.modules.utils.file_system import FS
from scepter.studio.utils.env import init_env
from importlib.metadata import version
from ace_inference import ACEInference
from example import get_examples
from utils import load_image
csv.field_size_limit(sys.maxsize)
refresh_sty = '\U0001f504' # 🔄
clear_sty = '\U0001f5d1' # 🗑️
upload_sty = '\U0001f5bc' # 🖼️
sync_sty = '\U0001f4be' # 💾
chat_sty = '\U0001F4AC' # 💬
video_sty = '\U0001f3a5' # 🎥
lock = threading.Lock()
class ChatBotUI(object):
def __init__(self,
cfg_general_file,
is_debug=False,
language='en',
root_work_dir='./'):
try:
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video
except Exception as e:
print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
if isinstance(cfg_general_file, str):
cfg = Config(cfg_file=cfg_general_file)
else:
cfg = cfg_general_file
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
if not FS.exists(cfg.WORK_DIR):
FS.make_dir(cfg.WORK_DIR)
cfg = init_env(cfg)
self.cache_dir = cfg.WORK_DIR
self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
'*.yaml'))
self.model_choices = dict()
self.default_model_name = ''
for i in self.model_yamls:
model_cfg = Config(load=True, cfg_file=i)
model_name = model_cfg.NAME
if model_cfg.IS_DEFAULT: self.default_model_name = model_name
self.model_choices[model_name] = model_cfg
print('Models: ', self.model_choices.keys())
assert len(self.model_choices) > 0
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
self.model_name = self.default_model_name
self.pipe = ACEInference()
self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
self.max_msgs = 20
self.enable_i2v = cfg.get('ENABLE_I2V', False)
self.gradio_version = version('gradio')
if self.enable_i2v:
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
if self.i2v_model_name == 'CogVideoX-5b-I2V':
with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
local_dir, torch_dtype=torch.bfloat16).cuda()
else:
raise NotImplementedError
with FS.get_dir_to_local_dir(
cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
self.captioner = AutoModel.from_pretrained(
local_dir,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True).eval().cuda()
self.llm_tokenizer = AutoTokenizer.from_pretrained(
local_dir, trust_remote_code=True, use_fast=False)
self.llm_generation_config = dict(max_new_tokens=1024,
do_sample=True)
self.llm_prompt = cfg.LLM.PROMPT
self.llm_max_num = 2
with FS.get_dir_to_local_dir(
cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
self.enhancer = transformers.pipeline(
'text-generation',
model=local_dir,
model_kwargs={'torch_dtype': torch.bfloat16},
device_map='auto',
)
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
There are a few rules to follow:
You will only ever output a single video description per user request.
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
"""
self.enhance_ctx = [
{
'role': 'system',
'content': sys_prompt
},
{
'role':
'user',
'content':
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
},
{
'role':
'assistant',
'content':
"A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
},
{
'role':
'user',
'content':
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
},
{
'role':
'assistant',
'content':
"A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
},
{
'role':
'user',
'content':
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
},
{
'role':
'assistant',
'content':
'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
},
]
def create_ui(self):
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
with gr.Blocks(css=css,
title='Chatbot',
head='Chatbot',
analytics_enabled=False):
self.history = gr.State(value=[])
self.images = gr.State(value={})
self.history_result = gr.State(value={})
self.retry_msg = gr.State(value='')
with gr.Group():
self.ui_mode = gr.State(value='legacy')
with gr.Row(equal_height=True, visible=False) as self.chat_group:
with gr.Column(visible=True) as self.chat_page:
self.chatbot = gr.Chatbot(
height=600,
value=[],
bubble_full_width=False,
show_copy_button=True,
container=False,
placeholder='<strong>Chat Box</strong>')
with gr.Row():
self.clear_btn = gr.Button(clear_sty +
' Clear Chat',
size='sm')
with gr.Column(visible=False) as self.editor_page:
with gr.Tabs(visible=False) as self.upload_tabs:
with gr.Tab(id='ImageUploader',
label='Image Uploader',
visible=True) as self.upload_tab:
self.image_uploader = gr.Image(
height=550,
interactive=True,
type='pil',
image_mode='RGB',
sources=['upload'],
elem_id='image_uploader',
format='png')
with gr.Row():
self.sub_btn_1 = gr.Button(
value='Submit',
elem_id='upload_submit')
self.ext_btn_1 = gr.Button(value='Exit')
with gr.Tabs(visible=False) as self.edit_tabs:
with gr.Tab(id='ImageEditor',
label='Image Editor') as self.edit_tab:
self.mask_type = gr.Dropdown(
label='Mask Type',
choices=[
'Background', 'Composite',
'Outpainting'
],
value='Background')
self.mask_type_info = gr.HTML(
value=
"<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
)
with gr.Accordion(
label='Outpainting Setting',
open=True,
visible=False) as self.outpaint_tab:
with gr.Row(variant='panel'):
self.top_ext = gr.Slider(
show_label=True,
label='Top Extend Ratio',
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.25)
self.bottom_ext = gr.Slider(
show_label=True,
label='Bottom Extend Ratio',
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.25)
with gr.Row(variant='panel'):
self.left_ext = gr.Slider(
show_label=True,
label='Left Extend Ratio',
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.25)
self.right_ext = gr.Slider(
show_label=True,
label='Right Extend Ratio',
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.25)
with gr.Row(variant='panel'):
self.img_pad_btn = gr.Button(
value='Pad Image')
self.image_editor = gr.ImageMask(
value=None,
sources=[],
layers=False,
label='Edit Image',
elem_id='image_editor',
format='png')
with gr.Row():
self.sub_btn_2 = gr.Button(
value='Submit', elem_id='edit_submit')
self.ext_btn_2 = gr.Button(value='Exit')
with gr.Tab(id='ImageViewer',
label='Image Viewer') as self.image_view_tab:
if self.gradio_version >= '5.0.0':
self.image_viewer = gr.Image(
label='Image',
type='pil',
show_download_button=True,
elem_id='image_viewer')
else:
try:
from gradio_imageslider import ImageSlider
except Exception as e:
print(f"Import gradio_imageslider failed, please install.")
self.image_viewer = ImageSlider(
label='Image',
type='pil',
show_download_button=True,
elem_id='image_viewer')
self.ext_btn_3 = gr.Button(value='Exit')
with gr.Tab(id='VideoViewer',
label='Video Viewer',
visible=False) as self.video_view_tab:
self.video_viewer = gr.Video(
label='Video',
interactive=False,
sources=[],
format='mp4',
show_download_button=True,
elem_id='video_viewer',
loop=True,
autoplay=True)
self.ext_btn_4 = gr.Button(value='Exit')
with gr.Row(equal_height=True, visible=True) as self.legacy_group:
with gr.Column():
self.legacy_image_uploader = gr.Image(
height=550,
interactive=True,
type='pil',
image_mode='RGB',
elem_id='legacy_image_uploader',
format='png')
with gr.Column():
self.legacy_image_viewer = gr.Image(
label='Image',
height=550,
type='pil',
interactive=False,
show_download_button=True,
elem_id='image_viewer')
with gr.Accordion(label='Setting', open=False):
with gr.Row():
self.model_name_dd = gr.Dropdown(
choices=self.model_choices,
value=self.default_model_name,
label='Model Version')
with gr.Row():
self.negative_prompt = gr.Textbox(
value='',
placeholder=
'Negative prompt used for Classifier-Free Guidance',
label='Negative Prompt',
container=False)
with gr.Row():
# REFINER_PROMPT
self.refiner_prompt = gr.Textbox(
value=self.pipe.input.get("refiner_prompt", ""),
visible=self.pipe.input.get("refiner_prompt", None) is not None,
placeholder=
'Prompt used for refiner',
label='Refiner Prompt',
container=False)
with gr.Row():
with gr.Column(scale=8, min_width=500):
with gr.Row():
self.step = gr.Slider(minimum=1,
maximum=1000,
value=self.pipe.input.get("sample_steps", 20),
visible=self.pipe.input.get("sample_steps", None) is not None,
label='Sample Step')
self.cfg_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=self.pipe.input.get("guide_scale", 4.5),
visible=self.pipe.input.get("guide_scale", None) is not None,
label='Guidance Scale')
self.rescale = gr.Slider(minimum=0.0,
maximum=1.0,
value=self.pipe.input.get("guide_rescale", 0.5),
visible=self.pipe.input.get("guide_rescale", None) is not None,
label='Rescale')
self.refiner_scale = gr.Slider(minimum=-0.1,
maximum=1.0,
value=self.pipe.input.get("refiner_scale", -1),
visible=self.pipe.input.get("refiner_scale", None) is not None,
label='Refiner Scale')
self.seed = gr.Slider(minimum=-1,
maximum=10000000,
value=-1,
label='Seed')
self.output_height = gr.Slider(
minimum=256,
maximum=1440,
value=self.pipe.input.get("output_height", 1024),
visible=self.pipe.input.get("output_height", None) is not None,
label='Output Height')
self.output_width = gr.Slider(
minimum=256,
maximum=1440,
value=self.pipe.input.get("output_width", 1024),
visible=self.pipe.input.get("output_width", None) is not None,
label='Output Width')
with gr.Column(scale=1, min_width=50):
self.use_history = gr.Checkbox(value=False,
label='Use History')
self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
visible=self.pipe.input.get("use_ace", None) is not None,
label='Use ACE')
self.video_auto = gr.Checkbox(
value=False,
label='Auto Gen Video',
visible=self.enable_i2v)
with gr.Row(variant='panel',
equal_height=True,
visible=self.enable_i2v):
self.video_fps = gr.Slider(minimum=1,
maximum=16,
value=8,
label='Video FPS',
visible=True)
self.video_frames = gr.Slider(minimum=8,
maximum=49,
value=49,
label='Video Frame Num',
visible=True)
self.video_step = gr.Slider(minimum=1,
maximum=1000,
value=50,
label='Video Sample Step',
visible=True)
self.video_cfg_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=6.0,
label='Video Guidance Scale',
visible=True)
self.video_seed = gr.Slider(minimum=-1,
maximum=10000000,
value=-1,
label='Video Seed',
visible=True)
with gr.Row():
self.chatbot_inst = """
**Instruction**:
1. Click 'Upload' button to upload one or more images as input images.
2. Enter '@' in the text box will exhibit all images in the gallery.
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction.
For example, you might say, "Change the girl's skirt in @123456 to blue." The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input.
5. Once your instructions are prepared, please click the "Chat" button to view the edited result in the chat window.
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
"""
self.legacy_inst = """
**Instruction**:
1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
2. Enter '@' in the text box will exhibit all images in the gallery.
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
"""
self.instruction = gr.Markdown(value=self.legacy_inst)
with gr.Row(variant='panel',
equal_height=True,
show_progress=False):
with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
self.upload_btn = gr.Button(value=upload_sty +
' Upload',
variant='secondary')
with gr.Column(scale=5, min_width=500):
self.text = gr.Textbox(
placeholder='Input "@" find history of image',
label='Instruction',
container=False)
with gr.Column(scale=1, min_width=100):
self.chat_btn = gr.Button(value='Generate',
variant='primary')
with gr.Column(scale=1, min_width=100):
self.retry_btn = gr.Button(value=refresh_sty +
' Retry',
variant='secondary')
with gr.Column(scale=1, min_width=100):
self.mode_checkbox = gr.Checkbox(
value=False,
label='ChatBot')
with gr.Column(scale=(1 if self.enable_i2v else 0),
min_width=0):
self.video_gen_btn = gr.Button(value=video_sty +
' Gen Video',
variant='secondary',
visible=self.enable_i2v)
with gr.Column(scale=(1 if self.enable_i2v else 0),
min_width=0):
self.extend_prompt = gr.Checkbox(
value=True,
label='Extend Prompt',
visible=self.enable_i2v)
with gr.Row():
self.gallery = gr.Gallery(visible=False,
label='History',
columns=10,
allow_preview=False,
interactive=False)
self.eg = gr.Column(visible=True)
def set_callbacks(self, *args, **kwargs):
########################################
def change_model(model_name):
if model_name not in self.model_choices:
gr.Info('The provided model name is not a valid choice!')
return model_name, gr.update(), gr.update()
if model_name != self.model_name:
lock.acquire()
del self.pipe
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
self.pipe = ACEInference()
self.pipe.init_from_cfg(self.model_choices[model_name])
self.model_name = model_name
lock.release()
return (model_name, gr.update(), gr.update(),
gr.Slider(
value=self.pipe.input.get("sample_steps", 20),
visible=self.pipe.input.get("sample_steps", None) is not None),
gr.Slider(
value=self.pipe.input.get("guide_scale", 4.5),
visible=self.pipe.input.get("guide_scale", None) is not None),
gr.Slider(
value=self.pipe.input.get("guide_rescale", 0.5),
visible=self.pipe.input.get("guide_rescale", None) is not None),
gr.Slider(
value=self.pipe.input.get("output_height", 1024),
visible=self.pipe.input.get("output_height", None) is not None),
gr.Slider(
value=self.pipe.input.get("output_width", 1024),
visible=self.pipe.input.get("output_width", None) is not None),
gr.Textbox(
value=self.pipe.input.get("refiner_prompt", ""),
visible=self.pipe.input.get("refiner_prompt", None) is not None),
gr.Slider(
value=self.pipe.input.get("refiner_scale", -1),
visible=self.pipe.input.get("refiner_scale", None) is not None
),
gr.Checkbox(
value=self.pipe.input.get("use_ace", True),
visible=self.pipe.input.get("use_ace", None) is not None
)
)
self.model_name_dd.change(
change_model,
inputs=[self.model_name_dd],
outputs=[
self.model_name_dd, self.chatbot, self.text,
self.step,
self.cfg_scale, self.rescale, self.output_height,
self.output_width, self.refiner_prompt, self.refiner_scale,
self.use_ace])
def mode_change(mode_check):
if mode_check:
# ChatBot
return (
gr.Row(visible=False),
gr.Row(visible=True),
gr.Button(value='Generate'),
gr.State(value='chatbot'),
gr.Column(visible=True),
gr.Markdown(value=self.chatbot_inst)
)
else:
# Legacy
return (
gr.Row(visible=True),
gr.Row(visible=False),
gr.Button(value=chat_sty + ' Chat'),
gr.State(value='legacy'),
gr.Column(visible=False),
gr.Markdown(value=self.legacy_inst)
)
self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
outputs=[self.legacy_group, self.chat_group,
self.chat_btn, self.ui_mode,
self.upload_panel, self.instruction])
########################################
def generate_gallery(text, images):
if text.endswith(' '):
return gr.update(), gr.update(visible=False)
elif text.endswith('@'):
gallery_info = []
for image_id, image_meta in images.items():
thumbnail_path = image_meta['thumbnail']
gallery_info.append((thumbnail_path, image_id))
return gr.update(), gr.update(visible=True, value=gallery_info)
else:
gallery_info = []
match = re.search('@([^@ ]+)$', text)
if match:
prefix = match.group(1)
for image_id, image_meta in images.items():
if not image_id.startswith(prefix):
continue
thumbnail_path = image_meta['thumbnail']
gallery_info.append((thumbnail_path, image_id))
if len(gallery_info) > 0:
return gr.update(), gr.update(visible=True,
value=gallery_info)
else:
return gr.update(), gr.update(visible=False)
else:
return gr.update(), gr.update(visible=False)
self.text.input(generate_gallery,
inputs=[self.text, self.images],
outputs=[self.text, self.gallery],
show_progress='hidden')
########################################
def select_image(text, evt: gr.SelectData):
image_id = evt.value['caption']
text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
return gr.update(value=text), gr.update(visible=False, value=None)
self.gallery.select(select_image,
inputs=self.text,
outputs=[self.text, self.gallery])
########################################
def generate_video(message,
extend_prompt,
history,
images,
num_steps,
num_frames,
cfg_scale,
fps,
seed,
progress=gr.Progress(track_tqdm=True)):
from diffusers.utils import export_to_video
generator = torch.Generator(device='cuda').manual_seed(seed)
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
if len(img_ids) == 0:
history.append((
message,
'Sorry, no images were found in the prompt to be used as the first frame of the video.'
))
while len(history) >= self.max_msgs:
history.pop(0)
return history, self.get_history(
history), gr.update(), gr.update(visible=False)
img_id = img_ids[0]
prompt = re.sub(f'@{img_id}\s+', '', message)
if extend_prompt:
messages = copy.deepcopy(self.enhance_ctx)
messages.append({
'role':
'user',
'content':
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
})
lock.acquire()
outputs = self.enhancer(
messages,
max_new_tokens=200,
)
prompt = outputs[0]['generated_text'][-1]['content']
print(prompt)
lock.release()
img_meta = images[img_id]
img_path = img_meta['image']
image = Image.open(img_path).convert('RGB')
lock.acquire()
video = self.i2v_pipe(
prompt=prompt,
image=image,
num_videos_per_prompt=1,
num_inference_steps=num_steps,
num_frames=num_frames,
guidance_scale=cfg_scale,
generator=generator,
).frames[0]
lock.release()
out_video_path = export_to_video(video, fps=fps)
history.append((
f"Based on first frame @{img_id} and description '{prompt}', generate a video",
'This is generated video:'))
history.append((None, out_video_path))
while len(history) >= self.max_msgs:
history.pop(0)
return history, self.get_history(history), gr.update(
value=''), gr.update(visible=False)
self.video_gen_btn.click(
generate_video,
inputs=[
self.text, self.extend_prompt, self.history, self.images,
self.video_step, self.video_frames, self.video_cfg_scale,
self.video_fps, self.video_seed
],
outputs=[self.history, self.chatbot, self.text, self.gallery])
########################################
def run_chat(
message,
legacy_image,
ui_mode,
use_ace,
extend_prompt,
history,
images,
use_history,
history_result,
negative_prompt,
cfg_scale,
rescale,
refiner_prompt,
refiner_scale,
step,
seed,
output_h,
output_w,
video_auto,
video_steps,
video_frames,
video_cfg_scale,
video_fps,
video_seed,
progress=gr.Progress(track_tqdm=True)):
legacy_img_ids = []
if ui_mode == 'legacy':
if legacy_image is not None:
history, images, img_id = self.add_uploaded_image_to_history(
legacy_image, history, images)
legacy_img_ids.append(img_id)
retry_msg = message
gen_id = get_md5(message)[:12]
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
history_io = None
if len(img_ids) < 1:
img_ids = legacy_img_ids
for img_id in img_ids:
if f'@{img_id}' not in message:
message = f'@{img_id} ' + message
new_message = message
if len(img_ids) > 0:
edit_image, edit_image_mask, edit_task = [], [], []
for i, img_id in enumerate(img_ids):
if img_id not in images:
gr.Info(
f'The input image ID {img_id} is not exist... Skip loading image.'
)
continue
placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
new_message = re.sub(f'@{img_id}', placeholder,
new_message)
img_meta = images[img_id]
img_path = img_meta['image']
img_mask = img_meta['mask']
img_mask_type = img_meta['mask_type']
if img_mask_type is not None and img_mask_type == 'Composite':
task = 'inpainting'
else:
task = ''
edit_image.append(Image.open(img_path).convert('RGB'))
edit_image_mask.append(
Image.open(img_mask).
convert('L') if img_mask is not None else None)
edit_task.append(task)
if use_history and (img_id in history_result):
history_io = history_result[img_id]
buffered = io.BytesIO()
edit_image[0].save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
else:
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
edit_image = None
edit_image_mask = None
edit_task = ''
print(new_message)
imgs = self.pipe(
image=edit_image,
mask=edit_image_mask,
task=edit_task,
prompt=[new_message] *
len(edit_image) if edit_image is not None else [new_message],
negative_prompt=[negative_prompt] * len(edit_image)
if edit_image is not None else [negative_prompt],
history_io=history_io,
output_height=output_h,
output_width=output_w,
sampler='ddim',
sample_steps=step,
guide_scale=cfg_scale,
guide_rescale=rescale,
seed=seed,
refiner_prompt=refiner_prompt,
refiner_scale=refiner_scale,
use_ace=use_ace
)
img = imgs[0]
img.save(save_path, format='PNG')
if history_io:
history_io_new = copy.deepcopy(history_io)
history_io_new['image'] += edit_image[:1]
history_io_new['mask'] += edit_image_mask[:1]
history_io_new['task'] += edit_task[:1]
history_io_new['prompt'] += [new_message]
history_io_new['image'] = history_io_new['image'][-5:]
history_io_new['mask'] = history_io_new['mask'][-5:]
history_io_new['task'] = history_io_new['task'][-5:]
history_io_new['prompt'] = history_io_new['prompt'][-5:]
history_result[gen_id] = history_io_new
elif edit_image is not None and len(edit_image) > 0:
history_io_new = {
'image': edit_image[:1],
'mask': edit_image_mask[:1],
'task': edit_task[:1],
'prompt': [new_message]
}
history_result[gen_id] = history_io_new
w, h = img.size
if w > h:
tb_w = 128
tb_h = int(h * tb_w / w)
else:
tb_h = 128
tb_w = int(w * tb_h / h)
thumbnail_path = os.path.join(self.cache_dir,
f'{gen_id}_thumbnail.jpg')
thumbnail = img.resize((tb_w, tb_h))
thumbnail.save(thumbnail_path, format='JPEG')
images[gen_id] = {
'image': save_path,
'mask': None,
'mask_type': None,
'thumbnail': thumbnail_path
}
buffered = io.BytesIO()
img.convert('RGB').save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
history.append(
(message,
f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))
if video_auto:
if video_seed is None or video_seed == -1:
video_seed = random.randint(0, 10000000)
lock.acquire()
generator = torch.Generator(
device='cuda').manual_seed(video_seed)
pixel_values = load_image(img.convert('RGB'),
max_num=self.llm_max_num).to(
torch.bfloat16).cuda()
prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
self.llm_prompt,
self.llm_generation_config)
print(prompt)
lock.release()
if extend_prompt:
messages = copy.deepcopy(self.enhance_ctx)
messages.append({
'role':
'user',
'content':
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
})
lock.acquire()
outputs = self.enhancer(
messages,
max_new_tokens=200,
)
prompt = outputs[0]['generated_text'][-1]['content']
print(prompt)
lock.release()
lock.acquire()
video = self.i2v_pipe(
prompt=prompt,
image=img,
num_videos_per_prompt=1,
num_inference_steps=video_steps,
num_frames=video_frames,
guidance_scale=video_cfg_scale,
generator=generator,
).frames[0]
lock.release()
out_video_path = export_to_video(video, fps=video_fps)
history.append((
f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
'This is generated video:'))
history.append((None, out_video_path))
while len(history) >= self.max_msgs:
history.pop(0)
return (history, images, gr.Image(value=save_path),
history_result, self.get_history(
history), gr.update(), gr.update(
visible=False), retry_msg)
chat_inputs = [
self.legacy_image_uploader, self.ui_mode, self.use_ace,
self.extend_prompt, self.history, self.images, self.use_history,
self.history_result, self.negative_prompt, self.cfg_scale,
self.rescale, self.refiner_prompt, self.refiner_scale,
self.step, self.seed, self.output_height,
self.output_width, self.video_auto, self.video_step,
self.video_frames, self.video_cfg_scale, self.video_fps,
self.video_seed
]
chat_outputs = [
self.history, self.images, self.legacy_image_viewer,
self.history_result, self.chatbot,
self.text, self.gallery, self.retry_msg
]
self.chat_btn.click(run_chat,
inputs=[self.text] + chat_inputs,
outputs=chat_outputs)
self.text.submit(run_chat,
inputs=[self.text] + chat_inputs,
outputs=chat_outputs)
def retry_fn(*args):
return run_chat(*args)
self.retry_btn.click(retry_fn,
inputs=[self.retry_msg] + chat_inputs,
outputs=chat_outputs)
########################################
def run_example(task, img, img_mask, ref1, prompt, seed):
edit_image, edit_image_mask, edit_task = [], [], []
if img is not None:
w, h = img.size
if w > 2048:
ratio = w / 2048.
w = 2048
h = int(h / ratio)
if h > 2048:
ratio = h / 2048.
h = 2048
w = int(w / ratio)
img = img.resize((w, h))
edit_image.append(img)
if img_mask is not None:
img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None
edit_image_mask.append(
img_mask if img_mask is not None else None)
edit_task.append(task)
if ref1 is not None:
ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None
if ref1 is not None:
edit_image.append(ref1)
edit_image_mask.append(None)
edit_task.append('')
buffered = io.BytesIO()
img.save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
else:
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
edit_image = None
edit_image_mask = None
edit_task = ''
img_num = len(edit_image) if edit_image is not None else 1
imgs = self.pipe(
image=edit_image,
mask=edit_image_mask,
task=edit_task,
prompt=[prompt] * img_num,
negative_prompt=[''] * img_num,
seed=seed,
refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
)
img = imgs[0]
buffered = io.BytesIO()
img.convert('RGB').save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
history = [(prompt,
f'{pre_info} The generated image is:\n {img_str}')]
img_id = get_md5(img_b64)[:12]
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
img.convert('RGB').save(save_path)
return self.get_history(history), gr.update(value=''), gr.update(
visible=False), gr.update(value=save_path), gr.update(value=-1)
with self.eg:
self.example_task = gr.Text(label='Task Name',
value='',
visible=False)
self.example_image = gr.Image(label='Edit Image',
type='pil',
image_mode='RGB',
visible=False)
self.example_mask = gr.Image(label='Edit Image Mask',
type='pil',
image_mode='L',
visible=False)
self.example_ref_im1 = gr.Image(label='Ref Image',
type='pil',
image_mode='RGB',
visible=False)
self.examples = gr.Examples(
fn=run_example,
examples=self.chatbot_examples,
inputs=[
self.example_task, self.example_image, self.example_mask,
self.example_ref_im1, self.text, self.seed
],
outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed],
examples_per_page=4,
cache_examples=False,
run_on_click=True)
########################################
def upload_image():
return (gr.update(visible=True,
scale=1), gr.update(visible=True, scale=1),
gr.update(visible=True), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False),
gr.update(visible=True))
self.upload_btn.click(upload_image,
inputs=[],
outputs=[
self.chat_page, self.editor_page,
self.upload_tab, self.edit_tab,
self.image_view_tab, self.video_view_tab,
self.upload_tabs
])
########################################
def edit_image(evt: gr.SelectData):
if isinstance(evt.value, str):
img_b64s = re.findall(
'<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
evt.value)
imgs = [
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
for i in img_b64s
]
if len(imgs) > 0:
if len(imgs) == 2:
if self.gradio_version >= '5.0.0':
view_img = copy.deepcopy(imgs[-1])
else:
view_img = copy.deepcopy(imgs)
edit_img = copy.deepcopy(imgs[-1])
else:
if self.gradio_version >= '5.0.0':
view_img = copy.deepcopy(imgs[-1])
else:
view_img = [
copy.deepcopy(imgs[-1]),
copy.deepcopy(imgs[-1])
]
edit_img = copy.deepcopy(imgs[-1])
return (gr.update(visible=True,
scale=1), gr.update(visible=True,
scale=1),
gr.update(visible=False), gr.update(visible=True),
gr.update(visible=True), gr.update(visible=False),
gr.update(value=edit_img),
gr.update(value=view_img), gr.update(value=None),
gr.update(visible=True))
else:
return (gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update())
elif isinstance(evt.value, dict) and evt.value.get(
'component', '') == 'video':
value = evt.value['value']['video']['path']
return (gr.update(visible=True,
scale=1), gr.update(visible=True, scale=1),
gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=True),
gr.update(), gr.update(), gr.update(value=value),
gr.update())
else:
return (gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update())
self.chatbot.select(edit_image,
outputs=[
self.chat_page, self.editor_page,
self.upload_tab, self.edit_tab,
self.image_view_tab, self.video_view_tab,
self.image_editor, self.image_viewer,
self.video_viewer, self.edit_tabs
])
if self.gradio_version < '5.0.0':
self.image_viewer.change(lambda x: x,
inputs=self.image_viewer,
outputs=self.image_viewer)
########################################
def submit_upload_image(image, history, images):
history, images, _ = self.add_uploaded_image_to_history(
image, history, images)
return gr.update(visible=False), gr.update(
visible=True), gr.update(
value=self.get_history(history)), history, images
self.sub_btn_1.click(
submit_upload_image,
inputs=[self.image_uploader, self.history, self.images],
outputs=[
self.editor_page, self.chat_page, self.chatbot, self.history,
self.images
])
########################################
def submit_edit_image(imagemask, mask_type, history, images):
history, images = self.add_edited_image_to_history(
imagemask, mask_type, history, images)
return gr.update(visible=False), gr.update(
visible=True), gr.update(
value=self.get_history(history)), history, images
self.sub_btn_2.click(submit_edit_image,
inputs=[
self.image_editor, self.mask_type,
self.history, self.images
],
outputs=[
self.editor_page, self.chat_page,
self.chatbot, self.history, self.images
])
########################################
def exit_edit():
return gr.update(visible=False), gr.update(visible=True, scale=3)
self.ext_btn_1.click(exit_edit,
outputs=[self.editor_page, self.chat_page])
self.ext_btn_2.click(exit_edit,
outputs=[self.editor_page, self.chat_page])
self.ext_btn_3.click(exit_edit,
outputs=[self.editor_page, self.chat_page])
self.ext_btn_4.click(exit_edit,
outputs=[self.editor_page, self.chat_page])
########################################
def update_mask_type_info(mask_type):
if mask_type == 'Background':
info = 'Background mode will not erase the visual content in the mask area'
visible = False
elif mask_type == 'Composite':
info = 'Composite mode will erase the visual content in the mask area'
visible = False
elif mask_type == 'Outpainting':
info = 'Outpaint mode is used for preparing input image for outpainting task'
visible = True
return (gr.update(
visible=True,
value=
f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
), gr.update(visible=visible))
self.mask_type.change(update_mask_type_info,
inputs=self.mask_type,
outputs=[self.mask_type_info, self.outpaint_tab])
########################################
def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
image):
img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
h, w = img.shape[:2]
new_h = int(h * (top_ratio + bottom_ratio + 1))
new_w = int(w * (left_ratio + right_ratio + 1))
start_h = int(h * top_ratio)
start_w = int(w * left_ratio)
new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
new_img[start_h:start_h + h, start_w:start_w + w, :] = img
new_mask[start_h:start_h + h, start_w:start_w + w] = 0
layer = np.concatenate([new_img, new_mask], axis=2)
value = {
'background': new_img,
'composite': new_img,
'layers': [layer]
}
return gr.update(value=value)
self.img_pad_btn.click(extend_image,
inputs=[
self.top_ext, self.bottom_ext,
self.left_ext, self.right_ext,
self.image_editor
],
outputs=self.image_editor)
########################################
def clear_chat(history, images, history_result):
history.clear()
images.clear()
history_result.clear()
return history, images, history_result, self.get_history(history)
self.clear_btn.click(
clear_chat,
inputs=[self.history, self.images, self.history_result],
outputs=[
self.history, self.images, self.history_result, self.chatbot
])
def get_history(self, history):
info = []
for item in history:
new_item = [None, None]
if isinstance(item[0], str) and item[0].endswith('.mp4'):
new_item[0] = gr.Video(item[0], format='mp4')
else:
new_item[0] = item[0]
if isinstance(item[1], str) and item[1].endswith('.mp4'):
new_item[1] = gr.Video(item[1], format='mp4')
else:
new_item[1] = item[1]
info.append(new_item)
return info
def generate_random_string(self, length=20):
letters_and_digits = string.ascii_letters + string.digits
random_string = ''.join(
random.choice(letters_and_digits) for i in range(length))
return random_string
def add_edited_image_to_history(self, image, mask_type, history, images):
if mask_type == 'Composite':
img = Image.fromarray(image['composite'])
else:
img = Image.fromarray(image['background'])
img_id = get_md5(self.generate_random_string())[:12]
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
img.convert('RGB').save(save_path)
mask = image['layers'][0][:, :, 3]
mask = Image.fromarray(mask).convert('RGB')
mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
mask.save(mask_path)
w, h = img.size
if w > h:
tb_w = 128
tb_h = int(h * tb_w / w)
else:
tb_h = 128
tb_w = int(w * tb_h / h)
if mask_type == 'Background':
comp_mask = np.array(mask, dtype=np.uint8)
mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
0.6).astype(np.uint8)
comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
thumbnail = Image.alpha_composite(
img.convert('RGBA'),
Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
else:
thumbnail = img.convert('RGB')
thumbnail_path = os.path.join(self.cache_dir,
f'{img_id}_thumbnail.jpg')
thumbnail = thumbnail.resize((tb_w, tb_h))
thumbnail.save(thumbnail_path, format='JPEG')
buffered = io.BytesIO()
img.convert('RGB').save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
buffered = io.BytesIO()
mask.convert('RGB').save(buffered, format='PNG')
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
images[img_id] = {
'image': save_path,
'mask': mask_path,
'mask_type': mask_type,
'thumbnail': thumbnail_path
}
history.append((
None,
f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
))
return history, images
def add_uploaded_image_to_history(self, img, history, images):
img_id = get_md5(self.generate_random_string())[:12]
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
w, h = img.size
if w > 2048:
ratio = w / 2048.
w = 2048
h = int(h / ratio)
if h > 2048:
ratio = h / 2048.
h = 2048
w = int(w / ratio)
img = img.resize((w, h))
img.save(save_path)
w, h = img.size
if w > h:
tb_w = 128
tb_h = int(h * tb_w / w)
else:
tb_h = 128
tb_w = int(w * tb_h / h)
thumbnail_path = os.path.join(self.cache_dir,
f'{img_id}_thumbnail.jpg')
thumbnail = img.resize((tb_w, tb_h))
thumbnail.save(thumbnail_path, format='JPEG')
images[img_id] = {
'image': save_path,
'mask': None,
'mask_type': None,
'thumbnail': thumbnail_path
}
buffered = io.BytesIO()
img.convert('RGB').save(buffered, format='PNG')
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
history.append(
(None,
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
return history, images, img_id
def run_gr(cfg):
with gr.Blocks() as demo:
chatbot = ChatBotUI(cfg)
chatbot.create_ui()
chatbot.set_callbacks()
demo.launch(server_name='0.0.0.0',
server_port=cfg.args.server_port,
root_path=cfg.args.root_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
parser.add_argument('--server_port',
dest='server_port',
help='',
type=int,
default=2345)
parser.add_argument('--root_path', dest='root_path', help='', default='')
cfg = Config(load=True, parser_ins=parser)
run_gr(cfg)
================================================
FILE: chatbot/utils.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image,
min_num=1,
max_num=12,
image_size=448,
use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
if isinstance(image_file, str):
image = Image.open(image_file).convert('RGB')
else:
image = image_file
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
image_size=input_size,
use_thumbnail=True,
max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
================================================
FILE: config/inference_config/chatbot_ui.yaml
================================================
WORK_DIR: ./cache/chatbot
FILE_SYSTEM:
- NAME: "HuggingfaceFs"
TEMP_DIR: ./cache
- NAME: "ModelscopeFs"
TEMP_DIR: ./cache
- NAME: "LocalFs"
TEMP_DIR: ./cache
- NAME: "HttpFs"
TEMP_DIR: ./cache
#
ENABLE_I2V: False
SKIP_EXAMPLES: False
#
MODEL:
EDIT_MODEL:
MODEL_CFG_DIR: config/inference_config/models/
DEFAULT: ace_0.6b_512
I2V:
MODEL_NAME: CogVideoX-5b-I2V
MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
CAPTIONER:
MODEL_NAME: InternVL2-2B
MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
ENHANCER:
MODEL_NAME: Meta-Llama-3.1-8B-Instruct
MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
================================================
FILE: config/inference_config/models/ace_0.6b_1024.yaml
================================================
NAME: ACE_0.6B_1024
IS_DEFAULT: False
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
PARAS:
#
INPUT:
INPUT_IMAGE:
INPUT_MASK:
TASK:
PROMPT: ""
NEGATIVE_PROMPT: ""
OUTPUT_HEIGHT: 1024
OUTPUT_WIDTH: 1024
SAMPLER: ddim
SAMPLE_STEPS: 50
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
SEED: -1
TAR_INDEX: 0
OUTPUT:
LATENT:
IMAGES:
SEED:
MODULES_PARAS:
FIRST_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: float16
INPUT: ["IMAGE"]
- NAME: decode
DTYPE: float16
INPUT: ["LATENT"]
#
DIFFUSION_MODEL:
FUNCTION:
- NAME: forward
DTYPE: float16
INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
#
COND_STAGE_MODEL:
FUNCTION:
- NAME: encode_list
DTYPE: bfloat16
INPUT: ["PROMPT"]
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT: ""
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 4096
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: ACETextEmbedder
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
================================================
FILE: config/inference_config/models/ace_0.6b_1024_refiner.yaml
================================================
NAME: ACE_0.6B_1024_REFINER
IS_DEFAULT: False
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
PARAS:
#
INPUT:
INPUT_IMAGE:
INPUT_MASK:
TASK:
PROMPT: ""
NEGATIVE_PROMPT: ""
OUTPUT_HEIGHT: 1024
OUTPUT_WIDTH: 1024
SAMPLER: ddim
SAMPLE_STEPS: 50
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
SEED: -1
TAR_INDEX: 0
REFINER_SCALE: 0.2
USE_ACE: True
#REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR"
REFINER_PROMPT: "High Resolution, Sharpness, Clarity, Detail Enhancement, Noise Reduction, HD, 4k, Image Restoration, HDR"
OUTPUT:
LATENT:
IMAGES:
SEED:
MODULES_PARAS:
FIRST_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: float16
INPUT: ["IMAGE"]
- NAME: decode
DTYPE: float16
INPUT: ["LATENT"]
#
DIFFUSION_MODEL:
FUNCTION:
- NAME: forward
DTYPE: float16
INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
#
COND_STAGE_MODEL:
FUNCTION:
- NAME: encode_list
DTYPE: bfloat16
INPUT: ["PROMPT"]
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT: ""
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/dit/ace_0.6b_1024px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 4096
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: ACETextEmbedder
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
ACE_PROMPT: [
"A cute cartoon rabbit holding a whiteboard that says 'ACE Refiner', standing in a sunny meadow filled with flowers, with a big smile and bright colors.",
"A beautiful young woman with long flowing hair, wearing a summer dress, holding a whiteboard that reads 'ACE Refiner' while sitting on a park bench surrounded by cherry blossoms.",
"An adorable cartoon cat wearing oversized glasses, holding a whiteboard that says 'ACE Refiner', perched on a stack of colorful books in a cozy library setting.",
"A charming girl with pigtails, wearing a cute school uniform, enthusiastically holding a whiteboard that has 'ACE Refiner' written on it, in a bright and cheerful classroom full of educational posters.",
"A friendly cartoon dog with floppy ears, sitting in front of a doghouse, proudly holding a whiteboard that says 'ACE Refiner', with a playful expression and a blue sky in the background.",
"A cute anime girl with big expressive eyes, dressed in a colorful outfit, holding a whiteboard that reads 'ACE Refiner' in a fantastical landscape filled with mythical creatures.",
"A vibrant cartoon fox holding a whiteboard that says 'ACE Refiner', standing on a rock by a sparkling stream, surrounded by lush greenery and butterflies.",
"A stylish young woman in a business outfit, smiling as she holds a whiteboard written with 'ACE Refiner', in a modern office filled with plants and natural light.",
"A cute cartoon unicorn holding a sparkling whiteboard that says 'ACE Refiner', frolicking in a magical forest, with rainbows and stars in the background.",
"A happy family, consisting of a cute little girl and her playful puppy, holding a whiteboard that says 'ACE Refiner', together in their backyard on a sunny day."
]
REFINER_MODEL:
NAME: ""
IS_DEFAULT: False
DEFAULT_PARAS:
PARAS:
RESOLUTIONS: [ [ 1024, 1024 ] ]
INPUT:
INPUT_IMAGE:
INPUT_MASK:
TASK:
PROMPT: ""
NEGATIVE_PROMPT: ""
OUTPUT_HEIGHT: 1024
OUTPUT_WIDTH: 1024
SAMPLER: flow_euler
SAMPLE_STEPS: 30
GUIDE_SCALE: 3.5
GUIDE_RESCALE:
OUTPUT:
LATENT:
IMAGES:
SEED:
MODULES_PARAS:
FIRST_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: bfloat16
INPUT: [ "IMAGE" ]
- NAME: decode
DTYPE: bfloat16
INPUT: [ "LATENT" ]
PARAS:
SCALE_FACTOR: 1.5305
SHIFT_FACTOR: 0.0609
SIZE_FACTOR: 8
DIFFUSION_MODEL:
FUNCTION:
- NAME: forward
DTYPE: bfloat16
INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
COND_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: bfloat16
INPUT: [ "PROMPT" ]
MODEL:
DIFFUSION:
NAME: DiffusionFluxRF
PREDICTION_TYPE: raw
NOISE_SCHEDULER:
NAME: FlowMatchSigmaScheduler
WEIGHTING_SCHEME: logit_normal
SHIFT: 3.0
LOGIT_MEAN: 0.0
LOGIT_STD: 1.0
MODE_SCALE: 1.29
DIFFUSION_MODEL:
NAME: FluxMR
PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors
IN_CHANNELS: 64
OUT_CHANNELS: 64
HIDDEN_SIZE: 3072
NUM_HEADS: 24
AXES_DIM: [ 16, 56, 56 ]
THETA: 10000
VEC_IN_DIM: 768
GUIDANCE_EMBED: True
CONTEXT_IN_DIM: 4096
MLP_RATIO: 4.0
QKV_BIAS: True
DEPTH: 19
DEPTH_SINGLE_BLOCKS: 38
USE_GRAD_CHECKPOINT: True
ATTN_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKLFlux
EMBED_DIM: 16
PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@ae.safetensors
IGNORE_KEYS: [ ]
BATCH_SIZE: 8
USE_CONV: False
SCALE_FACTOR: 0.3611
SHIFT_FACTOR: 0.1159
#
ENCODER:
NAME: Encoder
USE_CHECKPOINT: False
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 16
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
USE_CHECKPOINT: False
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 16
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: T5PlusClipFluxEmbedder
T5_MODEL:
NAME: HFEmbedder
HF_MODEL_CLS: T5EncoderModel
MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/
HF_TOKENIZER_CLS: T5Tokenizer
TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/
MAX_LENGTH: 512
OUTPUT_KEY: last_hidden_state
D_TYPE: bfloat16
BATCH_INFER: False
CLEAN: whitespace
CLIP_MODEL:
NAME: HFEmbedder
HF_MODEL_CLS: CLIPTextModel
MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder/
HF_TOKENIZER_CLS: CLIPTokenizer
TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer/
MAX_LENGTH: 77
OUTPUT_KEY: pooler_output
D_TYPE: bfloat16
BATCH_INFER: True
CLEAN: whitespace
================================================
FILE: config/inference_config/models/ace_0.6b_512.yaml
================================================
NAME: ACE_0.6B_512
IS_DEFAULT: True
USE_DYNAMIC_MODEL: False
DEFAULT_PARAS:
PARAS:
#
INPUT:
INPUT_IMAGE:
INPUT_MASK:
TASK:
PROMPT: ""
NEGATIVE_PROMPT: ""
OUTPUT_HEIGHT: 512
OUTPUT_WIDTH: 512
SAMPLER: ddim
SAMPLE_STEPS: 20
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
SEED: -1
TAR_INDEX: 0
OUTPUT:
LATENT:
IMAGES:
SEED:
MODULES_PARAS:
FIRST_STAGE_MODEL:
FUNCTION:
- NAME: encode
DTYPE: float16
INPUT: ["IMAGE"]
- NAME: decode
DTYPE: float16
INPUT: ["LATENT"]
#
DIFFUSION_MODEL:
FUNCTION:
- NAME: forward
DTYPE: float16
INPUT: ["SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE"]
#
COND_STAGE_MODEL:
FUNCTION:
- NAME: encode_list
DTYPE: bfloat16
INPUT: ["PROMPT"]
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT: ""
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 1024
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: ACETextEmbedder
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
================================================
FILE: config/train_config/ace_0.6b_1024_train.yaml
================================================
ENV:
BACKEND: nccl
SEED: 2024
#
SOLVER:
NAME: ACESolverV1
RESUME_FROM:
LOAD_MODEL_ONLY: True
USE_FSDP: False
SHARDING_STRATEGY:
USE_AMP: True
DTYPE: float16
CHANNELS_LAST: True
MAX_STEPS: 500
MAX_EPOCHS: -1
NUM_FOLDS: 1
ACCU_STEP: 1
EVAL_INTERVAL: 50
RESCALE_LR: False
#
WORK_DIR: ./cache/exp/exp1
LOG_FILE: std_log.txt
#
FILE_SYSTEM:
- NAME: "HuggingfaceFs"
TEMP_DIR: ./cache
- NAME: "ModelscopeFs"
TEMP_DIR: ./cache
- NAME: "LocalFs"
TEMP_DIR: ./cache
- NAME: "HttpFs"
TEMP_DIR: ./cache
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT:
USE_EMA: True
EVAL_EMA: False
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 4096
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: T5EmbedderHF
PRETRAINED_MODEL: ms://iic/ACE-0.6B-1024px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-1024px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
LOSS:
NAME: ReconstructLoss
LOSS_TYPE: l2
#
SAMPLE_ARGS:
SAMPLER: ddim
SAMPLE_STEPS: 20
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
#
OPTIMIZER:
NAME: AdamW
LEARNING_RATE: 1e-5
EPS: 1e-10
WEIGHT_DECAY: 5e-4
#
TRAIN_DATA:
NAME: ACEDemoDataset
MODE: train
MS_DATASET_NAME: cache/datasets/hed_pair
MS_DATASET_NAMESPACE: ""
MS_DATASET_SPLIT: "train"
MS_DATASET_SUBNAME: ""
PROMPT_PREFIX: ""
REPLACE_STYLE: False
MAX_SEQ_LEN: 4096
PIN_MEMORY: True
BATCH_SIZE: 1
NUM_WORKERS: 1
SAMPLER:
NAME: LoopSampler
#
TRAIN_HOOKS:
-
NAME: BackwardHook
PRIORITY: 0
-
NAME: LogHook
LOG_INTERVAL: 50
-
NAME: CheckpointHook
INTERVAL: 100
-
NAME: ProbeDataHook
PROB_INTERVAL: 100
================================================
FILE: config/train_config/ace_0.6b_512_train.yaml
================================================
ENV:
BACKEND: nccl
SEED: 2024
#
SOLVER:
NAME: ACESolverV1
RESUME_FROM:
LOAD_MODEL_ONLY: True
USE_FSDP: False
SHARDING_STRATEGY:
USE_AMP: True
DTYPE: float16
CHANNELS_LAST: True
MAX_STEPS: 500
MAX_EPOCHS: -1
NUM_FOLDS: 1
ACCU_STEP: 1
EVAL_INTERVAL: 50
RESCALE_LR: False
#
WORK_DIR: ./cache/exp/exp1
LOG_FILE: std_log.txt
#
FILE_SYSTEM:
- NAME: "HuggingfaceFs"
TEMP_DIR: ./cache
- NAME: "ModelscopeFs"
TEMP_DIR: ./cache
- NAME: "LocalFs"
TEMP_DIR: ./cache
- NAME: "HttpFs"
TEMP_DIR: ./cache
#
MODEL:
NAME: LdmACE
PRETRAINED_MODEL:
IGNORE_KEYS: [ ]
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
DECODER_BIAS: 0.5
DEFAULT_N_PROMPT:
USE_EMA: True
EVAL_EMA: False
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
USE_TEXT_POS_EMBEDDINGS: True
#
DIFFUSION:
NAME: ACEDiffusion
PREDICTION_TYPE: eps
MIN_SNR_GAMMA:
NOISE_SCHEDULER:
NAME: LinearScheduler
NUM_TIMESTEPS: 1000
BETA_MIN: 0.0001
BETA_MAX: 0.02
#
DIFFUSION_MODEL:
NAME: DiTACE
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/dit/ace_0.6b_512px.pth
IGNORE_KEYS: [ ]
PATCH_SIZE: 2
IN_CHANNELS: 4
HIDDEN_SIZE: 1152
DEPTH: 28
NUM_HEADS: 16
MLP_RATIO: 4.0
PRED_SIGMA: True
DROP_PATH: 0.0
WINDOW_DIZE: 0
Y_CHANNELS: 4096
MAX_SEQ_LEN: 1024
QK_NORM: True
USE_GRAD_CHECKPOINT: True
ATTENTION_BACKEND: flash_attn
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/vae/vae.bin
IGNORE_KEYS: []
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
COND_STAGE_MODEL:
NAME: ACETextEmbedder
PRETRAINED_MODEL: ms://iic/ACE-0.6B-512px@models/text_encoder/t5-v1_1-xxl/
TOKENIZER_PATH: ms://iic/ACE-0.6B-512px@models/tokenizer/t5-v1_1-xxl
LENGTH: 120
T5_DTYPE: bfloat16
ADDED_IDENTIFIER: [ '{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
CLEAN: whitespace
USE_GRAD: False
LOSS:
NAME: ReconstructLoss
LOSS_TYPE: l2
#
SAMPLE_ARGS:
SAMPLER: ddim
SAMPLE_STEPS: 20
GUIDE_SCALE: 4.5
GUIDE_RESCALE: 0.5
#
OPTIMIZER:
NAME: AdamW
LEARNING_RATE: 1e-5
EPS: 1e-10
WEIGHT_DECAY: 5e-4
#
TRAIN_DATA:
NAME: ACEDemoDataset
MODE: train
MS_DATASET_NAME: cache/datasets/hed_pair
MS_DATASET_NAMESPACE: ""
MS_DATASET_SPLIT: "train"
MS_DATASET_SUBNAME: ""
PROMPT_PREFIX: ""
REPLACE_STYLE: False
MAX_SEQ_LEN: 1024
PIN_MEMORY: True
BATCH_SIZE: 1
NUM_WORKERS: 1
SAMPLER:
NAME: LoopSampler
#
TRAIN_HOOKS:
-
NAME: BackwardHook
PRIORITY: 0
-
NAME: LogHook
LOG_INTERVAL: 50
-
NAME: CheckpointHook
INTERVAL: 100
-
NAME: ProbeDataHook
PROB_INTERVAL: 100
================================================
FILE: modules/__init__.py
================================================
from . import data, model, solver
================================================
FILE: modules/data/__init__.py
================================================
from . import dataset
================================================
FILE: modules/data/dataset/__init__.py
================================================
from .dataset import ACEDemoDataset
================================================
FILE: modules/data/dataset/dataset.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
import math
import os
import sys
from collections import defaultdict
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from scepter.modules.data.dataset.base_dataset import BaseDataset
from scepter.modules.data.dataset.registry import DATASETS
from scepter.modules.transform.io import pillow_convert
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS
Image.MAX_IMAGE_PIXELS = None
@DATASETS.register_class()
class ACEDemoDataset(BaseDataset):
para_dict = {
'MS_DATASET_NAME': {
'value': '',
'description': 'Modelscope dataset name.'
},
'MS_DATASET_NAMESPACE': {
'value': '',
'description': 'Modelscope dataset namespace.'
},
'MS_DATASET_SUBNAME': {
'value': '',
'description': 'Modelscope dataset subname.'
},
'MS_DATASET_SPLIT': {
'value': '',
'description':
'Modelscope dataset split set name, default is train.'
},
'MS_REMAP_KEYS': {
'value':
None,
'description':
'Modelscope dataset header of list file, the default is Target:FILE; '
'If your file is not this header, please set this field, which is a map dict.'
"For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
},
'MS_REMAP_PATH': {
'value':
None,
'description':
'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
' default is None. But if you want to use the datalist from modelscope and the file from '
'local device, you can use this field to set the root path of your images. '
},
'TRIGGER_WORDS': {
'value':
'',
'description':
'The words used to describe the common features of your data, especially when you customize a '
'tuner. Use these words you can get what you want.'
},
'HIGHLIGHT_KEYWORDS': {
'value':
'',
'description':
'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
},
'KEYWORDS_SIGN': {
'value':
'',
'description':
'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
},
}
def __init__(self, cfg, logger=None):
super().__init__(cfg=cfg, logger=logger)
from modelscope import MsDataset
from modelscope.utils.constant import DownloadMode
ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
ms_remap_path = cfg.get('MS_REMAP_PATH', None)
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
self.replace_style = cfg.get('REPLACE_STYLE', False)
self.trigger_words = cfg.get('TRIGGER_WORDS', '')
self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
self.add_indicator = cfg.get('ADD_INDICATOR', False)
# Use modelscope dataset
if not ms_dataset_name:
raise ValueError(
'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
'as modelscope dataset.')
if FS.exists(ms_dataset_name):
ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
self.ms_dataset_name = ms_dataset_name
# ms_remap_path = ms_dataset_name
try:
self.data = MsDataset.load(str(ms_dataset_name),
namespace=ms_dataset_namespace,
subset_name=ms_dataset_subname,
split=ms_dataset_split)
except Exception:
self.logger.info(
"Load Modelscope dataset failed, retry with download_mode='force_redownload'."
)
try:
self.data = MsDataset.load(
str(ms_dataset_name),
namespace=ms_dataset_namespace,
subset_name=ms_dataset_subname,
split=ms_dataset_split,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
except Exception as sec_e:
raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
if ms_remap_keys:
self.data = self.data.remap_columns(ms_remap_keys.get_dict())
if ms_remap_path:
def map_func(example):
return {
k: os.path.join(ms_remap_path, v)
if k.endswith(':FILE') else v
for k, v in example.items()
}
self.data = self.data.ds_instance.map(map_func)
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def __len__(self):
if self.mode == 'train':
return sys.maxsize
else:
return len(self.data)
def _get(self, index: int):
current_data = self.data[index % len(self.data)]
tar_image_path = current_data.get('Target:FILE', '')
src_image_path = current_data.get('Source:FILE', '')
style = current_data.get('Style', '')
prompt = current_data.get('Prompt', current_data.get('prompt', ''))
if self.replace_style and not style == '':
prompt = prompt.replace(style, f'<{self.keywords_sign}>')
elif not self.replace_keywords.strip() == '':
prompt = prompt.replace(
self.replace_keywords,
'<' + self.replace_keywords + f'{self.keywords_sign}>')
if not self.trigger_words == '':
prompt = self.trigger_words.strip() + ' ' + prompt
src_image = self.load_image(self.ms_dataset_name,
src_image_path,
cvt_type='RGB')
tar_image = self.load_image(self.ms_dataset_name,
tar_image_path,
cvt_type='RGB')
src_image = self.image_preprocess(src_image)
tar_image = self.image_preprocess(tar_image)
tar_image = self.transforms(tar_image)
src_image = self.transforms(src_image)
src_mask = torch.ones_like(src_image[[0]])
tar_mask = torch.ones_like(tar_image[[0]])
if self.add_indicator:
if '{image}' not in prompt:
prompt = '{image}, ' + prompt
return {
'edit_image': [src_image],
'edit_image_mask': [src_mask],
'image': tar_image,
'image_mask': tar_mask,
'prompt': [prompt],
}
def load_image(self, prefix, img_path, cvt_type=None):
if img_path is None or img_path == '':
return None
img_path = os.path.join(prefix, img_path)
with FS.get_object(img_path) as image_bytes:
image = Image.open(io.BytesIO(image_bytes))
if cvt_type is not None:
image = pillow_convert(image, cvt_type)
return image
def image_preprocess(self,
img,
size=None,
interpolation=InterpolationMode.BILINEAR):
H, W = img.height, img.width
if H / W > self.max_aspect_ratio:
img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
elif W / H > self.max_aspect_ratio:
img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)
if size is None:
# resize image for max_seq_len, while keep the aspect ratio
H, W = img.height, img.width
scale = min(
1.0,
math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
rH = int(
H * scale) // self.d * self.d # ensure divisible by self.d
rW = int(W * scale) // self.d * self.d
else:
rH, rW = size
img = T.Resize((rH, rW), interpolation=interpolation,
antialias=True)(img)
return np.array(img, dtype=np.uint8)
@staticmethod
def get_config_template():
return dict_to_yaml('DATASet',
__class__.__name__,
ACEDemoDataset.para_dict,
set_name=True)
@staticmethod
def collate_fn(batch):
collect = defaultdict(list)
for sample in batch:
for k, v in sample.items():
collect[k].append(v)
new_batch = dict()
for k, v in collect.items():
if all([i is None for i in v]):
new_batch[k] = None
else:
new_batch[k] = v
return new_batch
================================================
FILE: modules/inference/__init__.py
================================================
================================================
FILE: modules/model/__init__.py
================================================
from . import backbone, embedder, diffusion, network
================================================
FILE: modules/model/backbone/__init__.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from .ace import DiTACE
================================================
FILE: modules/model/backbone/ace.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint_sequential
from scepter.modules.model.base_model import BaseModel
from scepter.modules.model.registry import BACKBONES
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS
from .layers import (
Mlp,
TimestepEmbedder,
PatchEmbed,
DiTACEBlock,
T2IFinalLayer
)
from .pos_embed import rope_params
@BACKBONES.register_class()
class DiTACE(BaseModel):
para_dict = {
'PATCH_SIZE': {
'value': 2,
'description': ''
},
'IN_CHANNELS': {
'value': 4,
'description': ''
},
'HIDDEN_SIZE': {
'value': 1152,
'description': ''
},
'DEPTH': {
'value': 28,
'description': ''
},
'NUM_HEADS': {
'value': 16,
'description': ''
},
'MLP_RATIO': {
'value': 4.0,
'description': ''
},
'PRED_SIGMA': {
'value': True,
'description': ''
},
'DROP_PATH': {
'value': 0.,
'description': ''
},
'WINDOW_SIZE': {
'value': 0,
'description': ''
},
'WINDOW_BLOCK_INDEXES': {
'value': None,
'description': ''
},
'Y_CHANNELS': {
'value': 4096,
'description': ''
},
'ATTENTION_BACKEND': {
'value': None,
'description': ''
},
'QK_NORM': {
'value': True,
'description': 'Whether to use RMSNorm for query and key.',
},
}
para_dict.update(BaseModel.para_dict)
def __init__(self, cfg, logger):
super().__init__(cfg, logger=logger)
self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
if self.window_block_indexes is None:
self.window_block_indexes = []
self.pred_sigma = cfg.get('PRED_SIGMA', True)
self.in_channels = cfg.get('IN_CHANNELS', 4)
self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
self.patch_size = cfg.get('PATCH_SIZE', 2)
self.num_heads = cfg.get('NUM_HEADS', 16)
self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
self.y_channels = cfg.get('Y_CHANNELS', 4096)
self.drop_path = cfg.get('DROP_PATH', 0.)
self.depth = cfg.get('DEPTH', 28)
self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
self.qk_norm = cfg.get('QK_NORM', False)
self.ignore_keys = cfg.get('IGNORE_KEYS', [])
assert (self.hidden_size % self.num_heads
) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
d = self.hidden_size // self.num_heads
self.freqs = torch.cat(
[
rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3)
rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3)
rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3)
],
dim=1)
# init embedder
self.x_embedder = PatchEmbed(self.patch_size,
self.in_channels + 1,
self.hidden_size,
bias=True,
flatten=False)
self.t_embedder = TimestepEmbedder(self.hidden_size)
self.y_embedder = Mlp(in_features=self.y_channels,
hidden_features=self.hidden_size,
out_features=self.hidden_size,
act_layer=lambda: nn.GELU(approximate='tanh'),
drop=0)
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
# init blocks
drop_path = [
x.item() for x in torch.linspace(0, self.drop_path, self.depth)
]
self.blocks = nn.ModuleList([
DiTACEBlock(self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
window_size=self.window_size
if i in self.window_block_indexes else 0,
backend=self.attention_backend,
use_condition=True,
qk_norm=self.qk_norm) for i in range(self.depth)
])
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
self.out_channels)
self.initialize_weights()
def load_pretrained_model(self, pretrained_model):
if pretrained_model:
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
model = torch.load(local_path, map_location='cpu')
if 'state_dict' in model:
model = model['state_dict']
new_ckpt = OrderedDict()
for k, v in model.items():
if self.ignore_keys is not None:
if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
(isinstance(self.ignore_keys, list) and k in self.ignore_keys):
continue
k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
k = k.replace('.cross_attn.proj.',
'.cross_attn.o.').replace(
'.attn.proj.', '.attn.o.')
if '.cross_attn.kv_linear.' in k:
k_p, v_p = torch.split(v, v.shape[0] // 2)
new_ckpt[k.replace('.cross_attn.kv_linear.',
'.cross_attn.k.')] = k_p
new_ckpt[k.replace('.cross_attn.kv_linear.',
'.cross_attn.v.')] = v_p
elif '.attn.qkv.' in k:
q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
elif 'y_embedder.y_proj.' in k:
new_ckpt[k.replace('y_embedder.y_proj.',
'y_embedder.')] = v
elif k in ('x_embedder.proj.weight'):
model_p = self.state_dict()[k]
if v.shape != model_p.shape:
model_p.zero_()
model_p[:, :4, :, :].copy_(v)
new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
else:
new_ckpt[k] = v
elif k in ('x_embedder.proj.bias'):
new_ckpt[k] = v
else:
new_ckpt[k] = v
missing, unexpected = self.load_state_dict(new_ckpt,
strict=False)
print(
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f'Missing Keys:\n {missing}')
if len(unexpected) > 0:
print(f'\nUnexpected Keys:\n {unexpected}')
def forward(self,
x,
t=None,
cond=dict(),
mask=None,
text_position_embeddings=None,
gc_seg=-1,
**kwargs):
if self.freqs.device != x.device:
self.freqs = self.freqs.to(x.device)
if isinstance(cond, dict):
context = cond.get('crossattn', None)
else:
context = cond
if text_position_embeddings is not None:
# default use the text_position_embeddings in state_dict
# if state_dict doesn't including this key, use the arg: text_position_embeddings
proj_position_embeddings = self.y_embedder(
text_position_embeddings)
else:
proj_position_embeddings = None
ctx_batch, txt_lens = [], []
if mask is not None and isinstance(mask, list):
for ctx, ctx_mask in zip(context, mask):
for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
u, m = one_ctx
t_len = m.flatten().sum() # l
u = u[:t_len]
u = self.y_embedder(u)
if frame_id == 0:
u = u + proj_position_embeddings[
len(ctx) -
1] if proj_position_embeddings is not None else u
else:
u = u + proj_position_embeddings[
frame_id -
1] if proj_position_embeddings is not None else u
ctx_batch.append(u)
txt_lens.append(t_len)
else:
raise TypeError
y = torch.cat(ctx_batch, dim=0)
txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)
batch_frames = []
for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
if 'edit' in cond:
for i, (edit, edit_mask) in enumerate(
zip(cond['edit'], cond['edit_mask'])):
if edit is None:
continue
for u, m in zip(edit, edit_mask):
u = u.squeeze(0)
m = torch.ones_like(
u[[0], :, :]) if m is None else m.squeeze(0)
batch_frames[i].append(
torch.cat([u, m], dim=0).unsqueeze(0))
patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
for frames in batch_frames:
patches, patch_shapes = [], []
self_x_len.append(0)
for frame_id, u in enumerate(frames):
u = self.x_embedder(u)
h, w = u.size(2), u.size(3)
u = rearrange(u, '1 c h w -> (h w) c')
if frame_id == 0:
u = u + proj_position_embeddings[
len(frames) -
1] if proj_position_embeddings is not None else u
else:
u = u + proj_position_embeddings[
frame_id -
1] if proj_position_embeddings is not None else u
patches.append(u)
patch_shapes.append([h, w])
cross_x_len.append(h * w) # b*s, 1
self_x_len[-1] += h * w # b, 1
# u = torch.cat(patches, dim=0)
patch_batch.extend(patches)
shape_batch.append(
torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
# repeat t to align with x
t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
x.device, non_blocking=True))
# x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c
x = torch.cat(patch_batch, dim=0)
x_shapes = pad_sequence(tuple(shape_batch),
batch_first=True) # b, max(len(frames)), 2
t = self.t_embedder(t) # (N, D)
t0 = self.t_block(t)
# y = self.y_embedder(context)
kwargs = dict(y=y,
t=t0,
x_shapes=x_shapes,
self_x_len=self_x_len,
cross_x_len=cross_x_len,
freqs=self.freqs,
txt_lens=txt_lens)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.blocks],
segments=gc_seg if gc_seg > 0 else len(self.blocks),
input=x,
use_reentrant=False)
else:
for block in self.blocks:
x = block(x, **kwargs)
x = self.final_layer(x, t) # b*s*n, d
outs, cur_length = [], 0
p = self.patch_size
for seq_length, shape in zip(self_x_len, shape_batch):
x_i = x[cur_length:cur_length + seq_length]
h, w = shape[0].tolist()
u = x_i[:h * w].view(h, w, p, p, -1)
u = rearrange(u, 'h w p q c -> (h p w q) c'
) # dump into sequence for following tensor ops
cur_length = cur_length + seq_length
outs.append(u)
x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
if self.pred_sigma:
return x.chunk(2, dim=1)[0]
else:
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
if hasattr(self, 'y_embedder'):
nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers
for block in self.blocks:
nn.init.constant_(block.cross_attn.o.weight, 0)
nn.init.constant_(block.cross_attn.o.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
@staticmethod
def get_config_template():
return dict_to_yaml('BACKBONE',
__class__.__name__,
DiTACE.para_dict,
set_name=True)
================================================
FILE: modules/model/backbone/layers.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import warnings
import torch
import torch.nn as nn
from .pos_embed import rope_apply_multires as rope_apply
try:
from flash_attn import (flash_attn_varlen_func)
FLASHATTN_IS_AVAILABLE = True
except ImportError as e:
FLASHATTN_IS_AVAILABLE = False
flash_attn_varlen_func = None
warnings.warn(f'{e}')
__all__ = [
"drop_path",
"modulate",
"PatchEmbed",
"DropPath",
"RMSNorm",
"Mlp",
"TimestepEmbedder",
"DiTEditBlock",
"MultiHeadAttentionDiTEdit",
"T2IFinalLayer",
]
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape
gitextract_gx2sgdse/
├── .gitignore
├── LICENSE
├── __init__.py
├── chatbot/
│ ├── __init__.py
│ ├── ace_inference.py
│ ├── config/
│ │ ├── chatbot_ui.yaml
│ │ └── models/
│ │ └── ace_0.6b_512.yaml
│ ├── example.py
│ ├── infer.py
│ ├── run_gradio.py
│ └── utils.py
├── config/
│ ├── inference_config/
│ │ ├── chatbot_ui.yaml
│ │ └── models/
│ │ ├── ace_0.6b_1024.yaml
│ │ ├── ace_0.6b_1024_refiner.yaml
│ │ └── ace_0.6b_512.yaml
│ └── train_config/
│ ├── ace_0.6b_1024_train.yaml
│ └── ace_0.6b_512_train.yaml
├── modules/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset/
│ │ ├── __init__.py
│ │ └── dataset.py
│ ├── inference/
│ │ └── __init__.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── backbone/
│ │ │ ├── __init__.py
│ │ │ ├── ace.py
│ │ │ ├── layers.py
│ │ │ └── pos_embed.py
│ │ ├── diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── diffusions.py
│ │ │ ├── samplers.py
│ │ │ └── schedules.py
│ │ ├── embedder/
│ │ │ ├── __init__.py
│ │ │ └── embedder.py
│ │ ├── network/
│ │ │ ├── __init__.py
│ │ │ └── ldm_ace.py
│ │ └── utils/
│ │ └── basic_utils.py
│ └── solver/
│ ├── __init__.py
│ └── ace_solver.py
├── readme.md
├── requirements.txt
└── tools/
├── run_inference.py
└── run_train.py
SYMBOL INDEX (152 symbols across 18 files)
FILE: chatbot/ace_inference.py
function process_edit_image (line 26) | def process_edit_image(images,
class TextEmbedding (line 83) | class TextEmbedding(nn.Module):
method __init__ (line 84) | def __init__(self, embedding_shape):
class RefinerInference (line 88) | class RefinerInference(DiffusionInference):
method init_from_cfg (line 89) | def init_from_cfg(self, cfg):
method encode_first_stage (line 101) | def encode_first_stage(self, x, **kwargs):
method upscale_resize (line 113) | def upscale_resize(self, image, interpolation=T.InterpolationMode.BILI...
method decode_first_stage (line 121) | def decode_first_stage(self, z):
method noise_sample (line 128) | def noise_sample(self, num_samples, h, w, seed, device = None, dtype =...
method refine (line 140) | def refine(self,
class ACEInference (line 221) | class ACEInference(DiffusionInference):
method __init__ (line 222) | def __init__(self, logger=None):
method init_from_cfg (line 231) | def init_from_cfg(self, cfg):
method encode_first_stage (line 292) | def encode_first_stage(self, x, **kwargs):
method decode_first_stage (line 304) | def decode_first_stage(self, z):
method __call__ (line 319) | def __call__(self,
method cond_stage_embeddings (line 534) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
FILE: chatbot/example.py
function download_image (line 9) | def download_image(image, local_path=None):
function blank_image (line 14) | def blank_image():
function get_examples (line 19) | def get_examples(cache_dir):
FILE: chatbot/infer.py
function process_edit_image (line 26) | def process_edit_image(images,
class TextEmbedding (line 83) | class TextEmbedding(nn.Module):
method __init__ (line 84) | def __init__(self, embedding_shape):
class ACEInference (line 89) | class ACEInference(DiffusionInference):
method __init__ (line 90) | def __init__(self, logger=None):
method init_from_cfg (line 99) | def init_from_cfg(self, cfg):
method encode_first_stage (line 146) | def encode_first_stage(self, x, **kwargs):
method decode_first_stage (line 158) | def decode_first_stage(self, z):
method __call__ (line 171) | def __call__(self,
method cond_stage_embeddings (line 346) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
FILE: chatbot/run_gradio.py
class ChatBotUI (line 47) | class ChatBotUI(object):
method __init__ (line 48) | def __init__(self,
method create_ui (line 177) | def create_ui(self):
method set_callbacks (line 523) | def set_callbacks(self, *args, **kwargs):
method get_history (line 1274) | def get_history(self, history):
method generate_random_string (line 1289) | def generate_random_string(self, length=20):
method add_edited_image_to_history (line 1295) | def add_edited_image_to_history(self, image, mask_type, history, images):
method add_uploaded_image_to_history (line 1356) | def add_uploaded_image_to_history(self, img, history, images):
function run_gr (line 1401) | def run_gr(cfg):
FILE: chatbot/utils.py
function build_transform (line 12) | def build_transform(input_size):
function find_closest_aspect_ratio (line 24) | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
function dynamic_preprocess (line 41) | def dynamic_preprocess(image,
function load_image (line 83) | def load_image(image_file, input_size=448, max_num=12):
FILE: modules/data/dataset/dataset.py
class ACEDemoDataset (line 25) | class ACEDemoDataset(BaseDataset):
method __init__ (line 81) | def __init__(self, cfg, logger=None):
method __len__ (line 146) | def __len__(self):
method _get (line 152) | def _get(self, index: int):
method load_image (line 196) | def load_image(self, prefix, img_path, cvt_type=None):
method image_preprocess (line 206) | def image_preprocess(self,
method get_config_template (line 232) | def get_config_template():
method collate_fn (line 239) | def collate_fn(batch):
FILE: modules/model/backbone/ace.py
class DiTACE (line 29) | class DiTACE(BaseModel):
method __init__ (line 87) | def __init__(self, cfg, logger):
method load_pretrained_model (line 152) | def load_pretrained_model(self, pretrained_model):
method forward (line 204) | def forward(self,
method initialize_weights (line 336) | def initialize_weights(self):
method dtype (line 365) | def dtype(self):
method get_config_template (line 369) | def get_config_template():
FILE: modules/model/backbone/layers.py
function drop_path (line 30) | def drop_path(x, drop_prob: float = 0., training: bool = False):
function modulate (line 50) | def modulate(x, shift, scale, unsqueeze=False):
class PatchEmbed (line 57) | class PatchEmbed(nn.Module):
method __init__ (line 60) | def __init__(
method forward (line 78) | def forward(self, x):
class DropPath (line 86) | class DropPath(nn.Module):
method __init__ (line 89) | def __init__(self, drop_prob=None):
method forward (line 93) | def forward(self, x):
class RMSNorm (line 97) | class RMSNorm(nn.Module):
method __init__ (line 98) | def __init__(self, dim, eps=1e-6):
method forward (line 104) | def forward(self, x):
method _norm (line 107) | def _norm(self, x):
class Mlp (line 111) | class Mlp(nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 128) | def forward(self, x):
class TimestepEmbedder (line 137) | class TimestepEmbedder(nn.Module):
method __init__ (line 141) | def __init__(self, hidden_size, frequency_embedding_size=256):
method timestep_embedding (line 151) | def timestep_embedding(t, dim, max_period=10000):
method forward (line 173) | def forward(self, t):
class DiTACEBlock (line 179) | class DiTACEBlock(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 226) | def forward(self, x, y, t, **kwargs):
class MultiHeadAttention (line 244) | class MultiHeadAttention(nn.Module):
method __init__ (line 245) | def __init__(self,
method flash_attn (line 291) | def flash_attn(self, x, context=None, **kwargs):
method forward (line 359) | def forward(self, x, context=None, **kwargs):
class T2IFinalLayer (line 364) | class T2IFinalLayer(nn.Module):
method __init__ (line 368) | def __init__(self, hidden_size, patch_size, out_channels):
method forward (line 380) | def forward(self, x, t):
FILE: modules/model/backbone/pos_embed.py
function frame_pad (line 9) | def frame_pad(x, seq_len, shapes):
function frame_unpad (line 26) | def frame_unpad(x, shapes):
function rope_apply_multires (line 38) | def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
function rope_params (line 75) | def rope_params(max_seq_len, dim, theta=10000):
FILE: modules/model/diffusion/diffusions.py
class ACEDiffusion (line 18) | class ACEDiffusion(object):
method __init__ (line 29) | def __init__(self, cfg, logger=None):
method init_params (line 35) | def init_params(self):
method sample (line 53) | def sample(self,
method loss (line 133) | def loss(self,
method get_sampler (line 159) | def get_sampler(self, sampler):
method __repr__ (line 185) | def __repr__(self) -> str:
method get_config_template (line 189) | def get_config_template():
FILE: modules/model/diffusion/samplers.py
function _i (line 9) | def _i(tensor, t, x):
class DDIMSampler (line 20) | class DDIMSampler(BaseDiffusionSampler):
method init_params (line 21) | def init_params(self):
method preprare_sampler (line 27) | def preprare_sampler(self,
method step (line 60) | def step(self, sampler_output):
FILE: modules/model/diffusion/schedules.py
class ScheduleOutput (line 10) | class ScheduleOutput(object):
method add_custom_field (line 18) | def add_custom_field(self, key: str, value) -> None:
class LinearScheduler (line 23) | class LinearScheduler(BaseNoiseScheduler):
method init_params (line 26) | def init_params(self):
method betas_to_sigmas (line 31) | def betas_to_sigmas(self, betas):
method get_schedule (line 34) | def get_schedule(self):
method add_noise (line 46) | def add_noise(self, x_0, noise=None, t=None, **kwargs):
FILE: modules/model/embedder/embedder.py
class ACETextEmbedder (line 25) | class ACETextEmbedder(BaseEmbedder):
method __init__ (line 70) | def __init__(self, cfg, logger=None):
method freeze (line 112) | def freeze(self):
method forward (line 118) | def forward(self, tokens, return_mask=False, use_mask=True):
method _clean (line 134) | def _clean(self, text):
method encode (line 145) | def encode(self, text, return_mask=False, use_mask=True):
method encode_list (line 167) | def encode_list(self, text_list, return_mask=True):
method get_config_template (line 180) | def get_config_template():
FILE: modules/model/network/ldm_ace.py
class TextEmbedding (line 23) | class TextEmbedding(nn.Module):
method __init__ (line 24) | def __init__(self, embedding_shape):
class LdmACE (line 30) | class LdmACE(LatentDiffusion):
method __init__ (line 34) | def __init__(self, cfg, logger=None):
method encode_first_stage (line 53) | def encode_first_stage(self, x, **kwargs):
method decode_first_stage (line 61) | def decode_first_stage(self, z):
method cond_stage_embeddings (line 67) | def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
method limit_batch_data (line 85) | def limit_batch_data(self, batch_data_list, log_num):
method forward_train (line 96) | def forward_train(self,
method forward_test (line 186) | def forward_test(self,
method get_config_template (line 348) | def get_config_template():
FILE: modules/model/utils/basic_utils.py
function exists (line 11) | def exists(x):
function default (line 15) | def default(val, d):
function disabled_train (line 21) | def disabled_train(self, mode=True):
function transfer_size (line 27) | def transfer_size(para_num):
function count_params (line 44) | def count_params(model):
function expand_dims_like (line 49) | def expand_dims_like(x, y):
function unpack_tensor_into_imagelist (line 55) | def unpack_tensor_into_imagelist(image_tensor, shapes):
function find_example (line 64) | def find_example(tensor_list, image_list):
function pack_imagelist_into_tensor_v2 (line 75) | def pack_imagelist_into_tensor_v2(image_list):
function to_device (line 95) | def to_device(inputs, strict=True):
function check_list_of_list (line 103) | def check_list_of_list(ll):
FILE: modules/solver/ace_solver.py
class ACESolverV1 (line 16) | class ACESolverV1(LatentDiffusionSolver):
method __init__ (line 17) | def __init__(self, cfg, logger=None):
method save_results (line 21) | def save_results(self, results):
method run_eval (line 65) | def run_eval(self):
method run_test (line 95) | def run_test(self):
method probe_data (line 126) | def probe_data(self):
FILE: tools/run_inference.py
function run_one_case (line 31) | def run_one_case(pipe, input_image, input_mask, edit_k,
function run (line 63) | def run():
FILE: tools/run_train.py
function run_task (line 21) | def run_task(cfg):
function update_config (line 29) | def update_config(cfg):
function run (line 45) | def run():
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (286K chars).
[
{
"path": ".gitignore",
"chars": 179,
"preview": "*.pyc\n*.pth\n*.pt\n*.pkl\n*.ckpt\n*.DS_Store\n*__pycache__*\n*.cache*\n*.bin\n*.idea\n*.csv\ncache\nbuild\ndist\ndev\nscepter.egg-info"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "__init__.py",
"chars": 43,
"preview": "from . import modules\nfrom . import chatbot"
},
{
"path": "chatbot/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "chatbot/ace_inference.py",
"chars": 24539,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\n\nimport "
},
{
"path": "chatbot/config/chatbot_ui.yaml",
"chars": 887,
"preview": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n - NAME: \"HuggingfaceFs\"\n TEMP_DIR: ./cache\n - NAME: \"ModelscopeFs\"\n TEMP"
},
{
"path": "chatbot/config/models/ace_0.6b_512.yaml",
"chars": 2961,
"preview": "NAME: ACE_0.6B_512\nIS_DEFAULT: False\nDEFAULT_PARAS:\n PARAS:\n #\n INPUT:\n INPUT_IMAGE:\n INPUT_MASK:\n TASK:\n "
},
{
"path": "chatbot/example.py",
"chars": 19638,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport os\n\nfrom scepter.modules.utils.file_sys"
},
{
"path": "chatbot/infer.py",
"chars": 14660,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport math\nimport random\nimport n"
},
{
"path": "chatbot/run_gradio.py",
"chars": 68798,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport base64\nimport copy\nimpo"
},
{
"path": "chatbot/utils.py",
"chars": 3702,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nimport torchvision.transforms as "
},
{
"path": "config/inference_config/chatbot_ui.yaml",
"chars": 917,
"preview": "WORK_DIR: ./cache/chatbot\nFILE_SYSTEM:\n - NAME: \"HuggingfaceFs\"\n TEMP_DIR: ./cache\n - NAME: \"ModelscopeFs\"\n TEMP"
},
{
"path": "config/inference_config/models/ace_0.6b_1024.yaml",
"chars": 2994,
"preview": "NAME: ACE_0.6B_1024\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n PARAS:\n #\n INPUT:\n INPUT_IMAGE:\n "
},
{
"path": "config/inference_config/models/ace_0.6b_1024_refiner.yaml",
"chars": 8606,
"preview": "NAME: ACE_0.6B_1024_REFINER\nIS_DEFAULT: False\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n PARAS:\n #\n INPUT:\n INPUT_IM"
},
{
"path": "config/inference_config/models/ace_0.6b_512.yaml",
"chars": 2985,
"preview": "NAME: ACE_0.6B_512\nIS_DEFAULT: True\nUSE_DYNAMIC_MODEL: False\nDEFAULT_PARAS:\n PARAS:\n #\n INPUT:\n INPUT_IMAGE:\n I"
},
{
"path": "config/train_config/ace_0.6b_1024_train.yaml",
"chars": 3727,
"preview": "ENV:\n BACKEND: nccl\n SEED: 2024\n#\nSOLVER:\n NAME: ACESolverV1\n RESUME_FROM:\n LOAD_MODEL_ONLY: True\n USE_FSDP: False"
},
{
"path": "config/train_config/ace_0.6b_512_train.yaml",
"chars": 3727,
"preview": "ENV:\n BACKEND: nccl\n SEED: 2024\n#\nSOLVER:\n NAME: ACESolverV1\n RESUME_FROM:\n LOAD_MODEL_ONLY: True\n USE_FSDP: False"
},
{
"path": "modules/__init__.py",
"chars": 33,
"preview": "from . import data, model, solver"
},
{
"path": "modules/data/__init__.py",
"chars": 21,
"preview": "from . import dataset"
},
{
"path": "modules/data/dataset/__init__.py",
"chars": 35,
"preview": "from .dataset import ACEDemoDataset"
},
{
"path": "modules/data/dataset/dataset.py",
"chars": 9569,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nimport io\nimport math\nimport os\nimport sys\nfr"
},
{
"path": "modules/inference/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "modules/model/__init__.py",
"chars": 52,
"preview": "from . import backbone, embedder, diffusion, network"
},
{
"path": "modules/model/backbone/__init__.py",
"chars": 98,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom .ace import DiTACE\n"
},
{
"path": "modules/model/backbone/ace.py",
"chars": 15610,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport re\nfrom collections import OrderedDict\n"
},
{
"path": "modules/model/backbone/layers.py",
"chars": 14218,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport warnings\nimport torch\nimpor"
},
{
"path": "modules/model/backbone/pos_embed.py",
"chars": 2724,
"preview": "import numpy as np\nfrom einops import rearrange\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn.functional as"
},
{
"path": "modules/model/diffusion/__init__.py",
"chars": 184,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nfrom .diffusions import ACEDiffusion\nfrom .sa"
},
{
"path": "modules/model/diffusion/diffusions.py",
"chars": 7872,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport math\nimport os\nfrom collections import "
},
{
"path": "modules/model/diffusion/samplers.py",
"chars": 3523,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\n\nfrom scepter.modules.model.regis"
},
{
"path": "modules/model/diffusion/schedules.py",
"chars": 1939,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport torch\nfrom dataclasses import dataclass"
},
{
"path": "modules/model/embedder/__init__.py",
"chars": 37,
"preview": "from .embedder import ACETextEmbedder"
},
{
"path": "modules/model/embedder/embedder.py",
"chars": 6754,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport warnings\nfrom contextlib import nullcon"
},
{
"path": "modules/model/network/__init__.py",
"chars": 27,
"preview": "from .ldm_ace import LdmACE"
},
{
"path": "modules/model/network/ldm_ace.py",
"chars": 14723,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport copy\nimport random\nfrom contextlib impo"
},
{
"path": "modules/model/utils/basic_utils.py",
"chars": 2885,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nfrom inspect import isfunction\n\nimport torch\nf"
},
{
"path": "modules/solver/__init__.py",
"chars": 35,
"preview": "from .ace_solver import ACESolverV1"
},
{
"path": "modules/solver/ace_solver.py",
"chars": 6257,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport numpy as np\nimport torch\nfrom tqdm impo"
},
{
"path": "readme.md",
"chars": 10710,
"preview": "<p align=\"center\">\n\n <h2 align=\"center\"><img src=\"assets/figures/icon.png\" height=16> : All-round Creator and Editor Fo"
},
{
"path": "requirements.txt",
"chars": 330,
"preview": "git+https://github.com/modelscope/scepter.git@v1.3.0_dev#egg=scepter\npycocotools\npyyaml>=5.3.1\nscikit-image\ntorchsde\ntra"
},
{
"path": "tools/run_inference.py",
"chars": 6495,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport io\nimp"
},
{
"path": "tools/run_train.py",
"chars": 2114,
"preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\nimport argparse\nimport importlib\nimport os\nimp"
}
]
About this extraction
This page contains the full source code of the ali-vilab/ACE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (269.5 KB), approximately 62.7k tokens, and a symbol index with 152 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.