Repository: MarkMoHR/virtual_sketching Branch: main Commit: 229fefae08e4 Files: 28 Total size: 431.9 KB Directory structure: gitextract_xyepypzb/ ├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── WINDOWS_INSTALL_GUIDE.md ├── dataset_utils.py ├── docs/ │ ├── assets/ │ │ ├── font.css │ │ └── style.css │ └── index.html ├── hyper_parameters.py ├── launch_gui.bat ├── model_common_test.py ├── model_common_train.py ├── rasterization_utils/ │ ├── NeuralRenderer.py │ └── RealRenderer.py ├── rnn.py ├── subnet_tf_utils.py ├── test_photograph_to_line.py ├── test_rough_sketch_simplification.py ├── test_vectorization.py ├── tools/ │ ├── gif_making.py │ ├── svg_conversion.py │ └── visualize_drawing.py ├── train_rough_photograph.py ├── train_vectorization.py ├── utils.py ├── vgg_utils/ │ └── VGG16.py └── virtual_sketch_gui.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea .idea/ data/ datas/ dataset/ datasets/ model/ models/ testData/ output/ outputs/ *.csv # temporary files *.txt~ *.pyc .DS_Store .gitignore~ *.h5 ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021 [[Paper]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[Project Page]](https://markmohr.github.io/virtual_sketching/) | [[中文Readme]](/README_CN.md) | [[中文论文介绍]](https://blog.csdn.net/qq_33000225/article/details/118883153) This code is used for **line drawing vectorization**, **rough sketch simplification** and **photograph to vector line drawing**.      ## Outline - [Dependencies](#dependencies) - [Testing with Trained Weights](#testing-with-trained-weights) - [Training](#training) - [Citation](#citation) - [Projects Using this Model/Method](#projects-using-this-modelmethod) - [Blogs Mentioning this Paper](#blogs-mentioning-this-paper) - [For Windows users](#-windows-users) ## Dependencies - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= version <=1.15.0) - [opencv](https://opencv.org/) == 3.4.2 - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0 - [scipy](https://www.scipy.org/) == 1.5.2 - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11 ## Testing with Trained Weights ### Model Preparation Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): - `pretrain_clean_line_drawings` (105 MB): for vectorization - `pretrain_rough_sketches` (105 MB): for rough sketch simplification - `pretrain_faces` (105 MB): for photograph to line drawing Then, place them in this file structure: ``` outputs/ snapshot/ pretrain_clean_line_drawings/ pretrain_rough_sketches/ pretrain_faces/ ``` ### Usage Choose the image in the `sample_inputs/` directory, and run one of the following commands for each task. The results will be under `outputs/sampling/`. ``` python python3 test_vectorization.py --input muten.png python3 test_rough_sketch_simplification.py --input rocket.png python3 test_photograph_to_line.py --input 1390.png ``` **Note!!!** Our approach starts drawing from a randomly selected initial position, so it outputs different results in every testing trial (some might be fine and some might not be good enough). It is recommended to do several trials to select the visually best result. The number of outputs can be defined by the `--sample` argument: ``` python python3 test_vectorization.py --input muten.png --sample 10 python3 test_rough_sketch_simplification.py --input rocket.png --sample 10 python3 test_photograph_to_line.py --input 1390.png --sample 10 ``` **Reproducing Paper Figures:** our results (download from [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)) are selected by doing a certain number of trials. Apparently, it is required to use the same initial drawing positions to reproduce our results. ### Additional Tools #### a) Visualization Our vector output is stored in a `npz` package. Run the following command to obtain the rendered output and the drawing order. Results will be under the same directory of the `npz` file. ``` python python3 tools/visualize_drawing.py --file path/to/the/result.npz ``` #### b) GIF Making To see the dynamic drawing procedure, run the following command to obtain the `gif`. Result will be under the same directory of the `npz` file. ``` python python3 tools/gif_making.py --file path/to/the/result.npz ``` #### c) Conversion to SVG Our vector output in a `npz` package is stored as Eq.(1) in the main paper. Run the following command to convert it to the `svg` format. Result will be under the same directory of the `npz` file. ``` python python3 tools/svg_conversion.py --file path/to/the/result.npz ``` - The conversion is implemented in two modes (by setting the `--svg_type` argument): - `single` (default): each stroke (a single segment) forms a path in the SVG file - `cluster`: each continuous curve (with multiple strokes) forms a path in the SVG file **Important Notes** In SVG format, all the segments on a path share the same *stroke-width*. While in our stroke design, strokes on a common curve have different widths. Inside a stroke (a single segment), the thickness also changes linearly from an endpoint to another. Therefore, neither of the two conversion methods above generate visually the same results as the ones in our paper. *(Please mention this issue in your paper if you do qualitative comparisons with our results in SVG format.)*
## Training ### Preparations Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): - `pretrain_neural_renderer` (40 MB): the pre-trained neural renderer - `pretrain_perceptual_model` (691 MB): the pre-trained perceptual model for raster loss Download the datasets [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): - `QuickDraw-clean` (14 MB): for clean line drawing vectorization. Taken from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) dataset. - `QuickDraw-rough` (361 MB): for rough sketch simplification. Synthesized by the pencil drawing generation method from [Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation). - `CelebAMask-faces` (370 MB): for photograph to line drawing. Processed from the [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset. Then, place them in this file structure: ``` datasets/ QuickDraw-clean/ QuickDraw-rough/ CelebAMask-faces/ outputs/ snapshot/ pretrain_neural_renderer/ pretrain_perceptual_model/ ``` ### Running It is recommended to train with multi-GPU. We train each task with 2 GPUs (each with 11 GB). ``` python python3 train_vectorization.py python3 train_rough_photograph.py --data rough python3 train_rough_photograph.py --data face ```
## Citation If you use the code and models please cite: ``` @article{mo2021virtualsketching, title = {General Virtual Sketching Framework for Vector Line Art}, author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei}, journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)}, year = {2021}, volume = {40}, number = {4}, pages = {51:1--51:14} } ```
## Projects Using this Model/Method | **[Painterly style transfer](https://github.com/xch-liu/Painterly-Style-Transfer) (TVCG 2023)** | **[Robot calligraphy](https://github.com/LoYuXr/CalliRewrite) (ICRA 2024)** | |:-------------:|:-------------------:| | | | | **[Geometrized cartoon line inbetweening](https://github.com/lisiyao21/AnimeInbet) (ICCV 2023)** | **[Stroke correspondence and inbetweening](https://github.com/MarkMoHR/JoSTC) (TOG 2024)** | | | | | **[Modelling complex vector drawings](https://github.com/Co-do/Stroke-Cloud) (ICLR 2024)** | **[Sketch-to-Image Generation](https://github.com/BlockDetail/Block-and-Detail) (UIST 2024)** | | | | ## Blogs Mentioning this Paper - [The state of AI for hand-drawn animation inbetweening](https://yosefk.com/blog/the-state-of-ai-for-hand-drawn-animation-inbetweening.html) ## 🪟 Windows users See [WINDOWS_INSTALL_GUIDE.md](WINDOWS_INSTALL_GUIDE.md) for a complete Windows installation guide and GUI usage. ================================================ FILE: README_CN.md ================================================ # General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021 [[论文]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[项目主页]](https://markmohr.github.io/virtual_sketching/) 这份代码能用于实现:**线稿矢量化**、**粗糙草图简化**和**自然图像到矢量草图转换**。      ## 目录 - [环境依赖](#环境依赖) - [使用预训练模型测试](#使用预训练模型测试) - [重新训练](#重新训练) - [引用](#引用) ## 环境依赖 - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= 版本 <=1.15.0) - [opencv](https://opencv.org/) == 3.4.2 - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0 - [scipy](https://www.scipy.org/) == 1.5.2 - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11 ## 使用预训练模型测试 ### 模型下载与准备 在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型: - `pretrain_clean_line_drawings` (105 MB): 用于线稿矢量化 - `pretrain_rough_sketches` (105 MB): 用于粗糙草图简化 - `pretrain_faces` (105 MB): 用于自然图像到矢量草图转换 然后,按照如下结构放置模型: ``` outputs/ snapshot/ pretrain_clean_line_drawings/ pretrain_rough_sketches/ pretrain_faces/ ``` ### 测试方法 在`sample_inputs/`文件夹下选择图像,然后根据任务类型运行下面其中一个命令。生成结果会在`outputs/sampling/`目录下看到。 ``` python python3 test_vectorization.py --input muten.png python3 test_rough_sketch_simplification.py --input rocket.png python3 test_photograph_to_line.py --input 1390.png ``` **注意!!!** 我们的方法从一个随机挑选的初始位置启动绘制,所以每跑一次测试理论上都会得到一个不同的结果(有可能效果不错,但也可能效果不是很好)。因此,建议做多几次测试来挑选看上去最好的结果。也可以通过设置 `--sample`参数来定义跑一次测试代码同时输出(不同结果)的数量: ``` python python3 test_vectorization.py --input muten.png --sample 10 python3 test_rough_sketch_simplification.py --input rocket.png --sample 10 python3 test_photograph_to_line.py --input 1390.png --sample 10 ``` **如何复现论文展示的结果?** 可以从[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载论文展示的结果。这些是我们通过若干次测试得到不同输出后挑选的最好的结果。显然,若要复现这些结果,需要使用相同的初始位置启动绘制。 ### 其他工具 #### a) 可视化 我们的矢量输出均使用`npz` 文件包存储。运行以下的命令可以得到渲染后的结果以及绘制顺序。可以在`npz` 文件包相同的目录下找到这些可视化结果。 ``` python python3 tools/visualize_drawing.py --file path/to/the/result.npz ``` #### b) GIF制作 若要看到动态的绘制过程,可以运行以下命令来得到 `gif`。结果在`npz` 文件包相同的目录下。 ``` python python3 tools/gif_making.py --file path/to/the/result.npz ``` #### c) 转化为SVG `npz` 文件包中的矢量结果均按照论文里面的公式(1)格式存储。可以运行以下命令行,来将其转化为 `svg` 文件格式。结果在`npz` 文件包相同的目录下。 ``` python python3 tools/svg_conversion.py --file path/to/the/result.npz ``` - 转化过程以两种模式实现(设置`--svg_type`参数): - `single` (默认模式): 每个笔划(一根单独的曲线)构成SVG文件中的一个path路径 - `cluster`: 每个连续曲线(多个笔划)构成SVG文件中的一个path路径 **重要注意事项** 在SVG文件格式中,一个path上的所有线段均只有同一个线宽(*stroke-width*)。然而在我们论文里面,定义一个连续曲线上所有的笔划可以有不同的线宽。同时,对于一个单独的笔划(贝塞尔曲线),定义其线宽从一个端点到另一个端点线性递增或者递减。 因此,上述两个转化方法得到的SVG结果理论上都无法保证跟论文里面的结果在视觉上完全一致。(*假如你在论文里面使用这里转化后的SVG结果进行视觉上的对比,请提及此问题。*)
## 重新训练 ### 训练准备 在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型: - `pretrain_neural_renderer` (40 MB): 预训练好的神经网络渲染器 - `pretrain_perceptual_model` (691 MB): 预训练好的perceptual model,用于算 raster loss 在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载训练数据集: - `QuickDraw-clean` (14 MB): 用于线稿矢量化。来自 [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset)数据集。 - `QuickDraw-rough` (361 MB): 用于粗糙草图简化。利用[Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation)里面的铅笔画图像生成方法合成。 - `CelebAMask-faces` (370 MB): 用于自然图像到矢量草图转换。使用[CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)数据集进行处理后得到。 然后,按照如下结构放置数据集: ``` datasets/ QuickDraw-clean/ QuickDraw-rough/ CelebAMask-faces/ outputs/ snapshot/ pretrain_neural_renderer/ pretrain_perceptual_model/ ``` ### 训练方法 建议使用多GPU进行训练。每个任务,我们均使用2个GPU(每个11 GB)来训练。 ``` python python3 train_vectorization.py python3 train_rough_photograph.py --data rough python3 train_rough_photograph.py --data face ```
## 引用 若使用此代码和模型,请引用本工作,谢谢! ``` @article{mo2021virtualsketching, title = {General Virtual Sketching Framework for Vector Line Art}, author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei}, journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)}, year = {2021}, volume = {40}, number = {4}, pages = {51:1--51:14} } ``` ================================================ FILE: WINDOWS_INSTALL_GUIDE.md ================================================ # 🪟 Windows Installation Guide for Virtual Sketching This guide provides step-by-step instructions to set up and run the [Virtual Sketching](https://github.com/MarkMoHR/virtual_sketching) project on Windows using Anaconda and Python 3.6. ## ✅ Requirements - Windows 10 or newer - Anaconda installed - Git installed (optional, but recommended) --- ## 📦 Step 1: Create and Activate Conda Environment ```bash conda create -n virtual_sketching python=3.6 -y conda activate virtual_sketching ``` ## 📂 Step 2: Clone the Repository ```bash cd D:\ git clone https://github.com/MarkMoHR/virtual_sketching.git cd virtual_sketching ``` (If you plan to contribute, consider forking the repo and using your own URL.) --- ## 🔧 Step 3: Install Required Packages ### From `conda`: ```bash conda install opencv=3.4.2 pillow=6.2.0 scipy=1.5.2 -y conda install -c conda-forge pycairo gtk3 cffi -y ``` ### Then remove default TensorFlow (if installed via conda): ```bash conda remove tensorflow ``` ### Install required packages via `pip`: ```bash pip install tensorflow==1.15.0 pip install numpy gizeh cairocffi matplotlib svgwrite ``` > ⚠️ Do not upgrade `pillow`, `scipy`, or `tensorflow` — newer versions are incompatible. --- ## 🛠️ Step 4: Fix Backend Compatibility In `utils.py`, near the top, ensure the following: ```python from PIL import Image import matplotlib matplotlib.use('TkAgg') # force compatible backend import matplotlib.pyplot as plt ``` This ensures proper rendering with tkinter on Windows. --- ## 🚀 Step 5: Run a Demo From the project directory: ```bash python test_vectorization.py --input sample_inputs\muten.png --sample 5 ``` > ⚠️ This script only generates `.npz` and `.png` files. To convert to `.svg`, see next section. --- ## 🖼️ Optional: Use GUI Tool (Windows Only) ### Step 1: Launch GUI with provided batch file Use the `runme.bat` file to activate the conda environment and launch the Python GUI: ```bat rem runme.bat set CONDAPATH=C:\ProgramData\anaconda3 set ENVNAME=virtual_sketching call %CONDAPATH%\Scripts\activate.bat %ENVNAME% python virtual_sketch_gui.py pause ``` ### Step 2: Select an input image and model The GUI allows you to: - Choose input image (PNG, JPEG, BMP, etc.) - Select one of the three models - Automatically runs processing and converts results to SVG - Saves all outputs into a `sketches/` subfolder of the input image directory --- ## 🧠 Known Compatibility Notes - Python 3.6 and TensorFlow 1.15 are required (due to use of `tensorflow.contrib`) - Windows support requires manual setup of Gizeh and Cairo backends - GPU usage optional — TensorFlow 1.15 requires CUDA 10.0 and cuDNN 7 --- ## 🧩 Troubleshooting Tips - ❌ `No module named '_cffi_backend'` → Run: `conda install -c conda-forge cffi` - ❌ `ImportError: cannot import name 'draw_svg_from_npz'` → Use `svg_conversion.py` instead - ❌ `.svg` looks wrong → Make sure you're using the official `svg_conversion.py` from `tools/` - ❌ Missing `.svg`? → Use `virtual_sketch_gui.py` or run `tools/svg_conversion.py` manually on `.npz` --- ## 🤝 Contributing Feel free to open issues or pull requests if you encounter bugs or want to improve the Windows support! --- ## 📁 Folder Structure Suggestion ``` virtual_sketching/ ├── sample_inputs/ ├── tools/ ├── outputs/ ├── virtual_sketch_gui.py ├── runme.bat ├── README.md └── WINDOWS_INSTALL_GUIDE.md ← You are here ``` --- Made with ❤️ by the community to help Windows users get started! ================================================ FILE: dataset_utils.py ================================================ import os import math import random import scipy.io import numpy as np import tensorflow as tf from PIL import Image from rasterization_utils.RealRenderer import GizehRasterizor as RealRenderer def copy_hparams(hparams): """Return a copy of an HParams instance.""" return tf.contrib.training.HParams(**hparams.values()) class GeneralRawDataLoader(object): def __init__(self, image_path, raster_size, test_dataset): self.image_path = image_path self.raster_size = raster_size self.test_dataset = test_dataset def get_test_image(self, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1): input_image_data, image_size_test = self.gen_input_images(self.image_path) input_image_data = np.array(input_image_data, dtype=np.float32) # (1, image_size, image_size, (3)), [0.0-strokes, 1.0-BG] return input_image_data, \ self.gen_init_cursors(input_image_data, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \ image_size_test def gen_input_images(self, image_path): img = Image.open(image_path).convert('RGB') height, width = img.height, img.width max_dim = max(height, width) img = np.array(img, dtype=np.uint8) if height != width: # Padding to a square image if self.test_dataset == 'clean_line_drawings': pad_value = [255, 255, 255] elif self.test_dataset == 'faces': pad_value = [0, 0, 0] else: # TODO: find better padding pixel value pad_value = img[height - 10, width - 10] img_r, img_g, img_b = img[:, :, 0], img[:, :, 1], img[:, :, 2] pad_width = max_dim - width pad_height = max_dim - height pad_img_r = np.pad(img_r, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[0]) pad_img_g = np.pad(img_g, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[1]) pad_img_b = np.pad(img_b, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[2]) image_array = np.stack([pad_img_r, pad_img_g, pad_img_b], axis=-1) else: image_array = img if self.test_dataset == 'faces' and max_dim != 256: image_array_resize = Image.fromarray(image_array, 'RGB') image_array_resize = image_array_resize.resize(size=(256, 256), resample=Image.BILINEAR) image_array = np.array(image_array_resize, dtype=np.uint8) assert image_array.shape[0] == image_array.shape[1] img_size = image_array.shape[0] image_array = image_array.astype(np.float32) if self.test_dataset == 'clean_line_drawings': image_array = image_array[:, :, 0] / 255.0 # [0.0-stroke, 1.0-BG] else: image_array = image_array / 255.0 # [0.0-stroke, 1.0-BG] image_array = np.expand_dims(image_array, axis=0) return image_array, img_size def crop_patch(self, image, center, image_size, crop_size): x0 = center[0] - crop_size // 2 x1 = x0 + crop_size y0 = center[1] - crop_size // 2 y1 = y0 + crop_size x0 = max(0, min(x0, image_size)) y0 = max(0, min(y0, image_size)) x1 = max(0, min(x1, image_size)) y1 = max(0, min(y1, image_size)) patch = image[y0:y1, x0:x1] return patch def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3): # sketch_image: [0.0-stroke, 1.0-BG] image_size = sketch_image.shape[0] if np.sum(1.0 - sketch_image) == 0: center = np.zeros((2), dtype=np.int32) return center else: while True: center = np.random.randint(0, image_size, size=(2)) # (2), in large size patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size) if np.sum(patch) != 0: if not init_cursor_on_undrawn_pixel: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) else: center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size) if np.sum(center_patch) != 0: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1): init_cursor_batch_list = [] for cursor_i in range(init_cursor_num): if random_pos: init_cursor_batch = [] for i in range(len(sketch_data)): sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG] center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel) init_cursor_batch.append(center) init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2) else: raise Exception('Not finished') init_cursor_batch_list.append(init_cursor_batch) if init_cursor_num == 1: init_cursor_batch = init_cursor_batch_list[0] init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2) else: init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2) init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype( np.float32) # (N, init_cursor_num, 1, 2) return init_cursor_batch def load_dataset_testing(test_data_base_dir, test_dataset, test_img_name, model_params): assert test_dataset in ['clean_line_drawings', 'rough_sketches', 'faces'] img_path = os.path.join(test_data_base_dir, test_dataset, test_img_name) print('Loaded {} from {}'.format(img_path, test_dataset)) eval_model_params = copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.batch_size = 1 eval_model_params.model_mode = 'sample' sample_model_params = copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time test_set = GeneralRawDataLoader(img_path, eval_model_params.raster_size, test_dataset=test_dataset) result = [test_set, eval_model_params, sample_model_params] return result class GeneralMultiObjectDataLoader(object): def __init__(self, stroke3_data, batch_size, raster_size, image_size_small, image_size_large, is_bin, is_train): self.batch_size = batch_size # minibatch size self.raster_size = raster_size self.image_size_small = image_size_small self.image_size_large = image_size_large self.is_bin = is_bin self.is_train = is_train self.num_batches = len(stroke3_data) // self.batch_size self.batch_idx = -1 print('batch_size', batch_size, ', num_batches', self.num_batches) self.rasterizor = RealRenderer() self.memory_sketch_data_batch = [] assert type(stroke3_data) is list self.preprocess_rand_data(stroke3_data) def preprocess_rand_data(self, stroke3): if self.is_train: random.shuffle(stroke3) self.stroke3_data = stroke3 def cal_dist(self, posA, posB): return np.sqrt(np.sum(np.power(posA - posB, 2))) def invalid_position(self, pos, obj_size, pos_list, size_list): if len(pos_list) == 0: return False pos_a = pos size_a = obj_size for i in range(len(pos_list)): pos_b = pos_list[i] size_b = size_list[i] if self.cal_dist(pos_a, pos_b) < ((size_a + size_b) // 4): return True return False def get_object_info(self, image_size, vary_thickness=True, try_total_times=3): if image_size <= 172: obj_num = 1 obj_thickness_list = [3] elif image_size <= 225: obj_num = random.randint(1, 2) obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num)) elif image_size <= 278: obj_num = 2 obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num)) elif image_size <= 331: obj_num = random.randint(2, 3) while True: obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num)) if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13: break elif image_size <= 384: obj_num = 3 while True: obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num)) if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13: break else: raise Exception('Invalid image_size', image_size) if not vary_thickness: num_item = len(obj_thickness_list) obj_thickness_list = [3 for _ in range(num_item)] obj_pos_list = [] obj_size_list = [] if obj_num == 1: obj_size_list.append(image_size) center = (image_size // 2, image_size // 2) obj_pos_list.append(center) else: for obj_i in range(obj_num): for try_i in range(try_total_times): obj_size = random.randint(128, image_size * 3 // 4) obj_center = np.random.randint(obj_size // 3, image_size - (obj_size // 3) + 1, size=(2)) if not self.invalid_position(obj_center, obj_size, obj_pos_list, obj_size_list) or try_i == try_total_times - 1: obj_pos_list.append(obj_center) obj_size_list.append(obj_size) break assert len(obj_size_list) == len(obj_pos_list) == len(obj_thickness_list) == obj_num return obj_num, obj_size_list, obj_pos_list, obj_thickness_list def object_pasting(self, obj_img, canvas_img, center): c_y, c_x = center[0], center[1] obj_size = obj_img.shape[0] canvas_size = canvas_img.shape[0] box_left = max(0, c_x - obj_size // 2) box_right = min(canvas_size, c_x + obj_size // 2) box_up = max(0, c_y - obj_size // 2) box_bottom = min(canvas_size, c_y + obj_size // 2) box_canvas = canvas_img[box_up: box_bottom, box_left: box_right] obj_box_up = box_up - (c_y - obj_size // 2) obj_box_left = box_left - (c_x - obj_size // 2) box_obj = obj_img[obj_box_up: obj_box_up + (box_bottom - box_up), obj_box_left: obj_box_left + (box_right - box_left)] box_canvas += box_obj rst_canvas = np.copy(canvas_img) rst_canvas[box_up: box_bottom, box_left: box_right] = box_canvas rst_canvas = np.clip(rst_canvas, 0.0, 1.0) return rst_canvas def get_multi_object_image(self, img_size, vary_thickness): object_num, object_size_list, object_pos_list, object_thickness_list = self.get_object_info( img_size, vary_thickness=vary_thickness) canvas = np.zeros(shape=(img_size, img_size), dtype=np.float32) for obj_i in range(object_num): rand_idx = np.random.randint(0, len(self.stroke3_data)) rand_stroke3 = self.stroke3_data[rand_idx] # (N_points, 3) object_size = object_size_list[obj_i] object_enter = object_pos_list[obj_i] object_thickness = object_thickness_list[obj_i] stroke_image = self.gen_stroke_images([rand_stroke3], object_size, object_thickness) stroke_image = 1.0 - stroke_image[0] # (image_size, image_size), [0.0-BG, 1.0-strokes] canvas = self.object_pasting(stroke_image, canvas, object_enter) # [0.0-BG, 1.0-strokes] canvas = 1.0 - canvas # [0.0-strokes, 1.0-BG] return canvas def get_batch_from_memory(self, memory_idx, vary_thickness, fixed_image_size=-1, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1): if len(self.memory_sketch_data_batch) >= memory_idx + 1: sketch_data_batch = self.memory_sketch_data_batch[memory_idx] sketch_data_batch = np.expand_dims(sketch_data_batch, axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG] image_size_rand = sketch_data_batch.shape[1] else: if fixed_image_size == -1: image_size_rand = random.randint(self.image_size_small, self.image_size_large) else: image_size_rand = fixed_image_size multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness) # [0.0-strokes, 1.0-BG] self.memory_sketch_data_batch.append(multi_obj_image) sketch_data_batch = np.expand_dims(multi_obj_image, axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG] return None, sketch_data_batch, \ self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \ image_size_rand def get_batch_multi_res(self, loop_num, vary_thickness, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1): sketch_data_batch = [] init_cursors_batch = [] image_size_batch = [] batch_size_per_loop = self.batch_size // loop_num for loop_i in range(loop_num): image_size_rand = random.randint(self.image_size_small, self.image_size_large) sketch_data_sub_batch = [] for batch_i in range(batch_size_per_loop): multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness) # [0.0-strokes, 1.0-BG] sketch_data_sub_batch.append(multi_obj_image) sketch_data_sub_batch = np.stack(sketch_data_sub_batch, axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG] init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num) sketch_data_batch.append(sketch_data_sub_batch) init_cursors_batch.append(init_cursors_sub_batch) image_size_batch.append(image_size_rand) return None, \ sketch_data_batch, \ init_cursors_batch, \ image_size_batch def gen_stroke_images(self, stroke3_list, image_size, stroke_width): """ :param stroke3_list: list of (batch_size,), each with (N_points, 3) :param image_size: :return: """ gt_image_array = self.rasterizor.raster_func(stroke3_list, image_size, stroke_width=stroke_width, is_bin=self.is_bin, version='v2') gt_image_array = np.stack(gt_image_array, axis=0) gt_image_array = 1.0 - gt_image_array # (batch_size, image_size, image_size), [0.0-strokes, 1.0-BG] return gt_image_array def crop_patch(self, image, center, image_size, crop_size): x0 = center[0] - crop_size // 2 x1 = x0 + crop_size y0 = center[1] - crop_size // 2 y1 = y0 + crop_size x0 = max(0, min(x0, image_size)) y0 = max(0, min(y0, image_size)) x1 = max(0, min(x1, image_size)) y1 = max(0, min(y1, image_size)) patch = image[y0:y1, x0:x1] return patch def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3): # sketch_image: [0.0-stroke, 1.0-BG] image_size = sketch_image.shape[0] if np.sum(1.0 - sketch_image) == 0: center = np.zeros((2), dtype=np.int32) return center else: while True: center = np.random.randint(0, image_size, size=(2)) # (2), in large size patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size) if np.sum(patch) != 0: if not init_cursor_on_undrawn_pixel: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) else: center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size) if np.sum(center_patch) != 0: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1): init_cursor_batch_list = [] for cursor_i in range(init_cursor_num): if random_pos: init_cursor_batch = [] for i in range(len(sketch_data)): sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG] center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel) init_cursor_batch.append(center) init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2) else: raise Exception('Not finished') init_cursor_batch_list.append(init_cursor_batch) if init_cursor_num == 1: init_cursor_batch = init_cursor_batch_list[0] init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2) else: init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2) init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype( np.float32) # (N, init_cursor_num, 1, 2) return init_cursor_batch def load_dataset_multi_object(dataset_base_dir, model_params): train_stroke3_data = [] val_stroke3_data = [] if model_params.data_set == 'clean_line_drawings': def load_qd_npz_data(npz_path): data = np.load(npz_path, encoding='latin1', allow_pickle=True) selected_strokes3 = data['stroke3'] # (N_sketches,), each with (N_points, 3) selected_strokes3 = selected_strokes3.tolist() return selected_strokes3 base_dir_clean = 'QuickDraw-clean' cates = ['airplane', 'bus', 'car', 'sailboat', 'bird', 'cat', 'dog', # 'rabbit', 'tree', 'flower', # 'circle', 'line', 'zigzag' ] for cate in cates: train_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'train', cate + '.npz') val_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'test', cate + '.npz') print(train_cate_sketch_data_npz_path) train_cate_stroke3_data = load_qd_npz_data( train_cate_sketch_data_npz_path) # list of (N_sketches,), each with (N_points, 3) val_cate_stroke3_data = load_qd_npz_data(val_cate_sketch_data_npz_path) train_stroke3_data += train_cate_stroke3_data val_stroke3_data += val_cate_stroke3_data else: raise Exception('Unknown data type:', model_params.data_set) print('Loaded {}/{} from {}'.format(len(train_stroke3_data), len(val_stroke3_data), model_params.data_set)) print('model_params.max_seq_len %i.' % model_params.max_seq_len) eval_sample_model_params = copy_hparams(model_params) eval_sample_model_params.use_input_dropout = 0 eval_sample_model_params.use_recurrent_dropout = 0 eval_sample_model_params.use_output_dropout = 0 eval_sample_model_params.batch_size = 1 # only sample one at a time eval_sample_model_params.model_mode = 'eval_sample' train_set = GeneralMultiObjectDataLoader(train_stroke3_data, model_params.batch_size, model_params.raster_size, model_params.image_size_small, model_params.image_size_large, model_params.bin_gt, is_train=True) val_set = GeneralMultiObjectDataLoader(val_stroke3_data, eval_sample_model_params.batch_size, eval_sample_model_params.raster_size, eval_sample_model_params.image_size_small, eval_sample_model_params.image_size_large, eval_sample_model_params.bin_gt, is_train=False) result = [train_set, val_set, model_params, eval_sample_model_params] return result class GeneralDataLoaderMultiObjectRough(object): def __init__(self, photo_data, sketch_data, texture_data, shadow_data, batch_size, raster_size, image_size_small, image_size_large, is_train): self.batch_size = batch_size # minibatch size self.raster_size = raster_size self.image_size_small = image_size_small self.image_size_large = image_size_large self.is_train = is_train assert photo_data is not None assert len(photo_data) == len(sketch_data) # self.num_batches = len(sketch_data) // self.batch_size self.batch_idx = -1 print('batch_size', batch_size) assert type(photo_data) is list assert type(sketch_data) is list assert type(texture_data) is list and len(texture_data) > 0 assert type(shadow_data) is list and len(shadow_data) > 0 self.photo_data = photo_data self.sketch_data = sketch_data self.texture_data = texture_data # list of (H, W, 3), [0, 255], uint8 self.shadow_data = shadow_data # list of (H, W), [0, 255], uint8 self.memory_photo_data_batch = [] self.memory_sketch_data_batch = [] def rough_augmentation(self, raw_photo, texture_prob=0.20, noise_prob=0.15, shadow_prob=0.20): # raw_photo: (H, W), [0.0-stroke, 1.0-BG] aug_photo_rgb = np.stack([raw_photo for _ in range(3)], axis=-1) def texture_generation(texture_list, image_shape): while True: random_texture_id = random.randint(0, len(texture_list) - 1) texture_large = texture_list[random_texture_id] t_w, t_h = texture_large.shape[1], texture_large.shape[0] i_w, i_h = image_shape[1], image_shape[0] if t_h >= i_h and t_w >= i_w: texture_large = np.copy(texture_large).astype(np.float32) crop_y = random.randint(0, t_h - i_h) crop_x = random.randint(0, t_w - i_w) crop_texture = texture_large[crop_y: crop_y + i_h, crop_x: crop_x + i_w, :] return crop_texture def texture_change(rough_img_, all_textures): # rough_img_: (H, W, 3), [0.0-stroke, 1.0-BG] texture_image = texture_generation(all_textures, rough_img_.shape) # (h, w, 3) texture_image /= 255.0 rand_b = np.random.uniform(1.0, 2.0, size=rough_img_.shape) textured_img = rough_img_ * (texture_image / rand_b + (rand_b - 1.0) / rand_b) # [0.0, 1.0] return textured_img def noise_change(rough_img_, noise_scale=25): # rough_img_: (H, W, 3), [0.0, 1.0] rough_img_255 = rough_img_ * 255.0 rand_noise = np.random.uniform(-1.0, 1.0, size=rough_img_255.shape) * noise_scale # rand_noise = np.random.normal(size=rough_img.shape) * noise_scale noise_img = rough_img_255 + rand_noise noise_img = np.clip(noise_img, 0.0, 255.0) noise_img /= 255.0 return noise_img def shadow_change(rough_img_, all_shadows): # rough_img_: (H, W, 3), [0.0, 1.0] rough_img_255 = rough_img_ * 255.0 shadow_i = random.randint(0, len(all_shadows) - 1) shadow_full = all_shadows[shadow_i] # (H, W), [0, 255] shadow_img_size = shadow_full.shape[0] while True: position = np.random.randint(-shadow_img_size // 2, shadow_img_size // 2, (2)) if abs(position[0]) > (shadow_img_size // 8) and abs(position[1]) > (shadow_img_size // 8): break position += (shadow_img_size // 2) crop_up = shadow_img_size - position[0] crop_left = shadow_img_size - position[1] shadow_image_large = shadow_full[crop_up: crop_up + shadow_img_size, crop_left: crop_left + shadow_img_size] shadow_bg = Image.fromarray(shadow_image_large, 'L') shadow_bg = shadow_bg.resize(size=(rough_img_255.shape[1], rough_img_255.shape[0]), resample=Image.BILINEAR) shadow_bg = np.array(shadow_bg, dtype=np.float32) / 255.0 # [0.0-shadow, 1.0-BG] shadow_bg = np.stack([shadow_bg for _ in range(3)], axis=-1) shadow_img = rough_img_255 * shadow_bg shadow_img /= 255.0 return shadow_img if random.random() <= texture_prob: aug_photo_rgb = texture_change(aug_photo_rgb, self.texture_data) # (H, W, 3), [0.0, 1.0] if random.random() <= noise_prob: aug_photo_rgb = noise_change(aug_photo_rgb) # (H, W, 3), [0.0, 1.0] if random.random() <= shadow_prob: aug_photo_rgb = shadow_change(aug_photo_rgb, self.shadow_data) # (H, W, 3), [0.0, 1.0] return aug_photo_rgb def image_interpolation(self, photo, sketch, photo_prob): interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob) interp_photo = np.clip(interp_photo, 0.0, 1.0) return interp_photo def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True, photo_prob=1.0, init_cursor_num=1): if len(self.memory_sketch_data_batch) >= memory_idx + 1: photo_data_batch = self.memory_photo_data_batch[memory_idx] sketch_data_batch = self.memory_sketch_data_batch[memory_idx] image_size_rand = sketch_data_batch.shape[1] else: if fixed_image_size == -1: image_size_rand = random.randint(self.image_size_small, self.image_size_large) else: image_size_rand = fixed_image_size # photo_prob = 0.0 if photo_prob_type == 'zero' else 1.0 photo_data_batch, sketch_data_batch = self.select_sketch( image_size_rand) # both: (H, W), [0.0-stroke, 1.0-BG] photo_data_batch = self.rough_augmentation(photo_data_batch) # (H, W, 3), [0.0-stroke, 1.0-BG] self.memory_photo_data_batch.append(photo_data_batch) self.memory_sketch_data_batch.append(sketch_data_batch) if interpolate_type == 'prob': if random.random() >= photo_prob: photo_data_batch = np.stack([sketch_data_batch for _ in range(3)], axis=-1) # (H, W, 3), [0.0-stroke, 1.0-BG] elif interpolate_type == 'image': photo_data_batch = self.image_interpolation( photo_data_batch, np.stack([sketch_data_batch for _ in range(3)], axis=-1), photo_prob) else: raise Exception('Unknown interpolate_type', interpolate_type) photo_data_batch = np.expand_dims(photo_data_batch, axis=0) # (1, image_size, image_size, 3) sketch_data_batch = np.expand_dims(sketch_data_batch, axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG] return photo_data_batch, sketch_data_batch, \ self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand def select_sketch(self, image_size_rand): resolution_idx = image_size_rand - self.image_size_small img_idx = random.randint(0, len(self.sketch_data[resolution_idx]) - 1) assert img_idx != -1 selected_sketch = self.sketch_data[resolution_idx][img_idx] # [0-stroke, 255-BG], uint8 selected_photo = self.photo_data[resolution_idx][img_idx] # [0-stroke, 255-BG], uint8 rst_sketch_image = selected_sketch.astype(np.float32) / 255.0 # [0.0-stroke, 1.0-BG] rst_photo_image = selected_photo.astype(np.float32) / 255.0 # [0.0-stroke, 1.0-BG] return rst_photo_image, rst_sketch_image def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0): photo_data_batch = [] sketch_data_batch = [] init_cursors_batch = [] image_size_batch = [] batch_size_per_loop = self.batch_size // loop_num for loop_i in range(loop_num): image_size_rand = random.randint(self.image_size_small, self.image_size_large) photo_data_sub_batch = [] sketch_data_sub_batch = [] for img_i in range(batch_size_per_loop): photo_patch, sketch_patch = self.select_sketch(image_size_rand) # both: (H, W), [0.0-stroke, 1.0-BG] photo_patch = self.rough_augmentation(photo_patch) # (H, W, 3), [0.0-stroke, 1.0-BG] if interpolate_type == 'prob': if random.random() >= photo_prob: photo_patch = np.stack([sketch_patch for _ in range(3)], axis=-1) # (H, W, 3), [0.0-stroke, 1.0-BG] elif interpolate_type == 'image': photo_patch = self.image_interpolation( photo_patch, np.stack([sketch_patch for _ in range(3)], axis=-1), photo_prob) else: raise Exception('Unknown interpolate_type', interpolate_type) photo_data_sub_batch.append(photo_patch) sketch_data_sub_batch.append(sketch_patch) photo_data_sub_batch = np.stack(photo_data_sub_batch, axis=0) # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG] sketch_data_sub_batch = np.stack(sketch_data_sub_batch, axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG] init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num) photo_data_batch.append(photo_data_sub_batch) sketch_data_batch.append(sketch_data_sub_batch) init_cursors_batch.append(init_cursors_sub_batch) image_size_batch.append(image_size_rand) return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch def crop_patch(self, image, center, image_size, crop_size): x0 = center[0] - crop_size // 2 x1 = x0 + crop_size y0 = center[1] - crop_size // 2 y1 = y0 + crop_size x0 = max(0, min(x0, image_size)) y0 = max(0, min(y0, image_size)) x1 = max(0, min(x1, image_size)) y1 = max(0, min(y1, image_size)) patch = image[y0:y1, x0:x1] return patch def gen_init_cursor_single(self, sketch_image): # sketch_image: [0.0-stroke, 1.0-BG] image_size = sketch_image.shape[0] if np.sum(1.0 - sketch_image) == 0: center = np.zeros((2), dtype=np.int32) return center else: while True: center = np.random.randint(0, image_size, size=(2)) # (2), in large size patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size) if np.sum(patch) != 0: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1): init_cursor_batch_list = [] for cursor_i in range(init_cursor_num): if random_pos: init_cursor_batch = [] for i in range(len(sketch_data)): sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG] center = self.gen_init_cursor_single(sketch_image) init_cursor_batch.append(center) init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2) else: raise Exception('Not finished') init_cursor_batch_list.append(init_cursor_batch) if init_cursor_num == 1: init_cursor_batch = init_cursor_batch_list[0] init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2) else: init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2) init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype( np.float32) # (N, init_cursor_num, 1, 2) return init_cursor_batch def load_dataset_multi_object_rough(dataset_base_dir, model_params): train_photo_data = [] train_sketch_data = [] val_photo_data = [] val_sketch_data = [] texture_data = [] shadow_data = [] if model_params.data_set == 'rough_sketches': base_dir_rough = 'QuickDraw-rough' def load_sketch_data(mat_path): sketch_data_mat = scipy.io.loadmat(mat_path) sketch_data = sketch_data_mat['sketch_array'] sketch_data = np.array(sketch_data, dtype=np.uint8) # (N, resolution, resolution), [0-strokes, 255-BG] return sketch_data def load_photo_data(mat_path): photo_data_mat = scipy.io.loadmat(mat_path) photo_data = photo_data_mat['image_array'] photo_data = np.array(photo_data, dtype=np.uint8) # (N, resolution, resolution), [0-strokes, 255-BG] return photo_data def load_normal_data(img_path): assert '.png' in img_path or '.jpg' img = Image.open(img_path).convert('RGB') img = np.array(img, dtype=np.uint8) # (H, W, 3), [0-stroke, 255-BG], uint8 return img ## Texture texture_base = os.path.join(dataset_base_dir, base_dir_rough, 'texture') all_texture = os.listdir(texture_base) all_texture.sort() for file_name in all_texture: texture_path = os.path.join(texture_base, file_name) texture_uint8 = load_normal_data(texture_path) texture_data.append(texture_uint8) ## shadow def process_angle(img, temp_size): padded_img = img.copy() padded_img[0, 0:temp_size] -= 1 padded_img[0, -(temp_size + 1):-1] -= 1 padded_img[-1, 0:temp_size] -= 1 padded_img[-1, -(temp_size + 1):-1] -= 1 padded_img[0:temp_size, 0] -= 1 padded_img[0:temp_size, -1] -= 1 padded_img[-(temp_size + 1):-1, 0] -= 1 padded_img[-(temp_size + 1):-1, -1] -= 1 return padded_img def pad_img(ori_img, pad_value): padded_img = np.pad(ori_img, 1, constant_values=pad_value) img_h, img_w = padded_img.shape[0], padded_img.shape[1] temp_size = img_h // 3 padded_img = process_angle(padded_img, temp_size) temp_size = img_h // 9 padded_img = process_angle(padded_img, temp_size) temp_size = img_h // 15 padded_img = process_angle(padded_img, temp_size) temp_size = img_h // 21 padded_img = process_angle(padded_img, temp_size) padded_img = np.clip(padded_img, 0, 255) return padded_img def shadow_generation(transparency, shadow_img_size=1024): deepest_value = int(255 * transparency) center_patch = np.zeros((shadow_img_size // 2, shadow_img_size // 2), dtype=np.uint8) center_patch.fill(255) pad_gap = shadow_img_size // 2 shadow_patch = center_patch.copy() for i in range(pad_gap): curr_pad_value = 255.0 - float(255.0 - deepest_value) / float(pad_gap) * (i + 1) shadow_patch = pad_img(shadow_patch, pad_value=curr_pad_value) for i in range(shadow_img_size // 4): shadow_patch = pad_img(shadow_patch, pad_value=deepest_value) assert shadow_patch.shape[0] == shadow_img_size * 2, shadow_patch.shape[0] return shadow_patch for transparency_ in range(90, 95 + 1): transparency = transparency_ / 100.0 shadow_full = shadow_generation(transparency) shadow_data.append(shadow_full) splits = ['train', 'test'] resolutions = [model_params.image_size_small, model_params.image_size_large] for resolution in range(resolutions[0], resolutions[1] + 1): for split in splits: sketch_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1', 'sketch', split, 'res_' + str(resolution) + '.mat') photo_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1', 'photo', split, 'res_' + str(resolution) + '.mat') sketch_data1_uint8 = load_sketch_data( sketch_mat1_path) # (N, resolution, resolution), [0-strokes, 255-BG] photo_data1_uint8 = load_photo_data(photo_mat1_path) # (N, resolution, resolution), [0-strokes, 255-BG] sketch_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2', 'sketch', split, 'res_' + str(resolution) + '.mat') photo_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2', 'photo', split, 'res_' + str(resolution) + '.mat') sketch_data2_uint8 = load_sketch_data( sketch_mat2_path) # (N, resolution, resolution), [0-strokes, 255-BG] photo_data2_uint8 = load_photo_data(photo_mat2_path) # (N, resolution, resolution), [0-strokes, 255-BG] sketch_data_uint8 = np.concatenate([sketch_data1_uint8, sketch_data2_uint8], axis=0) # (N, resolution, resolution), [0-strokes, 255-BG] photo_data_uint8 = np.concatenate([photo_data1_uint8, photo_data2_uint8], axis=0) # (N, resolution, resolution), [0-strokes, 255-BG] if split == 'train': train_photo_data.append(photo_data_uint8) train_sketch_data.append(sketch_data_uint8) else: val_photo_data.append(photo_data_uint8) val_sketch_data.append(sketch_data_uint8) assert len(train_sketch_data) == len(train_photo_data) assert len(val_sketch_data) == len(val_photo_data) else: raise Exception('Unknown data type:', model_params.data_set) print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set)) print('model_params.max_seq_len %i.' % model_params.max_seq_len) eval_sample_model_params = copy_hparams(model_params) eval_sample_model_params.use_input_dropout = 0 eval_sample_model_params.use_recurrent_dropout = 0 eval_sample_model_params.use_output_dropout = 0 eval_sample_model_params.batch_size = 1 # only sample one at a time eval_sample_model_params.model_mode = 'eval_sample' train_set = GeneralDataLoaderMultiObjectRough(train_photo_data, train_sketch_data, texture_data, shadow_data, model_params.batch_size, model_params.raster_size, model_params.image_size_small, model_params.image_size_large, is_train=True) val_set = GeneralDataLoaderMultiObjectRough(val_photo_data, val_sketch_data, texture_data, shadow_data, eval_sample_model_params.batch_size, eval_sample_model_params.raster_size, eval_sample_model_params.image_size_small, eval_sample_model_params.image_size_large, is_train=False) result = [ train_set, val_set, model_params, eval_sample_model_params ] return result class GeneralDataLoaderNormalImageLinear(object): def __init__(self, photo_data, sketch_data, sketch_shape, batch_size, raster_size, image_size_small, image_size_large, random_image_size, flip_prob, rotate_prob, is_train): self.batch_size = batch_size # minibatch size self.raster_size = raster_size self.image_size_small = image_size_small self.image_size_large = image_size_large self.random_image_size = random_image_size self.is_train = is_train assert photo_data is not None assert len(photo_data) == len(sketch_data) self.num_batches = len(sketch_data) // self.batch_size self.batch_idx = -1 print('batch_size', batch_size, ', num_batches', self.num_batches) self.flip_prob = flip_prob self.rotate_prob = rotate_prob assert type(photo_data) is list assert type(sketch_data) is list self.photo_data = photo_data self.sketch_data = sketch_data self.sketch_shape = sketch_shape def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True, photo_prob=1.0, init_cursor_num=1): if self.random_image_size: image_size_rand = fixed_image_size else: image_size_rand = self.image_size_large photo_data_batch, sketch_data_batch = self.select_sketch_and_crop( image_size_rand, data_idx=memory_idx, photo_prob=photo_prob, interpolate_type=interpolate_type) # sketch_patch: [0.0-stroke, 1.0-BG] photo_data_batch = np.expand_dims(photo_data_batch, axis=0) # (1, image_size, image_size, 3) sketch_data_batch = np.expand_dims(sketch_data_batch, axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG] image_size_rand = sketch_data_batch.shape[1] return photo_data_batch, sketch_data_batch, \ self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand def crop_and_augment(self, photo, sketch, shape, crop_size, rotate_angle, stroke_cover=0.01): # img: [0-stroke, 255-BG], uint8 def angle_convert(angle): return angle / 180.0 * math.pi img_h, img_w = shape[0], shape[1] if self.is_train: crop_up = random.randint(0, img_h - crop_size) crop_left = random.randint(0, img_w - crop_size) else: crop_up = (img_h - crop_size) // 2 crop_left = (img_w - crop_size) // 2 assert crop_up >= 0 assert crop_left >= 0 crop_box = (crop_left, crop_up, crop_left + crop_size, crop_up + crop_size) rst_sketch_image = sketch.crop(crop_box) rst_photo_image = photo.crop(crop_box) if random.random() <= self.flip_prob and self.is_train: rst_sketch_image = rst_sketch_image.transpose(Image.FLIP_LEFT_RIGHT) rst_photo_image = rst_photo_image.transpose(Image.FLIP_LEFT_RIGHT) if rotate_angle != 0 and self.is_train: rst_sketch_image = rst_sketch_image.rotate(rotate_angle, resample=Image.BILINEAR) rst_photo_image = rst_photo_image.rotate(rotate_angle, resample=Image.BILINEAR) rst_sketch_image = np.array(rst_sketch_image, dtype=np.uint8) rst_photo_image = np.array(rst_photo_image, dtype=np.uint8) center = rst_photo_image.shape[0] // 2 new_dim = float(crop_size) / ( math.sin(angle_convert(abs(rotate_angle))) + math.cos(angle_convert(abs(rotate_angle)))) new_dim = int(round(new_dim)) start_pos = center - new_dim // 2 end_pos = start_pos + new_dim rst_sketch_image = rst_sketch_image[start_pos:end_pos, start_pos:end_pos, :] rst_photo_image = rst_photo_image[start_pos:end_pos, start_pos:end_pos, :] rst_sketch_image = np.array(rst_sketch_image, dtype=np.float32) / 255.0 # [0.0-stroke, 1.0-BG] rst_sketch_image = rst_sketch_image[:, :, 0] rst_photo_image = np.array(rst_photo_image, dtype=np.float32) / 255.0 # [0.0-stroke, 1.0-BG] percentage = np.mean(1.0 - rst_sketch_image) valid = True if percentage < stroke_cover: valid = False return rst_photo_image, rst_sketch_image, valid def image_interpolation(self, photo, sketch, photo_prob): interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob) interp_photo = np.clip(interp_photo, 0.0, 1.0) return interp_photo def select_sketch_and_crop(self, image_size_rand, interpolate_type, rotate_angle=0, photo_prob=1.0, data_idx=-1, trial_times=10): if self.is_train: while True: rand_img_idx = random.randint(0, len(self.sketch_data) - 1) selected_sketch_shape = self.sketch_shape[rand_img_idx] if selected_sketch_shape[0] >= image_size_rand and selected_sketch_shape[1] >= image_size_rand: img_idx = rand_img_idx break else: assert data_idx != -1 img_idx = data_idx assert img_idx != -1 selected_sketch = self.sketch_data[img_idx] selected_photo = self.photo_data[img_idx] selected_shape = self.sketch_shape[img_idx] assert interpolate_type in ['prob', 'image'] if interpolate_type == 'prob' and random.random() >= photo_prob: selected_photo = self.sketch_data[img_idx] for trial_i in range(trial_times): cropped_photo, cropped_sketch, valid = \ self.crop_and_augment(selected_photo, selected_sketch, selected_shape, image_size_rand, rotate_angle) # cropped_photo, cropped_sketch: [0.0-stroke, 1.0-BG] if valid or trial_i == trial_times - 1: if interpolate_type == 'image': cropped_photo = self.image_interpolation(cropped_photo, np.stack([cropped_sketch for _ in range(3)], axis=-1), photo_prob) return cropped_photo, cropped_sketch def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0): photo_data_batch = [] sketch_data_batch = [] init_cursors_batch = [] image_size_batch = [] batch_size_per_loop = self.batch_size // loop_num for loop_i in range(loop_num): if self.random_image_size: image_size_rand = random.randint(self.image_size_small, self.image_size_large) else: image_size_rand = self.image_size_large rotate_angle = 0 if random.random() <= self.rotate_prob: rotate_angle = random.randint(-45, 45) photo_data_sub_batch = [] sketch_data_sub_batch = [] for img_i in range(batch_size_per_loop): photo_patch, sketch_patch = \ self.select_sketch_and_crop(image_size_rand, rotate_angle=rotate_angle, photo_prob=photo_prob, interpolate_type=interpolate_type) # sketch_patch: [0.0-stroke, 1.0-BG] photo_data_sub_batch.append(photo_patch) sketch_data_sub_batch.append(sketch_patch) photo_data_sub_batch = np.stack(photo_data_sub_batch, axis=0) # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG] sketch_data_sub_batch = np.stack(sketch_data_sub_batch, axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG] init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num) photo_data_batch.append(photo_data_sub_batch) sketch_data_batch.append(sketch_data_sub_batch) init_cursors_batch.append(init_cursors_sub_batch) image_size_rand = photo_data_sub_batch.shape[1] image_size_batch.append(image_size_rand) return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch def crop_patch(self, image, center, image_size, crop_size): x0 = center[0] - crop_size // 2 x1 = x0 + crop_size y0 = center[1] - crop_size // 2 y1 = y0 + crop_size x0 = max(0, min(x0, image_size)) y0 = max(0, min(y0, image_size)) x1 = max(0, min(x1, image_size)) y1 = max(0, min(y1, image_size)) patch = image[y0:y1, x0:x1] return patch def gen_init_cursor_single(self, sketch_image): # sketch_image: [0.0-stroke, 1.0-BG] image_size = sketch_image.shape[0] if np.sum(1.0 - sketch_image) == 0: center = np.zeros((2), dtype=np.int32) return center else: while True: center = np.random.randint(0, image_size, size=(2)) # (2), in large size patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size) if np.sum(patch) != 0: return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0) def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1): init_cursor_batch_list = [] for cursor_i in range(init_cursor_num): if random_pos: init_cursor_batch = [] for i in range(len(sketch_data)): sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG] center = self.gen_init_cursor_single(sketch_image) init_cursor_batch.append(center) init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2) else: raise Exception('Not finished') init_cursor_batch_list.append(init_cursor_batch) if init_cursor_num == 1: init_cursor_batch = init_cursor_batch_list[0] init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2) else: init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2) init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype( np.float32) # (N, init_cursor_num, 1, 2) return init_cursor_batch def load_dataset_normal_images(dataset_base_dir, model_params): train_photo_data = [] train_sketch_data = [] train_data_shape = [] val_photo_data = [] val_sketch_data = [] val_data_shape = [] if model_params.data_set == 'faces': random_training_image_size = False flip_prob = -0.1 rotate_prob = -0.1 splits = ['train', 'val'] database = os.path.join(dataset_base_dir, 'CelebAMask-faces') photo_base = os.path.join(database, 'CelebA-HQ-img256') edge_base = os.path.join(database, 'CelebAMask-HQ-edge256') train_split_txt_save_path = os.path.join(database, 'train.txt') val_split_txt_save_path = os.path.join(database, 'val.txt') celeba_train_txt = np.loadtxt(train_split_txt_save_path, dtype=str) celeba_val_txt = np.loadtxt(val_split_txt_save_path, dtype=str) splits_indices_map = {'train': celeba_train_txt, 'val': celeba_val_txt} for split in splits: split_indices = splits_indices_map[split] for i in range(len(split_indices)): file_idx = split_indices[i] img_file_path = os.path.join(photo_base, str(file_idx) + '.jpg') edge_img_path = os.path.join(edge_base, str(file_idx) + '.png') img_data = Image.open(img_file_path).convert('RGB') edge_data = Image.open(edge_img_path).convert('RGB') if split == 'train': train_photo_data.append(img_data) train_sketch_data.append(edge_data) train_data_shape.append((img_data.height, img_data.width)) else: # split == 'val' val_photo_data.append(img_data) val_sketch_data.append(edge_data) val_data_shape.append((img_data.height, img_data.width)) assert len(train_sketch_data) == len(train_data_shape) == len(train_photo_data) assert len(val_sketch_data) == len(val_data_shape) == len(val_photo_data) else: raise Exception('Unknown data type:', model_params.data_set) print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set)) print('model_params.max_seq_len %i.' % model_params.max_seq_len) eval_sample_model_params = copy_hparams(model_params) eval_sample_model_params.use_input_dropout = 0 eval_sample_model_params.use_recurrent_dropout = 0 eval_sample_model_params.use_output_dropout = 0 eval_sample_model_params.batch_size = 1 # only sample one at a time eval_sample_model_params.model_mode = 'eval_sample' train_set = GeneralDataLoaderNormalImageLinear(train_photo_data, train_sketch_data, train_data_shape, model_params.batch_size, model_params.raster_size, image_size_small=model_params.image_size_small, image_size_large=model_params.image_size_large, random_image_size=random_training_image_size, flip_prob=flip_prob, rotate_prob=rotate_prob, is_train=True) val_set = GeneralDataLoaderNormalImageLinear(val_photo_data, val_sketch_data, val_data_shape, eval_sample_model_params.batch_size, eval_sample_model_params.raster_size, image_size_small=eval_sample_model_params.image_size_small, image_size_large=eval_sample_model_params.image_size_large, random_image_size=random_training_image_size, flip_prob=flip_prob, rotate_prob=rotate_prob, is_train=False) result = [ train_set, val_set, model_params, eval_sample_model_params ] return result def load_dataset_training(dataset_base_dir, model_params): if model_params.data_set == 'clean_line_drawings': return load_dataset_multi_object(dataset_base_dir, model_params) elif model_params.data_set == 'rough_sketches': return load_dataset_multi_object_rough(dataset_base_dir, model_params) elif model_params.data_set == 'faces': return load_dataset_normal_images(dataset_base_dir, model_params) else: raise Exception('Unknown data_set', model_params.data_set) ================================================ FILE: docs/assets/font.css ================================================ /* Homepage Font */ /* latin-ext */ @font-face { font-family: 'Lato'; font-style: normal; font-weight: 400; src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; } /* latin */ @font-face { font-family: 'Lato'; font-style: normal; font-weight: 400; src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* latin-ext */ @font-face { font-family: 'Lato'; font-style: normal; font-weight: 700; src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; } /* latin */ @font-face { font-family: 'Lato'; font-style: normal; font-weight: 700; src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } ================================================ FILE: docs/assets/style.css ================================================ /* Body */ body { background: #e3e5e8; color: #ffffff; font-family: 'Lato', Verdana, Helvetica, sans-serif; font-weight: 300; font-size: 14pt; } /* Hyperlinks */ a {text-decoration: none;} a:link {color: #1772d0;} a:visited {color: #1772d0;} a:active {color: red;} a:hover {color: #f09228;} /* Pre-formatted Text */ pre { margin: 5pt 0; border: 0; font-size: 12pt; background: #fcfcfc; } /* Project Page Style */ /* Section */ .section { width: 768pt; min-height: 100pt; margin: 15pt auto; padding: 20pt 30pt; border: 1pt hidden #000; text-align: justify; color: #000000; background: #ffffff; } /* Header (Title and Logo) */ .section .header { min-height: 80pt; margin-top: 30pt; } .section .header .logo { width: 80pt; margin-left: 10pt; float: left; } .section .header .logo img { width: 80pt; object-fit: cover; } .section .header .title { margin: 0 120pt; text-align: center; font-size: 22pt; } /* Author */ .section .author { margin: 5pt 0; text-align: center; font-size: 16pt; } /* Institution */ .section .institution { margin: 5pt 0; text-align: center; font-size: 16pt; } /* Hyperlink (such as Paper and Code) */ .section .link { margin: 5pt 0; text-align: center; font-size: 16pt; } /* Teaser */ .section .teaser { margin: 20pt 0; text-align: left; } .section .teaser img { width: 95%; } /* Section Title */ .section .title { text-align: center; font-size: 22pt; margin: 5pt 0 15pt 0; /* top right bottom left */ } /* Section Body */ .section .body { margin-bottom: 15pt; text-align: justify; font-size: 14pt; } /* BibTeX */ .section .bibtex { margin: 5pt 0; text-align: left; font-size: 22pt; } /* Related Work */ .section .ref { margin: 20pt 0 10pt 0; /* top right bottom left */ text-align: left; font-size: 18pt; font-weight: bold; } /* Citation */ .section .citation { min-height: 60pt; margin: 10pt 0; } .section .citation .image { width: 120pt; float: left; } .section .citation .image img { max-height: 60pt; width: 120pt; object-fit: cover; } .section .citation .comment{ margin-left: 0pt; text-align: left; font-size: 14pt; } ================================================ FILE: docs/index.html ================================================ General Virtual Sketching Framework for Vector Line Art
General Virtual Sketching Framework for Vector Line Art
Haoran Mo1Edgar Simo-Serra2Chengying Gao*1Changqing Zou3Ruomei Wang1
1Sun Yat-sen University,  2Waseda University, 
3Huawei Technologies Canada

Accepted by ACM SIGGRAPH 2021


Given clean line drawings, rough sketches or photographs of arbitrary resolution as input, our framework generates the corresponding vector line drawings directly. As shown in (b), the framework models a virtual pen surrounded by a dynamic window (red boxes), which moves while drawing the strokes. It learns to move around by scaling the window and sliding to an undrawn area for restarting the drawing (bottom example; sliding trajectory in blue arrow). With our proposed stroke regularization mechanism, the framework is able to enlarge the window and draw long strokes for simplicity (top example).
Abstract
Vector line art plays an important role in graphic design, however, it is tedious to manually create. We introduce a general framework to produce line drawings from a wide variety of images, by learning a mapping from raster image space to vector image space. Our approach is based on a recurrent neural network that draws the lines one by one. A differentiable rasterization module allows for training with only supervised raster data. We use a dynamic window around a virtual pen while drawing lines, implemented with a proposed aligned cropping and differentiable pasting modules. Furthermore, we develop a stroke regularization loss that encourages the model to use fewer and longer strokes to simplify the resulting vector image. Ablation studies and comparisons with existing methods corroborate the efficiency of our approach which is able to generate visually better results in less computation time, while generalizing better to a diversity of images and applications.
Method

Framework Overview



Our framework generates the parametrized strokes step by step in a recurrent manner. It uses a dynamic window (dashed red boxes) around a virtual pen to draw the strokes, and can both move and change the size of the window. (a) Four main modules at each time step: aligned cropping, stroke generation, differentiable rendering and differentiable pasting. (b) Architecture of the stroke generation module. (c) Structural strokes predicted at each step; movement only is illustrated by blue arrows during which no stroke is drawn on the canvas.

Overall Introduction

(Or watch on Bilibili)
👇

Results
Our framework is applicable to a diversity of image types, such as clean line drawing images, rough sketches and photographs.

Vectorization


Rough sketch simplification


Photograph to line drawing


More Results

(Or watch on Bilibili)
👇


Presentations

3-5 minute presentation

(Or watch on Bilibili)
👇


BibTeX
@article{mo2021virtualsketching,
    title   = {General Virtual Sketching Framework for Vector Line Art},
    author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
    journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
    year    = {2021},
    volume  = {40},
    number  = {4},
    pages   = {51:1--51:14}
}

Related Work
Jean-Dominique Favreau, Florent Lafarge and Adrien Bousseau. Fidelity vs. Simplicity: a Global Approach to Line Drawing Vectorization. SIGGRAPH 2016. [Paper] [Webpage]

Mikhail Bessmeltsev and Justin Solomon. Vectorization of Line Drawings via PolyVector Fields. SIGGRAPH 2019. [Paper] [Code]

Edgar Simo-Serra, Satoshi Iizuka and Hiroshi Ishikawa. Mastering Sketching: Adversarial Augmentation for Structured Prediction. SIGGRAPH 2018. [Paper] [Webpage] [Code]

Zhewei Huang, Wen Heng and Shuchang Zhou. Learning to Paint With Model-based Deep Reinforcement Learning. ICCV 2019. [Paper] [Code]

================================================ FILE: hyper_parameters.py ================================================ import tensorflow as tf ############################################# # Common parameters ############################################# FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'dataset_dir', 'datasets', 'The directory of sketch data of the dataset.') tf.app.flags.DEFINE_string( 'log_root', 'outputs/log', 'Directory to store tensorboard.') tf.app.flags.DEFINE_string( 'log_img_root', 'outputs/log_img', 'Directory to store intermediate output images.') tf.app.flags.DEFINE_string( 'snapshot_root', 'outputs/snapshot', 'Directory to store model checkpoints.') tf.app.flags.DEFINE_string( 'neural_renderer_path', 'outputs/snapshot/pretrain_neural_renderer/renderer_300000.tfmodel', 'Path to the neural renderer model.') tf.app.flags.DEFINE_string( 'perceptual_model_root', 'outputs/snapshot/pretrain_perceptual_model', 'Directory to store perceptual model.') tf.app.flags.DEFINE_string( 'data', '', 'The dataset type.') def get_default_hparams_clean(): """Return default HParams for sketch-rnn.""" hparams = tf.contrib.training.HParams( program_name='new_train_clean_line_drawings', data_set='clean_line_drawings', # Our dataset. input_channel=1, num_steps=75040, # Total number of steps of training. save_every=75000, eval_every=5000, max_seq_len=48, batch_size=20, gpus=[0, 1], loop_per_gpu=1, sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing'] stroke_num_loss_weight=0.02, stroke_num_loss_weight_end=0.0, increase_start_steps=25000, decrease_stop_steps=40000, perc_loss_layers=['ReLU1_2', 'ReLU2_2', 'ReLU3_3', 'ReLU5_1'], perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum'] init_cursor_on_undrawn_pixel=False, early_pen_loss_type='move', # ['head', 'tail', 'move'] early_pen_loss_weight=0.1, early_pen_length=7, min_width=0.01, min_window_size=32, max_scaling=2.0, encode_cursor_type='value', image_size_small=128, image_size_large=278, cropping_type='v3', # ['v2', 'v3'] pasting_type='v3', # ['v2', 'v3'] pasting_diff=True, concat_win_size=True, encoder_type='conv13_c3', # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3'] # ['conv13_c3_attn'] # ['combine33', 'combine43', 'combine53', 'combineFC'] vary_thickness=False, outside_loss_weight=10.0, win_size_outside_loss_weight=10.0, resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA'] concat_cursor=True, use_softargmax=True, soft_beta=10, # value for the soft argmax raster_loss_weight=1.0, dec_rnn_size=256, # Size of decoder. dec_model='hyper', # Decoder: lstm, layer_norm or hyper. # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128. bin_gt=True, stop_accu_grad=True, random_cursor=True, cursor_type='next', raster_size=128, pix_drop_kp=1.0, # Dropout keep rate add_coordconv=True, position_format='abs', raster_loss_base_type='perceptual', # [l1, mse, perceptual] grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0. learning_rate=0.0001, # Learning rate. decay_rate=0.9999, # Learning rate decay per minibatch. decay_power=0.9, min_learning_rate=0.000001, # Minimum learning rate. use_recurrent_dropout=True, # Dropout with memory loss. Recommended recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep. use_input_dropout=False, # Input dropout. Recommend leaving False. input_dropout_prob=0.90, # Probability of input dropout keep. use_output_dropout=False, # Output dropout. Recommend leaving False. output_dropout_prob=0.90, # Probability of output dropout keep. model_mode='train' # ['train', 'eval', 'sample'] ) return hparams def get_default_hparams_rough(): """Return default HParams for sketch-rnn.""" hparams = tf.contrib.training.HParams( program_name='new_train_rough_sketches', data_set='rough_sketches', # ['rough_sketches', 'faces'] input_channel=3, num_steps=90040, # Total number of steps of training. save_every=90000, eval_every=5000, max_seq_len=48, batch_size=20, gpus=[0, 1], loop_per_gpu=1, sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing'] stroke_num_loss_weight=0.1, stroke_num_loss_weight_end=0.0, increase_start_steps=25000, decrease_stop_steps=40000, photo_prob_type='one', # ['increasing', 'zero', 'one'] photo_prob_start_step=35000, perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU5_1'], perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum'] early_pen_loss_type='move', # ['head', 'tail', 'move'] early_pen_loss_weight=0.2, early_pen_length=7, min_width=0.01, min_window_size=32, max_scaling=2.0, encode_cursor_type='value', image_size_small=128, image_size_large=278, cropping_type='v3', # ['v2', 'v3'] pasting_type='v3', # ['v2', 'v3'] pasting_diff=True, concat_win_size=True, encoder_type='conv13_c3', # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3'] # ['conv13_c3_attn'] # ['combine33', 'combine43', 'combine53', 'combineFC'] outside_loss_weight=10.0, win_size_outside_loss_weight=10.0, resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA'] concat_cursor=True, use_softargmax=True, soft_beta=10, # value for the soft argmax raster_loss_weight=1.0, dec_rnn_size=256, # Size of decoder. dec_model='hyper', # Decoder: lstm, layer_norm or hyper. # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128. bin_gt=True, stop_accu_grad=True, random_cursor=True, cursor_type='next', raster_size=128, pix_drop_kp=1.0, # Dropout keep rate add_coordconv=True, position_format='abs', raster_loss_base_type='perceptual', # [l1, mse, perceptual] grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0. learning_rate=0.0001, # Learning rate. decay_rate=0.9999, # Learning rate decay per minibatch. decay_power=0.9, min_learning_rate=0.000001, # Minimum learning rate. use_recurrent_dropout=True, # Dropout with memory loss. Recommended recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep. use_input_dropout=False, # Input dropout. Recommend leaving False. input_dropout_prob=0.90, # Probability of input dropout keep. use_output_dropout=False, # Output dropout. Recommend leaving False. output_dropout_prob=0.90, # Probability of output dropout keep. model_mode='train' # ['train', 'eval', 'sample'] ) return hparams def get_default_hparams_normal(): """Return default HParams for sketch-rnn.""" hparams = tf.contrib.training.HParams( program_name='new_train_faces', data_set='faces', # ['rough_sketches', 'faces'] input_channel=3, num_steps=90040, # Total number of steps of training. save_every=90000, eval_every=5000, max_seq_len=48, batch_size=20, gpus=[0, 1], loop_per_gpu=1, sn_loss_type='fixed', # ['decreasing', 'fixed', 'increasing'] stroke_num_loss_weight=0.0, stroke_num_loss_weight_end=0.0, increase_start_steps=0, decrease_stop_steps=40000, photo_prob_type='interpolate', # ['increasing', 'zero', 'one', 'interpolate'] photo_prob_start_step=30000, photo_prob_end_step=60000, perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU4_2', 'ReLU5_1'], perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum'] early_pen_loss_type='move', # ['head', 'tail', 'move'] early_pen_loss_weight=0.2, early_pen_length=7, min_width=0.01, min_window_size=32, max_scaling=2.0, encode_cursor_type='value', image_size_small=128, image_size_large=256, cropping_type='v3', # ['v2', 'v3'] pasting_type='v3', # ['v2', 'v3'] pasting_diff=True, concat_win_size=True, encoder_type='conv13_c3', # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3'] # ['conv13_c3_attn'] # ['combine33', 'combine43', 'combine53', 'combineFC'] outside_loss_weight=10.0, win_size_outside_loss_weight=10.0, resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA'] concat_cursor=True, use_softargmax=True, soft_beta=10, # value for the soft argmax raster_loss_weight=1.0, dec_rnn_size=256, # Size of decoder. dec_model='hyper', # Decoder: lstm, layer_norm or hyper. # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128. bin_gt=True, stop_accu_grad=True, random_cursor=True, cursor_type='next', raster_size=128, pix_drop_kp=1.0, # Dropout keep rate add_coordconv=True, position_format='abs', raster_loss_base_type='perceptual', # [l1, mse, perceptual] grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0. learning_rate=0.0001, # Learning rate. decay_rate=0.9999, # Learning rate decay per minibatch. decay_power=0.9, min_learning_rate=0.000001, # Minimum learning rate. use_recurrent_dropout=True, # Dropout with memory loss. Recommended recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep. use_input_dropout=False, # Input dropout. Recommend leaving False. input_dropout_prob=0.90, # Probability of input dropout keep. use_output_dropout=False, # Output dropout. Recommend leaving False. output_dropout_prob=0.90, # Probability of output dropout keep. model_mode='train' # ['train', 'eval', 'sample'] ) return hparams ================================================ FILE: launch_gui.bat ================================================ @echo OFF REM === Cesta k instalaci Anacondy === set "CONDAPATH=C:\ProgramData\anaconda3" REM === Název a cesta k prostředí === set "ENVNAME=virtual_sketching" set "ENVPATH=%USERPROFILE%\.conda\envs\%ENVNAME%" REM === Aktivace prostředí === call "%CONDAPATH%\Scripts\activate.bat" "%ENVPATH%" REM === Spuštění GUI === python virtual_sketch_gui.py REM === Pozastavení po ukončení === echo. pause REM === Deaktivace prostředí === call conda deactivate ================================================ FILE: model_common_test.py ================================================ import rnn import tensorflow as tf from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \ generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \ generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \ generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \ generative_cnn_c3_encoder_deeper13_attn class DiffPastingV3(object): def __init__(self, raster_size): self.patch_canvas = tf.placeholder(dtype=tf.float32, shape=(None, None, 1)) # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] self.cursor_pos_a = tf.placeholder(dtype=tf.float32, shape=(2)) # (2), float32, in large size self.image_size_a = tf.placeholder(dtype=tf.int32, shape=()) # () self.window_size_a = tf.placeholder(dtype=tf.float32, shape=()) # (), float32, with grad self.raster_size_a = float(raster_size) self.pasted_image = self.image_pasting_sampling_v3() # (image_size, image_size, 1), [0.0-BG, 1.0-stroke] def image_pasting_sampling_v3(self): padding_size = tf.cast(tf.ceil(self.window_size_a / 2.0), tf.int32) x1y1_a = self.cursor_pos_a - self.window_size_a / 2.0 # (2), float32 x2y2_a = self.cursor_pos_a + self.window_size_a / 2.0 # (2), float32 x1y1_a_floor = tf.floor(x1y1_a) # (2) x2y2_a_ceil = tf.ceil(x2y2_a) # (2) cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2) cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / self.window_size_a * self.raster_size_a # (2) raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y) image_size_b = self.raster_size_a window_size_b = self.raster_size_a * (raster_size_b / self.window_size_a) # (x, y) cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1) y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2. x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2. y2_b = y1_b + (window_size_b[1] - 1.) x2_b = x1_b + (window_size_b[0] - 1.) boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4) boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a box_ind_b = tf.ones((1), dtype=tf.int32) # (1) box_ind_b = tf.cumsum(box_ind_b) - 1 patch_canvas = tf.expand_dims(self.patch_canvas, axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4) valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b, crop_size=[raster_size_b[1], raster_size_b[0]]) valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1) pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size pad_down = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32) pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size pad_right = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32) paddings = [[pad_up, pad_down], [pad_left, pad_right], [0, 0]] pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT', constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke] pasted_image = pad_img[padding_size: padding_size + self.image_size_a, padding_size: padding_size + self.image_size_a, :] # (image_size, image_size, 1), [0.0-BG, 1.0-stroke] return pasted_image class VirtualSketchingModel(object): def __init__(self, hps, gpu_mode=True, reuse=False): """Initializer for the model. Args: hps: a HParams object containing model hyperparameters gpu_mode: a boolean that when True, uses GPU mode. reuse: a boolean that when true, attemps to reuse variables. """ self.hps = hps assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample'] # with tf.variable_scope('SCC', reuse=reuse): if not gpu_mode: with tf.device('/cpu:0'): print('Model using cpu.') self.build_model() else: print('-' * 100) print('model_mode:', hps.model_mode) print('Model using gpu.') self.build_model() def build_model(self): """Define model architecture.""" self.config_model() initial_state = self.get_decoder_inputs() self.initial_state = initial_state ## use pred as the prev points other_params, pen_ras, final_state = self.get_points_and_raster_image(self.image_size) # other_params: (N * max_seq_len, 6) # pen_ras: (N * max_seq_len, 2), after softmax self.other_params = other_params # (N * max_seq_len, 6) self.pen_ras = pen_ras # (N * max_seq_len, 2), after softmax self.final_state = final_state if not self.hps.use_softargmax: pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1) else: pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1) pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7) self.pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7) # pred_params: (N, max_seq_len, 7) def config_model(self): if self.hps.model_mode == 'train': self.global_step = tf.Variable(0, name='global_step', trainable=False) if self.hps.dec_model == 'lstm': dec_cell_fn = rnn.LSTMCell elif self.hps.dec_model == 'layer_norm': dec_cell_fn = rnn.LayerNormLSTMCell elif self.hps.dec_model == 'hyper': dec_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout dec_cell = dec_cell_fn( self.hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: # print('Input dropout mode = %s.' % use_input_dropout) # print('Output dropout mode = %s.' % use_output_dropout) # print('Recurrent dropout mode = %s.' % use_recurrent_dropout) if use_input_dropout: print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob) dec_cell = tf.contrib.rnn.DropoutWrapper( dec_cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob) dec_cell = tf.contrib.rnn.DropoutWrapper( dec_cell, output_keep_prob=self.hps.output_dropout_prob) self.dec_cell = dec_cell self.input_photo = tf.placeholder(dtype=tf.float32, shape=[self.hps.batch_size, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG] self.init_cursor = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0) self.init_width = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size]) # (1), in [0.0, 1.0] self.init_scaling = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size]) # (N), in [0.0, 1.0] self.init_window_size = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size]) # (N) self.image_size = tf.placeholder(dtype=tf.int32, shape=()) # () ########################### def normalize_image_m1to1(self, in_img_0to1): norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0) norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0) return norm_img_m1to1 def add_coords(self, input_tensor): batch_size_tensor = tf.shape(input_tensor)[0] # get N size xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size) xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1) xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0), [batch_size_tensor, 1]) # e.g. (N, raster_size) xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size) xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size) xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1) yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size) yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size) yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0), [batch_size_tensor, 1]) # (N, raster_size) yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1) yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size) yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1) xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1) yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1) # xx_channel = xx_channel * 2 - 1 # [-1, 1] # yy_channel = yy_channel * 2 - 1 ret = tf.concat([ input_tensor, xx_channel, yy_channel, ], axis=-1) # e.g. (N, raster_size, raster_size, 4) return ret def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos, image_size, window_size): """ :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG] :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG] :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG] :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG] :param cursor_pos: (N, 1, 2), in size [0.0, 1.0) :param window_size: (N, 1, 1), float, in large size :return: """ if self.hps.resize_method == 'BILINEAR': resize_method = tf.image.ResizeMethod.BILINEAR elif self.hps.resize_method == 'NEAREST_NEIGHBOR': resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR elif self.hps.resize_method == 'BICUBIC': resize_method = tf.image.ResizeMethod.BICUBIC elif self.hps.resize_method == 'AREA': resize_method = tf.image.ResizeMethod.AREA else: raise Exception('unknown resize_method', self.hps.resize_method) patch_photo = tf.stop_gradient(patch_photo) patch_canvas = tf.stop_gradient(patch_canvas) cursor_pos = tf.stop_gradient(cursor_pos) window_size = tf.stop_gradient(window_size) entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo, (self.hps.raster_size, self.hps.raster_size), method=resize_method)) entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas, (self.hps.raster_size, self.hps.raster_size), method=resize_method)) entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG] entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG] if self.hps.encode_cursor_type == 'value': cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2) cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1]) cursor_info = cursor_pos_norm else: raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type) batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info], axis=-1) # [N, raster_size, raster_size, 6/10] batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4] batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info], axis=-1) # [N, raster_size, raster_size, 4/6] if self.hps.model_mode == 'train': is_training = True dropout_keep_prob = self.hps.pix_drop_kp else: is_training = False dropout_keep_prob = 1.0 if self.hps.add_coordconv: batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2) batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2) batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2) if 'combine' in self.hps.encoder_type: if self.hps.encoder_type == 'combine33': image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combine43': image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combine53': image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combineFC': image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 256) else: raise Exception('Unknown encoder_type', self.hps.encoder_type) else: with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE): if self.hps.encoder_type == 'conv10': image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_deep': image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512) elif self.hps.encoder_type == 'conv13': image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_c3': image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_deep_c3': image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512) elif self.hps.encoder_type == 'conv13_c3': image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv13_c3_attn': image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) else: raise Exception('Unknown encoder_type', self.hps.encoder_type) return image_embedding def build_seq_decoder(self, dec_cell, actual_input_x, initial_state): rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x) rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size]) pen_n_out = 2 params_n_out = 6 with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE): output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out]) output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0)) output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out) with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE): output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out]) output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0)) output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out) output = tf.concat([output_pen, output_params], axis=1) # (N, n_out) return output, last_state def get_mixture_coef(self, outputs): z = outputs z_pen_logits = z[:, 0:2] # (N, 2), pen states z_other_params_logits = z[:, 2:] # (N, 6) z_pen = tf.nn.softmax(z_pen_logits) # (N, 2) if self.hps.position_format == 'abs': x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2) x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2) widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1) widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width) scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling] # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling), # self.hps.min_scaling) z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6) else: # "rel" raise Exception('Unknown position_format', self.hps.position_format) r = [z_other_params, z_pen] return r ########################### def get_decoder_inputs(self): initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32) return initial_state def rnn_decoder(self, dec_cell, initial_state, actual_input_x): with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE): output, last_state = tf.nn.dynamic_rnn( dec_cell, actual_input_x, initial_state=initial_state, time_major=False, swap_memory=True, dtype=tf.float32) return output, last_state ########################### def image_padding(self, ori_image, window_size, pad_value): """ Pad with (bg) :param ori_image: :return: """ paddings = [[0, 0], [window_size // 2, window_size // 2], [window_size // 2, window_size // 2], [0, 0]] pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k) return pad_img def image_cropping_fn(self, fn_inputs): """ crop the patch :return: """ index_offset = self.hps.input_channel - 1 input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, -), [0.0-BG, 1.0-stroke] cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0) image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32 window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # () input_img_reshape = tf.expand_dims(input_image, axis=0) pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0) cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32) x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # () y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # () patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4) # resize to raster_size patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size), method=tf.image.ResizeMethod.AREA) patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0) # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke] return patch_image_scaled def image_cropping(self, cursor_position, input_img, image_size, window_sizes): """ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0) :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad """ input_img_ = input_img window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2) cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2) image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1) image_size_ = tf.tile(image_size_, [self.hps.batch_size, image_size, image_size, 1]) window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1) window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1) fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_], axis=-1) # (N, image_size, image_size, 2/4 + 4) curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -) return curr_patch_imgs def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes): """ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0) :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad """ window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32)) cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1) y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2 x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2 y2 = y1 + (window_sizes_non_grad - 1.0) x2 = x1 + (window_sizes_non_grad - 1.0) boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4) boxes = tf.squeeze(boxes, axis=1) # (N, 4) boxes = boxes / tf.cast(image_size - 1, tf.float32) box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N) box_ind = tf.cast(box_ind, dtype=tf.int32) box_ind = tf.cumsum(box_ind) - 1 curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind, crop_size=[self.hps.raster_size, self.hps.raster_size]) # (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke] return curr_patch_imgs def get_points_and_raster_image(self, image_size): ## generate the other_params and pen_ras and raster image for raster loss prev_state = self.initial_state # (N, dec_rnn_size * 3) prev_width = self.init_width # (N) prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=-1), axis=-1) # (N, 1, 1) prev_scaling = self.init_scaling # (N) prev_scaling = tf.reshape(prev_scaling, (-1, 1, 1)) # (N, 1, 1) prev_window_size = self.init_window_size # (N) prev_window_size = tf.reshape(prev_window_size, (-1, 1, 1)) # (N, 1, 1) cursor_position_temp = self.init_cursor self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0) cursor_position_loop = self.cursor_position other_params_list = [] pen_ras_list = [] curr_canvas_soft = tf.zeros_like(self.input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke] curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke] #### sampling part - start #### self.curr_canvas_hard = curr_canvas_hard if self.hps.cropping_type == 'v3': cropping_func = self.image_cropping_v3 # elif self.hps.cropping_type == 'v2': # cropping_func = self.image_cropping else: raise Exception('Unknown cropping_type', self.hps.cropping_type) for time_i in range(self.hps.max_seq_len): cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0) curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32)) curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32)) ## patch-level encoding # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation. curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard) curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1) # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG] crop_inputs = tf.concat([1.0 - self.input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1+1) cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size) index_offset = self.hps.input_channel - 1 curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke] curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset] # (N, raster_size, raster_size, 1/3), [0.0-BG, 1.0-stroke] curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG] curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs) # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG] # Normalizing image curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG] curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG] ## image-level encoding combined_z = self.build_combined_encoder( curr_patch_canvas_hard_non_grad, curr_patch_inputs, 1.0 - curr_canvas_hard_non_grad, self.input_photo, cursor_position_non_grad, image_size, curr_window_size) # (N, z_size) combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size) curr_window_size_top_side_norm_non_grad = \ tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32)) curr_window_size_bottom_side_norm_non_grad = \ tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32)) if not self.hps.concat_win_size: combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size) else: combined_z = tf.concat([tf.stop_gradient(prev_width), curr_window_size_top_side_norm_non_grad, curr_window_size_bottom_side_norm_non_grad, combined_z], 2) # (N, 1, 2+z_size) if self.hps.concat_cursor: prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size) else: prev_input_x = combined_z # (N, 1, 2+z_size) h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state) # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3) [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output) # o_other_params: (N * 1, 6) # o_pen_ras: (N * 1, 2), after softmax o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6) o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2) other_params_list.append(o_other_params) pen_ras_list.append(o_pen_ras_raw) #### sampling part - end #### prev_state = next_state other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6) pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2) return other_params_, pen_ras_, prev_state def differentiable_argmax(self, input_pen, soft_beta): """ Differentiable argmax trick. :param input_pen: (N, n_class) :return: pen_state: (N, 1) """ def sign_onehot(x): """ :param x: (N, n_class) :return: (N, n_class) """ y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x) y = (y - 1) * (-1) return y def softargmax(x, beta=1e2): """ :param x: (N, n_class) :param beta: 1e10 is the best. 1e2 is acceptable. :return: (N) """ x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2) return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1 ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero. # pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2) # pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1) pen_state = softargmax(input_pen, soft_beta) pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1) return pen_state ================================================ FILE: model_common_train.py ================================================ import rnn import tensorflow as tf from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \ generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \ generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \ generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \ generative_cnn_c3_encoder_deeper13_attn from rasterization_utils.NeuralRenderer import NeuralRasterizorStep from vgg_utils.VGG16 import vgg_net_slim class VirtualSketchingModel(object): def __init__(self, hps, gpu_mode=True, reuse=False): """Initializer for the model. Args: hps: a HParams object containing model hyperparameters gpu_mode: a boolean that when True, uses GPU mode. reuse: a boolean that when true, attemps to reuse variables. """ self.hps = hps assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample'] # with tf.variable_scope('SCC', reuse=reuse): if not gpu_mode: with tf.device('/cpu:0'): print('Model using cpu.') self.build_model() else: print('-' * 100) print('model_mode:', hps.model_mode) print('Model using gpu.') self.build_model() def build_model(self): """Define model architecture.""" self.config_model() initial_state = self.get_decoder_inputs() self.initial_state = initial_state self.initial_state_list = tf.split(self.initial_state, self.total_loop, axis=0) total_loss_list = [] ras_loss_list = [] perc_relu_raw_list = [] perc_relu_norm_list = [] sn_loss_list = [] cursor_outside_loss_list = [] win_size_outside_loss_list = [] early_state_loss_list = [] tower_grads = [] pred_raster_imgs_list = [] pred_raster_imgs_rgb_list = [] for t_i in range(self.total_loop): gpu_idx = t_i // self.hps.loop_per_gpu gpu_i = self.hps.gpus[gpu_idx] print(self.hps.model_mode, 'model, gpu:', gpu_i, ', loop:', t_i % self.hps.loop_per_gpu) with tf.device('/gpu:%d' % gpu_i): with tf.name_scope('GPU_%d' % gpu_i) as scope: if t_i > 0: tf.get_variable_scope().reuse_variables() else: total_loss_list.clear() ras_loss_list.clear() perc_relu_raw_list.clear() perc_relu_norm_list.clear() sn_loss_list.clear() cursor_outside_loss_list.clear() win_size_outside_loss_list.clear() early_state_loss_list.clear() tower_grads.clear() pred_raster_imgs_list.clear() pred_raster_imgs_rgb_list.clear() split_input_photo = self.input_photo_list[t_i] split_image_size = self.image_size[t_i] split_init_cursor = self.init_cursor_list[t_i] split_initial_state = self.initial_state_list[t_i] if self.hps.input_channel == 1: split_target_sketch = split_input_photo else: split_target_sketch = self.target_sketch_list[t_i] ## use pred as the prev points other_params, pen_ras, final_state, pred_raster_images, pred_raster_images_rgb, \ pos_before_max_min, win_size_before_max_min \ = self.get_points_and_raster_image(split_initial_state, split_init_cursor, split_input_photo, split_image_size) # other_params: (N * max_seq_len, 6) # pen_ras: (N * max_seq_len, 2), after softmax # pos_before_max_min: (N, max_seq_len, 2), in image_size # win_size_before_max_min: (N, max_seq_len, 1), in image_size pred_raster_imgs = 1.0 - pred_raster_images # (N, image_size, image_size), [0.0-stroke, 1.0-BG] pred_raster_imgs_rgb = 1.0 - pred_raster_images_rgb # (N, image_size, image_size, 3) pred_raster_imgs_list.append(pred_raster_imgs) pred_raster_imgs_rgb_list.append(pred_raster_imgs_rgb) if not self.hps.use_softargmax: pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1) else: pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1) pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7) pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7) # pred_params: (N, max_seq_len, 7) if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval': raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, \ early_pen_states_cost, \ perc_relu_loss_raw, perc_relu_loss_norm = \ self.build_losses(split_target_sketch, pred_raster_imgs, pred_params, pos_before_max_min, win_size_before_max_min, split_image_size) # perc_relu_loss_raw, perc_relu_loss_norm: (n_layers) ras_loss_list.append(raster_cost) perc_relu_raw_list.append(perc_relu_loss_raw) perc_relu_norm_list.append(perc_relu_loss_norm) sn_loss_list.append(sn_cost) cursor_outside_loss_list.append(cursor_outside_cost) win_size_outside_loss_list.append(winsize_outside_cost) early_state_loss_list.append(early_pen_states_cost) if self.hps.model_mode == 'train': total_cost_split, grads_and_vars_split = self.build_training_op_split( raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, early_pen_states_cost) total_loss_list.append(total_cost_split) tower_grads.append(grads_and_vars_split) self.raster_cost = tf.reduce_mean(tf.stack(ras_loss_list, axis=0)) self.perc_relu_losses_raw = tf.reduce_mean(tf.stack(perc_relu_raw_list, axis=0), axis=0) # (n_layers) self.perc_relu_losses_norm = tf.reduce_mean(tf.stack(perc_relu_norm_list, axis=0), axis=0) # (n_layers) self.stroke_num_cost = tf.reduce_mean(tf.stack(sn_loss_list, axis=0)) self.pos_outside_cost = tf.reduce_mean(tf.stack(cursor_outside_loss_list, axis=0)) self.win_size_outside_cost = tf.reduce_mean(tf.stack(win_size_outside_loss_list, axis=0)) self.early_pen_states_cost = tf.reduce_mean(tf.stack(early_state_loss_list, axis=0)) self.cost = tf.reduce_mean(tf.stack(total_loss_list, axis=0)) self.pred_raster_imgs = tf.concat(pred_raster_imgs_list, axis=0) # (N, image_size, image_size), [0.0-stroke, 1.0-BG] self.pred_raster_imgs_rgb = tf.concat(pred_raster_imgs_rgb_list, axis=0) # (N, image_size, image_size, 3) if self.hps.model_mode == 'train': self.build_training_op(tower_grads) def config_model(self): if self.hps.model_mode == 'train': self.global_step = tf.Variable(0, name='global_step', trainable=False) if self.hps.dec_model == 'lstm': dec_cell_fn = rnn.LSTMCell elif self.hps.dec_model == 'layer_norm': dec_cell_fn = rnn.LayerNormLSTMCell elif self.hps.dec_model == 'hyper': dec_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout dec_cell = dec_cell_fn( self.hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: # print('Input dropout mode = %s.' % use_input_dropout) # print('Output dropout mode = %s.' % use_output_dropout) # print('Recurrent dropout mode = %s.' % use_recurrent_dropout) if use_input_dropout: print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob) dec_cell = tf.contrib.rnn.DropoutWrapper( dec_cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob) dec_cell = tf.contrib.rnn.DropoutWrapper( dec_cell, output_keep_prob=self.hps.output_dropout_prob) self.dec_cell = dec_cell self.total_loop = len(self.hps.gpus) * self.hps.loop_per_gpu self.init_cursor = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0) self.init_width = tf.placeholder( dtype=tf.float32, shape=[1]) # (1), in [0.0, 1.0] self.image_size = tf.placeholder(dtype=tf.int32, shape=(self.total_loop)) # () self.init_cursor_list = tf.split(self.init_cursor, self.total_loop, axis=0) self.input_photo_list = [] for loop_i in range(self.total_loop): input_photo_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG] self.input_photo_list.append(input_photo_i) if self.hps.input_channel == 3: self.target_sketch_list = [] for loop_i in range(self.total_loop): target_sketch_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1]) # [0.0-stroke, 1.0-BG] self.target_sketch_list.append(target_sketch_i) if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval': self.stroke_num_loss_weight = tf.Variable(0.0, trainable=False) self.early_pen_loss_start_idx = tf.Variable(0, dtype=tf.int32, trainable=False) self.early_pen_loss_end_idx = tf.Variable(0, dtype=tf.int32, trainable=False) if self.hps.model_mode == 'train': self.perc_loss_mean_list = [] for loop_i in range(len(self.hps.perc_loss_layers)): relu_loss_mean = tf.Variable(0.0, trainable=False) self.perc_loss_mean_list.append(relu_loss_mean) self.last_step_num = tf.Variable(0.0, trainable=False) with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE): self.lr = tf.Variable(self.hps.learning_rate, trainable=False) self.optimizer = tf.train.AdamOptimizer(self.lr) ########################### def normalize_image_m1to1(self, in_img_0to1): norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0) norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0) return norm_img_m1to1 def add_coords(self, input_tensor): batch_size_tensor = tf.shape(input_tensor)[0] # get N size xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size) xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1) xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0), [batch_size_tensor, 1]) # e.g. (N, raster_size) xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size) xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size) xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1) yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size) yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size) yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0), [batch_size_tensor, 1]) # (N, raster_size) yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1) yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size) yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1) xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1) yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1) # xx_channel = xx_channel * 2 - 1 # [-1, 1] # yy_channel = yy_channel * 2 - 1 ret = tf.concat([ input_tensor, xx_channel, yy_channel, ], axis=-1) # e.g. (N, raster_size, raster_size, 4) return ret def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos, image_size, window_size): """ :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG] :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG] :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG] :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG] :param cursor_pos: (N, 1, 2), in size [0.0, 1.0) :param window_size: (N, 1, 1), float, in large size :return: """ if self.hps.resize_method == 'BILINEAR': resize_method = tf.image.ResizeMethod.BILINEAR elif self.hps.resize_method == 'NEAREST_NEIGHBOR': resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR elif self.hps.resize_method == 'BICUBIC': resize_method = tf.image.ResizeMethod.BICUBIC elif self.hps.resize_method == 'AREA': resize_method = tf.image.ResizeMethod.AREA else: raise Exception('unknown resize_method', self.hps.resize_method) patch_photo = tf.stop_gradient(patch_photo) patch_canvas = tf.stop_gradient(patch_canvas) cursor_pos = tf.stop_gradient(cursor_pos) window_size = tf.stop_gradient(window_size) entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo, (self.hps.raster_size, self.hps.raster_size), method=resize_method)) entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas, (self.hps.raster_size, self.hps.raster_size), method=resize_method)) entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG] entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG] if self.hps.encode_cursor_type == 'value': cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2) cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1]) cursor_info = cursor_pos_norm else: raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type) batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info], axis=-1) # [N, raster_size, raster_size, 6/10] batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4] batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info], axis=-1) # [N, raster_size, raster_size, 4/6] if self.hps.model_mode == 'train': is_training = True dropout_keep_prob = self.hps.pix_drop_kp else: is_training = False dropout_keep_prob = 1.0 if self.hps.add_coordconv: batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2) batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2) batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2) if 'combine' in self.hps.encoder_type: if self.hps.encoder_type == 'combine33': image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combine43': image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combine53': image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'combineFC': image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global, is_training, dropout_keep_prob) # (N, 256) else: raise Exception('Unknown encoder_type', self.hps.encoder_type) else: with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE): if self.hps.encoder_type == 'conv10': image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_deep': image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512) elif self.hps.encoder_type == 'conv13': image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_c3': image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv10_deep_c3': image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512) elif self.hps.encoder_type == 'conv13_c3': image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) elif self.hps.encoder_type == 'conv13_c3_attn': image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128) else: raise Exception('Unknown encoder_type', self.hps.encoder_type) return image_embedding def build_seq_decoder(self, dec_cell, actual_input_x, initial_state): rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x) rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size]) pen_n_out = 2 params_n_out = 6 with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE): output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out]) output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0)) output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out) with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE): output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out]) output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0)) output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out) output = tf.concat([output_pen, output_params], axis=1) # (N, n_out) return output, last_state def get_mixture_coef(self, outputs): z = outputs z_pen_logits = z[:, 0:2] # (N, 2), pen states z_other_params_logits = z[:, 2:] # (N, 6) z_pen = tf.nn.softmax(z_pen_logits) # (N, 2) if self.hps.position_format == 'abs': x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2) x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2) widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1) widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width) scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling] # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling), # self.hps.min_scaling) z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6) else: # "rel" raise Exception('Unknown position_format', self.hps.position_format) r = [z_other_params, z_pen] return r ########################### def get_decoder_inputs(self): initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32) return initial_state def rnn_decoder(self, dec_cell, initial_state, actual_input_x): with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE): output, last_state = tf.nn.dynamic_rnn( dec_cell, actual_input_x, initial_state=initial_state, time_major=False, swap_memory=True, dtype=tf.float32) return output, last_state ########################### def image_padding(self, ori_image, window_size, pad_value): """ Pad with (bg) :param ori_image: :return: """ paddings = [[0, 0], [window_size // 2, window_size // 2], [window_size // 2, window_size // 2], [0, 0]] pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k) return pad_img def image_cropping_fn(self, fn_inputs): """ crop the patch :return: """ index_offset = self.hps.input_channel - 1 input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, 2), [0.0-BG, 1.0-stroke] cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0) image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32 window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # () input_img_reshape = tf.expand_dims(input_image, axis=0) pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0) cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32) x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # () y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # () patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4) # resize to raster_size patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size), method=tf.image.ResizeMethod.AREA) patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0) # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke] return patch_image_scaled def image_cropping(self, cursor_position, input_img, image_size, window_sizes): """ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0) :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad """ input_img_ = input_img window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2) cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2) image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1) image_size_ = tf.tile(image_size_, [self.hps.batch_size // self.total_loop, image_size, image_size, 1]) window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1) window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1) fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_], axis=-1) # (N, image_size, image_size, 2/4 + 4) curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -) return curr_patch_imgs def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes): """ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0) :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad """ window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32)) cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1) y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2 x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2 y2 = y1 + (window_sizes_non_grad - 1.0) x2 = x1 + (window_sizes_non_grad - 1.0) boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4) boxes = tf.squeeze(boxes, axis=1) # (N, 4) boxes = boxes / tf.cast(image_size - 1, tf.float32) box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N) box_ind = tf.cast(box_ind, dtype=tf.int32) box_ind = tf.cumsum(box_ind) - 1 curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind, crop_size=[self.hps.raster_size, self.hps.raster_size]) # (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke] return curr_patch_imgs def get_pixel_value(self, img, x, y): """ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image. Input ----- - img: tensor of shape (B, H, W, C) - x: flattened tensor of shape (B, H', W') - y: flattened tensor of shape (B, H', W') Returns ------- - output: tensor of shape (B, H', W', C) """ shape = tf.shape(x) batch_size = shape[0] height = shape[1] width = shape[2] batch_idx = tf.range(0, batch_size) batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) b = tf.tile(batch_idx, (1, height, width)) indices = tf.stack([b, y, x], 3) return tf.gather_nd(img, indices) def image_pasting_nondiff_single(self, fn_inputs): patch_image = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # () window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # () patch_image_scaled = tf.expand_dims(patch_image, axis=0) # (1, raster_size, raster_size, 1) patch_image_scaled = tf.image.resize_images(patch_image_scaled, (window_size, window_size), method=tf.image.ResizeMethod.BILINEAR) patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0) # patch_canvas_scaled: (window_size, window_size, 1) cursor_pos = tf.cast(tf.round(cursor_pos), dtype=tf.int32) # (2) cursor_x, cursor_y = cursor_pos[0], cursor_pos[1] pad_up = cursor_y pad_down = image_size - cursor_y pad_left = cursor_x pad_right = image_size - cursor_x paddings = [[pad_up, pad_down], [pad_left, pad_right], [0, 0]] pad_img = tf.pad(patch_image_scaled, paddings=paddings, mode='CONSTANT', constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke] crop_start = window_size // 2 pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :] return pasted_image def image_pasting_diff_single(self, fn_inputs): patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # () window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # () cursor_x, cursor_y = cursor_pos[0], cursor_pos[1] patch_canvas_scaled = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1) patch_canvas_scaled = tf.image.resize_images(patch_canvas_scaled, (window_size, window_size), method=tf.image.ResizeMethod.BILINEAR) # patch_canvas_scaled: (1, window_size, window_size, 1) valid_canvas = self.image_pasting_diff_batch(patch_canvas_scaled, tf.expand_dims(tf.expand_dims(cursor_pos, axis=0), axis=0), window_size) valid_canvas = tf.squeeze(valid_canvas, axis=0) # (window_size + 1, window_size + 1, 1) pad_up = tf.cast(tf.floor(cursor_y), tf.int32) pad_down = image_size - 1 - tf.cast(tf.floor(cursor_y), tf.int32) pad_left = tf.cast(tf.floor(cursor_x), tf.int32) pad_right = image_size - 1 - tf.cast(tf.floor(cursor_x), tf.int32) paddings = [[pad_up, pad_down], [pad_left, pad_right], [0, 0]] pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT', constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke] crop_start = window_size // 2 pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :] return pasted_image def image_pasting_diff_single_v3(self, fn_inputs): patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] cursor_pos_a = fn_inputs[0, 0, 1:3] # (2), float32, in large size image_size_a = tf.cast(fn_inputs[0, 0, 3], tf.int32) # () window_size_a = fn_inputs[0, 0, 4] # (), float32, with grad raster_size_a = float(self.hps.raster_size) padding_size = tf.cast(tf.ceil(window_size_a / 2.0), tf.int32) x1y1_a = cursor_pos_a - window_size_a / 2.0 # (2), float32 x2y2_a = cursor_pos_a + window_size_a / 2.0 # (2), float32 x1y1_a_floor = tf.floor(x1y1_a) # (2) x2y2_a_ceil = tf.ceil(x2y2_a) # (2) cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2) cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / window_size_a * raster_size_a # (2) raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y) image_size_b = raster_size_a window_size_b = raster_size_a * (raster_size_b / window_size_a) # (x, y) cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1) y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2. x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2. y2_b = y1_b + (window_size_b[1] - 1.) x2_b = x1_b + (window_size_b[0] - 1.) boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4) boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a box_ind_b = tf.ones((1), dtype=tf.int32) # (1) box_ind_b = tf.cumsum(box_ind_b) - 1 patch_canvas = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4) valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b, crop_size=[raster_size_b[1], raster_size_b[0]]) valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1) pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size pad_down = image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32) pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size pad_right = image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32) paddings = [[pad_up, pad_down], [pad_left, pad_right], [0, 0]] pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT', constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke] pasted_image = pad_img[padding_size: padding_size + image_size_a, padding_size: padding_size + image_size_a, :] return pasted_image def image_pasting_diff_batch(self, patch_image, cursor_position, window_size): """ :param patch_img: (N, window_size, window_size, 1), [0.0-BG, 1.0-stroke] :param cursor_position: (N, 1, 2), in large size :return: """ paddings1 = [[0, 0], [1, 1], [1, 1], [0, 0]] patch_image_pad1 = tf.pad(patch_image, paddings=paddings1, mode='CONSTANT', constant_values=0.0) # (N, window_size+2, window_size+2, 1), [0.0-BG, 1.0-stroke] cursor_x, cursor_y = cursor_position[:, :, 0:1], cursor_position[:, :, 1:2] # (N, 1, 1) cursor_x_f, cursor_y_f = tf.floor(cursor_x), tf.floor(cursor_y) patch_x, patch_y = 1.0 - (cursor_x - cursor_x_f), 1.0 - (cursor_y - cursor_y_f) # (N, 1, 1) x_ones = tf.ones_like(patch_x, dtype=tf.float32) # (N, 1, 1) x_ones = tf.tile(x_ones, [1, 1, window_size]) # (N, 1, window_size) patch_x = tf.concat([patch_x, x_ones], axis=-1) # (N, 1, window_size + 1) patch_x = tf.tile(patch_x, [1, window_size + 1, 1]) # (N, window_size + 1, window_size + 1) patch_x = tf.cumsum(patch_x, axis=-1) # (N, window_size + 1, window_size + 1) patch_x0 = tf.cast(tf.floor(patch_x), tf.int32) # (N, window_size + 1, window_size + 1) patch_x1 = patch_x0 + 1 # (N, window_size + 1, window_size + 1) y_ones = tf.ones_like(patch_y, dtype=tf.float32) # (N, 1, 1) y_ones = tf.tile(y_ones, [1, window_size, 1]) # (N, window_size, 1) patch_y = tf.concat([patch_y, y_ones], axis=1) # (N, window_size + 1, 1) patch_y = tf.tile(patch_y, [1, 1, window_size + 1]) # (N, window_size + 1, window_size + 1) patch_y = tf.cumsum(patch_y, axis=1) # (N, window_size + 1, window_size + 1) patch_y0 = tf.cast(tf.floor(patch_y), tf.int32) # (N, window_size + 1, window_size + 1) patch_y1 = patch_y0 + 1 # (N, window_size + 1, window_size + 1) # get pixel value at corner coords valid_canvas_patch_a = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y0) valid_canvas_patch_b = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y1) valid_canvas_patch_c = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y0) valid_canvas_patch_d = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y1) # (N, window_size + 1, window_size + 1, 1) patch_x0 = tf.cast(patch_x0, tf.float32) patch_x1 = tf.cast(patch_x1, tf.float32) patch_y0 = tf.cast(patch_y0, tf.float32) patch_y1 = tf.cast(patch_y1, tf.float32) # calculate deltas wa = (patch_x1 - patch_x) * (patch_y1 - patch_y) wb = (patch_x1 - patch_x) * (patch_y - patch_y0) wc = (patch_x - patch_x0) * (patch_y1 - patch_y) wd = (patch_x - patch_x0) * (patch_y - patch_y0) # (N, window_size + 1, window_size + 1) # add dimension for addition wa = tf.expand_dims(wa, axis=3) wb = tf.expand_dims(wb, axis=3) wc = tf.expand_dims(wc, axis=3) wd = tf.expand_dims(wd, axis=3) # (N, window_size + 1, window_size + 1, 1) # compute output valid_canvas_patch_ = tf.add_n([wa * valid_canvas_patch_a, wb * valid_canvas_patch_b, wc * valid_canvas_patch_c, wd * valid_canvas_patch_d]) # (N, window_size + 1, window_size + 1, 1) return valid_canvas_patch_ def image_pasting(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False): """ paste the patch_img to padded size based on cursor_position :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0) :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad :return: """ cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size window_sizes_r = tf.round(window_sizes) # (N, 1, 1), no grad patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1) cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2) cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size, 1]) # (N, raster_size, raster_size, 2) image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1) image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size, self.hps.raster_size, 1]) window_sizes_tile = tf.reshape(window_sizes_r, (-1, 1, 1, 1)) # (N, 1, 1, 1) window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1]) pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile], axis=-1) # (N, raster_size, raster_size, 5) if is_differentiable: curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single, pasting_inputs, parallel_iterations=32) # (N, image_size, image_size, 1) else: curr_paste_imgs = tf.map_fn(self.image_pasting_nondiff_single, pasting_inputs, parallel_iterations=32) # (N, image_size, image_size, 1) curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size) return curr_paste_imgs def image_pasting_v3(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False): """ paste the patch_img to padded size based on cursor_position :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0) :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke] :param window_sizes: (N, 1, 1), float32, with grad :return: """ cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size if is_differentiable: patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1) cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2) cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size, 1]) # (N, raster_size, raster_size, 2) image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1) image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size, self.hps.raster_size, 1]) window_sizes_tile = tf.reshape(window_sizes, (-1, 1, 1, 1)) # (N, 1, 1, 1) window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1]) pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile], axis=-1) # (N, raster_size, raster_size, 5) curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single_v3, pasting_inputs, parallel_iterations=32) # (N, image_size, image_size, 1) else: raise Exception('Unfinished...') curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size) return curr_paste_imgs def get_points_and_raster_image(self, initial_state, init_cursor, input_photo, image_size): ## generate the other_params and pen_ras and raster image for raster loss prev_state = initial_state # (N, dec_rnn_size * 3) prev_width = self.init_width # (1) prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=0), axis=0) # (1, 1, 1) prev_width = tf.tile(prev_width, [self.hps.batch_size // self.total_loop, 1, 1]) # (N, 1, 1) prev_scaling = tf.ones((self.hps.batch_size // self.total_loop, 1, 1)) # (N, 1, 1) prev_window_size = tf.ones((self.hps.batch_size // self.total_loop, 1, 1), dtype=tf.float32) * float(self.hps.raster_size) # (N, 1, 1) cursor_position_temp = init_cursor self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0) cursor_position_loop = self.cursor_position other_params_list = [] pen_ras_list = [] pos_before_max_min_list = [] win_size_before_max_min_list = [] curr_canvas_soft = tf.zeros_like(input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke] curr_canvas_soft_rgb = tf.tile(tf.zeros_like(input_photo[:, :, :, 0:1]), [1, 1, 1, 3]) # (N, image_size, image_size, 3), [0.0-BG, 1.0-stroke] curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke] #### sampling part - start #### self.curr_canvas_hard = curr_canvas_hard rasterizor_st = NeuralRasterizorStep( raster_size=self.hps.raster_size, position_format=self.hps.position_format) if self.hps.cropping_type == 'v3': cropping_func = self.image_cropping_v3 # elif self.hps.cropping_type == 'v2': # cropping_func = self.image_cropping else: raise Exception('Unknown cropping_type', self.hps.cropping_type) if self.hps.pasting_type == 'v3': pasting_func = self.image_pasting_v3 # elif self.hps.pasting_type == 'v2': # pasting_func = self.image_pasting else: raise Exception('Unknown pasting_type', self.hps.pasting_type) for time_i in range(self.hps.max_seq_len): cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0) curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32)) curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32)) ## patch-level encoding # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation. curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard) curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1) # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG] crop_inputs = tf.concat([1.0 - input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1/3+1) cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size) index_offset = self.hps.input_channel - 1 curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke] curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset] # (N, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke] curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG] curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs) # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG] # Normalizing image curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG] curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG] ## image-level encoding combined_z = self.build_combined_encoder( curr_patch_canvas_hard_non_grad, curr_patch_inputs, 1.0 - curr_canvas_hard_non_grad, input_photo, cursor_position_non_grad, image_size, curr_window_size) # (N, z_size) combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size) curr_window_size_top_side_norm_non_grad = \ tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32)) curr_window_size_bottom_side_norm_non_grad = \ tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32)) if not self.hps.concat_win_size: combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size) else: combined_z = tf.concat([tf.stop_gradient(prev_width), curr_window_size_top_side_norm_non_grad, curr_window_size_bottom_side_norm_non_grad, combined_z], 2) # (N, 1, 2+z_size) if self.hps.concat_cursor: prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size) else: prev_input_x = combined_z # (N, 1, 2+z_size) h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state) # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3) [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output) # o_other_params: (N * 1, 6) # o_pen_ras: (N * 1, 2), after softmax o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6) o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2) other_params_list.append(o_other_params) pen_ras_list.append(o_pen_ras_raw) #### sampling part - end #### if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval' or self.hps.model_mode == 'eval_sample': # use renderer here to convert the strokes to image curr_other_params = tf.squeeze(o_other_params, axis=1) # (N, 6), (x1, y1)=[0.0, 1.0], (x2, y2)=[-1.0, 1.0] x1y1, x2y2, width2, scaling = curr_other_params[:, 0:2], curr_other_params[:, 2:4],\ curr_other_params[:, 4:5], curr_other_params[:, 5:6] x0y0 = tf.zeros_like(x2y2) # (N, 2), [-1.0, 1.0] x0y0 = tf.div(tf.add(x0y0, 1.0), 2.0) # (N, 2), [0.0, 1.0] x2y2 = tf.div(tf.add(x2y2, 1.0), 2.0) # (N, 2), [0.0, 1.0] widths = tf.concat([tf.squeeze(prev_width, axis=1), width2], axis=1) # (N, 2) curr_other_params = tf.concat([x0y0, x1y1, x2y2, widths], axis=-1) # (N, 8), (x0, y0)&(x2, y2)=[0.0, 1.0] curr_stroke_image = rasterizor_st.raster_func_stroke_abs(curr_other_params) # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke] curr_stroke_image_large = pasting_func(cursor_position_loop, curr_stroke_image, image_size, curr_window_size, is_differentiable=self.hps.pasting_diff) # (N, image_size, image_size), [0.0-BG, 1.0-stroke] ## soft if not self.hps.use_softargmax: curr_state_soft = o_pen_ras[:, 1:2] # (N, 1) else: curr_state_soft = self.differentiable_argmax(o_pen_ras, self.hps.soft_beta) # (N, 1) curr_state_soft = tf.expand_dims(curr_state_soft, axis=1) # (N, 1, 1) filter_curr_stroke_image_soft = tf.multiply(tf.subtract(1.0, curr_state_soft), curr_stroke_image_large) # (N, image_size, image_size), [0.0-BG, 1.0-stroke] curr_canvas_soft = tf.add(curr_canvas_soft, filter_curr_stroke_image_soft) # [0.0-BG, 1.0-stroke] ## hard curr_state_hard = tf.expand_dims(tf.cast(tf.argmax(o_pen_ras_raw, axis=-1), dtype=tf.float32), axis=-1) # (N, 1, 1) filter_curr_stroke_image_hard = tf.multiply(tf.subtract(1.0, curr_state_hard), curr_stroke_image_large) # (N, image_size, image_size), [0.0-BG, 1.0-stroke] self.curr_canvas_hard = tf.add(self.curr_canvas_hard, filter_curr_stroke_image_hard) # [0.0-BG, 1.0-stroke] self.curr_canvas_hard = tf.clip_by_value(self.curr_canvas_hard, 0.0, 1.0) # [0.0-BG, 1.0-stroke] next_width = o_other_params[:, :, 4:5] next_scaling = o_other_params[:, :, 5:6] next_window_size = tf.multiply(next_scaling, tf.stop_gradient(curr_window_size)) # float, with grad window_size_before_max_min = next_window_size # (N, 1, 1), large-level win_size_before_max_min_list.append(window_size_before_max_min) next_window_size = tf.maximum(next_window_size, tf.cast(self.hps.min_window_size, tf.float32)) next_window_size = tf.minimum(next_window_size, tf.cast(image_size, tf.float32)) prev_state = next_state prev_width = next_width * curr_window_size / next_window_size # (N, 1, 1) prev_scaling = next_scaling # (N, 1, 1)) prev_window_size = curr_window_size # update the cursor position new_cursor_offsets = tf.multiply(o_other_params[:, :, 2:4], tf.divide(curr_window_size, 2.0)) # (N, 1, 2), window-level new_cursor_offset_next = new_cursor_offsets new_cursor_offset_next = tf.concat([new_cursor_offset_next[:, :, 1:2], new_cursor_offset_next[:, :, 0:1]], axis=-1) cursor_position_loop_large = tf.multiply(cursor_position_loop, tf.cast(image_size, tf.float32)) if self.hps.stop_accu_grad: stroke_position_next = tf.stop_gradient(cursor_position_loop_large) + new_cursor_offset_next # (N, 1, 2), large-level else: stroke_position_next = cursor_position_loop_large + new_cursor_offset_next # (N, 1, 2), large-level stroke_position_before_max_min = stroke_position_next # (N, 1, 2), large-level pos_before_max_min_list.append(stroke_position_before_max_min) if self.hps.cursor_type == 'next': cursor_position_loop_large = stroke_position_next # (N, 1, 2), large-level else: raise Exception('Unknown cursor_type') cursor_position_loop_large = tf.maximum(cursor_position_loop_large, 0.0) cursor_position_loop_large = tf.minimum(cursor_position_loop_large, tf.cast(image_size - 1, tf.float32)) cursor_position_loop = tf.div(cursor_position_loop_large, tf.cast(image_size, tf.float32)) curr_canvas_soft = tf.clip_by_value(curr_canvas_soft, 0.0, 1.0) # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke] other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6) pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2) pos_before_max_min_ = tf.concat(pos_before_max_min_list, axis=1) # (N, max_seq_len, 2) win_size_before_max_min_ = tf.concat(win_size_before_max_min_list, axis=1) # (N, max_seq_len, 1) return other_params_, pen_ras_, prev_state, curr_canvas_soft, curr_canvas_soft_rgb, \ pos_before_max_min_, win_size_before_max_min_ def differentiable_argmax(self, input_pen, soft_beta): """ Differentiable argmax trick. :param input_pen: (N, n_class) :return: pen_state: (N, 1) """ def sign_onehot(x): """ :param x: (N, n_class) :return: (N, n_class) """ y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x) y = (y - 1) * (-1) return y def softargmax(x, beta=1e2): """ :param x: (N, n_class) :param beta: 1e10 is the best. 1e2 is acceptable. :return: (N) """ x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2) return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1 ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero. # pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2) # pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1) pen_state = softargmax(input_pen, soft_beta) pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1) return pen_state def build_losses(self, target_sketch, pred_raster_imgs, pred_params, pos_before_max_min, win_size_before_max_min, image_size): def get_raster_loss(pred_imgs, gt_imgs, loss_type): perc_layer_losses_raw = [] perc_layer_losses_weighted = [] perc_layer_losses_norm = [] if loss_type == 'l1': ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs, pred_imgs))) # () elif loss_type == 'l1_small': gt_imgs_small = tf.image.resize_images(tf.expand_dims(gt_imgs, axis=3), (32, 32)) pred_imgs_small = tf.image.resize_images(tf.expand_dims(pred_imgs, axis=3), (32, 32)) ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs_small, pred_imgs_small))) # () elif loss_type == 'mse': ras_cost = tf.reduce_mean(tf.pow(tf.subtract(gt_imgs, pred_imgs), 2)) # () elif loss_type == 'perceptual': return_map_pred = vgg_net_slim(pred_imgs, image_size) return_map_gt = vgg_net_slim(gt_imgs, image_size) perc_loss_type = 'l1' # [l1, mse] weighted_map = {'ReLU1_1': 100.0, 'ReLU1_2': 100.0, 'ReLU2_1': 100.0, 'ReLU2_2': 100.0, 'ReLU3_1': 10.0, 'ReLU3_2': 10.0, 'ReLU3_3': 10.0, 'ReLU4_1': 1.0, 'ReLU4_2': 1.0, 'ReLU4_3': 1.0, 'ReLU5_1': 1.0, 'ReLU5_2': 1.0, 'ReLU5_3': 1.0} for perc_layer in self.hps.perc_loss_layers: if perc_loss_type == 'l1': perc_layer_loss = tf.reduce_mean(tf.abs(tf.subtract(return_map_pred[perc_layer], return_map_gt[perc_layer]))) # () elif perc_loss_type == 'mse': perc_layer_loss = tf.reduce_mean(tf.pow(tf.subtract(return_map_pred[perc_layer], return_map_gt[perc_layer]), 2)) # () else: raise NameError('Unknown perceptual loss type:', perc_loss_type) perc_layer_losses_raw.append(perc_layer_loss) assert perc_layer in weighted_map perc_layer_losses_weighted.append(perc_layer_loss * weighted_map[perc_layer]) for loop_i in range(len(self.hps.perc_loss_layers)): perc_relu_loss_raw = perc_layer_losses_raw[loop_i] # () if self.hps.model_mode == 'train': curr_relu_mean = (self.perc_loss_mean_list[loop_i] * self.last_step_num + perc_relu_loss_raw) / (self.last_step_num + 1.0) relu_cost_norm = perc_relu_loss_raw / curr_relu_mean else: relu_cost_norm = perc_relu_loss_raw perc_layer_losses_norm.append(relu_cost_norm) perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0) perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0) if self.hps.perc_loss_fuse_type == 'max': ras_cost = tf.reduce_max(perc_layer_losses_norm) elif self.hps.perc_loss_fuse_type == 'add': ras_cost = tf.reduce_mean(perc_layer_losses_norm) elif self.hps.perc_loss_fuse_type == 'raw_add': ras_cost = tf.reduce_mean(perc_layer_losses_raw) elif self.hps.perc_loss_fuse_type == 'weighted_sum': ras_cost = tf.reduce_mean(perc_layer_losses_weighted) else: raise NameError('Unknown perc_loss_fuse_type:', self.hps.perc_loss_fuse_type) elif loss_type == 'triplet': raise Exception('Solution for triplet loss is coming soon.') else: raise NameError('Unknown loss type:', loss_type) if loss_type != 'perceptual': for perc_layer_i in self.hps.perc_loss_layers: perc_layer_losses_raw.append(tf.constant(0.0)) perc_layer_losses_norm.append(tf.constant(0.0)) perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0) perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0) return ras_cost, perc_layer_losses_raw, perc_layer_losses_norm gt_raster_images = tf.squeeze(target_sketch, axis=3) # (N, raster_h, raster_w), [0.0-stroke, 1.0-BG] raster_cost, perc_relu_losses_raw, perc_relu_losses_norm = \ get_raster_loss(pred_raster_imgs, gt_raster_images, loss_type=self.hps.raster_loss_base_type) def get_stroke_num_loss(input_strokes): ending_state = input_strokes[:, :, 0] # (N, seq_len) stroke_num_loss_pre = tf.reduce_mean(ending_state) # larger is better, [0.0, 1.0] stroke_num_loss = 1.0 - stroke_num_loss_pre # lower is better, [0.0, 1.0] return stroke_num_loss stroke_num_cost = get_stroke_num_loss(pred_params) # lower is better def get_pos_outside_loss(pos_before_max_min_): pos_after_max_min = tf.maximum(pos_before_max_min_, 0.0) pos_after_max_min = tf.minimum(pos_after_max_min, tf.cast(image_size - 1, tf.float32)) # (N, max_seq_len, 2) pos_outside_loss = tf.reduce_mean(tf.abs(pos_before_max_min_ - pos_after_max_min)) return pos_outside_loss pos_outside_cost = get_pos_outside_loss(pos_before_max_min) # lower is better def get_win_size_outside_loss(win_size_before_max_min_, min_window_size): win_size_outside_top_loss = tf.divide( tf.maximum(win_size_before_max_min_ - tf.cast(image_size, tf.float32), 0.0), tf.cast(image_size, tf.float32)) # (N, max_seq_len, 1) win_size_outside_bottom_loss = tf.divide( tf.maximum(tf.cast(min_window_size, tf.float32) - win_size_before_max_min_, 0.0), tf.cast(min_window_size, tf.float32)) # (N, max_seq_len, 1) win_size_outside_loss = tf.reduce_mean(win_size_outside_top_loss + win_size_outside_bottom_loss) return win_size_outside_loss win_size_outside_cost = get_win_size_outside_loss(win_size_before_max_min, self.hps.min_window_size) # lower is better def get_early_pen_states_loss(input_strokes, curr_start, curr_end): # input_strokes: (N, max_seq_len, 7) pred_early_pen_states = input_strokes[:, curr_start:curr_end, 0] # (N, curr_early_len) pred_early_pen_states_min = tf.reduce_min(pred_early_pen_states, axis=1) # (N), should not be 1 early_pen_states_loss = tf.reduce_mean(pred_early_pen_states_min) # lower is better return early_pen_states_loss early_pen_states_cost = get_early_pen_states_loss(pred_params, self.early_pen_loss_start_idx, self.early_pen_loss_end_idx) return raster_cost, stroke_num_cost, pos_outside_cost, win_size_outside_cost, \ early_pen_states_cost, \ perc_relu_losses_raw, perc_relu_losses_norm def build_training_op_split(self, raster_cost, sn_cost, cursor_outside_cost, win_size_outside_cost, early_pen_states_cost): total_cost = self.hps.raster_loss_weight * raster_cost + \ self.hps.early_pen_loss_weight * early_pen_states_cost + \ self.stroke_num_loss_weight * sn_cost + \ self.hps.outside_loss_weight * cursor_outside_cost + \ self.hps.win_size_outside_loss_weight * win_size_outside_cost tvars = [var for var in tf.trainable_variables() if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name] gvs = self.optimizer.compute_gradients(total_cost, var_list=tvars) return total_cost, gvs def build_training_op(self, grad_list): with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE): gvs = self.average_gradients(grad_list) g = self.hps.grad_clip for grad, var in gvs: print('>>', var.op.name) if grad is None: print(' >> None value') capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = self.optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step') def average_gradients(self, grads_list): """ Compute the average gradients. :param grads_list: list(of length N_GPU) of list(grad, var) :return: """ avg_grads = [] for grad_and_vars in zip(*grads_list): grads = [] for g, _ in grad_and_vars: expanded_g = tf.expand_dims(g, 0) grads.append(expanded_g) grad = tf.concat(grads, axis=0) grad = tf.reduce_mean(grad, axis=0) v = grad_and_vars[0][1] grad_and_var = (grad, v) avg_grads.append(grad_and_var) return avg_grads ================================================ FILE: rasterization_utils/NeuralRenderer.py ================================================ import tensorflow as tf class RasterUnit(object): def __init__(self, raster_size, input_params, # (N, 10) reuse=False): self.raster_size = raster_size self.input_params = input_params with tf.variable_scope("raster_unit", reuse=reuse): self.build_unit() def build_unit(self): x = self.input_params # (N, 10) x = self.fully_connected(x, 10, 512, scope='fc1') # (N, 512) x = tf.nn.relu(x) x = self.fully_connected(x, 512, 1024, scope='fc2') # (N, 1024) x = tf.nn.relu(x) x = self.fully_connected(x, 1024, 2048, scope='fc3') # (N, 2048) x = tf.nn.relu(x) x = self.fully_connected(x, 2048, 4096, scope='fc4') # (N, 4096) x = tf.nn.relu(x) x = tf.reshape(x, (-1, 16, 16, 16)) # (N, 16, 16, 16) x = tf.transpose(x, (0, 2, 3, 1)) x = self.conv2d(x, 32, 3, 1, scope='conv1') # (N, 16, 16, 32) x = tf.nn.relu(x) x = self.conv2d(x, 32, 3, 1, scope='conv2') # (N, 16, 16, 32) x = self.pixel_shuffle(x, upscale_factor=2) # (N, 32, 32, 8) x = self.conv2d(x, 16, 3, 1, scope='conv3') # (N, 32, 32, 16) x = tf.nn.relu(x) x = self.conv2d(x, 16, 3, 1, scope='conv4') # (N, 32, 32, 16) x = self.pixel_shuffle(x, upscale_factor=2) # (N, 64, 64, 4) x = self.conv2d(x, 8, 3, 1, scope='conv5') # (N, 64, 64, 8) x = tf.nn.relu(x) x = self.conv2d(x, 4, 3, 1, scope='conv6') # (N, 64, 64, 4) x = self.pixel_shuffle(x, upscale_factor=2) # (N, 128, 128, 1) x = tf.sigmoid(x) # (N, 128, 128), [0.0-stroke, 1.0-BG] self.stroke_image = 1.0 - tf.reshape(x, (-1, self.raster_size, self.raster_size)) def conv2d(self, input_tensor, out_channels, kernel_size, stride, scope, reuse=False): with tf.variable_scope(scope, reuse=reuse): output_tensor = tf.layers.conv2d(input_tensor, out_channels, kernel_size=kernel_size, strides=(stride, stride), padding="same", kernel_initializer=tf.keras.initializers.he_normal()) return output_tensor def fully_connected(self, input_tensor, in_dim, out_dim, scope, reuse=False): with tf.variable_scope(scope, reuse=reuse): weight = tf.get_variable("weight", [in_dim, out_dim], dtype=tf.float32, initializer=tf.random_normal_initializer()) bias = tf.get_variable("bias", [out_dim], dtype=tf.float32, initializer=tf.random_normal_initializer()) output_tensor = tf.matmul(input_tensor, weight) + bias return output_tensor def pixel_shuffle(self, input_tensor, upscale_factor): params_shape = input_tensor.get_shape() n, h, w, c = params_shape input_tensor_proc = tf.reshape(input_tensor, (n, h, w, c // 4, 4)) input_tensor_proc = tf.transpose(input_tensor_proc, (0, 1, 2, 4, 3)) input_tensor_proc = tf.reshape(input_tensor_proc, (n, h, w, -1)) output_tensor = tf.depth_to_space(input_tensor_proc, block_size=upscale_factor) return output_tensor class NeuralRasterizor(object): def __init__(self, raster_size, seq_len, position_format='abs', raster_padding=10, strokes_format=3): self.raster_size = raster_size self.seq_len = seq_len self.position_format = position_format self.raster_padding = raster_padding self.strokes_format = strokes_format assert position_format in ['abs', 'rel'] def raster_func_abs(self, input_data, raster_seq_len=None): """ x and y in absolute position. :param input_data: (N, seq_len, 10): [x0, y0, x1, y1, x2, y2, r0, r2, w0, w2]. All in [0.0, 1.0] :return: """ seq_len = raster_seq_len if raster_seq_len is not None else self.seq_len raster_params = tf.transpose(input_data, [1, 0, 2]) # (seq_len, N, 10) seq_stroke_images = tf.map_fn(self.stroke_drawer_with_raster_unit, raster_params, parallel_iterations=32) # (seq_len, N, raster_size, raster_size) seq_stroke_images = tf.transpose(seq_stroke_images, [1, 2, 3, 0]) # (N, raster_size, raster_size, seq_len), [0.0-stroke, 1.0-BG] filter_seq_stroke_images = 1.0 - seq_stroke_images # (N, raster_size, raster_size, seq_len), [0.0-BG, 1.0-stroke] # stacking stroke_images_unclip = tf.reduce_sum(filter_seq_stroke_images, axis=-1) # (N, raster_size, raster_size) stroke_images = tf.clip_by_value(stroke_images_unclip, 0.0, 1.0) # [0.0-BG, 1.0-stroke] return stroke_images, stroke_images_unclip, seq_stroke_images def stroke_drawer_with_raster_unit(self, params_batch): """ Convert two points into a raster stroke image with RasterUnit. :param params_batch: (N, 10) :return: (N, raster_size, raster_size) """ raster_unit = RasterUnit( raster_size=self.raster_size, input_params=params_batch, reuse=tf.AUTO_REUSE ) stroke_image = raster_unit.stroke_image # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG] return stroke_image class NeuralRasterizorStep(object): def __init__(self, raster_size, position_format='abs'): self.raster_size = raster_size self.position_format = position_format assert position_format in ['abs', 'rel'] def raster_func_stroke_abs(self, input_data): """ x and y in absolute position. :param input_data: (N, 8): [x0, y0, x1, y1, x2, y2, r0, r2]. All in [0.0, 1.0] :return: """ w_in = tf.ones(shape=(input_data.shape[0], 2), dtype=tf.float32) raster_params = tf.concat([input_data, w_in], axis=-1) # (N, 10) stroke_image = self.stroke_drawer_with_raster_unit(raster_params) # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG] stroke_image = 1.0 - stroke_image # [0.0-BG, 1.0-stroke] return stroke_image def mask_ending_state(self, input_states): """ Mask the ending state to be 1 :param input_states: (N, seq_len, 1) in offset manner :param seq_len: :return: """ ending_state_accu = tf.cumsum(input_states, axis=1) # (N, seq_len, 1) ending_state_clip = tf.clip_by_value(ending_state_accu, 0.0, 1.0) # (N, seq_len, 1) return ending_state_clip def stroke_drawer_with_raster_unit(self, params_batch): """ Convert two points into a raster stroke image with RasterUnit. :param params_batch: (N, 10) :return: (N, raster_size, raster_size) """ raster_unit = RasterUnit( raster_size=self.raster_size, input_params=params_batch, reuse=tf.AUTO_REUSE ) stroke_image = raster_unit.stroke_image # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG] return stroke_image ================================================ FILE: rasterization_utils/RealRenderer.py ================================================ import numpy as np import gizeh class GizehRasterizor(object): def __init__(self): self.name = 'GizehRasterizor' def get_line_array_v2(self, image_size, seq_strokes, stroke_width, is_bin=True): """ :param p1: (x, y) :param p2: (x, y) :return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes """ surface = gizeh.Surface(width=image_size, height=image_size) # in pixels shape_list = [] for seq_i in range(len(seq_strokes) - 1): p1, p2 = seq_strokes[seq_i, :2], seq_strokes[seq_i + 1, :2] pen_state = seq_strokes[seq_i, 2] if pen_state == 0.0: line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0)) shape_list.append(line) group = gizeh.Group(shape_list) group.draw(surface) # Now export the surface line_arr = surface.get_npimage()[:, :, 0] # returns a (width x height x 3) numpy array if is_bin: line_arr[line_arr <= 128] = 0 line_arr[line_arr != 0] = 1 # (image_size, image_size) else: line_arr = np.array(line_arr, dtype=np.float32) / 255.0 return line_arr def get_line_array(self, p1, p2, image_size, stroke_width, is_bin=True): """ :param p1: (x, y) :param p2: (x, y) :return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes """ surface = gizeh.Surface(width=image_size, height=image_size) # in pixels line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0)) line.draw(surface) # Now export the surface line_arr = surface.get_npimage()[:, :, 0] # returns a (width x height x 3) numpy array if is_bin: line_arr[line_arr <= 128] = 0 line_arr[line_arr != 0] = 1 # (image_size, image_size) else: line_arr = np.array(line_arr, dtype=np.float32) / 255.0 return line_arr def load_sketch_images_on_the_fly_v2(self, image_size, norm_strokes3, stroke_width, is_bin=True): """ :param norm_strokes3: list (N_sketches,), each with (N_points, 3) :return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes """ assert type(norm_strokes3) is list sketch_imgs_list = [] for stroke_i in range(len(norm_strokes3)): seq_strokes3 = norm_strokes3[stroke_i] # (N_points, 3) sketch_img = self.get_line_array_v2(image_size, seq_strokes3, stroke_width=stroke_width, is_bin=is_bin) sketch_img = np.clip(sketch_img, 0.0, 1.0) # (image_size, image_size), 0 for BG and 1 for strokes sketch_imgs_list.append(sketch_img) return sketch_imgs_list def load_sketch_images_on_the_fly(self, image_size, norm_strokes3, stroke_width, is_bin=True): """ :param norm_strokes3: list (N_sketches,), each with (N_points, 3) :return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes """ assert type(norm_strokes3) is list sketch_imgs_list = [] for stroke_i in range(len(norm_strokes3)): seq_strokes3 = norm_strokes3[stroke_i] # (N_points, 3) seq_len = len(seq_strokes3) stroke_imgs_list = [] for seq_i in range(seq_len - 1): stroke_img = self.get_line_array(seq_strokes3[seq_i, :2], seq_strokes3[seq_i + 1, :2], image_size, stroke_width=stroke_width, is_bin=is_bin) pen_state = seq_strokes3[seq_i, 2] stroke_img = stroke_img.astype(np.float32) * (1. - pen_state) stroke_imgs_list.append(stroke_img) stroke_imgs_list = np.stack(stroke_imgs_list, axis=-1) # (image_size, image_size, seq_len-1), 0 for BG and 1 for strokes stroke_imgs_list = np.sum(stroke_imgs_list, axis=-1) stroke_imgs_list = np.clip(stroke_imgs_list, 0.0, 1.0) # (image_size, image_size), 0 for BG and 1 for strokes sketch_imgs_list.append(stroke_imgs_list) return sketch_imgs_list def normalize_coordinate_np(self, sx, sy, image_size, raster_padding=10.0): """ Convert offset to normalized absolute points. The numpy version as in NeuralRasterizor. :param sx: (N, seq_len) :param sy: (N, seq_len) :return: """ seq_len = sx.shape[1] # transfer to abs points abs_x = np.cumsum(sx, axis=1) # (N, seq_len) abs_y = np.cumsum(sy, axis=1) min_x = np.min(abs_x, axis=1, keepdims=True) # (N, 1) max_x = np.max(abs_x, axis=1, keepdims=True) min_y = np.min(abs_y, axis=1, keepdims=True) max_y = np.max(abs_y, axis=1, keepdims=True) # transform to positive coordinate abs_x = np.subtract(abs_x, np.tile(min_x, [1, seq_len])) # (N, seq_len) abs_y = np.subtract(abs_y, np.tile(min_y, [1, seq_len])) # scaling to [0.0, raster_size - 2 * padding - 1] bbox_w = np.squeeze(np.subtract(max_x, min_x), axis=-1) # (N) bbox_h = np.squeeze(np.subtract(max_y, min_y), axis=-1) unpad_raster_size = (image_size - 1.0) - 2.0 * raster_padding scaling = np.divide(unpad_raster_size, np.maximum(bbox_w, bbox_h)) # (N) scaling_tile = np.tile(np.expand_dims(scaling, axis=-1), [1, seq_len]) # (N, seq_len) abs_x = np.multiply(abs_x, scaling_tile) # (N, seq_len) abs_y = np.multiply(abs_y, scaling_tile) # add padding abs_x = np.add(abs_x, raster_padding) # (N, seq_len) abs_y = np.add(abs_y, raster_padding) # transform to the middle trans_x = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_w, scaling)), 2.0) # (N) trans_y = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_h, scaling)), 2.0) trans_x = np.tile(np.expand_dims(trans_x, axis=-1), [1, seq_len]) # (N, seq_len) trans_y = np.tile(np.expand_dims(trans_y, axis=-1), [1, seq_len]) # (N, seq_len) abs_x = np.add(abs_x, trans_x) # (N, seq_len) abs_y = np.add(abs_y, trans_y) return abs_x, abs_y def normalize_strokes_np(self, strokes_list, image_size): """ :param strokes_list: list (N_sketches,), each with (N_points, 3) :return: """ assert type(strokes_list) is list rst_list = [] for i in range(len(strokes_list)): strokes_data = strokes_list[i] # (N_points, 3) norm_x, norm_y = self.normalize_coordinate_np(np.expand_dims(strokes_data[:, 0], axis=0), np.expand_dims(strokes_data[:, 1], axis=0), image_size) # (1, N_points) norm_strokes_data = np.stack([norm_x[0], norm_y[0], strokes_data[:, 2]], axis=-1) # (N_points, 3) rst_list.append(norm_strokes_data) return rst_list def raster_func(self, input_data, image_size, stroke_width, is_bin=True, version='v2'): """ :param input_data: (N_sketches,), each with (N_points, 3) :return: raster_image_array: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes """ norm_test_strokes3 = self.normalize_strokes_np(input_data, image_size) if version == 'v1': raster_image_array = self.load_sketch_images_on_the_fly(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin) else: raster_image_array = self.load_sketch_images_on_the_fly_v2(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin) return raster_image_array ================================================ FILE: rnn.py ================================================ # Copyright 2019 The Magenta Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SketchRNN RNN definition.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf def orthogonal(shape): """Orthogonal initilaizer.""" flat_shape = (shape[0], np.prod(shape[1:])) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v return q.reshape(shape) def orthogonal_initializer(scale=1.0): """Orthogonal initializer.""" def _initializer(shape, dtype=tf.float32, partition_info=None): # pylint: disable=unused-argument return tf.constant(orthogonal(shape) * scale, dtype) return _initializer def lstm_ortho_initializer(scale=1.0): """LSTM orthogonal initializer.""" def _initializer(shape, dtype=tf.float32, partition_info=None): # pylint: disable=unused-argument size_x = shape[0] size_h = shape[1] // 4 # assumes lstm. t = np.zeros(shape) t[:, :size_h] = orthogonal([size_x, size_h]) * scale t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale return tf.constant(t, dtype) return _initializer class LSTMCell(tf.contrib.rnn.RNNCell): """Vanilla LSTM cell. Uses ortho initializer, and also recurrent dropout without memory loss (https://arxiv.org/abs/1603.05118) """ def __init__(self, num_units, forget_bias=1.0, use_recurrent_dropout=False, dropout_keep_prob=0.9): self.num_units = num_units self.forget_bias = forget_bias self.use_recurrent_dropout = use_recurrent_dropout self.dropout_keep_prob = dropout_keep_prob @property def state_size(self): return 2 * self.num_units @property def output_size(self): return self.num_units def get_output(self, state): unused_c, h = tf.split(state, 2, 1) return h def __call__(self, x, state, scope=None): with tf.variable_scope(scope or type(self).__name__): c, h = tf.split(state, 2, 1) x_size = x.get_shape().as_list()[1] w_init = None # uniform h_init = lstm_ortho_initializer(1.0) # Keep W_xh and W_hh separate here as well to use different init methods. w_xh = tf.get_variable( 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) w_hh = tf.get_variable( 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) bias = tf.get_variable( 'bias', [4 * self.num_units], initializer=tf.constant_initializer(0.0)) concat = tf.concat([x, h], 1) w_full = tf.concat([w_xh, w_hh], 0) hidden = tf.matmul(concat, w_full) + bias i, j, f, o = tf.split(hidden, 4, 1) if self.use_recurrent_dropout: g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) else: g = tf.tanh(j) new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g new_h = tf.tanh(new_c) * tf.sigmoid(o) return new_h, tf.concat([new_c, new_h], 1) # fuk tuples. def layer_norm_all(h, batch_size, base, num_units, scope='layer_norm', reuse=False, gamma_start=1.0, epsilon=1e-3, use_bias=True): """Layer Norm (faster version, but not using defun).""" # Performs layer norm on multiple base at once (ie, i, g, j, o for lstm) # Reshapes h in to perform layer norm in parallel h_reshape = tf.reshape(h, [batch_size, base, num_units]) mean = tf.reduce_mean(h_reshape, [2], keep_dims=True) var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True) epsilon = tf.constant(epsilon) rstd = tf.rsqrt(var + epsilon) h_reshape = (h_reshape - mean) * rstd # reshape back to original h = tf.reshape(h_reshape, [batch_size, base * num_units]) with tf.variable_scope(scope): if reuse: tf.get_variable_scope().reuse_variables() gamma = tf.get_variable( 'ln_gamma', [4 * num_units], initializer=tf.constant_initializer(gamma_start)) if use_bias: beta = tf.get_variable( 'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0)) if use_bias: return gamma * h + beta return gamma * h def layer_norm(x, num_units, scope='layer_norm', reuse=False, gamma_start=1.0, epsilon=1e-3, use_bias=True): """Calculate layer norm.""" axes = [1] mean = tf.reduce_mean(x, axes, keep_dims=True) x_shifted = x - mean var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True) inv_std = tf.rsqrt(var + epsilon) with tf.variable_scope(scope): if reuse: tf.get_variable_scope().reuse_variables() gamma = tf.get_variable( 'ln_gamma', [num_units], initializer=tf.constant_initializer(gamma_start)) if use_bias: beta = tf.get_variable( 'ln_beta', [num_units], initializer=tf.constant_initializer(0.0)) output = gamma * (x_shifted) * inv_std if use_bias: output += beta return output def raw_layer_norm(x, epsilon=1e-3): axes = [1] mean = tf.reduce_mean(x, axes, keep_dims=True) std = tf.sqrt( tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True) + epsilon) output = (x - mean) / (std) return output def super_linear(x, output_size, scope=None, reuse=False, init_w='ortho', weight_start=0.0, use_bias=True, bias_start=0.0, input_size=None): """Performs linear operation. Uses ortho init defined earlier.""" shape = x.get_shape().as_list() with tf.variable_scope(scope or 'linear'): if reuse: tf.get_variable_scope().reuse_variables() w_init = None # uniform if input_size is None: x_size = shape[1] else: x_size = input_size if init_w == 'zeros': w_init = tf.constant_initializer(0.0) elif init_w == 'constant': w_init = tf.constant_initializer(weight_start) elif init_w == 'gaussian': w_init = tf.random_normal_initializer(stddev=weight_start) elif init_w == 'ortho': w_init = lstm_ortho_initializer(1.0) w = tf.get_variable( 'super_linear_w', [x_size, output_size], tf.float32, initializer=w_init) if use_bias: b = tf.get_variable( 'super_linear_b', [output_size], tf.float32, initializer=tf.constant_initializer(bias_start)) return tf.matmul(x, w) + b return tf.matmul(x, w) class LayerNormLSTMCell(tf.contrib.rnn.RNNCell): """Layer-Norm, with Ortho Init. and Recurrent Dropout without Memory Loss. https://arxiv.org/abs/1607.06450 - Layer Norm https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss """ def __init__(self, num_units, forget_bias=1.0, use_recurrent_dropout=False, dropout_keep_prob=0.90): """Initialize the Layer Norm LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (default 1.0). use_recurrent_dropout: Whether to use Recurrent Dropout (default False) dropout_keep_prob: float, dropout keep probability (default 0.90) """ self.num_units = num_units self.forget_bias = forget_bias self.use_recurrent_dropout = use_recurrent_dropout self.dropout_keep_prob = dropout_keep_prob @property def input_size(self): return self.num_units @property def output_size(self): return self.num_units @property def state_size(self): return 2 * self.num_units def get_output(self, state): h, unused_c = tf.split(state, 2, 1) return h def __call__(self, x, state, timestep=0, scope=None): with tf.variable_scope(scope or type(self).__name__): h, c = tf.split(state, 2, 1) h_size = self.num_units x_size = x.get_shape().as_list()[1] batch_size = x.get_shape().as_list()[0] w_init = None # uniform h_init = lstm_ortho_initializer(1.0) w_xh = tf.get_variable( 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) w_hh = tf.get_variable( 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) concat = tf.concat([x, h], 1) # concat for speed. w_full = tf.concat([w_xh, w_hh], 0) concat = tf.matmul(concat, w_full) # + bias # live life without garbage. # i = input_gate, j = new_input, f = forget_gate, o = output_gate concat = layer_norm_all(concat, batch_size, 4, h_size, 'ln_all') i, j, f, o = tf.split(concat, 4, 1) if self.use_recurrent_dropout: g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) else: g = tf.tanh(j) new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g new_h = tf.tanh(layer_norm(new_c, h_size, 'ln_c')) * tf.sigmoid(o) return new_h, tf.concat([new_h, new_c], 1) class HyperLSTMCell(tf.contrib.rnn.RNNCell): """HyperLSTM with Ortho Init, Layer Norm, Recurrent Dropout, no Memory Loss. https://arxiv.org/abs/1609.09106 http://blog.otoro.net/2016/09/28/hyper-networks/ """ def __init__(self, num_units, forget_bias=1.0, use_recurrent_dropout=False, dropout_keep_prob=0.90, use_layer_norm=True, hyper_num_units=256, hyper_embedding_size=32, hyper_use_recurrent_dropout=False): """Initialize the Layer Norm HyperLSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (default 1.0). use_recurrent_dropout: Whether to use Recurrent Dropout (default False) dropout_keep_prob: float, dropout keep probability (default 0.90) use_layer_norm: boolean. (default True) Controls whether we use LayerNorm layers in main LSTM & HyperLSTM cell. hyper_num_units: int, number of units in HyperLSTM cell. (default is 128, recommend experimenting with 256 for larger tasks) hyper_embedding_size: int, size of signals emitted from HyperLSTM cell. (default is 16, recommend trying larger values for large datasets) hyper_use_recurrent_dropout: boolean. (default False) Controls whether HyperLSTM cell also uses recurrent dropout. Recommend turning this on only if hyper_num_units becomes large (>= 512) """ self.num_units = num_units self.forget_bias = forget_bias self.use_recurrent_dropout = use_recurrent_dropout self.dropout_keep_prob = dropout_keep_prob self.use_layer_norm = use_layer_norm self.hyper_num_units = hyper_num_units self.hyper_embedding_size = hyper_embedding_size self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout self.total_num_units = self.num_units + self.hyper_num_units if self.use_layer_norm: cell_fn = LayerNormLSTMCell else: cell_fn = LSTMCell self.hyper_cell = cell_fn( hyper_num_units, use_recurrent_dropout=hyper_use_recurrent_dropout, dropout_keep_prob=dropout_keep_prob) @property def input_size(self): return self._input_size @property def output_size(self): return self.num_units @property def state_size(self): return 2 * self.total_num_units def get_output(self, state): total_h, unused_total_c = tf.split(state, 2, 1) h = total_h[:, 0:self.num_units] return h def hyper_norm(self, layer, scope='hyper', use_bias=True): num_units = self.num_units embedding_size = self.hyper_embedding_size # recurrent batch norm init trick (https://arxiv.org/abs/1603.09025). init_gamma = 0.10 # cooijmans' da man. with tf.variable_scope(scope): zw = super_linear( self.hyper_output, embedding_size, init_w='constant', weight_start=0.00, use_bias=True, bias_start=1.0, scope='zw') alpha = super_linear( zw, num_units, init_w='constant', weight_start=init_gamma / embedding_size, use_bias=False, scope='alpha') result = tf.multiply(alpha, layer) if use_bias: zb = super_linear( self.hyper_output, embedding_size, init_w='gaussian', weight_start=0.01, use_bias=False, bias_start=0.0, scope='zb') beta = super_linear( zb, num_units, init_w='constant', weight_start=0.00, use_bias=False, scope='beta') result += beta return result def __call__(self, x, state, timestep=0, scope=None): with tf.variable_scope(scope or type(self).__name__): total_h, total_c = tf.split(state, 2, 1) h = total_h[:, 0:self.num_units] c = total_c[:, 0:self.num_units] self.hyper_state = tf.concat( [total_h[:, self.num_units:], total_c[:, self.num_units:]], 1) batch_size = x.get_shape().as_list()[0] x_size = x.get_shape().as_list()[1] self._input_size = x_size w_init = None # uniform h_init = lstm_ortho_initializer(1.0) w_xh = tf.get_variable( 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) w_hh = tf.get_variable( 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) bias = tf.get_variable( 'bias', [4 * self.num_units], initializer=tf.constant_initializer(0.0)) # concatenate the input and hidden states for hyperlstm input hyper_input = tf.concat([x, h], 1) hyper_output, hyper_new_state = self.hyper_cell(hyper_input, self.hyper_state) self.hyper_output = hyper_output self.hyper_state = hyper_new_state xh = tf.matmul(x, w_xh) hh = tf.matmul(h, w_hh) # split Wxh contributions ix, jx, fx, ox = tf.split(xh, 4, 1) ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False) jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False) fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False) ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False) # split Whh contributions ih, jh, fh, oh = tf.split(hh, 4, 1) ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True) jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True) fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True) oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True) # split bias ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted. # i = input_gate, j = new_input, f = forget_gate, o = output_gate i = ix + ih + ib j = jx + jh + jb f = fx + fh + fb o = ox + oh + ob if self.use_layer_norm: concat = tf.concat([i, j, f, o], 1) concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all') i, j, f, o = tf.split(concat, 4, 1) if self.use_recurrent_dropout: g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) else: g = tf.tanh(j) new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o) hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1) new_total_h = tf.concat([new_h, hyper_h], 1) new_total_c = tf.concat([new_c, hyper_c], 1) new_total_state = tf.concat([new_total_h, new_total_c], 1) return new_h, new_total_state ================================================ FILE: subnet_tf_utils.py ================================================ import tensorflow as tf def get_initializer(init_method): if init_method == 'xavier_normal': initializer = tf.glorot_normal_initializer() elif init_method == 'xavier_uniform': initializer = tf.glorot_uniform_initializer() elif init_method == 'he_normal': initializer = tf.keras.initializers.he_normal() elif init_method == 'he_uniform': initializer = tf.keras.initializers.he_uniform() elif init_method == 'lecun_normal': initializer = tf.keras.initializers.lecun_normal() elif init_method == 'lecun_uniform': initializer = tf.keras.initializers.lecun_uniform() else: raise Exception('Unknown initializer:', init_method) return initializer def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False): with tf.variable_scope(name) as scope: if alt_relu_impl: f1 = 0.5 * (1 + leak) f2 = 0.5 * (1 - leak) return f1 * x + f2 * abs(x) else: return tf.maximum(x, leak * x) def batchnorm(input, name='batch_norm', init_method=None): if init_method is not None: initializer = get_initializer(init_method) else: initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32) with tf.variable_scope(name): # this block looks like it has 3 inputs on the graph unless we do this input = tf.identity(input) channels = input.get_shape()[3] offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=initializer) mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) variance_epsilon = 1e-5 normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) return normalized def layernorm(input, name='layer_norm', init_method=None): if init_method is not None: initializer = get_initializer(init_method) else: initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32) with tf.variable_scope(name): n_neurons = input.get_shape()[3] offset = tf.get_variable("offset", [n_neurons], dtype=tf.float32, initializer=tf.zeros_initializer()) scale = tf.get_variable("scale", [n_neurons], dtype=tf.float32, initializer=initializer) offset = tf.reshape(offset, [1, 1, -1]) scale = tf.reshape(scale, [1, 1, -1]) mean, variance = tf.nn.moments(input, axes=[1, 2, 3], keep_dims=True) variance_epsilon = 1e-5 normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) return normalized def instance_norm(input, name="instance_norm", init_method=None): if init_method is not None: initializer = get_initializer(init_method) else: initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32) with tf.variable_scope(name): depth = input.get_shape()[3] scale = tf.get_variable("scale", [depth], initializer=initializer) offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True) epsilon = 1e-5 inv = tf.rsqrt(variance + epsilon) normalized = (input - mean) * inv return scale * normalized + offset def linear1d(inputlin, inputdim, outputdim, name="linear1d", init_method=None): if init_method is not None: initializer = get_initializer(init_method) else: initializer = None with tf.variable_scope(name) as scope: weight = tf.get_variable("weight", [inputdim, outputdim], initializer=initializer) bias = tf.get_variable("bias", [outputdim], dtype=tf.float32, initializer=tf.constant_initializer(0.0)) return tf.matmul(inputlin, weight) + bias def general_conv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2, stddev=0.02, padding="SAME", name="conv2d", do_norm=True, norm_type='instance_norm', do_relu=True, relufactor=0, is_training=True, init_method=None): if init_method is not None: initializer = get_initializer(init_method) else: initializer = tf.truncated_normal_initializer(stddev=stddev) with tf.variable_scope(name) as scope: conv = tf.contrib.layers.conv2d(inputconv, output_dim, [filter_width, filter_height], [stride_width, stride_height], padding, activation_fn=None, weights_initializer=initializer, biases_initializer=tf.constant_initializer(0.0)) if do_norm: if norm_type == 'instance_norm': conv = instance_norm(conv, init_method=init_method) # conv = tf.contrib.layers.instance_norm(conv, epsilon=1e-05, center=True, scale=True, # scope='instance_norm') elif norm_type == 'batch_norm': # conv = batchnorm(conv, init_method=init_method) conv = tf.contrib.layers.batch_norm(conv, decay=0.9, is_training=is_training, updates_collections=None, epsilon=1e-5, center=True, scale=True, scope="batch_norm") elif norm_type == 'layer_norm': # conv = layernorm(conv, init_method=init_method) conv = tf.contrib.layers.layer_norm(conv, center=True, scale=True, scope='layer_norm') if do_relu: if relufactor == 0: conv = tf.nn.relu(conv, "relu") else: conv = lrelu(conv, relufactor, "lrelu") return conv def generative_cnn_c3_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4)) tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 512] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_combine33(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) tensor_x = tf.concat([local_x, global_x], axis=-1) with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE): tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_combine43(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) tensor_x = tf.concat([local_x, global_x], axis=-1) with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE): tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_combine53(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x = tf.concat([local_x, global_x], axis=-1) with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE): tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 1024 * 4 * 4)) tensor_x = linear1d(tensor_x, 1024 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_combineFC(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) local_x = tf.reshape(local_x, (-1, 512 * 4 * 4)) local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: local_x = tf.nn.dropout(local_x, drop_keep_prob) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) global_x = tf.reshape(global_x, (-1, 512 * 4 * 4)) global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: global_x = tf.nn.dropout(global_x, drop_keep_prob) tensor_x_sp = None tensor_x = tf.concat([local_x, global_x], axis=-1) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_combineFC_jointAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None, combine_manner='attn'): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) share_x = local_x with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE): attn_x = general_conv2d(share_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1", init_method=init_method) attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2", init_method=init_method) attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3", init_method=init_method) attn_map = tf.nn.sigmoid(attn_x) # (N, H/8, W/8, 1), [0.0, 1.0] if combine_manner == 'attn': attn_feat = attn_map * share_x + share_x else: raise Exception('Unknown combine_manner', combine_manner) local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) local_x = tf.reshape(local_x, (-1, 512 * 4 * 4)) local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: local_x = tf.nn.dropout(local_x, drop_keep_prob) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) global_x = tf.reshape(global_x, (-1, 512 * 4 * 4)) global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: global_x = tf.nn.dropout(global_x, drop_keep_prob) tensor_x_sp = None tensor_x = tf.concat([local_x, global_x], axis=-1) return tensor_x, tensor_x_sp, attn_map def generative_cnn_c3_encoder_combineFC_sepAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None, combine_manner='attn'): local_x = local_inputs global_x = global_inputs with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE): attn_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) attn_x = general_conv2d(attn_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) attn_x = general_conv2d(attn_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) attn_map = tf.nn.sigmoid(attn_x) # (N, H/8, W/8, 1), [0.0, 1.0] with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE): local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) if combine_manner == 'attn': attn_feat = attn_map * local_x + local_x elif combine_manner == 'channel': attn_feat = tf.concat([local_x, attn_map], axis=-1) else: raise Exception('Unknown combine_manner', combine_manner) local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) local_x = tf.reshape(local_x, (-1, 512 * 4 * 4)) local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: local_x = tf.nn.dropout(local_x, drop_keep_prob) with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE): global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) global_x = tf.reshape(global_x, (-1, 512 * 4 * 4)) global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: global_x = tf.nn.dropout(global_x, drop_keep_prob) tensor_x_sp = None tensor_x = tf.concat([local_x, global_x], axis=-1) return tensor_x, tensor_x_sp, attn_map def generative_cnn_c3_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_c3_encoder_deeper13_attn(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = self_attention(tensor_x, 64) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4)) tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 512, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 512] tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4)) tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def generative_cnn_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None): tensor_x = inputs with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_1_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_2_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_3_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_4_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_5", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training, name="CNN_ENC_5_3", init_method=init_method) tensor_x_sp = tensor_x # [N, h, w, 256] tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4)) tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method) if is_training: tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob) return tensor_x, tensor_x_sp def max_pooling(x) : return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME') def hw_flatten(x) : return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) def self_attention(x, in_channel, name='self_attention'): with tf.variable_scope(name) as scope: f = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1, do_norm=False, do_relu=False, name='f_conv') # (N, h, w, c') f = max_pooling(f) # (N, h', w', c') g = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1, do_norm=False, do_relu=False, name='g_conv') # (N, h, w, c') h = general_conv2d(x, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1, do_norm=False, do_relu=False, name='h_conv') # (N, h, w, c) h = max_pooling(h) # (N, h', w', c) # M = h * w, M' = h' * w' s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # (N, M, M') beta = tf.nn.softmax(s) # attention map o = tf.matmul(beta, hw_flatten(h)) # (N, M, c) gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) o = tf.reshape(o, shape=x.shape) # (N, h, w, c) o = general_conv2d(o, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1, do_norm=False, do_relu=False, name='attn_conv') x = gamma * o + x return x def global_avg_pooling(x): gap = tf.reduce_mean(x, axis=[1, 2]) return gap def cnn_discriminator_wgan_gp(discrim_inputs, discrim_targets, init_method=None): tensor_x = tf.concat([discrim_inputs, discrim_targets], axis=3) # (N, H, W, 3 + 1) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, is_training=True, name="CNN_ENC_1", init_method=init_method) tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, is_training=True, name="CNN_ENC_2", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=True, name="CNN_ENC_3", init_method=init_method) tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, is_training=True, name="CNN_ENC_4", init_method=init_method) tensor_x = general_conv2d(tensor_x, 1, filter_height=3, filter_width=3, is_training=True, name="CNN_ENC_5", init_method=init_method) # (N, H/32, W/32, 1) d_out = global_avg_pooling(tensor_x) # (N, 1) return d_out ================================================ FILE: test_photograph_to_line.py ================================================ import numpy as np import os import tensorflow as tf from six.moves import range from PIL import Image import argparse import hyper_parameters as hparams from model_common_test import DiffPastingV3, VirtualSketchingModel from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \ save_seq_data, image_pasting_v3_testing, draw_strokes from dataset_utils import load_dataset_testing os.environ['CUDA_VISIBLE_DEVICES'] = '0' def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_len, state_dependent, pasting_func): """Samples a sequence from a pre-trained model.""" select_times = 1 cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2) curr_canvas = np.zeros(dtype=np.float32, shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke] initial_state = sess.run(model.initial_state) prev_state = initial_state prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0) prev_scaling = np.ones((select_times), dtype=np.float32) # (N) prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N) params_list = [[] for _ in range(select_times)] state_raw_list = [[] for _ in range(select_times)] state_soft_list = [[] for _ in range(select_times)] window_size_list = [[] for _ in range(select_times)] input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1)) for i in range(seq_len): if not state_dependent and i % init_len == 0: prev_state = initial_state curr_window_size = prev_scaling * prev_window_size # (N) curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size) curr_window_size = np.minimum(curr_window_size, image_size) feed = { model.initial_state: prev_state, model.input_photo: input_photos_tiles, model.curr_canvas_hard: curr_canvas.copy(), model.cursor_position: cursor_pos, model.image_size: image_size, model.init_width: prev_width, model.init_scaling: prev_scaling, model.init_window_size: prev_window_size, } o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \ sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed) # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024) # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2] idx_eos_list = np.argmax(o_pen_list, axis=1) # (N) for output_i in range(idx_eos_list.shape[0]): idx_eos = idx_eos_list[output_i] eos = [0, 0] eos[idx_eos] = 1 other_params = o_other_params_list[output_i].tolist() # (6) params_list[output_i].append([eos[1]] + other_params) state_raw_list[output_i].append(o_pen_list[output_i][1]) state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0]) window_size_list[output_i].append(curr_window_size[output_i]) # draw the stroke and add to the canvas x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \ o_other_params_list[output_i, 4] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width[output_i], width2], axis=0) # (2) o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8) if idx_eos == 0: f = o_other_params_proc + [1.0, 1.0] pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG] pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0], image_size, curr_window_size[output_i], pasting_func, sess) # [0.0-BG, 1.0-stroke] curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke] curr_canvas = np.clip(curr_canvas, 0.0, 1.0) next_width = o_other_params_list[:, 4] # (N) next_scaling = o_other_params_list[:, 5] next_window_size = next_scaling * curr_window_size # (N) next_window_size = np.maximum(next_window_size, model.hps.min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_state = next_state_list prev_width = next_width * curr_window_size / next_window_size # (N,) prev_scaling = next_scaling # (N) prev_window_size = curr_window_size # update cursor_pos based on hps.cursor_type new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level if model.hps.cursor_type == 'next': cursor_pos_large = stroke_position_next # (N, 2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (N, 2), large-level cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2) cursor_pos = cursor_pos_large / float(image_size) return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list def main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, draw_seq=False, draw_order=False, state_dependent=True, longer_infer_len=-1): model_params_default = hparams.get_default_hparams_normal() model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset) [test_set, eval_hps_model, sample_hps_model] = \ load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params) test_image_raw_name = test_image_name[:test_image_name.find('.')] model_dir = os.path.join(model_base_dir, model_name) reset_graph() sampling_model = VirtualSketchingModel(sample_hps_model) # differentiable pasting graph paste_v3_func = DiffPastingV3(sample_hps_model.raster_size) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) # loads the weights from checkpoint into our model snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True) print('snapshot_step', snapshot_step) sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name) os.makedirs(sampling_dir, exist_ok=True) if longer_infer_len == -1: tmp_max_len = eval_hps_model.max_seq_len else: tmp_max_len = longer_infer_len for sampling_i in range(sampling_num): input_photos, init_cursors, test_image_size = test_set.get_test_image() # input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG] # init_cursors: (N, 1, 2), in size [0.0, 1.0) print() print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i) print('Processing ...') if init_cursors.ndim == 3: init_cursors = np.expand_dims(init_cursors, axis=0) input_photos = input_photos[0:1, :, :, :] ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8) ori_img_png = Image.fromarray(ori_img, 'RGB') ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG') # decoding for sampling strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, window_size_out_list = sample( sess, sampling_model, input_photos, init_cursors, test_image_size, eval_hps_model.max_seq_len, tmp_max_len, state_dependent, paste_v3_func) # pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke] output_i = 0 strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0) states_raw_out = states_raw_out_list[output_i] states_soft_out = states_soft_out_list[output_i] window_size_out = window_size_out_list[output_i] round_new_lengths = [tmp_max_len] multi_cursors = [init_cursors[0, output_i, 0, :]] print('strokes_raw_out', strokes_raw_out.shape) clean_states_soft_out = np.array(states_soft_out) # (N) flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N) drawing_len = len(flag_list) - np.sum(flag_list) assert drawing_len >= 0 # print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2') for i in range(strokes_raw_out.shape[0]): flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i] win_size = window_size_out[i] out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f' out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2) out_log = out_format % out_values # print(out_log) print('Saving results ...') save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i), strokes_raw_out, init_cursors[0, output_i, 0, :], test_image_size, tmp_max_len, eval_hps_model.min_width) draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png', ori_img, test_image_size, multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type, sample_hps_model.raster_size, sample_hps_model.min_window_size, sess, pasting_func=paste_v3_func, save_seq=draw_seq, draw_order=draw_order) def main(model_name, test_image_name, sampling_num): test_dataset = 'faces' test_image_base_dir = 'sample_inputs' sampling_base_dir = 'outputs/sampling' model_base_dir = 'outputs/snapshot' state_dependent = False longer_infer_len = 100 draw_seq = False draw_color_order = True # set numpy output to something sensible np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, draw_seq=draw_seq, draw_order=draw_color_order, state_dependent=state_dependent, longer_infer_len=longer_infer_len) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', '-i', type=str, default='', help="The test image name.") parser.add_argument('--model', '-m', type=str, default='pretrain_faces', help="The trained model.") parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.") args = parser.parse_args() assert args.input != '' assert args.sample > 0 main(args.model, args.input, args.sample) ================================================ FILE: test_rough_sketch_simplification.py ================================================ import numpy as np import os import tensorflow as tf from six.moves import range from PIL import Image import argparse import hyper_parameters as hparams from model_common_test import DiffPastingV3, VirtualSketchingModel from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \ save_seq_data, image_pasting_v3_testing, draw_strokes from dataset_utils import load_dataset_testing os.environ['CUDA_VISIBLE_DEVICES'] = '0' def move_cursor_to_undrawn(current_pos_list, input_image_, patch_size, move_min_dist, move_max_dist, trial_times=20): """ :param current_pos_list: (select_times, 1, 2), [0.0, 1.0) :param input_image_: (1, image_size, image_size, 3), [0-stroke, 1-BG] :return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0) """ def crop_patch(image, center, image_size, crop_size): x0 = center[0] - crop_size // 2 x1 = x0 + crop_size y0 = center[1] - crop_size // 2 y1 = y0 + crop_size x0 = max(0, min(x0, image_size)) y0 = max(0, min(y0, image_size)) x1 = max(0, min(x1, image_size)) y1 = max(0, min(y1, image_size)) patch = image[y0:y1, x0:x1] return patch def isvalid_cursor(input_img, cursor, raster_size, image_size): # input_img: (image_size, image_size, 3), [0.0-BG, 1.0-stroke] cursor_large = cursor * float(image_size) cursor_large = np.round(cursor_large).astype(np.int32) input_crop_patch = crop_patch(input_img, cursor_large, image_size, raster_size) if np.sum(input_crop_patch) > 0.0: return True else: return False def randomly_move_cursor(cursor_position, img_size, min_dist_p, max_dist_p): # cursor_position: (2), [0.0, 1.0) cursor_pos_large = cursor_position * img_size min_dist = int(min_dist_p / 2.0 * img_size) max_dist = int(max_dist_p / 2.0 * img_size) rand_cursor_offset = np.random.randint(min_dist, max_dist, size=cursor_pos_large.shape) rand_cursor_offset_sign = np.random.randint(0, 1 + 1, size=cursor_pos_large.shape) rand_cursor_offset_sign[rand_cursor_offset_sign == 0] = -1 rand_cursor_offset = rand_cursor_offset * rand_cursor_offset_sign new_cursor_pos_large = cursor_pos_large + rand_cursor_offset new_cursor_pos_large = np.minimum(np.maximum(new_cursor_pos_large, 0), img_size - 1) # (2), large-level new_cursor_pos = new_cursor_pos_large.astype(np.float32) / float(img_size) return new_cursor_pos input_image = 1.0 - input_image_[0] # (image_size, image_size, 3), [0-BG, 1-stroke] img_size = input_image.shape[0] new_cursor_pos = [] for cursor_i in range(current_pos_list.shape[0]): curr_cursor = current_pos_list[cursor_i][0] for trial_i in range(trial_times): new_cursor = randomly_move_cursor(curr_cursor, img_size, move_min_dist, move_max_dist) # (2), [0.0, 1.0) if isvalid_cursor(input_image, new_cursor, patch_size, img_size) or trial_i == trial_times - 1: new_cursor_pos.append(new_cursor) break assert len(new_cursor_pos) == current_pos_list.shape[0] new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1) # (select_times, 1, 2), [0.0, 1.0) return new_cursor_pos def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens, state_dependent, pasting_func, round_stop_state_num, min_dist_p, max_dist_p): """Samples a sequence from a pre-trained model.""" select_times = 1 curr_canvas = np.zeros(dtype=np.float32, shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke] initial_state = sess.run(model.initial_state) params_list = [[] for _ in range(select_times)] state_raw_list = [[] for _ in range(select_times)] state_soft_list = [[] for _ in range(select_times)] window_size_list = [[] for _ in range(select_times)] round_cursor_list = [] round_length_real_list = [] input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1)) for cursor_i, seq_len in enumerate(seq_lens): if cursor_i == 0: cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2) else: cursor_pos = move_cursor_to_undrawn(cursor_pos, input_photos, model.hps.raster_size, min_dist_p, max_dist_p) # (select_times, 1, 2) round_cursor_list.append(cursor_pos) prev_state = initial_state prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0) prev_scaling = np.ones((select_times), dtype=np.float32) # (N) prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N) continuous_one_state_num = 0 for i in range(seq_len): if not state_dependent and i % init_len == 0: prev_state = initial_state curr_window_size = prev_scaling * prev_window_size # (N) curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size) curr_window_size = np.minimum(curr_window_size, image_size) feed = { model.initial_state: prev_state, model.input_photo: input_photos_tiles, model.curr_canvas_hard: curr_canvas.copy(), model.cursor_position: cursor_pos, model.image_size: image_size, model.init_width: prev_width, model.init_scaling: prev_scaling, model.init_window_size: prev_window_size, } o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \ sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed) # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024) # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2] idx_eos_list = np.argmax(o_pen_list, axis=1) # (N) output_i = 0 idx_eos = idx_eos_list[output_i] eos = [0, 0] eos[idx_eos] = 1 other_params = o_other_params_list[output_i].tolist() # (6) params_list[output_i].append([eos[1]] + other_params) state_raw_list[output_i].append(o_pen_list[output_i][1]) state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0]) window_size_list[output_i].append(curr_window_size[output_i]) # draw the stroke and add to the canvas x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \ o_other_params_list[output_i, 4] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width[output_i], width2], axis=0) # (2) o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8) if idx_eos == 0: f = o_other_params_proc + [1.0, 1.0] pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG] pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0], image_size, curr_window_size[output_i], pasting_func, sess) # [0.0-BG, 1.0-stroke] curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke] continuous_one_state_num = 0 else: continuous_one_state_num += 1 curr_canvas = np.clip(curr_canvas, 0.0, 1.0) next_width = o_other_params_list[:, 4] # (N) next_scaling = o_other_params_list[:, 5] next_window_size = next_scaling * curr_window_size # (N) next_window_size = np.maximum(next_window_size, model.hps.min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_state = next_state_list prev_width = next_width * curr_window_size / next_window_size # (N,) prev_scaling = next_scaling # (N) prev_window_size = curr_window_size # update cursor_pos based on hps.cursor_type new_cursor_offsets = o_other_params_list[:, 2:4] * ( np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level if model.hps.cursor_type == 'next': cursor_pos_large = stroke_position_next # (N, 2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (N, 2), large-level cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2) cursor_pos = cursor_pos_large / float(image_size) if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1: round_length_real_list.append(i + 1) break return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \ round_cursor_list, round_length_real_list def main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, min_dist_p, max_dist_p, longer_infer_lens, round_stop_state_num, draw_seq=False, draw_order=False, state_dependent=True): model_params_default = hparams.get_default_hparams_rough() model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset) [test_set, eval_hps_model, sample_hps_model] = \ load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params) test_image_raw_name = test_image_name[:test_image_name.find('.')] model_dir = os.path.join(model_base_dir, model_name) reset_graph() sampling_model = VirtualSketchingModel(sample_hps_model) # differentiable pasting graph paste_v3_func = DiffPastingV3(sample_hps_model.raster_size) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) # loads the weights from checkpoint into our model snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True) print('snapshot_step', snapshot_step) sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name) os.makedirs(sampling_dir, exist_ok=True) for sampling_i in range(sampling_num): input_photos, init_cursors, test_image_size = test_set.get_test_image() # input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG] # init_cursors: (N, 1, 2), in size [0.0, 1.0) print() print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i) print('Processing ...') if init_cursors.ndim == 3: init_cursors = np.expand_dims(init_cursors, axis=0) input_photos = input_photos[0:1, :, :, :] ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8) ori_img_png = Image.fromarray(ori_img, 'RGB') ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG') # decoding for sampling strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \ window_size_out_list, round_new_cursors, round_new_lengths = sample( sess, sampling_model, input_photos, init_cursors, test_image_size, eval_hps_model.max_seq_len, longer_infer_lens, state_dependent, paste_v3_func, round_stop_state_num, min_dist_p, max_dist_p) # pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke] print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths) output_i = 0 strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0) states_raw_out = states_raw_out_list[output_i] states_soft_out = states_soft_out_list[output_i] window_size_out = window_size_out_list[output_i] multi_cursors = [init_cursors[0, output_i, 0]] for c_i in range(len(round_new_cursors)): best_cursor = round_new_cursors[c_i][output_i, 0] # (2) multi_cursors.append(best_cursor) assert len(multi_cursors) == len(round_new_lengths) print('strokes_raw_out', strokes_raw_out.shape) clean_states_soft_out = np.array(states_soft_out) # (N) flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N) drawing_len = len(flag_list) - np.sum(flag_list) assert drawing_len >= 0 # print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2') for i in range(strokes_raw_out.shape[0]): flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i] win_size = window_size_out[i] out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f' out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2) out_log = out_format % out_values # print(out_log) print('Saving results ...') save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i), strokes_raw_out, multi_cursors, test_image_size, round_new_lengths, eval_hps_model.min_width) draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png', ori_img, test_image_size, multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type, sample_hps_model.raster_size, sample_hps_model.min_window_size, sess, pasting_func=paste_v3_func, save_seq=draw_seq, draw_order=draw_order) def main(model_name, test_image_name, sampling_num): test_dataset = 'rough_sketches' test_image_base_dir = 'sample_inputs' sampling_base_dir = 'outputs/sampling' model_base_dir = 'outputs/snapshot' state_dependent = False longer_infer_lens = [128 for _ in range(10)] round_stop_state_num = 12 min_dist_p = 0.3 max_dist_p = 0.9 draw_seq = False draw_color_order = True # set numpy output to something sensible np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, min_dist_p=min_dist_p, max_dist_p=max_dist_p, draw_seq=draw_seq, draw_order=draw_color_order, state_dependent=state_dependent, longer_infer_lens=longer_infer_lens, round_stop_state_num=round_stop_state_num) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', '-i', type=str, default='', help="The test image name.") parser.add_argument('--model', '-m', type=str, default='pretrain_rough_sketches', help="The trained model.") parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.") args = parser.parse_args() assert args.input != '' assert args.sample > 0 main(args.model, args.input, args.sample) ================================================ FILE: test_vectorization.py ================================================ import numpy as np import random import os import tensorflow as tf from six.moves import range from PIL import Image import time import argparse import hyper_parameters as hparams from model_common_test import DiffPastingV3, VirtualSketchingModel from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \ save_seq_data, image_pasting_v3_testing, draw_strokes from dataset_utils import load_dataset_testing os.environ['CUDA_VISIBLE_DEVICES'] = '0' def move_cursor_to_undrawn(current_canvas_list, input_image_, last_min_acc_list, grid_patch_size=128, stroke_acc_threshold=0.95, stroke_num_threshold=5, continuous_min_acc_threshold=2): """ :param current_canvas_list: (select_times, image_size, image_size), [0.0-BG, 1.0-stroke] :param input_image_: (1, image_size, image_size), [0-stroke, 1-BG] :return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0) """ def split_images(in_img, image_size, grid_size): if image_size % grid_size == 0: paddings_ = 0 else: paddings_ = grid_size - image_size % grid_size paddings = [[0, paddings_], [0, paddings_]] image_pad = np.pad(in_img, paddings, mode='constant', constant_values=0.0) # (H_p, W_p), [0.0-BG, 1.0-stroke] assert image_pad.shape[0] % grid_size == 0 split_num = image_pad.shape[0] // grid_size images_h = np.hsplit(image_pad, split_num) image_patches = [] for image_h in images_h: images_v = np.vsplit(image_h, split_num) image_patches += images_v image_patches = np.array(image_patches, dtype=np.float32) return image_patches, split_num def line_drawing_rounding(line_drawing): line_drawing_r = np.copy(line_drawing) # [0.0-BG, 1.0-stroke] line_drawing_r[line_drawing_r != 0.0] = 1.0 return line_drawing_r def cal_undrawn_pixels(in_canvas, in_sketch): in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32) # (N, H, W), [0.0-BG, 1.0-stroke] in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32) intersection = np.bitwise_and(in_canvas_round, in_sketch_round) intersection_sum = np.sum(intersection, axis=(1, 2)) gt_sum = np.sum(in_sketch_round, axis=(1, 2)) # (N) undrawn_num = gt_sum - intersection_sum return undrawn_num def cal_stroke_acc(in_canvas, in_sketch): in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32) # (N, H, W), [0.0-BG, 1.0-stroke] in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32) intersection = np.bitwise_and(in_canvas_round, in_sketch_round) intersection_sum = np.sum(intersection, axis=(1, 2)).astype(np.float32) gt_sum = np.sum(in_sketch_round, axis=(1, 2)).astype(np.float32) # (N) undrawn_num = gt_sum - intersection_sum # (N) stroke_acc = intersection_sum / gt_sum # (N) stroke_acc[gt_sum == 0.0] = 1.0 stroke_acc[undrawn_num <= stroke_num_threshold] = 1.0 return stroke_acc def get_cursor(patch_idx, img_size, grid_size, split_num): y_pos = patch_idx % split_num x_pos = patch_idx // split_num y_top = y_pos * grid_size + grid_size // 4 y_bottom = y_top + grid_size // 2 x_left = x_pos * grid_size + grid_size // 4 x_right = x_left + grid_size // 2 cursor_y = random.randint(y_top, y_bottom) cursor_x = random.randint(x_left, x_right) cursor_y = max(0, min(cursor_y, img_size - 1)) cursor_x = max(0, min(cursor_x, img_size - 1)) # (2), in large size center = np.array([cursor_x, cursor_y], dtype=np.float32) return center / float(img_size) # (2), in size [0.0, 1.0) input_image = 1.0 - input_image_[0] # (image_size, image_size), [0-BG, 1-stroke] img_size = input_image.shape[0] input_image_patches, split_number = split_images(input_image, img_size, grid_patch_size) # (N, grid_size, grid_size) new_cursor_pos = [] last_min_acc_list_new = [item for item in last_min_acc_list] for canvas_i in range(current_canvas_list.shape[0]): curr_canvas = current_canvas_list[canvas_i] # (image_size, image_size), [0.0-BG, 1.0-stroke] curr_canvas_patches, _ = split_images(curr_canvas, img_size, grid_patch_size) # (N, grid_size, grid_size) # 1. detect ending flag by stroke accuracy stroke_accuracy = cal_stroke_acc(curr_canvas_patches, input_image_patches) min_acc_idx = np.argmin(stroke_accuracy) min_acc= stroke_accuracy[min_acc_idx] # print('min_acc_idx', min_acc_idx, ' | ', 'min_acc', min_acc) if min_acc >= stroke_acc_threshold: # end of drawing return None, None # 2. detect undrawn pixels undrawn_pixel_num = cal_undrawn_pixels(curr_canvas_patches, input_image_patches) # undrawn_pixel_num_dis = np.reshape(undrawn_pixel_num, (split_number, split_number)).T # print('undrawn_pixel_num_dis') # print(undrawn_pixel_num_dis) max_undrawn_idx = np.argmax(undrawn_pixel_num) # max_undrawn = undrawn_pixel_num[max_undrawn_idx] # print('max_undrawn_idx', max_undrawn_idx, ' | ', 'max_undrawn', max_undrawn) # 3. select a random position last_min_acc_idx, last_min_acc_times = last_min_acc_list[canvas_i] if last_min_acc_times >= continuous_min_acc_threshold and last_min_acc_idx == min_acc_idx: selected_patch_idx = last_min_acc_idx new_min_acc_times = 1 else: selected_patch_idx = max_undrawn_idx if min_acc_idx == last_min_acc_idx: new_min_acc_times = last_min_acc_times + 1 else: new_min_acc_times = 1 new_min_acc_idx = min_acc_idx last_min_acc_list_new[canvas_i] = (new_min_acc_idx, new_min_acc_times) # print('selected_patch_idx', selected_patch_idx) # 4. get cursor according to the selected_patch_idx rand_cursor = get_cursor(selected_patch_idx, img_size, grid_patch_size, split_number) # (2), in size [0.0, 1.0) new_cursor_pos.append(rand_cursor) assert len(new_cursor_pos) == current_canvas_list.shape[0] new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1) # (select_times, 1, 2), [0.0, 1.0) return new_cursor_pos, last_min_acc_list_new def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens, state_dependent, pasting_func, round_stop_state_num, stroke_acc_threshold): """Samples a sequence from a pre-trained model.""" select_times = 1 curr_canvas = np.zeros(dtype=np.float32, shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke] initial_state = sess.run(model.initial_state) prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0) params_list = [[] for _ in range(select_times)] state_raw_list = [[] for _ in range(select_times)] state_soft_list = [[] for _ in range(select_times)] window_size_list = [[] for _ in range(select_times)] last_min_stroke_acc_list = [(-1, 0) for _ in range(select_times)] round_cursor_list = [] round_length_real_list = [] input_photos_tiles = np.tile(input_photos, (select_times, 1, 1)) for cursor_i, seq_len in enumerate(seq_lens): # print('\n') # print('@@ Round', cursor_i + 1) if cursor_i == 0: cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2) else: cursor_pos, last_min_stroke_acc_list_updated = \ move_cursor_to_undrawn(curr_canvas, input_photos, last_min_stroke_acc_list, grid_patch_size=model.hps.raster_size, stroke_acc_threshold=stroke_acc_threshold) # (select_times, 1, 2) if cursor_pos is not None: round_cursor_list.append(cursor_pos) last_min_stroke_acc_list = last_min_stroke_acc_list_updated else: break prev_state = initial_state if not model.hps.init_cursor_on_undrawn_pixel: prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0) prev_scaling = np.ones((select_times), dtype=np.float32) # (N) prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N) continuous_one_state_num = 0 for i in range(seq_len): if not state_dependent and i % init_len == 0: prev_state = initial_state curr_window_size = prev_scaling * prev_window_size # (N) curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size) curr_window_size = np.minimum(curr_window_size, image_size) feed = { model.initial_state: prev_state, model.input_photo: np.expand_dims(input_photos_tiles, axis=-1), model.curr_canvas_hard: curr_canvas.copy(), model.cursor_position: cursor_pos, model.image_size: image_size, model.init_width: prev_width, model.init_scaling: prev_scaling, model.init_window_size: prev_window_size, } o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \ sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed) # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024) # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2] idx_eos_list = np.argmax(o_pen_list, axis=1) # (N) output_i = 0 idx_eos = idx_eos_list[output_i] eos = [0, 0] eos[idx_eos] = 1 other_params = o_other_params_list[output_i].tolist() # (6) params_list[output_i].append([eos[1]] + other_params) state_raw_list[output_i].append(o_pen_list[output_i][1]) state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0]) window_size_list[output_i].append(curr_window_size[output_i]) # draw the stroke and add to the canvas x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \ o_other_params_list[output_i, 4] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width[output_i], width2], axis=0) # (2) o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8) if idx_eos == 0: f = o_other_params_proc + [1.0, 1.0] pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG] pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0], image_size, curr_window_size[output_i], pasting_func, sess) # [0.0-BG, 1.0-stroke] curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke] continuous_one_state_num = 0 else: continuous_one_state_num += 1 curr_canvas = np.clip(curr_canvas, 0.0, 1.0) next_width = o_other_params_list[:, 4] # (N) next_scaling = o_other_params_list[:, 5] next_window_size = next_scaling * curr_window_size # (N) next_window_size = np.maximum(next_window_size, model.hps.min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_state = next_state_list prev_width = next_width * curr_window_size / next_window_size # (N,) prev_scaling = next_scaling # (N) prev_window_size = curr_window_size # update cursor_pos based on hps.cursor_type new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level if model.hps.cursor_type == 'next': cursor_pos_large = stroke_position_next # (N, 2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (N, 2), large-level cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2) cursor_pos = cursor_pos_large / float(image_size) if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1: round_length_real_list.append(i + 1) break return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \ round_cursor_list, round_length_real_list def main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, longer_infer_lens, round_stop_state_num, stroke_acc_threshold, draw_seq=False, draw_order=False, state_dependent=True): model_params_default = hparams.get_default_hparams_clean() model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset) [test_set, eval_hps_model, sample_hps_model] \ = load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params) test_image_raw_name = test_image_name[:test_image_name.find('.')] model_dir = os.path.join(model_base_dir, model_name) reset_graph() sampling_model = VirtualSketchingModel(sample_hps_model) # differentiable pasting graph paste_v3_func = DiffPastingV3(sample_hps_model.raster_size) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) # loads the weights from checkpoint into our model snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True) print('snapshot_step', snapshot_step) sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name) os.makedirs(sampling_dir, exist_ok=True) stroke_number_list = [] compute_time_list = [] for sampling_i in range(sampling_num): start_time_point = time.time() input_photos, init_cursors, test_image_size = test_set.get_test_image() # input_photos: (1, image_size, image_size), [0-stroke, 1-BG] # init_cursors: (1, 1, 2), in size [0.0, 1.0) print() print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i) print('Processing ...') if init_cursors.ndim == 3: init_cursors = np.expand_dims(init_cursors, axis=0) input_photos = input_photos[0:1, :, :] ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8) ori_img = np.stack([ori_img for _ in range(3)], axis=2) ori_img_png = Image.fromarray(ori_img, 'RGB') ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG') data_loading_time_point = time.time() # decoding for sampling strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \ window_size_out_list, round_new_cursors, round_new_lengths = sample( sess, sampling_model, input_photos, init_cursors, test_image_size, eval_hps_model.max_seq_len, longer_infer_lens, state_dependent, paste_v3_func, round_stop_state_num, stroke_acc_threshold) # pred_imgs_out: [0.0-BG, 1.0-stroke] print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths) sampling_time_point = time.time() data_loading_time = data_loading_time_point - start_time_point sampling_time_total = sampling_time_point - start_time_point sampling_time_wo_data_loading = sampling_time_point - data_loading_time_point compute_time_list.append(sampling_time_total) # print(' >>> data_loading_time', data_loading_time) print(' >>> sampling_time_total', sampling_time_total) # print(' >>> sampling_time_wo_data_loading', sampling_time_wo_data_loading) best_result_idx = 0 strokes_raw_out = np.stack(strokes_raw_out_list[best_result_idx], axis=0) states_raw_out = states_raw_out_list[best_result_idx] states_soft_out = states_soft_out_list[best_result_idx] window_size_out = window_size_out_list[best_result_idx] multi_cursors = [init_cursors[0, best_result_idx, 0]] for c_i in range(len(round_new_cursors)): best_cursor = round_new_cursors[c_i][best_result_idx, 0] # (2) multi_cursors.append(best_cursor) assert len(multi_cursors) == len(round_new_lengths) print('strokes_raw_out', strokes_raw_out.shape) stroke_number_list.append(strokes_raw_out.shape[0]) clean_states_soft_out = np.array(states_soft_out) # (N) flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N) drawing_len = len(flag_list) - np.sum(flag_list) assert drawing_len >= 0 # print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2') for i in range(strokes_raw_out.shape[0]): flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i] win_size = window_size_out[i] out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f' out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2) out_log = out_format % out_values # print(out_log) print('Saving results ...') save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i), strokes_raw_out, multi_cursors, test_image_size, round_new_lengths, eval_hps_model.min_width) draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png', ori_img, test_image_size, multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type, sample_hps_model.raster_size, sample_hps_model.min_window_size, sess, pasting_func=paste_v3_func, save_seq=draw_seq, draw_order=draw_order) average_stroke_number = np.mean(stroke_number_list) average_compute_time = np.mean(compute_time_list) print() print('@@@ Total summary:') print(' >>> average_stroke_number', average_stroke_number) print(' >>> average_compute_time', average_compute_time) def main(model_name, test_image_name, sampling_num): test_dataset = 'clean_line_drawings' test_image_base_dir = 'sample_inputs' sampling_base_dir = 'outputs/sampling' model_base_dir = 'outputs/snapshot' state_dependent = False longer_infer_lens = [500 for _ in range(10)] round_stop_state_num = 12 stroke_acc_threshold = 0.95 draw_seq = False draw_color_order = True # set numpy output to something sensible np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) main_testing(test_image_base_dir, test_dataset, test_image_name, sampling_base_dir, model_base_dir, model_name, sampling_num, draw_seq=draw_seq, draw_order=draw_color_order, state_dependent=state_dependent, longer_infer_lens=longer_infer_lens, round_stop_state_num=round_stop_state_num, stroke_acc_threshold=stroke_acc_threshold) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', '-i', type=str, default='', help="The test image name.") parser.add_argument('--model', '-m', type=str, default='pretrain_clean_line_drawings', help="The trained model.") parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.") args = parser.parse_args() assert args.input != '' assert args.sample > 0 main(args.model, args.input, args.sample) ================================================ FILE: tools/gif_making.py ================================================ import os import sys import argparse import numpy as np from PIL import Image import tensorflow as tf sys.path.append('./') from utils import draw, image_pasting_v3_testing from model_common_test import DiffPastingV3 os.environ['CUDA_VISIBLE_DEVICES'] = '0' def add_scaling_visualization(canvas_images, cursor, window_size, image_size): """ :param canvas_images: (N, H, W, 3) :param cursor: :param window_size: :param image_size: :return: """ cursor_pos = cursor * float(image_size) cursor_x, cursor_y = int(round(cursor_pos[0])), int(round(cursor_pos[1])) # in large size vis_color = [255, 0, 0] cursor_width = 3 box_width = 2 canvas_imgs = 255 - np.round(canvas_images * 255.0).astype(np.uint8) # add cursor visualization canvas_imgs[:, cursor_y - cursor_width: cursor_y + cursor_width, cursor_x - cursor_width: cursor_x + cursor_width, :] = vis_color # add box visualization up = max(0, cursor_y - window_size // 2) down = min(image_size, cursor_y + window_size // 2) left = max(0, cursor_x - window_size // 2) right = min(image_size, cursor_x + window_size // 2) # up = cursor_y - window_size // 2 # down = cursor_y + window_size // 2 # left = cursor_x - window_size // 2 # right = cursor_x + window_size // 2 if up > 0: canvas_imgs[:, up: up + box_width, left: right, :] = vis_color if down < image_size: canvas_imgs[:, down - box_width: down, left: right, :] = vis_color if left > 0: canvas_imgs[:, up: down, left: left + box_width, :] = vis_color if right < image_size: canvas_imgs[:, up: down, right - box_width: right, :] = vis_color return canvas_imgs def make_gif(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width, save_base, cursor_type='next', min_window_size=32, raster_size=128, add_box=True): """ :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2 :return: """ canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke] gif_frames = [] cursor_idx = 0 if init_cursor.ndim == 1: init_cursor = [init_cursor] for round_idx in range(len(infer_lengths)): print('Making progress', round_idx + 1, '/', len(infer_lengths)) round_length = infer_lengths[round_idx] cursor_pos = init_cursor[cursor_idx] # (2) cursor_idx += 1 prev_width = init_width prev_scaling = 1.0 prev_window_size = float(raster_size) # (1) for round_inner_i in range(round_length): stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i curr_window_size_raw = prev_scaling * prev_window_size curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size) curr_window_size_raw = np.minimum(curr_window_size_raw, image_size) curr_window_size = int(round(curr_window_size_raw)) # () pen_state = data[stroke_idx, 0] stroke_params = data[stroke_idx, 1:] # (8) x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width, width2], axis=0) # (2) stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8) next_width = stroke_params[4] next_scaling = stroke_params[5] next_window_size = next_scaling * curr_window_size_raw next_window_size = np.maximum(next_window_size, min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_width = next_width * curr_window_size_raw / next_window_size prev_scaling = next_scaling prev_window_size = curr_window_size_raw f = stroke_params_proc.tolist() # (8) f += [1.0, 1.0] gt_stroke_img = draw(f) # (H, W), [0.0-stroke, 1.0-BG] gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size, curr_window_size_raw, pasting_func, sess) # [0.0-BG, 1.0-stroke] if pen_state == 0: canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke] canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1) if add_box: vis_inputs = np.expand_dims(canvas_rgb, axis=0) vis_outputs = add_scaling_visualization(vis_inputs, cursor_pos, curr_window_size, image_size) canvas_vis = vis_outputs[0] else: canvas_vis = canvas_rgb canvas_vis_png = Image.fromarray(canvas_vis, 'RGB') gif_frames.append(canvas_vis_png) # update cursor_pos based on hps.cursor_type new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level if cursor_type == 'next': cursor_pos_large = stroke_position_next # (2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level cursor_pos = cursor_pos_large / float(image_size) print('Saving to GIF ...') save_path = os.path.join(save_base, 'dynamic.gif') first_frame = gif_frames[0] first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01) def gif_making(npz_path): assert npz_path != '' min_window_size = 32 raster_size = 128 split_idx = npz_path.rfind('/') if split_idx == -1: file_base = './' file_name = npz_path[:-4] else: file_base = npz_path[:npz_path.rfind('/')] file_name = npz_path[npz_path.rfind('/') + 1: -4] gif_base = os.path.join(file_base, file_name) os.makedirs(gif_base, exist_ok=True) # differentiable pasting graph paste_v3_func = DiffPastingV3(raster_size) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) data = np.load(npz_path, encoding='latin1', allow_pickle=True) strokes_data = data['strokes_data'] init_cursors = data['init_cursors'] image_size = data['image_size'] round_length = data['round_length'] init_width = data['init_width'] if round_length.ndim == 0: round_lengths = [round_length] else: round_lengths = round_length # print('round_lengths', round_lengths) make_gif(sess, paste_v3_func, strokes_data, init_cursors, image_size, round_lengths, init_width, gif_base, min_window_size=min_window_size, raster_size=raster_size) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--file', '-f', type=str, default='', help="define a npz path") args = parser.parse_args() gif_making(args.file) ================================================ FILE: tools/svg_conversion.py ================================================ import os import argparse import numpy as np from xml.dom import minidom def write_svg_1(path_list, img_size, save_path): ''' A long curve consisting of several strokes forms a path. ''' impl = minidom.getDOMImplementation() doc = impl.createDocument(None, None, None) rootElement = doc.createElement('svg') rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg') rootElement.setAttribute('height', str(img_size)) rootElement.setAttribute('width', str(img_size)) path_num = len(path_list) for path_i in range(path_num): path_items = path_list[path_i] assert len(path_items) > 0 if len(path_items) == 1: continue childElement = doc.createElement('path') childElement.setAttribute('id', 'curve_' + str(path_i)) childElement.setAttribute('stroke', '#000000') childElement.setAttribute('stroke-width', '3.5') childElement.setAttribute('stroke-linejoin', 'round') childElement.setAttribute('stroke-linecap', 'round') childElement.setAttribute('fill', 'none') command_str = '' for stroke_i, stroke_item in enumerate(path_items): if stroke_i == 0: command_str += 'M ' stroke_position = stroke_item[0] command_str += str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' ' else: command_str += 'Q ' ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2] ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1] ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0] command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \ str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' ' last_position = stroke_position childElement.setAttribute('d', command_str) rootElement.appendChild(childElement) doc.appendChild(rootElement) f = open(save_path, 'w') doc.writexml(f, addindent=' ', newl='\n') f.close() def write_svg_2(path_list, img_size, save_path): ''' A single stroke forms a path. ''' impl = minidom.getDOMImplementation() doc = impl.createDocument(None, None, None) rootElement = doc.createElement('svg') rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg') rootElement.setAttribute('height', str(img_size)) rootElement.setAttribute('width', str(img_size)) path_num = len(path_list) for path_i in range(path_num): path_items = path_list[path_i] assert len(path_items) > 0 if len(path_items) == 1: continue for stroke_i, stroke_item in enumerate(path_items): if stroke_i == 0: last_position = stroke_item[0] else: childElement = doc.createElement('path') childElement.setAttribute('id', 'curve_' + str(path_i)) childElement.setAttribute('stroke', '#000000') childElement.setAttribute('stroke-linejoin', 'round') childElement.setAttribute('stroke-linecap', 'round') childElement.setAttribute('fill', 'none') command_str = 'M ' + str(last_position[0]) + ', ' + str(last_position[1]) + ' ' command_str += 'Q ' ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2] ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1] ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0] command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \ str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' ' last_position = stroke_position childElement.setAttribute('d', command_str) childElement.setAttribute('stroke-width', str(stroke_width * img_size / 1.66)) rootElement.appendChild(childElement) doc.appendChild(rootElement) f = open(save_path, 'w') doc.writexml(f, addindent=' ', newl='\n') f.close() def convert_strokes_to_svg(data, init_cursor, image_size, infer_lengths, init_width, save_path, svg_type, cursor_type='next', min_window_size=32, raster_size=128): """ :param data: (N_strokes, 7): flag, x_c, y_c, dx, dy, r, ds :return: """ cursor_idx = 0 absolute_strokes = [] absolute_strokes_path = [] if init_cursor.ndim == 1: init_cursor = [init_cursor] for round_idx in range(len(infer_lengths)): round_length = infer_lengths[round_idx] cursor_pos = init_cursor[cursor_idx] # (2) cursor_idx += 1 cursor_pos_large = cursor_pos * float(image_size) if len(absolute_strokes_path) > 0: absolute_strokes.append(absolute_strokes_path) absolute_strokes_path = [[cursor_pos_large]] prev_width = init_width prev_scaling = 1.0 prev_window_size = float(raster_size) # (1) for round_inner_i in range(round_length): stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i curr_window_size_raw = prev_scaling * prev_window_size curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size) curr_window_size_raw = np.minimum(curr_window_size_raw, image_size) # curr_window_size = int(round(curr_window_size_raw)) # () stroke_params = data[stroke_idx, 1:] # (6) pen_state = data[stroke_idx, 0] next_width = stroke_params[4] next_scaling = stroke_params[5] next_width_abs = next_width * curr_window_size_raw / float(image_size) prev_scaling = next_scaling prev_window_size = curr_window_size_raw # update cursor_pos based on hps.cursor_type new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level if pen_state == 0: absolute_strokes_path.append([stroke_params[0:2], stroke_position_next, next_width_abs]) else: absolute_strokes.append(absolute_strokes_path) absolute_strokes_path = [[stroke_position_next]] if cursor_type == 'next': cursor_pos_large = stroke_position_next # (2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level cursor_pos = cursor_pos_large / float(image_size) absolute_strokes.append(absolute_strokes_path) if svg_type == 'cluster': write_svg_1(absolute_strokes, image_size, save_path) elif svg_type == 'single': write_svg_2(absolute_strokes, image_size, save_path) else: raise Exception('Unknown svg_type', svg_type) def data_convert_to_absolute(npz_path, svg_type): assert npz_path != '' assert svg_type in ['single', 'cluster'] min_window_size = 32 raster_size = 128 split_idx = npz_path.rfind('/') if split_idx == -1: file_base = './' file_name = npz_path[:-4] else: file_base = npz_path[:npz_path.rfind('/')] file_name = npz_path[npz_path.rfind('/') + 1: -4] svg_data_base = os.path.join(file_base, file_name) os.makedirs(svg_data_base, exist_ok=True) data = np.load(npz_path, encoding='latin1', allow_pickle=True) strokes_data = data['strokes_data'] init_cursors = data['init_cursors'] image_size = data['image_size'] round_length = data['round_length'] init_width = data['init_width'] if round_length.ndim == 0: round_lengths = [round_length] else: round_lengths = round_length save_path = os.path.join(svg_data_base, str(svg_type) + '.svg') convert_strokes_to_svg(strokes_data, init_cursors, image_size, round_lengths, init_width, min_window_size=min_window_size, raster_size=raster_size, save_path=save_path, svg_type=svg_type) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--file', '-f', type=str, default='', help="define a npz path") parser.add_argument('--svg_type', '-st', type=str, choices=['single', 'cluster'], default='single', help="svg type") args = parser.parse_args() data_convert_to_absolute(args.file, args.svg_type) ================================================ FILE: tools/visualize_drawing.py ================================================ import os import sys import argparse import numpy as np from PIL import Image import tensorflow as tf sys.path.append('./') from utils import get_colors, draw, image_pasting_v3_testing from model_common_test import DiffPastingV3 os.environ['CUDA_VISIBLE_DEVICES'] = '0' def display_strokes_final(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width, save_base, cursor_type='next', min_window_size=32, raster_size=128): """ :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2 :return: """ canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke] drawn_region = np.zeros_like(canvas) overlap_region = np.zeros_like(canvas) canvas_color_with_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32) canvas_color_wo_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32) canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32) cursor_idx = 0 if init_cursor.ndim == 1: init_cursor = [init_cursor] stroke_count = len(data) color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255] color_idx = 0 valid_stroke_count = stroke_count - np.sum(data[:, 0]).astype(np.int32) + len(init_cursor) valid_color_rgb_set = get_colors(valid_stroke_count) # list of (3,) in [0, 255] valid_color_idx = -1 # print('Drawn stroke number', valid_stroke_count) # print(' flag x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2') for round_idx in range(len(infer_lengths)): round_length = infer_lengths[round_idx] cursor_pos = init_cursor[cursor_idx] # (2) cursor_idx += 1 prev_width = init_width prev_scaling = 1.0 prev_window_size = float(raster_size) # (1) for round_inner_i in range(round_length): stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i curr_window_size_raw = prev_scaling * prev_window_size curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size) curr_window_size_raw = np.minimum(curr_window_size_raw, image_size) pen_state = data[stroke_idx, 0] stroke_params = data[stroke_idx, 1:] # (8) x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width, width2], axis=0) # (2) stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8) next_width = stroke_params[4] next_scaling = stroke_params[5] next_window_size = next_scaling * curr_window_size_raw next_window_size = np.maximum(next_window_size, min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_width = next_width * curr_window_size_raw / next_window_size prev_scaling = next_scaling prev_window_size = curr_window_size_raw f = stroke_params_proc.tolist() # (8) f += [1.0, 1.0] gt_stroke_img = draw(f) # (H, W), [0.0-stroke, 1.0-BG] gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size, curr_window_size_raw, pasting_func, sess) # [0.0-BG, 1.0-stroke] is_overlap = False if pen_state == 0: canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke] curr_drawn_stroke_region = np.zeros_like(gt_stroke_img_large) curr_drawn_stroke_region[gt_stroke_img_large > 0.5] = 1 intersection = drawn_region * curr_drawn_stroke_region # regard stroke with >50% overlap area as overlaped stroke if np.sum(intersection) / np.sum(curr_drawn_stroke_region) > 0.5: # enlarge the stroke a bit for better visualization overlap_region[gt_stroke_img_large > 0] += 1 is_overlap = True drawn_region[gt_stroke_img_large > 0.5] = 1 color_rgb = color_rgb_set[color_idx] # (3) in [0, 255] color_idx += 1 color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32) color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0) canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large), axis=-1) + color_stroke # (H, W, 3) if pen_state == 0: valid_color_idx += 1 if pen_state == 0: valid_color_rgb = valid_color_rgb_set[valid_color_idx] # (3) in [0, 255] # valid_color_idx += 1 valid_color_rgb = np.reshape(valid_color_rgb, (1, 1, 3)).astype(np.float32) valid_color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - valid_color_rgb / 255.0) canvas_color_with_overlap = canvas_color_with_overlap * np.expand_dims((1.0 - gt_stroke_img_large), axis=-1) + valid_color_stroke # (H, W, 3) if not is_overlap: canvas_color_wo_overlap = canvas_color_wo_overlap * np.expand_dims((1.0 - gt_stroke_img_large), axis=-1) + valid_color_stroke # (H, W, 3) # update cursor_pos based on hps.cursor_type new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level if cursor_type == 'next': cursor_pos_large = stroke_position_next # (2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level cursor_pos = cursor_pos_large / float(image_size) canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1) canvas_black = 255 - np.round(canvas_rgb * 255.0).astype(np.uint8) canvas_color_with_overlap = 255 - np.round(canvas_color_with_overlap * 255.0).astype(np.uint8) canvas_color_wo_overlap = 255 - np.round(canvas_color_wo_overlap * 255.0).astype(np.uint8) canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8) canvas_black_png = Image.fromarray(canvas_black, 'RGB') canvas_black_save_path = os.path.join(save_base, 'output_rendered.png') canvas_black_png.save(canvas_black_save_path, 'PNG') canvas_color_png = Image.fromarray(canvas_color_with_overlap, 'RGB') canvas_color_save_path = os.path.join(save_base, 'output_order_with_overlap.png') canvas_color_png.save(canvas_color_save_path, 'PNG') canvas_color_wo_png = Image.fromarray(canvas_color_wo_overlap, 'RGB') canvas_color_wo_save_path = os.path.join(save_base, 'output_order_wo_overlap.png') canvas_color_wo_png.save(canvas_color_wo_save_path, 'PNG') canvas_color_m_png = Image.fromarray(canvas_color_with_moving, 'RGB') canvas_color_m_save_path = os.path.join(save_base, 'output_order_with_moving.png') canvas_color_m_png.save(canvas_color_m_save_path, 'PNG') def visualize_drawing(npz_path): assert npz_path != '' min_window_size = 32 raster_size = 128 split_idx = npz_path.rfind('/') if split_idx == -1: file_base = './' file_name = npz_path[:-4] else: file_base = npz_path[:npz_path.rfind('/')] file_name = npz_path[npz_path.rfind('/') + 1: -4] regenerate_base = os.path.join(file_base, file_name) os.makedirs(regenerate_base, exist_ok=True) # differentiable pasting graph paste_v3_func = DiffPastingV3(raster_size) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) data = np.load(npz_path, encoding='latin1', allow_pickle=True) strokes_data = data['strokes_data'] init_cursors = data['init_cursors'] image_size = data['image_size'] round_length = data['round_length'] init_width = data['init_width'] if round_length.ndim == 0: round_lengths = [round_length] else: round_lengths = round_length # print('round_lengths', round_lengths) print('Processing ...') display_strokes_final(sess, paste_v3_func, strokes_data, init_cursors, image_size, round_lengths, init_width, regenerate_base, min_window_size=min_window_size, raster_size=raster_size) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--file', '-f', type=str, default='', help="define a npz path") args = parser.parse_args() visualize_drawing(args.file) ================================================ FILE: train_rough_photograph.py ================================================ import json import os import time import numpy as np import six import tensorflow as tf from PIL import Image import argparse import model_common_train as sketch_image_model from hyper_parameters import FLAGS, get_default_hparams_rough, get_default_hparams_normal from utils import create_summary, save_model, reset_graph, load_checkpoint from dataset_utils import load_dataset_training os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' tf.logging.set_verbosity(tf.logging.INFO) def should_save_log_img(step_): if step_ % 500 == 0: return True else: return False def save_log_images(sess, model, data_set, save_root, step_num, curr_photo_prob, interpolate_type, save_num=10): res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1) log_img_resolutions = [] for ii in range(save_num - 1): log_img_resolutions.append(model.hps.image_size_small + ii * res_gap) log_img_resolutions.append(model.hps.image_size_large) for res_i in range(len(log_img_resolutions)): resolution = log_img_resolutions[res_i] sub_save_root = os.path.join(save_root, 'res_' + str(resolution)) os.makedirs(sub_save_root, exist_ok=True) input_photos, target_sketches, init_cursors, image_size_rand = \ data_set.get_batch_from_memory(memory_idx=res_i, fixed_image_size=resolution, random_cursor=model.hps.random_cursor, photo_prob=curr_photo_prob, interpolate_type=interpolate_type) # input_photos: (N, image_size, image_size, 3), [0-stroke, 1-BG] # target_sketches: (N, image_size, image_size), [0-stroke, 1-BG] # init_cursors: (N, 1, 2), in size [0.0, 1.0) input_photo_val = input_photos init_cursor_input = [init_cursors for _ in range(model.total_loop)] init_cursor_input = np.concatenate(init_cursor_input, axis=0) image_size_input = [image_size_rand for _ in range(model.total_loop)] image_size_input = np.stack(image_size_input, axis=0) feed = { model.init_cursor: init_cursor_input, model.image_size: image_size_input, model.init_width: [model.hps.min_width], } for loop_i in range(model.total_loop): feed[model.input_photo_list[loop_i]] = input_photo_val raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb], feed) # (N, image_size, image_size), [0.0-stroke, 1.0-BG] raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8) input_photo = (np.array(input_photo_val[0, :, :, :]) * 255.0).astype(np.uint8) target_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8) raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8) pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png') input_save_path = os.path.join(sub_save_root, 'input.png') target_save_path = os.path.join(sub_save_root, 'gt.png') pred_rgb_save_root = os.path.join(sub_save_root, 'rgb') os.makedirs(pred_rgb_save_root, exist_ok=True) pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png') raster_images_pred = Image.fromarray(raster_images_pred, 'L') raster_images_pred.save(pred_save_path, 'PNG') input_photo = Image.fromarray(input_photo, 'RGB') input_photo.save(input_save_path, 'PNG') target_sketch = Image.fromarray(target_sketch, 'L') target_sketch.save(target_save_path, 'PNG') raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB') raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG') def train(sess, train_model, eval_sample_model, train_set, valid_set, sub_log_root, sub_snapshot_root, sub_log_img_root): # Setup summary writer. summary_writer = tf.summary.FileWriter(sub_log_root) print('-' * 100) # Calculate trainable params. t_vars = tf.trainable_variables() count_t_vars = 0 for var in t_vars: num_param = np.prod(var.get_shape().as_list()) count_t_vars += num_param print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param)) print('Total trainable variables %i.' % count_t_vars) print('-' * 100) # main train loop hps = train_model.hps start = time.time() # create saver snapshot_save_vars = [var for var in tf.global_variables() if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name] saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20) start_step = 1 print('start_step', start_step) mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))] for _ in range(start_step, hps.num_steps + 1): step = sess.run(train_model.global_step) # start from 0 count_step = min(step, hps.num_steps) curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) * (1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate) if hps.sn_loss_type == 'decreasing': assert hps.decrease_stop_steps <= hps.num_steps assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps) curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end) elif hps.sn_loss_type == 'fixed': curr_stroke_num_loss_weight = hps.stroke_num_loss_weight elif hps.sn_loss_type == 'increasing': curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps) curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k else: raise Exception('Unknown sn_loss_type', hps.sn_loss_type) if hps.early_pen_loss_type == 'head': curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length curr_early_pen_loss_start = 1 curr_early_pen_loss_end = curr_early_pen_loss_len elif hps.early_pen_loss_type == 'tail': curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2 curr_early_pen_loss_end = hps.max_seq_len curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len elif hps.early_pen_loss_type == 'move': curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2 curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2 else: raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type) curr_early_pen_loss_start = int(round(curr_early_pen_loss_start)) curr_early_pen_loss_end = int(round(curr_early_pen_loss_end)) if hps.photo_prob_type == 'increasing' or hps.photo_prob_type == 'interpolate': assert hps.photo_prob_end_step >= hps.photo_prob_start_step curr_photo_prob_k = 1.0 / float(hps.photo_prob_end_step - hps.photo_prob_start_step) curr_photo_prob = (count_step - hps.photo_prob_start_step) * curr_photo_prob_k curr_photo_prob = max(0.0, curr_photo_prob) curr_photo_prob = min(1.0, curr_photo_prob) interpolate_type = 'prob' if hps.photo_prob_type == 'increasing' else 'image' elif hps.photo_prob_type == 'zero': curr_photo_prob = 0.0 interpolate_type = 'prob' elif hps.photo_prob_type == 'one': curr_photo_prob = 1.0 interpolate_type = 'prob' else: raise Exception('Unknown photo_prob_type', hps.photo_prob_type) input_photos, target_sketches, init_cursors, image_sizes = \ train_set.get_batch_multi_res(loop_num=train_model.total_loop, random_cursor=hps.random_cursor, photo_prob=curr_photo_prob, interpolate_type=interpolate_type) # input_photos: list of (N, image_size, image_size, 3), [0-stroke, 1-BG] # target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG] # init_cursors: list of (N, 1, 2), in size [0.0, 1.0) init_cursors_input = np.concatenate(init_cursors, axis=0) image_size_input = np.stack(image_sizes, axis=0) feed = { train_model.init_cursor: init_cursors_input, train_model.image_size: image_size_input, train_model.init_width: [hps.min_width], train_model.lr: curr_learning_rate, train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight, train_model.early_pen_loss_start_idx: curr_early_pen_loss_start, train_model.early_pen_loss_end_idx: curr_early_pen_loss_end, train_model.last_step_num: float(step), } for layer_i in range(len(hps.perc_loss_layers)): feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i] for loop_i in range(train_model.total_loop): input_photo_val = input_photos[loop_i] target_sketch_val = target_sketches[loop_i] feed[train_model.input_photo_list[loop_i]] = input_photo_val feed[train_model.target_sketch_list[loop_i]] = np.expand_dims(target_sketch_val, axis=-1) (train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm, stroke_num_cost, early_pen_states_cost, pos_outside_cost, win_size_outside_cost, train_step) = sess.run([ train_model.cost, train_model.raster_cost, train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm, train_model.stroke_num_cost, train_model.early_pen_states_cost, train_model.pos_outside_cost, train_model.win_size_outside_cost, train_model.global_step ], feed) ## update mean_raster_loss for layer_i in range(len(hps.perc_loss_layers)): perc_relu_cost_raw = perc_relu_costs_raw[layer_i] mean_perc_relu_loss = mean_perc_relu_losses[layer_i] mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1) mean_perc_relu_losses[layer_i] = mean_perc_relu_loss _ = sess.run(train_model.train_op, feed) if step % 20 == 0 and step > 0: end = time.time() time_taken = end - start train_summary_map = { 'Train_Cost': train_cost, 'Train_raster_Cost': raster_cost, 'Train_stroke_num_Cost': stroke_num_cost, 'Train_early_pen_states_cost': early_pen_states_cost, 'Train_pos_outside_Cost': pos_outside_cost, 'Train_win_size_outside_Cost': win_size_outside_cost, 'Learning_Rate': curr_learning_rate, 'Time_Taken_Train': time_taken } for layer_i in range(len(hps.perc_loss_layers)): layer_name = hps.perc_loss_layers[layer_i] train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i] create_summary(summary_writer, train_summary_map, train_step) output_format = ('step: %d, lr: %.6f, ' 'snw: %.3f, ' 'cost: %.4f, ' 'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, ' 'pos_outside: %.4f, win_outside: %.4f, ' 'train_time_taken: %.1f') output_values = (step, curr_learning_rate, curr_stroke_num_loss_weight, train_cost, raster_cost, stroke_num_cost, early_pen_states_cost, pos_outside_cost, win_size_outside_cost, time_taken) output_log = output_format % output_values # print(output_log) tf.logging.info(output_log) start = time.time() if should_save_log_img(step) and step > 0: save_log_images(sess, eval_sample_model, valid_set, sub_log_img_root, step, curr_photo_prob, interpolate_type) if step % hps.save_every == 0 and step > 0: save_model(sess, saver, sub_snapshot_root, step) def trainer(model_params): np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) print('Hyperparams:') for key, val in six.iteritems(model_params.values()): print('%s = %s' % (key, str(val))) print('Loading data files.') print('-' * 100) datasets = load_dataset_training(FLAGS.dataset_dir, model_params) sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name) sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name) sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name) train_set = datasets[0] valid_set = datasets[1] train_model_params = datasets[2] eval_sample_model_params = datasets[3] eval_sample_model_params.loop_per_gpu = 1 eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu reset_graph() train_model = sketch_image_model.VirtualSketchingModel(train_model_params) eval_sample_model = sketch_image_model.VirtualSketchingModel(eval_sample_model_params, reuse=True) tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True) if train_model_params.raster_loss_base_type == 'perceptual': load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True) # Write config file to json file. os.makedirs(sub_log_root, exist_ok=True) os.makedirs(sub_log_img_root, exist_ok=True) os.makedirs(sub_snapshot_root, exist_ok=True) with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f: json.dump(train_model_params.values(), f, indent=True) train(sess, train_model, eval_sample_model, train_set, valid_set, sub_log_root, sub_snapshot_root, sub_log_img_root) def main(dataset_type): if dataset_type == 'rough': model_params = get_default_hparams_rough() elif dataset_type == 'face': model_params = get_default_hparams_normal() else: raise Exception('Unknown dataset_type:', dataset_type) trainer(model_params) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data', '-d', type=str, default='rough', choices=['rough', 'face'], help="The dataset type.") args = parser.parse_args() main(args.data) ================================================ FILE: train_vectorization.py ================================================ import json import os import time import numpy as np import six import tensorflow as tf from PIL import Image import model_common_train as sketch_vector_model from hyper_parameters import FLAGS, get_default_hparams_clean from utils import create_summary, save_model, reset_graph, load_checkpoint from dataset_utils import load_dataset_training os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' tf.logging.set_verbosity(tf.logging.INFO) def should_save_log_img(step_): if step_ % 500 == 0: return True else: return False def save_log_images(sess, model, data_set, save_root, step_num, save_num=10): res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1) log_img_resolutions = [] for ii in range(save_num - 1): log_img_resolutions.append(model.hps.image_size_small + ii * res_gap) log_img_resolutions.append(model.hps.image_size_large) for res_i in range(len(log_img_resolutions)): resolution = log_img_resolutions[res_i] sub_save_root = os.path.join(save_root, 'res_' + str(resolution)) os.makedirs(sub_save_root, exist_ok=True) input_photos, target_sketches, init_cursors, image_size_rand = \ data_set.get_batch_from_memory(memory_idx=res_i, vary_thickness=model.hps.vary_thickness, fixed_image_size=resolution, random_cursor=model.hps.random_cursor, init_cursor_on_undrawn_pixel=model.hps.init_cursor_on_undrawn_pixel) # input_photos: (N, image_size, image_size), [0-stroke, 1-BG] # target_sketches: (N, image_size, image_size), [0-stroke, 1-BG] # init_cursors: (N, 1, 2), in size [0.0, 1.0) if input_photos is not None: input_photo_val = np.expand_dims(input_photos, axis=-1) else: input_photo_val = np.expand_dims(target_sketches, axis=-1) init_cursor_input = [init_cursors for _ in range(model.total_loop)] init_cursor_input = np.concatenate(init_cursor_input, axis=0) image_size_input = [image_size_rand for _ in range(model.total_loop)] image_size_input = np.stack(image_size_input, axis=0) feed = { model.init_cursor: init_cursor_input, model.image_size: image_size_input, model.init_width: [model.hps.min_width], } for loop_i in range(model.total_loop): feed[model.input_photo_list[loop_i]] = input_photo_val raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb], feed) # (N, image_size, image_size), [0.0-stroke, 1.0-BG] raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8) input_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8) raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8) pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png') target_save_path = os.path.join(sub_save_root, 'gt.png') pred_rgb_save_root = os.path.join(sub_save_root, 'rgb') os.makedirs(pred_rgb_save_root, exist_ok=True) pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png') raster_images_pred = Image.fromarray(raster_images_pred, 'L') raster_images_pred.save(pred_save_path, 'PNG') input_sketch = Image.fromarray(input_sketch, 'L') input_sketch.save(target_save_path, 'PNG') raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB') raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG') def train(sess, train_model, eval_sample_model, train_set, val_set, sub_log_root, sub_snapshot_root, sub_log_img_root): # Setup summary writer. summary_writer = tf.summary.FileWriter(sub_log_root) print('-' * 100) # Calculate trainable params. t_vars = tf.trainable_variables() count_t_vars = 0 for var in t_vars: num_param = np.prod(var.get_shape().as_list()) count_t_vars += num_param print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param)) print('Total trainable variables %i.' % count_t_vars) print('-' * 100) # main train loop hps = train_model.hps start = time.time() # create saver snapshot_save_vars = [var for var in tf.global_variables() if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name] saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20) start_step = 1 print('start_step', start_step) mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))] for _ in range(start_step, hps.num_steps + 1): step = sess.run(train_model.global_step) # start from 0 count_step = min(step, hps.num_steps) curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) * (1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate) if hps.sn_loss_type == 'decreasing': assert hps.decrease_stop_steps <= hps.num_steps assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps) curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end) elif hps.sn_loss_type == 'fixed': curr_stroke_num_loss_weight = hps.stroke_num_loss_weight elif hps.sn_loss_type == 'increasing': curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps) curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k else: raise Exception('Unknown sn_loss_type', hps.sn_loss_type) if hps.early_pen_loss_type == 'head': curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length curr_early_pen_loss_start = 1 curr_early_pen_loss_end = curr_early_pen_loss_len elif hps.early_pen_loss_type == 'tail': curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2 curr_early_pen_loss_end = hps.max_seq_len curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len elif hps.early_pen_loss_type == 'move': curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps) curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2 curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2 else: raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type) curr_early_pen_loss_start = int(round(curr_early_pen_loss_start)) curr_early_pen_loss_end = int(round(curr_early_pen_loss_end)) input_photos, target_sketches, init_cursors, image_sizes = \ train_set.get_batch_multi_res(loop_num=train_model.total_loop, vary_thickness=hps.vary_thickness, random_cursor=hps.random_cursor, init_cursor_on_undrawn_pixel=hps.init_cursor_on_undrawn_pixel) # input_photos: list of (N, image_size, image_size), [0-stroke, 1-BG] # target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG] # init_cursors: list of (N, 1, 2), in size [0.0, 1.0) init_cursors_input = np.concatenate(init_cursors, axis=0) image_size_input = np.stack(image_sizes, axis=0) feed = { train_model.init_cursor: init_cursors_input, train_model.image_size: image_size_input, train_model.init_width: [hps.min_width], train_model.lr: curr_learning_rate, train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight, train_model.early_pen_loss_start_idx: curr_early_pen_loss_start, train_model.early_pen_loss_end_idx: curr_early_pen_loss_end, train_model.last_step_num: float(step), } for layer_i in range(len(hps.perc_loss_layers)): feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i] for loop_i in range(train_model.total_loop): if input_photos is not None: input_photo_val = np.expand_dims(input_photos[loop_i], axis=-1) else: input_photo_val = np.expand_dims(target_sketches[loop_i], axis=-1) feed[train_model.input_photo_list[loop_i]] = input_photo_val (train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm, stroke_num_cost, early_pen_states_cost, pos_outside_cost, win_size_outside_cost, train_step) = sess.run([ train_model.cost, train_model.raster_cost, train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm, train_model.stroke_num_cost, train_model.early_pen_states_cost, train_model.pos_outside_cost, train_model.win_size_outside_cost, train_model.global_step ], feed) ## update mean_raster_loss for layer_i in range(len(hps.perc_loss_layers)): perc_relu_cost_raw = perc_relu_costs_raw[layer_i] mean_perc_relu_loss = mean_perc_relu_losses[layer_i] mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1) mean_perc_relu_losses[layer_i] = mean_perc_relu_loss _ = sess.run(train_model.train_op, feed) if step % 20 == 0 and step > 0: end = time.time() time_taken = end - start train_summary_map = { 'Train_Cost': train_cost, 'Train_raster_Cost': raster_cost, 'Train_stroke_num_Cost': stroke_num_cost, 'Train_early_pen_states_cost': early_pen_states_cost, 'Train_pos_outside_Cost': pos_outside_cost, 'Train_win_size_outside_Cost': win_size_outside_cost, 'Learning_Rate': curr_learning_rate, 'Time_Taken_Train': time_taken } for layer_i in range(len(hps.perc_loss_layers)): layer_name = hps.perc_loss_layers[layer_i] train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i] create_summary(summary_writer, train_summary_map, train_step) output_format = ('step: %d, lr: %.6f, ' 'snw: %.3f, ' 'cost: %.4f, ' 'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, ' 'pos_outside: %.4f, win_outside: %.4f, ' 'train_time_taken: %.1f') output_values = (step, curr_learning_rate, curr_stroke_num_loss_weight, train_cost, raster_cost, stroke_num_cost, early_pen_states_cost, pos_outside_cost, win_size_outside_cost, time_taken) output_log = output_format % output_values # print(output_log) tf.logging.info(output_log) start = time.time() if should_save_log_img(step) and step > 0: save_log_images(sess, eval_sample_model, val_set, sub_log_img_root, step) if step % hps.save_every == 0 and step > 0: save_model(sess, saver, sub_snapshot_root, step) def trainer(model_params): np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) print('Hyperparams:') for key, val in six.iteritems(model_params.values()): print('%s = %s' % (key, str(val))) print('Loading data files.') print('-' * 100) datasets = load_dataset_training(FLAGS.dataset_dir, model_params) sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name) sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name) sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name) train_set = datasets[0] val_set = datasets[1] train_model_params = datasets[2] eval_sample_model_params = datasets[3] eval_sample_model_params.loop_per_gpu = 1 eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu reset_graph() train_model = sketch_vector_model.VirtualSketchingModel(train_model_params) eval_sample_model = sketch_vector_model.VirtualSketchingModel(eval_sample_model_params, reuse=True) tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=tfconfig) sess.run(tf.global_variables_initializer()) load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True) if train_model_params.raster_loss_base_type == 'perceptual': load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True) # Write config file to json file. os.makedirs(sub_log_root, exist_ok=True) os.makedirs(sub_log_img_root, exist_ok=True) os.makedirs(sub_snapshot_root, exist_ok=True) with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f: json.dump(train_model_params.values(), f, indent=True) train(sess, train_model, eval_sample_model, train_set, val_set, sub_log_root, sub_snapshot_root, sub_log_img_root) def main(): model_params = get_default_hparams_clean() trainer(model_params) if __name__ == '__main__': main() ================================================ FILE: utils.py ================================================ import os import cv2 import json import numpy as np import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt ############################################# # Tensorflow utils ############################################# def reset_graph(): """Closes the current default session and resets the graph.""" sess = tf.get_default_session() if sess: sess.close() tf.reset_default_graph() def load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_only=False, gen_model_pretrain=False, train_entire=False): if ras_only: load_var = {var.op.name: var for var in tf.global_variables() if 'raster_unit' in var.op.name} elif perceptual_only: load_var = {var.op.name: var for var in tf.global_variables() if 'VGG16' in var.op.name} elif train_entire: load_var = {var.op.name: var for var in tf.global_variables() if 'discriminator' not in var.op.name and 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name and 'beta1' not in var.op.name and 'beta2' not in var.op.name and 'global_step' not in var.op.name and 'Entire' not in var.op.name } else: if gen_model_pretrain: load_var = {var.op.name: var for var in tf.global_variables() if 'discriminator' not in var.op.name and 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name and 'beta1' not in var.op.name and 'beta2' not in var.op.name # and 'global_step' not in var.op.name } else: load_var = tf.global_variables() restorer = tf.train.Saver(load_var) if not ras_only: ckpt = tf.train.get_checkpoint_state(checkpoint_path) model_checkpoint_path = ckpt.model_checkpoint_path else: model_checkpoint_path = checkpoint_path print('Loading model %s' % model_checkpoint_path) restorer.restore(sess, model_checkpoint_path) snapshot_step = model_checkpoint_path[model_checkpoint_path.rfind('-') + 1:] return snapshot_step def create_summary(summary_writer, summ_map, step): for summ_key in summ_map: summ_value = summ_map[summ_key] summ = tf.summary.Summary() summ.value.add(tag=summ_key, simple_value=float(summ_value)) summary_writer.add_summary(summ, step) summary_writer.flush() def save_model(sess, saver, model_save_path, global_step): checkpoint_path = os.path.join(model_save_path, 'p2s') print('saving model %s.' % checkpoint_path) print('global_step %i.' % global_step) saver.save(sess, checkpoint_path, global_step=global_step) ############################################# # Utils for basic image processing ############################################# def normal(x, width): return (int)(x * (width - 1) + 0.5) def draw(f, width=128): x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f x1 = x0 + (x2 - x0) * x1 y1 = y0 + (y2 - y0) * y1 x0 = normal(x0, width * 2) x1 = normal(x1, width * 2) x2 = normal(x2, width * 2) y0 = normal(y0, width * 2) y1 = normal(y1, width * 2) y2 = normal(y2, width * 2) z0 = (int)(1 + z0 * width // 2) z2 = (int)(1 + z2 * width // 2) canvas = np.zeros([width * 2, width * 2]).astype('float32') tmp = 1. / 100 for i in range(100): t = i * tmp x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2) y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2) z = (int)((1-t) * z0 + t * z2) w = (1-t) * w0 + t * w2 cv2.circle(canvas, (y, x), z, w, -1) return 1 - cv2.resize(canvas, dsize=(width, width)) def rgb_trans(split_num, break_values): slice_per_split = split_num // 8 break_values_head, break_values_tail = break_values[:-1], break_values[1:] results = [] for split_i in range(8): break_value_head = break_values_head[split_i] break_value_tail = break_values_tail[split_i] slice_gap = float(break_value_tail - break_value_head) / float(slice_per_split) for slice_i in range(slice_per_split): slice_val = break_value_head + slice_gap * slice_i slice_val = int(round(slice_val)) results.append(slice_val) return results def get_colors(color_num): split_num = (color_num // 8 + 1) * 8 r_break_values = [0, 0, 0, 0, 128, 255, 255, 255, 128] g_break_values = [0, 0, 128, 255, 255, 255, 128, 0, 0] b_break_values = [128, 255, 255, 255, 128, 0, 0, 0, 0] r_rst_list = rgb_trans(split_num, r_break_values) g_rst_list = rgb_trans(split_num, g_break_values) b_rst_list = rgb_trans(split_num, b_break_values) assert len(r_rst_list) == len(g_rst_list) assert len(b_rst_list) == len(g_rst_list) rgb_color_list = [(r_rst_list[i], g_rst_list[i], b_rst_list[i]) for i in range(len(r_rst_list))] return rgb_color_list ############################################# # Utils for testing ############################################# def save_seq_data(save_root, save_filename, strokes_data, init_cursors, image_size, round_length, init_width): seq_save_root = os.path.join(save_root, 'seq_data') os.makedirs(seq_save_root, exist_ok=True) save_npz_path = os.path.join(seq_save_root, save_filename + '.npz') np.savez(save_npz_path, strokes_data=strokes_data, init_cursors=init_cursors, image_size=image_size, round_length=round_length, init_width=init_width) def image_pasting_v3_testing(patch_image, cursor, image_size, window_size_f, pasting_func, sess): """ :param patch_image: (raster_size, raster_size), [0.0-BG, 1.0-stroke] :param cursor: (2), in size [0.0, 1.0) :param window_size_f: (), float32, [0.0, image_size) :return: (image_size, image_size), [0.0-BG, 1.0-stroke] """ cursor_pos = cursor * float(image_size) pasted_image = sess.run(pasting_func.pasted_image, feed_dict={pasting_func.patch_canvas: np.expand_dims(patch_image, axis=-1), pasting_func.cursor_pos_a: cursor_pos, pasting_func.image_size_a: image_size, pasting_func.window_size_a: window_size_f}) # (image_size, image_size, 1), [0.0-BG, 1.0-stroke] pasted_image = pasted_image[:, :, 0] return pasted_image def draw_strokes(data, save_root, save_filename, input_img, image_size, init_cursor, infer_lengths, init_width, cursor_type, raster_size, min_window_size, sess, pasting_func=None, save_seq=False, draw_order=False): """ :param data: (N_strokes, 9): flag, x1, y1, x2, y2, r2, s2 :return: """ canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke] canvas_color = np.zeros((image_size, image_size, 3), dtype=np.float32) canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32) frames = [] cursor_idx = 0 stroke_count = len(data) color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255] color_idx = 0 for round_idx in range(len(infer_lengths)): round_length = infer_lengths[round_idx] cursor_pos = init_cursor[cursor_idx] # (2) cursor_idx += 1 prev_width = init_width prev_scaling = 1.0 prev_window_size = raster_size # (1) for round_inner_i in range(round_length): stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i curr_window_size = prev_scaling * prev_window_size curr_window_size = np.maximum(curr_window_size, min_window_size) curr_window_size = np.minimum(curr_window_size, image_size) pen_state = data[stroke_idx, 0] stroke_params = data[stroke_idx, 1:] # (8) x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5] x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0] x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0] x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0] widths = np.stack([prev_width, width2], axis=0) # (2) stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8) next_width = stroke_params[4] next_scaling = stroke_params[5] next_window_size = next_scaling * curr_window_size next_window_size = np.maximum(next_window_size, min_window_size) next_window_size = np.minimum(next_window_size, image_size) prev_width = next_width * curr_window_size / next_window_size prev_scaling = next_scaling prev_window_size = curr_window_size f = stroke_params_proc.tolist() # (8) f += [1.0, 1.0] gt_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG] gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size, curr_window_size, pasting_func, sess) # [0.0-BG, 1.0-stroke] if pen_state == 0: canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke] if draw_order: color_rgb = color_rgb_set[color_idx] # (3) in [0, 255] color_idx += 1 color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32) color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0) canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large), axis=-1) + color_stroke # (H, W, 3) if pen_state == 0: canvas_color = canvas_color * np.expand_dims((1.0 - gt_stroke_img_large), axis=-1) + color_stroke # (H, W, 3) # update cursor_pos based on hps.cursor_type new_cursor_offsets = stroke_params[2:4] * (curr_window_size / 2.0) # (1, 6), patch-level new_cursor_offset_next = new_cursor_offsets # important!!! new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1) cursor_pos_large = cursor_pos * float(image_size) stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level if cursor_type == 'next': cursor_pos_large = stroke_position_next # (2), large-level else: raise Exception('Unknown cursor_type') cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level cursor_pos = cursor_pos_large / float(image_size) frames.append(canvas.copy()) canvas = np.clip(canvas, 0.0, 1.0) canvas = np.round((1.0 - canvas) * 255.0).astype(np.uint8) # [0-stroke, 255-BG] os.makedirs(save_root, exist_ok=True) save_path = os.path.join(save_root, save_filename) canvas_img = Image.fromarray(canvas, 'L') canvas_img.save(save_path, 'PNG') if save_seq: seq_save_root = os.path.join(save_root, 'seq', save_filename[:-4]) os.makedirs(seq_save_root, exist_ok=True) for len_i in range(len(frames)): frame = frames[len_i] frame = np.round((1.0 - frame) * 255.0).astype(np.uint8) save_path = os.path.join(seq_save_root, str(len_i) + '.png') frame_img = Image.fromarray(frame, 'L') frame_img.save(save_path, 'PNG') if draw_order: order_save_root = os.path.join(save_root, 'order') order_comp_save_root = os.path.join(save_root, 'order-compare') os.makedirs(order_save_root, exist_ok=True) os.makedirs(order_comp_save_root, exist_ok=True) canvas_color = 255 - np.round(canvas_color * 255.0).astype(np.uint8) canvas_color_img = Image.fromarray(canvas_color, 'RGB') save_path = os.path.join(order_save_root, save_filename) canvas_color_img.save(save_path, 'PNG') canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8) # comparsions rows = 2 cols = 3 plt.figure(figsize=(5 * cols, 5 * rows)) plt.subplot(rows, cols, 1) plt.title('Input', fontsize=12) # plt.axis('off') input_rgb = input_img plt.imshow(input_rgb) # plt.subplot(rows, cols, 2) # plt.title('GT', fontsize=12) # # plt.axis('off') # gt_rgb = np.stack([gt_img for _ in range(3)], axis=2) # plt.imshow(gt_rgb) plt.subplot(rows, cols, 2) plt.title('Sketch', fontsize=12) # plt.axis('off') canvas_rgb = np.stack([canvas for _ in range(3)], axis=2) plt.imshow(canvas_rgb) plt.subplot(rows, cols, 4) plt.title('Sketch Order', fontsize=12) # plt.axis('off') plt.imshow(canvas_color) plt.subplot(rows, cols, 5) plt.title('Sketch Order with moving', fontsize=12) # plt.axis('off') plt.imshow(canvas_color_with_moving) plt.subplot(rows, cols, 6) plt.title('Order', fontsize=12) plt.axis('off') img_h = 5 img_w = 10 color_array = np.zeros([len(color_rgb_set) * img_h, img_w, 3], dtype=np.uint8) for i in range(len(color_rgb_set)): color_array[i * img_h: i * img_h + img_h, :, :] = color_rgb_set[i] plt.imshow(color_array) comp_save_path = os.path.join(order_comp_save_root, save_filename) plt.savefig(comp_save_path) plt.close() # plt.show() def update_hyperparams(model_params, model_base_dir, model_name, infer_dataset): with tf.gfile.Open(os.path.join(model_base_dir, model_name, 'model_config.json'), 'r') as f: data = json.load(f) ignored_keys = ['image_size_small', 'image_size_large', 'z_size', 'raster_perc_loss_layer', 'raster_loss_wk', 'decreasing_sn', 'raster_loss_weight'] for name in model_params._hparam_types.keys(): if name not in data and name not in ignored_keys: raise Exception(name, 'not in model_config.json') assert data['resize_method'] == 'AREA' data['data_set'] = infer_dataset fix_list = ['use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) pop_keys = ['gpus', 'image_size', 'resolution_type', 'loop_per_gpu', 'stroke_num_loss_weight_end', 'perc_loss_fuse_type', 'early_pen_length', 'early_pen_loss_type', 'early_pen_loss_weight', 'increase_start_steps', 'perc_loss_layers', 'sn_loss_type', 'photo_prob_end_step', 'sup_weight', 'gan_weight', 'base_raster_loss_base_type'] for pop_key in pop_keys: if pop_key in data.keys(): data.pop(pop_key) model_params.parse_json(json.dumps(data)) return model_params ================================================ FILE: vgg_utils/VGG16.py ================================================ import tensorflow as tf def vgg_net(x, n_classes, img_size, reuse, is_train=True, dropout_rate=0.5): # Define a scope for reusing the variables with tf.variable_scope('VGG16', reuse=reuse): x = tf.reshape(x, [-1, img_size, img_size, 1]) x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#1', x.shape) x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#2', x.shape) x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#3', x.shape) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#4', x.shape) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#5', x.shape) x_shape = x.get_shape().as_list() nodes = x_shape[1] * x_shape[2] * x_shape[3] x = tf.reshape(x, [-1, nodes]) x = tf.layers.dense(x, 4096, activation=tf.nn.relu) if is_train: x = tf.layers.dropout(x, dropout_rate) x = tf.layers.dense(x, 4096, activation=tf.nn.relu) if is_train: x = tf.layers.dropout(x, dropout_rate) out = tf.layers.dense(x, n_classes) print(out) return out def vgg_net_slim(x, img_size): return_map = {} # Define a scope for reusing the variables with tf.variable_scope('VGG16', reuse=tf.AUTO_REUSE): x = tf.reshape(x, [-1, img_size, img_size, 1]) x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU1_1'] = x x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU1_2'] = x x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#1', x.shape) #1 (?, 64, 64, 64) x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU2_1'] = x x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU2_2'] = x x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#2', x.shape) #2 (?, 32, 32, 128) x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU3_1'] = x x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU3_2'] = x x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU3_3'] = x x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#3', x.shape) #3 (?, 16, 16, 256) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU4_1'] = x x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU4_2'] = x x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU4_3'] = x x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#4', x.shape) #4 (?, 8, 8, 512) x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU5_1'] = x x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU5_2'] = x x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1, padding='SAME', activation=tf.nn.relu) return_map['ReLU5_3'] = x x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2) print('#5', x.shape) #5 (?, 4, 4, 512) return return_map ================================================ FILE: virtual_sketch_gui.py ================================================ import tkinter as tk from tkinter import filedialog, messagebox import subprocess import os import threading import glob import sys import shutil # ==== Nastavení cesty ke skriptům ==== MODEL_OPTIONS = { "Rough → Clean Sketch": "test_rough_sketch_simplification.py", "Photo → Line Drawing": "test_photograph_to_line.py", "Clean Sketch → Vector": "test_vectorization.py", } SVG_CONVERTER = os.path.join("tools", "svg_conversion.py") class VirtualSketchApp: def __init__(self, root): self.root = root self.root.title("Virtual Sketching GUI") self.input_file = None self.model_script = tk.StringVar(value=list(MODEL_OPTIONS.values())[0]) self.build_ui() def build_ui(self): tk.Label(self.root, text="1. Vyber vstupní obrázek:").pack(anchor="w") tk.Button(self.root, text="Vybrat obrázek", command=self.choose_file).pack(fill="x") self.file_label = tk.Label(self.root, text="Žádný soubor nevybrán", fg="gray") self.file_label.pack(anchor="w") tk.Label(self.root, text="2. Zvol model:").pack(anchor="w") for name, script in MODEL_OPTIONS.items(): tk.Radiobutton(self.root, text=name, variable=self.model_script, value=script).pack(anchor="w") tk.Button(self.root, text="3. Spustit zpracování", command=self.run_processing).pack(pady=10, fill="x") self.status = tk.Label(self.root, text="Připraven", fg="green") self.status.pack(anchor="w") def choose_file(self): path = filedialog.askopenfilename(filetypes=[ ("Obrázkové soubory", "*.png *.jpg *.jpeg *.bmp *.gif *.tif *.tiff"), ("PNG", "*.png"), ("JPEG", "*.jpg;*.jpeg"), ("BMP", "*.bmp"), ("GIF", "*.gif"), ("TIFF", "*.tif;*.tiff") ]) if path: self.input_file = path self.file_label.config(text=os.path.basename(path), fg="black") def run_processing(self): if not self.input_file: messagebox.showerror("Chyba", "Nejprve vyber obrázek.") return script = self.model_script.get() cmd = [sys.executable, script, "--input", self.input_file] def task(): self.status.config(text="Zpracovávám...", fg="blue") try: subprocess.run(cmd, check=True) self.status.config(text="✅ Hotovo", fg="green") self.move_outputs_to_sketches() self.run_svg_conversion() except subprocess.CalledProcessError: self.status.config(text="❌ Chyba při běhu skriptu", fg="red") threading.Thread(target=task).start() def move_outputs_to_sketches(self): if not self.input_file: return input_dir = os.path.dirname(self.input_file) input_base = os.path.splitext(os.path.basename(self.input_file))[0] sketches_dir = os.path.join(input_dir, "sketches") os.makedirs(sketches_dir, exist_ok=True) for ext in ["_0.npz", "_0_pred.png", "_input.png", "_0.svg"]: candidate = os.path.join(input_dir, f"{input_base}{ext}") if os.path.isfile(candidate): shutil.move(candidate, os.path.join(sketches_dir, os.path.basename(candidate))) def run_svg_conversion(self): if not self.input_file: return input_dir = os.path.dirname(self.input_file) input_base = os.path.splitext(os.path.basename(self.input_file))[0] npz_file = os.path.join(input_dir, f"{input_base}_0.npz") sketches_dir = os.path.join(input_dir, "sketches") npz_file_in_sketches = os.path.join(sketches_dir, f"{input_base}_0.npz") if not os.path.isfile(npz_file_in_sketches): print("⚠️ .npz soubor nebyl nalezen pro SVG konverzi.") return cmd = [sys.executable, SVG_CONVERTER, "--file", npz_file_in_sketches, "--svg_type", "single"] try: subprocess.run(cmd, check=True) svg_path = os.path.join(sketches_dir, f"{input_base}_0.svg") if os.path.isfile(svg_path): print(f"✅ SVG vytvořeno: {svg_path}") except subprocess.CalledProcessError: print("⚠️ Chyba při SVG konverzi") if __name__ == "__main__": root = tk.Tk() app = VirtualSketchApp(root) root.mainloop()