Repository: PeterL1n/RobustVideoMatting Branch: master Commit: 53d74c682673 Files: 42 Total size: 233.4 KB Directory structure: gitextract_gj6tfm4_/ ├── LICENSE ├── README.md ├── README_zh_Hans.md ├── dataset/ │ ├── augmentation.py │ ├── coco.py │ ├── imagematte.py │ ├── spd.py │ ├── videomatte.py │ └── youtubevis.py ├── documentation/ │ ├── inference.md │ ├── inference_zh_Hans.md │ ├── misc/ │ │ ├── aim_test.txt │ │ ├── d646_test.txt │ │ ├── dvm_background_test_clips.txt │ │ ├── dvm_background_train_clips.txt │ │ ├── imagematte_train.txt │ │ ├── imagematte_valid.txt │ │ └── spd_preprocess.py │ └── training.md ├── evaluation/ │ ├── evaluate_hr.py │ ├── evaluate_lr.py │ ├── generate_imagematte_with_background_image.py │ ├── generate_imagematte_with_background_video.py │ ├── generate_videomatte_with_background_image.py │ └── generate_videomatte_with_background_video.py ├── hubconf.py ├── inference.py ├── inference_speed_test.py ├── inference_utils.py ├── model/ │ ├── __init__.py │ ├── decoder.py │ ├── deep_guided_filter.py │ ├── fast_guided_filter.py │ ├── lraspp.py │ ├── mobilenetv3.py │ ├── model.py │ └── resnet.py ├── requirements_inference.txt ├── requirements_training.txt ├── train.py ├── train_config.py └── train_loss.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU General Public License is a free, copyleft license for software and other kinds of works. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. We, the Free Software Foundation, use the GNU General Public License for most of our software; it applies also to any other work released this way by its authors. You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. To protect your rights, we need to prevent others from denying you these rights or asking you to surrender the rights. Therefore, you have certain responsibilities if you distribute copies of the software, or if you modify it: responsibilities to respect the freedom of others. For example, if you distribute copies of such a program, whether gratis or for a fee, you must pass on to the recipients the same freedoms that you received. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. Developers that use the GNU GPL protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License giving you legal permission to copy, distribute and/or modify it. For the developers' and authors' protection, the GPL clearly explains that there is no warranty for this free software. For both users' and authors' sake, the GPL requires that modified versions be marked as changed, so that their problems will not be attributed erroneously to authors of previous versions. Some devices are designed to deny users access to install or run modified versions of the software inside them, although the manufacturer can do so. This is fundamentally incompatible with the aim of protecting users' freedom to change the software. The systematic pattern of such abuse occurs in the area of products for individuals to use, which is precisely where it is most unacceptable. Therefore, we have designed this version of the GPL to prohibit the practice for those products. If such problems arise substantially in other domains, we stand ready to extend this provision to those domains in future versions of the GPL, as needed to protect the freedom of users. Finally, every program is threatened constantly by software patents. States should not allow patents to restrict development and use of software on general-purpose computers, but in those that do, we wish to avoid the special danger that patents applied to a free program could make it effectively proprietary. To prevent this, the GPL assures that patents cannot be used to render the program non-free. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Use with the GNU Affero General Public License. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU Affero General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the special requirements of the GNU Affero General Public License, section 13, concerning interaction through a network will apply to the combination as such. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: Copyright (C) This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, your program's commands might be different; for a GUI interface, you would use an "about box". You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see . The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . ================================================ FILE: README.md ================================================ # Robust Video Matting (RVM) ![Teaser](/documentation/image/teaser.gif)

English | 中文

Official repository for the paper [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/). RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves **4K 76FPS** and **HD 104FPS** on an Nvidia GTX 1080 Ti GPU. The project was developed at [ByteDance Inc.](https://www.bytedance.com/)
## News * [Nov 03 2021] Fixed a bug in [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f). * [Sep 16 2021] Code is re-released under GPL-3.0 license. * [Aug 25 2021] Source code and pretrained models are published. * [Jul 27 2021] Paper is accepted by WACV 2022.
## Showreel Watch the showreel video ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)) to see the model's performance.

All footage in the video are available in [Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing).
## Demo * [Webcam Demo](https://peterl1n.github.io/RobustVideoMatting/#/demo): Run the model live in your browser. Visualize recurrent states. * [Colab Demo](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): Test our model on your own videos with free GPU.
## Download We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See [inference documentation](documentation/inference.md) for more instructions.
Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter
All models are available in [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) and [Baidu Pan](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA) (code: gym7).
## PyTorch Example 1. Install dependencies: ```sh pip install -r requirements_inference.txt ``` 2. Load the model: ```python import torch from model import MattingNetwork model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) ``` 3. To convert videos, we provide a simple conversion API: ```python from inference import convert_video convert_video( model, # The model, can be on any device (cpu or cuda). input_source='input.mp4', # A video file or an image sequence directory. output_type='video', # Choose "video" or "png_sequence" output_composition='com.mp4', # File path if video; directory path if png sequence. output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. output_video_mbps=4, # Output video mbps. Not needed for png sequence. downsample_ratio=None, # A hyperparameter to adjust or use None for auto. seq_chunk=12, # Process n frames at once for better parallelism. ) ``` 4. Or write your own inference code: ```python from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from inference_utils import VideoReader, VideoWriter reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30) bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. rec = [None] * 4 # Initial recurrent states. downsample_ratio = 0.25 # Adjust based on your video. with torch.no_grad(): for src in DataLoader(reader): # RGB tensor normalized to 0 ~ 1. fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle the recurrent states. com = fgr * pha + bgr * (1 - pha) # Composite to green background. writer.write(com) # Write frame. ``` 5. The models and converter API are also available through TorchHub. ```python # Load the model. model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" # Converter API. convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") ``` Please see [inference documentation](documentation/inference.md) for details on `downsample_ratio` hyperparameter, more converter arguments, and more advanced usage.
## Training and Evaluation Please refer to the [training documentation](documentation/training.md) to train and evaluate your own model.
## Speed Speed is measured with `inference_speed_test.py` for reference. | GPU | dType | HD (1920x1080) | 4K (3840x2160) | | -------------- | ----- | -------------- |----------------| | RTX 3090 | FP16 | 172 FPS | 154 FPS | | RTX 2060 Super | FP16 | 134 FPS | 108 FPS | | GTX 1080 Ti | FP32 | 104 FPS | 74 FPS | * Note 1: HD uses `downsample_ratio=0.25`, 4K uses `downsample_ratio=0.125`. All tests use batch size 1 and frame chunk 1. * Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32. * Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework).
## Project Members * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/) * [Linjie Yang](https://sites.google.com/site/linjieyang89/) * [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/) * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/)
## Third-Party Projects * [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu)) * [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth)) * [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391)) * [Unity Engine demo with NatML](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite)) * [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) * [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) ================================================ FILE: README_zh_Hans.md ================================================ # 稳定视频抠像 (RVM) ![Teaser](/documentation/image/teaser.gif)

English | 中文

论文 [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/) 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 **4K 76FPS** 和 **HD 104FPS**。此研究项目来自[字节跳动](https://www.bytedance.com/)。
## 更新 * [2021年11月3日] 修复了 [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f) 的 bug。 * [2021年9月16日] 代码重新以 GPL-3.0 许可发布。 * [2021年8月25日] 公开代码和模型。 * [2021年7月27日] 论文被 WACV 2022 收录。
## 展示视频 观看展示视频 ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)),了解模型能力。

视频中的所有素材都提供下载,可用于测试模型:[Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing)
## Demo * [网页](https://peterl1n.github.io/RobustVideoMatting/#/demo): 在浏览器里看摄像头抠像效果,展示模型内部循环记忆值。 * [Colab](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): 用我们的模型转换你的视频。
## 下载 推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读[推断文档](documentation/inference_zh_Hans.md)。
框架 下载 备注
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
官方 PyTorch 模型权值。文档
TorchHub 无需手动下载。 更方便地在你的 PyTorch 项目里使用此模型。文档
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
若需在移动端推断,可以考虑自行导出 int8 量化的模型。文档
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。文档导出
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel 格式。文档
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
在网页上跑模型。展示示范代码
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。s 代表下采样比。文档导出
所有模型可在 [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) 或[百度网盘](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA)(密码: gym7)上下载。
## PyTorch 范例 1. 安装 Python 库: ```sh pip install -r requirements_inference.txt ``` 2. 加载模型: ```python import torch from model import MattingNetwork model = MattingNetwork('mobilenetv3').eval().cuda() # 或 "resnet50" model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) ``` 3. 若只需要做视频抠像处理,我们提供简单的 API: ```python from inference import convert_video convert_video( model, # 模型,可以加载到任何设备(cpu 或 cuda) input_source='input.mp4', # 视频文件,或图片序列文件夹 output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列) output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径 output_alpha="pha.mp4", # [可选项] 输出透明度预测 output_foreground="fgr.mp4", # [可选项] 输出前景预测 output_video_mbps=4, # 若导出视频,提供视频码率 downsample_ratio=None, # 下采样比,可根据具体视频调节,或 None 选择自动 seq_chunk=12, # 设置多帧并行计算 ) ``` 4. 或自己写推断逻辑: ```python from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from inference_utils import VideoReader, VideoWriter reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30) bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景 rec = [None] * 4 # 初始循环记忆(Recurrent States) downsample_ratio = 0.25 # 下采样比,根据视频调节 with torch.no_grad(): for src in DataLoader(reader): # 输入张量,RGB通道,范围为 0~1 fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # 将上一帧的记忆给下一帧 com = fgr * pha + bgr * (1 - pha) # 将前景合成到绿色背景 writer.write(com) # 输出帧 ``` 5. 模型和 API 也可通过 TorchHub 快速载入。 ```python # 加载模型 model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50" # 转换 API convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") ``` [推断文档](documentation/inference_zh_Hans.md)里有对 `downsample_ratio` 参数,API 使用,和高阶使用的讲解。
## 训练和评估 请参照[训练文档(英文)](documentation/training.md)。
## 速度 速度用 `inference_speed_test.py` 测量以供参考。 | GPU | dType | HD (1920x1080) | 4K (3840x2160) | | -------------- | ----- | -------------- |----------------| | RTX 3090 | FP16 | 172 FPS | 154 FPS | | RTX 2060 Super | FP16 | 134 FPS | 108 FPS | | GTX 1080 Ti | FP32 | 104 FPS | 74 FPS | * 注释1:HD 使用 `downsample_ratio=0.25`,4K 使用 `downsample_ratio=0.125`。 所有测试都使用 batch size 1 和 frame chunk 1。 * 注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。 * 注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework)。
## 项目成员 * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/) * [Linjie Yang](https://sites.google.com/site/linjieyang89/) * [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/) * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/)
## 第三方资源 * [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu)) * [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth)) * [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391)) * [带有 NatML 的 Unity 引擎](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite)) * [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) * [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) ================================================ FILE: dataset/augmentation.py ================================================ import easing_functions as ef import random import torch from torchvision import transforms from torchvision.transforms import functional as F class MotionAugmentation: def __init__(self, size, prob_fgr_affine, prob_bgr_affine, prob_noise, prob_color_jitter, prob_grayscale, prob_sharpness, prob_blur, prob_hflip, prob_pause, static_affine=True, aspect_ratio_range=(0.9, 1.1)): self.size = size self.prob_fgr_affine = prob_fgr_affine self.prob_bgr_affine = prob_bgr_affine self.prob_noise = prob_noise self.prob_color_jitter = prob_color_jitter self.prob_grayscale = prob_grayscale self.prob_sharpness = prob_sharpness self.prob_blur = prob_blur self.prob_hflip = prob_hflip self.prob_pause = prob_pause self.static_affine = static_affine self.aspect_ratio_range = aspect_ratio_range def __call__(self, fgrs, phas, bgrs): # Foreground affine if random.random() < self.prob_fgr_affine: fgrs, phas = self._motion_affine(fgrs, phas) # Background affine if random.random() < self.prob_bgr_affine / 2: bgrs = self._motion_affine(bgrs) if random.random() < self.prob_bgr_affine / 2: fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) # Still Affine if self.static_affine: fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) # To tensor fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) phas = torch.stack([F.to_tensor(pha) for pha in phas]) bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) # Resize params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) # Horizontal flip if random.random() < self.prob_hflip: fgrs = F.hflip(fgrs) phas = F.hflip(phas) if random.random() < self.prob_hflip: bgrs = F.hflip(bgrs) # Noise if random.random() < self.prob_noise: fgrs, bgrs = self._motion_noise(fgrs, bgrs) # Color jitter if random.random() < self.prob_color_jitter: fgrs = self._motion_color_jitter(fgrs) if random.random() < self.prob_color_jitter: bgrs = self._motion_color_jitter(bgrs) # Grayscale if random.random() < self.prob_grayscale: fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() # Sharpen if random.random() < self.prob_sharpness: sharpness = random.random() * 8 fgrs = F.adjust_sharpness(fgrs, sharpness) phas = F.adjust_sharpness(phas, sharpness) bgrs = F.adjust_sharpness(bgrs, sharpness) # Blur if random.random() < self.prob_blur / 3: fgrs, phas = self._motion_blur(fgrs, phas) if random.random() < self.prob_blur / 3: bgrs = self._motion_blur(bgrs) if random.random() < self.prob_blur / 3: fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) # Pause if random.random() < self.prob_pause: fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) return fgrs, phas, bgrs def _static_affine(self, *imgs, scale_ranges): params = transforms.RandomAffine.get_params( degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, shears=(-5, 5), img_size=imgs[0][0].size) imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] return imgs if len(imgs) > 1 else imgs[0] def _motion_affine(self, *imgs): config = dict(degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) T = len(imgs[0]) easing = random_easing_fn() for t in range(T): percentage = easing(t / (T - 1)) angle = lerp(angleA, angleB, percentage) transX = lerp(transXA, transXB, percentage) transY = lerp(transYA, transYB, percentage) scale = lerp(scaleA, scaleB, percentage) shearX = lerp(shearXA, shearXB, percentage) shearY = lerp(shearYA, shearYB, percentage) for img in imgs: img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) return imgs if len(imgs) > 1 else imgs[0] def _motion_noise(self, *imgs): grain_size = random.random() * 3 + 1 # range 1 ~ 4 monochrome = random.random() < 0.5 for img in imgs: T, C, H, W = img.shape noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) noise.mul_(random.random() * 0.2 / grain_size) if grain_size != 1: noise = F.resize(noise, (H, W)) img.add_(noise).clamp_(0, 1) return imgs if len(imgs) > 1 else imgs[0] def _motion_color_jitter(self, *imgs): brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ = torch.randn(8).mul(0.1).tolist() strength = random.random() * 0.2 easing = random_easing_fn() T = len(imgs[0]) for t in range(T): percentage = easing(t / (T - 1)) * strength for img in imgs: img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) return imgs if len(imgs) > 1 else imgs[0] def _motion_blur(self, *imgs): blurA = random.random() * 10 blurB = random.random() * 10 T = len(imgs[0]) easing = random_easing_fn() for t in range(T): percentage = easing(t / (T - 1)) blur = max(lerp(blurA, blurB, percentage), 0) if blur != 0: kernel_size = int(blur * 2) if kernel_size % 2 == 0: kernel_size += 1 # Make kernel_size odd for img in imgs: img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) return imgs if len(imgs) > 1 else imgs[0] def _motion_pause(self, *imgs): T = len(imgs[0]) pause_frame = random.choice(range(T - 1)) pause_length = random.choice(range(T - pause_frame)) for img in imgs: img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] return imgs if len(imgs) > 1 else imgs[0] def lerp(a, b, percentage): return a * (1 - percentage) + b * percentage def random_easing_fn(): if random.random() < 0.2: return ef.LinearInOut() else: return random.choice([ ef.BackEaseIn, ef.BackEaseOut, ef.BackEaseInOut, ef.BounceEaseIn, ef.BounceEaseOut, ef.BounceEaseInOut, ef.CircularEaseIn, ef.CircularEaseOut, ef.CircularEaseInOut, ef.CubicEaseIn, ef.CubicEaseOut, ef.CubicEaseInOut, ef.ExponentialEaseIn, ef.ExponentialEaseOut, ef.ExponentialEaseInOut, ef.ElasticEaseIn, ef.ElasticEaseOut, ef.ElasticEaseInOut, ef.QuadEaseIn, ef.QuadEaseOut, ef.QuadEaseInOut, ef.QuarticEaseIn, ef.QuarticEaseOut, ef.QuarticEaseInOut, ef.QuinticEaseIn, ef.QuinticEaseOut, ef.QuinticEaseInOut, ef.SineEaseIn, ef.SineEaseOut, ef.SineEaseInOut, Step, ])() class Step: # Custom easing function for sudden change. def __call__(self, value): return 0 if value < 0.5 else 1 # ---------------------------- Frame Sampler ---------------------------- class TrainFrameSampler: def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): self.speed = speed def __call__(self, seq_length): frames = list(range(seq_length)) # Speed up speed = random.choice(self.speed) frames = [int(f * speed) for f in frames] # Shift shift = random.choice(range(seq_length)) frames = [f + shift for f in frames] # Reverse if random.random() < 0.5: frames = frames[::-1] return frames class ValidFrameSampler: def __call__(self, seq_length): return range(seq_length) ================================================ FILE: dataset/coco.py ================================================ import os import numpy as np import random import json import os from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms import functional as F from PIL import Image class CocoPanopticDataset(Dataset): def __init__(self, imgdir: str, anndir: str, annfile: str, transform=None): with open(annfile) as f: self.data = json.load(f)['annotations'] self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data)) self.imgdir = imgdir self.anndir = anndir self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] img = self._load_img(data) seg = self._load_seg(data) if self.transform is not None: img, seg = self.transform(img, seg) return img, seg def _load_img(self, data): with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img: return img.convert('RGB') def _load_seg(self, data): with Image.open(os.path.join(self.anndir, data['file_name'])) as ann: ann.load() ann = np.array(ann, copy=False).astype(np.int32) ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2] seg = np.zeros(ann.shape, np.uint8) for segments_info in data['segments_info']: if segments_info['category_id'] in [1, 27, 32]: # person, backpack, tie seg[ann == segments_info['id']] = 255 return Image.fromarray(seg) class CocoPanopticTrainAugmentation: def __init__(self, size): self.size = size self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1) def __call__(self, img, seg): # Affine params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1), scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size) img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR) seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST) # Resize params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3)) img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) # Horizontal flip if random.random() < 0.5: img = F.hflip(img) seg = F.hflip(seg) # Color jitter img = self.jitter(img) # To tensor img = F.to_tensor(img) seg = F.to_tensor(seg) return img, seg class CocoPanopticValidAugmentation: def __init__(self, size): self.size = size def __call__(self, img, seg): # Resize params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.)) img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) # To tensor img = F.to_tensor(img) seg = F.to_tensor(seg) return img, seg ================================================ FILE: dataset/imagematte.py ================================================ import os import random from torch.utils.data import Dataset from PIL import Image from .augmentation import MotionAugmentation class ImageMatteDataset(Dataset): def __init__(self, imagematte_dir, background_image_dir, background_video_dir, size, seq_length, seq_sampler, transform): self.imagematte_dir = imagematte_dir self.imagematte_files = os.listdir(os.path.join(imagematte_dir, 'fgr')) self.background_image_dir = background_image_dir self.background_image_files = os.listdir(background_image_dir) self.background_video_dir = background_video_dir self.background_video_clips = os.listdir(background_video_dir) self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) for clip in self.background_video_clips] self.seq_length = seq_length self.seq_sampler = seq_sampler self.size = size self.transform = transform def __len__(self): return max(len(self.imagematte_files), len(self.background_image_files) + len(self.background_video_clips)) def __getitem__(self, idx): if random.random() < 0.5: bgrs = self._get_random_image_background() else: bgrs = self._get_random_video_background() fgrs, phas = self._get_imagematte(idx) if self.transform is not None: return self.transform(fgrs, phas, bgrs) return fgrs, phas, bgrs def _get_imagematte(self, idx): with Image.open(os.path.join(self.imagematte_dir, 'fgr', self.imagematte_files[idx % len(self.imagematte_files)])) as fgr, \ Image.open(os.path.join(self.imagematte_dir, 'pha', self.imagematte_files[idx % len(self.imagematte_files)])) as pha: fgr = self._downsample_if_needed(fgr.convert('RGB')) pha = self._downsample_if_needed(pha.convert('L')) fgrs = [fgr] * self.seq_length phas = [pha] * self.seq_length return fgrs, phas def _get_random_image_background(self): with Image.open(os.path.join(self.background_image_dir, self.background_image_files[random.choice(range(len(self.background_image_files)))])) as bgr: bgr = self._downsample_if_needed(bgr.convert('RGB')) bgrs = [bgr] * self.seq_length return bgrs def _get_random_video_background(self): clip_idx = random.choice(range(len(self.background_video_clips))) frame_count = len(self.background_video_frames[clip_idx]) frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) clip = self.background_video_clips[clip_idx] bgrs = [] for i in self.seq_sampler(self.seq_length): frame_idx_t = frame_idx + i frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: bgr = self._downsample_if_needed(bgr.convert('RGB')) bgrs.append(bgr) return bgrs def _downsample_if_needed(self, img): w, h = img.size if min(w, h) > self.size: scale = self.size / min(w, h) w = int(scale * w) h = int(scale * h) img = img.resize((w, h)) return img class ImageMatteAugmentation(MotionAugmentation): def __init__(self, size): super().__init__( size=size, prob_fgr_affine=0.95, prob_bgr_affine=0.3, prob_noise=0.05, prob_color_jitter=0.3, prob_grayscale=0.03, prob_sharpness=0.05, prob_blur=0.02, prob_hflip=0.5, prob_pause=0.03, ) ================================================ FILE: dataset/spd.py ================================================ import os from torch.utils.data import Dataset from PIL import Image class SuperviselyPersonDataset(Dataset): def __init__(self, imgdir, segdir, transform=None): self.img_dir = imgdir self.img_files = sorted(os.listdir(imgdir)) self.seg_dir = segdir self.seg_files = sorted(os.listdir(segdir)) assert len(self.img_files) == len(self.seg_files) self.transform = transform def __len__(self): return len(self.img_files) def __getitem__(self, idx): with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \ Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg: img = img.convert('RGB') seg = seg.convert('L') if self.transform is not None: img, seg = self.transform(img, seg) return img, seg ================================================ FILE: dataset/videomatte.py ================================================ import os import random from torch.utils.data import Dataset from PIL import Image from .augmentation import MotionAugmentation class VideoMatteDataset(Dataset): def __init__(self, videomatte_dir, background_image_dir, background_video_dir, size, seq_length, seq_sampler, transform=None): self.background_image_dir = background_image_dir self.background_image_files = os.listdir(background_image_dir) self.background_video_dir = background_video_dir self.background_video_clips = sorted(os.listdir(background_video_dir)) self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) for clip in self.background_video_clips] self.videomatte_dir = videomatte_dir self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr'))) self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) for clip in self.videomatte_clips] self.videomatte_idx = [(clip_idx, frame_idx) for clip_idx in range(len(self.videomatte_clips)) for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)] self.size = size self.seq_length = seq_length self.seq_sampler = seq_sampler self.transform = transform def __len__(self): return len(self.videomatte_idx) def __getitem__(self, idx): if random.random() < 0.5: bgrs = self._get_random_image_background() else: bgrs = self._get_random_video_background() fgrs, phas = self._get_videomatte(idx) if self.transform is not None: return self.transform(fgrs, phas, bgrs) return fgrs, phas, bgrs def _get_random_image_background(self): with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: bgr = self._downsample_if_needed(bgr.convert('RGB')) bgrs = [bgr] * self.seq_length return bgrs def _get_random_video_background(self): clip_idx = random.choice(range(len(self.background_video_clips))) frame_count = len(self.background_video_frames[clip_idx]) frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) clip = self.background_video_clips[clip_idx] bgrs = [] for i in self.seq_sampler(self.seq_length): frame_idx_t = frame_idx + i frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: bgr = self._downsample_if_needed(bgr.convert('RGB')) bgrs.append(bgr) return bgrs def _get_videomatte(self, idx): clip_idx, frame_idx = self.videomatte_idx[idx] clip = self.videomatte_clips[clip_idx] frame_count = len(self.videomatte_frames[clip_idx]) fgrs, phas = [], [] for i in self.seq_sampler(self.seq_length): frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count] with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \ Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha: fgr = self._downsample_if_needed(fgr.convert('RGB')) pha = self._downsample_if_needed(pha.convert('L')) fgrs.append(fgr) phas.append(pha) return fgrs, phas def _downsample_if_needed(self, img): w, h = img.size if min(w, h) > self.size: scale = self.size / min(w, h) w = int(scale * w) h = int(scale * h) img = img.resize((w, h)) return img class VideoMatteTrainAugmentation(MotionAugmentation): def __init__(self, size): super().__init__( size=size, prob_fgr_affine=0.3, prob_bgr_affine=0.3, prob_noise=0.1, prob_color_jitter=0.3, prob_grayscale=0.02, prob_sharpness=0.1, prob_blur=0.02, prob_hflip=0.5, prob_pause=0.03, ) class VideoMatteValidAugmentation(MotionAugmentation): def __init__(self, size): super().__init__( size=size, prob_fgr_affine=0, prob_bgr_affine=0, prob_noise=0, prob_color_jitter=0, prob_grayscale=0, prob_sharpness=0, prob_blur=0, prob_hflip=0, prob_pause=0, ) ================================================ FILE: dataset/youtubevis.py ================================================ import torch import os import json import numpy as np import random from torch.utils.data import Dataset from PIL import Image from torchvision import transforms from torchvision.transforms import functional as F class YouTubeVISDataset(Dataset): def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None): self.videodir = videodir self.size = size self.seq_length = seq_length self.seq_sampler = seq_sampler self.transform = transform with open(annfile) as f: data = json.load(f) self.masks = {} for ann in data['annotations']: if ann['category_id'] == 26: # person video_id = ann['video_id'] if video_id not in self.masks: self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))] for frame, mask in zip(self.masks[video_id], ann['segmentations']): if mask is not None: frame.append(mask) self.videos = {} for video in data['videos']: video_id = video['id'] if video_id in self.masks: self.videos[video_id] = video self.index = [] for video_id in self.videos.keys(): for frame in range(len(self.videos[video_id]['file_names'])): self.index.append((video_id, frame)) def __len__(self): return len(self.index) def __getitem__(self, idx): video_id, frame_id = self.index[idx] video = self.videos[video_id] frame_count = len(self.videos[video_id]['file_names']) H, W = video['height'], video['width'] imgs, segs = [], [] for t in self.seq_sampler(self.seq_length): frame = (frame_id + t) % frame_count filename = video['file_names'][frame] masks = self.masks[video_id][frame] with Image.open(os.path.join(self.videodir, filename)) as img: imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR)) seg = np.zeros((H, W), dtype=np.uint8) for mask in masks: seg |= self._decode_rle(mask) segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST)) if self.transform is not None: imgs, segs = self.transform(imgs, segs) return imgs, segs def _decode_rle(self, rle): H, W = rle['size'] msk = np.zeros(H * W, dtype=np.uint8) encoding = rle['counts'] skip = 0 for i in range(0, len(encoding) - 1, 2): skip += encoding[i] draw = encoding[i + 1] msk[skip : skip + draw] = 255 skip += draw return msk.reshape(W, H).transpose() def _downsample_if_needed(self, img, resample): w, h = img.size if min(w, h) > self.size: scale = self.size / min(w, h) w = int(scale * w) h = int(scale * h) img = img.resize((w, h), resample) return img class YouTubeVISAugmentation: def __init__(self, size): self.size = size self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15) def __call__(self, imgs, segs): # To tensor imgs = torch.stack([F.to_tensor(img) for img in imgs]) segs = torch.stack([F.to_tensor(seg) for seg in segs]) # Resize params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1)) imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) # Color jitter imgs = self.jitter(imgs) # Grayscale if random.random() < 0.05: imgs = F.rgb_to_grayscale(imgs, num_output_channels=3) # Horizontal flip if random.random() < 0.5: imgs = F.hflip(imgs) segs = F.hflip(segs) return imgs, segs ================================================ FILE: documentation/inference.md ================================================ # Inference

English | 中文

## Content * [Concepts](#concepts) * [Downsample Ratio](#downsample-ratio) * [Recurrent States](#recurrent-states) * [PyTorch](#pytorch) * [TorchHub](#torchhub) * [TorchScript](#torchscript) * [ONNX](#onnx) * [TensorFlow](#tensorflow) * [TensorFlow.js](#tensorflowjs) * [CoreML](#coreml)
## Concepts ### Downsample Ratio The table provides a general guideline. Please adjust based on your video content. | Resolution | Portrait | Full-Body | | ------------- | ------------- | -------------- | | <= 512x512 | 1 | 1 | | 1280x720 | 0.375 | 0.6 | | 1920x1080 | 0.25 | 0.4 | | 3840x2160 | 0.125 | 0.2 | Internally, the model resizes down the input for stage 1. Then, it refines at high-resolution for stage 2. Set `downsample_ratio` so that the downsampled resolution is between 256 and 512. For example, for `1920x1080` input with `downsample_ratio=0.25`, the resized resolution `480x270` is between 256 and 512. Adjust `downsample_ratio` base on the video content. If the shot is portrait, a lower `downsample_ratio` is sufficient. If the shot contains the full human body, use high `downsample_ratio`. Note that higher `downsample_ratio` is not always better.
### Recurrent States The model is a recurrent neural network. You must process frames sequentially and recycle its recurrent states. **Correct Way** The recurrent outputs are recycled back as input when processing the next frame. The states are essentially the model's memory. ```python rec = [None] * 4 # Initial recurrent states are None for frame in YOUR_VIDEO: fgr, pha, *rec = model(frame, *rec, downsample_ratio) ``` **Wrong Way** The model does not utilize the recurrent states. Only use it to process independent images. ```python for frame in YOUR_VIDEO: fgr, pha = model(frame, downsample_ratio)[:2] ``` More technical details are in the [paper](https://peterl1n.github.io/RobustVideoMatting/).


## PyTorch Model loading: ```python import torch from model import MattingNetwork model = MattingNetwork(variant='mobilenetv3').eval().cuda() # Or variant="resnet50" model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) ``` Example inference loop: ```python rec = [None] * 4 # Set initial recurrent states to None for src in YOUR_VIDEO: # src can be [B, C, H, W] or [B, T, C, H, W] fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) ``` * `src`: Input frame. * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]`. * If `[B, T, C, H, W]`, a chunk of `T` frames can be given at once for better parallelism. * RGB input is normalized to `0~1` range. * `fgr, pha`: Foreground and alpha predictions. * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]` depends on `src`. * `fgr` has `C=3` for RGB, `pha` has `C=1`. * Outputs normalized to `0~1` range. * `rec`: Recurrent states. * Type of `List[Tensor, Tensor, Tensor, Tensor]`. * Initial `rec` can be `List[None, None, None, None]`. * It has 4 recurrent states because the model has 4 ConvGRU layers. * All tensors are rank 4 regardless of `src` rank. * If a chunk of `T` frames is given, only the last frame's recurrent states will be returned. To inference on video, here is a complete example: ```python from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from inference_utils import VideoReader, VideoWriter reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30) bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. rec = [None] * 4 # Initial recurrent states. with torch.no_grad(): for src in DataLoader(reader): fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25) # Cycle the recurrent states. writer.write(fgr * pha + bgr * (1 - pha)) ``` Or you can use the provided video converter: ```python from inference import convert_video convert_video( model, # The loaded model, can be on any device (cpu or cuda). input_source='input.mp4', # A video file or an image sequence directory. input_resize=(1920, 1080), # [Optional] Resize the input (also the output). downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px. output_type='video', # Choose "video" or "png_sequence" output_composition='com.mp4', # File path if video; directory path if png sequence. output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. output_video_mbps=4, # Output video mbps. Not needed for png sequence. seq_chunk=12, # Process n frames at once for better parallelism. num_workers=1, # Only for image sequence input. Reader threads. progress=True # Print conversion progress. ) ``` The converter can also be invoked in command line: ```sh python inference.py \ --variant mobilenetv3 \ --checkpoint "CHECKPOINT" \ --device cuda \ --input-source "input.mp4" \ --downsample-ratio 0.25 \ --output-type video \ --output-composition "composition.mp4" \ --output-alpha "alpha.mp4" \ --output-foreground "foreground.mp4" \ --output-video-mbps 4 \ --seq-chunk 12 ```


## TorchHub Model loading: ```python model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" ``` Use the conversion function. Refer to the documentation for `convert_video` function above. ```python convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") convert_video(model, ...args...) ```


## TorchScript Model loading: ```python import torch model = torch.jit.load('rvm_mobilenetv3.torchscript') ``` Optionally, freeze the model. This will trigger graph optimization, such as BatchNorm fusion etc. Frozen models are faster. ```python model = torch.jit.freeze(model) ``` Then, you can use the `model` exactly the same as a PyTorch model, with the exception that you must manually provide `device` and `dtype` to the converter API for frozen model. For example: ```python convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32) ```


## ONNX Model spec: * Inputs: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. * `src` is the RGB input frame of shape `[B, C, H, W]` normalized to `0~1` range. * `rXi` are the recurrent state inputs. Initial recurrent states are zero value tensors of shape `[1, 1, 1, 1]`. * `downsample_ratio` is a tensor of shape `[1]`. * Only `downsample_ratio` must have `dtype=FP32`. Other inputs must have `dtype` matching the loaded model's precision. * Outputs: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`] * `fgr, pha` are the foreground and alpha prediction. Normalized to `0~1` range. * `rXo` are the recurrent state outputs. We only show examples of using onnxruntime CUDA backend in Python. Model loading ```python import onnxruntime as ort sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') ``` Naive inference loop ```python import numpy as np rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4 # Must match dtype of the model. downsample_ratio = np.array([0.25], dtype=np.float32) # dtype always FP32 for src in YOUR_VIDEO: # src is of [B, C, H, W] with dtype of the model. fgr, pha, *rec = sess.run([], { 'src': src, 'r1i': rec[0], 'r2i': rec[1], 'r3i': rec[2], 'r4i': rec[3], 'downsample_ratio': downsample_ratio }) ``` If you use GPU version of ONNX Runtime, the above naive implementation has recurrent states transferred between CPU and GPU on every frame. They could have just stayed on the GPU for better performance. Below is an example using `iobinding` to eliminate useless transfers. ```python import onnxruntime as ort import numpy as np # Load model. sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') # Create an io binding. io = sess.io_binding() # Create tensors on CUDA. rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4 downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda') # Set output binding. for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']: io.bind_output(name, 'cuda') # Inference loop for src in YOUR_VIDEO: io.bind_cpu_input('src', src) io.bind_ortvalue_input('r1i', rec[0]) io.bind_ortvalue_input('r2i', rec[1]) io.bind_ortvalue_input('r3i', rec[2]) io.bind_ortvalue_input('r4i', rec[3]) io.bind_ortvalue_input('downsample_ratio', downsample_ratio) sess.run_with_iobinding(io) fgr, pha, *rec = io.get_outputs() # Only transfer `fgr` and `pha` to CPU. fgr = fgr.numpy() pha = pha.numpy() ``` Note: depending on the inference tool you choose, it may not support all the operations in our official ONNX model. You are responsible for modifying the model code and exporting your own ONNX model. You can refer to our exporter code in the [onnx branch](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx).


### TensorFlow An example usage: ```python import tensorflow as tf model = tf.keras.models.load_model('rvm_mobilenetv3_tf') model = tf.function(model) rec = [ tf.constant(0.) ] * 4 # Initial recurrent states. downsample_ratio = tf.constant(0.25) # Adjust based on your video. for src in YOUR_VIDEO: # src is of shape [B, H, W, C], not [B, C, H, W]! out = model([src, *rec, downsample_ratio]) fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o'] ``` Note the the tensors are all channel last. Otherwise, the inputs and outputs are exactly the same as PyTorch. We also provide the raw TensorFlow model code in the [tensorflow branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow). You can transfer PyTorch checkpoint weights to TensorFlow models.


### TensorFlow.js We provide a starter code in the [tfjs branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs). The example is very self-explanatory. It shows how to properly use the model.


### CoreML We only show example usage of the CoreML models in Python API using `coremltools`. In production, the same logic can be applied in Swift. When processing the first frame, do not provide recurrent states. CoreML will internally construct zero tensors of the correct shapes as the initial recurrent states. ```python import coremltools as ct model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel') r1, r2, r3, r4 = None, None, None, None for src in YOUR_VIDEO: # src is PIL.Image. if r1 is None: # Initial frame, do not provide recurrent states. inputs = {'src': src} else: # Subsequent frames, provide recurrent states. inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4} outputs = model.predict(inputs) fgr = outputs['fgr'] # PIL.Image. pha = outputs['pha'] # PIL.Image. r1 = outputs['r1o'] # Numpy array. r2 = outputs['r2o'] # Numpy array. r3 = outputs['r3o'] # Numpy array. r4 = outputs['r4o'] # Numpy array. ``` Our CoreML models only support fixed resolutions. If you need other resolutions, you can export them yourself. See [coreml branch](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) for model export. ================================================ FILE: documentation/inference_zh_Hans.md ================================================ # 推断文档

English | 中文

## 目录 * [概念](#概念) * [下采样比](#下采样比) * [循环记忆](#循环记忆) * [PyTorch](#pytorch) * [TorchHub](#torchhub) * [TorchScript](#torchscript) * [ONNX](#onnx) * [TensorFlow](#tensorflow) * [TensorFlow.js](#tensorflowjs) * [CoreML](#coreml)
## 概念 ### 下采样比 该表仅供参考。可根据视频内容进行调节。 | 分辨率 | 人像 | 全身 | | ------------- | ------------- | -------------- | | <= 512x512 | 1 | 1 | | 1280x720 | 0.375 | 0.6 | | 1920x1080 | 0.25 | 0.4 | | 3840x2160 | 0.125 | 0.2 | 模型在内部将高分辨率输入缩小做初步的处理,然后再放大做细分处理。 建议设置 `downsample_ratio` 使缩小后的分辨率维持在 256 到 512 像素之间. 例如,`1920x1080` 的输入用 `downsample_ratio=0.25`,缩小后的分辨率 `480x270` 在 256 到 512 像素之间。 根据视频内容调整 `downsample_ratio`。若视频是上身人像,低 `downsample_ratio` 足矣。若视频是全身像,建议尝试更高的 `downsample_ratio`。但注意,过高的 `downsample_ratio` 反而会降低效果。
### 循环记忆 此模型是循环神经网络(Recurrent Neural Network)。必须按顺序处理视频每帧,并提供网络循环记忆。 **正确用法** 循环记忆输出被传递到下一帧做输入。 ```python rec = [None] * 4 # 初始值设置为 None for frame in YOUR_VIDEO: fgr, pha, *rec = model(frame, *rec, downsample_ratio) ``` **错误用法** 没有使用循环记忆。此方法仅可用于处理单独的图片。 ```python for frame in YOUR_VIDEO: fgr, pha = model(frame, downsample_ratio)[:2] ``` 更多技术细节见[论文](https://peterl1n.github.io/RobustVideoMatting/)。


## PyTorch 载入模型: ```python import torch from model import MattingNetwork model = MattingNetwork(variant='mobilenetv3').eval().cuda() # 或 variant="resnet50" model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) ``` 推断循环: ```python rec = [None] * 4 # 初始值设置为 None for src in YOUR_VIDEO: # src 可以是 [B, C, H, W] 或 [B, T, C, H, W] fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) ``` * `src`: 输入帧(Source)。 * 可以是 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的张量。 * 若是 `[B, T, C, H, W]`,可给模型一次 `T` 帧,做一小段一小段地处理,用于更好的并行计算。 * RGB 通道输入,范围为 `0~1`。 * `fgr, pha`: 前景(Foreground)和透明度通道(Alpha)的预测。 * 根据`src`,可为 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的输出。 * `fgr` 是 RGB 三通道,`pha` 为一通道。 * 输出范围为 `0~1`。 * `rec`: 循环记忆(Recurrent States)。 * `List[Tensor, Tensor, Tensor, Tensor]` 类型。 * 初始 `rec` 为 `List[None, None, None, None]`。 * 有四个记忆,因为网络使用四个 `ConvGRU` 层。 * 无论 `src` 的 Rank,所有记忆张量的 Rank 为 4。 * 若一次给予 `T` 帧,只返回处理完最后一帧后的记忆。 完整的推断例子: ```python from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from inference_utils import VideoReader, VideoWriter reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30) bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景 rec = [None] * 4 # 初始记忆 with torch.no_grad(): for src in DataLoader(reader): fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25) # 将上一帧的记忆给下一帧 writer.write(fgr * pha + bgr * (1 - pha)) ``` 或者使用提供的视频转换 API: ```python from inference import convert_video convert_video( model, # 模型,可以加载到任何设备(cpu 或 cuda) input_source='input.mp4', # 视频文件,或图片序列文件夹 input_resize=(1920, 1080), # [可选项] 缩放视频大小 downsample_ratio=0.25, # [可选项] 下采样比,若 None,自动下采样至 512px output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列) output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径 output_alpha="pha.mp4", # [可选项] 输出透明度预测 output_foreground="fgr.mp4", # [可选项] 输出前景预测 output_video_mbps=4, # 若导出视频,提供视频码率 seq_chunk=12, # 设置多帧并行计算 num_workers=1, # 只适用于图片序列输入,读取线程 progress=True # 显示进度条 ) ``` 也可通过命令行调用转换 API: ```sh python inference.py \ --variant mobilenetv3 \ --checkpoint "CHECKPOINT" \ --device cuda \ --input-source "input.mp4" \ --downsample-ratio 0.25 \ --output-type video \ --output-composition "composition.mp4" \ --output-alpha "alpha.mp4" \ --output-foreground "foreground.mp4" \ --output-video-mbps 4 \ --seq-chunk 12 ```


## TorchHub 载入模型: ```python model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" ``` 使用转换 API,具体请参考之前对 `convert_video` 的文档。 ```python convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") convert_video(model, ...args...) ```


## TorchScript 载入模型: ```python import torch model = torch.jit.load('rvm_mobilenetv3.torchscript') ``` 也可以可选的将模型固化(Freeze)。这会对模型进行优化,例如 BatchNorm Fusion 等。固化的模型更快。 ```python model = torch.jit.freeze(model) ``` 然后,可以将 `model` 作为普通的 PyTorch 模型使用。但注意,若用固化模型调用转换 API,必须手动提供 `device` 和 `dtype`: ```python convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32) ```


## ONNX 模型规格: * 输入: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. * `src`:输入帧,RGB 通道,形状为 `[B, C, H, W]`,范围为`0~1`。 * `rXi`:记忆输入,初始值是是形状为 `[1, 1, 1, 1]` 的零张量。 * `downsample_ratio` 下采样比,张量形状为 `[1]`。 * 只有 `downsample_ratio` 必须是 `FP32`,其他输入必须和加载的模型使用一样的 `dtype`。 * 输出: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`] * `fgr, pha`:前景和透明度通道输出,范围为 `0~1`。 * `rXo`:记忆输出。 我们只展示用 ONNX Runtime CUDA Backend 在 Python 上的使用范例。 载入模型: ```python import onnxruntime as ort sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') ``` 简单推断循环,但此方法不是最优化的: ```python import numpy as np rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4 # 必须用模型一样的 dtype downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32 for src in YOUR_VIDEO: # src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype fgr, pha, *rec = sess.run([], { 'src': src, 'r1i': rec[0], 'r2i': rec[1], 'r3i': rec[2], 'r4i': rec[3], 'downsample_ratio': downsample_ratio }) ``` 若使用 GPU,上例会将记忆输出传回到 CPU,再在下一帧时传回到 GPU。这种传输是无意义的,因为记忆值可以留在 GPU 上。下例使用 `iobinding` 来杜绝无用的传输。 ```python import onnxruntime as ort import numpy as np # 载入模型 sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') # 创建 io binding. io = sess.io_binding() # 在 CUDA 上创建张量 rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4 downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda') # 设置输出项 for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']: io.bind_output(name, 'cuda') # 推断 for src in YOUR_VIDEO: io.bind_cpu_input('src', src) io.bind_ortvalue_input('r1i', rec[0]) io.bind_ortvalue_input('r2i', rec[1]) io.bind_ortvalue_input('r3i', rec[2]) io.bind_ortvalue_input('r4i', rec[3]) io.bind_ortvalue_input('downsample_ratio', downsample_ratio) sess.run_with_iobinding(io) fgr, pha, *rec = io.get_outputs() # 只将 `fgr` 和 `pha` 回传到 CPU fgr = fgr.numpy() pha = pha.numpy() ``` 注:若你使用其他推断框架,可能有些 ONNX ops 不被支持,需被替换。可以参考 [onnx](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx) 分支的代码做自行导出。


### TensorFlow 范例: ```python import tensorflow as tf model = tf.keras.models.load_model('rvm_mobilenetv3_tf') model = tf.function(model) rec = [ tf.constant(0.) ] * 4 # 初始记忆 downsample_ratio = tf.constant(0.25) # 下采样率,根据视频调整 for src in YOUR_VIDEO: # src 张量是 [B, H, W, C] 的形状,而不是 [B, C, H, W]! out = model([src, *rec, downsample_ratio]) fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o'] ``` 注意,在 TensorFlow 上,所有张量都是 Channal Last 的格式。 我们提供 TensorFlow 的原始模型代码,请参考 [tensorflow](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow) 分支。您可自行将 PyTorch 的权值转到 TensorFlow 模型上。


### TensorFlow.js 我们在 [tfjs](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs) 分支提供范例代码。代码简单易懂,解释如何正确使用模型。


### CoreML 我们只展示在 Python 下通过 `coremltools` 使用 CoreML 模型。在部署时,同样逻辑可用于 Swift。模型的循环记忆输入不需要在处理第一帧时提供。CoreML 内部会自动创建零张量作为初始记忆。 ```python import coremltools as ct model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel') r1, r2, r3, r4 = None, None, None, None for src in YOUR_VIDEO: # src 是 PIL.Image. if r1 is None: # 初始帧, 不用提供循环记忆 inputs = {'src': src} else: # 剩余帧,提供循环记忆 inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4} outputs = model.predict(inputs) fgr = outputs['fgr'] # PIL.Image pha = outputs['pha'] # PIL.Image r1 = outputs['r1o'] # Numpy array r2 = outputs['r2o'] # Numpy array r3 = outputs['r3o'] # Numpy array r4 = outputs['r4o'] # Numpy array ``` 我们的 CoreML 模型只支持固定分辨率。如果你需要其他分辨率,可自行导出。导出代码见 [coreml](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) 分支。 ================================================ FILE: documentation/misc/aim_test.txt ================================================ boy-1518482_1920.png girl-1219339_1920.png girl-1467820_1280.png girl-beautiful-young-face-53000.png long-1245787_1920.png model-600238_1920.png pexels-photo-58463.png sea-sunny-person-beach.png wedding-dresses-1486260_1280.png woman-952506_1920 (1).png woman-morning-bathrobe-bathroom.png ================================================ FILE: documentation/misc/d646_test.txt ================================================ test_13.png test_16.png test_18.png test_22.png test_32.png test_35.png test_39.png test_42.png test_46.png test_4.png test_6.png ================================================ FILE: documentation/misc/dvm_background_test_clips.txt ================================================ 0000 0001 0002 0004 0005 0007 0008 0009 0010 0012 0013 0014 0015 0016 0017 0018 0019 0021 0022 0023 0024 0025 0027 0029 0030 0032 0033 0034 0035 0037 0038 0039 0040 0041 0042 0043 0045 0046 0047 0048 0050 0051 0052 0054 0055 0057 0058 0059 0060 0061 0062 0063 0064 0065 0066 0068 0070 0071 0073 0074 0075 0077 0078 0079 0080 0081 0082 0083 0084 0085 0086 0089 0097 0100 0101 0102 0103 0104 0106 0107 0109 0110 0111 0113 0115 0116 0117 0119 0120 0121 0122 0123 0124 0125 0126 0127 0128 0129 0130 0131 0132 0133 0134 0135 0136 0137 0143 0145 0147 0148 0150 0159 0160 0161 0162 0165 0166 0168 0172 0174 0175 0176 0178 0181 0182 0183 0184 0185 0187 0194 0198 0200 0201 0207 0210 0211 0212 0215 0217 0218 0219 0220 0222 0223 0224 0225 0226 0227 0229 0230 0231 0232 0233 0234 0235 0237 0240 0241 0242 0243 0244 0245 ================================================ FILE: documentation/misc/dvm_background_train_clips.txt ================================================ 0000 0002 0003 0004 0005 0006 0007 0009 0010 0012 0013 0014 0015 0016 0019 0021 0022 0023 0024 0025 0028 0029 0030 0031 0032 0034 0035 0036 0037 0039 0040 0041 0042 0043 0044 0046 0047 0048 0049 0050 0051 0052 0053 0054 0055 0056 0057 0058 0060 0061 0062 0063 0064 0065 0066 0067 0068 0069 0070 0071 0073 0074 0075 0076 0077 0078 0079 0081 0082 0087 0088 0099 0100 0101 0104 0105 0107 0108 0109 0110 0111 0112 0113 0114 0115 0117 0118 0119 0120 0122 0123 0124 0125 0127 0128 0129 0130 0131 0132 0133 0134 0135 0136 0137 0138 0139 0140 0141 0142 0144 0146 0147 0148 0150 0151 0152 0153 0154 0155 0156 0157 0158 0159 0160 0161 0163 0164 0165 0167 0168 0169 0170 0171 0172 0174 0175 0176 0177 0178 0180 0181 0182 0184 0185 0187 0188 0189 0190 0192 0193 0194 0195 0196 0197 0198 0199 0200 0202 0203 0204 0205 0206 0207 0208 0209 0210 0211 0212 0213 0214 0215 0217 0218 0219 0220 0221 0222 0223 0224 0225 0226 0227 0229 0230 0231 0233 0234 0235 0236 0237 0238 0240 0241 0242 0243 0244 0245 0246 0247 0248 0249 0250 0251 0252 0253 0254 0255 0256 0257 0258 0259 0260 0261 0262 0263 0264 0265 0266 0267 0268 0269 0270 0271 0272 0273 0274 0275 0276 0277 0278 0279 0280 0281 0282 0283 0284 0285 0286 0287 0288 0289 0290 0291 0292 0293 0294 0297 0298 0299 0300 0301 0302 0303 0304 0305 0306 0308 0309 0310 0311 0312 0313 0314 0315 0316 0317 0319 0320 0321 0322 0323 0324 0325 0326 0327 0328 0329 0330 0331 0332 0333 0335 0336 0337 0338 0339 0341 0342 0344 0345 0346 0348 0349 0352 0353 0356 0357 0358 0359 0360 0361 0362 0363 0364 0365 0366 0368 0369 0370 0371 0372 0373 0374 0375 0376 0377 0378 0379 0380 0381 0382 0383 0384 0385 0386 0387 0388 0389 0391 0392 0393 0394 0395 0397 0398 0399 0400 0401 0402 0403 0404 0405 0406 0407 0408 0409 0410 0411 0413 0414 0415 0416 0417 0419 0420 0421 0422 0423 0424 0425 0426 0427 0428 0429 0431 0433 0434 0435 0436 0437 0438 0439 0440 0441 0442 0443 0445 0446 0447 0448 0449 0450 0451 0452 0453 0454 0456 0457 0458 0459 0462 0463 0464 0465 0466 0467 0468 0469 0470 0471 0472 0473 0474 0475 0476 0477 0478 0479 0480 0481 0482 0483 0484 0485 0486 0487 0488 0489 0490 0491 0492 0493 0494 0499 0501 0502 0503 0504 0505 0506 0507 0509 0510 0511 0512 0513 0514 0515 0517 0518 0519 0520 0521 0522 0524 0526 0527 0529 0530 0534 0535 0536 0538 0539 0541 0542 0543 0544 0545 0546 0548 0549 0550 0552 0554 0555 0556 0557 0558 0559 0560 0561 0562 0563 0564 0565 0566 0567 0568 0571 0572 0573 0574 0575 0576 0577 0578 0579 0580 0581 0582 0583 0584 0586 0587 0589 0590 0591 0592 0594 0595 0596 0597 0598 0600 0601 0602 0603 0604 0605 0606 0608 0609 0610 0611 0612 0613 0614 0615 0616 0617 0618 0619 0620 0624 0625 0626 0627 0628 0629 0630 0631 0634 0635 0636 0637 0638 0639 0640 0641 0642 0643 0644 0645 0646 0647 0648 0650 0651 0652 0654 0655 0656 0658 0659 0660 0661 0662 0663 0664 0665 0666 0667 0669 0670 0671 0672 0673 0674 0675 0676 0677 0678 0679 0680 0681 0682 0683 0684 0685 0686 0687 0689 0690 0691 0692 0693 0694 0695 0696 0697 0698 0699 0700 0701 0702 0703 0704 0705 0706 0707 0708 0709 0710 0711 0712 0713 0714 0715 0716 0717 0718 0719 0720 0721 0723 0724 0725 0726 0727 0729 0730 0731 0732 0733 0734 0735 0736 0738 0740 0741 0742 0743 0744 0746 0747 0748 0749 0750 0752 0753 0754 0755 0756 0757 0758 0759 0760 0762 0763 0764 0765 0766 0767 0768 0770 0771 0772 0773 0774 0775 0776 0777 0778 0779 0780 0781 0782 0783 0784 0786 0787 0788 0789 0790 0791 0792 0793 0794 0795 0796 0797 0798 0800 0801 0804 0806 0808 0809 0811 0812 0813 0814 0815 0816 0817 0819 0823 0824 0825 0827 0828 0829 0830 0831 0832 0833 0834 0835 0836 0837 0840 0841 0842 0847 0848 0850 0851 0852 0853 0854 0855 0856 0857 0858 0859 0860 0861 0862 0864 0867 0868 0869 0870 0871 0872 0873 0874 0876 0877 0878 0879 0880 0881 0882 0883 0885 0886 0887 0889 0890 0891 0892 0893 0894 0895 0896 0899 0900 0901 0902 0903 0904 0905 0906 0907 0908 0909 0910 0911 0912 0913 0914 0915 0916 0917 0918 0919 0921 0922 0923 0924 0925 0926 0927 0929 0930 0931 0932 0933 0934 0935 0936 0937 0939 0940 0941 0942 0945 0946 0947 0948 0949 0950 0951 0952 0953 0954 0955 0956 0957 0958 0960 0962 0963 0964 0965 0966 0967 0968 0969 0970 0971 0973 0974 0976 0977 0978 0979 0980 0981 0982 0983 0984 0985 0986 0987 0988 0989 0990 0991 0992 0993 0994 0995 0996 0997 0998 0999 1000 1001 1002 1003 1005 1006 1008 1009 1010 1011 1012 1013 1014 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1028 1029 1030 1032 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1052 1053 1054 1055 1056 1057 1058 1061 1062 1063 1064 1065 1066 1067 1068 1069 1072 1075 1076 1077 1078 1079 1081 1082 1083 1084 1087 1088 1089 1090 1096 1097 1098 1099 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1128 1129 1130 1131 1132 1134 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1165 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1199 1200 1201 1202 1203 1204 1206 1207 1208 1211 1212 1213 1215 1216 1217 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1237 1238 1239 1240 1241 1242 1245 1246 1248 1249 1252 1253 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1277 1278 1279 1280 1281 1282 1283 1284 1287 1288 1289 1290 1291 1293 1294 1295 1296 1297 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1350 1351 1352 1353 1354 1355 1356 1357 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1378 1379 1380 1381 1382 1383 1385 1386 1387 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1402 1403 1405 1406 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1493 1494 1495 1496 1497 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1560 1561 1562 1563 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1608 1609 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1778 1779 1780 1781 1782 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1804 1806 1807 1808 1809 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1825 1826 1827 1828 1829 1831 1833 1834 1835 1836 1837 1838 1839 1840 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1879 1880 1881 1882 1886 1887 1889 1891 1892 1893 1894 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1952 1953 1954 1955 1956 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1980 1981 1982 1983 1984 1985 1986 1988 1989 1990 1991 1992 1993 1994 1995 1997 1998 1999 2000 2002 2003 2004 2005 2007 2008 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2024 2025 2026 2027 2028 2029 2030 2031 2032 2035 2036 2037 2038 2039 2040 2041 2042 2043 2045 2046 2047 2049 2050 2052 2053 2054 2056 2057 2059 2060 2061 2064 2066 2068 2069 2070 2071 2072 2073 2074 2075 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2107 2108 2109 2111 2112 2113 2114 2115 2116 2117 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2143 2144 2145 2146 2147 2148 2149 2152 2153 2155 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2174 2176 2177 2178 2179 2180 2181 2182 2183 2184 2186 2187 2188 2189 2190 2191 2192 2193 2195 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2255 2256 2257 2258 2259 2260 2261 2263 2264 2265 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2290 2291 2292 2293 2294 2295 2297 2299 2300 2301 2302 2303 2304 2305 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2322 2324 2325 2326 2329 2331 2332 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2347 2349 2350 2351 2352 2353 2355 2356 2358 2359 2360 2361 2362 2363 2364 2365 2367 2368 2369 2370 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2386 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2406 2407 2408 2409 2410 2411 2412 2413 2414 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2431 2432 2433 2434 2437 2439 2440 2441 2442 2443 2444 2445 2446 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2472 2474 2475 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2530 2534 2536 2537 2539 2540 2541 2542 2543 2544 2545 2546 2548 2549 2550 2551 2553 2554 2555 2556 2559 2560 2562 2567 2570 2573 2575 2576 2578 2579 2582 2585 2588 2590 2593 2594 2595 2596 2597 2598 2600 2601 2602 2603 2604 2605 2606 2607 2612 2613 2615 2616 2617 2618 2619 2620 2622 2623 2624 2625 2629 2630 2633 2634 2635 2637 2638 2639 2641 2643 2644 2646 2648 2649 2650 2652 2654 2656 2659 2661 2662 2663 2664 2665 2667 2669 2670 2671 2672 2674 2675 2677 2679 2680 2682 2684 2685 2687 2689 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2705 2706 2707 2710 2713 2714 2715 2717 2718 2720 2721 2722 2726 2727 2729 2733 2734 2737 2739 2740 2741 2742 2743 2745 2747 2748 2749 2755 2756 2757 2758 2759 2761 2762 2763 2765 2767 2768 2769 2770 2773 2775 2776 2780 2781 2782 2785 2787 2790 2791 2793 2794 2795 2796 2797 2798 2800 2801 2802 2803 2806 2808 2810 2812 2813 2814 2815 2816 2818 2819 2820 2821 2822 2823 2825 2827 2828 2829 2830 2831 2832 2833 2834 2835 2837 2838 2842 2843 2844 2845 2847 2848 2849 2850 2851 2852 2853 2854 2856 2857 2858 2860 2861 2862 2863 2864 2865 2866 2868 2869 2871 2874 2875 2877 2879 2881 2882 2884 2885 2887 2888 2889 2890 2891 2893 2894 2895 2896 2900 2902 2903 2904 2905 2906 2907 2908 2911 2912 2914 2915 2916 2917 2918 2919 2921 2923 2924 2925 2926 2928 2929 2930 2931 2932 2933 2934 2935 2936 2938 2939 2940 2941 2944 2945 2946 2947 2952 2954 2957 2958 2960 2962 2963 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2981 2982 2984 2985 2986 2988 2993 2994 2996 2998 2999 3000 3001 3002 3004 3006 3007 3008 3010 3015 3016 3017 3018 3019 3020 3021 3022 3024 3026 3027 3029 3033 3034 3035 3036 3037 3040 3041 3042 3043 3044 3045 3046 3047 3048 3049 3050 3051 3053 3056 3057 3058 3059 3060 3061 3063 3065 3066 3068 3069 3070 3074 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086 3087 3094 3095 3097 3098 3100 3101 3102 3103 3105 3106 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3121 3123 3124 3125 3127 3128 3130 3131 3133 3134 3136 3137 3140 3142 3143 3144 3145 3147 3148 3152 3153 3155 3156 3157 3159 3160 3162 3164 3165 3166 3168 3170 3171 3173 3176 3178 3179 3180 3181 3184 3187 3188 3190 3191 3194 3197 3199 3200 3201 3202 3204 3209 3210 3211 3212 3214 3215 3217 3218 3219 3221 3222 3223 3224 3226 3228 3229 3230 3232 3237 3238 3239 3240 3241 3242 3243 3245 3247 3248 3250 3251 3252 3253 3254 3257 3258 3259 3260 3262 3264 3266 3267 3269 3270 3275 3278 3279 3280 3282 3284 3285 3286 3287 3288 3289 3292 3293 3295 3296 3298 3299 3300 3301 3302 3303 3304 3309 3311 3312 3315 3316 3317 3318 3319 3322 3325 3328 3332 3334 3339 3340 3342 3346 3348 3349 3350 3351 3352 3354 3355 3357 3358 3361 3363 3364 3365 3366 3367 3368 3369 3370 3373 3374 3377 3378 3380 3381 3382 3383 3384 3385 3386 3391 3393 3395 3396 3397 3398 3399 3400 3401 3402 3403 3404 3405 3407 3408 3409 3410 3411 3415 3416 3418 3420 3422 3423 3424 3427 3428 3431 3433 3435 3437 3438 3439 3440 3441 3442 3443 3444 3446 3449 3450 3451 3452 3453 3454 3455 3456 3457 3460 3462 3463 3464 3465 3466 3467 3468 3470 3475 3476 3477 3483 3484 3486 3487 3489 3492 3496 3497 3498 3500 3501 3502 3505 3507 3508 3509 3510 3511 3512 3513 3514 3515 3517 3518 3519 3521 3524 3525 3526 3528 3529 3532 3536 3541 3542 3543 3544 3545 3546 3549 3550 3551 3552 3554 3556 3557 3558 3559 3560 3562 3563 3564 3566 3567 3568 3571 3572 3573 3574 3575 3576 3578 3579 3580 3582 3584 3585 3588 3590 3592 3593 3594 3599 3602 3605 3606 3608 3611 3612 3615 3617 3620 3621 3622 3623 3624 3625 3626 3629 3630 3632 3633 3636 3637 3638 3641 3644 3646 3647 3648 3649 3654 3655 3656 3657 3660 3662 3667 3671 3672 3673 3674 3675 3676 3678 3679 3680 3681 3682 3683 3685 3687 3691 3692 3694 3695 3697 3698 3699 3700 3701 3703 3704 3705 3707 3709 3711 3712 3715 3717 3718 3719 3720 3721 3723 3724 3725 3726 3728 3729 3733 3734 3735 3737 3738 3739 3741 3743 3746 3748 3750 3753 3754 3756 3757 3759 3760 3762 3764 3765 3766 3768 3772 3773 3774 3776 3777 3779 3780 3782 3783 3784 3785 3786 3790 3792 3793 3794 3798 3799 3800 3801 3802 3803 3804 3807 3809 3813 3814 3816 3819 3822 3823 3824 3826 3827 3828 3829 3830 3831 3832 3833 3834 3836 3837 3838 3839 3840 3841 3842 3844 3845 3847 3848 3849 3850 3852 3854 3855 3856 3857 3858 3861 3862 3863 3865 3869 3872 3873 3874 3879 3880 3883 3885 3886 3887 3888 3889 3890 3891 3893 3894 3895 3897 3898 3899 3900 3901 3903 3904 3906 3914 3915 3916 3920 3923 3924 3925 3926 3928 3929 3931 3932 3934 3935 3937 3939 3942 3943 3945 3949 3950 3952 3955 3956 3963 3965 3966 3967 3969 3970 3971 3974 3976 3977 3978 3980 3981 3982 3983 3984 3987 3988 3992 3995 3996 3997 3999 ================================================ FILE: documentation/misc/imagematte_train.txt ================================================ 10743257206_18e7f44f2e_b.jpg 10845279884_d2d4c7b4d1_b.jpg 1-1252426161dfXY.jpg 1-1255621189mTnS.jpg 1-1259162624NMFK.jpg 1-1259245823Un3j.jpg 11363165393_05d7a21d76_b.jpg 131686738165901828.jpg 13564741125_753939e9ce_o.jpg 14731860273_5b40b19b51_o.jpg 16087-a-young-woman-showing-a-bitten-green-apple-pv.jpg 1609484818_b9bb12b.jpg 17620-a-beautiful-woman-in-a-bikini-pv.jpg 20672673163_20c8467827_b.jpg 3262986095_2d5afe583c_b.jpg 3588101233_f91aa5e3a3.jpg 3858897226_cae5b75963_o.jpg 4889657410_2d503ca287_o.jpg 4981835627_c4e6c4ffa8_o.jpg 5025666458_576b974455_o.jpg 5149410930_3a943dc43f_b.jpg 539641011387760661.jpg 5892503248_4b882863c7_o.jpg 604673748289192179.jpg 606189768665464996.jpg 624753897218113578.jpg 657454154710122500.jpg 664308724952072193.jpg 7669262460_e4be408343_b.jpg 8244818049_dfa59a3eb8_b.jpg 8688417335_01f3bafbe5_o.jpg 9434599749_e7ccfc7812_b.jpg Aaron_Friedman_Headshot.jpg arrgh___r___28_by_mjranum_stock.jpg arrgh___r___29_by_mjranum_stock.jpg arrgh___r___30_by_mjranum_stock.jpg a-single-person-1084191_960_720.jpg ballerina-855652_1920.jpg beautiful-19075_960_720.jpg boy-454633_1920.jpg bride-2819673_1920.jpg bride-442894_1920.jpg face-1223346_960_720.jpg fashion-model-portrait.jpg fashion-model-pose.jpg girl-1535859_1920.jpg Girl_in_front_of_a_green_background.jpg goth_by_bugidifino-d4w7zms.jpg h_0.jpg h_100.jpg h_101.jpg h_102.jpg h_103.jpg h_104.jpg h_105.jpg h_106.jpg h_107.jpg h_108.jpg h_109.jpg h_10.jpg h_111.jpg h_112.jpg h_113.jpg h_114.jpg h_115.jpg h_116.jpg h_117.jpg h_118.jpg h_119.jpg h_11.jpg h_120.jpg h_121.jpg h_122.jpg h_123.jpg h_124.jpg h_125.jpg h_126.jpg h_127.jpg h_128.jpg h_129.jpg h_12.jpg h_130.jpg h_131.jpg h_132.jpg h_133.jpg h_134.jpg h_135.jpg h_136.jpg h_137.jpg h_138.jpg h_139.jpg h_13.jpg h_140.jpg h_141.jpg h_142.jpg h_143.jpg h_144.jpg h_145.jpg h_146.jpg h_147.jpg h_148.jpg h_149.jpg h_14.jpg h_151.jpg h_152.jpg h_153.jpg h_154.jpg h_155.jpg h_156.jpg h_157.jpg h_158.jpg h_159.jpg h_15.jpg h_160.jpg h_161.jpg h_162.jpg h_163.jpg h_164.jpg h_165.jpg h_166.jpg h_167.jpg h_168.jpg h_169.jpg h_170.jpg h_171.jpg h_172.jpg h_173.jpg h_174.jpg h_175.jpg h_176.jpg h_177.jpg h_178.jpg h_179.jpg h_17.jpg h_180.jpg h_181.jpg h_182.jpg h_183.jpg h_184.jpg h_185.jpg h_186.jpg h_187.jpg h_188.jpg h_189.jpg h_18.jpg h_190.jpg h_191.jpg h_192.jpg h_193.jpg h_194.jpg h_195.jpg h_196.jpg h_197.jpg h_198.jpg h_199.jpg h_19.jpg h_1.jpg h_200.jpg h_201.jpg h_202.jpg h_204.jpg h_205.jpg h_206.jpg h_207.jpg h_208.jpg h_209.jpg h_20.jpg h_210.jpg h_211.jpg h_212.jpg h_213.jpg h_214.jpg h_215.jpg h_216.jpg h_217.jpg h_218.jpg h_219.jpg h_21.jpg h_220.jpg h_221.jpg h_222.jpg h_223.jpg h_224.jpg h_225.jpg h_226.jpg h_227.jpg h_228.jpg h_229.jpg h_22.jpg h_230.jpg h_231.jpg h_232.jpg h_233.jpg h_234.jpg h_235.jpg h_236.jpg h_237.jpg h_238.jpg h_239.jpg h_23.jpg h_240.jpg h_241.jpg h_242.jpg h_243.jpg h_244.jpg h_245.jpg h_247.jpg h_248.jpg h_249.jpg h_24.jpg h_250.jpg h_251.jpg h_252.jpg h_253.jpg h_254.jpg h_255.jpg h_256.jpg h_257.jpg h_258.jpg h_259.jpg h_25.jpg h_260.jpg h_261.jpg h_262.jpg h_263.jpg h_264.jpg h_265.jpg h_266.jpg h_268.jpg h_269.jpg h_26.jpg h_270.jpg h_271.jpg h_272.jpg h_273.jpg h_274.jpg h_276.jpg h_277.jpg h_278.jpg h_279.jpg h_27.jpg h_280.jpg h_281.jpg h_282.jpg h_283.jpg h_284.jpg h_285.jpg h_286.jpg h_287.jpg h_288.jpg h_289.jpg h_28.jpg h_290.jpg h_291.jpg h_292.jpg h_293.jpg h_294.jpg h_295.jpg h_296.jpg h_297.jpg h_298.jpg h_299.jpg h_29.jpg h_300.jpg h_301.jpg h_302.jpg h_303.jpg h_304.jpg h_305.jpg h_307.jpg h_308.jpg h_309.jpg h_30.jpg h_310.jpg h_311.jpg h_312.jpg h_313.jpg h_314.jpg h_315.jpg h_316.jpg h_317.jpg h_318.jpg h_319.jpg h_31.jpg h_320.jpg h_321.jpg h_322.jpg h_323.jpg h_324.jpg h_325.jpg h_326.jpg h_327.jpg h_329.jpg h_32.jpg h_33.jpg h_34.jpg h_35.jpg h_36.jpg h_37.jpg h_38.jpg h_39.jpg h_3.jpg h_40.jpg h_41.jpg h_42.jpg h_43.jpg h_44.jpg h_45.jpg h_46.jpg h_47.jpg h_48.jpg h_49.jpg h_4.jpg h_50.jpg h_51.jpg h_52.jpg h_53.jpg h_54.jpg h_55.jpg h_56.jpg h_57.jpg h_58.jpg h_59.jpg h_5.jpg h_60.jpg h_61.jpg h_62.jpg h_63.jpg h_65.jpg h_67.jpg h_68.jpg h_69.jpg h_6.jpg h_70.jpg h_71.jpg h_72.jpg h_73.jpg h_74.jpg h_75.jpg h_76.jpg h_77.jpg h_78.jpg h_79.jpg h_7.jpg h_80.jpg h_81.jpg h_82.jpg h_83.jpg h_84.jpg h_85.jpg h_86.jpg h_87.jpg h_88.jpg h_89.jpg h_8.jpg h_90.jpg h_91.jpg h_92.jpg h_93.jpg h_94.jpg h_95.jpg h_96.jpg h_97.jpg h_98.jpg h_99.jpg h_9.jpg hair-flying-142210_1920.jpg headshotid_by_bokogreat_stock-d355xf3.jpg lil_white_goth_grl___23_by_mjranum_stock.jpg lil_white_goth_grl___26_by_mjranum_stock.jpg man-388104_960_720.jpg man_headshot.jpg MFettes-headshot.jpg model-429733_960_720.jpg model-610352_960_720.jpg model-858753_960_720.jpg model-858755_960_720.jpg model-873675_960_720.jpg model-873678_960_720.jpg model-873690_960_720.jpg model-881425_960_720.jpg model-881431_960_720.jpg model-female-girl-beautiful-51969.jpg Model_in_green_dress_3.jpg Modern_shingle_bob_haircut.jpg Motivate_(Fitness_model).jpg Official_portrait_of_Barack_Obama.jpg person-woman-eyes-face.jpg pink-hair-855660_960_720.jpg portrait-750774_1920.jpg Professor_Steven_Chu_ForMemRS_headshot.jpg sailor_flying_4_by_senshistock-d4k2wmr.jpg skin-care-937667_960_720.jpg sorcery___8_by_mjranum_stock.jpg t_62.jpg t_65.jpg test_32.jpg test_8.jpg train_245.jpg train_246.jpg train_255.jpg train_304.jpg train_333.jpg train_361.jpg train_395.jpg train_480.jpg train_488.jpg train_539.jpg wedding-846926_1920.jpg Wild_hair.jpg with_wings___pose_reference_by_senshistock-d6by42n_2.jpg with_wings___pose_reference_by_senshistock-d6by42n.jpg woman-1138435_960_720.jpg woman1.jpg woman2.jpg woman-659354_960_720.jpg woman-804072_960_720.jpg woman-868519_960_720.jpg Woman_in_white_shirt_on_August_2009_02.jpg women-878869_1920.jpg ================================================ FILE: documentation/misc/imagematte_valid.txt ================================================ 13564741125_753939e9ce_o.jpg 3858897226_cae5b75963_o.jpg 538724499685900405.jpg ballerina-855652_1920.jpg boy-454633_1920.jpg h_110.jpg h_150.jpg h_16.jpg h_246.jpg h_267.jpg h_275.jpg h_306.jpg h_328.jpg model-610352_960_720.jpg t_66.jpg ================================================ FILE: documentation/misc/spd_preprocess.py ================================================ # pip install supervisely import supervisely_lib as sly import numpy as np import os from PIL import Image from tqdm import tqdm # Download dataset from project_root = 'PATH_TO/Supervisely Person Dataset' # <-- Configure input project = sly.Project(project_root, sly.OpenMode.READ) output_path = 'OUTPUT_DIR' # <-- Configure output os.makedirs(os.path.join(output_path, 'train', 'src')) os.makedirs(os.path.join(output_path, 'train', 'msk')) os.makedirs(os.path.join(output_path, 'valid', 'src')) os.makedirs(os.path.join(output_path, 'valid', 'msk')) max_size = 2048 # <-- Configure max size for dataset in project.datasets: for item in tqdm(dataset): ann = sly.Annotation.load_json_file(dataset.get_ann_path(item), project.meta) msk = np.zeros(ann.img_size, dtype=np.uint8) for label in ann.labels: label.geometry.draw(msk, color=[255]) msk = Image.fromarray(msk) img = Image.open(dataset.get_img_path(item)).convert('RGB') if img.size[0] > max_size or img.size[1] > max_size: scale = max_size / max(img.size) img = img.resize((int(img.size[0] * scale), int(img.size[1] * scale)), Image.BILINEAR) msk = msk.resize((int(msk.size[0] * scale), int(msk.size[1] * scale)), Image.NEAREST) img.save(os.path.join(output_path, 'train', 'src', item.replace('.png', '.jpg'))) msk.save(os.path.join(output_path, 'train', 'msk', item.replace('.png', '.jpg'))) # Move first 100 to validation set names = os.listdir(os.path.join(output_path, 'train', 'src')) for name in tqdm(names[:100]): os.rename( os.path.join(output_path, 'train', 'src', name), os.path.join(output_path, 'valid', 'src', name)) os.rename( os.path.join(output_path, 'train', 'msk', name), os.path.join(output_path, 'valid', 'msk', name)) ================================================ FILE: documentation/training.md ================================================ # Training Documentation This documentation only shows the way to re-produce our [paper](https://peterl1n.github.io/RobustVideoMatting/). If you would like to remove or add a dataset to the training, you are responsible for adapting the training code yourself. ## Datasets The following datasets are used during our training. **IMPORTANT: If you choose to download our preprocessed versions. Please avoid repeated downloads and cache the data locally. All traffics cost our expense. Please be responsible. We may only provide the preprocessed version of a limited time.** ### Matting Datasets * [VideoMatte240K](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets) * Download JPEG SD version (6G) for stage 1 and 2. * Download JPEG HD version (60G) for stage 3 and 4. * Manually move clips `0000`, `0100`, `0200`, `0300` from the training set to a validation set. * ImageMatte * ImageMatte consists of [Distinctions-646](https://wukaoliu.github.io/HAttMatting/) and [Adobe Image Matting](https://sites.google.com/view/deepimagematting) datasets. * Only needed for stage 4. * You need to contact their authors to acquire. * After downloading both datasets, merge their samples together to form ImageMatte dataset. * Only keep samples of humans. * Full list of images we used in ImageMatte for training: * [imagematte_train.txt](/documentation/misc/imagematte_train.txt) * [imagematte_valid.txt](/documentation/misc/imagematte_valid.txt) * Full list of images we used for evaluation. * [aim_test.txt](/documentation/misc/aim_test.txt) * [d646_test.txt](/documentation/misc/d646_test.txt) ### Background Datasets * Video Backgrounds * We process from [DVM Background Set](https://github.com/nowsyn/DVM) by selecting clips without humans and extract only the first 100 frames as JPEG sequence. * Full list of clips we used: * [dvm_background_train_clips.txt](/documentation/misc/dvm_background_train_clips.txt) * [dvm_background_test_clips.txt](/documentation/misc/dvm_background_test_clips.txt) * You can download our preprocessed versions: * [Train set (14.6G)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTrain.tar) (Manually move some clips to validation set) * [Test set (936M)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTest.tar) (Not needed for training. Only used for making synthetic test samples for evaluation) * Image Backgrounds * Train set: * We crawled 8000 suitable images from Google and Flicker. * We will not publish these images. * [Test set](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets) * We use the validation background set from [BGMv2](https://grail.cs.washington.edu/projects/background-matting-v2/) project. * It contains about 200 images. * It is not used in our training. Only used for making synthetic test samples for evaluation. * But if you just want to quickly tryout training, you may use this as a temporary subsitute for the train set. ### Segmentation Datasets * [COCO](https://cocodataset.org/#download) * Download [train2017.zip (18G)](http://images.cocodataset.org/zips/train2017.zip) * Download [panoptic_annotations_trainval2017.zip (821M)](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip) * Note that our train script expects the panopitc version. * [YouTubeVIS 2021](https://youtube-vos.org/dataset/vis/) * Download the train set. No preprocessing needed. * [Supervisely Person Dataset](https://supervise.ly/explore/projects/supervisely-person-dataset-23304/datasets) * We used the supervisedly library to convert their encoding to bitmaps masks before using our script. We also resized down some of the large images to avoid disk loading bottleneck. * You can refer to [spd_preprocess.py](/documentation/misc/spd_preprocess.py) * Or, you can download our [preprocessed version (800M)](https://robustvideomatting.blob.core.windows.net/data/SuperviselyPersonDataset.tar) ## Training For reference, our training was done on data center machines with 48 CPU cores, 300G CPU memory, and 4 Nvidia V100 32G GPUs. During our official training, the code contains custom logics for our infrastructure. For release, the script has been cleaned up. There may be bugs existing in this version of the code but not in our official training. If you find problems, please file an issue. After you have downloaded the datasets. Please configure `train_config.py` to provide paths to your datasets. The training consists of 4 stages. For detail, please refer to the [paper](https://peterl1n.github.io/RobustVideoMatting/). ### Stage 1 ```sh python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --resolution-lr 512 \ --seq-length-lr 15 \ --learning-rate-backbone 0.0001 \ --learning-rate-aspp 0.0002 \ --learning-rate-decoder 0.0002 \ --learning-rate-refiner 0 \ --checkpoint-dir checkpoint/stage1 \ --log-dir log/stage1 \ --epoch-start 0 \ --epoch-end 20 ``` ### Stage 2 ```sh python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --resolution-lr 512 \ --seq-length-lr 50 \ --learning-rate-backbone 0.00005 \ --learning-rate-aspp 0.0001 \ --learning-rate-decoder 0.0001 \ --learning-rate-refiner 0 \ --checkpoint checkpoint/stage1/epoch-19.pth \ --checkpoint-dir checkpoint/stage2 \ --log-dir log/stage2 \ --epoch-start 20 \ --epoch-end 22 ``` ### Stage 3 ```sh python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --train-hr \ --resolution-lr 512 \ --resolution-hr 2048 \ --seq-length-lr 40 \ --seq-length-hr 6 \ --learning-rate-backbone 0.00001 \ --learning-rate-aspp 0.00001 \ --learning-rate-decoder 0.00001 \ --learning-rate-refiner 0.0002 \ --checkpoint checkpoint/stage2/epoch-21.pth \ --checkpoint-dir checkpoint/stage3 \ --log-dir log/stage3 \ --epoch-start 22 \ --epoch-end 23 ``` ### Stage 4 ```sh python train.py \ --model-variant mobilenetv3 \ --dataset imagematte \ --train-hr \ --resolution-lr 512 \ --resolution-hr 2048 \ --seq-length-lr 40 \ --seq-length-hr 6 \ --learning-rate-backbone 0.00001 \ --learning-rate-aspp 0.00001 \ --learning-rate-decoder 0.00005 \ --learning-rate-refiner 0.0002 \ --checkpoint checkpoint/stage3/epoch-22.pth \ --checkpoint-dir checkpoint/stage4 \ --log-dir log/stage4 \ --epoch-start 23 \ --epoch-end 28 ```


## Evaluation We synthetically composite test samples to both image and video backgrounds. Image samples (from D646, AIM) are augmented with synthetic motion. We only provide the composited VideoMatte240K test set. They are used in our paper evaluation. For D646 and AIM, you need to acquire the data from their authors and composite them yourself. The composition scripts we used are saved in `/evaluation` folder as reference backup. You need to modify them based on your setup. * [videomatte_512x512.tar (PNG 1.8G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_512x288.tar) * [videomatte_1920x1080.tar (JPG 2.2G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_1920x1080.tar) Evaluation scripts are provided in `/evaluation` folder. ================================================ FILE: evaluation/evaluate_hr.py ================================================ """ HR (High-Resolution) evaluation. We found using numpy is very slow for high resolution, so we moved it to PyTorch using CUDA. Note, the script only does evaluation. You will need to first inference yourself and save the results to disk Expected directory format for both prediction and ground-truth is: videomatte_1920x1080 ├── videomatte_motion ├── pha ├── 0000 ├── 0000.png ├── fgr ├── 0000 ├── 0000.png ├── videomatte_static ├── pha ├── 0000 ├── 0000.png ├── fgr ├── 0000 ├── 0000.png Prediction must have the exact file structure and file name as the ground-truth, meaning that if the ground-truth is png/jpg, prediction should be png/jpg. Example usage: python evaluate.py \ --pred-dir pred/videomatte_1920x1080 \ --true-dir true/videomatte_1920x1080 An excel sheet with evaluation results will be written to "pred/videomatte_1920x1080/videomatte_1920x1080.xlsx" """ import argparse import os import cv2 import kornia import numpy as np import xlsxwriter import torch from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm class Evaluator: def __init__(self): self.parse_args() self.init_metrics() self.evaluate() self.write_excel() def parse_args(self): parser = argparse.ArgumentParser() parser.add_argument('--pred-dir', type=str, required=True) parser.add_argument('--true-dir', type=str, required=True) parser.add_argument('--num-workers', type=int, default=48) parser.add_argument('--metrics', type=str, nargs='+', default=[ 'pha_mad', 'pha_mse', 'pha_grad', 'pha_dtssd', 'fgr_mse']) self.args = parser.parse_args() def init_metrics(self): self.mad = MetricMAD() self.mse = MetricMSE() self.grad = MetricGRAD() self.dtssd = MetricDTSSD() def evaluate(self): tasks = [] position = 0 with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor: for dataset in sorted(os.listdir(self.args.pred_dir)): if os.path.isdir(os.path.join(self.args.pred_dir, dataset)): for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))): future = executor.submit(self.evaluate_worker, dataset, clip, position) tasks.append((dataset, clip, future)) position += 1 self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks] def write_excel(self): workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx')) summarysheet = workbook.add_worksheet('summary') metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()] for i, metric in enumerate(self.results[0][2].keys()): summarysheet.write(i, 0, metric) summarysheet.write(i, 1, f'={metric}!B2') for row, (dataset, clip, metrics) in enumerate(self.results): for metricsheet, metric in zip(metricsheets, metrics.values()): # Write the header if row == 0: metricsheet.write(1, 0, 'Average') metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)') for col in range(len(metric)): metricsheet.write(0, col + 2, col) colname = xlsxwriter.utility.xl_col_to_name(col + 2) metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)') metricsheet.write(row + 2, 0, dataset) metricsheet.write(row + 2, 1, clip) metricsheet.write_row(row + 2, 2, metric) workbook.close() def evaluate_worker(self, dataset, clip, position): framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha'))) metrics = {metric_name : [] for metric_name in self.args.metrics} pred_pha_tm1 = None true_pha_tm1 = None for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)): true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE) pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE) true_pha = torch.from_numpy(true_pha).cuda(non_blocking=True).float().div_(255) pred_pha = torch.from_numpy(pred_pha).cuda(non_blocking=True).float().div_(255) if 'pha_mad' in self.args.metrics: metrics['pha_mad'].append(self.mad(pred_pha, true_pha)) if 'pha_mse' in self.args.metrics: metrics['pha_mse'].append(self.mse(pred_pha, true_pha)) if 'pha_grad' in self.args.metrics: metrics['pha_grad'].append(self.grad(pred_pha, true_pha)) if 'pha_conn' in self.args.metrics: metrics['pha_conn'].append(self.conn(pred_pha, true_pha)) if 'pha_dtssd' in self.args.metrics: if i == 0: metrics['pha_dtssd'].append(0) else: metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1)) pred_pha_tm1 = pred_pha true_pha_tm1 = true_pha if 'fgr_mse' in self.args.metrics: true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR) pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR) true_fgr = torch.from_numpy(true_fgr).float().div_(255) pred_fgr = torch.from_numpy(pred_fgr).float().div_(255) true_msk = true_pha > 0 metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk])) return metrics class MetricMAD: def __call__(self, pred, true): return (pred - true).abs_().mean() * 1e3 class MetricMSE: def __call__(self, pred, true): return ((pred - true) ** 2).mean() * 1e3 class MetricGRAD: def __init__(self, sigma=1.4): self.filter_x, self.filter_y = self.gauss_filter(sigma) self.filter_x = torch.from_numpy(self.filter_x).unsqueeze(0).cuda() self.filter_y = torch.from_numpy(self.filter_y).unsqueeze(0).cuda() def __call__(self, pred, true): true_grad = self.gauss_gradient(true) pred_grad = self.gauss_gradient(pred) return ((true_grad - pred_grad) ** 2).sum() / 1000 def gauss_gradient(self, img): img_filtered_x = kornia.filters.filter2D(img[None, None, :, :], self.filter_x, border_type='replicate')[0, 0] img_filtered_y = kornia.filters.filter2D(img[None, None, :, :], self.filter_y, border_type='replicate')[0, 0] return (img_filtered_x**2 + img_filtered_y**2).sqrt() @staticmethod def gauss_filter(sigma, epsilon=1e-2): half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) size = np.int(2 * half_size + 1) # create filter in x axis filter_x = np.zeros((size, size)) for i in range(size): for j in range(size): filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian( j - half_size, sigma) # normalize filter norm = np.sqrt((filter_x**2).sum()) filter_x = filter_x / norm filter_y = np.transpose(filter_x) return filter_x, filter_y @staticmethod def gaussian(x, sigma): return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) @staticmethod def dgaussian(x, sigma): return -x * MetricGRAD.gaussian(x, sigma) / sigma**2 class MetricDTSSD: def __call__(self, pred_t, pred_tm1, true_t, true_tm1): dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2 dtSSD = dtSSD.sum() / true_t.numel() dtSSD = dtSSD.sqrt() return dtSSD * 1e2 if __name__ == '__main__': Evaluator() ================================================ FILE: evaluation/evaluate_lr.py ================================================ """ LR (Low-Resolution) evaluation. Note, the script only does evaluation. You will need to first inference yourself and save the results to disk Expected directory format for both prediction and ground-truth is: videomatte_512x288 ├── videomatte_motion ├── pha ├── 0000 ├── 0000.png ├── fgr ├── 0000 ├── 0000.png ├── videomatte_static ├── pha ├── 0000 ├── 0000.png ├── fgr ├── 0000 ├── 0000.png Prediction must have the exact file structure and file name as the ground-truth, meaning that if the ground-truth is png/jpg, prediction should be png/jpg. Example usage: python evaluate.py \ --pred-dir PATH_TO_PREDICTIONS/videomatte_512x288 \ --true-dir PATH_TO_GROUNDTURTH/videomatte_512x288 An excel sheet with evaluation results will be written to "PATH_TO_PREDICTIONS/videomatte_512x288/videomatte_512x288.xlsx" """ import argparse import os import cv2 import numpy as np import xlsxwriter from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm class Evaluator: def __init__(self): self.parse_args() self.init_metrics() self.evaluate() self.write_excel() def parse_args(self): parser = argparse.ArgumentParser() parser.add_argument('--pred-dir', type=str, required=True) parser.add_argument('--true-dir', type=str, required=True) parser.add_argument('--num-workers', type=int, default=48) parser.add_argument('--metrics', type=str, nargs='+', default=[ 'pha_mad', 'pha_mse', 'pha_grad', 'pha_conn', 'pha_dtssd', 'fgr_mad', 'fgr_mse']) self.args = parser.parse_args() def init_metrics(self): self.mad = MetricMAD() self.mse = MetricMSE() self.grad = MetricGRAD() self.conn = MetricCONN() self.dtssd = MetricDTSSD() def evaluate(self): tasks = [] position = 0 with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor: for dataset in sorted(os.listdir(self.args.pred_dir)): if os.path.isdir(os.path.join(self.args.pred_dir, dataset)): for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))): future = executor.submit(self.evaluate_worker, dataset, clip, position) tasks.append((dataset, clip, future)) position += 1 self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks] def write_excel(self): workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx')) summarysheet = workbook.add_worksheet('summary') metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()] for i, metric in enumerate(self.results[0][2].keys()): summarysheet.write(i, 0, metric) summarysheet.write(i, 1, f'={metric}!B2') for row, (dataset, clip, metrics) in enumerate(self.results): for metricsheet, metric in zip(metricsheets, metrics.values()): # Write the header if row == 0: metricsheet.write(1, 0, 'Average') metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)') for col in range(len(metric)): metricsheet.write(0, col + 2, col) colname = xlsxwriter.utility.xl_col_to_name(col + 2) metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)') metricsheet.write(row + 2, 0, dataset) metricsheet.write(row + 2, 1, clip) metricsheet.write_row(row + 2, 2, metric) workbook.close() def evaluate_worker(self, dataset, clip, position): framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha'))) metrics = {metric_name : [] for metric_name in self.args.metrics} pred_pha_tm1 = None true_pha_tm1 = None for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)): true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255 pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255 if 'pha_mad' in self.args.metrics: metrics['pha_mad'].append(self.mad(pred_pha, true_pha)) if 'pha_mse' in self.args.metrics: metrics['pha_mse'].append(self.mse(pred_pha, true_pha)) if 'pha_grad' in self.args.metrics: metrics['pha_grad'].append(self.grad(pred_pha, true_pha)) if 'pha_conn' in self.args.metrics: metrics['pha_conn'].append(self.conn(pred_pha, true_pha)) if 'pha_dtssd' in self.args.metrics: if i == 0: metrics['pha_dtssd'].append(0) else: metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1)) pred_pha_tm1 = pred_pha true_pha_tm1 = true_pha if 'fgr_mse' in self.args.metrics or 'fgr_mad' in self.args.metrics: true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255 pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255 true_msk = true_pha > 0 if 'fgr_mse' in self.args.metrics: metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk])) if 'fgr_mad' in self.args.metrics: metrics['fgr_mad'].append(self.mad(pred_fgr[true_msk], true_fgr[true_msk])) return metrics class MetricMAD: def __call__(self, pred, true): return np.abs(pred - true).mean() * 1e3 class MetricMSE: def __call__(self, pred, true): return ((pred - true) ** 2).mean() * 1e3 class MetricGRAD: def __init__(self, sigma=1.4): self.filter_x, self.filter_y = self.gauss_filter(sigma) def __call__(self, pred, true): pred_normed = np.zeros_like(pred) true_normed = np.zeros_like(true) cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX) cv2.normalize(true, true_normed, 1., 0., cv2.NORM_MINMAX) true_grad = self.gauss_gradient(true_normed).astype(np.float32) pred_grad = self.gauss_gradient(pred_normed).astype(np.float32) grad_loss = ((true_grad - pred_grad) ** 2).sum() return grad_loss / 1000 def gauss_gradient(self, img): img_filtered_x = cv2.filter2D(img, -1, self.filter_x, borderType=cv2.BORDER_REPLICATE) img_filtered_y = cv2.filter2D(img, -1, self.filter_y, borderType=cv2.BORDER_REPLICATE) return np.sqrt(img_filtered_x**2 + img_filtered_y**2) @staticmethod def gauss_filter(sigma, epsilon=1e-2): half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) size = np.int(2 * half_size + 1) # create filter in x axis filter_x = np.zeros((size, size)) for i in range(size): for j in range(size): filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian( j - half_size, sigma) # normalize filter norm = np.sqrt((filter_x**2).sum()) filter_x = filter_x / norm filter_y = np.transpose(filter_x) return filter_x, filter_y @staticmethod def gaussian(x, sigma): return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) @staticmethod def dgaussian(x, sigma): return -x * MetricGRAD.gaussian(x, sigma) / sigma**2 class MetricCONN: def __call__(self, pred, true): step=0.1 thresh_steps = np.arange(0, 1 + step, step) round_down_map = -np.ones_like(true) for i in range(1, len(thresh_steps)): true_thresh = true >= thresh_steps[i] pred_thresh = pred >= thresh_steps[i] intersection = (true_thresh & pred_thresh).astype(np.uint8) # connected components _, output, stats, _ = cv2.connectedComponentsWithStats( intersection, connectivity=4) # start from 1 in dim 0 to exclude background size = stats[1:, -1] # largest connected component of the intersection omega = np.zeros_like(true) if len(size) != 0: max_id = np.argmax(size) # plus one to include background omega[output == max_id + 1] = 1 mask = (round_down_map == -1) & (omega == 0) round_down_map[mask] = thresh_steps[i - 1] round_down_map[round_down_map == -1] = 1 true_diff = true - round_down_map pred_diff = pred - round_down_map # only calculate difference larger than or equal to 0.15 true_phi = 1 - true_diff * (true_diff >= 0.15) pred_phi = 1 - pred_diff * (pred_diff >= 0.15) connectivity_error = np.sum(np.abs(true_phi - pred_phi)) return connectivity_error / 1000 class MetricDTSSD: def __call__(self, pred_t, pred_tm1, true_t, true_tm1): dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2 dtSSD = np.sum(dtSSD) / true_t.size dtSSD = np.sqrt(dtSSD) return dtSSD * 1e2 if __name__ == '__main__': Evaluator() ================================================ FILE: evaluation/generate_imagematte_with_background_image.py ================================================ """ python generate_imagematte_with_background_image.py \ --imagematte-dir ../matting-data/Distinctions/test \ --background-dir ../matting-data/Backgrounds/valid \ --resolution 512 \ --out-dir ../matting-data/evaluation/distinction_static_sd/ \ --random-seed 10 Seed: 10 - distinction-static 11 - distinction-motion 12 - adobe-static 13 - adobe-motion """ import argparse import os import pims import numpy as np import random from PIL import Image from tqdm import tqdm from tqdm.contrib.concurrent import process_map from torchvision import transforms from torchvision.transforms import functional as F parser = argparse.ArgumentParser() parser.add_argument('--imagematte-dir', type=str, required=True) parser.add_argument('--background-dir', type=str, required=True) parser.add_argument('--num-samples', type=int, default=20) parser.add_argument('--num-frames', type=int, default=100) parser.add_argument('--resolution', type=int, required=True) parser.add_argument('--out-dir', type=str, required=True) parser.add_argument('--random-seed', type=int) parser.add_argument('--extension', type=str, default='.png') args = parser.parse_args() random.seed(args.random_seed) imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr')) background_filenames = os.listdir(args.background_dir) random.shuffle(imagematte_filenames) random.shuffle(background_filenames) def lerp(a, b, percentage): return a * (1 - percentage) + b * percentage def motion_affine(*imgs): config = dict(degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) T = len(imgs[0]) variation_over_time = random.random() for t in range(T): percentage = (t / (T - 1)) * variation_over_time angle = lerp(angleA, angleB, percentage) transX = lerp(transXA, transXB, percentage) transY = lerp(transYA, transYB, percentage) scale = lerp(scaleA, scaleB, percentage) shearX = lerp(shearXA, shearXB, percentage) shearY = lerp(shearYA, shearYB, percentage) for img in imgs: img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) return imgs def process(i): imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)] background_filename = background_filenames[i % len(background_filenames)] out_path = os.path.join(args.out_dir, str(i).zfill(4)) os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) with Image.open(os.path.join(args.background_dir, background_filename)) as bgr: bgr = bgr.convert('RGB') w, h = bgr.size scale = args.resolution / min(h, w) w, h = int(w * scale), int(h * scale) bgr = bgr.resize((w, h)) bgr = F.center_crop(bgr, (args.resolution, args.resolution)) with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \ Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha: fgr = fgr.convert('RGB') pha = pha.convert('L') fgrs = [fgr] * args.num_frames phas = [pha] * args.num_frames fgrs, phas = motion_affine(fgrs, phas) for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): fgr = fgrs[t] pha = phas[t] w, h = fgr.size scale = args.resolution / max(h, w) w, h = int(w * scale), int(h * scale) fgr = fgr.resize((w, h)) pha = pha.resize((w, h)) if h < args.resolution: pt = (args.resolution - h) // 2 pb = args.resolution - h - pt else: pt = 0 pb = 0 if w < args.resolution: pl = (args.resolution - w) // 2 pr = args.resolution - w - pl else: pl = 0 pr = 0 fgr = F.pad(fgr, [pl, pt, pr, pb]) pha = F.pad(pha, [pl, pt, pr, pb]) if i // len(imagematte_filenames) % 2 == 1: fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) pha = pha.transpose(Image.FLIP_LEFT_RIGHT) fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) if t == 0: bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) else: os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) pha = np.asarray(pha).astype(float)[:, :, None] / 255 com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) if __name__ == '__main__': r = process_map(process, range(args.num_samples), max_workers=32) ================================================ FILE: evaluation/generate_imagematte_with_background_video.py ================================================ """ python generate_imagematte_with_background_video.py \ --imagematte-dir ../matting-data/Distinctions/test \ --background-dir ../matting-data/BackgroundVideos_mp4/test \ --resolution 512 \ --out-dir ../matting-data/evaluation/distinction_motion_sd/ \ --random-seed 11 Seed: 10 - distinction-static 11 - distinction-motion 12 - adobe-static 13 - adobe-motion """ import argparse import os import pims import numpy as np import random from multiprocessing import Pool from PIL import Image # from tqdm import tqdm from tqdm.contrib.concurrent import process_map from torchvision import transforms from torchvision.transforms import functional as F parser = argparse.ArgumentParser() parser.add_argument('--imagematte-dir', type=str, required=True) parser.add_argument('--background-dir', type=str, required=True) parser.add_argument('--num-samples', type=int, default=20) parser.add_argument('--num-frames', type=int, default=100) parser.add_argument('--resolution', type=int, required=True) parser.add_argument('--out-dir', type=str, required=True) parser.add_argument('--random-seed', type=int) parser.add_argument('--extension', type=str, default='.png') args = parser.parse_args() random.seed(args.random_seed) imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr')) random.shuffle(imagematte_filenames) background_filenames = [ "0000.mp4", "0007.mp4", "0008.mp4", "0010.mp4", "0013.mp4", "0015.mp4", "0016.mp4", "0018.mp4", "0021.mp4", "0029.mp4", "0033.mp4", "0035.mp4", "0039.mp4", "0050.mp4", "0052.mp4", "0055.mp4", "0060.mp4", "0063.mp4", "0087.mp4", "0086.mp4", "0090.mp4", "0101.mp4", "0110.mp4", "0117.mp4", "0120.mp4", "0122.mp4", "0123.mp4", "0125.mp4", "0128.mp4", "0131.mp4", "0172.mp4", "0176.mp4", "0181.mp4", "0187.mp4", "0193.mp4", "0198.mp4", "0220.mp4", "0221.mp4", "0224.mp4", "0229.mp4", "0233.mp4", "0238.mp4", "0241.mp4", "0245.mp4", "0246.mp4" ] random.shuffle(background_filenames) def lerp(a, b, percentage): return a * (1 - percentage) + b * percentage def motion_affine(*imgs): config = dict(degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) T = len(imgs[0]) variation_over_time = random.random() for t in range(T): percentage = (t / (T - 1)) * variation_over_time angle = lerp(angleA, angleB, percentage) transX = lerp(transXA, transXB, percentage) transY = lerp(transYA, transYB, percentage) scale = lerp(scaleA, scaleB, percentage) shearX = lerp(shearXA, shearXB, percentage) shearY = lerp(shearYA, shearYB, percentage) for img in imgs: img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) return imgs def process(i): imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)] background_filename = background_filenames[i % len(background_filenames)] bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filename)) out_path = os.path.join(args.out_dir, str(i).zfill(4)) os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \ Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha: fgr = fgr.convert('RGB') pha = pha.convert('L') fgrs = [fgr] * args.num_frames phas = [pha] * args.num_frames fgrs, phas = motion_affine(fgrs, phas) for t in range(args.num_frames): fgr = fgrs[t] pha = phas[t] w, h = fgr.size scale = args.resolution / max(h, w) w, h = int(w * scale), int(h * scale) fgr = fgr.resize((w, h)) pha = pha.resize((w, h)) if h < args.resolution: pt = (args.resolution - h) // 2 pb = args.resolution - h - pt else: pt = 0 pb = 0 if w < args.resolution: pl = (args.resolution - w) // 2 pr = args.resolution - w - pl else: pl = 0 pr = 0 fgr = F.pad(fgr, [pl, pt, pr, pb]) pha = F.pad(pha, [pl, pt, pr, pb]) if i // len(imagematte_filenames) % 2 == 1: fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) pha = pha.transpose(Image.FLIP_LEFT_RIGHT) fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) bgr = Image.fromarray(bgrs[t]).convert('RGB') w, h = bgr.size scale = args.resolution / min(h, w) w, h = int(w * scale), int(h * scale) bgr = bgr.resize((w, h)) bgr = F.center_crop(bgr, (args.resolution, args.resolution)) bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) pha = np.asarray(pha).astype(float)[:, :, None] / 255 com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) if __name__ == '__main__': r = process_map(process, range(args.num_samples), max_workers=10) ================================================ FILE: evaluation/generate_videomatte_with_background_image.py ================================================ """ python generate_videomatte_with_background_image.py \ --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \ --background-dir ../matting-data/Backgrounds/valid \ --num-samples 25 \ --resize 512 288 \ --out-dir ../matting-data/evaluation/vidematte_static_sd/ """ import argparse import os import pims import numpy as np import random from PIL import Image from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('--videomatte-dir', type=str, required=True) parser.add_argument('--background-dir', type=str, required=True) parser.add_argument('--num-samples', type=int, default=20) parser.add_argument('--num-frames', type=int, default=100) parser.add_argument('--resize', type=int, default=None, nargs=2) parser.add_argument('--out-dir', type=str, required=True) parser.add_argument('--extension', type=str, default='.png') args = parser.parse_args() random.seed(10) videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))] background_filenames = os.listdir(args.background_dir) random.shuffle(background_filenames) for i in range(args.num_samples): clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)] out_path = os.path.join(args.out_dir, str(i).zfill(4)) os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) with Image.open(os.path.join(args.background_dir, background_filenames[i])) as bgr: bgr = bgr.convert('RGB') base_t = random.choice(range(len(framenames) - args.num_frames)) for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \ Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha: fgr = fgr.convert('RGB') pha = pha.convert('L') if args.resize is not None: fgr = fgr.resize(args.resize, Image.BILINEAR) pha = pha.resize(args.resize, Image.BILINEAR) if i // len(videomatte_filenames) % 2 == 1: fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) pha = pha.transpose(Image.FLIP_LEFT_RIGHT) fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) if t == 0: bgr = bgr.resize(fgr.size, Image.BILINEAR) bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) else: os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) pha = np.asarray(pha).astype(float)[:, :, None] / 255 com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) ================================================ FILE: evaluation/generate_videomatte_with_background_video.py ================================================ """ python generate_videomatte_with_background_video.py \ --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \ --background-dir ../matting-data/BackgroundVideos_mp4/test \ --resize 512 288 \ --out-dir ../matting-data/evaluation/vidematte_motion_sd/ """ import argparse import os import pims import numpy as np import random from PIL import Image from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('--videomatte-dir', type=str, required=True) parser.add_argument('--background-dir', type=str, required=True) parser.add_argument('--num-samples', type=int, default=20) parser.add_argument('--num-frames', type=int, default=100) parser.add_argument('--resize', type=int, default=None, nargs=2) parser.add_argument('--out-dir', type=str, required=True) args = parser.parse_args() # Hand selected a list of videos background_filenames = [ "0000.mp4", "0007.mp4", "0008.mp4", "0010.mp4", "0013.mp4", "0015.mp4", "0016.mp4", "0018.mp4", "0021.mp4", "0029.mp4", "0033.mp4", "0035.mp4", "0039.mp4", "0050.mp4", "0052.mp4", "0055.mp4", "0060.mp4", "0063.mp4", "0087.mp4", "0086.mp4", "0090.mp4", "0101.mp4", "0110.mp4", "0117.mp4", "0120.mp4", "0122.mp4", "0123.mp4", "0125.mp4", "0128.mp4", "0131.mp4", "0172.mp4", "0176.mp4", "0181.mp4", "0187.mp4", "0193.mp4", "0198.mp4", "0220.mp4", "0221.mp4", "0224.mp4", "0229.mp4", "0233.mp4", "0238.mp4", "0241.mp4", "0245.mp4", "0246.mp4" ] random.seed(10) videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))] random.shuffle(background_filenames) for i in range(args.num_samples): bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filenames[i % len(background_filenames)])) clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)] out_path = os.path.join(args.out_dir, str(i).zfill(4)) os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) base_t = random.choice(range(len(framenames) - args.num_frames)) for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \ Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha: fgr = fgr.convert('RGB') pha = pha.convert('L') if args.resize is not None: fgr = fgr.resize(args.resize, Image.BILINEAR) pha = pha.resize(args.resize, Image.BILINEAR) if i // len(videomatte_filenames) % 2 == 1: fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) pha = pha.transpose(Image.FLIP_LEFT_RIGHT) fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + '.png')) pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + '.png')) bgr = Image.fromarray(bgrs[t]) bgr = bgr.resize(fgr.size, Image.BILINEAR) bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + '.png')) pha = np.asarray(pha).astype(float)[:, :, None] / 255 com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) com.save(os.path.join(out_path, 'com', str(t).zfill(4) + '.png')) ================================================ FILE: hubconf.py ================================================ """ Loading model model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50") Converter API convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") """ dependencies = ['torch', 'torchvision'] import torch from model import MattingNetwork def mobilenetv3(pretrained: bool = True, progress: bool = True): model = MattingNetwork('mobilenetv3') if pretrained: url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth' model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) return model def resnet50(pretrained: bool = True, progress: bool = True): model = MattingNetwork('resnet50') if pretrained: url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth' model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) return model def converter(): try: from inference import convert_video return convert_video except ModuleNotFoundError as error: print(error) print('Please run "pip install av tqdm pims"') ================================================ FILE: inference.py ================================================ """ python inference.py \ --variant mobilenetv3 \ --checkpoint "CHECKPOINT" \ --device cuda \ --input-source "input.mp4" \ --output-type video \ --output-composition "composition.mp4" \ --output-alpha "alpha.mp4" \ --output-foreground "foreground.mp4" \ --output-video-mbps 4 \ --seq-chunk 1 """ import torch import os from torch.utils.data import DataLoader from torchvision import transforms from typing import Optional, Tuple from tqdm.auto import tqdm from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter def convert_video(model, input_source: str, input_resize: Optional[Tuple[int, int]] = None, downsample_ratio: Optional[float] = None, output_type: str = 'video', output_composition: Optional[str] = None, output_alpha: Optional[str] = None, output_foreground: Optional[str] = None, output_video_mbps: Optional[float] = None, seq_chunk: int = 1, num_workers: int = 0, progress: bool = True, device: Optional[str] = None, dtype: Optional[torch.dtype] = None): """ Args: input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg. input_resize: If provided, the input are first resized to (w, h). downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one. output_type: Options: ["video", "png_sequence"]. output_composition: The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'. If output_type == 'video', the composition has green screen background. If output_type == 'png_sequence'. the composition is RGBA png images. output_alpha: The alpha output from the model. output_foreground: The foreground output from the model. seq_chunk: Number of frames to process at once. Increase it for better parallelism. num_workers: PyTorch's DataLoader workers. Only use >0 for image input. progress: Show progress bar. device: Only need to manually provide if model is a TorchScript freezed model. dtype: Only need to manually provide if model is a TorchScript freezed model. """ assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).' assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.' assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.' assert seq_chunk >= 1, 'Sequence chunk must be >= 1' assert num_workers >= 0, 'Number of workers must be >= 0' # Initialize transform if input_resize is not None: transform = transforms.Compose([ transforms.Resize(input_resize[::-1]), transforms.ToTensor() ]) else: transform = transforms.ToTensor() # Initialize reader if os.path.isfile(input_source): source = VideoReader(input_source, transform) else: source = ImageSequenceReader(input_source, transform) reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers) # Initialize writers if output_type == 'video': frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30 output_video_mbps = 1 if output_video_mbps is None else output_video_mbps if output_composition is not None: writer_com = VideoWriter( path=output_composition, frame_rate=frame_rate, bit_rate=int(output_video_mbps * 1000000)) if output_alpha is not None: writer_pha = VideoWriter( path=output_alpha, frame_rate=frame_rate, bit_rate=int(output_video_mbps * 1000000)) if output_foreground is not None: writer_fgr = VideoWriter( path=output_foreground, frame_rate=frame_rate, bit_rate=int(output_video_mbps * 1000000)) else: if output_composition is not None: writer_com = ImageSequenceWriter(output_composition, 'png') if output_alpha is not None: writer_pha = ImageSequenceWriter(output_alpha, 'png') if output_foreground is not None: writer_fgr = ImageSequenceWriter(output_foreground, 'png') # Inference model = model.eval() if device is None or dtype is None: param = next(model.parameters()) dtype = param.dtype device = param.device if (output_composition is not None) and (output_type == 'video'): bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) try: with torch.no_grad(): bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) rec = [None] * 4 for src in reader: if downsample_ratio is None: downsample_ratio = auto_downsample_ratio(*src.shape[2:]) src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] fgr, pha, *rec = model(src, *rec, downsample_ratio) if output_foreground is not None: writer_fgr.write(fgr[0]) if output_alpha is not None: writer_pha.write(pha[0]) if output_composition is not None: if output_type == 'video': com = fgr * pha + bgr * (1 - pha) else: fgr = fgr * pha.gt(0) com = torch.cat([fgr, pha], dim=-3) writer_com.write(com[0]) bar.update(src.size(1)) finally: # Clean up if output_composition is not None: writer_com.close() if output_alpha is not None: writer_pha.close() if output_foreground is not None: writer_fgr.close() def auto_downsample_ratio(h, w): """ Automatically find a downsample ratio so that the largest side of the resolution be 512px. """ return min(512 / max(h, w), 1) class Converter: def __init__(self, variant: str, checkpoint: str, device: str): self.model = MattingNetwork(variant).eval().to(device) self.model.load_state_dict(torch.load(checkpoint, map_location=device)) self.model = torch.jit.script(self.model) self.model = torch.jit.freeze(self.model) self.device = device def convert(self, *args, **kwargs): convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs) if __name__ == '__main__': import argparse from model import MattingNetwork parser = argparse.ArgumentParser() parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) parser.add_argument('--checkpoint', type=str, required=True) parser.add_argument('--device', type=str, required=True) parser.add_argument('--input-source', type=str, required=True) parser.add_argument('--input-resize', type=int, default=None, nargs=2) parser.add_argument('--downsample-ratio', type=float) parser.add_argument('--output-composition', type=str) parser.add_argument('--output-alpha', type=str) parser.add_argument('--output-foreground', type=str) parser.add_argument('--output-type', type=str, required=True, choices=['video', 'png_sequence']) parser.add_argument('--output-video-mbps', type=int, default=1) parser.add_argument('--seq-chunk', type=int, default=1) parser.add_argument('--num-workers', type=int, default=0) parser.add_argument('--disable-progress', action='store_true') args = parser.parse_args() converter = Converter(args.variant, args.checkpoint, args.device) converter.convert( input_source=args.input_source, input_resize=args.input_resize, downsample_ratio=args.downsample_ratio, output_type=args.output_type, output_composition=args.output_composition, output_alpha=args.output_alpha, output_foreground=args.output_foreground, output_video_mbps=args.output_video_mbps, seq_chunk=args.seq_chunk, num_workers=args.num_workers, progress=not args.disable_progress ) ================================================ FILE: inference_speed_test.py ================================================ """ python inference_speed_test.py \ --model-variant mobilenetv3 \ --resolution 1920 1080 \ --downsample-ratio 0.25 \ --precision float32 """ import argparse import torch from tqdm import tqdm from model.model import MattingNetwork torch.backends.cudnn.benchmark = True class InferenceSpeedTest: def __init__(self): self.parse_args() self.init_model() self.loop() def parse_args(self): parser = argparse.ArgumentParser() parser.add_argument('--model-variant', type=str, required=True) parser.add_argument('--resolution', type=int, required=True, nargs=2) parser.add_argument('--downsample-ratio', type=float, required=True) parser.add_argument('--precision', type=str, default='float32') parser.add_argument('--disable-refiner', action='store_true') self.args = parser.parse_args() def init_model(self): self.device = 'cuda' self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision] self.model = MattingNetwork(self.args.model_variant) self.model = self.model.to(device=self.device, dtype=self.precision).eval() self.model = torch.jit.script(self.model) self.model = torch.jit.freeze(self.model) def loop(self): w, h = self.args.resolution src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision) with torch.no_grad(): rec = None, None, None, None for _ in tqdm(range(1000)): fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio) torch.cuda.synchronize() if __name__ == '__main__': InferenceSpeedTest() ================================================ FILE: inference_utils.py ================================================ import av import os import pims import numpy as np from torch.utils.data import Dataset from torchvision.transforms.functional import to_pil_image from PIL import Image class VideoReader(Dataset): def __init__(self, path, transform=None): self.video = pims.PyAVVideoReader(path) self.rate = self.video.frame_rate self.transform = transform @property def frame_rate(self): return self.rate def __len__(self): return len(self.video) def __getitem__(self, idx): frame = self.video[idx] frame = Image.fromarray(np.asarray(frame)) if self.transform is not None: frame = self.transform(frame) return frame class VideoWriter: def __init__(self, path, frame_rate, bit_rate=1000000): self.container = av.open(path, mode='w') self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}') self.stream.pix_fmt = 'yuv420p' self.stream.bit_rate = bit_rate def write(self, frames): # frames: [T, C, H, W] self.stream.width = frames.size(3) self.stream.height = frames.size(2) if frames.size(1) == 1: frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() for t in range(frames.shape[0]): frame = frames[t] frame = av.VideoFrame.from_ndarray(frame, format='rgb24') self.container.mux(self.stream.encode(frame)) def close(self): self.container.mux(self.stream.encode()) self.container.close() class ImageSequenceReader(Dataset): def __init__(self, path, transform=None): self.path = path self.files = sorted(os.listdir(path)) self.transform = transform def __len__(self): return len(self.files) def __getitem__(self, idx): with Image.open(os.path.join(self.path, self.files[idx])) as img: img.load() if self.transform is not None: return self.transform(img) return img class ImageSequenceWriter: def __init__(self, path, extension='jpg'): self.path = path self.extension = extension self.counter = 0 os.makedirs(path, exist_ok=True) def write(self, frames): # frames: [T, C, H, W] for t in range(frames.shape[0]): to_pil_image(frames[t]).save(os.path.join( self.path, str(self.counter).zfill(4) + '.' + self.extension)) self.counter += 1 def close(self): pass ================================================ FILE: model/__init__.py ================================================ from .model import MattingNetwork ================================================ FILE: model/decoder.py ================================================ import torch from torch import Tensor from torch import nn from torch.nn import functional as F from typing import Tuple, Optional class RecurrentDecoder(nn.Module): def __init__(self, feature_channels, decoder_channels): super().__init__() self.avgpool = AvgPool() self.decode4 = BottleneckBlock(feature_channels[3]) self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor], r3: Optional[Tensor], r4: Optional[Tensor]): s1, s2, s3 = self.avgpool(s0) x4, r4 = self.decode4(f4, r4) x3, r3 = self.decode3(x4, f3, s3, r3) x2, r2 = self.decode2(x3, f2, s2, r2) x1, r1 = self.decode1(x2, f1, s1, r1) x0 = self.decode0(x1, s0) return x0, r1, r2, r3, r4 class AvgPool(nn.Module): def __init__(self): super().__init__() self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True) def forward_single_frame(self, s0): s1 = self.avgpool(s0) s2 = self.avgpool(s1) s3 = self.avgpool(s2) return s1, s2, s3 def forward_time_series(self, s0): B, T = s0.shape[:2] s0 = s0.flatten(0, 1) s1, s2, s3 = self.forward_single_frame(s0) s1 = s1.unflatten(0, (B, T)) s2 = s2.unflatten(0, (B, T)) s3 = s3.unflatten(0, (B, T)) return s1, s2, s3 def forward(self, s0): if s0.ndim == 5: return self.forward_time_series(s0) else: return self.forward_single_frame(s0) class BottleneckBlock(nn.Module): def __init__(self, channels): super().__init__() self.channels = channels self.gru = ConvGRU(channels // 2) def forward(self, x, r: Optional[Tensor]): a, b = x.split(self.channels // 2, dim=-3) b, r = self.gru(b, r) x = torch.cat([a, b], dim=-3) return x, r class UpsamplingBlock(nn.Module): def __init__(self, in_channels, skip_channels, src_channels, out_channels): super().__init__() self.out_channels = out_channels self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv = nn.Sequential( nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) self.gru = ConvGRU(out_channels // 2) def forward_single_frame(self, x, f, s, r: Optional[Tensor]): x = self.upsample(x) x = x[:, :, :s.size(2), :s.size(3)] x = torch.cat([x, f, s], dim=1) x = self.conv(x) a, b = x.split(self.out_channels // 2, dim=1) b, r = self.gru(b, r) x = torch.cat([a, b], dim=1) return x, r def forward_time_series(self, x, f, s, r: Optional[Tensor]): B, T, _, H, W = s.shape x = x.flatten(0, 1) f = f.flatten(0, 1) s = s.flatten(0, 1) x = self.upsample(x) x = x[:, :, :H, :W] x = torch.cat([x, f, s], dim=1) x = self.conv(x) x = x.unflatten(0, (B, T)) a, b = x.split(self.out_channels // 2, dim=2) b, r = self.gru(b, r) x = torch.cat([a, b], dim=2) return x, r def forward(self, x, f, s, r: Optional[Tensor]): if x.ndim == 5: return self.forward_time_series(x, f, s, r) else: return self.forward_single_frame(x, f, s, r) class OutputBlock(nn.Module): def __init__(self, in_channels, src_channels, out_channels): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv = nn.Sequential( nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) def forward_single_frame(self, x, s): x = self.upsample(x) x = x[:, :, :s.size(2), :s.size(3)] x = torch.cat([x, s], dim=1) x = self.conv(x) return x def forward_time_series(self, x, s): B, T, _, H, W = s.shape x = x.flatten(0, 1) s = s.flatten(0, 1) x = self.upsample(x) x = x[:, :, :H, :W] x = torch.cat([x, s], dim=1) x = self.conv(x) x = x.unflatten(0, (B, T)) return x def forward(self, x, s): if x.ndim == 5: return self.forward_time_series(x, s) else: return self.forward_single_frame(x, s) class ConvGRU(nn.Module): def __init__(self, channels: int, kernel_size: int = 3, padding: int = 1): super().__init__() self.channels = channels self.ih = nn.Sequential( nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding), nn.Sigmoid() ) self.hh = nn.Sequential( nn.Conv2d(channels * 2, channels, kernel_size, padding=padding), nn.Tanh() ) def forward_single_frame(self, x, h): r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1) c = self.hh(torch.cat([x, r * h], dim=1)) h = (1 - z) * h + z * c return h, h def forward_time_series(self, x, h): o = [] for xt in x.unbind(dim=1): ot, h = self.forward_single_frame(xt, h) o.append(ot) o = torch.stack(o, dim=1) return o, h def forward(self, x, h: Optional[Tensor]): if h is None: h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)), device=x.device, dtype=x.dtype) if x.ndim == 5: return self.forward_time_series(x, h) else: return self.forward_single_frame(x, h) class Projection(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 1) def forward_single_frame(self, x): return self.conv(x) def forward_time_series(self, x): B, T = x.shape[:2] return self.conv(x.flatten(0, 1)).unflatten(0, (B, T)) def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) ================================================ FILE: model/deep_guided_filter.py ================================================ import torch from torch import nn from torch.nn import functional as F """ Adopted from """ class DeepGuidedFilterRefiner(nn.Module): def __init__(self, hid_channels=16): super().__init__() self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4) self.box_filter.weight.data[...] = 1 / 9 self.conv = nn.Sequential( nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(hid_channels), nn.ReLU(True), nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False), nn.BatchNorm2d(hid_channels), nn.ReLU(True), nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True) ) def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid): fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1) base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1) base_y = torch.cat([base_fgr, base_pha], dim=1) mean_x = self.box_filter(base_x) mean_y = self.box_filter(base_y) cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y var_x = self.box_filter(base_x * base_x) - mean_x * mean_x A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1)) b = mean_y - A * mean_x H, W = fine_src.shape[2:] A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) out = A * fine_x + b fgr, pha = out.split([3, 1], dim=1) return fgr, pha def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid): B, T = fine_src.shape[:2] fgr, pha = self.forward_single_frame( fine_src.flatten(0, 1), base_src.flatten(0, 1), base_fgr.flatten(0, 1), base_pha.flatten(0, 1), base_hid.flatten(0, 1)) fgr = fgr.unflatten(0, (B, T)) pha = pha.unflatten(0, (B, T)) return fgr, pha def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): if fine_src.ndim == 5: return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid) else: return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid) ================================================ FILE: model/fast_guided_filter.py ================================================ import torch from torch import nn from torch.nn import functional as F """ Adopted from """ class FastGuidedFilterRefiner(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.guilded_filter = FastGuidedFilter(1) def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha): fine_src_gray = fine_src.mean(1, keepdim=True) base_src_gray = base_src.mean(1, keepdim=True) fgr, pha = self.guilded_filter( torch.cat([base_src, base_src_gray], dim=1), torch.cat([base_fgr, base_pha], dim=1), torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1) return fgr, pha def forward_time_series(self, fine_src, base_src, base_fgr, base_pha): B, T = fine_src.shape[:2] fgr, pha = self.forward_single_frame( fine_src.flatten(0, 1), base_src.flatten(0, 1), base_fgr.flatten(0, 1), base_pha.flatten(0, 1)) fgr = fgr.unflatten(0, (B, T)) pha = pha.unflatten(0, (B, T)) return fgr, pha def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): if fine_src.ndim == 5: return self.forward_time_series(fine_src, base_src, base_fgr, base_pha) else: return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha) class FastGuidedFilter(nn.Module): def __init__(self, r: int, eps: float = 1e-5): super().__init__() self.r = r self.eps = eps self.boxfilter = BoxFilter(r) def forward(self, lr_x, lr_y, hr_x): mean_x = self.boxfilter(lr_x) mean_y = self.boxfilter(lr_y) cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x A = cov_xy / (var_x + self.eps) b = mean_y - A * mean_x A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False) b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False) return A * hr_x + b class BoxFilter(nn.Module): def __init__(self, r): super(BoxFilter, self).__init__() self.r = r def forward(self, x): # Note: The original implementation at # uses faster box blur. However, it may not be friendly for ONNX export. # We are switching to use simple convolution for box blur. kernel_size = 2 * self.r + 1 kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype) kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype) x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1]) x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1]) return x ================================================ FILE: model/lraspp.py ================================================ from torch import nn class LRASPP(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.aspp1 = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True) ) self.aspp2 = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.Sigmoid() ) def forward_single_frame(self, x): return self.aspp1(x) * self.aspp2(x) def forward_time_series(self, x): B, T = x.shape[:2] x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) return x def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) ================================================ FILE: model/mobilenetv3.py ================================================ import torch from torch import nn from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig from torchvision.transforms.functional import normalize class MobileNetV3LargeEncoder(MobileNetV3): def __init__(self, pretrained: bool = False): super().__init__( inverted_residual_setting=[ InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1), InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1 InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1), InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2 InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3 InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1), InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1), InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1), InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4 InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), ], last_channel=1280 ) if pretrained: self.load_state_dict(torch.hub.load_state_dict_from_url( 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) del self.avgpool del self.classifier def forward_single_frame(self, x): x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) x = self.features[0](x) x = self.features[1](x) f1 = x x = self.features[2](x) x = self.features[3](x) f2 = x x = self.features[4](x) x = self.features[5](x) x = self.features[6](x) f3 = x x = self.features[7](x) x = self.features[8](x) x = self.features[9](x) x = self.features[10](x) x = self.features[11](x) x = self.features[12](x) x = self.features[13](x) x = self.features[14](x) x = self.features[15](x) x = self.features[16](x) f4 = x return [f1, f2, f3, f4] def forward_time_series(self, x): B, T = x.shape[:2] features = self.forward_single_frame(x.flatten(0, 1)) features = [f.unflatten(0, (B, T)) for f in features] return features def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) ================================================ FILE: model/model.py ================================================ import torch from torch import Tensor from torch import nn from torch.nn import functional as F from typing import Optional, List from .mobilenetv3 import MobileNetV3LargeEncoder from .resnet import ResNet50Encoder from .lraspp import LRASPP from .decoder import RecurrentDecoder, Projection from .fast_guided_filter import FastGuidedFilterRefiner from .deep_guided_filter import DeepGuidedFilterRefiner class MattingNetwork(nn.Module): def __init__(self, variant: str = 'mobilenetv3', refiner: str = 'deep_guided_filter', pretrained_backbone: bool = False): super().__init__() assert variant in ['mobilenetv3', 'resnet50'] assert refiner in ['fast_guided_filter', 'deep_guided_filter'] if variant == 'mobilenetv3': self.backbone = MobileNetV3LargeEncoder(pretrained_backbone) self.aspp = LRASPP(960, 128) self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16]) else: self.backbone = ResNet50Encoder(pretrained_backbone) self.aspp = LRASPP(2048, 256) self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16]) self.project_mat = Projection(16, 4) self.project_seg = Projection(16, 1) if refiner == 'deep_guided_filter': self.refiner = DeepGuidedFilterRefiner() else: self.refiner = FastGuidedFilterRefiner() def forward(self, src: Tensor, r1: Optional[Tensor] = None, r2: Optional[Tensor] = None, r3: Optional[Tensor] = None, r4: Optional[Tensor] = None, downsample_ratio: float = 1, segmentation_pass: bool = False): if downsample_ratio != 1: src_sm = self._interpolate(src, scale_factor=downsample_ratio) else: src_sm = src f1, f2, f3, f4 = self.backbone(src_sm) f4 = self.aspp(f4) hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) if not segmentation_pass: fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) if downsample_ratio != 1: fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) fgr = fgr_residual + src fgr = fgr.clamp(0., 1.) pha = pha.clamp(0., 1.) return [fgr, pha, *rec] else: seg = self.project_seg(hid) return [seg, *rec] def _interpolate(self, x: Tensor, scale_factor: float): if x.ndim == 5: B, T = x.shape[:2] x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor, mode='bilinear', align_corners=False, recompute_scale_factor=False) x = x.unflatten(0, (B, T)) else: x = F.interpolate(x, scale_factor=scale_factor, mode='bilinear', align_corners=False, recompute_scale_factor=False) return x ================================================ FILE: model/resnet.py ================================================ import torch from torch import nn from torchvision.models.resnet import ResNet, Bottleneck class ResNet50Encoder(ResNet): def __init__(self, pretrained: bool = False): super().__init__( block=Bottleneck, layers=[3, 4, 6, 3], replace_stride_with_dilation=[False, False, True], norm_layer=None) if pretrained: self.load_state_dict(torch.hub.load_state_dict_from_url( 'https://download.pytorch.org/models/resnet50-0676ba61.pth')) del self.avgpool del self.fc def forward_single_frame(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) f1 = x # 1/2 x = self.maxpool(x) x = self.layer1(x) f2 = x # 1/4 x = self.layer2(x) f3 = x # 1/8 x = self.layer3(x) x = self.layer4(x) f4 = x # 1/16 return [f1, f2, f3, f4] def forward_time_series(self, x): B, T = x.shape[:2] features = self.forward_single_frame(x.flatten(0, 1)) features = [f.unflatten(0, (B, T)) for f in features] return features def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) ================================================ FILE: requirements_inference.txt ================================================ av==8.0.3 torch==1.9.0 torchvision==0.10.0 tqdm==4.61.1 pims==0.5 ================================================ FILE: requirements_training.txt ================================================ easing_functions==1.0.4 tensorboard==2.5.0 torch==1.9.0 torchvision==0.10.0 tqdm==4.61.1 ================================================ FILE: train.py ================================================ """ # First update `train_config.py` to set paths to your dataset locations. # You may want to change `--num-workers` according to your machine's memory. # The default num-workers=8 may cause dataloader to exit unexpectedly when # machine is out of memory. # Stage 1 python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --resolution-lr 512 \ --seq-length-lr 15 \ --learning-rate-backbone 0.0001 \ --learning-rate-aspp 0.0002 \ --learning-rate-decoder 0.0002 \ --learning-rate-refiner 0 \ --checkpoint-dir checkpoint/stage1 \ --log-dir log/stage1 \ --epoch-start 0 \ --epoch-end 20 # Stage 2 python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --resolution-lr 512 \ --seq-length-lr 50 \ --learning-rate-backbone 0.00005 \ --learning-rate-aspp 0.0001 \ --learning-rate-decoder 0.0001 \ --learning-rate-refiner 0 \ --checkpoint checkpoint/stage1/epoch-19.pth \ --checkpoint-dir checkpoint/stage2 \ --log-dir log/stage2 \ --epoch-start 20 \ --epoch-end 22 # Stage 3 python train.py \ --model-variant mobilenetv3 \ --dataset videomatte \ --train-hr \ --resolution-lr 512 \ --resolution-hr 2048 \ --seq-length-lr 40 \ --seq-length-hr 6 \ --learning-rate-backbone 0.00001 \ --learning-rate-aspp 0.00001 \ --learning-rate-decoder 0.00001 \ --learning-rate-refiner 0.0002 \ --checkpoint checkpoint/stage2/epoch-21.pth \ --checkpoint-dir checkpoint/stage3 \ --log-dir log/stage3 \ --epoch-start 22 \ --epoch-end 23 # Stage 4 python train.py \ --model-variant mobilenetv3 \ --dataset imagematte \ --train-hr \ --resolution-lr 512 \ --resolution-hr 2048 \ --seq-length-lr 40 \ --seq-length-hr 6 \ --learning-rate-backbone 0.00001 \ --learning-rate-aspp 0.00001 \ --learning-rate-decoder 0.00005 \ --learning-rate-refiner 0.0002 \ --checkpoint checkpoint/stage3/epoch-22.pth \ --checkpoint-dir checkpoint/stage4 \ --log-dir log/stage4 \ --epoch-start 23 \ --epoch-end 28 """ import argparse import torch import random import os from torch import nn from torch import distributed as dist from torch import multiprocessing as mp from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam from torch.cuda.amp import autocast, GradScaler from torch.utils.data import DataLoader, ConcatDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter from torchvision.utils import make_grid from torchvision.transforms.functional import center_crop from tqdm import tqdm from dataset.videomatte import ( VideoMatteDataset, VideoMatteTrainAugmentation, VideoMatteValidAugmentation, ) from dataset.imagematte import ( ImageMatteDataset, ImageMatteAugmentation ) from dataset.coco import ( CocoPanopticDataset, CocoPanopticTrainAugmentation, ) from dataset.spd import ( SuperviselyPersonDataset ) from dataset.youtubevis import ( YouTubeVISDataset, YouTubeVISAugmentation ) from dataset.augmentation import ( TrainFrameSampler, ValidFrameSampler ) from model import MattingNetwork from train_config import DATA_PATHS from train_loss import matting_loss, segmentation_loss class Trainer: def __init__(self, rank, world_size): self.parse_args() self.init_distributed(rank, world_size) self.init_datasets() self.init_model() self.init_writer() self.train() self.cleanup() def parse_args(self): parser = argparse.ArgumentParser() # Model parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) # Matting dataset parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte']) # Learning rate parser.add_argument('--learning-rate-backbone', type=float, required=True) parser.add_argument('--learning-rate-aspp', type=float, required=True) parser.add_argument('--learning-rate-decoder', type=float, required=True) parser.add_argument('--learning-rate-refiner', type=float, required=True) # Training setting parser.add_argument('--train-hr', action='store_true') parser.add_argument('--resolution-lr', type=int, default=512) parser.add_argument('--resolution-hr', type=int, default=2048) parser.add_argument('--seq-length-lr', type=int, required=True) parser.add_argument('--seq-length-hr', type=int, default=6) parser.add_argument('--downsample-ratio', type=float, default=0.25) parser.add_argument('--batch-size-per-gpu', type=int, default=1) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--epoch-start', type=int, default=0) parser.add_argument('--epoch-end', type=int, default=16) # Tensorboard logging parser.add_argument('--log-dir', type=str, required=True) parser.add_argument('--log-train-loss-interval', type=int, default=20) parser.add_argument('--log-train-images-interval', type=int, default=500) # Checkpoint loading and saving parser.add_argument('--checkpoint', type=str) parser.add_argument('--checkpoint-dir', type=str, required=True) parser.add_argument('--checkpoint-save-interval', type=int, default=500) # Distributed parser.add_argument('--distributed-addr', type=str, default='localhost') parser.add_argument('--distributed-port', type=str, default='12355') # Debugging parser.add_argument('--disable-progress-bar', action='store_true') parser.add_argument('--disable-validation', action='store_true') parser.add_argument('--disable-mixed-precision', action='store_true') self.args = parser.parse_args() def init_distributed(self, rank, world_size): self.rank = rank self.world_size = world_size self.log('Initializing distributed') os.environ['MASTER_ADDR'] = self.args.distributed_addr os.environ['MASTER_PORT'] = self.args.distributed_port dist.init_process_group("nccl", rank=rank, world_size=world_size) def init_datasets(self): self.log('Initializing matting datasets') size_hr = (self.args.resolution_hr, self.args.resolution_hr) size_lr = (self.args.resolution_lr, self.args.resolution_lr) # Matting datasets: if self.args.dataset == 'videomatte': self.dataset_lr_train = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['train'], background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_lr, seq_length=self.args.seq_length_lr, seq_sampler=TrainFrameSampler(), transform=VideoMatteTrainAugmentation(size_lr)) if self.args.train_hr: self.dataset_hr_train = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['train'], background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_hr, seq_length=self.args.seq_length_hr, seq_sampler=TrainFrameSampler(), transform=VideoMatteTrainAugmentation(size_hr)) self.dataset_valid = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['valid'], background_image_dir=DATA_PATHS['background_images']['valid'], background_video_dir=DATA_PATHS['background_videos']['valid'], size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, seq_sampler=ValidFrameSampler(), transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) else: self.dataset_lr_train = ImageMatteDataset( imagematte_dir=DATA_PATHS['imagematte']['train'], background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_lr, seq_length=self.args.seq_length_lr, seq_sampler=TrainFrameSampler(), transform=ImageMatteAugmentation(size_lr)) if self.args.train_hr: self.dataset_hr_train = ImageMatteDataset( imagematte_dir=DATA_PATHS['imagematte']['train'], background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_hr, seq_length=self.args.seq_length_hr, seq_sampler=TrainFrameSampler(), transform=ImageMatteAugmentation(size_hr)) self.dataset_valid = ImageMatteDataset( imagematte_dir=DATA_PATHS['imagematte']['valid'], background_image_dir=DATA_PATHS['background_images']['valid'], background_video_dir=DATA_PATHS['background_videos']['valid'], size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, seq_sampler=ValidFrameSampler(), transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) # Matting dataloaders: self.datasampler_lr_train = DistributedSampler( dataset=self.dataset_lr_train, rank=self.rank, num_replicas=self.world_size, shuffle=True) self.dataloader_lr_train = DataLoader( dataset=self.dataset_lr_train, batch_size=self.args.batch_size_per_gpu, num_workers=self.args.num_workers, sampler=self.datasampler_lr_train, pin_memory=True) if self.args.train_hr: self.datasampler_hr_train = DistributedSampler( dataset=self.dataset_hr_train, rank=self.rank, num_replicas=self.world_size, shuffle=True) self.dataloader_hr_train = DataLoader( dataset=self.dataset_hr_train, batch_size=self.args.batch_size_per_gpu, num_workers=self.args.num_workers, sampler=self.datasampler_hr_train, pin_memory=True) self.dataloader_valid = DataLoader( dataset=self.dataset_valid, batch_size=self.args.batch_size_per_gpu, num_workers=self.args.num_workers, pin_memory=True) # Segementation datasets self.log('Initializing image segmentation datasets') self.dataset_seg_image = ConcatDataset([ CocoPanopticDataset( imgdir=DATA_PATHS['coco_panoptic']['imgdir'], anndir=DATA_PATHS['coco_panoptic']['anndir'], annfile=DATA_PATHS['coco_panoptic']['annfile'], transform=CocoPanopticTrainAugmentation(size_lr)), SuperviselyPersonDataset( imgdir=DATA_PATHS['spd']['imgdir'], segdir=DATA_PATHS['spd']['segdir'], transform=CocoPanopticTrainAugmentation(size_lr)) ]) self.datasampler_seg_image = DistributedSampler( dataset=self.dataset_seg_image, rank=self.rank, num_replicas=self.world_size, shuffle=True) self.dataloader_seg_image = DataLoader( dataset=self.dataset_seg_image, batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, num_workers=self.args.num_workers, sampler=self.datasampler_seg_image, pin_memory=True) self.log('Initializing video segmentation datasets') self.dataset_seg_video = YouTubeVISDataset( videodir=DATA_PATHS['youtubevis']['videodir'], annfile=DATA_PATHS['youtubevis']['annfile'], size=self.args.resolution_lr, seq_length=self.args.seq_length_lr, seq_sampler=TrainFrameSampler(speed=[1]), transform=YouTubeVISAugmentation(size_lr)) self.datasampler_seg_video = DistributedSampler( dataset=self.dataset_seg_video, rank=self.rank, num_replicas=self.world_size, shuffle=True) self.dataloader_seg_video = DataLoader( dataset=self.dataset_seg_video, batch_size=self.args.batch_size_per_gpu, num_workers=self.args.num_workers, sampler=self.datasampler_seg_video, pin_memory=True) def init_model(self): self.log('Initializing model') self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank) if self.args.checkpoint: self.log(f'Restoring from checkpoint: {self.args.checkpoint}') self.log(self.model.load_state_dict( torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}'))) self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True) self.optimizer = Adam([ {'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone}, {'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp}, {'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder}, {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder}, {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder}, {'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner}, ]) self.scaler = GradScaler() def init_writer(self): if self.rank == 0: self.log('Initializing writer') self.writer = SummaryWriter(self.args.log_dir) def train(self): for epoch in range(self.args.epoch_start, self.args.epoch_end): self.epoch = epoch self.step = epoch * len(self.dataloader_lr_train) if not self.args.disable_validation: self.validate() self.log(f'Training epoch: {epoch}') for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True): # Low resolution pass self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr') # High resolution pass if self.args.train_hr: true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample() self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') # Segmentation pass if self.step % 2 == 0: true_img, true_seg = self.load_next_seg_video_sample() self.train_seg(true_img, true_seg, log_label='seg_video') else: true_img, true_seg = self.load_next_seg_image_sample() self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') if self.step % self.args.checkpoint_save_interval == 0: self.save() self.step += 1 def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): true_fgr = true_fgr.to(self.rank, non_blocking=True) true_pha = true_pha.to(self.rank, non_blocking=True) true_bgr = true_bgr.to(self.rank, non_blocking=True) true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) with autocast(enabled=not self.args.disable_mixed_precision): pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) self.scaler.scale(loss['total']).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0: for loss_name, loss_value in loss.items(): self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step) if self.rank == 0 and self.step % self.args.log_train_images_interval == 0: self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step) self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) def train_seg(self, true_img, true_seg, log_label): true_img = true_img.to(self.rank, non_blocking=True) true_seg = true_seg.to(self.rank, non_blocking=True) true_img, true_seg = self.random_crop(true_img, true_seg) with autocast(enabled=not self.args.disable_mixed_precision): pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] loss = segmentation_loss(pred_seg, true_seg) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: self.writer.add_scalar(f'{log_label}_loss', loss, self.step) if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) def load_next_mat_hr_sample(self): try: sample = next(self.dataiterator_mat_hr) except: self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1) self.dataiterator_mat_hr = iter(self.dataloader_hr_train) sample = next(self.dataiterator_mat_hr) return sample def load_next_seg_video_sample(self): try: sample = next(self.dataiterator_seg_video) except: self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) self.dataiterator_seg_video = iter(self.dataloader_seg_video) sample = next(self.dataiterator_seg_video) return sample def load_next_seg_image_sample(self): try: sample = next(self.dataiterator_seg_image) except: self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) self.dataiterator_seg_image = iter(self.dataloader_seg_image) sample = next(self.dataiterator_seg_image) return sample def validate(self): if self.rank == 0: self.log(f'Validating at the start of epoch: {self.epoch}') self.model_ddp.eval() total_loss, total_count = 0, 0 with torch.no_grad(): with autocast(enabled=not self.args.disable_mixed_precision): for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True): true_fgr = true_fgr.to(self.rank, non_blocking=True) true_pha = true_pha.to(self.rank, non_blocking=True) true_bgr = true_bgr.to(self.rank, non_blocking=True) true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) batch_size = true_src.size(0) pred_fgr, pred_pha = self.model(true_src)[:2] total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size total_count += batch_size avg_loss = total_loss / total_count self.log(f'Validation set average loss: {avg_loss}') self.writer.add_scalar('valid_loss', avg_loss, self.step) self.model_ddp.train() dist.barrier() def random_crop(self, *imgs): h, w = imgs[0].shape[-2:] w = random.choice(range(w // 2, w)) h = random.choice(range(h // 2, h)) results = [] for img in imgs: B, T = img.shape[:2] img = img.flatten(0, 1) img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False) img = center_crop(img, (h, w)) img = img.reshape(B, T, *img.shape[1:]) results.append(img) return results def save(self): if self.rank == 0: os.makedirs(self.args.checkpoint_dir, exist_ok=True) torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth')) self.log('Model saved') dist.barrier() def cleanup(self): dist.destroy_process_group() def log(self, msg): print(f'[GPU{self.rank}] {msg}') if __name__ == '__main__': world_size = torch.cuda.device_count() mp.spawn( Trainer, nprocs=world_size, args=(world_size,), join=True) ================================================ FILE: train_config.py ================================================ """ Expected directory format: VideoMatte Train/Valid: ├──fgr/ ├── 0001/ ├── 00000.jpg ├── 00001.jpg ├── pha/ ├── 0001/ ├── 00000.jpg ├── 00001.jpg ImageMatte Train/Valid: ├── fgr/ ├── sample1.jpg ├── sample2.jpg ├── pha/ ├── sample1.jpg ├── sample2.jpg Background Image Train/Valid ├── sample1.png ├── sample2.png Background Video Train/Valid ├── 0000/ ├── 0000.jpg/ ├── 0001.jpg/ """ DATA_PATHS = { 'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', }, 'imagematte': { 'train': '../matting-data/ImageMatte/train', 'valid': '../matting-data/ImageMatte/valid', }, 'background_images': { 'train': '../matting-data/Backgrounds/train', 'valid': '../matting-data/Backgrounds/valid', }, 'background_videos': { 'train': '../matting-data/BackgroundVideos/train', 'valid': '../matting-data/BackgroundVideos/valid', }, 'coco_panoptic': { 'imgdir': '../matting-data/coco/train2017/', 'anndir': '../matting-data/coco/panoptic_train2017/', 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json', }, 'spd': { 'imgdir': '../matting-data/SuperviselyPersonDataset/img', 'segdir': '../matting-data/SuperviselyPersonDataset/seg', }, 'youtubevis': { 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages', 'annfile': '../matting-data/YouTubeVIS/train/instances.json', } } ================================================ FILE: train_loss.py ================================================ import torch from torch.nn import functional as F # --------------------------------------------------------------------------------- Train Loss def matting_loss(pred_fgr, pred_pha, true_fgr, true_pha): """ Args: pred_fgr: Shape(B, T, 3, H, W) pred_pha: Shape(B, T, 1, H, W) true_fgr: Shape(B, T, 3, H, W) true_pha: Shape(B, T, 1, H, W) """ loss = dict() # Alpha losses loss['pha_l1'] = F.l1_loss(pred_pha, true_pha) loss['pha_laplacian'] = laplacian_loss(pred_pha.flatten(0, 1), true_pha.flatten(0, 1)) loss['pha_coherence'] = F.mse_loss(pred_pha[:, 1:] - pred_pha[:, :-1], true_pha[:, 1:] - true_pha[:, :-1]) * 5 # Foreground losses true_msk = true_pha.gt(0) pred_fgr = pred_fgr * true_msk true_fgr = true_fgr * true_msk loss['fgr_l1'] = F.l1_loss(pred_fgr, true_fgr) loss['fgr_coherence'] = F.mse_loss(pred_fgr[:, 1:] - pred_fgr[:, :-1], true_fgr[:, 1:] - true_fgr[:, :-1]) * 5 # Total loss['total'] = loss['pha_l1'] + loss['pha_coherence'] + loss['pha_laplacian'] \ + loss['fgr_l1'] + loss['fgr_coherence'] return loss def segmentation_loss(pred_seg, true_seg): """ Args: pred_seg: Shape(B, T, 1, H, W) true_seg: Shape(B, T, 1, H, W) """ return F.binary_cross_entropy_with_logits(pred_seg, true_seg) # ----------------------------------------------------------------------------- Laplacian Loss def laplacian_loss(pred, true, max_levels=5): kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) true_pyramid = laplacian_pyramid(true, kernel, max_levels) loss = 0 for level in range(max_levels): loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) return loss / max_levels def laplacian_pyramid(img, kernel, max_levels): current = img pyramid = [] for _ in range(max_levels): current = crop_to_even_size(current) down = downsample(current, kernel) up = upsample(down, kernel) diff = current - up pyramid.append(diff) current = down return pyramid def gauss_kernel(device='cpu', dtype=torch.float32): kernel = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]], device=device, dtype=dtype) kernel /= 256 kernel = kernel[None, None, :, :] return kernel def gauss_convolution(img, kernel): B, C, H, W = img.shape img = img.reshape(B * C, 1, H, W) img = F.pad(img, (2, 2, 2, 2), mode='reflect') img = F.conv2d(img, kernel) img = img.reshape(B, C, H, W) return img def downsample(img, kernel): img = gauss_convolution(img, kernel) img = img[:, :, ::2, ::2] return img def upsample(img, kernel): B, C, H, W = img.shape out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) out[:, :, ::2, ::2] = img * 4 out = gauss_convolution(out, kernel) return out def crop_to_even_size(img): H, W = img.shape[2:] H = H - H % 2 W = W - W % 2 return img[:, :, :H, :W]