Full Code of PeterL1n/RobustVideoMatting for AI

master 53d74c682673 cached
42 files
233.4 KB
70.8k tokens
236 symbols
1 requests
Download .txt
Showing preview only (246K chars total). Download the full file or copy to clipboard to get everything.
Repository: PeterL1n/RobustVideoMatting
Branch: master
Commit: 53d74c682673
Files: 42
Total size: 233.4 KB

Directory structure:
gitextract_gj6tfm4_/

├── LICENSE
├── README.md
├── README_zh_Hans.md
├── dataset/
│   ├── augmentation.py
│   ├── coco.py
│   ├── imagematte.py
│   ├── spd.py
│   ├── videomatte.py
│   └── youtubevis.py
├── documentation/
│   ├── inference.md
│   ├── inference_zh_Hans.md
│   ├── misc/
│   │   ├── aim_test.txt
│   │   ├── d646_test.txt
│   │   ├── dvm_background_test_clips.txt
│   │   ├── dvm_background_train_clips.txt
│   │   ├── imagematte_train.txt
│   │   ├── imagematte_valid.txt
│   │   └── spd_preprocess.py
│   └── training.md
├── evaluation/
│   ├── evaluate_hr.py
│   ├── evaluate_lr.py
│   ├── generate_imagematte_with_background_image.py
│   ├── generate_imagematte_with_background_video.py
│   ├── generate_videomatte_with_background_image.py
│   └── generate_videomatte_with_background_video.py
├── hubconf.py
├── inference.py
├── inference_speed_test.py
├── inference_utils.py
├── model/
│   ├── __init__.py
│   ├── decoder.py
│   ├── deep_guided_filter.py
│   ├── fast_guided_filter.py
│   ├── lraspp.py
│   ├── mobilenetv3.py
│   ├── model.py
│   └── resnet.py
├── requirements_inference.txt
├── requirements_training.txt
├── train.py
├── train_config.py
└── train_loss.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
                    GNU GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU General Public License is a free, copyleft license for
software and other kinds of works.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.  We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors.  You can apply it to
your programs, too.

  When we speak of free software, we are referring to freedom, not
price.  Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.

  To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights.  Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.

  For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received.  You must make sure that they, too, receive
or can get the source code.  And you must show them these terms so they
know their rights.

  Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.

  For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software.  For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.

  Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so.  This is fundamentally incompatible with the aim of
protecting users' freedom to change the software.  The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable.  Therefore, we
have designed this version of the GPL to prohibit the practice for those
products.  If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.

  Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary.  To prevent this, the GPL assures that
patents cannot be used to render the program non-free.

  The precise terms and conditions for copying, distribution and
modification follow.

                       TERMS AND CONDITIONS

  0. Definitions.

  "This License" refers to version 3 of the GNU General Public License.

  "Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.

  "The Program" refers to any copyrightable work licensed under this
License.  Each licensee is addressed as "you".  "Licensees" and
"recipients" may be individuals or organizations.

  To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy.  The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.

  A "covered work" means either the unmodified Program or a work based
on the Program.

  To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy.  Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.

  To "convey" a work means any kind of propagation that enables other
parties to make or receive copies.  Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.

  An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License.  If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.

  1. Source Code.

  The "source code" for a work means the preferred form of the work
for making modifications to it.  "Object code" means any non-source
form of a work.

  A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.

  The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form.  A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.

  The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities.  However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work.  For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.

  The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.

  The Corresponding Source for a work in source code form is that
same work.

  2. Basic Permissions.

  All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met.  This License explicitly affirms your unlimited
permission to run the unmodified Program.  The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work.  This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.

  You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force.  You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright.  Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.

  Conveying under any other circumstances is permitted solely under
the conditions stated below.  Sublicensing is not allowed; section 10
makes it unnecessary.

  3. Protecting Users' Legal Rights From Anti-Circumvention Law.

  No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.

  When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.

  4. Conveying Verbatim Copies.

  You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.

  You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.

  5. Conveying Modified Source Versions.

  You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:

    a) The work must carry prominent notices stating that you modified
    it, and giving a relevant date.

    b) The work must carry prominent notices stating that it is
    released under this License and any conditions added under section
    7.  This requirement modifies the requirement in section 4 to
    "keep intact all notices".

    c) You must license the entire work, as a whole, under this
    License to anyone who comes into possession of a copy.  This
    License will therefore apply, along with any applicable section 7
    additional terms, to the whole of the work, and all its parts,
    regardless of how they are packaged.  This License gives no
    permission to license the work in any other way, but it does not
    invalidate such permission if you have separately received it.

    d) If the work has interactive user interfaces, each must display
    Appropriate Legal Notices; however, if the Program has interactive
    interfaces that do not display Appropriate Legal Notices, your
    work need not make them do so.

  A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit.  Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.

  6. Conveying Non-Source Forms.

  You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:

    a) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by the
    Corresponding Source fixed on a durable physical medium
    customarily used for software interchange.

    b) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by a
    written offer, valid for at least three years and valid for as
    long as you offer spare parts or customer support for that product
    model, to give anyone who possesses the object code either (1) a
    copy of the Corresponding Source for all the software in the
    product that is covered by this License, on a durable physical
    medium customarily used for software interchange, for a price no
    more than your reasonable cost of physically performing this
    conveying of source, or (2) access to copy the
    Corresponding Source from a network server at no charge.

    c) Convey individual copies of the object code with a copy of the
    written offer to provide the Corresponding Source.  This
    alternative is allowed only occasionally and noncommercially, and
    only if you received the object code with such an offer, in accord
    with subsection 6b.

    d) Convey the object code by offering access from a designated
    place (gratis or for a charge), and offer equivalent access to the
    Corresponding Source in the same way through the same place at no
    further charge.  You need not require recipients to copy the
    Corresponding Source along with the object code.  If the place to
    copy the object code is a network server, the Corresponding Source
    may be on a different server (operated by you or a third party)
    that supports equivalent copying facilities, provided you maintain
    clear directions next to the object code saying where to find the
    Corresponding Source.  Regardless of what server hosts the
    Corresponding Source, you remain obligated to ensure that it is
    available for as long as needed to satisfy these requirements.

    e) Convey the object code using peer-to-peer transmission, provided
    you inform other peers where the object code and Corresponding
    Source of the work are being offered to the general public at no
    charge under subsection 6d.

  A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.

  A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling.  In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage.  For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product.  A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.

  "Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source.  The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.

  If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information.  But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).

  The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed.  Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.

  Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.

  7. Additional Terms.

  "Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law.  If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.

  When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it.  (Additional permissions may be written to require their own
removal in certain cases when you modify the work.)  You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.

  Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:

    a) Disclaiming warranty or limiting liability differently from the
    terms of sections 15 and 16 of this License; or

    b) Requiring preservation of specified reasonable legal notices or
    author attributions in that material or in the Appropriate Legal
    Notices displayed by works containing it; or

    c) Prohibiting misrepresentation of the origin of that material, or
    requiring that modified versions of such material be marked in
    reasonable ways as different from the original version; or

    d) Limiting the use for publicity purposes of names of licensors or
    authors of the material; or

    e) Declining to grant rights under trademark law for use of some
    trade names, trademarks, or service marks; or

    f) Requiring indemnification of licensors and authors of that
    material by anyone who conveys the material (or modified versions of
    it) with contractual assumptions of liability to the recipient, for
    any liability that these contractual assumptions directly impose on
    those licensors and authors.

  All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10.  If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term.  If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.

  If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.

  Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.

  8. Termination.

  You may not propagate or modify a covered work except as expressly
provided under this License.  Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).

  However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.

  Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.

  Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License.  If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.

  9. Acceptance Not Required for Having Copies.

  You are not required to accept this License in order to receive or
run a copy of the Program.  Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance.  However,
nothing other than this License grants you permission to propagate or
modify any covered work.  These actions infringe copyright if you do
not accept this License.  Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.

  10. Automatic Licensing of Downstream Recipients.

  Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License.  You are not responsible
for enforcing compliance by third parties with this License.

  An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations.  If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.

  You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License.  For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.

  11. Patents.

  A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based.  The
work thus licensed is called the contributor's "contributor version".

  A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version.  For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.

  Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.

  In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement).  To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.

  If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients.  "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.

  If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.

  A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License.  You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.

  Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.

  12. No Surrender of Others' Freedom.

  If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License.  If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all.  For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.

  13. Use with the GNU Affero General Public License.

  Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work.  The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.

  14. Revised Versions of this License.

  The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time.  Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.

  Each version is given a distinguishing version number.  If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation.  If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.

  If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.

  Later license versions may give you additional or different
permissions.  However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.

  15. Disclaimer of Warranty.

  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.

  16. Limitation of Liability.

  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.

  17. Interpretation of Sections 15 and 16.

  If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.

                     END OF TERMS AND CONDITIONS

            How to Apply These Terms to Your New Programs

  If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.

  To do so, attach the following notices to the program.  It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

    <one line to give the program's name and a brief idea of what it does.>
    Copyright (C) <year>  <name of author>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

Also add information on how to contact you by electronic and paper mail.

  If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:

    <program>  Copyright (C) <year>  <name of author>
    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
    This is free software, and you are welcome to redistribute it
    under certain conditions; type `show c' for details.

The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License.  Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".

  You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.

  The GNU General Public License does not permit incorporating your program
into proprietary programs.  If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library.  If this is what you want to do, use the GNU Lesser General
Public License instead of this License.  But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

================================================
FILE: README.md
================================================
# Robust Video Matting (RVM)

![Teaser](/documentation/image/teaser.gif)

<p align="center">English | <a href="README_zh_Hans.md">中文</a></p>

Official repository for the paper [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/). RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves **4K 76FPS** and **HD 104FPS** on an Nvidia GTX 1080 Ti GPU. The project was developed at [ByteDance Inc.](https://www.bytedance.com/)

<br>

## News

* [Nov 03 2021] Fixed a bug in [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f).
* [Sep 16 2021] Code is re-released under GPL-3.0 license.
* [Aug 25 2021] Source code and pretrained models are published.
* [Jul 27 2021] Paper is accepted by WACV 2022.

<br>

## Showreel
Watch the showreel video ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)) to see the model's performance. 

<p align="center">
    <a href="https://youtu.be/Jvzltozpbpk">
        <img src="documentation/image/showreel.gif">
    </a>
</p>

All footage in the video are available in [Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing).

<br>


## Demo
* [Webcam Demo](https://peterl1n.github.io/RobustVideoMatting/#/demo): Run the model live in your browser. Visualize recurrent states.
* [Colab Demo](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): Test our model on your own videos with free GPU. 

<br>

## Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See [inference documentation](documentation/inference.md) for more instructions.

<table>
    <thead>
        <tr>
            <td>Framework</td>
            <td>Download</td>
            <td>Notes</td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td>PyTorch</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth">rvm_mobilenetv3.pth</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth">rvm_resnet50.pth</a>
            </td>
            <td>
                Official weights for PyTorch. <a href="documentation/inference.md#pytorch">Doc</a>
            </td>
        </tr>
        <tr>
            <td>TorchHub</td>
            <td>
                Nothing to Download.
            </td>
            <td>
                Easiest way to use our model in your PyTorch project. <a href="documentation/inference.md#torchhub">Doc</a>
            </td>
        </tr>
        <tr>
            <td>TorchScript</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp32.torchscript">rvm_mobilenetv3_fp32.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp16.torchscript">rvm_mobilenetv3_fp16.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp32.torchscript">rvm_resnet50_fp32.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp16.torchscript">rvm_resnet50_fp16.torchscript</a>
            </td>
            <td>
                If inference on mobile, consider export int8 quantized models yourself. <a href="documentation/inference.md#torchscript">Doc</a>
            </td>
        </tr>
        <tr>
            <td>ONNX</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp32.onnx">rvm_mobilenetv3_fp32.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp16.onnx">rvm_mobilenetv3_fp16.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp32.onnx">rvm_resnet50_fp32.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp16.onnx">rvm_resnet50_fp16.onnx</a>
            </td>
            <td>
                Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. <a href="documentation/inference.md#onnx">Doc</a>, <a href="https://github.com/PeterL1n/RobustVideoMatting/tree/onnx">Exporter</a>.
            </td>
        </tr>
        <tr>
            <td>TensorFlow</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_tf.zip">rvm_mobilenetv3_tf.zip</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_tf.zip">rvm_resnet50_tf.zip</a>
            </td>
            <td>
                TensorFlow 2 SavedModel. <a href="documentation/inference.md#tensorflow">Doc</a>
            </td>
        </tr>
        <tr>
            <td>TensorFlow.js</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_tfjs_int8.zip">rvm_mobilenetv3_tfjs_int8.zip</a><br>
            </td>
            <td>
                Run the model on the web. <a href="https://peterl1n.github.io/RobustVideoMatting/#/demo">Demo</a>, <a href="https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs">Starter Code</a>
            </td>
        </tr>
        <tr>
            <td>CoreML</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel">rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel">rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel">rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel">rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel</a><br>
            </td>
            <td>
                CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. <code>s</code> denotes <code>downsample_ratio</code>. <a href="documentation/inference.md#coreml">Doc</a>, <a href="https://github.com/PeterL1n/RobustVideoMatting/tree/coreml">Exporter</a>
            </td>
        </tr>
    </tbody>
</table>

All models are available in [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) and [Baidu Pan](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA) (code: gym7).

<br>

## PyTorch Example

1. Install dependencies:
```sh
pip install -r requirements_inference.txt
```

2. Load the model:

```python
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
```

3. To convert videos, we provide a simple conversion API:

```python
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='com.mp4',    # File path if video; directory path if png sequence.
    output_alpha="pha.mp4",          # [Optional] Output the raw alpha prediction.
    output_foreground="fgr.mp4",     # [Optional] Output the raw foreground prediction.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
```

4. Or write your own inference code:
```python
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
```

5. The models and converter API are also available through TorchHub.

```python
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
```

Please see [inference documentation](documentation/inference.md) for details on `downsample_ratio` hyperparameter, more converter arguments, and more advanced usage.

<br>

## Training and Evaluation

Please refer to the [training documentation](documentation/training.md) to train and evaluate your own model.

<br>

## Speed

Speed is measured with `inference_speed_test.py` for reference.

| GPU            | dType | HD (1920x1080) | 4K (3840x2160) |
| -------------- | ----- | -------------- |----------------|
| RTX 3090       | FP16  | 172 FPS        | 154 FPS        |
| RTX 2060 Super | FP16  | 134 FPS        | 108 FPS        |
| GTX 1080 Ti    | FP32  | 104 FPS        | 74 FPS         |

* Note 1: HD uses `downsample_ratio=0.25`, 4K uses `downsample_ratio=0.125`. All tests use batch size 1 and frame chunk 1.
* Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
* Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework).

<br>  

## Project Members
* [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)
* [Linjie Yang](https://sites.google.com/site/linjieyang89/)
* [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/)
* [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/)

<br>

## Third-Party Projects

* [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu))
* [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth))
* [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391))
* [Unity Engine demo with NatML](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite))  
* [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth))
* [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth))



================================================
FILE: README_zh_Hans.md
================================================
# 稳定视频抠像 (RVM)

![Teaser](/documentation/image/teaser.gif)

<p align="center"><a href="README.md">English</a> | 中文</p>

论文 [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/) 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 **4K 76FPS** 和 **HD 104FPS**。此研究项目来自[字节跳动](https://www.bytedance.com/)。

<br>

## 更新

* [2021年11月3日] 修复了 [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f) 的 bug。
* [2021年9月16日] 代码重新以 GPL-3.0 许可发布。
* [2021年8月25日] 公开代码和模型。
* [2021年7月27日] 论文被 WACV 2022 收录。

<br>

## 展示视频
观看展示视频 ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)),了解模型能力。
<p align="center">
    <a href="https://youtu.be/Jvzltozpbpk">
        <img src="documentation/image/showreel.gif">
    </a>
</p>

视频中的所有素材都提供下载,可用于测试模型:[Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing)

<br>


## Demo
* [网页](https://peterl1n.github.io/RobustVideoMatting/#/demo): 在浏览器里看摄像头抠像效果,展示模型内部循环记忆值。
* [Colab](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): 用我们的模型转换你的视频。

<br>

## 下载

推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读[推断文档](documentation/inference_zh_Hans.md)。

<table>
    <thead>
        <tr>
            <td>框架</td>
            <td>下载</td>
            <td>备注</td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td>PyTorch</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth">rvm_mobilenetv3.pth</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth">rvm_resnet50.pth</a>
            </td>
            <td>
                官方 PyTorch 模型权值。<a href="documentation/inference_zh_Hans.md#pytorch">文档</a>
            </td>
        </tr>
        <tr>
            <td>TorchHub</td>
            <td>
                无需手动下载。
            </td>
            <td>
                更方便地在你的 PyTorch 项目里使用此模型。<a href="documentation/inference_zh_Hans.md#torchhub">文档</a>
            </td>
        </tr>
        <tr>
            <td>TorchScript</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp32.torchscript">rvm_mobilenetv3_fp32.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp16.torchscript">rvm_mobilenetv3_fp16.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp32.torchscript">rvm_resnet50_fp32.torchscript</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp16.torchscript">rvm_resnet50_fp16.torchscript</a>
            </td>
            <td>
                若需在移动端推断,可以考虑自行导出 int8 量化的模型。<a href="documentation/inference_zh_Hans.md#torchscript">文档</a>
            </td>
        </tr>
        <tr>
            <td>ONNX</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp32.onnx">rvm_mobilenetv3_fp32.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_fp16.onnx">rvm_mobilenetv3_fp16.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp32.onnx">rvm_resnet50_fp32.onnx</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_fp16.onnx">rvm_resnet50_fp16.onnx</a>
            </td>
            <td>
                在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。<a href="documentation/inference_zh_Hans.md#onnx">文档</a>,<a href="https://github.com/PeterL1n/RobustVideoMatting/tree/onnx">导出</a>
            </td>
        </tr>
        <tr>
            <td>TensorFlow</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_tf.zip">rvm_mobilenetv3_tf.zip</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50_tf.zip">rvm_resnet50_tf.zip</a>
            </td>
            <td>
                TensorFlow 2 SavedModel 格式。<a href="documentation/inference_zh_Hans.md#tensorflow">文档</a>
            </td>
        </tr>
        <tr>
            <td>TensorFlow.js</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_tfjs_int8.zip">rvm_mobilenetv3_tfjs_int8.zip</a><br>
            </td>
            <td>
                在网页上跑模型。<a href="https://peterl1n.github.io/RobustVideoMatting/#/demo">展示</a>,<a href="https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs">示范代码</a>
            </td>
        </tr>
        <tr>
            <td>CoreML</td>
            <td>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel">rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel">rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel">rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel</a><br>
                <a  href="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel">rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel</a><br>
            </td>
            <td>
                CoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。<code>s</code> 代表下采样比。<a href="documentation/inference_zh_Hans.md#coreml">文档</a>,<a href="https://github.com/PeterL1n/RobustVideoMatting/tree/coreml">导出</a>
            </td>
        </tr>
    </tbody>
</table>

所有模型可在 [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) 或[百度网盘](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA)(密码: gym7)上下载。

<br>

## PyTorch 范例

1. 安装 Python 库:
```sh
pip install -r requirements_inference.txt
```

2. 加载模型:

```python
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # 或 "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
```

3. 若只需要做视频抠像处理,我们提供简单的 API:

```python
from inference import convert_video

convert_video(
    model,                           # 模型,可以加载到任何设备(cpu 或 cuda)
    input_source='input.mp4',        # 视频文件,或图片序列文件夹
    output_type='video',             # 可选 "video"(视频)或 "png_sequence"(PNG 序列)
    output_composition='com.mp4',    # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径
    output_alpha="pha.mp4",          # [可选项] 输出透明度预测
    output_foreground="fgr.mp4",     # [可选项] 输出前景预测
    output_video_mbps=4,             # 若导出视频,提供视频码率
    downsample_ratio=None,           # 下采样比,可根据具体视频调节,或 None 选择自动
    seq_chunk=12,                    # 设置多帧并行计算
)
```

4. 或自己写推断逻辑:
```python
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # 绿背景
rec = [None] * 4                                       # 初始循环记忆(Recurrent States)
downsample_ratio = 0.25                                # 下采样比,根据视频调节

with torch.no_grad():
    for src in DataLoader(reader):                     # 输入张量,RGB通道,范围为 0~1
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # 将上一帧的记忆给下一帧
        com = fgr * pha + bgr * (1 - pha)              # 将前景合成到绿色背景
        writer.write(com)                              # 输出帧
```

5. 模型和 API 也可通过 TorchHub 快速载入。

```python
# 加载模型
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50"

# 转换 API
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
```

[推断文档](documentation/inference_zh_Hans.md)里有对 `downsample_ratio` 参数,API 使用,和高阶使用的讲解。

<br>

## 训练和评估

请参照[训练文档(英文)](documentation/training.md)。

<br>

## 速度

速度用 `inference_speed_test.py` 测量以供参考。

| GPU            | dType | HD (1920x1080) | 4K (3840x2160) |
| -------------- | ----- | -------------- |----------------|
| RTX 3090       | FP16  | 172 FPS        | 154 FPS        |
| RTX 2060 Super | FP16  | 134 FPS        | 108 FPS        |
| GTX 1080 Ti    | FP32  | 104 FPS        | 74 FPS         |

* 注释1:HD 使用 `downsample_ratio=0.25`,4K 使用 `downsample_ratio=0.125`。 所有测试都使用 batch size 1 和 frame chunk 1。
* 注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。
* 注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework)。

<br>

## 项目成员
* [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)
* [Linjie Yang](https://sites.google.com/site/linjieyang89/)
* [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/)
* [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/)

<br>

## 第三方资源

* [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu))
* [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth))
* [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391))
* [带有 NatML 的 Unity 引擎](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite))  
* [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth))
* [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth))



================================================
FILE: dataset/augmentation.py
================================================
import easing_functions as ef
import random
import torch
from torchvision import transforms
from torchvision.transforms import functional as F


class MotionAugmentation:
    def __init__(self,
                 size,
                 prob_fgr_affine,
                 prob_bgr_affine,
                 prob_noise,
                 prob_color_jitter,
                 prob_grayscale,
                 prob_sharpness,
                 prob_blur,
                 prob_hflip,
                 prob_pause,
                 static_affine=True,
                 aspect_ratio_range=(0.9, 1.1)):
        self.size = size
        self.prob_fgr_affine = prob_fgr_affine
        self.prob_bgr_affine = prob_bgr_affine
        self.prob_noise = prob_noise
        self.prob_color_jitter = prob_color_jitter
        self.prob_grayscale = prob_grayscale
        self.prob_sharpness = prob_sharpness
        self.prob_blur = prob_blur
        self.prob_hflip = prob_hflip
        self.prob_pause = prob_pause
        self.static_affine = static_affine
        self.aspect_ratio_range = aspect_ratio_range
        
    def __call__(self, fgrs, phas, bgrs):
        # Foreground affine
        if random.random() < self.prob_fgr_affine:
            fgrs, phas = self._motion_affine(fgrs, phas)

        # Background affine
        if random.random() < self.prob_bgr_affine / 2:
            bgrs = self._motion_affine(bgrs)
        if random.random() < self.prob_bgr_affine / 2:
            fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs)
                
        # Still Affine
        if self.static_affine:
            fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1))
            bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5))
        
        # To tensor
        fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs])
        phas = torch.stack([F.to_tensor(pha) for pha in phas])
        bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs])
        
        # Resize
        params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
        fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
        bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)

        # Horizontal flip
        if random.random() < self.prob_hflip:
            fgrs = F.hflip(fgrs)
            phas = F.hflip(phas)
        if random.random() < self.prob_hflip:
            bgrs = F.hflip(bgrs)

        # Noise
        if random.random() < self.prob_noise:
            fgrs, bgrs = self._motion_noise(fgrs, bgrs)
        
        # Color jitter
        if random.random() < self.prob_color_jitter:
            fgrs = self._motion_color_jitter(fgrs)
        if random.random() < self.prob_color_jitter:
            bgrs = self._motion_color_jitter(bgrs)
            
        # Grayscale
        if random.random() < self.prob_grayscale:
            fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous()
            bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous()
            
        # Sharpen
        if random.random() < self.prob_sharpness:
            sharpness = random.random() * 8
            fgrs = F.adjust_sharpness(fgrs, sharpness)
            phas = F.adjust_sharpness(phas, sharpness)
            bgrs = F.adjust_sharpness(bgrs, sharpness)
        
        # Blur
        if random.random() < self.prob_blur / 3:
            fgrs, phas = self._motion_blur(fgrs, phas)
        if random.random() < self.prob_blur / 3:
            bgrs = self._motion_blur(bgrs)
        if random.random() < self.prob_blur / 3:
            fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs)

        # Pause
        if random.random() < self.prob_pause:
            fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs)
        
        return fgrs, phas, bgrs
    
    def _static_affine(self, *imgs, scale_ranges):
        params = transforms.RandomAffine.get_params(
            degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges,
            shears=(-5, 5), img_size=imgs[0][0].size)
        imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs]
        return imgs if len(imgs) > 1 else imgs[0] 
    
    def _motion_affine(self, *imgs):
        config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
                      scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
        angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
        angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)
        
        T = len(imgs[0])
        easing = random_easing_fn()
        for t in range(T):
            percentage = easing(t / (T - 1))
            angle = lerp(angleA, angleB, percentage)
            transX = lerp(transXA, transXB, percentage)
            transY = lerp(transYA, transYB, percentage)
            scale = lerp(scaleA, scaleB, percentage)
            shearX = lerp(shearXA, shearXB, percentage)
            shearY = lerp(shearYA, shearYB, percentage)
            for img in imgs:
                img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
        return imgs if len(imgs) > 1 else imgs[0]
    
    def _motion_noise(self, *imgs):
        grain_size = random.random() * 3 + 1 # range 1 ~ 4
        monochrome = random.random() < 0.5
        for img in imgs:
            T, C, H, W = img.shape
            noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size)))
            noise.mul_(random.random() * 0.2 / grain_size)
            if grain_size != 1:
                noise = F.resize(noise, (H, W))
            img.add_(noise).clamp_(0, 1)
        return imgs if len(imgs) > 1 else imgs[0]
    
    def _motion_color_jitter(self, *imgs):
        brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \
            = torch.randn(8).mul(0.1).tolist()
        strength = random.random() * 0.2
        easing = random_easing_fn()
        T = len(imgs[0])
        for t in range(T):
            percentage = easing(t / (T - 1)) * strength
            for img in imgs:
                img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
                img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1))
                img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
                img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1)))
        return imgs if len(imgs) > 1 else imgs[0]
    
    def _motion_blur(self, *imgs):
        blurA = random.random() * 10
        blurB = random.random() * 10

        T = len(imgs[0])
        easing = random_easing_fn()
        for t in range(T):
            percentage = easing(t / (T - 1))
            blur = max(lerp(blurA, blurB, percentage), 0)
            if blur != 0:
                kernel_size = int(blur * 2)
                if kernel_size % 2 == 0:
                    kernel_size += 1 # Make kernel_size odd
                for img in imgs:
                    img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur)
    
        return imgs if len(imgs) > 1 else imgs[0]
    
    def _motion_pause(self, *imgs):
        T = len(imgs[0])
        pause_frame = random.choice(range(T - 1))
        pause_length = random.choice(range(T - pause_frame))
        for img in imgs:
            img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame]
        return imgs if len(imgs) > 1 else imgs[0]
    

def lerp(a, b, percentage):
    return a * (1 - percentage) + b * percentage


def random_easing_fn():
    if random.random() < 0.2:
        return ef.LinearInOut()
    else:
        return random.choice([
            ef.BackEaseIn,
            ef.BackEaseOut,
            ef.BackEaseInOut,
            ef.BounceEaseIn,
            ef.BounceEaseOut,
            ef.BounceEaseInOut,
            ef.CircularEaseIn,
            ef.CircularEaseOut,
            ef.CircularEaseInOut,
            ef.CubicEaseIn,
            ef.CubicEaseOut,
            ef.CubicEaseInOut,
            ef.ExponentialEaseIn,
            ef.ExponentialEaseOut,
            ef.ExponentialEaseInOut,
            ef.ElasticEaseIn,
            ef.ElasticEaseOut,
            ef.ElasticEaseInOut,
            ef.QuadEaseIn,
            ef.QuadEaseOut,
            ef.QuadEaseInOut,
            ef.QuarticEaseIn,
            ef.QuarticEaseOut,
            ef.QuarticEaseInOut,
            ef.QuinticEaseIn,
            ef.QuinticEaseOut,
            ef.QuinticEaseInOut,
            ef.SineEaseIn,
            ef.SineEaseOut,
            ef.SineEaseInOut,
            Step,
        ])()

class Step: # Custom easing function for sudden change.
    def __call__(self, value):
        return 0 if value < 0.5 else 1


# ---------------------------- Frame Sampler ----------------------------


class TrainFrameSampler:
    def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):
        self.speed = speed
    
    def __call__(self, seq_length):
        frames = list(range(seq_length))
        
        # Speed up
        speed = random.choice(self.speed)
        frames = [int(f * speed) for f in frames]
        
        # Shift
        shift = random.choice(range(seq_length))
        frames = [f + shift for f in frames]
        
        # Reverse
        if random.random() < 0.5:
            frames = frames[::-1]

        return frames
    
class ValidFrameSampler:
    def __call__(self, seq_length):
        return range(seq_length)


================================================
FILE: dataset/coco.py
================================================
import os
import numpy as np
import random
import json
import os
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import functional as F
from PIL import Image


class CocoPanopticDataset(Dataset):
    def __init__(self,
                 imgdir: str,
                 anndir: str,
                 annfile: str,
                 transform=None):
        with open(annfile) as f:
            self.data = json.load(f)['annotations']
            self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data))
        self.imgdir = imgdir
        self.anndir = anndir
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        img = self._load_img(data)
        seg = self._load_seg(data)
        
        if self.transform is not None:
            img, seg = self.transform(img, seg)
            
        return img, seg

    def _load_img(self, data):
        with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img:
            return img.convert('RGB')
    
    def _load_seg(self, data):
        with Image.open(os.path.join(self.anndir, data['file_name'])) as ann:
            ann.load()
            
        ann = np.array(ann, copy=False).astype(np.int32)
        ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2]
        seg = np.zeros(ann.shape, np.uint8)
        
        for segments_info in data['segments_info']:
            if segments_info['category_id'] in [1, 27, 32]: # person, backpack, tie
                seg[ann == segments_info['id']] = 255
        
        return Image.fromarray(seg)
    

class CocoPanopticTrainAugmentation:
    def __init__(self, size):
        self.size = size
        self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)
    
    def __call__(self, img, seg):
        # Affine
        params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1),
                                                    scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size)
        img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR)
        seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST)
        
        # Resize
        params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3))
        img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
        
        # Horizontal flip
        if random.random() < 0.5:
            img = F.hflip(img)
            seg = F.hflip(seg)
        
        # Color jitter
        img = self.jitter(img)
        
        # To tensor
        img = F.to_tensor(img)
        seg = F.to_tensor(seg)
        
        return img, seg
    

class CocoPanopticValidAugmentation:
    def __init__(self, size):
        self.size = size
    
    def __call__(self, img, seg):
        # Resize
        params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.))
        img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
        
        # To tensor
        img = F.to_tensor(img)
        seg = F.to_tensor(seg)
        
        return img, seg

================================================
FILE: dataset/imagematte.py
================================================
import os
import random
from torch.utils.data import Dataset
from PIL import Image

from .augmentation import MotionAugmentation


class ImageMatteDataset(Dataset):
    def __init__(self,
                 imagematte_dir,
                 background_image_dir,
                 background_video_dir,
                 size,
                 seq_length,
                 seq_sampler,
                 transform):
        self.imagematte_dir = imagematte_dir
        self.imagematte_files = os.listdir(os.path.join(imagematte_dir, 'fgr'))
        self.background_image_dir = background_image_dir
        self.background_image_files = os.listdir(background_image_dir)
        self.background_video_dir = background_video_dir
        self.background_video_clips = os.listdir(background_video_dir)
        self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
                                        for clip in self.background_video_clips]
        self.seq_length = seq_length
        self.seq_sampler = seq_sampler
        self.size = size
        self.transform = transform
        
    def __len__(self):
        return max(len(self.imagematte_files), len(self.background_image_files) + len(self.background_video_clips))
    
    def __getitem__(self, idx):
        if random.random() < 0.5:
            bgrs = self._get_random_image_background()
        else:
            bgrs = self._get_random_video_background()
        
        fgrs, phas = self._get_imagematte(idx)
        
        if self.transform is not None:
            return self.transform(fgrs, phas, bgrs)
        
        return fgrs, phas, bgrs
    
    def _get_imagematte(self, idx):
        with Image.open(os.path.join(self.imagematte_dir, 'fgr', self.imagematte_files[idx % len(self.imagematte_files)])) as fgr, \
             Image.open(os.path.join(self.imagematte_dir, 'pha', self.imagematte_files[idx % len(self.imagematte_files)])) as pha:
            fgr = self._downsample_if_needed(fgr.convert('RGB'))
            pha = self._downsample_if_needed(pha.convert('L'))
        fgrs = [fgr] * self.seq_length
        phas = [pha] * self.seq_length
        return fgrs, phas
    
    def _get_random_image_background(self):
        with Image.open(os.path.join(self.background_image_dir, self.background_image_files[random.choice(range(len(self.background_image_files)))])) as bgr:
            bgr = self._downsample_if_needed(bgr.convert('RGB'))
        bgrs = [bgr] * self.seq_length
        return bgrs

    def _get_random_video_background(self):
        clip_idx = random.choice(range(len(self.background_video_clips)))
        frame_count = len(self.background_video_frames[clip_idx])
        frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
        clip = self.background_video_clips[clip_idx]
        bgrs = []
        for i in self.seq_sampler(self.seq_length):
            frame_idx_t = frame_idx + i
            frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
            with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
                bgr = self._downsample_if_needed(bgr.convert('RGB'))
            bgrs.append(bgr)
        return bgrs
    
    def _downsample_if_needed(self, img):
        w, h = img.size
        if min(w, h) > self.size:
            scale = self.size / min(w, h)
            w = int(scale * w)
            h = int(scale * h)
            img = img.resize((w, h))
        return img

class ImageMatteAugmentation(MotionAugmentation):
    def __init__(self, size):
        super().__init__(
            size=size,
            prob_fgr_affine=0.95,
            prob_bgr_affine=0.3,
            prob_noise=0.05,
            prob_color_jitter=0.3,
            prob_grayscale=0.03,
            prob_sharpness=0.05,
            prob_blur=0.02,
            prob_hflip=0.5,
            prob_pause=0.03,
        )


================================================
FILE: dataset/spd.py
================================================
import os
from torch.utils.data import Dataset
from PIL import Image


class SuperviselyPersonDataset(Dataset):
    def __init__(self, imgdir, segdir, transform=None):
        self.img_dir = imgdir
        self.img_files = sorted(os.listdir(imgdir))
        self.seg_dir = segdir
        self.seg_files = sorted(os.listdir(segdir))
        assert len(self.img_files) == len(self.seg_files)
        self.transform = transform
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \
             Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:
            img = img.convert('RGB')
            seg = seg.convert('L')
        
        if self.transform is not None:
            img, seg = self.transform(img, seg)
            
        return img, seg


================================================
FILE: dataset/videomatte.py
================================================
import os
import random
from torch.utils.data import Dataset
from PIL import Image

from .augmentation import MotionAugmentation


class VideoMatteDataset(Dataset):
    def __init__(self,
                 videomatte_dir,
                 background_image_dir,
                 background_video_dir,
                 size,
                 seq_length,
                 seq_sampler,
                 transform=None):
        self.background_image_dir = background_image_dir
        self.background_image_files = os.listdir(background_image_dir)
        self.background_video_dir = background_video_dir
        self.background_video_clips = sorted(os.listdir(background_video_dir))
        self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
                                        for clip in self.background_video_clips]
        
        self.videomatte_dir = videomatte_dir
        self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr')))
        self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) 
                                  for clip in self.videomatte_clips]
        self.videomatte_idx = [(clip_idx, frame_idx) 
                               for clip_idx in range(len(self.videomatte_clips)) 
                               for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)]
        self.size = size
        self.seq_length = seq_length
        self.seq_sampler = seq_sampler
        self.transform = transform

    def __len__(self):
        return len(self.videomatte_idx)
    
    def __getitem__(self, idx):
        if random.random() < 0.5:
            bgrs = self._get_random_image_background()
        else:
            bgrs = self._get_random_video_background()
        
        fgrs, phas = self._get_videomatte(idx)
        
        if self.transform is not None:
            return self.transform(fgrs, phas, bgrs)
        
        return fgrs, phas, bgrs
    
    def _get_random_image_background(self):
        with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
            bgr = self._downsample_if_needed(bgr.convert('RGB'))
        bgrs = [bgr] * self.seq_length
        return bgrs
    
    def _get_random_video_background(self):
        clip_idx = random.choice(range(len(self.background_video_clips)))
        frame_count = len(self.background_video_frames[clip_idx])
        frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
        clip = self.background_video_clips[clip_idx]
        bgrs = []
        for i in self.seq_sampler(self.seq_length):
            frame_idx_t = frame_idx + i
            frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
            with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
                bgr = self._downsample_if_needed(bgr.convert('RGB'))
            bgrs.append(bgr)
        return bgrs
    
    def _get_videomatte(self, idx):
        clip_idx, frame_idx = self.videomatte_idx[idx]
        clip = self.videomatte_clips[clip_idx]
        frame_count = len(self.videomatte_frames[clip_idx])
        fgrs, phas = [], []
        for i in self.seq_sampler(self.seq_length):
            frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count]
            with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \
                 Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha:
                    fgr = self._downsample_if_needed(fgr.convert('RGB'))
                    pha = self._downsample_if_needed(pha.convert('L'))
            fgrs.append(fgr)
            phas.append(pha)
        return fgrs, phas
    
    def _downsample_if_needed(self, img):
        w, h = img.size
        if min(w, h) > self.size:
            scale = self.size / min(w, h)
            w = int(scale * w)
            h = int(scale * h)
            img = img.resize((w, h))
        return img

class VideoMatteTrainAugmentation(MotionAugmentation):
    def __init__(self, size):
        super().__init__(
            size=size,
            prob_fgr_affine=0.3,
            prob_bgr_affine=0.3,
            prob_noise=0.1,
            prob_color_jitter=0.3,
            prob_grayscale=0.02,
            prob_sharpness=0.1,
            prob_blur=0.02,
            prob_hflip=0.5,
            prob_pause=0.03,
        )

class VideoMatteValidAugmentation(MotionAugmentation):
    def __init__(self, size):
        super().__init__(
            size=size,
            prob_fgr_affine=0,
            prob_bgr_affine=0,
            prob_noise=0,
            prob_color_jitter=0,
            prob_grayscale=0,
            prob_sharpness=0,
            prob_blur=0,
            prob_hflip=0,
            prob_pause=0,
        )


================================================
FILE: dataset/youtubevis.py
================================================
import torch
import os
import json
import numpy as np
import random
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F


class YouTubeVISDataset(Dataset):
    def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None):
        self.videodir = videodir
        self.size = size
        self.seq_length = seq_length
        self.seq_sampler = seq_sampler
        self.transform = transform
        
        with open(annfile) as f:
            data = json.load(f)

        self.masks = {}
        for ann in data['annotations']:
            if ann['category_id'] == 26: # person
                video_id = ann['video_id']
                if video_id not in self.masks:
                    self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))]
                for frame, mask in zip(self.masks[video_id], ann['segmentations']):
                    if mask is not None:
                        frame.append(mask)
        
        self.videos = {}
        for video in data['videos']:
            video_id = video['id']
            if video_id in self.masks:
                self.videos[video_id] = video
        
        self.index = []
        for video_id in self.videos.keys():
            for frame in range(len(self.videos[video_id]['file_names'])):
                self.index.append((video_id, frame))
                
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, idx):
        video_id, frame_id = self.index[idx]
        video = self.videos[video_id]
        frame_count = len(self.videos[video_id]['file_names'])
        H, W = video['height'], video['width']
        
        imgs, segs = [], []
        for t in self.seq_sampler(self.seq_length):
            frame = (frame_id + t) % frame_count

            filename = video['file_names'][frame]
            masks = self.masks[video_id][frame]
        
            with Image.open(os.path.join(self.videodir, filename)) as img:
                imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR))
        
            seg = np.zeros((H, W), dtype=np.uint8)
            for mask in masks:
                seg |= self._decode_rle(mask)
            segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST))
            
        if self.transform is not None:
            imgs, segs = self.transform(imgs, segs)
        
        return imgs, segs
    
    def _decode_rle(self, rle):
        H, W = rle['size']
        msk = np.zeros(H * W, dtype=np.uint8)
        encoding = rle['counts']
        skip = 0
        for i in range(0, len(encoding) - 1, 2):
            skip += encoding[i]
            draw = encoding[i + 1]
            msk[skip : skip + draw] = 255
            skip += draw
        return msk.reshape(W, H).transpose()
    
    def _downsample_if_needed(self, img, resample):
        w, h = img.size
        if min(w, h) > self.size:
            scale = self.size / min(w, h)
            w = int(scale * w)
            h = int(scale * h)
            img = img.resize((w, h), resample)
        return img


class YouTubeVISAugmentation:
    def __init__(self, size):
        self.size = size
        self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15)
    
    def __call__(self, imgs, segs):
        
        # To tensor
        imgs = torch.stack([F.to_tensor(img) for img in imgs])
        segs = torch.stack([F.to_tensor(seg) for seg in segs])
        
        # Resize
        params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1))
        imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
        
        # Color jitter
        imgs = self.jitter(imgs)
        
        # Grayscale
        if random.random() < 0.05:
            imgs = F.rgb_to_grayscale(imgs, num_output_channels=3)
        
        # Horizontal flip
        if random.random() < 0.5:
            imgs = F.hflip(imgs)
            segs = F.hflip(segs)
        
        return imgs, segs


================================================
FILE: documentation/inference.md
================================================
# Inference

<p align="center">English | <a href="inference_zh_Hans.md">中文</a></p>

## Content

* [Concepts](#concepts)
    * [Downsample Ratio](#downsample-ratio)
    * [Recurrent States](#recurrent-states)
* [PyTorch](#pytorch)
* [TorchHub](#torchhub)
* [TorchScript](#torchscript)
* [ONNX](#onnx)
* [TensorFlow](#tensorflow)
* [TensorFlow.js](#tensorflowjs)
* [CoreML](#coreml)

<br>


## Concepts

### Downsample Ratio

The table provides a general guideline. Please adjust based on your video content.

| Resolution    | Portrait      | Full-Body      |
| ------------- | ------------- | -------------- |
| <= 512x512    | 1             | 1              |
| 1280x720      | 0.375         | 0.6            |
| 1920x1080     | 0.25          | 0.4            |
| 3840x2160     | 0.125         | 0.2            |

Internally, the model resizes down the input for stage 1. Then, it refines at high-resolution for stage 2.

Set `downsample_ratio` so that the downsampled resolution is between 256 and 512. For example, for `1920x1080` input with `downsample_ratio=0.25`, the resized resolution `480x270` is between 256 and 512.

Adjust `downsample_ratio` base on the video content. If the shot is portrait, a lower `downsample_ratio` is sufficient. If the shot contains the full human body, use high `downsample_ratio`. Note that higher `downsample_ratio` is not always better.


<br>

### Recurrent States
The model is a recurrent neural network. You must process frames sequentially and recycle its recurrent states. 

**Correct Way**

The recurrent outputs are recycled back as input when processing the next frame. The states are essentially the model's memory.

```python
rec = [None] * 4  # Initial recurrent states are None

for frame in YOUR_VIDEO:
    fgr, pha, *rec = model(frame, *rec, downsample_ratio)
```

**Wrong Way**

The model does not utilize the recurrent states. Only use it to process independent images.

```python
for frame in YOUR_VIDEO:
    fgr, pha = model(frame, downsample_ratio)[:2]
```

More technical details are in the [paper](https://peterl1n.github.io/RobustVideoMatting/).

<br><br><br>


## PyTorch

Model loading:

```python
import torch
from model import MattingNetwork

model = MattingNetwork(variant='mobilenetv3').eval().cuda() # Or variant="resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
```

Example inference loop:
```python
rec = [None] * 4 # Set initial recurrent states to None

for src in YOUR_VIDEO:  # src can be [B, C, H, W] or [B, T, C, H, W]
    fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
```

* `src`: Input frame. 
    * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]`. 
    * If `[B, T, C, H, W]`, a chunk of `T` frames can be given at once for better parallelism.
    * RGB input is normalized to `0~1` range.

* `fgr, pha`: Foreground and alpha predictions. 
    * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]` depends on `src`. 
    * `fgr` has `C=3` for RGB, `pha` has `C=1`.
    * Outputs normalized to `0~1` range.
* `rec`: Recurrent states. 
    * Type of `List[Tensor, Tensor, Tensor, Tensor]`. 
    * Initial `rec` can be `List[None, None, None, None]`.
    * It has 4 recurrent states because the model has 4 ConvGRU layers.
    * All tensors are rank 4 regardless of `src` rank.
    * If a chunk of `T` frames is given, only the last frame's recurrent states will be returned.

To inference on video, here is a complete example:

```python
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.

with torch.no_grad():
    for src in DataLoader(reader):
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25)  # Cycle the recurrent states.
        writer.write(fgr * pha + bgr * (1 - pha))
```

Or you can use the provided video converter:

```python
from inference import convert_video

convert_video(
    model,                           # The loaded model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    input_resize=(1920, 1080),       # [Optional] Resize the input (also the output).
    downsample_ratio=0.25,           # [Optional] If None, make downsampled max size be 512px.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='com.mp4',    # File path if video; directory path if png sequence.
    output_alpha="pha.mp4",          # [Optional] Output the raw alpha prediction.
    output_foreground="fgr.mp4",     # [Optional] Output the raw foreground prediction.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
    num_workers=1,                   # Only for image sequence input. Reader threads.
    progress=True                    # Print conversion progress.
)
```

The converter can also be invoked in command line:

```sh
python inference.py \
    --variant mobilenetv3 \
    --checkpoint "CHECKPOINT" \
    --device cuda \
    --input-source "input.mp4" \
    --downsample-ratio 0.25 \
    --output-type video \
    --output-composition "composition.mp4" \
    --output-alpha "alpha.mp4" \
    --output-foreground "foreground.mp4" \
    --output-video-mbps 4 \
    --seq-chunk 12
```

<br><br><br>

## TorchHub

Model loading:

```python
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"
```

Use the conversion function. Refer to the documentation for `convert_video` function above.

```python
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

convert_video(model, ...args...)
```

<br><br><br>

## TorchScript

Model loading:

```python
import torch
model = torch.jit.load('rvm_mobilenetv3.torchscript')
```

Optionally, freeze the model. This will trigger graph optimization, such as BatchNorm fusion etc. Frozen models are faster.

```python
model = torch.jit.freeze(model)
```

Then, you can use the `model` exactly the same as a PyTorch model, with the exception that you must manually provide `device` and `dtype` to the converter API for frozen model. For example:

```python
convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32)
```

<br><br><br>

## ONNX

Model spec:
* Inputs: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. 
    * `src` is the RGB input frame of shape `[B, C, H, W]` normalized to `0~1` range. 
    * `rXi` are the recurrent state inputs. Initial recurrent states are zero value tensors of shape `[1, 1, 1, 1]`.
    * `downsample_ratio` is a tensor of shape `[1]`.
    * Only `downsample_ratio` must have `dtype=FP32`. Other inputs must have `dtype` matching the loaded model's precision.
* Outputs: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`]
    * `fgr, pha` are the foreground and alpha prediction. Normalized to `0~1` range.
    * `rXo` are the recurrent state outputs.

We only show examples of using onnxruntime CUDA backend in Python.

Model loading

```python
import onnxruntime as ort

sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx')
```

Naive inference loop

```python
import numpy as np

rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4  # Must match dtype of the model.
downsample_ratio = np.array([0.25], dtype=np.float32)  # dtype always FP32

for src in YOUR_VIDEO:  # src is of [B, C, H, W] with dtype of the model.
    fgr, pha, *rec = sess.run([], {
        'src': src, 
        'r1i': rec[0], 
        'r2i': rec[1], 
        'r3i': rec[2], 
        'r4i': rec[3], 
        'downsample_ratio': downsample_ratio
    })
```

If you use GPU version of ONNX Runtime, the above naive implementation has recurrent states transferred between CPU and GPU on every frame. They could have just stayed on the GPU for better performance. Below is an example using `iobinding` to eliminate useless transfers.

```python
import onnxruntime as ort
import numpy as np

# Load model.
sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx')

# Create an io binding.
io = sess.io_binding()

# Create tensors on CUDA.
rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4
downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda')

# Set output binding.
for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']:
    io.bind_output(name, 'cuda')

# Inference loop
for src in YOUR_VIDEO:
    io.bind_cpu_input('src', src)
    io.bind_ortvalue_input('r1i', rec[0])
    io.bind_ortvalue_input('r2i', rec[1])
    io.bind_ortvalue_input('r3i', rec[2])
    io.bind_ortvalue_input('r4i', rec[3])
    io.bind_ortvalue_input('downsample_ratio', downsample_ratio)

    sess.run_with_iobinding(io)

    fgr, pha, *rec = io.get_outputs()

    # Only transfer `fgr` and `pha` to CPU.
    fgr = fgr.numpy()
    pha = pha.numpy()
```

Note: depending on the inference tool you choose, it may not support all the operations in our official ONNX model. You are responsible for modifying the model code and exporting your own ONNX model. You can refer to our exporter code in the [onnx branch](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx).

<br><br><br>

### TensorFlow

An example usage:

```python
import tensorflow as tf

model = tf.keras.models.load_model('rvm_mobilenetv3_tf')
model = tf.function(model)

rec = [ tf.constant(0.) ] * 4         # Initial recurrent states.
downsample_ratio = tf.constant(0.25)  # Adjust based on your video.

for src in YOUR_VIDEO:  # src is of shape [B, H, W, C], not [B, C, H, W]!
    out = model([src, *rec, downsample_ratio])
    fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o']
```

Note the the tensors are all channel last. Otherwise, the inputs and outputs are exactly the same as PyTorch.

We also provide the raw TensorFlow model code in the [tensorflow branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow). You can transfer PyTorch checkpoint weights to TensorFlow models.

<br><br><br>

### TensorFlow.js

We provide a starter code in the [tfjs branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs). The example is very self-explanatory. It shows how to properly use the model.

<br><br><br>

### CoreML

We only show example usage of the CoreML models in Python API using `coremltools`. In production, the same logic can be applied in Swift. When processing the first frame, do not provide recurrent states. CoreML will internally construct zero tensors of the correct shapes as the initial recurrent states.

```python
import coremltools as ct

model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel')

r1, r2, r3, r4 = None, None, None, None

for src in YOUR_VIDEO:  # src is PIL.Image.
    
    if r1 is None:
        # Initial frame, do not provide recurrent states.
        inputs = {'src': src}
    else:
        # Subsequent frames, provide recurrent states.
        inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4}

    outputs = model.predict(inputs)

    fgr = outputs['fgr']  # PIL.Image.
    pha = outputs['pha']  # PIL.Image.
    
    r1 = outputs['r1o']  # Numpy array.
    r2 = outputs['r2o']  # Numpy array.
    r3 = outputs['r3o']  # Numpy array.
    r4 = outputs['r4o']  # Numpy array.

```

Our CoreML models only support fixed resolutions. If you need other resolutions, you can export them yourself. See [coreml branch](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) for model export. 

================================================
FILE: documentation/inference_zh_Hans.md
================================================
# 推断文档

<p align="center"><a href="inference.md">English</a> | 中文</p>

## 目录

* [概念](#概念)
    * [下采样比](#下采样比)
    * [循环记忆](#循环记忆)
* [PyTorch](#pytorch)
* [TorchHub](#torchhub)
* [TorchScript](#torchscript)
* [ONNX](#onnx)
* [TensorFlow](#tensorflow)
* [TensorFlow.js](#tensorflowjs)
* [CoreML](#coreml)

<br>


## 概念

### 下采样比

该表仅供参考。可根据视频内容进行调节。

| 分辨率         | 人像           | 全身            |
| ------------- | ------------- | -------------- |
| <= 512x512    | 1             | 1              |
| 1280x720      | 0.375         | 0.6            |
| 1920x1080     | 0.25          | 0.4            |
| 3840x2160     | 0.125         | 0.2            |

模型在内部将高分辨率输入缩小做初步的处理,然后再放大做细分处理。

建议设置 `downsample_ratio` 使缩小后的分辨率维持在 256 到 512 像素之间. 例如,`1920x1080` 的输入用 `downsample_ratio=0.25`,缩小后的分辨率 `480x270` 在 256 到 512 像素之间。

根据视频内容调整 `downsample_ratio`。若视频是上身人像,低 `downsample_ratio` 足矣。若视频是全身像,建议尝试更高的 `downsample_ratio`。但注意,过高的 `downsample_ratio` 反而会降低效果。


<br>

### 循环记忆
此模型是循环神经网络(Recurrent Neural Network)。必须按顺序处理视频每帧,并提供网络循环记忆。

**正确用法**

循环记忆输出被传递到下一帧做输入。

```python
rec = [None] * 4  # 初始值设置为 None

for frame in YOUR_VIDEO:
    fgr, pha, *rec = model(frame, *rec, downsample_ratio)
```

**错误用法**

没有使用循环记忆。此方法仅可用于处理单独的图片。

```python
for frame in YOUR_VIDEO:
    fgr, pha = model(frame, downsample_ratio)[:2]
```

更多技术细节见[论文](https://peterl1n.github.io/RobustVideoMatting/)。

<br><br><br>


## PyTorch

载入模型:

```python
import torch
from model import MattingNetwork

model = MattingNetwork(variant='mobilenetv3').eval().cuda() # 或 variant="resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
```

推断循环:
```python
rec = [None] * 4 # 初始值设置为 None

for src in YOUR_VIDEO:  # src 可以是 [B, C, H, W] 或 [B, T, C, H, W]
    fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
```

* `src`: 输入帧(Source)。 
    * 可以是 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的张量。 
    * 若是 `[B, T, C, H, W]`,可给模型一次 `T` 帧,做一小段一小段地处理,用于更好的并行计算。
    * RGB 通道输入,范围为 `0~1`。

* `fgr, pha`: 前景(Foreground)和透明度通道(Alpha)的预测。 
    * 根据`src`,可为 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的输出。
    * `fgr` 是 RGB 三通道,`pha` 为一通道。
    * 输出范围为 `0~1`。
* `rec`: 循环记忆(Recurrent States)。 
    * `List[Tensor, Tensor, Tensor, Tensor]` 类型。 
    * 初始 `rec` 为 `List[None, None, None, None]`。
    * 有四个记忆,因为网络使用四个 `ConvGRU` 层。
    * 无论 `src` 的 Rank,所有记忆张量的 Rank 为 4。
    * 若一次给予 `T` 帧,只返回处理完最后一帧后的记忆。

完整的推断例子:

```python
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # 绿背景
rec = [None] * 4                                       # 初始记忆

with torch.no_grad():
    for src in DataLoader(reader):
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25)  # 将上一帧的记忆给下一帧
        writer.write(fgr * pha + bgr * (1 - pha))
```

或者使用提供的视频转换 API:

```python
from inference import convert_video

convert_video(
    model,                           # 模型,可以加载到任何设备(cpu 或 cuda)
    input_source='input.mp4',        # 视频文件,或图片序列文件夹
    input_resize=(1920, 1080),       # [可选项] 缩放视频大小
    downsample_ratio=0.25,           # [可选项] 下采样比,若 None,自动下采样至 512px
    output_type='video',             # 可选 "video"(视频)或 "png_sequence"(PNG 序列)
    output_composition='com.mp4',    # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径
    output_alpha="pha.mp4",          # [可选项] 输出透明度预测
    output_foreground="fgr.mp4",     # [可选项] 输出前景预测
    output_video_mbps=4,             # 若导出视频,提供视频码率
    seq_chunk=12,                    # 设置多帧并行计算
    num_workers=1,                   # 只适用于图片序列输入,读取线程
    progress=True                    # 显示进度条
)
```

也可通过命令行调用转换 API:

```sh
python inference.py \
    --variant mobilenetv3 \
    --checkpoint "CHECKPOINT" \
    --device cuda \
    --input-source "input.mp4" \
    --downsample-ratio 0.25 \
    --output-type video \
    --output-composition "composition.mp4" \
    --output-alpha "alpha.mp4" \
    --output-foreground "foreground.mp4" \
    --output-video-mbps 4 \
    --seq-chunk 12
```

<br><br><br>

## TorchHub

载入模型:

```python
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"
```

使用转换 API,具体请参考之前对 `convert_video` 的文档。

```python
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

convert_video(model, ...args...)
```

<br><br><br>

## TorchScript

载入模型:

```python
import torch
model = torch.jit.load('rvm_mobilenetv3.torchscript')
```

也可以可选的将模型固化(Freeze)。这会对模型进行优化,例如 BatchNorm Fusion 等。固化的模型更快。

```python
model = torch.jit.freeze(model)
```

然后,可以将 `model` 作为普通的 PyTorch 模型使用。但注意,若用固化模型调用转换 API,必须手动提供 `device` 和 `dtype`:

```python
convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32)
```

<br><br><br>

## ONNX

模型规格:
* 输入: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. 
    * `src`:输入帧,RGB 通道,形状为 `[B, C, H, W]`,范围为`0~1`。
    * `rXi`:记忆输入,初始值是是形状为 `[1, 1, 1, 1]` 的零张量。
    * `downsample_ratio` 下采样比,张量形状为 `[1]`。
    * 只有 `downsample_ratio` 必须是 `FP32`,其他输入必须和加载的模型使用一样的 `dtype`。
* 输出: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`]
    * `fgr, pha`:前景和透明度通道输出,范围为 `0~1`。
    * `rXo`:记忆输出。

我们只展示用 ONNX Runtime CUDA Backend 在 Python 上的使用范例。

载入模型:

```python
import onnxruntime as ort

sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx')
```

简单推断循环,但此方法不是最优化的:

```python
import numpy as np

rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4  # 必须用模型一样的 dtype
downsample_ratio = np.array([0.25], dtype=np.float32)  # 必须是 FP32

for src in YOUR_VIDEO:  # src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype
    fgr, pha, *rec = sess.run([], {
        'src': src, 
        'r1i': rec[0], 
        'r2i': rec[1], 
        'r3i': rec[2], 
        'r4i': rec[3], 
        'downsample_ratio': downsample_ratio
    })
```

若使用 GPU,上例会将记忆输出传回到 CPU,再在下一帧时传回到 GPU。这种传输是无意义的,因为记忆值可以留在 GPU 上。下例使用 `iobinding` 来杜绝无用的传输。

```python
import onnxruntime as ort
import numpy as np

# 载入模型
sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx')

# 创建 io binding.
io = sess.io_binding()

# 在 CUDA 上创建张量
rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4
downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda')

# 设置输出项
for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']:
    io.bind_output(name, 'cuda')

# 推断
for src in YOUR_VIDEO:
    io.bind_cpu_input('src', src)
    io.bind_ortvalue_input('r1i', rec[0])
    io.bind_ortvalue_input('r2i', rec[1])
    io.bind_ortvalue_input('r3i', rec[2])
    io.bind_ortvalue_input('r4i', rec[3])
    io.bind_ortvalue_input('downsample_ratio', downsample_ratio)

    sess.run_with_iobinding(io)

    fgr, pha, *rec = io.get_outputs()

    # 只将 `fgr` 和 `pha` 回传到 CPU
    fgr = fgr.numpy()
    pha = pha.numpy()
```

注:若你使用其他推断框架,可能有些 ONNX ops 不被支持,需被替换。可以参考 [onnx](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx) 分支的代码做自行导出。

<br><br><br>

### TensorFlow

范例:

```python
import tensorflow as tf

model = tf.keras.models.load_model('rvm_mobilenetv3_tf')
model = tf.function(model)

rec = [ tf.constant(0.) ] * 4         # 初始记忆
downsample_ratio = tf.constant(0.25)  # 下采样率,根据视频调整

for src in YOUR_VIDEO:  # src 张量是 [B, H, W, C] 的形状,而不是 [B, C, H, W]!
    out = model([src, *rec, downsample_ratio])
    fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o']
```

注意,在 TensorFlow 上,所有张量都是 Channal Last 的格式。

我们提供 TensorFlow 的原始模型代码,请参考 [tensorflow](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow) 分支。您可自行将 PyTorch 的权值转到 TensorFlow 模型上。


<br><br><br>

### TensorFlow.js

我们在 [tfjs](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs) 分支提供范例代码。代码简单易懂,解释如何正确使用模型。

<br><br><br>

### CoreML

我们只展示在 Python 下通过 `coremltools` 使用 CoreML 模型。在部署时,同样逻辑可用于 Swift。模型的循环记忆输入不需要在处理第一帧时提供。CoreML 内部会自动创建零张量作为初始记忆。

```python
import coremltools as ct

model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel')

r1, r2, r3, r4 = None, None, None, None

for src in YOUR_VIDEO:  # src 是 PIL.Image.
    
    if r1 is None:
        # 初始帧, 不用提供循环记忆
        inputs = {'src': src}
    else:
        # 剩余帧,提供循环记忆
        inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4}

    outputs = model.predict(inputs)

    fgr = outputs['fgr']  # PIL.Image
    pha = outputs['pha']  # PIL.Image
    
    r1 = outputs['r1o']  # Numpy array
    r2 = outputs['r2o']  # Numpy array
    r3 = outputs['r3o']  # Numpy array
    r4 = outputs['r4o']  # Numpy array

```

我们的 CoreML 模型只支持固定分辨率。如果你需要其他分辨率,可自行导出。导出代码见 [coreml](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) 分支。

================================================
FILE: documentation/misc/aim_test.txt
================================================
boy-1518482_1920.png
girl-1219339_1920.png
girl-1467820_1280.png
girl-beautiful-young-face-53000.png
long-1245787_1920.png
model-600238_1920.png
pexels-photo-58463.png
sea-sunny-person-beach.png
wedding-dresses-1486260_1280.png
woman-952506_1920 (1).png
woman-morning-bathrobe-bathroom.png


================================================
FILE: documentation/misc/d646_test.txt
================================================
test_13.png
test_16.png
test_18.png
test_22.png
test_32.png
test_35.png
test_39.png
test_42.png
test_46.png
test_4.png
test_6.png


================================================
FILE: documentation/misc/dvm_background_test_clips.txt
================================================
0000
0001
0002
0004
0005
0007
0008
0009
0010
0012
0013
0014
0015
0016
0017
0018
0019
0021
0022
0023
0024
0025
0027
0029
0030
0032
0033
0034
0035
0037
0038
0039
0040
0041
0042
0043
0045
0046
0047
0048
0050
0051
0052
0054
0055
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0068
0070
0071
0073
0074
0075
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0089
0097
0100
0101
0102
0103
0104
0106
0107
0109
0110
0111
0113
0115
0116
0117
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0143
0145
0147
0148
0150
0159
0160
0161
0162
0165
0166
0168
0172
0174
0175
0176
0178
0181
0182
0183
0184
0185
0187
0194
0198
0200
0201
0207
0210
0211
0212
0215
0217
0218
0219
0220
0222
0223
0224
0225
0226
0227
0229
0230
0231
0232
0233
0234
0235
0237
0240
0241
0242
0243
0244
0245


================================================
FILE: documentation/misc/dvm_background_train_clips.txt
================================================
0000
0002
0003
0004
0005
0006
0007
0009
0010
0012
0013
0014
0015
0016
0019
0021
0022
0023
0024
0025
0028
0029
0030
0031
0032
0034
0035
0036
0037
0039
0040
0041
0042
0043
0044
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0073
0074
0075
0076
0077
0078
0079
0081
0082
0087
0088
0099
0100
0101
0104
0105
0107
0108
0109
0110
0111
0112
0113
0114
0115
0117
0118
0119
0120
0122
0123
0124
0125
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0144
0146
0147
0148
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160
0161
0163
0164
0165
0167
0168
0169
0170
0171
0172
0174
0175
0176
0177
0178
0180
0181
0182
0184
0185
0187
0188
0189
0190
0192
0193
0194
0195
0196
0197
0198
0199
0200
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0229
0230
0231
0233
0234
0235
0236
0237
0238
0240
0241
0242
0243
0244
0245
0246
0247
0248
0249
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270
0271
0272
0273
0274
0275
0276
0277
0278
0279
0280
0281
0282
0283
0284
0285
0286
0287
0288
0289
0290
0291
0292
0293
0294
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0308
0309
0310
0311
0312
0313
0314
0315
0316
0317
0319
0320
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332
0333
0335
0336
0337
0338
0339
0341
0342
0344
0345
0346
0348
0349
0352
0353
0356
0357
0358
0359
0360
0361
0362
0363
0364
0365
0366
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381
0382
0383
0384
0385
0386
0387
0388
0389
0391
0392
0393
0394
0395
0397
0398
0399
0400
0401
0402
0403
0404
0405
0406
0407
0408
0409
0410
0411
0413
0414
0415
0416
0417
0419
0420
0421
0422
0423
0424
0425
0426
0427
0428
0429
0431
0433
0434
0435
0436
0437
0438
0439
0440
0441
0442
0443
0445
0446
0447
0448
0449
0450
0451
0452
0453
0454
0456
0457
0458
0459
0462
0463
0464
0465
0466
0467
0468
0469
0470
0471
0472
0473
0474
0475
0476
0477
0478
0479
0480
0481
0482
0483
0484
0485
0486
0487
0488
0489
0490
0491
0492
0493
0494
0499
0501
0502
0503
0504
0505
0506
0507
0509
0510
0511
0512
0513
0514
0515
0517
0518
0519
0520
0521
0522
0524
0526
0527
0529
0530
0534
0535
0536
0538
0539
0541
0542
0543
0544
0545
0546
0548
0549
0550
0552
0554
0555
0556
0557
0558
0559
0560
0561
0562
0563
0564
0565
0566
0567
0568
0571
0572
0573
0574
0575
0576
0577
0578
0579
0580
0581
0582
0583
0584
0586
0587
0589
0590
0591
0592
0594
0595
0596
0597
0598
0600
0601
0602
0603
0604
0605
0606
0608
0609
0610
0611
0612
0613
0614
0615
0616
0617
0618
0619
0620
0624
0625
0626
0627
0628
0629
0630
0631
0634
0635
0636
0637
0638
0639
0640
0641
0642
0643
0644
0645
0646
0647
0648
0650
0651
0652
0654
0655
0656
0658
0659
0660
0661
0662
0663
0664
0665
0666
0667
0669
0670
0671
0672
0673
0674
0675
0676
0677
0678
0679
0680
0681
0682
0683
0684
0685
0686
0687
0689
0690
0691
0692
0693
0694
0695
0696
0697
0698
0699
0700
0701
0702
0703
0704
0705
0706
0707
0708
0709
0710
0711
0712
0713
0714
0715
0716
0717
0718
0719
0720
0721
0723
0724
0725
0726
0727
0729
0730
0731
0732
0733
0734
0735
0736
0738
0740
0741
0742
0743
0744
0746
0747
0748
0749
0750
0752
0753
0754
0755
0756
0757
0758
0759
0760
0762
0763
0764
0765
0766
0767
0768
0770
0771
0772
0773
0774
0775
0776
0777
0778
0779
0780
0781
0782
0783
0784
0786
0787
0788
0789
0790
0791
0792
0793
0794
0795
0796
0797
0798
0800
0801
0804
0806
0808
0809
0811
0812
0813
0814
0815
0816
0817
0819
0823
0824
0825
0827
0828
0829
0830
0831
0832
0833
0834
0835
0836
0837
0840
0841
0842
0847
0848
0850
0851
0852
0853
0854
0855
0856
0857
0858
0859
0860
0861
0862
0864
0867
0868
0869
0870
0871
0872
0873
0874
0876
0877
0878
0879
0880
0881
0882
0883
0885
0886
0887
0889
0890
0891
0892
0893
0894
0895
0896
0899
0900
0901
0902
0903
0904
0905
0906
0907
0908
0909
0910
0911
0912
0913
0914
0915
0916
0917
0918
0919
0921
0922
0923
0924
0925
0926
0927
0929
0930
0931
0932
0933
0934
0935
0936
0937
0939
0940
0941
0942
0945
0946
0947
0948
0949
0950
0951
0952
0953
0954
0955
0956
0957
0958
0960
0962
0963
0964
0965
0966
0967
0968
0969
0970
0971
0973
0974
0976
0977
0978
0979
0980
0981
0982
0983
0984
0985
0986
0987
0988
0989
0990
0991
0992
0993
0994
0995
0996
0997
0998
0999
1000
1001
1002
1003
1005
1006
1008
1009
1010
1011
1012
1013
1014
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1028
1029
1030
1032
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1052
1053
1054
1055
1056
1057
1058
1061
1062
1063
1064
1065
1066
1067
1068
1069
1072
1075
1076
1077
1078
1079
1081
1082
1083
1084
1087
1088
1089
1090
1096
1097
1098
1099
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1128
1129
1130
1131
1132
1134
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1165
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1199
1200
1201
1202
1203
1204
1206
1207
1208
1211
1212
1213
1215
1216
1217
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1237
1238
1239
1240
1241
1242
1245
1246
1248
1249
1252
1253
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1277
1278
1279
1280
1281
1282
1283
1284
1287
1288
1289
1290
1291
1293
1294
1295
1296
1297
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1350
1351
1352
1353
1354
1355
1356
1357
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1378
1379
1380
1381
1382
1383
1385
1386
1387
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1402
1403
1405
1406
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1493
1494
1495
1496
1497
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1560
1561
1562
1563
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1608
1609
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1778
1779
1780
1781
1782
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1804
1806
1807
1808
1809
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1825
1826
1827
1828
1829
1831
1833
1834
1835
1836
1837
1838
1839
1840
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1879
1880
1881
1882
1886
1887
1889
1891
1892
1893
1894
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1952
1953
1954
1955
1956
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1980
1981
1982
1983
1984
1985
1986
1988
1989
1990
1991
1992
1993
1994
1995
1997
1998
1999
2000
2002
2003
2004
2005
2007
2008
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2024
2025
2026
2027
2028
2029
2030
2031
2032
2035
2036
2037
2038
2039
2040
2041
2042
2043
2045
2046
2047
2049
2050
2052
2053
2054
2056
2057
2059
2060
2061
2064
2066
2068
2069
2070
2071
2072
2073
2074
2075
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2107
2108
2109
2111
2112
2113
2114
2115
2116
2117
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2143
2144
2145
2146
2147
2148
2149
2152
2153
2155
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2174
2176
2177
2178
2179
2180
2181
2182
2183
2184
2186
2187
2188
2189
2190
2191
2192
2193
2195
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2255
2256
2257
2258
2259
2260
2261
2263
2264
2265
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2290
2291
2292
2293
2294
2295
2297
2299
2300
2301
2302
2303
2304
2305
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2322
2324
2325
2326
2329
2331
2332
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2347
2349
2350
2351
2352
2353
2355
2356
2358
2359
2360
2361
2362
2363
2364
2365
2367
2368
2369
2370
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2386
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2406
2407
2408
2409
2410
2411
2412
2413
2414
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2431
2432
2433
2434
2437
2439
2440
2441
2442
2443
2444
2445
2446
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2472
2474
2475
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2530
2534
2536
2537
2539
2540
2541
2542
2543
2544
2545
2546
2548
2549
2550
2551
2553
2554
2555
2556
2559
2560
2562
2567
2570
2573
2575
2576
2578
2579
2582
2585
2588
2590
2593
2594
2595
2596
2597
2598
2600
2601
2602
2603
2604
2605
2606
2607
2612
2613
2615
2616
2617
2618
2619
2620
2622
2623
2624
2625
2629
2630
2633
2634
2635
2637
2638
2639
2641
2643
2644
2646
2648
2649
2650
2652
2654
2656
2659
2661
2662
2663
2664
2665
2667
2669
2670
2671
2672
2674
2675
2677
2679
2680
2682
2684
2685
2687
2689
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2705
2706
2707
2710
2713
2714
2715
2717
2718
2720
2721
2722
2726
2727
2729
2733
2734
2737
2739
2740
2741
2742
2743
2745
2747
2748
2749
2755
2756
2757
2758
2759
2761
2762
2763
2765
2767
2768
2769
2770
2773
2775
2776
2780
2781
2782
2785
2787
2790
2791
2793
2794
2795
2796
2797
2798
2800
2801
2802
2803
2806
2808
2810
2812
2813
2814
2815
2816
2818
2819
2820
2821
2822
2823
2825
2827
2828
2829
2830
2831
2832
2833
2834
2835
2837
2838
2842
2843
2844
2845
2847
2848
2849
2850
2851
2852
2853
2854
2856
2857
2858
2860
2861
2862
2863
2864
2865
2866
2868
2869
2871
2874
2875
2877
2879
2881
2882
2884
2885
2887
2888
2889
2890
2891
2893
2894
2895
2896
2900
2902
2903
2904
2905
2906
2907
2908
2911
2912
2914
2915
2916
2917
2918
2919
2921
2923
2924
2925
2926
2928
2929
2930
2931
2932
2933
2934
2935
2936
2938
2939
2940
2941
2944
2945
2946
2947
2952
2954
2957
2958
2960
2962
2963
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2981
2982
2984
2985
2986
2988
2993
2994
2996
2998
2999
3000
3001
3002
3004
3006
3007
3008
3010
3015
3016
3017
3018
3019
3020
3021
3022
3024
3026
3027
3029
3033
3034
3035
3036
3037
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3053
3056
3057
3058
3059
3060
3061
3063
3065
3066
3068
3069
3070
3074
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3094
3095
3097
3098
3100
3101
3102
3103
3105
3106
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3121
3123
3124
3125
3127
3128
3130
3131
3133
3134
3136
3137
3140
3142
3143
3144
3145
3147
3148
3152
3153
3155
3156
3157
3159
3160
3162
3164
3165
3166
3168
3170
3171
3173
3176
3178
3179
3180
3181
3184
3187
3188
3190
3191
3194
3197
3199
3200
3201
3202
3204
3209
3210
3211
3212
3214
3215
3217
3218
3219
3221
3222
3223
3224
3226
3228
3229
3230
3232
3237
3238
3239
3240
3241
3242
3243
3245
3247
3248
3250
3251
3252
3253
3254
3257
3258
3259
3260
3262
3264
3266
3267
3269
3270
3275
3278
3279
3280
3282
3284
3285
3286
3287
3288
3289
3292
3293
3295
3296
3298
3299
3300
3301
3302
3303
3304
3309
3311
3312
3315
3316
3317
3318
3319
3322
3325
3328
3332
3334
3339
3340
3342
3346
3348
3349
3350
3351
3352
3354
3355
3357
3358
3361
3363
3364
3365
3366
3367
3368
3369
3370
3373
3374
3377
3378
3380
3381
3382
3383
3384
3385
3386
3391
3393
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3407
3408
3409
3410
3411
3415
3416
3418
3420
3422
3423
3424
3427
3428
3431
3433
3435
3437
3438
3439
3440
3441
3442
3443
3444
3446
3449
3450
3451
3452
3453
3454
3455
3456
3457
3460
3462
3463
3464
3465
3466
3467
3468
3470
3475
3476
3477
3483
3484
3486
3487
3489
3492
3496
3497
3498
3500
3501
3502
3505
3507
3508
3509
3510
3511
3512
3513
3514
3515
3517
3518
3519
3521
3524
3525
3526
3528
3529
3532
3536
3541
3542
3543
3544
3545
3546
3549
3550
3551
3552
3554
3556
3557
3558
3559
3560
3562
3563
3564
3566
3567
3568
3571
3572
3573
3574
3575
3576
3578
3579
3580
3582
3584
3585
3588
3590
3592
3593
3594
3599
3602
3605
3606
3608
3611
3612
3615
3617
3620
3621
3622
3623
3624
3625
3626
3629
3630
3632
3633
3636
3637
3638
3641
3644
3646
3647
3648
3649
3654
3655
3656
3657
3660
3662
3667
3671
3672
3673
3674
3675
3676
3678
3679
3680
3681
3682
3683
3685
3687
3691
3692
3694
3695
3697
3698
3699
3700
3701
3703
3704
3705
3707
3709
3711
3712
3715
3717
3718
3719
3720
3721
3723
3724
3725
3726
3728
3729
3733
3734
3735
3737
3738
3739
3741
3743
3746
3748
3750
3753
3754
3756
3757
3759
3760
3762
3764
3765
3766
3768
3772
3773
3774
3776
3777
3779
3780
3782
3783
3784
3785
3786
3790
3792
3793
3794
3798
3799
3800
3801
3802
3803
3804
3807
3809
3813
3814
3816
3819
3822
3823
3824
3826
3827
3828
3829
3830
3831
3832
3833
3834
3836
3837
3838
3839
3840
3841
3842
3844
3845
3847
3848
3849
3850
3852
3854
3855
3856
3857
3858
3861
3862
3863
3865
3869
3872
3873
3874
3879
3880
3883
3885
3886
3887
3888
3889
3890
3891
3893
3894
3895
3897
3898
3899
3900
3901
3903
3904
3906
3914
3915
3916
3920
3923
3924
3925
3926
3928
3929
3931
3932
3934
3935
3937
3939
3942
3943
3945
3949
3950
3952
3955
3956
3963
3965
3966
3967
3969
3970
3971
3974
3976
3977
3978
3980
3981
3982
3983
3984
3987
3988
3992
3995
3996
3997
3999


================================================
FILE: documentation/misc/imagematte_train.txt
================================================
10743257206_18e7f44f2e_b.jpg
10845279884_d2d4c7b4d1_b.jpg
1-1252426161dfXY.jpg
1-1255621189mTnS.jpg
1-1259162624NMFK.jpg
1-1259245823Un3j.jpg
11363165393_05d7a21d76_b.jpg
131686738165901828.jpg
13564741125_753939e9ce_o.jpg
14731860273_5b40b19b51_o.jpg
16087-a-young-woman-showing-a-bitten-green-apple-pv.jpg
1609484818_b9bb12b.jpg
17620-a-beautiful-woman-in-a-bikini-pv.jpg
20672673163_20c8467827_b.jpg
3262986095_2d5afe583c_b.jpg
3588101233_f91aa5e3a3.jpg
3858897226_cae5b75963_o.jpg
4889657410_2d503ca287_o.jpg
4981835627_c4e6c4ffa8_o.jpg
5025666458_576b974455_o.jpg
5149410930_3a943dc43f_b.jpg
539641011387760661.jpg
5892503248_4b882863c7_o.jpg
604673748289192179.jpg
606189768665464996.jpg
624753897218113578.jpg
657454154710122500.jpg
664308724952072193.jpg
7669262460_e4be408343_b.jpg
8244818049_dfa59a3eb8_b.jpg
8688417335_01f3bafbe5_o.jpg
9434599749_e7ccfc7812_b.jpg
Aaron_Friedman_Headshot.jpg
arrgh___r___28_by_mjranum_stock.jpg
arrgh___r___29_by_mjranum_stock.jpg
arrgh___r___30_by_mjranum_stock.jpg
a-single-person-1084191_960_720.jpg
ballerina-855652_1920.jpg
beautiful-19075_960_720.jpg
boy-454633_1920.jpg
bride-2819673_1920.jpg
bride-442894_1920.jpg
face-1223346_960_720.jpg
fashion-model-portrait.jpg
fashion-model-pose.jpg
girl-1535859_1920.jpg
Girl_in_front_of_a_green_background.jpg
goth_by_bugidifino-d4w7zms.jpg
h_0.jpg
h_100.jpg
h_101.jpg
h_102.jpg
h_103.jpg
h_104.jpg
h_105.jpg
h_106.jpg
h_107.jpg
h_108.jpg
h_109.jpg
h_10.jpg
h_111.jpg
h_112.jpg
h_113.jpg
h_114.jpg
h_115.jpg
h_116.jpg
h_117.jpg
h_118.jpg
h_119.jpg
h_11.jpg
h_120.jpg
h_121.jpg
h_122.jpg
h_123.jpg
h_124.jpg
h_125.jpg
h_126.jpg
h_127.jpg
h_128.jpg
h_129.jpg
h_12.jpg
h_130.jpg
h_131.jpg
h_132.jpg
h_133.jpg
h_134.jpg
h_135.jpg
h_136.jpg
h_137.jpg
h_138.jpg
h_139.jpg
h_13.jpg
h_140.jpg
h_141.jpg
h_142.jpg
h_143.jpg
h_144.jpg
h_145.jpg
h_146.jpg
h_147.jpg
h_148.jpg
h_149.jpg
h_14.jpg
h_151.jpg
h_152.jpg
h_153.jpg
h_154.jpg
h_155.jpg
h_156.jpg
h_157.jpg
h_158.jpg
h_159.jpg
h_15.jpg
h_160.jpg
h_161.jpg
h_162.jpg
h_163.jpg
h_164.jpg
h_165.jpg
h_166.jpg
h_167.jpg
h_168.jpg
h_169.jpg
h_170.jpg
h_171.jpg
h_172.jpg
h_173.jpg
h_174.jpg
h_175.jpg
h_176.jpg
h_177.jpg
h_178.jpg
h_179.jpg
h_17.jpg
h_180.jpg
h_181.jpg
h_182.jpg
h_183.jpg
h_184.jpg
h_185.jpg
h_186.jpg
h_187.jpg
h_188.jpg
h_189.jpg
h_18.jpg
h_190.jpg
h_191.jpg
h_192.jpg
h_193.jpg
h_194.jpg
h_195.jpg
h_196.jpg
h_197.jpg
h_198.jpg
h_199.jpg
h_19.jpg
h_1.jpg
h_200.jpg
h_201.jpg
h_202.jpg
h_204.jpg
h_205.jpg
h_206.jpg
h_207.jpg
h_208.jpg
h_209.jpg
h_20.jpg
h_210.jpg
h_211.jpg
h_212.jpg
h_213.jpg
h_214.jpg
h_215.jpg
h_216.jpg
h_217.jpg
h_218.jpg
h_219.jpg
h_21.jpg
h_220.jpg
h_221.jpg
h_222.jpg
h_223.jpg
h_224.jpg
h_225.jpg
h_226.jpg
h_227.jpg
h_228.jpg
h_229.jpg
h_22.jpg
h_230.jpg
h_231.jpg
h_232.jpg
h_233.jpg
h_234.jpg
h_235.jpg
h_236.jpg
h_237.jpg
h_238.jpg
h_239.jpg
h_23.jpg
h_240.jpg
h_241.jpg
h_242.jpg
h_243.jpg
h_244.jpg
h_245.jpg
h_247.jpg
h_248.jpg
h_249.jpg
h_24.jpg
h_250.jpg
h_251.jpg
h_252.jpg
h_253.jpg
h_254.jpg
h_255.jpg
h_256.jpg
h_257.jpg
h_258.jpg
h_259.jpg
h_25.jpg
h_260.jpg
h_261.jpg
h_262.jpg
h_263.jpg
h_264.jpg
h_265.jpg
h_266.jpg
h_268.jpg
h_269.jpg
h_26.jpg
h_270.jpg
h_271.jpg
h_272.jpg
h_273.jpg
h_274.jpg
h_276.jpg
h_277.jpg
h_278.jpg
h_279.jpg
h_27.jpg
h_280.jpg
h_281.jpg
h_282.jpg
h_283.jpg
h_284.jpg
h_285.jpg
h_286.jpg
h_287.jpg
h_288.jpg
h_289.jpg
h_28.jpg
h_290.jpg
h_291.jpg
h_292.jpg
h_293.jpg
h_294.jpg
h_295.jpg
h_296.jpg
h_297.jpg
h_298.jpg
h_299.jpg
h_29.jpg
h_300.jpg
h_301.jpg
h_302.jpg
h_303.jpg
h_304.jpg
h_305.jpg
h_307.jpg
h_308.jpg
h_309.jpg
h_30.jpg
h_310.jpg
h_311.jpg
h_312.jpg
h_313.jpg
h_314.jpg
h_315.jpg
h_316.jpg
h_317.jpg
h_318.jpg
h_319.jpg
h_31.jpg
h_320.jpg
h_321.jpg
h_322.jpg
h_323.jpg
h_324.jpg
h_325.jpg
h_326.jpg
h_327.jpg
h_329.jpg
h_32.jpg
h_33.jpg
h_34.jpg
h_35.jpg
h_36.jpg
h_37.jpg
h_38.jpg
h_39.jpg
h_3.jpg
h_40.jpg
h_41.jpg
h_42.jpg
h_43.jpg
h_44.jpg
h_45.jpg
h_46.jpg
h_47.jpg
h_48.jpg
h_49.jpg
h_4.jpg
h_50.jpg
h_51.jpg
h_52.jpg
h_53.jpg
h_54.jpg
h_55.jpg
h_56.jpg
h_57.jpg
h_58.jpg
h_59.jpg
h_5.jpg
h_60.jpg
h_61.jpg
h_62.jpg
h_63.jpg
h_65.jpg
h_67.jpg
h_68.jpg
h_69.jpg
h_6.jpg
h_70.jpg
h_71.jpg
h_72.jpg
h_73.jpg
h_74.jpg
h_75.jpg
h_76.jpg
h_77.jpg
h_78.jpg
h_79.jpg
h_7.jpg
h_80.jpg
h_81.jpg
h_82.jpg
h_83.jpg
h_84.jpg
h_85.jpg
h_86.jpg
h_87.jpg
h_88.jpg
h_89.jpg
h_8.jpg
h_90.jpg
h_91.jpg
h_92.jpg
h_93.jpg
h_94.jpg
h_95.jpg
h_96.jpg
h_97.jpg
h_98.jpg
h_99.jpg
h_9.jpg
hair-flying-142210_1920.jpg
headshotid_by_bokogreat_stock-d355xf3.jpg
lil_white_goth_grl___23_by_mjranum_stock.jpg
lil_white_goth_grl___26_by_mjranum_stock.jpg
man-388104_960_720.jpg
man_headshot.jpg
MFettes-headshot.jpg
model-429733_960_720.jpg
model-610352_960_720.jpg
model-858753_960_720.jpg
model-858755_960_720.jpg
model-873675_960_720.jpg
model-873678_960_720.jpg
model-873690_960_720.jpg
model-881425_960_720.jpg
model-881431_960_720.jpg
model-female-girl-beautiful-51969.jpg
Model_in_green_dress_3.jpg
Modern_shingle_bob_haircut.jpg
Motivate_(Fitness_model).jpg
Official_portrait_of_Barack_Obama.jpg
person-woman-eyes-face.jpg
pink-hair-855660_960_720.jpg
portrait-750774_1920.jpg
Professor_Steven_Chu_ForMemRS_headshot.jpg
sailor_flying_4_by_senshistock-d4k2wmr.jpg
skin-care-937667_960_720.jpg
sorcery___8_by_mjranum_stock.jpg
t_62.jpg
t_65.jpg
test_32.jpg
test_8.jpg
train_245.jpg
train_246.jpg
train_255.jpg
train_304.jpg
train_333.jpg
train_361.jpg
train_395.jpg
train_480.jpg
train_488.jpg
train_539.jpg
wedding-846926_1920.jpg
Wild_hair.jpg
with_wings___pose_reference_by_senshistock-d6by42n_2.jpg
with_wings___pose_reference_by_senshistock-d6by42n.jpg
woman-1138435_960_720.jpg
woman1.jpg
woman2.jpg
woman-659354_960_720.jpg
woman-804072_960_720.jpg
woman-868519_960_720.jpg
Woman_in_white_shirt_on_August_2009_02.jpg
women-878869_1920.jpg


================================================
FILE: documentation/misc/imagematte_valid.txt
================================================
13564741125_753939e9ce_o.jpg
3858897226_cae5b75963_o.jpg
538724499685900405.jpg
ballerina-855652_1920.jpg
boy-454633_1920.jpg
h_110.jpg
h_150.jpg
h_16.jpg
h_246.jpg
h_267.jpg
h_275.jpg
h_306.jpg
h_328.jpg
model-610352_960_720.jpg
t_66.jpg


================================================
FILE: documentation/misc/spd_preprocess.py
================================================
# pip install supervisely
import supervisely_lib as sly
import numpy as np
import os
from PIL import Image
from tqdm import tqdm

# Download dataset from <https://supervise.ly/explore/projects/supervisely-person-dataset-23304/datasets>
project_root = 'PATH_TO/Supervisely Person Dataset'  # <-- Configure input
project = sly.Project(project_root, sly.OpenMode.READ)

output_path = 'OUTPUT_DIR'  # <-- Configure output
os.makedirs(os.path.join(output_path, 'train', 'src'))
os.makedirs(os.path.join(output_path, 'train', 'msk'))
os.makedirs(os.path.join(output_path, 'valid', 'src'))
os.makedirs(os.path.join(output_path, 'valid', 'msk'))

max_size = 2048  # <-- Configure max size

for dataset in project.datasets:
    for item in tqdm(dataset):
        ann = sly.Annotation.load_json_file(dataset.get_ann_path(item), project.meta)
        msk = np.zeros(ann.img_size, dtype=np.uint8)
        for label in ann.labels:
            label.geometry.draw(msk, color=[255])
        msk = Image.fromarray(msk)
        
        img = Image.open(dataset.get_img_path(item)).convert('RGB')
        if img.size[0] > max_size or img.size[1] > max_size:
            scale = max_size / max(img.size)
            img = img.resize((int(img.size[0] * scale), int(img.size[1] * scale)), Image.BILINEAR)
            msk = msk.resize((int(msk.size[0] * scale), int(msk.size[1] * scale)), Image.NEAREST)
        
        img.save(os.path.join(output_path, 'train', 'src', item.replace('.png', '.jpg')))
        msk.save(os.path.join(output_path, 'train', 'msk', item.replace('.png', '.jpg')))

# Move first 100 to validation set
names = os.listdir(os.path.join(output_path, 'train', 'src'))
for name in tqdm(names[:100]):
    os.rename(
        os.path.join(output_path, 'train', 'src', name),
        os.path.join(output_path, 'valid', 'src', name))
    os.rename(
        os.path.join(output_path, 'train', 'msk', name),
        os.path.join(output_path, 'valid', 'msk', name))

================================================
FILE: documentation/training.md
================================================
# Training Documentation

This documentation only shows the way to re-produce our [paper](https://peterl1n.github.io/RobustVideoMatting/). If you would like to remove or add a dataset to the training, you are responsible for adapting the training code yourself.

## Datasets

The following datasets are used during our training.

**IMPORTANT: If you choose to download our preprocessed versions. Please avoid repeated downloads and cache the data locally. All traffics cost our expense. Please be responsible. We may only provide the preprocessed version of a limited time.**

### Matting Datasets
* [VideoMatte240K](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)
    * Download JPEG SD version (6G) for stage 1 and 2.
    * Download JPEG HD version (60G) for stage 3 and 4.
    * Manually move clips `0000`, `0100`, `0200`, `0300` from the training set to a validation set.
* ImageMatte
    * ImageMatte consists of [Distinctions-646](https://wukaoliu.github.io/HAttMatting/) and [Adobe Image Matting](https://sites.google.com/view/deepimagematting) datasets.
    * Only needed for stage 4.
    * You need to contact their authors to acquire.
    * After downloading both datasets, merge their samples together to form ImageMatte dataset.
    * Only keep samples of humans.
    * Full list of images we used in ImageMatte for training:
        * [imagematte_train.txt](/documentation/misc/imagematte_train.txt)
        * [imagematte_valid.txt](/documentation/misc/imagematte_valid.txt)
    * Full list of images we used for evaluation.
        * [aim_test.txt](/documentation/misc/aim_test.txt)
        * [d646_test.txt](/documentation/misc/d646_test.txt)
### Background Datasets
* Video Backgrounds
    * We process from [DVM Background Set](https://github.com/nowsyn/DVM) by selecting clips without humans and extract only the first 100 frames as JPEG sequence.
    * Full list of clips we used:
        * [dvm_background_train_clips.txt](/documentation/misc/dvm_background_train_clips.txt)
        * [dvm_background_test_clips.txt](/documentation/misc/dvm_background_test_clips.txt)
    * You can download our preprocessed versions:
        * [Train set (14.6G)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTrain.tar) (Manually move some clips to validation set)
        * [Test set (936M)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTest.tar) (Not needed for training. Only used for making synthetic test samples for evaluation)
* Image Backgrounds
    * Train set:
        * We crawled 8000 suitable images from Google and Flicker.
        * We will not publish these images.
    * [Test set](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)
        * We use the validation background set from [BGMv2](https://grail.cs.washington.edu/projects/background-matting-v2/) project.
        * It contains about 200 images.
        * It is not used in our training. Only used for making synthetic test samples for evaluation.
        * But if you just want to quickly tryout training, you may use this as a temporary subsitute for the train set.

### Segmentation Datasets

* [COCO](https://cocodataset.org/#download)
    * Download [train2017.zip (18G)](http://images.cocodataset.org/zips/train2017.zip)
    * Download [panoptic_annotations_trainval2017.zip (821M)](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip)
    * Note that our train script expects the panopitc version.
* [YouTubeVIS 2021](https://youtube-vos.org/dataset/vis/)
    * Download the train set. No preprocessing needed.
* [Supervisely Person Dataset](https://supervise.ly/explore/projects/supervisely-person-dataset-23304/datasets)
    * We used the supervisedly library to convert their encoding to bitmaps masks before using our script. We also resized down some of the large images to avoid disk loading bottleneck.
    * You can refer to [spd_preprocess.py](/documentation/misc/spd_preprocess.py)
    * Or, you can download our [preprocessed version (800M)](https://robustvideomatting.blob.core.windows.net/data/SuperviselyPersonDataset.tar)

## Training

For reference, our training was done on data center machines with 48 CPU cores, 300G CPU memory, and 4 Nvidia V100 32G GPUs.

During our official training, the code contains custom logics for our infrastructure. For release, the script has been cleaned up. There may be bugs existing in this version of the code but not in our official training. If you find problems, please file an issue.

After you have downloaded the datasets. Please configure `train_config.py` to provide paths to your datasets.

The training consists of 4 stages. For detail, please refer to the [paper](https://peterl1n.github.io/RobustVideoMatting/).

### Stage 1
```sh
python train.py \
    --model-variant mobilenetv3 \
    --dataset videomatte \
    --resolution-lr 512 \
    --seq-length-lr 15 \
    --learning-rate-backbone 0.0001 \
    --learning-rate-aspp 0.0002 \
    --learning-rate-decoder 0.0002 \
    --learning-rate-refiner 0 \
    --checkpoint-dir checkpoint/stage1 \
    --log-dir log/stage1 \
    --epoch-start 0 \
    --epoch-end 20
```

### Stage 2
```sh
python train.py \
    --model-variant mobilenetv3 \
    --dataset videomatte \
    --resolution-lr 512 \
    --seq-length-lr 50 \
    --learning-rate-backbone 0.00005 \
    --learning-rate-aspp 0.0001 \
    --learning-rate-decoder 0.0001 \
    --learning-rate-refiner 0 \
    --checkpoint checkpoint/stage1/epoch-19.pth \
    --checkpoint-dir checkpoint/stage2 \
    --log-dir log/stage2 \
    --epoch-start 20 \
    --epoch-end 22
```

### Stage 3
```sh
python train.py \
    --model-variant mobilenetv3 \
    --dataset videomatte \
    --train-hr \
    --resolution-lr 512 \
    --resolution-hr 2048 \
    --seq-length-lr 40 \
    --seq-length-hr 6 \
    --learning-rate-backbone 0.00001 \
    --learning-rate-aspp 0.00001 \
    --learning-rate-decoder 0.00001 \
    --learning-rate-refiner 0.0002 \
    --checkpoint checkpoint/stage2/epoch-21.pth \
    --checkpoint-dir checkpoint/stage3 \
    --log-dir log/stage3 \
    --epoch-start 22 \
    --epoch-end 23
```

### Stage 4
```sh
python train.py \
    --model-variant mobilenetv3 \
    --dataset imagematte \
    --train-hr \
    --resolution-lr 512 \
    --resolution-hr 2048 \
    --seq-length-lr 40 \
    --seq-length-hr 6 \
    --learning-rate-backbone 0.00001 \
    --learning-rate-aspp 0.00001 \
    --learning-rate-decoder 0.00005 \
    --learning-rate-refiner 0.0002 \
    --checkpoint checkpoint/stage3/epoch-22.pth \
    --checkpoint-dir checkpoint/stage4 \
    --log-dir log/stage4 \
    --epoch-start 23 \
    --epoch-end 28
```

<br><br><br>

## Evaluation

We synthetically composite test samples to both image and video backgrounds. Image samples (from D646, AIM) are augmented with synthetic motion.

We only provide the composited VideoMatte240K test set. They are used in our paper evaluation. For D646 and AIM, you need to acquire the data from their authors and composite them yourself. The composition scripts we used are saved in `/evaluation` folder as reference backup. You need to modify them based on your setup.

* [videomatte_512x512.tar (PNG 1.8G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_512x288.tar)
* [videomatte_1920x1080.tar (JPG 2.2G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_1920x1080.tar)

Evaluation scripts are provided in `/evaluation` folder.

================================================
FILE: evaluation/evaluate_hr.py
================================================
"""
HR (High-Resolution) evaluation. We found using numpy is very slow for high resolution, so we moved it to PyTorch using CUDA.

Note, the script only does evaluation. You will need to first inference yourself and save the results to disk
Expected directory format for both prediction and ground-truth is:

    videomatte_1920x1080
        ├── videomatte_motion
          ├── pha
            ├── 0000
              ├── 0000.png
          ├── fgr
            ├── 0000
              ├── 0000.png
        ├── videomatte_static
          ├── pha
            ├── 0000
              ├── 0000.png
          ├── fgr
            ├── 0000
              ├── 0000.png

Prediction must have the exact file structure and file name as the ground-truth,
meaning that if the ground-truth is png/jpg, prediction should be png/jpg.

Example usage:

python evaluate.py \
    --pred-dir pred/videomatte_1920x1080 \
    --true-dir true/videomatte_1920x1080
    
An excel sheet with evaluation results will be written to "pred/videomatte_1920x1080/videomatte_1920x1080.xlsx"
"""


import argparse
import os
import cv2
import kornia
import numpy as np
import xlsxwriter
import torch
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm


class Evaluator:
    def __init__(self):
        self.parse_args()
        self.init_metrics()
        self.evaluate()
        self.write_excel()
        
    def parse_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--pred-dir', type=str, required=True)
        parser.add_argument('--true-dir', type=str, required=True)
        parser.add_argument('--num-workers', type=int, default=48)
        parser.add_argument('--metrics', type=str, nargs='+', default=[
            'pha_mad', 'pha_mse', 'pha_grad', 'pha_dtssd', 'fgr_mse'])
        self.args = parser.parse_args()
        
    def init_metrics(self):
        self.mad = MetricMAD()
        self.mse = MetricMSE()
        self.grad = MetricGRAD()
        self.dtssd = MetricDTSSD()
        
    def evaluate(self):
        tasks = []
        position = 0
        
        with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor:
            for dataset in sorted(os.listdir(self.args.pred_dir)):
                if os.path.isdir(os.path.join(self.args.pred_dir, dataset)):
                    for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))):
                        future = executor.submit(self.evaluate_worker, dataset, clip, position)
                        tasks.append((dataset, clip, future))
                        position += 1
                    
        self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks]
        
    def write_excel(self):
        workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx'))
        summarysheet = workbook.add_worksheet('summary')
        metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()]
        
        for i, metric in enumerate(self.results[0][2].keys()):
            summarysheet.write(i, 0, metric)
            summarysheet.write(i, 1, f'={metric}!B2')
        
        for row, (dataset, clip, metrics) in enumerate(self.results):
            for metricsheet, metric in zip(metricsheets, metrics.values()):
                # Write the header
                if row == 0:
                    metricsheet.write(1, 0, 'Average')
                    metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)')
                    for col in range(len(metric)):
                        metricsheet.write(0, col + 2, col)
                        colname = xlsxwriter.utility.xl_col_to_name(col + 2)
                        metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)')
                        
                metricsheet.write(row + 2, 0, dataset)
                metricsheet.write(row + 2, 1, clip)
                metricsheet.write_row(row + 2, 2, metric)
        
        workbook.close()

    def evaluate_worker(self, dataset, clip, position):
        framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha')))
        metrics = {metric_name : [] for metric_name in self.args.metrics}
        
        pred_pha_tm1 = None
        true_pha_tm1 = None
        
        for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)):
            true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE)
            pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE)
            
            true_pha = torch.from_numpy(true_pha).cuda(non_blocking=True).float().div_(255)
            pred_pha = torch.from_numpy(pred_pha).cuda(non_blocking=True).float().div_(255)
            
            if 'pha_mad' in self.args.metrics:
                metrics['pha_mad'].append(self.mad(pred_pha, true_pha))
            if 'pha_mse' in self.args.metrics:
                metrics['pha_mse'].append(self.mse(pred_pha, true_pha))
            if 'pha_grad' in self.args.metrics:
                metrics['pha_grad'].append(self.grad(pred_pha, true_pha))
            if 'pha_conn' in self.args.metrics:
                metrics['pha_conn'].append(self.conn(pred_pha, true_pha))
            if 'pha_dtssd' in self.args.metrics:
                if i == 0:
                    metrics['pha_dtssd'].append(0)
                else:
                    metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1))
                    
            pred_pha_tm1 = pred_pha
            true_pha_tm1 = true_pha
            
            if 'fgr_mse' in self.args.metrics:
                true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR)
                pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR)
                
                true_fgr = torch.from_numpy(true_fgr).float().div_(255)
                pred_fgr = torch.from_numpy(pred_fgr).float().div_(255)
                
                true_msk = true_pha > 0
                metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk]))

        return metrics


class MetricMAD:
    def __call__(self, pred, true):
        return (pred - true).abs_().mean() * 1e3


class MetricMSE:
    def __call__(self, pred, true):
        return ((pred - true) ** 2).mean() * 1e3


class MetricGRAD:
    def __init__(self, sigma=1.4):
        self.filter_x, self.filter_y = self.gauss_filter(sigma)
        self.filter_x = torch.from_numpy(self.filter_x).unsqueeze(0).cuda()
        self.filter_y = torch.from_numpy(self.filter_y).unsqueeze(0).cuda()
    
    def __call__(self, pred, true):
        true_grad = self.gauss_gradient(true)
        pred_grad = self.gauss_gradient(pred)
        return ((true_grad - pred_grad) ** 2).sum() / 1000
    
    def gauss_gradient(self, img):
        img_filtered_x = kornia.filters.filter2D(img[None, None, :, :], self.filter_x, border_type='replicate')[0, 0]
        img_filtered_y = kornia.filters.filter2D(img[None, None, :, :], self.filter_y, border_type='replicate')[0, 0]
        return (img_filtered_x**2 + img_filtered_y**2).sqrt()
    
    @staticmethod
    def gauss_filter(sigma, epsilon=1e-2):
        half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
        size = np.int(2 * half_size + 1)

        # create filter in x axis
        filter_x = np.zeros((size, size))
        for i in range(size):
            for j in range(size):
                filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian(
                    j - half_size, sigma)

        # normalize filter
        norm = np.sqrt((filter_x**2).sum())
        filter_x = filter_x / norm
        filter_y = np.transpose(filter_x)

        return filter_x, filter_y
        
    @staticmethod
    def gaussian(x, sigma):
        return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
    
    @staticmethod
    def dgaussian(x, sigma):
        return -x * MetricGRAD.gaussian(x, sigma) / sigma**2


class MetricDTSSD:
    def __call__(self, pred_t, pred_tm1, true_t, true_tm1):
        dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2
        dtSSD = dtSSD.sum() / true_t.numel()
        dtSSD = dtSSD.sqrt()
        return dtSSD * 1e2


if __name__ == '__main__':
    Evaluator()

================================================
FILE: evaluation/evaluate_lr.py
================================================
"""
LR (Low-Resolution) evaluation.

Note, the script only does evaluation. You will need to first inference yourself and save the results to disk
Expected directory format for both prediction and ground-truth is:

    videomatte_512x288
        ├── videomatte_motion
          ├── pha
            ├── 0000
              ├── 0000.png
          ├── fgr
            ├── 0000
              ├── 0000.png
        ├── videomatte_static
          ├── pha
            ├── 0000
              ├── 0000.png
          ├── fgr
            ├── 0000
              ├── 0000.png

Prediction must have the exact file structure and file name as the ground-truth,
meaning that if the ground-truth is png/jpg, prediction should be png/jpg.

Example usage:

python evaluate.py \
    --pred-dir PATH_TO_PREDICTIONS/videomatte_512x288 \
    --true-dir PATH_TO_GROUNDTURTH/videomatte_512x288
    
An excel sheet with evaluation results will be written to "PATH_TO_PREDICTIONS/videomatte_512x288/videomatte_512x288.xlsx"
"""


import argparse
import os
import cv2
import numpy as np
import xlsxwriter
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm


class Evaluator:
    def __init__(self):
        self.parse_args()
        self.init_metrics()
        self.evaluate()
        self.write_excel()
        
    def parse_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--pred-dir', type=str, required=True)
        parser.add_argument('--true-dir', type=str, required=True)
        parser.add_argument('--num-workers', type=int, default=48)
        parser.add_argument('--metrics', type=str, nargs='+', default=[
            'pha_mad', 'pha_mse', 'pha_grad', 'pha_conn', 'pha_dtssd', 'fgr_mad', 'fgr_mse'])
        self.args = parser.parse_args()
        
    def init_metrics(self):
        self.mad = MetricMAD()
        self.mse = MetricMSE()
        self.grad = MetricGRAD()
        self.conn = MetricCONN()
        self.dtssd = MetricDTSSD()
        
    def evaluate(self):
        tasks = []
        position = 0
        
        with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor:
            for dataset in sorted(os.listdir(self.args.pred_dir)):
                if os.path.isdir(os.path.join(self.args.pred_dir, dataset)):
                    for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))):
                        future = executor.submit(self.evaluate_worker, dataset, clip, position)
                        tasks.append((dataset, clip, future))
                        position += 1
                    
        self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks]
        
    def write_excel(self):
        workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx'))
        summarysheet = workbook.add_worksheet('summary')
        metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()]
        
        for i, metric in enumerate(self.results[0][2].keys()):
            summarysheet.write(i, 0, metric)
            summarysheet.write(i, 1, f'={metric}!B2')
        
        for row, (dataset, clip, metrics) in enumerate(self.results):
            for metricsheet, metric in zip(metricsheets, metrics.values()):
                # Write the header
                if row == 0:
                    metricsheet.write(1, 0, 'Average')
                    metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)')
                    for col in range(len(metric)):
                        metricsheet.write(0, col + 2, col)
                        colname = xlsxwriter.utility.xl_col_to_name(col + 2)
                        metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)')
                        
                metricsheet.write(row + 2, 0, dataset)
                metricsheet.write(row + 2, 1, clip)
                metricsheet.write_row(row + 2, 2, metric)
        
        workbook.close()

    def evaluate_worker(self, dataset, clip, position):
        framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha')))
        metrics = {metric_name : [] for metric_name in self.args.metrics}
        
        pred_pha_tm1 = None
        true_pha_tm1 = None
        
        for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)):
            true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
            pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
            if 'pha_mad' in self.args.metrics:
                metrics['pha_mad'].append(self.mad(pred_pha, true_pha))
            if 'pha_mse' in self.args.metrics:
                metrics['pha_mse'].append(self.mse(pred_pha, true_pha))
            if 'pha_grad' in self.args.metrics:
                metrics['pha_grad'].append(self.grad(pred_pha, true_pha))
            if 'pha_conn' in self.args.metrics:
                metrics['pha_conn'].append(self.conn(pred_pha, true_pha))
            if 'pha_dtssd' in self.args.metrics:
                if i == 0:
                    metrics['pha_dtssd'].append(0)
                else:
                    metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1))
                    
            pred_pha_tm1 = pred_pha
            true_pha_tm1 = true_pha
            
            if 'fgr_mse' in self.args.metrics or 'fgr_mad' in self.args.metrics:
                true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255
                pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255
                true_msk = true_pha > 0
                
                if 'fgr_mse' in self.args.metrics:
                    metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk]))
                if 'fgr_mad' in self.args.metrics:
                    metrics['fgr_mad'].append(self.mad(pred_fgr[true_msk], true_fgr[true_msk]))

        return metrics

    
class MetricMAD:
    def __call__(self, pred, true):
        return np.abs(pred - true).mean() * 1e3


class MetricMSE:
    def __call__(self, pred, true):
        return ((pred - true) ** 2).mean() * 1e3


class MetricGRAD:
    def __init__(self, sigma=1.4):
        self.filter_x, self.filter_y = self.gauss_filter(sigma)
    
    def __call__(self, pred, true):
        pred_normed = np.zeros_like(pred)
        true_normed = np.zeros_like(true)
        cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX)
        cv2.normalize(true, true_normed, 1., 0., cv2.NORM_MINMAX)

        true_grad = self.gauss_gradient(true_normed).astype(np.float32)
        pred_grad = self.gauss_gradient(pred_normed).astype(np.float32)

        grad_loss = ((true_grad - pred_grad) ** 2).sum()
        return grad_loss / 1000
    
    def gauss_gradient(self, img):
        img_filtered_x = cv2.filter2D(img, -1, self.filter_x, borderType=cv2.BORDER_REPLICATE)
        img_filtered_y = cv2.filter2D(img, -1, self.filter_y, borderType=cv2.BORDER_REPLICATE)
        return np.sqrt(img_filtered_x**2 + img_filtered_y**2)
    
    @staticmethod
    def gauss_filter(sigma, epsilon=1e-2):
        half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
        size = np.int(2 * half_size + 1)

        # create filter in x axis
        filter_x = np.zeros((size, size))
        for i in range(size):
            for j in range(size):
                filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian(
                    j - half_size, sigma)

        # normalize filter
        norm = np.sqrt((filter_x**2).sum())
        filter_x = filter_x / norm
        filter_y = np.transpose(filter_x)

        return filter_x, filter_y
        
    @staticmethod
    def gaussian(x, sigma):
        return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
    
    @staticmethod
    def dgaussian(x, sigma):
        return -x * MetricGRAD.gaussian(x, sigma) / sigma**2


class MetricCONN:
    def __call__(self, pred, true):
        step=0.1
        thresh_steps = np.arange(0, 1 + step, step)
        round_down_map = -np.ones_like(true)
        for i in range(1, len(thresh_steps)):
            true_thresh = true >= thresh_steps[i]
            pred_thresh = pred >= thresh_steps[i]
            intersection = (true_thresh & pred_thresh).astype(np.uint8)

            # connected components
            _, output, stats, _ = cv2.connectedComponentsWithStats(
                intersection, connectivity=4)
            # start from 1 in dim 0 to exclude background
            size = stats[1:, -1]

            # largest connected component of the intersection
            omega = np.zeros_like(true)
            if len(size) != 0:
                max_id = np.argmax(size)
                # plus one to include background
                omega[output == max_id + 1] = 1

            mask = (round_down_map == -1) & (omega == 0)
            round_down_map[mask] = thresh_steps[i - 1]
        round_down_map[round_down_map == -1] = 1

        true_diff = true - round_down_map
        pred_diff = pred - round_down_map
        # only calculate difference larger than or equal to 0.15
        true_phi = 1 - true_diff * (true_diff >= 0.15)
        pred_phi = 1 - pred_diff * (pred_diff >= 0.15)

        connectivity_error = np.sum(np.abs(true_phi - pred_phi))
        return connectivity_error / 1000


class MetricDTSSD:
    def __call__(self, pred_t, pred_tm1, true_t, true_tm1):
        dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2
        dtSSD = np.sum(dtSSD) / true_t.size
        dtSSD = np.sqrt(dtSSD)
        return dtSSD * 1e2



if __name__ == '__main__':
    Evaluator()

================================================
FILE: evaluation/generate_imagematte_with_background_image.py
================================================
"""
python generate_imagematte_with_background_image.py \
    --imagematte-dir ../matting-data/Distinctions/test \
    --background-dir ../matting-data/Backgrounds/valid \
    --resolution 512 \
    --out-dir ../matting-data/evaluation/distinction_static_sd/ \
    --random-seed 10
    
Seed:
    10 - distinction-static
    11 - distinction-motion
    12 - adobe-static
    13 - adobe-motion
    
"""

import argparse
import os
import pims
import numpy as np
import random
from PIL import Image
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from torchvision import transforms
from torchvision.transforms import functional as F

parser = argparse.ArgumentParser()
parser.add_argument('--imagematte-dir', type=str, required=True)
parser.add_argument('--background-dir', type=str, required=True)
parser.add_argument('--num-samples', type=int, default=20)
parser.add_argument('--num-frames', type=int, default=100)
parser.add_argument('--resolution', type=int, required=True)
parser.add_argument('--out-dir', type=str, required=True)
parser.add_argument('--random-seed', type=int)
parser.add_argument('--extension', type=str, default='.png')
args = parser.parse_args()
    
random.seed(args.random_seed)

imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr'))
background_filenames = os.listdir(args.background_dir)
random.shuffle(imagematte_filenames)
random.shuffle(background_filenames)


def lerp(a, b, percentage):
    return a * (1 - percentage) + b * percentage

def motion_affine(*imgs):
    config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
                  scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
    angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
    angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)

    T = len(imgs[0])
    variation_over_time = random.random()
    for t in range(T):
        percentage = (t / (T - 1)) * variation_over_time
        angle = lerp(angleA, angleB, percentage)
        transX = lerp(transXA, transXB, percentage)
        transY = lerp(transYA, transYB, percentage)
        scale = lerp(scaleA, scaleB, percentage)
        shearX = lerp(shearXA, shearXB, percentage)
        shearY = lerp(shearYA, shearYB, percentage)
        for img in imgs:
            img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
    return imgs
    


def process(i):
    imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)]
    background_filename = background_filenames[i % len(background_filenames)]
    
    out_path = os.path.join(args.out_dir, str(i).zfill(4))
    os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'com'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True)
    
    with Image.open(os.path.join(args.background_dir, background_filename)) as bgr:
        bgr = bgr.convert('RGB')
        
        w, h = bgr.size
        scale = args.resolution / min(h, w)
        w, h = int(w * scale), int(h * scale)
        bgr = bgr.resize((w, h))
        bgr = F.center_crop(bgr, (args.resolution, args.resolution))

    with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \
         Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha:
        fgr = fgr.convert('RGB')
        pha = pha.convert('L')
        
    fgrs = [fgr] * args.num_frames
    phas = [pha] * args.num_frames
    fgrs, phas = motion_affine(fgrs, phas)
    
    for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)):
        fgr = fgrs[t]
        pha = phas[t]
        
        w, h = fgr.size
        scale = args.resolution / max(h, w)
        w, h = int(w * scale), int(h * scale)
        
        fgr = fgr.resize((w, h))
        pha = pha.resize((w, h))
        
        if h < args.resolution:
            pt = (args.resolution - h) // 2
            pb = args.resolution - h - pt
        else:
            pt = 0
            pb = 0
            
        if w < args.resolution:
            pl = (args.resolution - w) // 2
            pr = args.resolution - w - pl
        else:
            pl = 0
            pr = 0
            
        fgr = F.pad(fgr, [pl, pt, pr, pb])
        pha = F.pad(pha, [pl, pt, pr, pb])
        
        if i // len(imagematte_filenames) % 2 == 1:
            fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT)
            pha = pha.transpose(Image.FLIP_LEFT_RIGHT)
            
        fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension))
        pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension))
        
        if t == 0:
            bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
        else:
            os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
        
        pha = np.asarray(pha).astype(float)[:, :, None] / 255
        com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha)))
        com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension))


if __name__ == '__main__':
    r = process_map(process, range(args.num_samples), max_workers=32)

================================================
FILE: evaluation/generate_imagematte_with_background_video.py
================================================
"""
python generate_imagematte_with_background_video.py \
    --imagematte-dir ../matting-data/Distinctions/test \
    --background-dir ../matting-data/BackgroundVideos_mp4/test \
    --resolution 512 \
    --out-dir ../matting-data/evaluation/distinction_motion_sd/ \
    --random-seed 11
    
Seed:
    10 - distinction-static
    11 - distinction-motion
    12 - adobe-static
    13 - adobe-motion
    
"""

import argparse
import os
import pims
import numpy as np
import random
from multiprocessing import Pool
from PIL import Image
# from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from torchvision import transforms
from torchvision.transforms import functional as F

parser = argparse.ArgumentParser()
parser.add_argument('--imagematte-dir', type=str, required=True)
parser.add_argument('--background-dir', type=str, required=True)
parser.add_argument('--num-samples', type=int, default=20)
parser.add_argument('--num-frames', type=int, default=100)
parser.add_argument('--resolution', type=int, required=True)
parser.add_argument('--out-dir', type=str, required=True)
parser.add_argument('--random-seed', type=int)
parser.add_argument('--extension', type=str, default='.png')
args = parser.parse_args()
    
random.seed(args.random_seed)

imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr'))
random.shuffle(imagematte_filenames)

background_filenames = [
    "0000.mp4",
    "0007.mp4",
    "0008.mp4",
    "0010.mp4",
    "0013.mp4",
    "0015.mp4",
    "0016.mp4",
    "0018.mp4",
    "0021.mp4",
    "0029.mp4",
    "0033.mp4",
    "0035.mp4",
    "0039.mp4",
    "0050.mp4",
    "0052.mp4",
    "0055.mp4",
    "0060.mp4",
    "0063.mp4",
    "0087.mp4",
    "0086.mp4",
    "0090.mp4",
    "0101.mp4",
    "0110.mp4",
    "0117.mp4",
    "0120.mp4",
    "0122.mp4",
    "0123.mp4",
    "0125.mp4",
    "0128.mp4",
    "0131.mp4",
    "0172.mp4",
    "0176.mp4",
    "0181.mp4",
    "0187.mp4",
    "0193.mp4",
    "0198.mp4",
    "0220.mp4",
    "0221.mp4",
    "0224.mp4",
    "0229.mp4",
    "0233.mp4",
    "0238.mp4",
    "0241.mp4",
    "0245.mp4",
    "0246.mp4"
]

random.shuffle(background_filenames)

def lerp(a, b, percentage):
    return a * (1 - percentage) + b * percentage

def motion_affine(*imgs):
    config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
                  scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
    angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
    angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)

    T = len(imgs[0])
    variation_over_time = random.random()
    for t in range(T):
        percentage = (t / (T - 1)) * variation_over_time
        angle = lerp(angleA, angleB, percentage)
        transX = lerp(transXA, transXB, percentage)
        transY = lerp(transYA, transYB, percentage)
        scale = lerp(scaleA, scaleB, percentage)
        shearX = lerp(shearXA, shearXB, percentage)
        shearY = lerp(shearYA, shearYB, percentage)
        for img in imgs:
            img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
    return imgs


def process(i):
    imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)]
    background_filename = background_filenames[i % len(background_filenames)]
    
    bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filename))
    
    out_path = os.path.join(args.out_dir, str(i).zfill(4))
    os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'com'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True)

    with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \
         Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha:
        fgr = fgr.convert('RGB')
        pha = pha.convert('L')
        
    fgrs = [fgr] * args.num_frames
    phas = [pha] * args.num_frames
    fgrs, phas = motion_affine(fgrs, phas)
    
    for t in range(args.num_frames):
        fgr = fgrs[t]
        pha = phas[t]
        
        w, h = fgr.size
        scale = args.resolution / max(h, w)
        w, h = int(w * scale), int(h * scale)
        
        fgr = fgr.resize((w, h))
        pha = pha.resize((w, h))
        
        if h < args.resolution:
            pt = (args.resolution - h) // 2
            pb = args.resolution - h - pt
        else:
            pt = 0
            pb = 0
            
        if w < args.resolution:
            pl = (args.resolution - w) // 2
            pr = args.resolution - w - pl
        else:
            pl = 0
            pr = 0
            
        fgr = F.pad(fgr, [pl, pt, pr, pb])
        pha = F.pad(pha, [pl, pt, pr, pb])
        
        if i // len(imagematte_filenames) % 2 == 1:
            fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT)
            pha = pha.transpose(Image.FLIP_LEFT_RIGHT)
            
        fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension))
        pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension))
        
        bgr = Image.fromarray(bgrs[t]).convert('RGB')
        w, h = bgr.size
        scale = args.resolution / min(h, w)
        w, h = int(w * scale), int(h * scale)
        bgr = bgr.resize((w, h))
        bgr = F.center_crop(bgr, (args.resolution, args.resolution))
        bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
        
        pha = np.asarray(pha).astype(float)[:, :, None] / 255
        com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha)))
        com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension))
        
if __name__ == '__main__':
    r = process_map(process, range(args.num_samples), max_workers=10)



================================================
FILE: evaluation/generate_videomatte_with_background_image.py
================================================
"""
python generate_videomatte_with_background_image.py \
    --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \
    --background-dir ../matting-data/Backgrounds/valid \
    --num-samples 25 \
    --resize 512 288 \
    --out-dir ../matting-data/evaluation/vidematte_static_sd/
"""

import argparse
import os
import pims
import numpy as np
import random
from PIL import Image
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--videomatte-dir', type=str, required=True)
parser.add_argument('--background-dir', type=str, required=True)
parser.add_argument('--num-samples', type=int, default=20)
parser.add_argument('--num-frames', type=int, default=100)
parser.add_argument('--resize', type=int, default=None, nargs=2)
parser.add_argument('--out-dir', type=str, required=True)
parser.add_argument('--extension', type=str, default='.png')
args = parser.parse_args()
    
random.seed(10)

videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) 
                        for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))]

background_filenames = os.listdir(args.background_dir)
random.shuffle(background_filenames)

for i in range(args.num_samples):
    
    clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)]
    
    out_path = os.path.join(args.out_dir, str(i).zfill(4))
    os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'com'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True)
    
    with Image.open(os.path.join(args.background_dir, background_filenames[i])) as bgr:
        bgr = bgr.convert('RGB')

    
    base_t = random.choice(range(len(framenames) - args.num_frames))
    
    for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)):
        with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \
             Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha:
            fgr = fgr.convert('RGB')
            pha = pha.convert('L')
            
            if args.resize is not None:
                fgr = fgr.resize(args.resize, Image.BILINEAR)
                pha = pha.resize(args.resize, Image.BILINEAR)
                
            
            if i // len(videomatte_filenames) % 2 == 1:
                fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT)
                pha = pha.transpose(Image.FLIP_LEFT_RIGHT)
            
            fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension))
            pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension))
        
        if t == 0:
            bgr = bgr.resize(fgr.size, Image.BILINEAR)
            bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
        else:
            os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
        
        pha = np.asarray(pha).astype(float)[:, :, None] / 255
        com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha)))
        com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension))


================================================
FILE: evaluation/generate_videomatte_with_background_video.py
================================================
"""
python generate_videomatte_with_background_video.py \
    --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \
    --background-dir ../matting-data/BackgroundVideos_mp4/test \
    --resize 512 288 \
    --out-dir ../matting-data/evaluation/vidematte_motion_sd/
"""

import argparse
import os
import pims
import numpy as np
import random
from PIL import Image
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--videomatte-dir', type=str, required=True)
parser.add_argument('--background-dir', type=str, required=True)
parser.add_argument('--num-samples', type=int, default=20)
parser.add_argument('--num-frames', type=int, default=100)
parser.add_argument('--resize', type=int, default=None, nargs=2)
parser.add_argument('--out-dir', type=str, required=True)
args = parser.parse_args()

# Hand selected a list of videos
background_filenames = [
    "0000.mp4",
    "0007.mp4",
    "0008.mp4",
    "0010.mp4",
    "0013.mp4",
    "0015.mp4",
    "0016.mp4",
    "0018.mp4",
    "0021.mp4",
    "0029.mp4",
    "0033.mp4",
    "0035.mp4",
    "0039.mp4",
    "0050.mp4",
    "0052.mp4",
    "0055.mp4",
    "0060.mp4",
    "0063.mp4",
    "0087.mp4",
    "0086.mp4",
    "0090.mp4",
    "0101.mp4",
    "0110.mp4",
    "0117.mp4",
    "0120.mp4",
    "0122.mp4",
    "0123.mp4",
    "0125.mp4",
    "0128.mp4",
    "0131.mp4",
    "0172.mp4",
    "0176.mp4",
    "0181.mp4",
    "0187.mp4",
    "0193.mp4",
    "0198.mp4",
    "0220.mp4",
    "0221.mp4",
    "0224.mp4",
    "0229.mp4",
    "0233.mp4",
    "0238.mp4",
    "0241.mp4",
    "0245.mp4",
    "0246.mp4"
]

random.seed(10)
    
videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) 
                        for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))]

random.shuffle(background_filenames)

for i in range(args.num_samples):
    bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filenames[i % len(background_filenames)]))
    clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)]
    
    out_path = os.path.join(args.out_dir, str(i).zfill(4))
    os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'com'), exist_ok=True)
    os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True)
    
    base_t = random.choice(range(len(framenames) - args.num_frames))
    
    for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)):
        with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \
             Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha:
            fgr = fgr.convert('RGB')
            pha = pha.convert('L')
            
            if args.resize is not None:
                fgr = fgr.resize(args.resize, Image.BILINEAR)
                pha = pha.resize(args.resize, Image.BILINEAR)
                
            
            if i // len(videomatte_filenames) % 2 == 1:
                fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT)
                pha = pha.transpose(Image.FLIP_LEFT_RIGHT)
            
            fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + '.png'))
            pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + '.png'))
        
        bgr = Image.fromarray(bgrs[t])
        bgr = bgr.resize(fgr.size, Image.BILINEAR)
        bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + '.png'))
        
        pha = np.asarray(pha).astype(float)[:, :, None] / 255
        com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha)))
        com.save(os.path.join(out_path, 'com', str(t).zfill(4) + '.png'))


================================================
FILE: hubconf.py
================================================
"""
Loading model
    model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
    model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50")

Converter API
    convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
"""


dependencies = ['torch', 'torchvision']

import torch
from model import MattingNetwork


def mobilenetv3(pretrained: bool = True, progress: bool = True):
    model = MattingNetwork('mobilenetv3')
    if pretrained:
        url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth'
        model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
    return model


def resnet50(pretrained: bool = True, progress: bool = True):
    model = MattingNetwork('resnet50')
    if pretrained:
        url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth'
        model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
    return model


def converter():
    try:
        from inference import convert_video
        return convert_video
    except ModuleNotFoundError as error:
        print(error)
        print('Please run "pip install av tqdm pims"')


================================================
FILE: inference.py
================================================
"""
python inference.py \
    --variant mobilenetv3 \
    --checkpoint "CHECKPOINT" \
    --device cuda \
    --input-source "input.mp4" \
    --output-type video \
    --output-composition "composition.mp4" \
    --output-alpha "alpha.mp4" \
    --output-foreground "foreground.mp4" \
    --output-video-mbps 4 \
    --seq-chunk 1
"""

import torch
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from typing import Optional, Tuple
from tqdm.auto import tqdm

from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter

def convert_video(model,
                  input_source: str,
                  input_resize: Optional[Tuple[int, int]] = None,
                  downsample_ratio: Optional[float] = None,
                  output_type: str = 'video',
                  output_composition: Optional[str] = None,
                  output_alpha: Optional[str] = None,
                  output_foreground: Optional[str] = None,
                  output_video_mbps: Optional[float] = None,
                  seq_chunk: int = 1,
                  num_workers: int = 0,
                  progress: bool = True,
                  device: Optional[str] = None,
                  dtype: Optional[torch.dtype] = None):
    
    """
    Args:
        input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
        input_resize: If provided, the input are first resized to (w, h).
        downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
        output_type: Options: ["video", "png_sequence"].
        output_composition:
            The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
            If output_type == 'video', the composition has green screen background.
            If output_type == 'png_sequence'. the composition is RGBA png images.
        output_alpha: The alpha output from the model.
        output_foreground: The foreground output from the model.
        seq_chunk: Number of frames to process at once. Increase it for better parallelism.
        num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
        progress: Show progress bar.
        device: Only need to manually provide if model is a TorchScript freezed model.
        dtype: Only need to manually provide if model is a TorchScript freezed model.
    """
    
    assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
    assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
    assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
    assert seq_chunk >= 1, 'Sequence chunk must be >= 1'
    assert num_workers >= 0, 'Number of workers must be >= 0'
    
    # Initialize transform
    if input_resize is not None:
        transform = transforms.Compose([
            transforms.Resize(input_resize[::-1]),
            transforms.ToTensor()
        ])
    else:
        transform = transforms.ToTensor()

    # Initialize reader
    if os.path.isfile(input_source):
        source = VideoReader(input_source, transform)
    else:
        source = ImageSequenceReader(input_source, transform)
    reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
    
    # Initialize writers
    if output_type == 'video':
        frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
        output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
        if output_composition is not None:
            writer_com = VideoWriter(
                path=output_composition,
                frame_rate=frame_rate,
                bit_rate=int(output_video_mbps * 1000000))
        if output_alpha is not None:
            writer_pha = VideoWriter(
                path=output_alpha,
                frame_rate=frame_rate,
                bit_rate=int(output_video_mbps * 1000000))
        if output_foreground is not None:
            writer_fgr = VideoWriter(
                path=output_foreground,
                frame_rate=frame_rate,
                bit_rate=int(output_video_mbps * 1000000))
    else:
        if output_composition is not None:
            writer_com = ImageSequenceWriter(output_composition, 'png')
        if output_alpha is not None:
            writer_pha = ImageSequenceWriter(output_alpha, 'png')
        if output_foreground is not None:
            writer_fgr = ImageSequenceWriter(output_foreground, 'png')

    # Inference
    model = model.eval()
    if device is None or dtype is None:
        param = next(model.parameters())
        dtype = param.dtype
        device = param.device
    
    if (output_composition is not None) and (output_type == 'video'):
        bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
    
    try:
        with torch.no_grad():
            bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
            rec = [None] * 4
            for src in reader:

                if downsample_ratio is None:
                    downsample_ratio = auto_downsample_ratio(*src.shape[2:])

                src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
                fgr, pha, *rec = model(src, *rec, downsample_ratio)

                if output_foreground is not None:
                    writer_fgr.write(fgr[0])
                if output_alpha is not None:
                    writer_pha.write(pha[0])
                if output_composition is not None:
                    if output_type == 'video':
                        com = fgr * pha + bgr * (1 - pha)
                    else:
                        fgr = fgr * pha.gt(0)
                        com = torch.cat([fgr, pha], dim=-3)
                    writer_com.write(com[0])
                
                bar.update(src.size(1))

    finally:
        # Clean up
        if output_composition is not None:
            writer_com.close()
        if output_alpha is not None:
            writer_pha.close()
        if output_foreground is not None:
            writer_fgr.close()


def auto_downsample_ratio(h, w):
    """
    Automatically find a downsample ratio so that the largest side of the resolution be 512px.
    """
    return min(512 / max(h, w), 1)


class Converter:
    def __init__(self, variant: str, checkpoint: str, device: str):
        self.model = MattingNetwork(variant).eval().to(device)
        self.model.load_state_dict(torch.load(checkpoint, map_location=device))
        self.model = torch.jit.script(self.model)
        self.model = torch.jit.freeze(self.model)
        self.device = device
    
    def convert(self, *args, **kwargs):
        convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
    
if __name__ == '__main__':
    import argparse
    from model import MattingNetwork
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
    parser.add_argument('--checkpoint', type=str, required=True)
    parser.add_argument('--device', type=str, required=True)
    parser.add_argument('--input-source', type=str, required=True)
    parser.add_argument('--input-resize', type=int, default=None, nargs=2)
    parser.add_argument('--downsample-ratio', type=float)
    parser.add_argument('--output-composition', type=str)
    parser.add_argument('--output-alpha', type=str)
    parser.add_argument('--output-foreground', type=str)
    parser.add_argument('--output-type', type=str, required=True, choices=['video', 'png_sequence'])
    parser.add_argument('--output-video-mbps', type=int, default=1)
    parser.add_argument('--seq-chunk', type=int, default=1)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--disable-progress', action='store_true')
    args = parser.parse_args()
    
    converter = Converter(args.variant, args.checkpoint, args.device)
    converter.convert(
        input_source=args.input_source,
        input_resize=args.input_resize,
        downsample_ratio=args.downsample_ratio,
        output_type=args.output_type,
        output_composition=args.output_composition,
        output_alpha=args.output_alpha,
        output_foreground=args.output_foreground,
        output_video_mbps=args.output_video_mbps,
        seq_chunk=args.seq_chunk,
        num_workers=args.num_workers,
        progress=not args.disable_progress
    )
    
    


================================================
FILE: inference_speed_test.py
================================================
"""
python inference_speed_test.py \
    --model-variant mobilenetv3 \
    --resolution 1920 1080 \
    --downsample-ratio 0.25 \
    --precision float32
"""

import argparse
import torch
from tqdm import tqdm

from model.model import MattingNetwork

torch.backends.cudnn.benchmark = True

class InferenceSpeedTest:
    def __init__(self):
        self.parse_args()
        self.init_model()
        self.loop()
        
    def parse_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--model-variant', type=str, required=True)
        parser.add_argument('--resolution', type=int, required=True, nargs=2)
        parser.add_argument('--downsample-ratio', type=float, required=True)
        parser.add_argument('--precision', type=str, default='float32')
        parser.add_argument('--disable-refiner', action='store_true')
        self.args = parser.parse_args()
        
    def init_model(self):
        self.device = 'cuda'
        self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision]
        self.model = MattingNetwork(self.args.model_variant)
        self.model = self.model.to(device=self.device, dtype=self.precision).eval()
        self.model = torch.jit.script(self.model)
        self.model = torch.jit.freeze(self.model)
    
    def loop(self):
        w, h = self.args.resolution
        src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision)
        with torch.no_grad():
            rec = None, None, None, None
            for _ in tqdm(range(1000)):
                fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio)
                torch.cuda.synchronize()

if __name__ == '__main__':
    InferenceSpeedTest()

================================================
FILE: inference_utils.py
================================================
import av
import os
import pims
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Image


class VideoReader(Dataset):
    def __init__(self, path, transform=None):
        self.video = pims.PyAVVideoReader(path)
        self.rate = self.video.frame_rate
        self.transform = transform
        
    @property
    def frame_rate(self):
        return self.rate
        
    def __len__(self):
        return len(self.video)
        
    def __getitem__(self, idx):
        frame = self.video[idx]
        frame = Image.fromarray(np.asarray(frame))
        if self.transform is not None:
            frame = self.transform(frame)
        return frame


class VideoWriter:
    def __init__(self, path, frame_rate, bit_rate=1000000):
        self.container = av.open(path, mode='w')
        self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}')
        self.stream.pix_fmt = 'yuv420p'
        self.stream.bit_rate = bit_rate
    
    def write(self, frames):
        # frames: [T, C, H, W]
        self.stream.width = frames.size(3)
        self.stream.height = frames.size(2)
        if frames.size(1) == 1:
            frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
        frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
        for t in range(frames.shape[0]):
            frame = frames[t]
            frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
            self.container.mux(self.stream.encode(frame))
                
    def close(self):
        self.container.mux(self.stream.encode())
        self.container.close()


class ImageSequenceReader(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.files = sorted(os.listdir(path))
        self.transform = transform
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        with Image.open(os.path.join(self.path, self.files[idx])) as img:
            img.load()
        if self.transform is not None:
            return self.transform(img)
        return img


class ImageSequenceWriter:
    def __init__(self, path, extension='jpg'):
        self.path = path
        self.extension = extension
        self.counter = 0
        os.makedirs(path, exist_ok=True)
    
    def write(self, frames):
        # frames: [T, C, H, W]
        for t in range(frames.shape[0]):
            to_pil_image(frames[t]).save(os.path.join(
                self.path, str(self.counter).zfill(4) + '.' + self.extension))
            self.counter += 1
            
    def close(self):
        pass
        


================================================
FILE: model/__init__.py
================================================
from .model import MattingNetwork

================================================
FILE: model/decoder.py
================================================
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Tuple, Optional

class RecurrentDecoder(nn.Module):
    def __init__(self, feature_channels, decoder_channels):
        super().__init__()
        self.avgpool = AvgPool()
        self.decode4 = BottleneckBlock(feature_channels[3])
        self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
        self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
        self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
        self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])

    def forward(self,
                s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
                r1: Optional[Tensor], r2: Optional[Tensor],
                r3: Optional[Tensor], r4: Optional[Tensor]):
        s1, s2, s3 = self.avgpool(s0)
        x4, r4 = self.decode4(f4, r4)
        x3, r3 = self.decode3(x4, f3, s3, r3)
        x2, r2 = self.decode2(x3, f2, s2, r2)
        x1, r1 = self.decode1(x2, f1, s1, r1)
        x0 = self.decode0(x1, s0)
        return x0, r1, r2, r3, r4
    

class AvgPool(nn.Module):
    def __init__(self):
        super().__init__()
        self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
        
    def forward_single_frame(self, s0):
        s1 = self.avgpool(s0)
        s2 = self.avgpool(s1)
        s3 = self.avgpool(s2)
        return s1, s2, s3
    
    def forward_time_series(self, s0):
        B, T = s0.shape[:2]
        s0 = s0.flatten(0, 1)
        s1, s2, s3 = self.forward_single_frame(s0)
        s1 = s1.unflatten(0, (B, T))
        s2 = s2.unflatten(0, (B, T))
        s3 = s3.unflatten(0, (B, T))
        return s1, s2, s3
    
    def forward(self, s0):
        if s0.ndim == 5:
            return self.forward_time_series(s0)
        else:
            return self.forward_single_frame(s0)


class BottleneckBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.gru = ConvGRU(channels // 2)
        
    def forward(self, x, r: Optional[Tensor]):
        a, b = x.split(self.channels // 2, dim=-3)
        b, r = self.gru(b, r)
        x = torch.cat([a, b], dim=-3)
        return x, r

    
class UpsamplingBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, src_channels, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
        )
        self.gru = ConvGRU(out_channels // 2)

    def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
        x = self.upsample(x)
        x = x[:, :, :s.size(2), :s.size(3)]
        x = torch.cat([x, f, s], dim=1)
        x = self.conv(x)
        a, b = x.split(self.out_channels // 2, dim=1)
        b, r = self.gru(b, r)
        x = torch.cat([a, b], dim=1)
        return x, r
    
    def forward_time_series(self, x, f, s, r: Optional[Tensor]):
        B, T, _, H, W = s.shape
        x = x.flatten(0, 1)
        f = f.flatten(0, 1)
        s = s.flatten(0, 1)
        x = self.upsample(x)
        x = x[:, :, :H, :W]
        x = torch.cat([x, f, s], dim=1)
        x = self.conv(x)
        x = x.unflatten(0, (B, T))
        a, b = x.split(self.out_channels // 2, dim=2)
        b, r = self.gru(b, r)
        x = torch.cat([a, b], dim=2)
        return x, r
    
    def forward(self, x, f, s, r: Optional[Tensor]):
        if x.ndim == 5:
            return self.forward_time_series(x, f, s, r)
        else:
            return self.forward_single_frame(x, f, s, r)


class OutputBlock(nn.Module):
    def __init__(self, in_channels, src_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
        )
        
    def forward_single_frame(self, x, s):
        x = self.upsample(x)
        x = x[:, :, :s.size(2), :s.size(3)]
        x = torch.cat([x, s], dim=1)
        x = self.conv(x)
        return x
    
    def forward_time_series(self, x, s):
        B, T, _, H, W = s.shape
        x = x.flatten(0, 1)
        s = s.flatten(0, 1)
        x = self.upsample(x)
        x = x[:, :, :H, :W]
        x = torch.cat([x, s], dim=1)
        x = self.conv(x)
        x = x.unflatten(0, (B, T))
        return x
    
    def forward(self, x, s):
        if x.ndim == 5:
            return self.forward_time_series(x, s)
        else:
            return self.forward_single_frame(x, s)


class ConvGRU(nn.Module):
    def __init__(self,
                 channels: int,
                 kernel_size: int = 3,
                 padding: int = 1):
        super().__init__()
        self.channels = channels
        self.ih = nn.Sequential(
            nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
            nn.Sigmoid()
        )
        self.hh = nn.Sequential(
            nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
            nn.Tanh()
        )
        
    def forward_single_frame(self, x, h):
        r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
        c = self.hh(torch.cat([x, r * h], dim=1))
        h = (1 - z) * 
Download .txt
gitextract_gj6tfm4_/

├── LICENSE
├── README.md
├── README_zh_Hans.md
├── dataset/
│   ├── augmentation.py
│   ├── coco.py
│   ├── imagematte.py
│   ├── spd.py
│   ├── videomatte.py
│   └── youtubevis.py
├── documentation/
│   ├── inference.md
│   ├── inference_zh_Hans.md
│   ├── misc/
│   │   ├── aim_test.txt
│   │   ├── d646_test.txt
│   │   ├── dvm_background_test_clips.txt
│   │   ├── dvm_background_train_clips.txt
│   │   ├── imagematte_train.txt
│   │   ├── imagematte_valid.txt
│   │   └── spd_preprocess.py
│   └── training.md
├── evaluation/
│   ├── evaluate_hr.py
│   ├── evaluate_lr.py
│   ├── generate_imagematte_with_background_image.py
│   ├── generate_imagematte_with_background_video.py
│   ├── generate_videomatte_with_background_image.py
│   └── generate_videomatte_with_background_video.py
├── hubconf.py
├── inference.py
├── inference_speed_test.py
├── inference_utils.py
├── model/
│   ├── __init__.py
│   ├── decoder.py
│   ├── deep_guided_filter.py
│   ├── fast_guided_filter.py
│   ├── lraspp.py
│   ├── mobilenetv3.py
│   ├── model.py
│   └── resnet.py
├── requirements_inference.txt
├── requirements_training.txt
├── train.py
├── train_config.py
└── train_loss.py
Download .txt
SYMBOL INDEX (236 symbols across 23 files)

FILE: dataset/augmentation.py
  class MotionAugmentation (line 8) | class MotionAugmentation:
    method __init__ (line 9) | def __init__(self,
    method __call__ (line 35) | def __call__(self, fgrs, phas, bgrs):
    method _static_affine (line 106) | def _static_affine(self, *imgs, scale_ranges):
    method _motion_affine (line 113) | def _motion_affine(self, *imgs):
    method _motion_noise (line 133) | def _motion_noise(self, *imgs):
    method _motion_color_jitter (line 145) | def _motion_color_jitter(self, *imgs):
    method _motion_blur (line 160) | def _motion_blur(self, *imgs):
    method _motion_pause (line 178) | def _motion_pause(self, *imgs):
  function lerp (line 187) | def lerp(a, b, percentage):
  function random_easing_fn (line 191) | def random_easing_fn():
  class Step (line 229) | class Step: # Custom easing function for sudden change.
    method __call__ (line 230) | def __call__(self, value):
  class TrainFrameSampler (line 237) | class TrainFrameSampler:
    method __init__ (line 238) | def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):
    method __call__ (line 241) | def __call__(self, seq_length):
  class ValidFrameSampler (line 258) | class ValidFrameSampler:
    method __call__ (line 259) | def __call__(self, seq_length):

FILE: dataset/coco.py
  class CocoPanopticDataset (line 12) | class CocoPanopticDataset(Dataset):
    method __init__ (line 13) | def __init__(self,
    method __len__ (line 25) | def __len__(self):
    method __getitem__ (line 28) | def __getitem__(self, idx):
    method _load_img (line 38) | def _load_img(self, data):
    method _load_seg (line 42) | def _load_seg(self, data):
  class CocoPanopticTrainAugmentation (line 57) | class CocoPanopticTrainAugmentation:
    method __init__ (line 58) | def __init__(self, size):
    method __call__ (line 62) | def __call__(self, img, seg):
  class CocoPanopticValidAugmentation (line 89) | class CocoPanopticValidAugmentation:
    method __init__ (line 90) | def __init__(self, size):
    method __call__ (line 93) | def __call__(self, img, seg):

FILE: dataset/imagematte.py
  class ImageMatteDataset (line 9) | class ImageMatteDataset(Dataset):
    method __init__ (line 10) | def __init__(self,
    method __len__ (line 31) | def __len__(self):
    method __getitem__ (line 34) | def __getitem__(self, idx):
    method _get_imagematte (line 47) | def _get_imagematte(self, idx):
    method _get_random_image_background (line 56) | def _get_random_image_background(self):
    method _get_random_video_background (line 62) | def _get_random_video_background(self):
    method _downsample_if_needed (line 76) | def _downsample_if_needed(self, img):
  class ImageMatteAugmentation (line 85) | class ImageMatteAugmentation(MotionAugmentation):
    method __init__ (line 86) | def __init__(self, size):

FILE: dataset/spd.py
  class SuperviselyPersonDataset (line 6) | class SuperviselyPersonDataset(Dataset):
    method __init__ (line 7) | def __init__(self, imgdir, segdir, transform=None):
    method __len__ (line 15) | def __len__(self):
    method __getitem__ (line 18) | def __getitem__(self, idx):

FILE: dataset/videomatte.py
  class VideoMatteDataset (line 9) | class VideoMatteDataset(Dataset):
    method __init__ (line 10) | def __init__(self,
    method __len__ (line 37) | def __len__(self):
    method __getitem__ (line 40) | def __getitem__(self, idx):
    method _get_random_image_background (line 53) | def _get_random_image_background(self):
    method _get_random_video_background (line 59) | def _get_random_video_background(self):
    method _get_videomatte (line 73) | def _get_videomatte(self, idx):
    method _downsample_if_needed (line 88) | def _downsample_if_needed(self, img):
  class VideoMatteTrainAugmentation (line 97) | class VideoMatteTrainAugmentation(MotionAugmentation):
    method __init__ (line 98) | def __init__(self, size):
  class VideoMatteValidAugmentation (line 112) | class VideoMatteValidAugmentation(MotionAugmentation):
    method __init__ (line 113) | def __init__(self, size):

FILE: dataset/youtubevis.py
  class YouTubeVISDataset (line 12) | class YouTubeVISDataset(Dataset):
    method __init__ (line 13) | def __init__(self, videodir, annfile, size, seq_length, seq_sampler, t...
    method __len__ (line 44) | def __len__(self):
    method __getitem__ (line 47) | def __getitem__(self, idx):
    method _decode_rle (line 73) | def _decode_rle(self, rle):
    method _downsample_if_needed (line 85) | def _downsample_if_needed(self, img, resample):
  class YouTubeVISAugmentation (line 95) | class YouTubeVISAugmentation:
    method __init__ (line 96) | def __init__(self, size):
    method __call__ (line 100) | def __call__(self, imgs, segs):

FILE: evaluation/evaluate_hr.py
  class Evaluator (line 47) | class Evaluator:
    method __init__ (line 48) | def __init__(self):
    method parse_args (line 54) | def parse_args(self):
    method init_metrics (line 63) | def init_metrics(self):
    method evaluate (line 69) | def evaluate(self):
    method write_excel (line 83) | def write_excel(self):
    method evaluate_worker (line 109) | def evaluate_worker(self, dataset, clip, position):
  class MetricMAD (line 153) | class MetricMAD:
    method __call__ (line 154) | def __call__(self, pred, true):
  class MetricMSE (line 158) | class MetricMSE:
    method __call__ (line 159) | def __call__(self, pred, true):
  class MetricGRAD (line 163) | class MetricGRAD:
    method __init__ (line 164) | def __init__(self, sigma=1.4):
    method __call__ (line 169) | def __call__(self, pred, true):
    method gauss_gradient (line 174) | def gauss_gradient(self, img):
    method gauss_filter (line 180) | def gauss_filter(sigma, epsilon=1e-2):
    method gaussian (line 199) | def gaussian(x, sigma):
    method dgaussian (line 203) | def dgaussian(x, sigma):
  class MetricDTSSD (line 207) | class MetricDTSSD:
    method __call__ (line 208) | def __call__(self, pred_t, pred_tm1, true_t, true_tm1):

FILE: evaluation/evaluate_lr.py
  class Evaluator (line 45) | class Evaluator:
    method __init__ (line 46) | def __init__(self):
    method parse_args (line 52) | def parse_args(self):
    method init_metrics (line 61) | def init_metrics(self):
    method evaluate (line 68) | def evaluate(self):
    method write_excel (line 82) | def write_excel(self):
    method evaluate_worker (line 108) | def evaluate_worker(self, dataset, clip, position):
  class MetricMAD (line 148) | class MetricMAD:
    method __call__ (line 149) | def __call__(self, pred, true):
  class MetricMSE (line 153) | class MetricMSE:
    method __call__ (line 154) | def __call__(self, pred, true):
  class MetricGRAD (line 158) | class MetricGRAD:
    method __init__ (line 159) | def __init__(self, sigma=1.4):
    method __call__ (line 162) | def __call__(self, pred, true):
    method gauss_gradient (line 174) | def gauss_gradient(self, img):
    method gauss_filter (line 180) | def gauss_filter(sigma, epsilon=1e-2):
    method gaussian (line 199) | def gaussian(x, sigma):
    method dgaussian (line 203) | def dgaussian(x, sigma):
  class MetricCONN (line 207) | class MetricCONN:
    method __call__ (line 208) | def __call__(self, pred, true):
  class MetricDTSSD (line 244) | class MetricDTSSD:
    method __call__ (line 245) | def __call__(self, pred_t, pred_tm1, true_t, true_tm1):

FILE: evaluation/generate_imagematte_with_background_image.py
  function lerp (line 47) | def lerp(a, b, percentage):
  function motion_affine (line 50) | def motion_affine(*imgs):
  function process (line 72) | def process(i):

FILE: evaluation/generate_imagematte_with_background_video.py
  function lerp (line 95) | def lerp(a, b, percentage):
  function motion_affine (line 98) | def motion_affine(*imgs):
  function process (line 119) | def process(i):

FILE: hubconf.py
  function mobilenetv3 (line 17) | def mobilenetv3(pretrained: bool = True, progress: bool = True):
  function resnet50 (line 25) | def resnet50(pretrained: bool = True, progress: bool = True):
  function converter (line 33) | def converter():

FILE: inference.py
  function convert_video (line 24) | def convert_video(model,
  function auto_downsample_ratio (line 153) | def auto_downsample_ratio(h, w):
  class Converter (line 160) | class Converter:
    method __init__ (line 161) | def __init__(self, variant: str, checkpoint: str, device: str):
    method convert (line 168) | def convert(self, *args, **kwargs):

FILE: inference_speed_test.py
  class InferenceSpeedTest (line 17) | class InferenceSpeedTest:
    method __init__ (line 18) | def __init__(self):
    method parse_args (line 23) | def parse_args(self):
    method init_model (line 32) | def init_model(self):
    method loop (line 40) | def loop(self):

FILE: inference_utils.py
  class VideoReader (line 10) | class VideoReader(Dataset):
    method __init__ (line 11) | def __init__(self, path, transform=None):
    method frame_rate (line 17) | def frame_rate(self):
    method __len__ (line 20) | def __len__(self):
    method __getitem__ (line 23) | def __getitem__(self, idx):
  class VideoWriter (line 31) | class VideoWriter:
    method __init__ (line 32) | def __init__(self, path, frame_rate, bit_rate=1000000):
    method write (line 38) | def write(self, frames):
    method close (line 50) | def close(self):
  class ImageSequenceReader (line 55) | class ImageSequenceReader(Dataset):
    method __init__ (line 56) | def __init__(self, path, transform=None):
    method __len__ (line 61) | def __len__(self):
    method __getitem__ (line 64) | def __getitem__(self, idx):
  class ImageSequenceWriter (line 72) | class ImageSequenceWriter:
    method __init__ (line 73) | def __init__(self, path, extension='jpg'):
    method write (line 79) | def write(self, frames):
    method close (line 86) | def close(self):

FILE: model/decoder.py
  class RecurrentDecoder (line 7) | class RecurrentDecoder(nn.Module):
    method __init__ (line 8) | def __init__(self, feature_channels, decoder_channels):
    method forward (line 17) | def forward(self,
  class AvgPool (line 30) | class AvgPool(nn.Module):
    method __init__ (line 31) | def __init__(self):
    method forward_single_frame (line 35) | def forward_single_frame(self, s0):
    method forward_time_series (line 41) | def forward_time_series(self, s0):
    method forward (line 50) | def forward(self, s0):
  class BottleneckBlock (line 57) | class BottleneckBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, channels):
    method forward (line 63) | def forward(self, x, r: Optional[Tensor]):
  class UpsamplingBlock (line 70) | class UpsamplingBlock(nn.Module):
    method __init__ (line 71) | def __init__(self, in_channels, skip_channels, src_channels, out_chann...
    method forward_single_frame (line 82) | def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
    method forward_time_series (line 92) | def forward_time_series(self, x, f, s, r: Optional[Tensor]):
    method forward (line 107) | def forward(self, x, f, s, r: Optional[Tensor]):
  class OutputBlock (line 114) | class OutputBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, in_channels, src_channels, out_channels):
    method forward_single_frame (line 127) | def forward_single_frame(self, x, s):
    method forward_time_series (line 134) | def forward_time_series(self, x, s):
    method forward (line 145) | def forward(self, x, s):
  class ConvGRU (line 152) | class ConvGRU(nn.Module):
    method __init__ (line 153) | def __init__(self,
    method forward_single_frame (line 168) | def forward_single_frame(self, x, h):
    method forward_time_series (line 174) | def forward_time_series(self, x, h):
    method forward (line 182) | def forward(self, x, h: Optional[Tensor]):
  class Projection (line 193) | class Projection(nn.Module):
    method __init__ (line 194) | def __init__(self, in_channels, out_channels):
    method forward_single_frame (line 198) | def forward_single_frame(self, x):
    method forward_time_series (line 201) | def forward_time_series(self, x):
    method forward (line 205) | def forward(self, x):

FILE: model/deep_guided_filter.py
  class DeepGuidedFilterRefiner (line 9) | class DeepGuidedFilterRefiner(nn.Module):
    method __init__ (line 10) | def __init__(self, hid_channels=16):
    method forward_single_frame (line 24) | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha,...
    method forward_time_series (line 45) | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, ...
    method forward (line 57) | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):

FILE: model/fast_guided_filter.py
  class FastGuidedFilterRefiner (line 9) | class FastGuidedFilterRefiner(nn.Module):
    method __init__ (line 10) | def __init__(self, *args, **kwargs):
    method forward_single_frame (line 14) | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
    method forward_time_series (line 25) | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
    method forward (line 36) | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  class FastGuidedFilter (line 43) | class FastGuidedFilter(nn.Module):
    method __init__ (line 44) | def __init__(self, r: int, eps: float = 1e-5):
    method forward (line 50) | def forward(self, lr_x, lr_y, hr_x):
  class BoxFilter (line 62) | class BoxFilter(nn.Module):
    method __init__ (line 63) | def __init__(self, r):
    method forward (line 67) | def forward(self, x):

FILE: model/lraspp.py
  class LRASPP (line 3) | class LRASPP(nn.Module):
    method __init__ (line 4) | def __init__(self, in_channels, out_channels):
    method forward_single_frame (line 17) | def forward_single_frame(self, x):
    method forward_time_series (line 20) | def forward_time_series(self, x):
    method forward (line 25) | def forward(self, x):

FILE: model/mobilenetv3.py
  class MobileNetV3LargeEncoder (line 6) | class MobileNetV3LargeEncoder(MobileNetV3):
    method __init__ (line 7) | def __init__(self, pretrained: bool = False):
    method forward_single_frame (line 36) | def forward_single_frame(self, x):
    method forward_time_series (line 62) | def forward_time_series(self, x):
    method forward (line 68) | def forward(self, x):

FILE: model/model.py
  class MattingNetwork (line 14) | class MattingNetwork(nn.Module):
    method __init__ (line 15) | def __init__(self,
    method forward (line 40) | def forward(self,
    method _interpolate (line 70) | def _interpolate(self, x: Tensor, scale_factor: float):

FILE: model/resnet.py
  class ResNet50Encoder (line 5) | class ResNet50Encoder(ResNet):
    method __init__ (line 6) | def __init__(self, pretrained: bool = False):
    method forward_single_frame (line 20) | def forward_single_frame(self, x):
    method forward_time_series (line 35) | def forward_time_series(self, x):
    method forward (line 41) | def forward(self, x):

FILE: train.py
  class Trainer (line 126) | class Trainer:
    method __init__ (line 127) | def __init__(self, rank, world_size):
    method parse_args (line 136) | def parse_args(self):
    method init_distributed (line 175) | def init_distributed(self, rank, world_size):
    method init_datasets (line 183) | def init_datasets(self):
    method init_model (line 317) | def init_model(self):
    method init_writer (line 338) | def init_writer(self):
    method train (line 343) | def train(self):
    method train_mat (line 374) | def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
    method train_seg (line 401) | def train_seg(self, true_img, true_seg, log_label):
    method load_next_mat_hr_sample (line 424) | def load_next_mat_hr_sample(self):
    method load_next_seg_video_sample (line 433) | def load_next_seg_video_sample(self):
    method load_next_seg_image_sample (line 442) | def load_next_seg_image_sample(self):
    method validate (line 451) | def validate(self):
    method random_crop (line 473) | def random_crop(self, *imgs):
    method save (line 487) | def save(self):
    method cleanup (line 494) | def cleanup(self):
    method log (line 497) | def log(self, msg):

FILE: train_loss.py
  function matting_loss (line 7) | def matting_loss(pred_fgr, pred_pha, true_fgr, true_pha):
  function segmentation_loss (line 33) | def segmentation_loss(pred_seg, true_seg):
  function laplacian_loss (line 45) | def laplacian_loss(pred, true, max_levels=5):
  function laplacian_pyramid (line 54) | def laplacian_pyramid(img, kernel, max_levels):
  function gauss_kernel (line 66) | def gauss_kernel(device='cpu', dtype=torch.float32):
  function gauss_convolution (line 76) | def gauss_convolution(img, kernel):
  function downsample (line 84) | def downsample(img, kernel):
  function upsample (line 89) | def upsample(img, kernel):
  function crop_to_even_size (line 96) | def crop_to_even_size(img):
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (251K chars).
[
  {
    "path": "LICENSE",
    "chars": 35148,
    "preview": "                    GNU GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free "
  },
  {
    "path": "README.md",
    "chars": 12293,
    "preview": "# Robust Video Matting (RVM)\n\n![Teaser](/documentation/image/teaser.gif)\n\n<p align=\"center\">English | <a href=\"README_zh"
  },
  {
    "path": "README_zh_Hans.md",
    "chars": 10332,
    "preview": "# 稳定视频抠像 (RVM)\n\n![Teaser](/documentation/image/teaser.gif)\n\n<p align=\"center\"><a href=\"README.md\">English</a> | 中文</p>\n\n"
  },
  {
    "path": "dataset/augmentation.py",
    "chars": 10016,
    "preview": "import easing_functions as ef\nimport random\nimport torch\nfrom torchvision import transforms\nfrom torchvision.transforms "
  },
  {
    "path": "dataset/coco.py",
    "chars": 3556,
    "preview": "import os\nimport numpy as np\nimport random\nimport json\nimport os\nfrom torch.utils.data import Dataset\nfrom torchvision i"
  },
  {
    "path": "dataset/imagematte.py",
    "chars": 3926,
    "preview": "import os\nimport random\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nfrom .augmentation import MotionAugm"
  },
  {
    "path": "dataset/spd.py",
    "chars": 898,
    "preview": "import os\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\n\nclass SuperviselyPersonDataset(Dataset):\n    def "
  },
  {
    "path": "dataset/videomatte.py",
    "chars": 4893,
    "preview": "import os\nimport random\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nfrom .augmentation import MotionAugm"
  },
  {
    "path": "dataset/youtubevis.py",
    "chars": 4230,
    "preview": "import torch\nimport os\nimport json\nimport numpy as np\nimport random\nfrom torch.utils.data import Dataset\nfrom PIL import"
  },
  {
    "path": "documentation/inference.md",
    "chars": 11830,
    "preview": "# Inference\n\n<p align=\"center\">English | <a href=\"inference_zh_Hans.md\">中文</a></p>\n\n## Content\n\n* [Concepts](#concepts)\n"
  },
  {
    "path": "documentation/inference_zh_Hans.md",
    "chars": 8676,
    "preview": "# 推断文档\n\n<p align=\"center\"><a href=\"inference.md\">English</a> | 中文</p>\n\n## 目录\n\n* [概念](#概念)\n    * [下采样比](#下采样比)\n    * [循环记"
  },
  {
    "path": "documentation/misc/aim_test.txt",
    "chars": 290,
    "preview": "boy-1518482_1920.png\ngirl-1219339_1920.png\ngirl-1467820_1280.png\ngirl-beautiful-young-face-53000.png\nlong-1245787_1920.p"
  },
  {
    "path": "documentation/misc/d646_test.txt",
    "chars": 130,
    "preview": "test_13.png\ntest_16.png\ntest_18.png\ntest_22.png\ntest_32.png\ntest_35.png\ntest_39.png\ntest_42.png\ntest_46.png\ntest_4.png\nt"
  },
  {
    "path": "documentation/misc/dvm_background_test_clips.txt",
    "chars": 810,
    "preview": "0000\n0001\n0002\n0004\n0005\n0007\n0008\n0009\n0010\n0012\n0013\n0014\n0015\n0016\n0017\n0018\n0019\n0021\n0022\n0023\n0024\n0025\n0027\n0029\n"
  },
  {
    "path": "documentation/misc/dvm_background_train_clips.txt",
    "chars": 15585,
    "preview": "0000\n0002\n0003\n0004\n0005\n0006\n0007\n0009\n0010\n0012\n0013\n0014\n0015\n0016\n0019\n0021\n0022\n0023\n0024\n0025\n0028\n0029\n0030\n0031\n"
  },
  {
    "path": "documentation/misc/imagematte_train.txt",
    "chars": 5766,
    "preview": "10743257206_18e7f44f2e_b.jpg\n10845279884_d2d4c7b4d1_b.jpg\n1-1252426161dfXY.jpg\n1-1255621189mTnS.jpg\n1-1259162624NMFK.jpg"
  },
  {
    "path": "documentation/misc/imagematte_valid.txt",
    "chars": 239,
    "preview": "13564741125_753939e9ce_o.jpg\n3858897226_cae5b75963_o.jpg\n538724499685900405.jpg\nballerina-855652_1920.jpg\nboy-454633_192"
  },
  {
    "path": "documentation/misc/spd_preprocess.py",
    "chars": 1958,
    "preview": "# pip install supervisely\nimport supervisely_lib as sly\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom tqdm imp"
  },
  {
    "path": "documentation/training.md",
    "chars": 7505,
    "preview": "# Training Documentation\n\nThis documentation only shows the way to re-produce our [paper](https://peterl1n.github.io/Rob"
  },
  {
    "path": "evaluation/evaluate_hr.py",
    "chars": 8610,
    "preview": "\"\"\"\nHR (High-Resolution) evaluation. We found using numpy is very slow for high resolution, so we moved it to PyTorch us"
  },
  {
    "path": "evaluation/evaluate_lr.py",
    "chars": 10082,
    "preview": "\"\"\"\nLR (Low-Resolution) evaluation.\n\nNote, the script only does evaluation. You will need to first inference yourself an"
  },
  {
    "path": "evaluation/generate_imagematte_with_background_image.py",
    "chars": 5426,
    "preview": "\"\"\"\npython generate_imagematte_with_background_image.py \\\n    --imagematte-dir ../matting-data/Distinctions/test \\\n    -"
  },
  {
    "path": "evaluation/generate_imagematte_with_background_video.py",
    "chars": 5994,
    "preview": "\"\"\"\npython generate_imagematte_with_background_video.py \\\n    --imagematte-dir ../matting-data/Distinctions/test \\\n    -"
  },
  {
    "path": "evaluation/generate_videomatte_with_background_image.py",
    "chars": 3339,
    "preview": "\"\"\"\npython generate_videomatte_with_background_image.py \\\n    --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/te"
  },
  {
    "path": "evaluation/generate_videomatte_with_background_video.py",
    "chars": 3820,
    "preview": "\"\"\"\npython generate_videomatte_with_background_video.py \\\n    --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/te"
  },
  {
    "path": "hubconf.py",
    "chars": 1283,
    "preview": "\"\"\"\nLoading model\n    model = torch.hub.load(\"PeterL1n/RobustVideoMatting\", \"mobilenetv3\")\n    model = torch.hub.load(\"P"
  },
  {
    "path": "inference.py",
    "chars": 8797,
    "preview": "\"\"\"\npython inference.py \\\n    --variant mobilenetv3 \\\n    --checkpoint \"CHECKPOINT\" \\\n    --device cuda \\\n    --input-so"
  },
  {
    "path": "inference_speed_test.py",
    "chars": 1737,
    "preview": "\"\"\"\npython inference_speed_test.py \\\n    --model-variant mobilenetv3 \\\n    --resolution 1920 1080 \\\n    --downsample-rat"
  },
  {
    "path": "inference_utils.py",
    "chars": 2686,
    "preview": "import av\nimport os\nimport pims\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom torchvision.transforms.func"
  },
  {
    "path": "model/__init__.py",
    "chars": 33,
    "preview": "from .model import MattingNetwork"
  },
  {
    "path": "model/decoder.py",
    "chars": 7091,
    "preview": "import torch\nfrom torch import Tensor\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Tuple"
  },
  {
    "path": "model/deep_guided_filter.py",
    "chars": 2508,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\"\"\"\nAdopted from <https://github.com/wuhuikai/De"
  },
  {
    "path": "model/fast_guided_filter.py",
    "chars": 3009,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\"\"\"\nAdopted from <https://github.com/wuhuikai/De"
  },
  {
    "path": "model/lraspp.py",
    "chars": 896,
    "preview": "from torch import nn\n\nclass LRASPP(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init"
  },
  {
    "path": "model/mobilenetv3.py",
    "chars": 3001,
    "preview": "import torch\nfrom torch import nn\nfrom torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig\nfrom to"
  },
  {
    "path": "model/model.py",
    "chars": 3091,
    "preview": "import torch\nfrom torch import Tensor\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Optio"
  },
  {
    "path": "model/resnet.py",
    "chars": 1349,
    "preview": "import torch\nfrom torch import nn\nfrom torchvision.models.resnet import ResNet, Bottleneck\n\nclass ResNet50Encoder(ResNet"
  },
  {
    "path": "requirements_inference.txt",
    "chars": 65,
    "preview": "av==8.0.3\ntorch==1.9.0\ntorchvision==0.10.0\ntqdm==4.61.1\npims==0.5"
  },
  {
    "path": "requirements_training.txt",
    "chars": 88,
    "preview": "easing_functions==1.0.4\ntensorboard==2.5.0\ntorch==1.9.0\ntorchvision==0.10.0\ntqdm==4.61.1"
  },
  {
    "path": "train.py",
    "chars": 23036,
    "preview": "\"\"\"\n# First update `train_config.py` to set paths to your dataset locations.\n\n# You may want to change `--num-workers` a"
  },
  {
    "path": "train_config.py",
    "chars": 1669,
    "preview": "\"\"\"\nExpected directory format:\n\nVideoMatte Train/Valid:\n    ├──fgr/\n      ├── 0001/\n        ├── 00000.jpg\n        ├── 00"
  },
  {
    "path": "train_loss.py",
    "chars": 3358,
    "preview": "import torch\nfrom torch.nn import functional as F\n\n# -------------------------------------------------------------------"
  }
]

About this extraction

This page contains the full source code of the PeterL1n/RobustVideoMatting GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (233.4 KB), approximately 70.8k tokens, and a symbol index with 236 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!