Showing preview only (646K chars total). Download the full file or copy to clipboard to get everything.
Repository: fudan-generative-vision/hallo
Branch: main
Commit: 8fd7c572a3d4
Files: 48
Total size: 623.2 KB
Directory structure:
gitextract_p10rx2rb/
├── .github/
│ └── workflows/
│ └── static-check.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── LICENSE
├── README.md
├── accelerate_config.yaml
├── configs/
│ ├── inference/
│ │ ├── .gitkeep
│ │ └── default.yaml
│ ├── train/
│ │ ├── stage1.yaml
│ │ └── stage2.yaml
│ └── unet/
│ └── unet.yaml
├── hallo/
│ ├── __init__.py
│ ├── animate/
│ │ ├── __init__.py
│ │ ├── face_animate.py
│ │ └── face_animate_static.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── audio_processor.py
│ │ ├── image_processor.py
│ │ ├── mask_image.py
│ │ └── talk_video.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── audio_proj.py
│ │ ├── face_locator.py
│ │ ├── image_proj.py
│ │ ├── motion_module.py
│ │ ├── mutual_self_attention.py
│ │ ├── resnet.py
│ │ ├── transformer_2d.py
│ │ ├── transformer_3d.py
│ │ ├── unet_2d_blocks.py
│ │ ├── unet_2d_condition.py
│ │ ├── unet_3d.py
│ │ ├── unet_3d_blocks.py
│ │ └── wav2vec.py
│ └── utils/
│ ├── __init__.py
│ ├── config.py
│ └── util.py
├── requirements.txt
├── scripts/
│ ├── app.py
│ ├── data_preprocess.py
│ ├── extract_meta_info_stage1.py
│ ├── extract_meta_info_stage2.py
│ ├── inference.py
│ ├── train_stage1.py
│ └── train_stage2.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/static-check.yaml
================================================
name: Pylint
on: [push, pull_request]
jobs:
static-check:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-22.04]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pylint
python -m pip install --upgrade isort
python -m pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
isort $(git ls-files '*.py') --check-only --diff
pylint $(git ls-files '*.py')
================================================
FILE: .gitignore
================================================
# running cache
mlruns/
# Test directories
test_data/
pretrained_models/
# Poetry project
poetry.lock
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDE
.idea/
.vscode/
data
pretrained_models
test_data
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: local
hooks:
- id: isort
name: isort
language: system
types: [python]
pass_filenames: false
entry: isort
args: ["."]
- id: pylint
name: pylint
language: system
types: [python]
pass_filenames: false
entry: pylint
args: ["**/*.py"]
================================================
FILE: .pylintrc
================================================
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=cv2
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=cv2
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
init-hook='import sys; sys.path.append(".")'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.10
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
#verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
variable-rgx=(_?[a-z][A-Za-z0-9]{0,30})|([A-Z0-9]{1,30})
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=7
# Maximum number of attributes for a class (see R0902).
max-attributes=20
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=300
# Minimum number of public methods for a class (see R0903).
min-public-methods=1
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=150
# Maximum number of lines in a module.
max-module-lines=2000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=too-many-arguments,
too-many-locals,
too-many-branches,
protected-access
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
# suggest-join-with-non-empty-separator=yes
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<h1 align='center'>Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation</h1>
<div align='center'>
<a href='https://github.com/xumingw' target='_blank'>Mingwang Xu</a><sup>1*</sup> 
<a href='https://github.com/crystallee-ai' target='_blank'>Hui Li</a><sup>1*</sup> 
<a href='https://github.com/subazinga' target='_blank'>Qingkun Su</a><sup>1*</sup> 
<a href='https://github.com/NinoNeumann' target='_blank'>Hanlin Shang</a><sup>1</sup> 
<a href='https://github.com/AricGamma' target='_blank'>Liwei Zhang</a><sup>1</sup> 
<a href='https://github.com/cnexah' target='_blank'>Ce Liu</a><sup>3</sup> 
</div>
<div align='center'>
<a href='https://jingdongwang2017.github.io/' target='_blank'>Jingdong Wang</a><sup>2</sup> 
<a href='https://yoyo000.github.io/' target='_blank'>Yao Yao</a><sup>4</sup> 
<a href='https://sites.google.com/site/zhusiyucs/home' target='_blank'>Siyu Zhu</a><sup>1</sup> 
</div>
<div align='center'>
<sup>1</sup>Fudan University  <sup>2</sup>Baidu Inc  <sup>3</sup>ETH Zurich  <sup>4</sup>Nanjing University
</div>
<br>
<div align='center'>
<a href='https://github.com/fudan-generative-vision/hallo'><img src='https://img.shields.io/github/stars/fudan-generative-vision/hallo?style=social'></a>
<a href='https://fudan-generative-vision.github.io/hallo/#/'><img src='https://img.shields.io/badge/Project-HomePage-Green'></a>
<a href='https://arxiv.org/pdf/2406.08801'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
<a href='https://huggingface.co/fudan-generative-ai/hallo'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow'></a>
<a href='https://huggingface.co/spaces/fffiloni/tts-hallo-talking-portrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Demo-yellow'></a>
<a href='https://www.modelscope.cn/models/fudan-generative-vision/Hallo/summary'><img src='https://img.shields.io/badge/Modelscope-Model-purple'></a>
<a href='assets/wechat.jpeg'><img src='https://badges.aleen42.com/src/wechat.svg'></a>
</div>
<br>
## 📸 Showcase
https://github.com/fudan-generative-vision/hallo/assets/17402682/9d1a0de4-3470-4d38-9e4f-412f517f834c
### 🎬 Honoring Classic Films
<table class="center">
<tr>
<td style="text-align: center"><b>Devil Wears Prada</b></td>
<td style="text-align: center"><b>Green Book</b></td>
<td style="text-align: center"><b>Infernal Affairs</b></td>
</tr>
<tr>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/Devil_Wears_Prada-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Devil_Wears_Prada_GIF.gif"></a></td>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/Green_Book-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Green_Book_GIF.gif"></a></td>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/无间道-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Infernal_Affairs_GIF.gif"></a></td>
</tr>
<tr>
<td style="text-align: center"><b>Patch Adams</b></td>
<td style="text-align: center"><b>Tough Love</b></td>
<td style="text-align: center"><b>Shawshank Redemption</b></td>
</tr>
<tr>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/Patch_Adams-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Patch_Adams_GIF.gif"></a></td>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/Tough_Love-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Tough_Love_GIF.gif"></a></td>
<td style="text-align: center"><a target="_blank" href="https://cdn.aondata.work/video/short_movie/Shawshank-480p.mp4"><img src="https://cdn.aondata.work/img/short_movie/Shawshank_GIF.gif"></a></td>
</tr>
</table>
Explore [more examples](https://fudan-generative-vision.github.io/hallo).
## 📰 News
- **`2024/06/28`**: 🎉🎉🎉 We are proud to announce the release of our model training code. Try your own training data. Here is [tutorial](#training).
- **`2024/06/21`**: 🚀🚀🚀 Cloned a Gradio demo on [🤗Huggingface space](https://huggingface.co/spaces/fudan-generative-ai/hallo).
- **`2024/06/20`**: 🌟🌟🌟 Received numerous contributions from the community, including a [Windows version](https://github.com/sdbds/hallo-for-windows), [ComfyUI](https://github.com/AIFSH/ComfyUI-Hallo), [WebUI](https://github.com/fudan-generative-vision/hallo/pull/51), and [Docker template](https://github.com/ashleykleynhans/hallo-docker).
- **`2024/06/15`**: ✨✨✨ Released some images and audios for inference testing on [🤗Huggingface](https://huggingface.co/datasets/fudan-generative-ai/hallo_inference_samples).
- **`2024/06/15`**: 🎉🎉🎉 Launched the first version on 🫡[GitHub](https://github.com/fudan-generative-vision/hallo).
## 🤝 Community Resources
Explore the resources developed by our community to enhance your experience with Hallo:
- [TTS x Hallo Talking Portrait Generator](https://huggingface.co/spaces/fffiloni/tts-hallo-talking-portrait) - Check out this awesome Gradio demo by [@Sylvain Filoni](https://huggingface.co/fffiloni)! With this tool, you can conveniently prepare portrait image and audio for Hallo.
- [Demo on Huggingface](https://huggingface.co/spaces/multimodalart/hallo) - Check out this easy-to-use Gradio demo by [@multimodalart](https://huggingface.co/multimodalart).
- [hallo-webui](https://github.com/daswer123/hallo-webui) - Explore the WebUI created by [@daswer123](https://github.com/daswer123).
- [hallo-for-windows](https://github.com/sdbds/hallo-for-windows) - Utilize Hallo on Windows with the guide by [@sdbds](https://github.com/sdbds).
- [ComfyUI-Hallo](https://github.com/AIFSH/ComfyUI-Hallo) - Integrate Hallo with the ComfyUI tool by [@AIFSH](https://github.com/AIFSH).
- [hallo-docker](https://github.com/ashleykleynhans/hallo-docker) - Docker image for Hallo by [@ashleykleynhans](https://github.com/ashleykleynhans).
- [RunPod Template](https://runpod.io/console/deploy?template=aeyibwyvzy&ref=2xxro4syy) - Deploy Hallo to RunPod by [@ashleykleynhans](https://github.com/ashleykleynhans).
- [JoyHallo](https://jdh-algo.github.io/JoyHallo/) - JoyHallo extends the capabilities of Hallo, enabling it to support Mandarin
Thanks to all of them.
Join our community and explore these amazing resources to make the most out of Hallo. Enjoy and elevate their creative projects!
## 🔧️ Framework


## ⚙️ Installation
- System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 12.1
- Tested GPUs: A100
Create conda environment:
```bash
conda create -n hallo python=3.10
conda activate hallo
```
Install packages with `pip`
```bash
pip install -r requirements.txt
pip install .
```
Besides, ffmpeg is also needed:
```bash
apt-get install ffmpeg
```
## 🗝️️ Usage
The entry point for inference is `scripts/inference.py`. Before testing your cases, two preparations need to be completed:
1. [Download all required pretrained models](#download-pretrained-models).
2. [Prepare source image and driving audio pairs](#prepare-inference-data).
3. [Run inference](#run-inference).
### 📥 Download Pretrained Models
You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo).
Clone the pretrained models into `${PROJECT_ROOT}/pretrained_models` directory by cmd below:
```shell
git lfs install
git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
```
Or you can download them separately from their source repo:
- [hallo](https://huggingface.co/fudan-generative-ai/hallo/tree/main/hallo): Our checkpoints consist of denoising UNet, face locator, image & audio proj.
- [audio_separator](https://huggingface.co/huangjackson/Kim_Vocal_2): Kim\_Vocal\_2 MDX-Net vocal removal model. (_Thanks to [KimberleyJensen](https://github.com/KimberleyJensen)_)
- [insightface](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo): 2D and 3D Face Analysis placed into `pretrained_models/face_analysis/models/`. (_Thanks to deepinsight_)
- [face landmarker](https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task): Face detection & mesh model from [mediapipe](https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker#models) placed into `pretrained_models/face_analysis/models`.
- [motion module](https://github.com/guoyww/AnimateDiff/blob/main/README.md#202309-animatediff-v2): motion module from [AnimateDiff](https://github.com/guoyww/AnimateDiff). (_Thanks to [guoyww](https://github.com/guoyww)_).
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse): Weights are intended to be used with the diffusers library. (_Thanks to [stablilityai](https://huggingface.co/stabilityai)_)
- [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5): Initialized and fine-tuned from Stable-Diffusion-v1-2. (_Thanks to [runwayml](https://huggingface.co/runwayml)_)
- [wav2vec](https://huggingface.co/facebook/wav2vec2-base-960h): wav audio to vector model from [Facebook](https://huggingface.co/facebook/wav2vec2-base-960h).
Finally, these pretrained models should be organized as follows:
```text
./pretrained_models/
|-- audio_separator/
| |-- download_checks.json
| |-- mdx_model_data.json
| |-- vr_model_data.json
| `-- Kim_Vocal_2.onnx
|-- face_analysis/
| `-- models/
| |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe
| |-- 1k3d68.onnx
| |-- 2d106det.onnx
| |-- genderage.onnx
| |-- glintr100.onnx
| `-- scrfd_10g_bnkps.onnx
|-- motion_module/
| `-- mm_sd_v15_v2.ckpt
|-- sd-vae-ft-mse/
| |-- config.json
| `-- diffusion_pytorch_model.safetensors
|-- stable-diffusion-v1-5/
| `-- unet/
| |-- config.json
| `-- diffusion_pytorch_model.safetensors
`-- wav2vec/
`-- wav2vec2-base-960h/
|-- config.json
|-- feature_extractor_config.json
|-- model.safetensors
|-- preprocessor_config.json
|-- special_tokens_map.json
|-- tokenizer_config.json
`-- vocab.json
```
### 🛠️ Prepare Inference Data
Hallo has a few simple requirements for input data:
For the source image:
1. It should be cropped into squares.
2. The face should be the main focus, making up 50%-70% of the image.
3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles).
For the driving audio:
1. It must be in WAV format.
2. It must be in English since our training datasets are only in this language.
3. Ensure the vocals are clear; background music is acceptable.
We have provided [some samples](examples/) for your reference.
### 🎮 Run Inference
Simply to run the `scripts/inference.py` and pass `source_image` and `driving_audio` as input:
```bash
python scripts/inference.py --source_image examples/reference_images/1.jpg --driving_audio examples/driving_audios/1.wav
```
Animation results will be saved as `${PROJECT_ROOT}/.cache/output.mp4` by default. You can pass `--output` to specify the output file name. You can find more examples for inference at [examples folder](https://github.com/fudan-generative-vision/hallo/tree/main/examples).
For more options:
```shell
usage: inference.py [-h] [-c CONFIG] [--source_image SOURCE_IMAGE] [--driving_audio DRIVING_AUDIO] [--output OUTPUT] [--pose_weight POSE_WEIGHT]
[--face_weight FACE_WEIGHT] [--lip_weight LIP_WEIGHT] [--face_expand_ratio FACE_EXPAND_RATIO]
options:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
--source_image SOURCE_IMAGE
source image
--driving_audio DRIVING_AUDIO
driving audio
--output OUTPUT output video file name
--pose_weight POSE_WEIGHT
weight of pose
--face_weight FACE_WEIGHT
weight of face
--lip_weight LIP_WEIGHT
weight of lip
--face_expand_ratio FACE_EXPAND_RATIO
face region
```
## Training
### Prepare Data for Training
The training data, which utilizes some talking-face videos similar to the source images used for inference, also needs to meet the following requirements:
1. It should be cropped into squares.
2. The face should be the main focus, making up 50%-70% of the image.
3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles).
Organize your raw videos into the following directory structure:
```text
dataset_name/
|-- videos/
| |-- 0001.mp4
| |-- 0002.mp4
| |-- 0003.mp4
| `-- 0004.mp4
```
You can use any `dataset_name`, but ensure the `videos` directory is named as shown above.
Next, process the videos with the following commands:
```bash
python -m scripts.data_preprocess --input_dir dataset_name/videos --step 1
python -m scripts.data_preprocess --input_dir dataset_name/videos --step 2
```
**Note:** Execute steps 1 and 2 sequentially as they perform different tasks. Step 1 converts videos into frames, extracts audio from each video, and generates the necessary masks. Step 2 generates face embeddings using InsightFace and audio embeddings using Wav2Vec, and requires a GPU. For parallel processing, use the `-p` and `-r` arguments. The `-p` argument specifies the total number of instances to launch, dividing the data into `p` parts. The `-r` argument specifies which part the current process should handle. You need to manually launch multiple instances with different values for `-r`.
Generate the metadata JSON files with the following commands:
```bash
python scripts/extract_meta_info_stage1.py -r path/to/dataset -n dataset_name
python scripts/extract_meta_info_stage2.py -r path/to/dataset -n dataset_name
```
Replace `path/to/dataset` with the path to the parent directory of `videos`, such as `dataset_name` in the example above. This will generate `dataset_name_stage1.json` and `dataset_name_stage2.json` in the `./data` directory.
### Training
Update the data meta path settings in the configuration YAML files, `configs/train/stage1.yaml` and `configs/train/stage2.yaml`:
```yaml
#stage1.yaml
data:
meta_paths:
- ./data/dataset_name_stage1.json
#stage2.yaml
data:
meta_paths:
- ./data/dataset_name_stage2.json
```
Start training with the following command:
```shell
accelerate launch -m \
--config_file accelerate_config.yaml \
--machine_rank 0 \
--main_process_ip 0.0.0.0 \
--main_process_port 20055 \
--num_machines 1 \
--num_processes 8 \
scripts.train_stage1 --config ./configs/train/stage1.yaml
```
#### Accelerate Usage Explanation
The `accelerate launch` command is used to start the training process with distributed settings.
```shell
accelerate launch [arguments] {training_script} --{training_script-argument-1} --{training_script-argument-2} ...
```
**Arguments for Accelerate:**
- `-m, --module`: Interpret the launch script as a Python module.
- `--config_file`: Configuration file for Hugging Face Accelerate.
- `--machine_rank`: Rank of the current machine in a multi-node setup.
- `--main_process_ip`: IP address of the master node.
- `--main_process_port`: Port of the master node.
- `--num_machines`: Total number of nodes participating in the training.
- `--num_processes`: Total number of processes for training, matching the total number of GPUs across all machines.
**Arguments for Training:**
- `{training_script}`: The training script, such as `scripts.train_stage1` or `scripts.train_stage2`.
- `--{training_script-argument-1}`: Arguments specific to the training script. Our training scripts accept one argument, `--config`, to specify the training configuration file.
For multi-node training, you need to manually run the command with different `machine_rank` on each node separately.
For more settings, refer to the [Accelerate documentation](https://huggingface.co/docs/accelerate/en/index).
## 📅️ Roadmap
| Status | Milestone | ETA |
| :----: | :---------------------------------------------------------------------------------------------------- | :--------: |
| ✅ | **[Inference source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo)** | 2024-06-15 |
| ✅ | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/hallo)** | 2024-06-15 |
| ✅ | **[Releasing data preparation and training scripts](#training)** | 2024-06-28 |
| 🚀 | **[Improving the model's performance on Mandarin Chinese]()** | TBD |
<details>
<summary>Other Enhancements</summary>
- [x] Enhancement: Test and ensure compatibility with Windows operating system. [#39](https://github.com/fudan-generative-vision/hallo/issues/39)
- [x] Bug: Output video may lose several frames. [#41](https://github.com/fudan-generative-vision/hallo/issues/41)
- [ ] Bug: Sound volume affecting inference results (audio normalization).
- [ ] ~~Enhancement: Inference code logic optimization~~. This solution doesn't show significant performance improvements. Trying other approaches.
</details>
## 📝 Citation
If you find our work useful for your research, please consider citing the paper:
```
@misc{xu2024hallo,
title={Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation},
author={Mingwang Xu and Hui Li and Qingkun Su and Hanlin Shang and Liwei Zhang and Ce Liu and Jingdong Wang and Yao Yao and Siyu zhu},
year={2024},
eprint={2406.08801},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## 🌟 Opportunities Available
Multiple research positions are open at the **Generative Vision Lab, Fudan University**! Include:
- Research assistant
- Postdoctoral researcher
- PhD candidate
- Master students
Interested individuals are encouraged to contact us at [siyuzhu@fudan.edu.cn](mailto://siyuzhu@fudan.edu.cn) for further information.
## ⚠️ Social Risks and Mitigations
The development of portrait image animation technologies driven by audio inputs poses social risks, such as the ethical implications of creating realistic portraits that could be misused for deepfakes. To mitigate these risks, it is crucial to establish ethical guidelines and responsible use practices. Privacy and consent concerns also arise from using individuals' images and voices. Addressing these involves transparent data usage policies, informed consent, and safeguarding privacy rights. By addressing these risks and implementing mitigations, the research aims to ensure the responsible and ethical development of this technology.
## 🤗 Acknowledgements
We would like to thank the contributors to the [magic-animate](https://github.com/magic-research/magic-animate), [AnimateDiff](https://github.com/guoyww/AnimateDiff), [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui), [AniPortrait](https://github.com/Zejun-Yang/AniPortrait) and [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) repositories, for their open research and exploration.
If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately.
## 👏 Community Contributors
Thank you to all the contributors who have helped to make this project better!
<a href="https://github.com/fudan-generative-vision/hallo/graphs/contributors">
<img src="https://contrib.rocks/image?repo=fudan-generative-vision/hallo" />
</a>
================================================
FILE: accelerate_config.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: "no"
main_training_function: main
mixed_precision: "fp16"
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: configs/inference/.gitkeep
================================================
================================================
FILE: configs/inference/default.yaml
================================================
source_image: examples/reference_images/1.jpg
driving_audio: examples/driving_audios/1.wav
weight_dtype: fp16
data:
n_motion_frames: 2
n_sample_frames: 16
source_image:
width: 512
height: 512
driving_audio:
sample_rate: 16000
export_video:
fps: 25
inference_steps: 40
cfg_scale: 3.5
audio_ckpt_dir: ./pretrained_models/hallo
base_model_path: ./pretrained_models/stable-diffusion-v1-5
motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt
face_analysis:
model_path: ./pretrained_models/face_analysis
wav2vec:
model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
features: all
audio_separator:
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
vae:
model_path: ./pretrained_models/sd-vae-ft-mse
save_path: ./.cache
face_expand_ratio: 1.2
pose_weight: 1.0
face_weight: 1.0
lip_weight: 1.0
unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
use_audio_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
audio_attention_dim: 768
stack_enable_blocks_name:
- "up"
- "down"
- "mid"
stack_enable_blocks_depth: [0,1,2,3]
enable_zero_snr: true
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
clip_sample: false
steps_offset: 1
### Zero-SNR params
prediction_type: "v_prediction"
rescale_betas_zero_snr: True
timestep_spacing: "trailing"
sampler: DDIM
================================================
FILE: configs/train/stage1.yaml
================================================
data:
train_bs: 8
train_width: 512
train_height: 512
meta_paths:
- "./data/HDTF_meta.json"
# Margin of frame indexes between ref and tgt images
sample_margin: 30
solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: False
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1.0e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"
# optimizer
use_8bit_adam: False
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8
val:
validation_steps: 500
noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false
base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"
weight_dtype: "fp16" # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
face_locator_pretrained: False
seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage1"
output_dir: "./exp_output"
ref_image_paths:
- "examples/reference_images/1.jpg"
mask_image_paths:
- "examples/masks/1.png"
================================================
FILE: configs/train/stage2.yaml
================================================
data:
train_bs: 4
val_bs: 1
train_width: 512
train_height: 512
fps: 25
sample_rate: 16000
n_motion_frames: 2
n_sample_frames: 14
audio_margin: 2
train_meta_paths:
- "./data/hdtf_split_stage2.json"
wav2vec_config:
audio_type: "vocals" # audio vocals
model_scale: "base" # base large
features: "all" # last avg all
model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
audio_separator:
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
face_expand_ratio: 1.2
solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"
# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8
val:
validation_steps: 1000
noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false
unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
use_audio_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
audio_attention_dim: 768
stack_enable_blocks_name:
- "up"
- "down"
- "mid"
stack_enable_blocks_depth: [0,1,2,3]
trainable_para:
- audio_modules
- motion_modules
base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"
mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
weight_dtype: "fp16" # [fp16, fp32]
uncond_img_ratio: 0.05
uncond_audio_ratio: 0.05
uncond_ia_ratio: 0.05
start_ratio: 0.05
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
stage1_ckpt_dir: "./exp_output/stage1/"
single_inference_times: 10
inference_steps: 40
cfg_scale: 3.5
seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage2"
output_dir: "./exp_output"
ref_img_path:
- "examples/reference_images/1.jpg"
audio_path:
- "examples/driving_audios/1.wav"
================================================
FILE: configs/unet/unet.yaml
================================================
unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
use_audio_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
audio_attention_dim: 768
stack_enable_blocks_name:
- "up"
- "down"
- "mid"
stack_enable_blocks_depth: [0,1,2,3]
enable_zero_snr: true
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
clip_sample: false
steps_offset: 1
### Zero-SNR params
prediction_type: "v_prediction"
rescale_betas_zero_snr: True
timestep_spacing: "trailing"
sampler: DDIM
================================================
FILE: hallo/__init__.py
================================================
================================================
FILE: hallo/animate/__init__.py
================================================
================================================
FILE: hallo/animate/face_animate.py
================================================
# pylint: disable=R0801
"""
This module is responsible for animating faces in videos using a combination of deep learning techniques.
It provides a pipeline for generating face animations by processing video frames and extracting face features.
The module utilizes various schedulers and utilities for efficient face animation and supports different types
of latents for more control over the animation process.
Functions and Classes:
- FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks.
- __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.).
- prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements.
- prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers.
- decode_latents: Decodes the latents into video frames, ready for animation.
Usage:
- Import the necessary packages and classes.
- Create a FaceAnimatePipeline instance with the required components.
- Prepare the latents for the animation process.
- Use the pipeline to generate the animated video.
Note:
- This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning.
- The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases.
"""
import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from diffusers import (DDIMScheduler, DiffusionPipeline,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
LMSDiscreteScheduler, PNDMScheduler)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange, repeat
from tqdm import tqdm
from hallo.models.mutual_self_attention import ReferenceAttentionControl
@dataclass
class FaceAnimatePipelineOutput(BaseOutput):
"""
FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline.
Attributes:
videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames.
Methods:
__init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames.
"""
videos: Union[torch.Tensor, np.ndarray]
class FaceAnimatePipeline(DiffusionPipeline):
"""
FaceAnimatePipeline is a custom DiffusionPipeline for animating faces.
It inherits from the DiffusionPipeline class and is used to animate faces by
utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet,
a face locator, and an image processor. The pipeline is responsible for generating
and animating face latents, and decoding the latents to produce the final video output.
Attributes:
vae (VaeImageProcessor): Variational autoencoder for processing images.
reference_unet (nn.Module): Reference UNet for mutual self-attention.
denoising_unet (nn.Module): Denoising UNet for image denoising.
face_locator (nn.Module): Face locator for detecting and cropping faces.
image_proj (nn.Module): Image projector for processing images.
scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler]): Diffusion scheduler for
controlling the noise level.
Methods:
__init__(self, vae, reference_unet, denoising_unet, face_locator,
image_proj, scheduler): Initializes the FaceAnimatePipeline
with the given components and scheduler.
prepare_latents(self, batch_size, num_channels_latents, width, height,
video_length, dtype, device, generator=None, latents=None):
Prepares the initial latents for video generation.
prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword
arguments for the scheduler step.
decode_latents(self, latents): Decodes the latents to produce the final
video output.
"""
def __init__(
self,
vae,
reference_unet,
denoising_unet,
face_locator,
image_proj,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
) -> None:
super().__init__()
self.register_modules(
vae=vae,
reference_unet=reference_unet,
denoising_unet=denoising_unet,
face_locator=face_locator,
scheduler=scheduler,
image_proj=image_proj,
)
self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.ref_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True,
)
@property
def _execution_device(self):
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def prepare_latents(
self,
batch_size: int, # Number of videos to generate in parallel
num_channels_latents: int, # Number of channels in the latents
width: int, # Width of the video frame
height: int, # Height of the video frame
video_length: int, # Length of the video in frames
dtype: torch.dtype, # Data type of the latents
device: torch.device, # Device to store the latents on
generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
latents: Optional[torch.Tensor] = None # Pre-generated latents (optional)
):
"""
Prepares the initial latents for video generation.
Args:
batch_size (int): Number of videos to generate in parallel.
num_channels_latents (int): Number of channels in the latents.
width (int): Width of the video frame.
height (int): Height of the video frame.
video_length (int): Length of the video in frames.
dtype (torch.dtype): Data type of the latents.
device (torch.device): Device to store the latents on.
generator (Optional[torch.Generator]): Random number generator for reproducibility.
latents (Optional[torch.Tensor]): Pre-generated latents (optional).
Returns:
latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height)
containing the initial latents for video generation.
"""
shape = (
batch_size,
num_channels_latents,
video_length,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_extra_step_kwargs(self, generator, eta):
"""
Prepares extra keyword arguments for the scheduler step.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility.
eta (float): The eta (η) parameter used with the DDIMScheduler.
It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1].
Returns:
dict: A dictionary containing the extra keyword arguments for the scheduler step.
"""
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def decode_latents(self, latents):
"""
Decode the latents to produce a video.
Parameters:
latents (torch.Tensor): The latents to be decoded.
Returns:
video (torch.Tensor): The decoded video.
video_length (int): The length of the video in frames.
"""
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
# video = self.vae.decode(latents).sample
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(
latents[frame_idx: frame_idx + 1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.cpu().float().numpy()
return video
@torch.no_grad()
def __call__(
self,
ref_image,
face_emb,
audio_tensor,
face_mask,
pixel_values_full_mask,
pixel_values_face_mask,
pixel_values_lip_mask,
width,
height,
video_length,
num_inference_steps,
guidance_scale,
num_images_per_prompt=1,
eta: float = 0.0,
motion_scale: Optional[List[torch.Tensor]] = None,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
callback: Optional[Callable[[
int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_size = 1
# prepare clip image embeddings
clip_image_embeds = face_emb
clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
encoder_hidden_states = self.image_proj(clip_image_embeds)
uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
if do_classifier_free_guidance:
encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
reference_control_writer = ReferenceAttentionControl(
self.reference_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="write",
batch_size=batch_size,
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
self.denoising_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="read",
batch_size=batch_size,
fusion_blocks="full",
)
num_channels_latents = self.denoising_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
width,
height,
video_length,
clip_image_embeds.dtype,
device,
generator,
)
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Prepare ref image latents
ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W)
face_mask = self.face_locator(face_mask)
face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask
pixel_values_full_mask = (
[torch.cat([mask] * 2) for mask in pixel_values_full_mask]
if do_classifier_free_guidance
else pixel_values_full_mask
)
pixel_values_face_mask = (
[torch.cat([mask] * 2) for mask in pixel_values_face_mask]
if do_classifier_free_guidance
else pixel_values_face_mask
)
pixel_values_lip_mask = (
[torch.cat([mask] * 2) for mask in pixel_values_lip_mask]
if do_classifier_free_guidance
else pixel_values_lip_mask
)
pixel_values_face_mask_ = []
for mask in pixel_values_face_mask:
pixel_values_face_mask_.append(
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
pixel_values_face_mask = pixel_values_face_mask_
pixel_values_lip_mask_ = []
for mask in pixel_values_lip_mask:
pixel_values_lip_mask_.append(
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
pixel_values_lip_mask = pixel_values_lip_mask_
pixel_values_full_mask_ = []
for mask in pixel_values_full_mask:
pixel_values_full_mask_.append(
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
pixel_values_full_mask = pixel_values_full_mask_
uncond_audio_tensor = torch.zeros_like(audio_tensor)
audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)
# denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Forward reference image
if i == 0:
self.reference_unet(
ref_image_latents.repeat(
(2 if do_classifier_free_guidance else 1), 1, 1, 1
),
torch.zeros_like(t),
encoder_hidden_states=encoder_hidden_states,
return_dict=False,
)
reference_control_reader.update(reference_control_writer)
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.denoising_unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
mask_cond_fea=face_mask,
full_mask=pixel_values_full_mask,
face_mask=pixel_values_face_mask,
lip_mask=pixel_values_lip_mask,
audio_embedding=audio_tensor,
motion_scale=motion_scale,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
reference_control_reader.clear()
reference_control_writer.clear()
# Post-processing
images = self.decode_latents(latents) # (b, c, f, h, w)
# Convert to tensor
if output_type == "tensor":
images = torch.from_numpy(images)
if not return_dict:
return images
return FaceAnimatePipelineOutput(videos=images)
================================================
FILE: hallo/animate/face_animate_static.py
================================================
# pylint: disable=R0801
"""
This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques.
It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments.
The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance.
Functions and Classes:
- StaticPipelineOutput: A class that represents the output of the animation pipeline, c
ontaining properties and methods related to the generated images.
- prepare_latents: A function that prepares the initial noise for the animation process,
scaling it according to the scheduler's requirements.
- prepare_condition: A function that processes the user-provided conditions
(e.g., facial expressions) and prepares them for use in the animation pipeline.
- decode_latents: A function that decodes the latent representations of the face animations into
their corresponding image formats.
- prepare_extra_step_kwargs: A function that prepares additional parameters for each step of
the animation process, such as the generator and eta values.
Dependencies:
- numpy: A library for numerical computing.
- torch: A machine learning library based on PyTorch.
- diffusers: A library for image-to-image diffusion models.
- transformers: A library for pre-trained transformer models.
Usage:
- To create an instance of the animation pipeline, provide the necessary components such as
the VAE, reference UNET, denoising UNET, face locator, and image processor.
- Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as
required for the animation process.
- Generate the face animations by decoding the latents and processing the conditions.
Note:
- The module is designed to work with the diffusers library, which is based on
the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765).
- The face animations generated by this module should be used for entertainment purposes
only and should respect the rights and privacy of the individuals involved.
"""
import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler)
from diffusers.utils import BaseOutput, is_accelerate_available
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange
from tqdm import tqdm
from transformers import CLIPImageProcessor
from hallo.models.mutual_self_attention import ReferenceAttentionControl
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
@dataclass
class StaticPipelineOutput(BaseOutput):
"""
StaticPipelineOutput is a class that represents the output of the static pipeline.
It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
Attributes:
images (Union[torch.Tensor, np.ndarray]): The generated images.
"""
images: Union[torch.Tensor, np.ndarray]
class StaticPipeline(DiffusionPipeline):
"""
StaticPipelineOutput is a class that represents the output of the static pipeline.
It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
Attributes:
images (Union[torch.Tensor, np.ndarray]): The generated images.
"""
_optional_components = []
def __init__(
self,
vae,
reference_unet,
denoising_unet,
face_locator,
imageproj,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
self.register_modules(
vae=vae,
reference_unet=reference_unet,
denoising_unet=denoising_unet,
face_locator=face_locator,
scheduler=scheduler,
imageproj=imageproj,
)
self.vae_scale_factor = 2 ** (
len(self.vae.config.block_out_channels) - 1)
self.clip_image_processor = CLIPImageProcessor()
self.ref_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
)
self.cond_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
do_convert_rgb=True,
do_normalize=False,
)
def enable_vae_slicing(self):
"""
Enable VAE slicing.
This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
"""
Disable vae slicing.
This function disables the vae slicing for the StaticPipeline object.
It calls the `disable_slicing()` method of the vae model.
This is useful when you want to use the entire vae model for decoding latents
instead of slicing it for better performance.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
"""
Offloads selected models to the GPU for increased performance.
Args:
gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0.
"""
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def decode_latents(self, latents):
"""
Decode the given latents to video frames.
Parameters:
latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width).
Returns:
video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width).
"""
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
# video = self.vae.decode(latents).sample
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(
latents[frame_idx: frame_idx + 1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.cpu().float().numpy()
return video
def prepare_extra_step_kwargs(self, generator, eta):
"""
Prepare extra keyword arguments for the scheduler step.
Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler.
Args:
generator (Optional[torch.Generator]): A random number generator for reproducibility.
eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1.
Returns:
dict: A dictionary containing the extra keyword arguments for the scheduler step.
"""
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def prepare_latents(
self,
batch_size,
num_channels_latents,
width,
height,
dtype,
device,
generator,
latents=None,
):
"""
Prepares the initial latents for the diffusion pipeline.
Args:
batch_size (int): The number of images to generate in one forward pass.
num_channels_latents (int): The number of channels in the latents tensor.
width (int): The width of the latents tensor.
height (int): The height of the latents tensor.
dtype (torch.dtype): The data type of the latents tensor.
device (torch.device): The device to place the latents tensor on.
generator (Optional[torch.Generator], optional): A random number generator
for reproducibility. Defaults to None.
latents (Optional[torch.Tensor], optional): Pre-computed latents to use as
initial conditions for the diffusion process. Defaults to None.
Returns:
torch.Tensor: The prepared latents tensor.
"""
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_condition(
self,
cond_image,
width,
height,
device,
dtype,
do_classififer_free_guidance=False,
):
"""
Prepares the condition for the face animation pipeline.
Args:
cond_image (torch.Tensor): The conditional image tensor.
width (int): The width of the output image.
height (int): The height of the output image.
device (torch.device): The device to run the pipeline on.
dtype (torch.dtype): The data type of the tensor.
do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors.
"""
image = self.cond_image_processor.preprocess(
cond_image, height=height, width=width
).to(dtype=torch.float32)
image = image.to(device=device, dtype=dtype)
if do_classififer_free_guidance:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
def __call__(
self,
ref_image,
face_mask,
width,
height,
num_inference_steps,
guidance_scale,
face_embedding,
num_images_per_prompt=1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
callback: Optional[Callable[[
int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_size = 1
image_prompt_embeds = self.imageproj(face_embedding)
uncond_image_prompt_embeds = self.imageproj(
torch.zeros_like(face_embedding))
if do_classifier_free_guidance:
image_prompt_embeds = torch.cat(
[uncond_image_prompt_embeds, image_prompt_embeds], dim=0
)
reference_control_writer = ReferenceAttentionControl(
self.reference_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="write",
batch_size=batch_size,
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
self.denoising_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="read",
batch_size=batch_size,
fusion_blocks="full",
)
num_channels_latents = self.denoising_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
width,
height,
face_embedding.dtype,
device,
generator,
)
latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
# latents_dtype = latents.dtype
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Prepare ref image latents
ref_image_tensor = self.ref_image_processor.preprocess(
ref_image, height=height, width=width
) # (bs, c, width, height)
ref_image_tensor = ref_image_tensor.to(
dtype=self.vae.dtype, device=self.vae.device
)
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
# Prepare face mask image
face_mask_tensor = self.cond_image_processor.preprocess(
face_mask, height=height, width=width
)
face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w)
face_mask_tensor = face_mask_tensor.to(
device=device, dtype=self.face_locator.dtype
)
mask_fea = self.face_locator(face_mask_tensor)
mask_fea = (
torch.cat(
[mask_fea] * 2) if do_classifier_free_guidance else mask_fea
)
# denoising loop
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# 1. Forward reference image
if i == 0:
self.reference_unet(
ref_image_latents.repeat(
(2 if do_classifier_free_guidance else 1), 1, 1, 1
),
torch.zeros_like(t),
encoder_hidden_states=image_prompt_embeds,
return_dict=False,
)
# 2. Update reference unet feature into denosing net
reference_control_reader.update(reference_control_writer)
# 3.1 expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
noise_pred = self.denoising_unet(
latent_model_input,
t,
encoder_hidden_states=image_prompt_embeds,
mask_cond_fea=mask_fea,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i +
1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
reference_control_reader.clear()
reference_control_writer.clear()
# Post-processing
image = self.decode_latents(latents) # (b, c, 1, h, w)
# Convert to tensor
if output_type == "tensor":
image = torch.from_numpy(image)
if not return_dict:
return image
return StaticPipelineOutput(images=image)
================================================
FILE: hallo/datasets/__init__.py
================================================
================================================
FILE: hallo/datasets/audio_processor.py
================================================
# pylint: disable=C0301
'''
This module contains the AudioProcessor class and related functions for processing audio data.
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
and audio separation. The class is initialized with configuration parameters and can process
audio files using the provided models.
'''
import math
import os
import librosa
import numpy as np
import torch
from audio_separator.separator import Separator
from einops import rearrange
from transformers import Wav2Vec2FeatureExtractor
from hallo.models.wav2vec import Wav2VecModel
from hallo.utils.util import resample_audio
class AudioProcessor:
"""
AudioProcessor is a class that handles the processing of audio files.
It takes care of preprocessing the audio files, extracting features
using wav2vec models, and separating audio signals if needed.
:param sample_rate: Sampling rate of the audio file
:param fps: Frames per second for the extracted features
:param wav2vec_model_path: Path to the wav2vec model
:param only_last_features: Whether to only use the last features
:param audio_separator_model_path: Path to the audio separator model
:param audio_separator_model_name: Name of the audio separator model
:param cache_dir: Directory to cache the intermediate results
:param device: Device to run the processing on
"""
def __init__(
self,
sample_rate,
fps,
wav2vec_model_path,
only_last_features,
audio_separator_model_path:str=None,
audio_separator_model_name:str=None,
cache_dir:str='',
device="cuda:0",
) -> None:
self.sample_rate = sample_rate
self.fps = fps
self.device = device
self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
self.audio_encoder.feature_extractor._freeze_parameters()
self.only_last_features = only_last_features
if audio_separator_model_name is not None:
try:
os.makedirs(cache_dir, exist_ok=True)
except OSError as _:
print("Fail to create the output cache dir.")
self.audio_separator = Separator(
output_dir=cache_dir,
output_single_stem="vocals",
model_file_dir=audio_separator_model_path,
)
self.audio_separator.load_model(audio_separator_model_name)
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
else:
self.audio_separator=None
print("Use audio directly without vocals seperator.")
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
def preprocess(self, wav_file: str, clip_length: int=-1):
"""
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
Args:
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
Raises:
RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
such as file not found, unsupported file format, or errors during the audio processing steps.
Returns:
torch.tensor: Returns an audio embedding as a torch.tensor
"""
if self.audio_separator is not None:
# 1. separate vocals
# TODO: process in memory
outputs = self.audio_separator.separate(wav_file)
if len(outputs) <= 0:
raise RuntimeError("Audio separate failed.")
vocal_audio_file = outputs[0]
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
else:
vocal_audio_file=wav_file
# 2. extract wav2vec features
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
audio_length = seq_len
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
if clip_length>0 and seq_len % clip_length != 0:
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
seq_len += clip_length - seq_len % clip_length
audio_feature = audio_feature.unsqueeze(0)
with torch.no_grad():
embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
assert len(embeddings) > 0, "Fail to extract audio embedding"
if self.only_last_features:
audio_emb = embeddings.last_hidden_state.squeeze()
else:
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_emb = rearrange(audio_emb, "b s d -> s b d")
audio_emb = audio_emb.cpu().detach()
return audio_emb, audio_length
def get_embedding(self, wav_file: str):
"""preprocess wav audio file convert to embeddings
Args:
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
Returns:
torch.tensor: Returns an audio embedding as a torch.tensor
"""
speech_array, sampling_rate = librosa.load(
wav_file, sr=self.sample_rate)
assert sampling_rate == 16000, "The audio sample rate must be 16000"
audio_feature = np.squeeze(self.wav2vec_feature_extractor(
speech_array, sampling_rate=sampling_rate).input_values)
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
audio_feature = torch.from_numpy(
audio_feature).float().to(device=self.device)
audio_feature = audio_feature.unsqueeze(0)
with torch.no_grad():
embeddings = self.audio_encoder(
audio_feature, seq_len=seq_len, output_hidden_states=True)
assert len(embeddings) > 0, "Fail to extract audio embedding"
if self.only_last_features:
audio_emb = embeddings.last_hidden_state.squeeze()
else:
audio_emb = torch.stack(
embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_emb = rearrange(audio_emb, "b s d -> s b d")
audio_emb = audio_emb.cpu().detach()
return audio_emb
def close(self):
"""
TODO: to be implemented
"""
return self
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
================================================
FILE: hallo/datasets/image_processor.py
================================================
# pylint: disable=W0718
"""
This module is responsible for processing images, particularly for face-related tasks.
It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates
the functionality for these operations.
"""
import os
from typing import List
import cv2
import mediapipe as mp
import numpy as np
import torch
from insightface.app import FaceAnalysis
from PIL import Image
from torchvision import transforms
from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
get_union_face_mask, get_union_lip_mask)
MEAN = 0.5
STD = 0.5
class ImageProcessor:
"""
ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
It takes in an image and performs various operations such as augmentation, face detection,
face embedding extraction, and rendering a face mask. The processed images are then used for
further analysis or recognition purposes.
Attributes:
img_size (int): The size of the image to be processed.
face_analysis_model_path (str): The path to the face analysis model.
Methods:
preprocess(source_image_path, cache_dir):
Preprocesses the input image by performing augmentation, face detection,
face embedding extraction, and rendering a face mask.
close():
Closes the ImageProcessor and releases any resources being used.
_augmentation(images, transform, state=None):
Applies image augmentation to the input images using the given transform and state.
__enter__():
Enters a runtime context and returns the ImageProcessor object.
__exit__(_exc_type, _exc_val, _exc_tb):
Exits a runtime context and handles any exceptions that occurred during the processing.
"""
def __init__(self, img_size, face_analysis_model_path) -> None:
self.img_size = img_size
self.pixel_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
transforms.Normalize([MEAN], [STD]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
]
)
self.attn_transform_64 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 8, self.img_size[0] // 8)),
transforms.ToTensor(),
]
)
self.attn_transform_32 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 16, self.img_size[0] // 16)),
transforms.ToTensor(),
]
)
self.attn_transform_16 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 32, self.img_size[0] // 32)),
transforms.ToTensor(),
]
)
self.attn_transform_8 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 64, self.img_size[0] // 64)),
transforms.ToTensor(),
]
)
self.face_analysis = FaceAnalysis(
name="",
root=face_analysis_model_path,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float):
"""
Apply preprocessing to the source image to prepare for face analysis.
Parameters:
source_image_path (str): The path to the source image.
cache_dir (str): The directory to cache intermediate results.
Returns:
None
"""
source_image = Image.open(source_image_path)
ref_image_pil = source_image.convert("RGB")
# 1. image augmentation
pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform)
# 2.1 detect face
faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
if not faces:
print("No faces detected in the image. Using the entire image as the face region.")
# Use the entire image as the face region
face = {
"bbox": [0, 0, ref_image_pil.width, ref_image_pil.height],
"embedding": np.zeros(512)
}
else:
# Sort faces by size and select the largest one
faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True)
face = faces_sorted[0] # Select the largest face
# 2.2 face embedding
face_emb = face["embedding"]
# 2.3 render face mask
get_mask(source_image_path, cache_dir, face_region_ratio)
file_name = os.path.basename(source_image_path).split(".")[0]
face_mask_pil = Image.open(
os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB")
face_mask = self._augmentation(face_mask_pil, self.cond_transform)
# 2.4 detect and expand lip, face mask
sep_background_mask = Image.open(
os.path.join(cache_dir, f"{file_name}_sep_background.png"))
sep_face_mask = Image.open(
os.path.join(cache_dir, f"{file_name}_sep_face.png"))
sep_lip_mask = Image.open(
os.path.join(cache_dir, f"{file_name}_sep_lip.png"))
pixel_values_face_mask = [
self._augmentation(sep_face_mask, self.attn_transform_64),
self._augmentation(sep_face_mask, self.attn_transform_32),
self._augmentation(sep_face_mask, self.attn_transform_16),
self._augmentation(sep_face_mask, self.attn_transform_8),
]
pixel_values_lip_mask = [
self._augmentation(sep_lip_mask, self.attn_transform_64),
self._augmentation(sep_lip_mask, self.attn_transform_32),
self._augmentation(sep_lip_mask, self.attn_transform_16),
self._augmentation(sep_lip_mask, self.attn_transform_8),
]
pixel_values_full_mask = [
self._augmentation(sep_background_mask, self.attn_transform_64),
self._augmentation(sep_background_mask, self.attn_transform_32),
self._augmentation(sep_background_mask, self.attn_transform_16),
self._augmentation(sep_background_mask, self.attn_transform_8),
]
pixel_values_full_mask = [mask.view(1, -1)
for mask in pixel_values_full_mask]
pixel_values_face_mask = [mask.view(1, -1)
for mask in pixel_values_face_mask]
pixel_values_lip_mask = [mask.view(1, -1)
for mask in pixel_values_lip_mask]
return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask
def close(self):
"""
Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
Args:
self: The ImageProcessor instance.
Returns:
None.
"""
for _, model in self.face_analysis.models.items():
if hasattr(model, "Dispose"):
model.Dispose()
def _augmentation(self, images, transform, state=None):
if state is not None:
torch.set_rng_state(state)
if isinstance(images, List):
transformed_images = [transform(img) for img in images]
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
else:
ret_tensor = transform(images) # (c, h, w)
return ret_tensor
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
class ImageProcessorForDataProcessing():
"""
ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
It takes in an image and performs various operations such as augmentation, face detection,
face embedding extraction, and rendering a face mask. The processed images are then used for
further analysis or recognition purposes.
Attributes:
img_size (int): The size of the image to be processed.
face_analysis_model_path (str): The path to the face analysis model.
Methods:
preprocess(source_image_path, cache_dir):
Preprocesses the input image by performing augmentation, face detection,
face embedding extraction, and rendering a face mask.
close():
Closes the ImageProcessor and releases any resources being used.
_augmentation(images, transform, state=None):
Applies image augmentation to the input images using the given transform and state.
__enter__():
Enters a runtime context and returns the ImageProcessor object.
__exit__(_exc_type, _exc_val, _exc_tb):
Exits a runtime context and handles any exceptions that occurred during the processing.
"""
def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
if step == 2:
self.face_analysis = FaceAnalysis(
name="",
root=face_analysis_model_path,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
self.landmarker = None
else:
BaseOptions = mp.tasks.BaseOptions
FaceLandmarker = mp.tasks.vision.FaceLandmarker
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode
# Create a face landmarker instance with the video mode:
options = FaceLandmarkerOptions(
base_options=BaseOptions(model_asset_path=landmark_model_path),
running_mode=VisionRunningMode.IMAGE,
)
self.landmarker = FaceLandmarker.create_from_options(options)
self.face_analysis = None
def preprocess(self, source_image_path: str):
"""
Apply preprocessing to the source image to prepare for face analysis.
Parameters:
source_image_path (str): The path to the source image.
cache_dir (str): The directory to cache intermediate results.
Returns:
None
"""
# 1. get face embdeding
face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
if self.face_analysis:
for frame in sorted(os.listdir(source_image_path)):
try:
source_image = Image.open(
os.path.join(source_image_path, frame))
ref_image_pil = source_image.convert("RGB")
# 2.1 detect face
faces = self.face_analysis.get(cv2.cvtColor(
np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
# use max size face
face = sorted(faces, key=lambda x: (
x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
# 2.2 face embedding
face_emb = face["embedding"]
if face_emb is not None:
break
except Exception as _:
continue
if self.landmarker:
# 3.1 get landmark
landmarks, height, width = get_landmark_overframes(
self.landmarker, source_image_path)
assert len(landmarks) == len(os.listdir(source_image_path))
# 3 render face and lip mask
face_mask = get_union_face_mask(landmarks, height, width)
lip_mask = get_union_lip_mask(landmarks, height, width)
# 4 gaussian blur
blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
# 5 seperate mask
sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
sep_pose_mask = 255.0 - blur_face_mask
sep_lip_mask = blur_lip_mask
return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
def close(self):
"""
Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
Args:
self: The ImageProcessor instance.
Returns:
None.
"""
for _, model in self.face_analysis.models.items():
if hasattr(model, "Dispose"):
model.Dispose()
def _augmentation(self, images, transform, state=None):
if state is not None:
torch.set_rng_state(state)
if isinstance(images, List):
transformed_images = [transform(img) for img in images]
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
else:
ret_tensor = transform(images) # (c, h, w)
return ret_tensor
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
================================================
FILE: hallo/datasets/mask_image.py
================================================
# pylint: disable=R0801
"""
This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
provides methods for data augmentation, getting items from the dataset, and determining the length of the
dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
PIL, and transformers.
"""
import json
import random
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPImageProcessor
class FaceMaskDataset(Dataset):
"""
FaceMaskDataset is a custom dataset for face mask images.
Args:
img_size (int): The size of the input images.
drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
Attributes:
img_size (int): The size of the input images.
drop_ratio (float): The ratio of dropped pixels during data augmentation.
data_meta_paths (list): The paths to the metadata files containing image paths and labels.
sample_margin (int): The margin for sampling regions in the image.
processor (CLIPImageProcessor): The image processor for preprocessing images.
transform (transforms.Compose): The image augmentation transform.
"""
def __init__(
self,
img_size,
drop_ratio=0.1,
data_meta_paths=None,
sample_margin=30,
):
super().__init__()
self.img_size = img_size
self.sample_margin = sample_margin
vid_meta = []
for data_meta_path in data_meta_paths:
with open(data_meta_path, "r", encoding="utf-8") as f:
vid_meta.extend(json.load(f))
self.vid_meta = vid_meta
self.length = len(self.vid_meta)
self.clip_image_processor = CLIPImageProcessor()
self.transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
]
)
self.drop_ratio = drop_ratio
def augmentation(self, image, transform, state=None):
"""
Apply data augmentation to the input image.
Args:
image (PIL.Image): The input image.
transform (torchvision.transforms.Compose): The data augmentation transforms.
state (dict, optional): The random state for reproducibility. Defaults to None.
Returns:
PIL.Image: The augmented image.
"""
if state is not None:
torch.set_rng_state(state)
return transform(image)
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["image_path"]
mask_path = video_meta["mask_path"]
face_emb_path = video_meta["face_emb"]
video_frames = sorted(Path(video_path).iterdir())
video_length = len(video_frames)
margin = min(self.sample_margin, video_length)
ref_img_idx = random.randint(0, video_length - 1)
if ref_img_idx + margin < video_length:
tgt_img_idx = random.randint(
ref_img_idx + margin, video_length - 1)
elif ref_img_idx - margin > 0:
tgt_img_idx = random.randint(0, ref_img_idx - margin)
else:
tgt_img_idx = random.randint(0, video_length - 1)
ref_img_pil = Image.open(video_frames[ref_img_idx])
tgt_img_pil = Image.open(video_frames[tgt_img_idx])
tgt_mask_pil = Image.open(mask_path)
assert ref_img_pil is not None, "Fail to load reference image."
assert tgt_img_pil is not None, "Fail to load target image."
assert tgt_mask_pil is not None, "Fail to load target mask."
state = torch.get_rng_state()
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
tgt_mask_img = self.augmentation(
tgt_mask_pil, self.cond_transform, state)
tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
ref_img_vae = self.augmentation(
ref_img_pil, self.transform, state)
face_emb = torch.load(face_emb_path)
sample = {
"video_dir": video_path,
"img": tgt_img,
"tgt_mask": tgt_mask_img,
"ref_img": ref_img_vae,
"face_emb": face_emb,
}
return sample
def __len__(self):
return len(self.vid_meta)
if __name__ == "__main__":
data = FaceMaskDataset(img_size=(512, 512))
train_dataloader = torch.utils.data.DataLoader(
data, batch_size=4, shuffle=True, num_workers=1
)
for step, batch in enumerate(train_dataloader):
print(batch["tgt_mask"].shape)
break
================================================
FILE: hallo/datasets/talk_video.py
================================================
# pylint: disable=R0801
"""
talking_video_dataset.py
This module defines the TalkingVideoDataset class, a custom PyTorch dataset
for handling talking video data. The dataset uses video files, masks, and
embeddings to prepare data for tasks such as video generation and
speech-driven video animation.
Classes:
TalkingVideoDataset
Dependencies:
json
random
torch
decord.VideoReader, decord.cpu
PIL.Image
torch.utils.data.Dataset
torchvision.transforms
Example:
from talking_video_dataset import TalkingVideoDataset
from torch.utils.data import DataLoader
# Example configuration for the Wav2Vec model
class Wav2VecConfig:
def __init__(self, audio_type, model_scale, features):
self.audio_type = audio_type
self.model_scale = model_scale
self.features = features
wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature")
# Initialize dataset
dataset = TalkingVideoDataset(
img_size=(512, 512),
sample_rate=16000,
audio_margin=2,
n_motion_frames=0,
n_sample_frames=16,
data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"],
wav2vec_cfg=wav2vec_cfg,
)
# Initialize dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Fetch one batch of data
batch = next(iter(dataloader))
print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512)
The TalkingVideoDataset class provides methods for loading video frames, masks,
audio embeddings, and other relevant data, applying transformations, and preparing
the data for training and evaluation in a deep learning pipeline.
Attributes:
img_size (tuple): The dimensions to resize the video frames to.
sample_rate (int): The audio sample rate.
audio_margin (int): The margin for audio sampling.
n_motion_frames (int): The number of motion frames.
n_sample_frames (int): The number of sample frames.
data_meta_paths (list): List of paths to the JSON metadata files.
wav2vec_cfg (object): Configuration for the Wav2Vec model.
Methods:
augmentation(images, transform, state=None): Apply transformation to input images.
__getitem__(index): Get a sample from the dataset at the specified index.
__len__(): Return the length of the dataset.
"""
import json
import random
from typing import List
import torch
from decord import VideoReader, cpu
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class TalkingVideoDataset(Dataset):
"""
A dataset class for processing talking video data.
Args:
img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
audio_margin (int, optional): The margin for the audio data. Defaults to 2.
n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
n_sample_frames (int, optional): The number of sample frames. Defaults to 16.
data_meta_paths (list, optional): The paths to the data metadata. Defaults to None.
wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None.
Attributes:
img_size (tuple): The size of the output images.
sample_rate (int): The sample rate of the audio data.
audio_margin (int): The margin for the audio data.
n_motion_frames (int): The number of motion frames.
n_sample_frames (int): The number of sample frames.
data_meta_paths (list): The paths to the data metadata.
wav2vec_cfg (dict): The configuration for the wav2vec model.
"""
def __init__(
self,
img_size=(512, 512),
sample_rate=16000,
audio_margin=2,
n_motion_frames=0,
n_sample_frames=16,
data_meta_paths=None,
wav2vec_cfg=None,
):
super().__init__()
self.sample_rate = sample_rate
self.img_size = img_size
self.audio_margin = audio_margin
self.n_motion_frames = n_motion_frames
self.n_sample_frames = n_sample_frames
self.audio_type = wav2vec_cfg.audio_type
self.audio_model = wav2vec_cfg.model_scale
self.audio_features = wav2vec_cfg.features
vid_meta = []
for data_meta_path in data_meta_paths:
with open(data_meta_path, "r", encoding="utf-8") as f:
vid_meta.extend(json.load(f))
self.vid_meta = vid_meta
self.length = len(self.vid_meta)
self.pixel_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
]
)
self.attn_transform_64 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 8, self.img_size[0] // 8)),
transforms.ToTensor(),
]
)
self.attn_transform_32 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 16, self.img_size[0] // 16)),
transforms.ToTensor(),
]
)
self.attn_transform_16 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 32, self.img_size[0] // 32)),
transforms.ToTensor(),
]
)
self.attn_transform_8 = transforms.Compose(
[
transforms.Resize(
(self.img_size[0] // 64, self.img_size[0] // 64)),
transforms.ToTensor(),
]
)
def augmentation(self, images, transform, state=None):
"""
Apply the given transformation to the input images.
Args:
images (List[PIL.Image] or PIL.Image): The input images to be transformed.
transform (torchvision.transforms.Compose): The transformation to be applied to the images.
state (torch.ByteTensor, optional): The state of the random number generator.
If provided, it will set the RNG state to this value before applying the transformation. Defaults to None.
Returns:
torch.Tensor: The transformed images as a tensor.
If the input was a list of images, the tensor will have shape (f, c, h, w),
where f is the number of images, c is the number of channels, h is the height, and w is the width.
If the input was a single image, the tensor will have shape (c, h, w),
where c is the number of channels, h is the height, and w is the width.
"""
if state is not None:
torch.set_rng_state(state)
if isinstance(images, List):
transformed_images = [transform(img) for img in images]
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
else:
ret_tensor = transform(images) # (c, h, w)
return ret_tensor
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["video_path"]
mask_path = video_meta["mask_path"]
lip_mask_union_path = video_meta.get("sep_mask_lip", None)
face_mask_union_path = video_meta.get("sep_mask_face", None)
full_mask_union_path = video_meta.get("sep_mask_border", None)
face_emb_path = video_meta["face_emb_path"]
audio_emb_path = video_meta[
f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}"
]
tgt_mask_pil = Image.open(mask_path)
video_frames = VideoReader(video_path, ctx=cpu(0))
assert tgt_mask_pil is not None, "Fail to load target mask."
assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames."
video_length = len(video_frames)
assert (
video_length
> self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin
)
start_idx = random.randint(
self.n_motion_frames,
video_length - self.n_sample_frames - self.audio_margin - 1,
)
videos = video_frames[start_idx : start_idx + self.n_sample_frames]
frame_list = [
Image.fromarray(video).convert("RGB") for video in videos.asnumpy()
]
face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames
lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames
full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames
assert face_masks_list[0] is not None, "Fail to load face mask."
assert lip_masks_list[0] is not None, "Fail to load lip mask."
assert full_masks_list[0] is not None, "Fail to load full mask."
face_emb = torch.load(face_emb_path)
audio_emb = torch.load(audio_emb_path)
indices = (
torch.arange(2 * self.audio_margin + 1) - self.audio_margin
) # Generates [-2, -1, 0, 1, 2]
center_indices = torch.arange(
start_idx,
start_idx + self.n_sample_frames,
).unsqueeze(1) + indices.unsqueeze(0)
audio_tensor = audio_emb[center_indices]
ref_img_idx = random.randint(
self.n_motion_frames,
video_length - self.n_sample_frames - self.audio_margin - 1,
)
ref_img = video_frames[ref_img_idx].asnumpy()
ref_img = Image.fromarray(ref_img)
if self.n_motion_frames > 0:
motions = video_frames[start_idx - self.n_motion_frames : start_idx]
motion_list = [
Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy()
]
# transform
state = torch.get_rng_state()
pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state)
pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state)
pixel_values_mask = pixel_values_mask.repeat(3, 1, 1)
pixel_values_face_mask = [
self.augmentation(face_masks_list, self.attn_transform_64, state),
self.augmentation(face_masks_list, self.attn_transform_32, state),
self.augmentation(face_masks_list, self.attn_transform_16, state),
self.augmentation(face_masks_list, self.attn_transform_8, state),
]
pixel_values_lip_mask = [
self.augmentation(lip_masks_list, self.attn_transform_64, state),
self.augmentation(lip_masks_list, self.attn_transform_32, state),
self.augmentation(lip_masks_list, self.attn_transform_16, state),
self.augmentation(lip_masks_list, self.attn_transform_8, state),
]
pixel_values_full_mask = [
self.augmentation(full_masks_list, self.attn_transform_64, state),
self.augmentation(full_masks_list, self.attn_transform_32, state),
self.augmentation(full_masks_list, self.attn_transform_16, state),
self.augmentation(full_masks_list, self.attn_transform_8, state),
]
pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
if self.n_motion_frames > 0:
pixel_values_motion = self.augmentation(
motion_list, self.pixel_transform, state
)
pixel_values_ref_img = torch.cat(
[pixel_values_ref_img, pixel_values_motion], dim=0
)
sample = {
"video_dir": video_path,
"pixel_values_vid": pixel_values_vid,
"pixel_values_mask": pixel_values_mask,
"pixel_values_face_mask": pixel_values_face_mask,
"pixel_values_lip_mask": pixel_values_lip_mask,
"pixel_values_full_mask": pixel_values_full_mask,
"audio_tensor": audio_tensor,
"pixel_values_ref_img": pixel_values_ref_img,
"face_emb": face_emb,
}
return sample
def __len__(self):
return len(self.vid_meta)
================================================
FILE: hallo/models/__init__.py
================================================
================================================
FILE: hallo/models/attention.py
================================================
# pylint: disable=R0801
# pylint: disable=C0303
"""
This module contains various transformer blocks for different applications, such as BasicTransformerBlock,
TemporalBasicTransformerBlock, and AudioTemporalBasicTransformerBlock. These blocks are used in various models,
such as GLIGEN, UNet, and others. The transformer blocks implement self-attention, cross-attention, feed-forward
networks, and other related functions.
Functions and classes included in this module are:
- BasicTransformerBlock: A basic transformer block with self-attention, cross-attention, and feed-forward layers.
- TemporalBasicTransformerBlock: A transformer block with additional temporal attention mechanisms for video data.
- AudioTemporalBasicTransformerBlock: A transformer block with additional audio-specific mechanisms for audio data.
- zero_module: A function to zero out the parameters of a given module.
For more information on each specific class and function, please refer to the respective docstrings.
"""
from typing import Any, Dict, List, Optional
import torch
from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero,
Attention, FeedForward)
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from einops import rearrange
from torch import nn
class GatedSelfAttentionDense(nn.Module):
"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
"""
Apply the Gated Self-Attention mechanism to the input tensor `x` and object tensor `objs`.
Args:
x (torch.Tensor): The input tensor.
objs (torch.Tensor): The object tensor.
Returns:
torch.Tensor: The output tensor after applying Gated Self-Attention.
"""
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
# 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(
dim, max_seq_length=num_positional_embeddings
)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=(
cross_attention_dim if not double_self_attention else None
),
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type in {"gated", "gated-text-image"}: # Updated line
self.fuser = GatedSelfAttentionDense(
dim, cross_attention_dim, num_attention_heads, attention_head_dim
)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(
torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
"""
Sets the chunk size for feed-forward processing in the transformer block.
Args:
chunk_size (Optional[int]): The size of the chunks to process in feed-forward layers.
If None, the chunk size is set to the maximum possible value.
dim (int, optional): The dimension along which to split the input tensor into chunks. Defaults to 0.
Returns:
None.
"""
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
"""
This function defines the forward pass of the BasicTransformerBlock.
Args:
self (BasicTransformerBlock):
An instance of the BasicTransformerBlock class.
hidden_states (torch.FloatTensor):
A tensor containing the hidden states.
attention_mask (Optional[torch.FloatTensor], optional):
A tensor containing the attention mask. Defaults to None.
encoder_hidden_states (Optional[torch.FloatTensor], optional):
A tensor containing the encoder hidden states. Defaults to None.
encoder_attention_mask (Optional[torch.FloatTensor], optional):
A tensor containing the encoder attention mask. Defaults to None.
timestep (Optional[torch.LongTensor], optional):
A tensor containing the timesteps. Defaults to None.
cross_attention_kwargs (Dict[str, Any], optional):
Additional cross-attention arguments. Defaults to None.
class_labels (Optional[torch.LongTensor], optional):
A tensor containing the class labels. Defaults to None.
Returns:
torch.FloatTensor:
A tensor containing the transformed hidden states.
"""
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
gate_msa = None
scale_mlp = None
shift_mlp = None
gate_mlp = None
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] +
timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * \
(1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = (
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
)
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=(
encoder_hidden_states if self.only_cross_attention else None
),
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states *
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * \
(1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class TemporalBasicTransformerBlock(nn.Module):
"""
A PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
This class is particularly useful for video-related tasks where capturing temporal information within the sequence of frames is necessary.
Attributes:
dim (int): The dimension of the input and output embeddings.
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
attention_head_dim (int): The dimension of each attention head.
dropout (float): The dropout probability for the attention scores.
cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
activation_fn (str): The activation function used in the feed-forward layer.
num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
attention_bias (bool): If True, uses bias in the attention mechanism.
only_cross_attention (bool): If True, only uses cross-attention.
upcast_attention (bool): If True, upcasts the attention mechanism for better performance.
unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in the UNet model.
unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in the UNet model.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
"""
The TemporalBasicTransformerBlock class is a PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
This is particularly useful for video-related tasks, where the model needs to capture the temporal information within the sequence of frames.
The block consists of self-attention, cross-attention, feed-forward, and temporal attention mechanisms.
dim (int): The dimension of the input and output embeddings.
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
attention_head_dim (int): The dimension of each attention head.
dropout (float, optional): The dropout probability for the attention scores. Defaults to 0.0.
cross_attention_dim (int, optional): The dimension of the cross-attention mechanism. Defaults to None.
activation_fn (str, optional): The activation function used in the feed-forward layer. Defaults to "geglu".
num_embeds_ada_norm (int, optional): The number of embeddings for adaptive normalization. Defaults to None.
attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
upcast_attention (bool, optional): If True, upcasts the attention mechanism for better performance. Defaults to False.
unet_use_cross_frame_attention (bool, optional): If True, uses cross-frame attention in the UNet model. Defaults to None.
unet_use_temporal_attention (bool, optional): If True, uses temporal attention in the UNet model. Defaults to None.
Forward method:
hidden_states (torch.FloatTensor): The input hidden states.
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
timestep (torch.LongTensor, optional): The current timestep for the transformer model. Defaults to None.
attention_mask (torch.FloatTensor, optional): The attention mask for the self-attention mechanism. Defaults to None.
video_length (int, optional): The length of the video sequence. Defaults to None.
Returns:
torch.FloatTensor: The output hidden states after passing through the TemporalBasicTransformerBlock.
"""
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout,
activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
# assert unet_use_temporal_attention is not None
if unet_use_temporal_attention is None:
unet_use_temporal_attention = False
if unet_use_temporal_attention:
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
video_length=None,
):
"""
Forward pass for the TemporalBasicTransformerBlock.
Args:
hidden_states (torch.FloatTensor): The input hidden states with shape (batch_size, seq_len, dim).
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states with shape (batch_size, src_seq_len, dim).
timestep (torch.LongTensor, optional): The timestep for the transformer block.
attention_mask (torch.FloatTensor, optional): The attention mask with shape (batch_size, seq_len, seq_len).
video_length (int, optional): The length of the video sequence.
Returns:
torch.FloatTensor: The output tensor after passing through the transformer block with shape (batch_size, seq_len, dim).
"""
norm_hidden_states = (
self.norm1(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm1(hidden_states)
)
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
)
+ hidden_states
)
else:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask)
+ hidden_states
)
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(
hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
class AudioTemporalBasicTransformerBlock(nn.Module):
"""
A PyTorch module designed to handle audio data within a transformer framework, including temporal attention mechanisms.
Attributes:
dim (int): The dimension of the input and output embeddings.
num_attention_heads (int): The number of attention heads.
attention_head_dim (int): The dimension of each attention head.
dropout (float): The dropout probability.
cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
activation_fn (str): The activation function for the feed-forward network.
num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
attention_bias (bool): If True, uses bias in the attention mechanism.
only_cross_attention (bool): If True, only uses cross-attention.
upcast_attention (bool): If True, upcasts the attention mechanism to float32.
unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in UNet.
unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in UNet.
depth (int): The depth of the transformer block.
unet_block_name (Optional[str]): The name of the UNet block.
stack_enable_blocks_name (Optional[List[str]]): The list of enabled blocks in the stack.
stack_enable_blocks_depth (Optional[List[int]]): The list of depths for the enabled blocks in the stack.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
depth=0,
unet_block_name=None,
stack_enable_blocks_name: Optional[List[str]] = None,
stack_enable_blocks_depth: Optional[List[int]] = None,
):
"""
Initializes the AudioTemporalBasicTransformerBlock module.
Args:
dim (int): The dimension of the input and output embeddings.
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
attention_head_dim (int): The dimension of each attention head.
dropout (float, optional): The dropout probability for the attention mechanism. Defaults to 0.0.
cross_attention_dim (Optional[int], optional): The dimension of the cross-attention mechanism. Defaults to None.
activation_fn (str, optional): The activation function to be used in the feed-forward network. Defaults to "geglu".
num_embeds_ada_norm (Optional[int], optional): The number of embeddings for adaptive normalization. Defaults to None.
attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
upcast_attention (bool, optional): If True, upcasts the attention mechanism to float32. Defaults to False.
unet_use_cross_frame_attention (Optional[bool], optional): If True, uses cross-frame attention in UNet. Defaults to None.
unet_use_temporal_attention (Optional[bool], optional): If True, uses temporal attention in UNet. Defaults to None.
depth (int, optional): The depth of the transformer block. Defaults to 0.
unet_block_name (Optional[str], optional): The name of the UNet block. Defaults to None.
stack_enable_blocks_name (Optional[List[str]], optional): The list of enabled blocks in the stack. Defaults to None.
stack_enable_blocks_depth (Optional[List[int]], optional): The list of depths for the enabled blocks in the stack. Defaults to None.
"""
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
self.unet_block_name = unet_block_name
self.depth = depth
zero_conv_full = nn.Conv2d(
dim, dim, kernel_size=1)
self.zero_conv_full = zero_module(zero_conv_full)
zero_conv_face = nn.Conv2d(
dim, dim, kernel_size=1)
self.zero_conv_face = zero_module(zero_conv_face)
zero_conv_lip = nn.Conv2d(
dim, dim, kernel_size=1)
self.zero_conv_lip = zero_module(zero_conv_lip)
# SC-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
# Cross-Attn
if cross_attention_dim is not None:
if (stack_enable_blocks_name is not None and
stack_enable_blocks_depth is not None and
self.unet_block_name in stack_enable_blocks_name and
self.depth in stack_enable_blocks_depth):
self.attn2_0 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2_1 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2_2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2 = None
else:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2_0=None
else:
self.attn2 = None
self.attn2_0 = None
if cross_attention_dim is not None:
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout,
activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
full_mask=None,
face_mask=None,
lip_mask=None,
motion_scale=None,
video_length=None,
):
"""
Forward pass for the AudioTemporalBasicTransformerBlock.
Args:
hidden_states (torch.FloatTensor): The input hidden states.
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
timestep (torch.LongTensor, optional): The timestep for the transformer block. Defaults to None.
attention_mask (torch.FloatTensor, optional): The attention mask. Defaults to None.
full_mask (torch.FloatTensor, optional): The full mask. Defaults to None.
face_mask (torch.FloatTensor, optional): The face mask. Defaults to None.
lip_mask (torch.FloatTensor, optional): The lip mask. Defaults to None.
video_length (int, optional): The length of the video. Defaults to None.
Returns:
torch.FloatTensor: The output tensor after passing through the AudioTemporalBasicTransformerBlock.
"""
norm_hidden_states = (
self.norm1(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm1(hidden_states)
)
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
)
+ hidden_states
)
else:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask)
+ hidden_states
)
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) + hidden_states
elif self.attn2_0 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
level = self.depth
full_hidden_states = (
self.attn2_0(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) * full_mask[level][:, :, None]
)
bz, sz, c = full_hidden_states.shape
sz_sqrt = int(sz ** 0.5)
full_hidden_states = full_hidden_states.reshape(
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c)
face_hidden_state = (
self.attn2_1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) * face_mask[level][:, :, None]
)
face_hidden_state = face_hidden_state.reshape(
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
face_hidden_state = self.zero_conv_face(
face_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
lip_hidden_state = (
self.attn2_2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) * lip_mask[level][:, :, None]
) # [32, 4096, 320]
lip_hidden_state = lip_hidden_state.reshape(
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
lip_hidden_state = self.zero_conv_lip(
lip_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
if motion_scale is not None:
hidden_states = (
motion_scale[0] * full_hidden_states +
motion_scale[1] * face_hidden_state +
motion_scale[2] * lip_hidden_state + hidden_states
)
else:
hidden_states = (
full_hidden_states +
face_hidden_state +
lip_hidden_state + hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
def zero_module(module):
"""
Zeroes out the parameters of a given module.
Args:
module (nn.Module): The module whose parameters need to be zeroed out.
Returns:
None.
"""
for p in module.parameters():
nn.init.zeros_(p)
return module
================================================
FILE: hallo/models/audio_proj.py
================================================
"""
This module provides the implementation of an Audio Projection Model, which is designed for
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
that can be used for various downstream applications, such as audio analysis or synthesis.
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
provides a foundation for building custom models. This implementation includes multiple linear
layers with ReLU activation functions and a LayerNorm for normalization.
Key Features:
- Audio embedding input with flexible sequence length and block structure.
- Multiple linear layers for feature transformation.
- ReLU activation for non-linear transformation.
- LayerNorm for stabilizing and speeding up training.
- Rearrangement of input embeddings to match the model's expected input shape.
- Customizable number of blocks, channels, and context tokens for adaptability.
The module is structured to be easily integrated into larger systems or used as a standalone
component for audio feature extraction and processing.
Classes:
- AudioProjModel: A class representing the audio projection model with configurable parameters.
Functions:
- (none)
Dependencies:
- torch: For tensor operations and neural network components.
- diffusers: For the ModelMixin base class.
- einops: For tensor rearrangement operations.
"""
import torch
from diffusers import ModelMixin
from einops import rearrange
from torch import nn
class AudioProjModel(ModelMixin):
"""Audio Projection Model
This class defines an audio projection model that takes audio embeddings as input
and produces context tokens as output. The model is based on the ModelMixin class
and consists of multiple linear layers and activation functions. It can be used
for various audio processing tasks.
Attributes:
seq_len (int): The length of the audio sequence.
blocks (int): The number of blocks in the audio projection model.
channels (int): The number of channels in the audio projection model.
intermediate_dim (int): The intermediate dimension of the model.
context_tokens (int): The number of context tokens in the output.
output_dim (int): The output dimension of the context tokens.
Methods:
__init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
Initializes the AudioProjModel with the given parameters.
forward(self, audio_embeds):
Defines the forward pass for the AudioProjModel.
Parameters:
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
Returns:
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
"""
def __init__(
self,
seq_len=5,
blocks=12, # add a new parameter blocks
channels=768, # add a new parameter channels
intermediate_dim=512,
output_dim=768,
context_tokens=32,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = (
seq_len * blocks * channels
) # update input_dim to be the product of blocks and channels.
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
self.norm = nn.LayerNorm(output_dim)
def forward(self, audio_embeds):
"""
Defines the forward pass for the AudioProjModel.
Parameters:
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
Returns:
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
"""
# merge
video_length = audio_embeds.shape[1]
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds = torch.relu(self.proj2(audio_embeds))
context_tokens = self.proj3(audio_embeds).reshape(
batch_size, self.context_tokens, self.output_dim
)
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(
context_tokens, "(bz f) m c -> bz f m c", f=video_length
)
return context_tokens
================================================
FILE: hallo/models/face_locator.py
================================================
"""
This module implements the FaceLocator class, which is a neural network model designed to
locate and extract facial features from input images or tensors. It uses a series of
convolutional layers to progressively downsample and refine the facial feature map.
The FaceLocator class is part of a larger system that may involve facial recognition or
similar tasks where precise location and extraction of facial features are required.
Attributes:
conditioning_embedding_channels (int): The number of channels in the output embedding.
conditioning_channels (int): The number of input channels for the conditioning tensor.
block_out_channels (Tuple[int]): A tuple of integers representing the output channels
for each block in the model.
The model uses the following components:
- InflatedConv3d: A convolutional layer that inflates the input to increase the depth.
- zero_module: A utility function that may set certain parameters to zero for regularization
or other purposes.
The forward method of the FaceLocator class takes a conditioning tensor as input and
produces an embedding tensor as output, which can be used for further processing or analysis.
"""
from typing import Tuple
import torch.nn.functional as F
from diffusers.models.modeling_utils import ModelMixin
from torch import nn
from .motion_module import zero_module
from .resnet import InflatedConv3d
class FaceLocator(ModelMixin):
"""
The FaceLocator class is a neural network model designed to process and extract facial
features from an input tensor. It consists of a series of convolutional layers that
progressively downsample the input while increasing the depth of the feature map.
The model is built using InflatedConv3d layers, which are designed to inflate the
feature channels, allowing for more complex feature extraction. The final output is a
conditioning embedding that can be used for various tasks such as facial recognition or
feature-based image manipulation.
Parameters:
conditioning_embedding_channels (int): The number of channels in the output embedding.
conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3.
block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels
for each block in the model. The default is (16, 32, 64, 128), which defines the
progression of the network's depth.
Attributes:
conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process.
blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model.
conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding.
The forward method applies the convolutional layers to the input conditioning tensor and
returns the resulting embedding tensor.
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 64, 128),
):
super().__init__()
self.conv_in = InflatedConv3d(
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
)
self.blocks.append(
InflatedConv3d(
channel_in, channel_out, kernel_size=3, padding=1, stride=2
)
)
self.conv_out = zero_module(
InflatedConv3d(
block_out_channels[-1],
conditioning_embedding_channels,
kernel_size=3,
padding=1,
)
)
def forward(self, conditioning):
"""
Forward pass of the FaceLocator model.
Args:
conditioning (Tensor): The input conditioning tensor.
Returns:
Tensor: The output embedding tensor.
"""
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
================================================
FILE: hallo/models/image_proj.py
================================================
"""
image_proj_model.py
This module defines the ImageProjModel class, which is responsible for
projecting image embeddings into a different dimensional space. The model
leverages a linear transformation followed by a layer normalization to
reshape and normalize the input image embeddings for further processing in
cross-attention mechanisms or other downstream tasks.
Classes:
ImageProjModel
Dependencies:
torch
diffusers.ModelMixin
"""
import torch
from diffusers import ModelMixin
class ImageProjModel(ModelMixin):
"""
ImageProjModel is a class that projects image embeddings into a different
dimensional space. It inherits from ModelMixin, providing additional functionalities
specific to image projection.
Attributes:
cross_attention_dim (int): The dimension of the cross attention.
clip_embeddings_dim (int): The dimension of the CLIP embeddings.
clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
Methods:
forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
embeddings and returns the projected tokens.
"""
def __init__(
self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4,
):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
"""
Forward pass of the ImageProjModel, which takes in image embeddings and returns the
projected tokens after reshaping and normalization.
Args:
image_embeds (torch.Tensor): The input image embeddings, with shape
batch_size x num_image_tokens x clip_embeddings_dim.
Returns:
clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
and normalization, with shape batch_size x (clip_extra_context_tokens *
cross_attention_dim).
"""
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
================================================
FILE: hallo/models/motion_module.py
================================================
# pylint: disable=R0801
# pylint: disable=W0613
# pylint: disable=W0221
"""
temporal_transformers.py
This module provides classes and functions for implementing Temporal Transformers
in PyTorch, designed for handling video data and temporal sequences within transformer-based models.
Functions:
zero_module(module)
Zero out the parameters of a module and return it.
Classes:
TemporalTransformer3DModelOutput(BaseOutput)
Dataclass for storing the output of TemporalTransformer3DModel.
VanillaTemporalModule(nn.Module)
A Vanilla Temporal Module class for handling temporal data.
TemporalTransformer3DModel(nn.Module)
A Temporal Transformer 3D Model class for transforming temporal data.
TemporalTransformerBlock(nn.Module)
A Temporal Transformer Block class for building the transformer architecture.
PositionalEncoding(nn.Module)
A Positional Encoding module for transformers to encode positional information.
Dependencies:
math
dataclasses.dataclass
typing (Callable, Optional)
torch
diffusers (FeedForward, Attention, AttnProcessor)
diffusers.utils (BaseOutput)
diffusers.utils.import_utils (is_xformers_available)
einops (rearrange, repeat)
torch.nn
xformers
xformers.ops
Example Usage:
>>> motion_module = get_motion_module(in_channels=512, motion_module_type="Vanilla", motion_module_kwargs={})
>>> output = motion_module(input_tensor, temb, encoder_hidden_states)
This module is designed to facilitate the creation, training, and inference of transformer models
that operate on temporal data, such as videos or time-series. It includes mechanisms for applying temporal attention,
managing positional encoding, and integrating with external libraries for efficient attention operations.
"""
# This code is copied from https://github.com/guoyww/AnimateDiff.
import math
import torch
import xformers
import xformers.ops
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention, AttnProcessor
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
from torch import nn
def zero_module(module):
"""
Zero out the parameters of a module and return it.
Args:
- module: A PyTorch module to zero out its parameters.
Returns:
A zeroed out PyTorch module.
"""
for p in module.parameters():
p.detach().zero_()
return module
class TemporalTransformer3DModelOutput(BaseOutput):
"""
Output class for the TemporalTransformer3DModel.
Attributes:
sample (torch.FloatTensor): The output sample tensor from the model.
"""
sample: torch.FloatTensor
def get_sample_shape(self):
"""
Returns the shape of the sample tensor.
Returns:
Tuple: The shape of the sample tensor.
"""
return self.sample.shape
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
"""
This function returns a motion module based on the given type and parameters.
Args:
- in_channels (int): The number of input channels for the motion module.
- motion_module_type (str): The type of motion module to create. Currently, only "Vanilla" is supported.
- motion_module_kwargs (dict): Additional keyword arguments to pass to the motion module constructor.
Returns:
VanillaTemporalModule: The created motion module.
Raises:
ValueError: If an unsupported motion_module_type is provided.
"""
if motion_module_type == "Vanilla":
return VanillaTemporalModule(
in_channels=in_channels,
**motion_module_kwargs,
)
raise ValueError
class VanillaTemporalModule(nn.Module):
"""
A Vanilla Temporal Module class.
Args:
- in_channels (int): The number of input channels for the motion module.
- num_attention_heads (int): Number of attention heads.
- num_transformer_block (int): Number of transformer blocks.
- attention_block_types (tuple): Types of attention blocks.
- cross_frame_attention_mode: Mode for cross-frame attention.
- temporal_position_encoding (bool): Flag for temporal position encoding.
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
- temporal_attention_dim_div (int): Divisor for temporal attention dimension.
- zero_initialize (bool): Flag for zero initialization.
"""
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels
// num_attention_heads
// temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(
self.temporal_transformer.proj_out
)
def forward(
self,
input_tensor,
encoder_hidden_states,
attention_mask=None,
):
"""
Forward pass of the TemporalTransformer3DModel.
Args:
hidden_states (torch.Tensor): The hidden states of the model.
encoder_hidden_states (torch.Tensor, optional): The hidden states of the encoder.
attention_mask (torch.Tensor, optional): The attention mask.
Returns:
torch.Tensor: The output tensor after the forward pass.
"""
hidden_states = input_tensor
hidden_states = self.temporal_transformer(
hidden_states, encoder_hidden_states
)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
"""
A Temporal Transformer 3D Model class.
Args:
- in_channels (int): The number of input channels.
- num_attention_heads (int): Number of attention heads.
- attention_head_dim (int): Dimension of attention heads.
- num_layers (int): Number of transformer layers.
- attention_block_types (tuple): Types of attention blocks.
- dropout (float): Dropout rate.
- norm_num_groups (int): Number of groups for normalization.
- cross_attention_dim (int): Dimension for cross-attention.
- activation_fn (str): Activation function.
- attention_bias (bool): Flag for attention bias.
- upcast_attention (bool): Flag for upcast attention.
- cross_frame_attention_mode: Mode for cross-frame attention.
- temporal_position_encoding (bool): Flag for temporal position encoding.
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
"""
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None):
"""
Forward pass for the TemporalTransformer3DModel.
Args:
hidden_states (torch.Tensor): The input hidden states with shape (batch_size, sequence_length, in_channels).
encoder_hidden_states (torch.Tensor, optional): The encoder hidden states with shape (batch_size, encoder_sequence_length, in_channels).
Returns:
torch.Tensor: The output hidden states with shape (batch_size, sequence_length, in_channels).
"""
assert (
hidden_states.dim() == 5
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, _, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * weight, inner_dim
)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
video_length=video_length,
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
"""
A Temporal Transformer Block class.
Args:
- dim (int): Dimension of the block.
- num_attention_heads (int): Number of attention heads.
- attention_head_dim (int): Dimension of attention heads.
- attention_block_types (tuple): Types of attention blocks.
- dropout (float): Dropout rate.
- cross_attention_dim (int): Dimension for cross-attention.
- activation_fn (str): Activation function.
- attention_bias (bool): Flag for attention bias.
- upcast_attention (bool): Flag for upcast attention.
- cross_frame_attention_mode: Mode for cross-frame attention.
- temporal_position_encoding (bool): Flag for temporal position encoding.
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_", maxsplit=1)[0],
cross_attention_dim=cross_attention_dim
if block_name.endswith("_Cross")
else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.no
gitextract_p10rx2rb/ ├── .github/ │ └── workflows/ │ └── static-check.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── README.md ├── accelerate_config.yaml ├── configs/ │ ├── inference/ │ │ ├── .gitkeep │ │ └── default.yaml │ ├── train/ │ │ ├── stage1.yaml │ │ └── stage2.yaml │ └── unet/ │ └── unet.yaml ├── hallo/ │ ├── __init__.py │ ├── animate/ │ │ ├── __init__.py │ │ ├── face_animate.py │ │ └── face_animate_static.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── audio_processor.py │ │ ├── image_processor.py │ │ ├── mask_image.py │ │ └── talk_video.py │ ├── models/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── audio_proj.py │ │ ├── face_locator.py │ │ ├── image_proj.py │ │ ├── motion_module.py │ │ ├── mutual_self_attention.py │ │ ├── resnet.py │ │ ├── transformer_2d.py │ │ ├── transformer_3d.py │ │ ├── unet_2d_blocks.py │ │ ├── unet_2d_condition.py │ │ ├── unet_3d.py │ │ ├── unet_3d_blocks.py │ │ └── wav2vec.py │ └── utils/ │ ├── __init__.py │ ├── config.py │ └── util.py ├── requirements.txt ├── scripts/ │ ├── app.py │ ├── data_preprocess.py │ ├── extract_meta_info_stage1.py │ ├── extract_meta_info_stage2.py │ ├── inference.py │ ├── train_stage1.py │ └── train_stage2.py └── setup.py
SYMBOL INDEX (259 symbols across 29 files)
FILE: hallo/animate/face_animate.py
class FaceAnimatePipelineOutput (line 46) | class FaceAnimatePipelineOutput(BaseOutput):
class FaceAnimatePipeline (line 58) | class FaceAnimatePipeline(DiffusionPipeline):
method __init__ (line 90) | def __init__(
method _execution_device (line 124) | def _execution_device(self):
method prepare_latents (line 136) | def prepare_latents(
method prepare_extra_step_kwargs (line 190) | def prepare_extra_step_kwargs(self, generator, eta):
method decode_latents (line 222) | def decode_latents(self, latents):
method __call__ (line 250) | def __call__(
FILE: hallo/animate/face_animate_static.py
class StaticPipelineOutput (line 65) | class StaticPipelineOutput(BaseOutput):
class StaticPipeline (line 76) | class StaticPipeline(DiffusionPipeline):
method __init__ (line 86) | def __init__(
method enable_vae_slicing (line 124) | def enable_vae_slicing(self):
method disable_vae_slicing (line 132) | def disable_vae_slicing(self):
method enable_sequential_cpu_offload (line 143) | def enable_sequential_cpu_offload(self, gpu_id=0):
method _execution_device (line 157) | def _execution_device(self):
method decode_latents (line 169) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 194) | def prepare_extra_step_kwargs(self, generator, eta):
method prepare_latents (line 226) | def prepare_latents(
method prepare_condition (line 278) | def prepare_condition(
method __call__ (line 313) | def __call__(
FILE: hallo/datasets/audio_processor.py
class AudioProcessor (line 22) | class AudioProcessor:
method __init__ (line 37) | def __init__(
method preprocess (line 76) | def preprocess(self, wav_file: str, clip_length: int=-1):
method get_embedding (line 131) | def get_embedding(self, wav_file: str):
method close (line 167) | def close(self):
method __enter__ (line 173) | def __enter__(self):
method __exit__ (line 176) | def __exit__(self, _exc_type, _exc_val, _exc_tb):
FILE: hallo/datasets/image_processor.py
class ImageProcessor (line 25) | class ImageProcessor:
method __init__ (line 53) | def __init__(self, img_size, face_analysis_model_path) -> None:
method preprocess (line 107) | def preprocess(self, source_image_path: str, cache_dir: str, face_regi...
method close (line 184) | def close(self):
method _augmentation (line 198) | def _augmentation(self, images, transform, state=None):
method __enter__ (line 208) | def __enter__(self):
method __exit__ (line 211) | def __exit__(self, _exc_type, _exc_val, _exc_tb):
class ImageProcessorForDataProcessing (line 215) | class ImageProcessorForDataProcessing():
method __init__ (line 243) | def __init__(self, face_analysis_model_path, landmark_model_path, step...
method preprocess (line 265) | def preprocess(self, source_image_path: str):
method close (line 318) | def close(self):
method _augmentation (line 332) | def _augmentation(self, images, transform, state=None):
method __enter__ (line 342) | def __enter__(self):
method __exit__ (line 345) | def __exit__(self, _exc_type, _exc_val, _exc_tb):
FILE: hallo/datasets/mask_image.py
class FaceMaskDataset (line 21) | class FaceMaskDataset(Dataset):
method __init__ (line 40) | def __init__(
method augmentation (line 78) | def augmentation(self, image, transform, state=None):
method __getitem__ (line 94) | def __getitem__(self, index):
method __len__ (line 143) | def __len__(self):
FILE: hallo/datasets/talk_video.py
class TalkingVideoDataset (line 83) | class TalkingVideoDataset(Dataset):
method __init__ (line 106) | def __init__(
method augmentation (line 175) | def augmentation(self, images, transform, state=None):
method __getitem__ (line 201) | def __getitem__(self, index):
method __len__ (line 315) | def __len__(self):
FILE: hallo/models/attention.py
class GatedSelfAttentionDense (line 29) | class GatedSelfAttentionDense(nn.Module):
method __init__ (line 40) | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_h...
method forward (line 57) | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
class BasicTransformerBlock (line 79) | class BasicTransformerBlock(nn.Module):
method __init__ (line 114) | def __init__(
method set_chunk_feed_forward (line 242) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int =...
method forward (line 257) | def forward(
class TemporalBasicTransformerBlock (line 410) | class TemporalBasicTransformerBlock(nn.Module):
method __init__ (line 429) | def __init__(
method forward (line 542) | def forward(
class AudioTemporalBasicTransformerBlock (line 621) | class AudioTemporalBasicTransformerBlock(nn.Module):
method __init__ (line 643) | def __init__(
method forward (line 784) | def forward(
function zero_module (line 909) | def zero_module(module):
FILE: hallo/models/audio_proj.py
class AudioProjModel (line 40) | class AudioProjModel(ModelMixin):
method __init__ (line 68) | def __init__(
method forward (line 96) | def forward(self, audio_embeds):
FILE: hallo/models/face_locator.py
class FaceLocator (line 34) | class FaceLocator(ModelMixin):
method __init__ (line 60) | def __init__(
method forward (line 94) | def forward(self, conditioning):
FILE: hallo/models/image_proj.py
class ImageProjModel (line 23) | class ImageProjModel(ModelMixin):
method __init__ (line 40) | def __init__(
method forward (line 56) | def forward(self, image_embeds):
FILE: hallo/models/motion_module.py
function zero_module (line 68) | def zero_module(module):
class TemporalTransformer3DModelOutput (line 83) | class TemporalTransformer3DModelOutput(BaseOutput):
method get_sample_shape (line 92) | def get_sample_shape(self):
function get_motion_module (line 102) | def get_motion_module(in_channels, motion_module_type: str, motion_modul...
class VanillaTemporalModule (line 126) | class VanillaTemporalModule(nn.Module):
method __init__ (line 142) | def __init__(
method forward (line 174) | def forward(
class TemporalTransformer3DModel (line 200) | class TemporalTransformer3DModel(nn.Module):
method __init__ (line 220) | def __init__(
method forward (line 270) | def forward(self, hidden_states, encoder_hidden_states=None):
class TemporalTransformerBlock (line 319) | class TemporalTransformerBlock(nn.Module):
method __init__ (line 337) | def __init__(
method forward (line 387) | def forward(
class PositionalEncoding (line 426) | class PositionalEncoding(nn.Module):
method __init__ (line 435) | def __init__(self, d_model, dropout=0.0, max_len=24):
method forward (line 447) | def forward(self, x):
class VersatileAttention (line 464) | class VersatileAttention(Attention):
method __init__ (line 473) | def __init__(
method extra_repr (line 498) | def extra_repr(self):
method set_use_memory_efficient_attention_xformers (line 507) | def set_use_memory_efficient_attention_xformers(
method forward (line 553) | def forward(
FILE: hallo/models/mutual_self_attention.py
function torch_dfs (line 19) | def torch_dfs(model: torch.nn.Module):
class ReferenceAttentionControl (line 39) | class ReferenceAttentionControl:
method __init__ (line 64) | def __init__(
method register_reference_hooks (line 115) | def register_reference_hooks(
method update (line 404) | def update(self, writer, dtype=torch.float16):
method clear (line 456) | def clear(self):
FILE: hallo/models/resnet.py
class InflatedConv3d (line 30) | class InflatedConv3d(nn.Conv2d):
method forward (line 50) | def forward(self, x):
class InflatedGroupNorm (line 69) | class InflatedGroupNorm(nn.GroupNorm):
method forward (line 88) | def forward(self, x):
class Upsample3D (line 104) | class Upsample3D(nn.Module):
method __init__ (line 115) | def __init__(
method forward (line 135) | def forward(self, hidden_states, output_size=None):
class Downsample3D (line 188) | class Downsample3D(nn.Module):
method __init__ (line 204) | def __init__(
method forward (line 232) | def forward(self, hidden_states):
class ResnetBlock3D (line 255) | class ResnetBlock3D(nn.Module):
method __init__ (line 279) | def __init__(
method forward (line 372) | def forward(self, input_tensor, temb):
class Mish (line 415) | class Mish(torch.nn.Module):
method forward (line 425) | def forward(self, hidden_states):
FILE: hallo/models/transformer_2d.py
class Transformer2DModelOutput (line 51) | class Transformer2DModelOutput(BaseOutput):
class Transformer2DModel (line 66) | class Transformer2DModel(ModelMixin, ConfigMixin):
method __init__ (line 97) | def __init__(
method _set_gradient_checkpointing (line 241) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 245) | def forward(
FILE: hallo/models/transformer_3d.py
class Transformer3DModelOutput (line 26) | class Transformer3DModelOutput(BaseOutput):
class Transformer3DModel (line 38) | class Transformer3DModel(ModelMixin, ConfigMixin):
method __init__ (line 47) | def __init__(
method _set_gradient_checkpointing (line 143) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 147) | def forward(
FILE: hallo/models/unet_2d_blocks.py
function get_down_block (line 35) | def get_down_block(
function get_up_block (line 132) | def get_up_block(
class AutoencoderTinyBlock (line 215) | class AutoencoderTinyBlock(nn.Module):
method __init__ (line 231) | def __init__(self, in_channels: int, out_channels: int, act_fn: str):
method forward (line 248) | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
class UNetMidBlock2D (line 261) | class UNetMidBlock2D(nn.Module):
method __init__ (line 292) | def __init__(
method forward (line 384) | def forward(
class UNetMidBlock2DCrossAttn (line 407) | class UNetMidBlock2DCrossAttn(nn.Module):
method __init__ (line 428) | def __init__(
method forward (line 523) | def forward(
class CrossAttnDownBlock2D (line 595) | class CrossAttnDownBlock2D(nn.Module):
method __init__ (line 627) | def __init__(
method forward (line 722) | def forward(
class DownBlock2D (line 812) | class DownBlock2D(nn.Module):
method __init__ (line 842) | def __init__(
method forward (line 897) | def forward(
class CrossAttnUpBlock2D (line 950) | class CrossAttnUpBlock2D(nn.Module):
method __init__ (line 987) | def __init__(
method forward (line 1079) | def forward(
class UpBlock2D (line 1186) | class UpBlock2D(nn.Module):
method __init__ (line 1217) | def __init__(
method forward (line 1268) | def forward(
FILE: hallo/models/unet_2d_condition.py
class UNet2DConditionOutput (line 80) | class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel (line 93) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
method __init__ (line 191) | def __init__(
method attn_processors (line 703) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 733) | def set_attn_processor(
method set_default_attn_processor (line 774) | def set_default_attn_processor(self):
method set_attention_slice (line 795) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 866) | def _set_gradient_checkpointing(self, module, value=False):
method enable_freeu (line 870) | def enable_freeu(self, s1, s2, b1, b2):
method disable_freeu (line 894) | def disable_freeu(self):
method forward (line 905) | def forward(
method load_change_cross_attention_dim (line 1361) | def load_change_cross_attention_dim(
FILE: hallo/models/unet_3d.py
class UNet3DConditionOutput (line 47) | class UNet3DConditionOutput(BaseOutput):
class UNet3DConditionModel (line 59) | class UNet3DConditionModel(ModelMixin, ConfigMixin):
method __init__ (line 121) | def __init__(
method attn_processors (line 365) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attention_slice (line 395) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 466) | def _set_gradient_checkpointing(self, module, value=False):
method set_attn_processor (line 471) | def set_attn_processor(
method forward (line 510) | def forward(
method from_pretrained_2d (line 718) | def from_pretrained_2d(
FILE: hallo/models/unet_3d_blocks.py
function get_down_block (line 26) | def get_down_block(
function get_up_block (line 137) | def get_up_block(
class UNetMidBlock3DCrossAttn (line 247) | class UNetMidBlock3DCrossAttn(nn.Module):
method __init__ (line 283) | def __init__(
method forward (line 407) | def forward(
class CrossAttnDownBlock3D (line 497) | class CrossAttnDownBlock3D(nn.Module):
method __init__ (line 509) | def __init__(
method forward (line 638) | def forward(
class DownBlock3D (line 783) | class DownBlock3D(nn.Module):
method __init__ (line 812) | def __init__(
method forward (line 884) | def forward(
class CrossAttnUpBlock3D (line 940) | class CrossAttnUpBlock3D(nn.Module):
method __init__ (line 969) | def __init__(
method forward (line 1092) | def forward(
class UpBlock3D (line 1238) | class UpBlock3D(nn.Module):
method __init__ (line 1281) | def __init__(
method forward (line 1347) | def forward(
FILE: hallo/models/wav2vec.py
class Wav2VecModel (line 21) | class Wav2VecModel(Wav2Vec2Model):
method forward (line 42) | def forward(
method feature_extract (line 112) | def feature_extract(
method encode (line 133) | def encode(
function linear_interpolation (line 196) | def linear_interpolation(features, seq_len):
FILE: hallo/utils/config.py
function filter_non_none (line 8) | def filter_non_none(dict_obj: Dict):
FILE: hallo/utils/util.py
function seed_everything (line 84) | def seed_everything(seed):
function import_filename (line 97) | def import_filename(filename):
function delete_additional_ckpt (line 120) | def delete_additional_ckpt(base_path, num_keep):
function save_videos_from_pil (line 154) | def save_videos_from_pil(pil_images, path, fps=8):
function save_videos_grid (line 206) | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_r...
function read_frames (line 244) | def read_frames(video_path):
function get_fps (line 280) | def get_fps(video_path):
function tensor_to_video (line 297) | def tensor_to_video(tensor, output_video_file, audio_source, fps=25):
function compute_face_landmarks (line 332) | def compute_face_landmarks(detection_result, h, w):
function get_landmark (line 351) | def get_landmark(file):
function get_landmark_overframes (line 382) | def get_landmark_overframes(landmark_model, frames_path):
function get_lip_mask (line 407) | def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2...
function get_union_lip_mask (line 433) | def get_union_lip_mask(landmarks, height, width, expand_ratio=1):
function get_face_mask (line 451) | def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=...
function get_union_face_mask (line 479) | def get_union_face_mask(landmarks, height, width, expand_ratio=1):
function get_mask (line 497) | def get_mask(file, cache_dir, face_expand_raio):
function expand_region (line 529) | def expand_region(region, image_w, image_h, expand_ratio=1.0):
function get_blur_mask (line 567) | def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kern...
function blur_mask (line 589) | def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)):
function get_background_mask (line 611) | def get_background_mask(file_path, output_file_path):
function get_sep_face_mask (line 638) | def get_sep_face_mask(file_path1, file_path2, output_file_path):
function resample_audio (line 668) | def resample_audio(input_audio_file: str, output_audio_file: str, sample...
function get_face_region (line 676) | def get_face_region(image_path: str, detector):
function save_checkpoint (line 707) | def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ...
function init_output_dir (line 771) | def init_output_dir(dir_list: List[str]):
function load_checkpoint (line 784) | def load_checkpoint(cfg, save_dir, accelerator):
function compute_snr (line 822) | def compute_snr(noise_scheduler, timesteps):
function extract_audio_from_videos (line 854) | def extract_audio_from_videos(video_path: Path, audio_output_path: Path)...
function convert_video_to_images (line 889) | def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
function get_union_mask (line 923) | def get_union_mask(masks):
function move_final_checkpoint (line 960) | def move_final_checkpoint(save_dir, module_dir, prefix):
FILE: scripts/app.py
function predict (line 18) | def predict(image, audio, pose_weight, face_weight, lip_weight, face_exp...
FILE: scripts/data_preprocess.py
function setup_directories (line 33) | def setup_directories(video_path: Path) -> dict:
function process_single_video (line 59) | def process_single_video(video_path: Path,
function process_all_videos (line 116) | def process_all_videos(input_video_list: List[Path], output_dir: Path, s...
function get_video_paths (line 148) | def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> Li...
FILE: scripts/extract_meta_info_stage1.py
function collect_video_folder_paths (line 27) | def collect_video_folder_paths(root_path: Path) -> list:
function construct_meta_info (line 40) | def construct_meta_info(frames_dir_path: Path) -> dict:
function main (line 68) | def main():
FILE: scripts/extract_meta_info_stage2.py
function get_video_paths (line 37) | def get_video_paths(root_path: Path, extensions: list) -> list:
function file_exists (line 51) | def file_exists(file_path: str) -> bool:
function construct_paths (line 64) | def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ex...
function extract_meta_info (line 80) | def extract_meta_info(video_path: str) -> dict:
function main (line 155) | def main():
FILE: scripts/inference.py
class Net (line 51) | class Net(nn.Module):
method __init__ (line 62) | def __init__(
method forward (line 77) | def forward(self,):
method get_modules (line 82) | def get_modules(self):
function process_audio_emb (line 95) | def process_audio_emb(audio_emb):
function inference_process (line 118) | def inference_process(args: argparse.Namespace):
FILE: scripts/train_stage1.py
class Net (line 68) | class Net(nn.Module):
method __init__ (line 93) | def __init__(
method forward (line 110) | def forward(
function get_noise_scheduler (line 157) | def get_noise_scheduler(cfg: argparse.Namespace):
function log_validation (line 181) | def log_validation(
function train_stage1_process (line 289) | def train_stage1_process(cfg: argparse.Namespace) -> None:
function load_config (line 765) | def load_config(config_path: str) -> dict:
FILE: scripts/train_stage2.py
class Net (line 74) | class Net(nn.Module):
method __init__ (line 104) | def __init__(
method forward (line 123) | def forward(
function get_attention_mask (line 182) | def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) ->...
function get_noise_scheduler (line 203) | def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler,...
function process_audio_emb (line 228) | def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
function log_validation (line 250) | def log_validation(
function train_stage2_process (line 421) | def train_stage2_process(cfg: argparse.Namespace) -> None:
function load_config (line 962) | def load_config(config_path: str) -> dict:
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (661K chars).
[
{
"path": ".github/workflows/static-check.yaml",
"chars": 747,
"preview": "name: Pylint\n\non: [push, pull_request]\n\njobs:\n static-check:\n runs-on: ${{ matrix.os }}\n strategy:\n matrix:\n"
},
{
"path": ".gitignore",
"chars": 2859,
"preview": "# running cache\nmlruns/\n\n# Test directories\ntest_data/\npretrained_models/\n\n# Poetry project\npoetry.lock\n\n# Byte-compiled"
},
{
"path": ".pre-commit-config.yaml",
"chars": 359,
"preview": "repos:\n - repo: local\n hooks:\n - id: isort\n name: isort\n language: system\n types: [python]"
},
{
"path": ".pylintrc",
"chars": 21382,
"preview": "[MAIN]\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means"
},
{
"path": "LICENSE",
"chars": 1110,
"preview": "MIT License\n\nCopyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University\n\nPermission is hereby granted, fre"
},
{
"path": "README.md",
"chars": 19957,
"preview": "<h1 align='center'>Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation</h1>\n\n<div align='cent"
},
{
"path": "accelerate_config.yaml",
"chars": 486,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: true\ndeepspeed_config:\n deepspeed_multinode_launcher: standard\n gradient_acc"
},
{
"path": "configs/inference/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "configs/inference/default.yaml",
"chars": 1900,
"preview": "source_image: examples/reference_images/1.jpg\ndriving_audio: examples/driving_audios/1.wav\n\nweight_dtype: fp16\n\ndata:\n "
},
{
"path": "configs/train/stage1.yaml",
"chars": 1328,
"preview": "data:\n train_bs: 8\n train_width: 512\n train_height: 512\n meta_paths:\n - \"./data/HDTF_meta.json\"\n # Margin of fra"
},
{
"path": "configs/train/stage2.yaml",
"chars": 2654,
"preview": "data:\n train_bs: 4\n val_bs: 1\n train_width: 512\n train_height: 512\n fps: 25\n sample_rate: 16000\n n_motion_frames:"
},
{
"path": "configs/unet/unet.yaml",
"chars": 1023,
"preview": "unet_additional_kwargs:\n use_inflated_groupnorm: true\n unet_use_cross_frame_attention: false\n unet_use_temporal_atten"
},
{
"path": "hallo/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "hallo/animate/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "hallo/animate/face_animate.py",
"chars": 19299,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module is responsible for animating faces in videos using a combination of deep learnin"
},
{
"path": "hallo/animate/face_animate_static.py",
"chars": 18753,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module is responsible for handling the animation of faces using a combination of deep l"
},
{
"path": "hallo/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "hallo/datasets/audio_processor.py",
"chars": 7330,
"preview": "# pylint: disable=C0301\n'''\nThis module contains the AudioProcessor class and related functions for processing audio dat"
},
{
"path": "hallo/datasets/image_processor.py",
"chars": 13650,
"preview": "# pylint: disable=W0718\n\"\"\"\nThis module is responsible for processing images, particularly for face-related tasks.\nIt us"
},
{
"path": "hallo/datasets/mask_image.py",
"chars": 5348,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module contains the code for a dataset class called FaceMaskDataset, which is used to p"
},
{
"path": "hallo/datasets/talk_video.py",
"chars": 12492,
"preview": "# pylint: disable=R0801\n\"\"\"\ntalking_video_dataset.py\n\nThis module defines the TalkingVideoDataset class, a custom PyTorc"
},
{
"path": "hallo/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "hallo/models/attention.py",
"chars": 40117,
"preview": "# pylint: disable=R0801\n# pylint: disable=C0303\n\n\"\"\"\nThis module contains various transformer blocks for different appli"
},
{
"path": "hallo/models/audio_proj.py",
"chars": 5007,
"preview": "\"\"\"\nThis module provides the implementation of an Audio Projection Model, which is designed for\naudio processing tasks. "
},
{
"path": "hallo/models/face_locator.py",
"chars": 4580,
"preview": "\"\"\"\nThis module implements the FaceLocator class, which is a neural network model designed to\nlocate and extract facial "
},
{
"path": "hallo/models/image_proj.py",
"chars": 2571,
"preview": "\"\"\"\nimage_proj_model.py\n\nThis module defines the ImageProjModel class, which is responsible for\nprojecting image embeddi"
},
{
"path": "hallo/models/motion_module.py",
"chars": 21381,
"preview": "# pylint: disable=R0801\n# pylint: disable=W0613\n# pylint: disable=W0221\n\n\"\"\"\ntemporal_transformers.py\n\nThis module provi"
},
{
"path": "hallo/models/mutual_self_attention.py",
"chars": 21535,
"preview": "# pylint: disable=E1120\n\"\"\"\nThis module contains the implementation of mutual self-attention, \nwhich is a type of attent"
},
{
"path": "hallo/models/resnet.py",
"chars": 16079,
"preview": "# pylint: disable=E1120\n# pylint: disable=E1102\n# pylint: disable=W0237\n\n# src/models/resnet.py\n\n\"\"\"\nThis module defines"
},
{
"path": "hallo/models/transformer_2d.py",
"chars": 20356,
"preview": "# pylint: disable=E1101\n# src/models/transformer_2d.py\n\n\"\"\"\nThis module defines the Transformer2DModel, a PyTorch model "
},
{
"path": "hallo/models/transformer_3d.py",
"chars": 10140,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module implements the Transformer3DModel, a PyTorch model designed for processing\n3D da"
},
{
"path": "hallo/models/unet_2d_blocks.py",
"chars": 57770,
"preview": "# pylint: disable=R0801\n# pylint: disable=W1203\n\n\"\"\"\nThis file defines the 2D blocks for the UNet model in a PyTorch imp"
},
{
"path": "hallo/models/unet_2d_condition.py",
"chars": 69276,
"preview": "# pylint: disable=R0801\n# pylint: disable=E1101\n# pylint: disable=W1203\n\n\"\"\"\nThis module implements the `UNet2DCondition"
},
{
"path": "hallo/models/unet_3d.py",
"chars": 37102,
"preview": "# pylint: disable=R0801\n# pylint: disable=E1101\n# pylint: disable=R0402\n# pylint: disable=W1203\n\n\"\"\"\nThis is the main fi"
},
{
"path": "hallo/models/unet_3d_blocks.py",
"chars": 56716,
"preview": "# pylint: disable=R0801\n# src/models/unet_3d_blocks.py\n\n\"\"\"\nThis module defines various 3D UNet blocks used in the video"
},
{
"path": "hallo/models/wav2vec.py",
"chars": 8035,
"preview": "# pylint: disable=R0901\n# src/models/wav2vec.py\n\n\"\"\"\nThis module defines the Wav2Vec model, which is a pre-trained model"
},
{
"path": "hallo/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "hallo/utils/config.py",
"chars": 797,
"preview": "\"\"\"\nThis module provides utility functions for configuration manipulation.\n\"\"\"\n\nfrom typing import Dict\n\n\ndef filter_non"
},
{
"path": "hallo/utils/util.py",
"chars": 34402,
"preview": "# pylint: disable=C0116\n# pylint: disable=W0718\n# pylint: disable=R1732\n# pylint: disable=R0801\n\"\"\"\nutils.py\n\nThis modul"
},
{
"path": "requirements.txt",
"chars": 657,
"preview": "--find-links https://download.pytorch.org/whl/torch_stable.html\n\naccelerate==0.28.0\naudio-separator==0.17.2\nav==12.1.0\nb"
},
{
"path": "scripts/app.py",
"chars": 1505,
"preview": "\"\"\"\nThis script is a gradio web ui.\n\nThe script takes an image and an audio clip, and lets you configure all the\nvariabl"
},
{
"path": "scripts/data_preprocess.py",
"chars": 7711,
"preview": "# pylint: disable=W1203,W0718\n\"\"\"\nThis module is used to process videos to prepare data for training. It utilizes variou"
},
{
"path": "scripts/extract_meta_info_stage1.py",
"chars": 3361,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module is used to extract meta information from video directories.\n\nIt takes in two com"
},
{
"path": "scripts/extract_meta_info_stage2.py",
"chars": 6681,
"preview": "# pylint: disable=R0801\n\"\"\"\nThis module is used to extract meta information from video files and store them in a JSON fi"
},
{
"path": "scripts/inference.py",
"chars": 13649,
"preview": "# pylint: disable=E1101\n# scripts/inference.py\n\n\"\"\"\nThis script contains the main inference pipeline for processing audi"
},
{
"path": "scripts/train_stage1.py",
"chars": 29374,
"preview": "# pylint: disable=E1101,C0415,W0718,R0801\n# scripts/train_stage1.py\n\"\"\"\nThis is the main training script for stage 1 of "
},
{
"path": "scripts/train_stage2.py",
"chars": 37347,
"preview": "# pylint: disable=E1101,C0415,W0718,R0801\n# scripts/train_stage2.py\n\"\"\"\nThis is the main training script for stage 2 of "
},
{
"path": "setup.py",
"chars": 1309,
"preview": "\"\"\"\nsetup.py\n----\nThis is the main setup file for the hallo face animation project. It defines the package\nmetadata, req"
}
]
About this extraction
This page contains the full source code of the fudan-generative-vision/hallo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (623.2 KB), approximately 136.5k tokens, and a symbol index with 259 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.