Repository: MarkMoHR/virtual_sketching
Branch: main
Commit: 229fefae08e4
Files: 28
Total size: 431.9 KB
Directory structure:
gitextract_xyepypzb/
├── .gitignore
├── LICENSE
├── README.md
├── README_CN.md
├── WINDOWS_INSTALL_GUIDE.md
├── dataset_utils.py
├── docs/
│ ├── assets/
│ │ ├── font.css
│ │ └── style.css
│ └── index.html
├── hyper_parameters.py
├── launch_gui.bat
├── model_common_test.py
├── model_common_train.py
├── rasterization_utils/
│ ├── NeuralRenderer.py
│ └── RealRenderer.py
├── rnn.py
├── subnet_tf_utils.py
├── test_photograph_to_line.py
├── test_rough_sketch_simplification.py
├── test_vectorization.py
├── tools/
│ ├── gif_making.py
│ ├── svg_conversion.py
│ └── visualize_drawing.py
├── train_rough_photograph.py
├── train_vectorization.py
├── utils.py
├── vgg_utils/
│ └── VGG16.py
└── virtual_sketch_gui.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
.idea/
data/
datas/
dataset/
datasets/
model/
models/
testData/
output/
outputs/
*.csv
# temporary files
*.txt~
*.pyc
.DS_Store
.gitignore~
*.h5
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021
[[Paper]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[Project Page]](https://markmohr.github.io/virtual_sketching/) | [[中文Readme]](/README_CN.md) | [[中文论文介绍]](https://blog.csdn.net/qq_33000225/article/details/118883153)
This code is used for **line drawing vectorization**, **rough sketch simplification** and **photograph to vector line drawing**.
## Outline
- [Dependencies](#dependencies)
- [Testing with Trained Weights](#testing-with-trained-weights)
- [Training](#training)
- [Citation](#citation)
- [Projects Using this Model/Method](#projects-using-this-modelmethod)
- [Blogs Mentioning this Paper](#blogs-mentioning-this-paper)
- [For Windows users](#-windows-users)
## Dependencies
- [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= version <=1.15.0)
- [opencv](https://opencv.org/) == 3.4.2
- [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0
- [scipy](https://www.scipy.org/) == 1.5.2
- [gizeh](https://github.com/Zulko/gizeh) == 0.1.11
## Testing with Trained Weights
### Model Preparation
Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing):
- `pretrain_clean_line_drawings` (105 MB): for vectorization
- `pretrain_rough_sketches` (105 MB): for rough sketch simplification
- `pretrain_faces` (105 MB): for photograph to line drawing
Then, place them in this file structure:
```
outputs/
snapshot/
pretrain_clean_line_drawings/
pretrain_rough_sketches/
pretrain_faces/
```
### Usage
Choose the image in the `sample_inputs/` directory, and run one of the following commands for each task. The results will be under `outputs/sampling/`.
``` python
python3 test_vectorization.py --input muten.png
python3 test_rough_sketch_simplification.py --input rocket.png
python3 test_photograph_to_line.py --input 1390.png
```
**Note!!!** Our approach starts drawing from a randomly selected initial position, so it outputs different results in every testing trial (some might be fine and some might not be good enough). It is recommended to do several trials to select the visually best result. The number of outputs can be defined by the `--sample` argument:
``` python
python3 test_vectorization.py --input muten.png --sample 10
python3 test_rough_sketch_simplification.py --input rocket.png --sample 10
python3 test_photograph_to_line.py --input 1390.png --sample 10
```
**Reproducing Paper Figures:** our results (download from [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)) are selected by doing a certain number of trials. Apparently, it is required to use the same initial drawing positions to reproduce our results.
### Additional Tools
#### a) Visualization
Our vector output is stored in a `npz` package. Run the following command to obtain the rendered output and the drawing order. Results will be under the same directory of the `npz` file.
``` python
python3 tools/visualize_drawing.py --file path/to/the/result.npz
```
#### b) GIF Making
To see the dynamic drawing procedure, run the following command to obtain the `gif`. Result will be under the same directory of the `npz` file.
``` python
python3 tools/gif_making.py --file path/to/the/result.npz
```
#### c) Conversion to SVG
Our vector output in a `npz` package is stored as Eq.(1) in the main paper. Run the following command to convert it to the `svg` format. Result will be under the same directory of the `npz` file.
``` python
python3 tools/svg_conversion.py --file path/to/the/result.npz
```
- The conversion is implemented in two modes (by setting the `--svg_type` argument):
- `single` (default): each stroke (a single segment) forms a path in the SVG file
- `cluster`: each continuous curve (with multiple strokes) forms a path in the SVG file
**Important Notes**
In SVG format, all the segments on a path share the same *stroke-width*. While in our stroke design, strokes on a common curve have different widths. Inside a stroke (a single segment), the thickness also changes linearly from an endpoint to another.
Therefore, neither of the two conversion methods above generate visually the same results as the ones in our paper.
*(Please mention this issue in your paper if you do qualitative comparisons with our results in SVG format.)*
## Training
### Preparations
Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing):
- `pretrain_neural_renderer` (40 MB): the pre-trained neural renderer
- `pretrain_perceptual_model` (691 MB): the pre-trained perceptual model for raster loss
Download the datasets [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing):
- `QuickDraw-clean` (14 MB): for clean line drawing vectorization. Taken from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) dataset.
- `QuickDraw-rough` (361 MB): for rough sketch simplification. Synthesized by the pencil drawing generation method from [Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation).
- `CelebAMask-faces` (370 MB): for photograph to line drawing. Processed from the [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset.
Then, place them in this file structure:
```
datasets/
QuickDraw-clean/
QuickDraw-rough/
CelebAMask-faces/
outputs/
snapshot/
pretrain_neural_renderer/
pretrain_perceptual_model/
```
### Running
It is recommended to train with multi-GPU. We train each task with 2 GPUs (each with 11 GB).
``` python
python3 train_vectorization.py
python3 train_rough_photograph.py --data rough
python3 train_rough_photograph.py --data face
```
## Citation
If you use the code and models please cite:
```
@article{mo2021virtualsketching,
title = {General Virtual Sketching Framework for Vector Line Art},
author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
year = {2021},
volume = {40},
number = {4},
pages = {51:1--51:14}
}
```
## Projects Using this Model/Method
| **[Painterly style transfer](https://github.com/xch-liu/Painterly-Style-Transfer) (TVCG 2023)** | **[Robot calligraphy](https://github.com/LoYuXr/CalliRewrite) (ICRA 2024)** |
|:-------------:|:-------------------:|
| | |
| **[Geometrized cartoon line inbetweening](https://github.com/lisiyao21/AnimeInbet) (ICCV 2023)** | **[Stroke correspondence and inbetweening](https://github.com/MarkMoHR/JoSTC) (TOG 2024)** |
| | |
| **[Modelling complex vector drawings](https://github.com/Co-do/Stroke-Cloud) (ICLR 2024)** | **[Sketch-to-Image Generation](https://github.com/BlockDetail/Block-and-Detail) (UIST 2024)** |
| | |
## Blogs Mentioning this Paper
- [The state of AI for hand-drawn animation inbetweening](https://yosefk.com/blog/the-state-of-ai-for-hand-drawn-animation-inbetweening.html)
## 🪟 Windows users
See [WINDOWS_INSTALL_GUIDE.md](WINDOWS_INSTALL_GUIDE.md) for a complete Windows installation guide and GUI usage.
================================================
FILE: README_CN.md
================================================
# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021
[[论文]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[项目主页]](https://markmohr.github.io/virtual_sketching/)
这份代码能用于实现:**线稿矢量化**、**粗糙草图简化**和**自然图像到矢量草图转换**。
## 目录
- [环境依赖](#环境依赖)
- [使用预训练模型测试](#使用预训练模型测试)
- [重新训练](#重新训练)
- [引用](#引用)
## 环境依赖
- [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= 版本 <=1.15.0)
- [opencv](https://opencv.org/) == 3.4.2
- [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0
- [scipy](https://www.scipy.org/) == 1.5.2
- [gizeh](https://github.com/Zulko/gizeh) == 0.1.11
## 使用预训练模型测试
### 模型下载与准备
在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型:
- `pretrain_clean_line_drawings` (105 MB): 用于线稿矢量化
- `pretrain_rough_sketches` (105 MB): 用于粗糙草图简化
- `pretrain_faces` (105 MB): 用于自然图像到矢量草图转换
然后,按照如下结构放置模型:
```
outputs/
snapshot/
pretrain_clean_line_drawings/
pretrain_rough_sketches/
pretrain_faces/
```
### 测试方法
在`sample_inputs/`文件夹下选择图像,然后根据任务类型运行下面其中一个命令。生成结果会在`outputs/sampling/`目录下看到。
``` python
python3 test_vectorization.py --input muten.png
python3 test_rough_sketch_simplification.py --input rocket.png
python3 test_photograph_to_line.py --input 1390.png
```
**注意!!!** 我们的方法从一个随机挑选的初始位置启动绘制,所以每跑一次测试理论上都会得到一个不同的结果(有可能效果不错,但也可能效果不是很好)。因此,建议做多几次测试来挑选看上去最好的结果。也可以通过设置 `--sample`参数来定义跑一次测试代码同时输出(不同结果)的数量:
``` python
python3 test_vectorization.py --input muten.png --sample 10
python3 test_rough_sketch_simplification.py --input rocket.png --sample 10
python3 test_photograph_to_line.py --input 1390.png --sample 10
```
**如何复现论文展示的结果?** 可以从[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载论文展示的结果。这些是我们通过若干次测试得到不同输出后挑选的最好的结果。显然,若要复现这些结果,需要使用相同的初始位置启动绘制。
### 其他工具
#### a) 可视化
我们的矢量输出均使用`npz` 文件包存储。运行以下的命令可以得到渲染后的结果以及绘制顺序。可以在`npz` 文件包相同的目录下找到这些可视化结果。
``` python
python3 tools/visualize_drawing.py --file path/to/the/result.npz
```
#### b) GIF制作
若要看到动态的绘制过程,可以运行以下命令来得到 `gif`。结果在`npz` 文件包相同的目录下。
``` python
python3 tools/gif_making.py --file path/to/the/result.npz
```
#### c) 转化为SVG
`npz` 文件包中的矢量结果均按照论文里面的公式(1)格式存储。可以运行以下命令行,来将其转化为 `svg` 文件格式。结果在`npz` 文件包相同的目录下。
``` python
python3 tools/svg_conversion.py --file path/to/the/result.npz
```
- 转化过程以两种模式实现(设置`--svg_type`参数):
- `single` (默认模式): 每个笔划(一根单独的曲线)构成SVG文件中的一个path路径
- `cluster`: 每个连续曲线(多个笔划)构成SVG文件中的一个path路径
**重要注意事项**
在SVG文件格式中,一个path上的所有线段均只有同一个线宽(*stroke-width*)。然而在我们论文里面,定义一个连续曲线上所有的笔划可以有不同的线宽。同时,对于一个单独的笔划(贝塞尔曲线),定义其线宽从一个端点到另一个端点线性递增或者递减。
因此,上述两个转化方法得到的SVG结果理论上都无法保证跟论文里面的结果在视觉上完全一致。(*假如你在论文里面使用这里转化后的SVG结果进行视觉上的对比,请提及此问题。*)
## 重新训练
### 训练准备
在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型:
- `pretrain_neural_renderer` (40 MB): 预训练好的神经网络渲染器
- `pretrain_perceptual_model` (691 MB): 预训练好的perceptual model,用于算 raster loss
在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载训练数据集:
- `QuickDraw-clean` (14 MB): 用于线稿矢量化。来自 [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset)数据集。
- `QuickDraw-rough` (361 MB): 用于粗糙草图简化。利用[Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation)里面的铅笔画图像生成方法合成。
- `CelebAMask-faces` (370 MB): 用于自然图像到矢量草图转换。使用[CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)数据集进行处理后得到。
然后,按照如下结构放置数据集:
```
datasets/
QuickDraw-clean/
QuickDraw-rough/
CelebAMask-faces/
outputs/
snapshot/
pretrain_neural_renderer/
pretrain_perceptual_model/
```
### 训练方法
建议使用多GPU进行训练。每个任务,我们均使用2个GPU(每个11 GB)来训练。
``` python
python3 train_vectorization.py
python3 train_rough_photograph.py --data rough
python3 train_rough_photograph.py --data face
```
## 引用
若使用此代码和模型,请引用本工作,谢谢!
```
@article{mo2021virtualsketching,
title = {General Virtual Sketching Framework for Vector Line Art},
author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
year = {2021},
volume = {40},
number = {4},
pages = {51:1--51:14}
}
```
================================================
FILE: WINDOWS_INSTALL_GUIDE.md
================================================
# 🪟 Windows Installation Guide for Virtual Sketching
This guide provides step-by-step instructions to set up and run the [Virtual Sketching](https://github.com/MarkMoHR/virtual_sketching) project on Windows using Anaconda and Python 3.6.
## ✅ Requirements
- Windows 10 or newer
- Anaconda installed
- Git installed (optional, but recommended)
---
## 📦 Step 1: Create and Activate Conda Environment
```bash
conda create -n virtual_sketching python=3.6 -y
conda activate virtual_sketching
```
## 📂 Step 2: Clone the Repository
```bash
cd D:\
git clone https://github.com/MarkMoHR/virtual_sketching.git
cd virtual_sketching
```
(If you plan to contribute, consider forking the repo and using your own URL.)
---
## 🔧 Step 3: Install Required Packages
### From `conda`:
```bash
conda install opencv=3.4.2 pillow=6.2.0 scipy=1.5.2 -y
conda install -c conda-forge pycairo gtk3 cffi -y
```
### Then remove default TensorFlow (if installed via conda):
```bash
conda remove tensorflow
```
### Install required packages via `pip`:
```bash
pip install tensorflow==1.15.0
pip install numpy gizeh cairocffi matplotlib svgwrite
```
> ⚠️ Do not upgrade `pillow`, `scipy`, or `tensorflow` — newer versions are incompatible.
---
## 🛠️ Step 4: Fix Backend Compatibility
In `utils.py`, near the top, ensure the following:
```python
from PIL import Image
import matplotlib
matplotlib.use('TkAgg') # force compatible backend
import matplotlib.pyplot as plt
```
This ensures proper rendering with tkinter on Windows.
---
## 🚀 Step 5: Run a Demo
From the project directory:
```bash
python test_vectorization.py --input sample_inputs\muten.png --sample 5
```
> ⚠️ This script only generates `.npz` and `.png` files. To convert to `.svg`, see next section.
---
## 🖼️ Optional: Use GUI Tool (Windows Only)
### Step 1: Launch GUI with provided batch file
Use the `runme.bat` file to activate the conda environment and launch the Python GUI:
```bat
rem runme.bat
set CONDAPATH=C:\ProgramData\anaconda3
set ENVNAME=virtual_sketching
call %CONDAPATH%\Scripts\activate.bat %ENVNAME%
python virtual_sketch_gui.py
pause
```
### Step 2: Select an input image and model
The GUI allows you to:
- Choose input image (PNG, JPEG, BMP, etc.)
- Select one of the three models
- Automatically runs processing and converts results to SVG
- Saves all outputs into a `sketches/` subfolder of the input image directory
---
## 🧠 Known Compatibility Notes
- Python 3.6 and TensorFlow 1.15 are required (due to use of `tensorflow.contrib`)
- Windows support requires manual setup of Gizeh and Cairo backends
- GPU usage optional — TensorFlow 1.15 requires CUDA 10.0 and cuDNN 7
---
## 🧩 Troubleshooting Tips
- ❌ `No module named '_cffi_backend'` → Run: `conda install -c conda-forge cffi`
- ❌ `ImportError: cannot import name 'draw_svg_from_npz'` → Use `svg_conversion.py` instead
- ❌ `.svg` looks wrong → Make sure you're using the official `svg_conversion.py` from `tools/`
- ❌ Missing `.svg`? → Use `virtual_sketch_gui.py` or run `tools/svg_conversion.py` manually on `.npz`
---
## 🤝 Contributing
Feel free to open issues or pull requests if you encounter bugs or want to improve the Windows support!
---
## 📁 Folder Structure Suggestion
```
virtual_sketching/
├── sample_inputs/
├── tools/
├── outputs/
├── virtual_sketch_gui.py
├── runme.bat
├── README.md
└── WINDOWS_INSTALL_GUIDE.md ← You are here
```
---
Made with ❤️ by the community to help Windows users get started!
================================================
FILE: dataset_utils.py
================================================
import os
import math
import random
import scipy.io
import numpy as np
import tensorflow as tf
from PIL import Image
from rasterization_utils.RealRenderer import GizehRasterizor as RealRenderer
def copy_hparams(hparams):
"""Return a copy of an HParams instance."""
return tf.contrib.training.HParams(**hparams.values())
class GeneralRawDataLoader(object):
def __init__(self,
image_path,
raster_size,
test_dataset):
self.image_path = image_path
self.raster_size = raster_size
self.test_dataset = test_dataset
def get_test_image(self, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
input_image_data, image_size_test = self.gen_input_images(self.image_path)
input_image_data = np.array(input_image_data,
dtype=np.float32) # (1, image_size, image_size, (3)), [0.0-strokes, 1.0-BG]
return input_image_data, \
self.gen_init_cursors(input_image_data, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \
image_size_test
def gen_input_images(self, image_path):
img = Image.open(image_path).convert('RGB')
height, width = img.height, img.width
max_dim = max(height, width)
img = np.array(img, dtype=np.uint8)
if height != width:
# Padding to a square image
if self.test_dataset == 'clean_line_drawings':
pad_value = [255, 255, 255]
elif self.test_dataset == 'faces':
pad_value = [0, 0, 0]
else:
# TODO: find better padding pixel value
pad_value = img[height - 10, width - 10]
img_r, img_g, img_b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
pad_width = max_dim - width
pad_height = max_dim - height
pad_img_r = np.pad(img_r, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[0])
pad_img_g = np.pad(img_g, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[1])
pad_img_b = np.pad(img_b, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[2])
image_array = np.stack([pad_img_r, pad_img_g, pad_img_b], axis=-1)
else:
image_array = img
if self.test_dataset == 'faces' and max_dim != 256:
image_array_resize = Image.fromarray(image_array, 'RGB')
image_array_resize = image_array_resize.resize(size=(256, 256), resample=Image.BILINEAR)
image_array = np.array(image_array_resize, dtype=np.uint8)
assert image_array.shape[0] == image_array.shape[1]
img_size = image_array.shape[0]
image_array = image_array.astype(np.float32)
if self.test_dataset == 'clean_line_drawings':
image_array = image_array[:, :, 0] / 255.0 # [0.0-stroke, 1.0-BG]
else:
image_array = image_array / 255.0 # [0.0-stroke, 1.0-BG]
image_array = np.expand_dims(image_array, axis=0)
return image_array, img_size
def crop_patch(self, image, center, image_size, crop_size):
x0 = center[0] - crop_size // 2
x1 = x0 + crop_size
y0 = center[1] - crop_size // 2
y1 = y0 + crop_size
x0 = max(0, min(x0, image_size))
y0 = max(0, min(y0, image_size))
x1 = max(0, min(x1, image_size))
y1 = max(0, min(y1, image_size))
patch = image[y0:y1, x0:x1]
return patch
def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):
# sketch_image: [0.0-stroke, 1.0-BG]
image_size = sketch_image.shape[0]
if np.sum(1.0 - sketch_image) == 0:
center = np.zeros((2), dtype=np.int32)
return center
else:
while True:
center = np.random.randint(0, image_size, size=(2)) # (2), in large size
patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
if np.sum(patch) != 0:
if not init_cursor_on_undrawn_pixel:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
else:
center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)
if np.sum(center_patch) != 0:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
init_cursor_batch_list = []
for cursor_i in range(init_cursor_num):
if random_pos:
init_cursor_batch = []
for i in range(len(sketch_data)):
sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG]
center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)
init_cursor_batch.append(center)
init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2)
else:
raise Exception('Not finished')
init_cursor_batch_list.append(init_cursor_batch)
if init_cursor_num == 1:
init_cursor_batch = init_cursor_batch_list[0]
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2)
else:
init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2)
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
np.float32) # (N, init_cursor_num, 1, 2)
return init_cursor_batch
def load_dataset_testing(test_data_base_dir, test_dataset, test_img_name, model_params):
assert test_dataset in ['clean_line_drawings', 'rough_sketches', 'faces']
img_path = os.path.join(test_data_base_dir, test_dataset, test_img_name)
print('Loaded {} from {}'.format(img_path, test_dataset))
eval_model_params = copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.batch_size = 1
eval_model_params.model_mode = 'sample'
sample_model_params = copy_hparams(eval_model_params)
sample_model_params.batch_size = 1 # only sample one at a time
sample_model_params.max_seq_len = 1 # sample one point at a time
test_set = GeneralRawDataLoader(img_path, eval_model_params.raster_size, test_dataset=test_dataset)
result = [test_set, eval_model_params, sample_model_params]
return result
class GeneralMultiObjectDataLoader(object):
def __init__(self,
stroke3_data,
batch_size,
raster_size,
image_size_small,
image_size_large,
is_bin,
is_train):
self.batch_size = batch_size # minibatch size
self.raster_size = raster_size
self.image_size_small = image_size_small
self.image_size_large = image_size_large
self.is_bin = is_bin
self.is_train = is_train
self.num_batches = len(stroke3_data) // self.batch_size
self.batch_idx = -1
print('batch_size', batch_size, ', num_batches', self.num_batches)
self.rasterizor = RealRenderer()
self.memory_sketch_data_batch = []
assert type(stroke3_data) is list
self.preprocess_rand_data(stroke3_data)
def preprocess_rand_data(self, stroke3):
if self.is_train:
random.shuffle(stroke3)
self.stroke3_data = stroke3
def cal_dist(self, posA, posB):
return np.sqrt(np.sum(np.power(posA - posB, 2)))
def invalid_position(self, pos, obj_size, pos_list, size_list):
if len(pos_list) == 0:
return False
pos_a = pos
size_a = obj_size
for i in range(len(pos_list)):
pos_b = pos_list[i]
size_b = size_list[i]
if self.cal_dist(pos_a, pos_b) < ((size_a + size_b) // 4):
return True
return False
def get_object_info(self, image_size, vary_thickness=True, try_total_times=3):
if image_size <= 172:
obj_num = 1
obj_thickness_list = [3]
elif image_size <= 225:
obj_num = random.randint(1, 2)
obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))
elif image_size <= 278:
obj_num = 2
obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))
elif image_size <= 331:
obj_num = random.randint(2, 3)
while True:
obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))
if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:
break
elif image_size <= 384:
obj_num = 3
while True:
obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))
if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:
break
else:
raise Exception('Invalid image_size', image_size)
if not vary_thickness:
num_item = len(obj_thickness_list)
obj_thickness_list = [3 for _ in range(num_item)]
obj_pos_list = []
obj_size_list = []
if obj_num == 1:
obj_size_list.append(image_size)
center = (image_size // 2, image_size // 2)
obj_pos_list.append(center)
else:
for obj_i in range(obj_num):
for try_i in range(try_total_times):
obj_size = random.randint(128, image_size * 3 // 4)
obj_center = np.random.randint(obj_size // 3, image_size - (obj_size // 3) + 1, size=(2))
if not self.invalid_position(obj_center, obj_size, obj_pos_list,
obj_size_list) or try_i == try_total_times - 1:
obj_pos_list.append(obj_center)
obj_size_list.append(obj_size)
break
assert len(obj_size_list) == len(obj_pos_list) == len(obj_thickness_list) == obj_num
return obj_num, obj_size_list, obj_pos_list, obj_thickness_list
def object_pasting(self, obj_img, canvas_img, center):
c_y, c_x = center[0], center[1]
obj_size = obj_img.shape[0]
canvas_size = canvas_img.shape[0]
box_left = max(0, c_x - obj_size // 2)
box_right = min(canvas_size, c_x + obj_size // 2)
box_up = max(0, c_y - obj_size // 2)
box_bottom = min(canvas_size, c_y + obj_size // 2)
box_canvas = canvas_img[box_up: box_bottom, box_left: box_right]
obj_box_up = box_up - (c_y - obj_size // 2)
obj_box_left = box_left - (c_x - obj_size // 2)
box_obj = obj_img[obj_box_up: obj_box_up + (box_bottom - box_up),
obj_box_left: obj_box_left + (box_right - box_left)]
box_canvas += box_obj
rst_canvas = np.copy(canvas_img)
rst_canvas[box_up: box_bottom, box_left: box_right] = box_canvas
rst_canvas = np.clip(rst_canvas, 0.0, 1.0)
return rst_canvas
def get_multi_object_image(self, img_size, vary_thickness):
object_num, object_size_list, object_pos_list, object_thickness_list = self.get_object_info(
img_size, vary_thickness=vary_thickness)
canvas = np.zeros(shape=(img_size, img_size), dtype=np.float32)
for obj_i in range(object_num):
rand_idx = np.random.randint(0, len(self.stroke3_data))
rand_stroke3 = self.stroke3_data[rand_idx] # (N_points, 3)
object_size = object_size_list[obj_i]
object_enter = object_pos_list[obj_i]
object_thickness = object_thickness_list[obj_i]
stroke_image = self.gen_stroke_images([rand_stroke3], object_size, object_thickness)
stroke_image = 1.0 - stroke_image[0] # (image_size, image_size), [0.0-BG, 1.0-strokes]
canvas = self.object_pasting(stroke_image, canvas, object_enter) # [0.0-BG, 1.0-strokes]
canvas = 1.0 - canvas # [0.0-strokes, 1.0-BG]
return canvas
def get_batch_from_memory(self, memory_idx, vary_thickness, fixed_image_size=-1, random_cursor=True,
init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
if len(self.memory_sketch_data_batch) >= memory_idx + 1:
sketch_data_batch = self.memory_sketch_data_batch[memory_idx]
sketch_data_batch = np.expand_dims(sketch_data_batch,
axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
image_size_rand = sketch_data_batch.shape[1]
else:
if fixed_image_size == -1:
image_size_rand = random.randint(self.image_size_small, self.image_size_large)
else:
image_size_rand = fixed_image_size
multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness) # [0.0-strokes, 1.0-BG]
self.memory_sketch_data_batch.append(multi_obj_image)
sketch_data_batch = np.expand_dims(multi_obj_image,
axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
return None, sketch_data_batch, \
self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \
image_size_rand
def get_batch_multi_res(self, loop_num, vary_thickness, random_cursor=True,
init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
sketch_data_batch = []
init_cursors_batch = []
image_size_batch = []
batch_size_per_loop = self.batch_size // loop_num
for loop_i in range(loop_num):
image_size_rand = random.randint(self.image_size_small, self.image_size_large)
sketch_data_sub_batch = []
for batch_i in range(batch_size_per_loop):
multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness) # [0.0-strokes, 1.0-BG]
sketch_data_sub_batch.append(multi_obj_image)
sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG]
init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor,
init_cursor_on_undrawn_pixel, init_cursor_num)
sketch_data_batch.append(sketch_data_sub_batch)
init_cursors_batch.append(init_cursors_sub_batch)
image_size_batch.append(image_size_rand)
return None, \
sketch_data_batch, \
init_cursors_batch, \
image_size_batch
def gen_stroke_images(self, stroke3_list, image_size, stroke_width):
"""
:param stroke3_list: list of (batch_size,), each with (N_points, 3)
:param image_size:
:return:
"""
gt_image_array = self.rasterizor.raster_func(stroke3_list, image_size, stroke_width=stroke_width,
is_bin=self.is_bin, version='v2')
gt_image_array = np.stack(gt_image_array, axis=0)
gt_image_array = 1.0 - gt_image_array # (batch_size, image_size, image_size), [0.0-strokes, 1.0-BG]
return gt_image_array
def crop_patch(self, image, center, image_size, crop_size):
x0 = center[0] - crop_size // 2
x1 = x0 + crop_size
y0 = center[1] - crop_size // 2
y1 = y0 + crop_size
x0 = max(0, min(x0, image_size))
y0 = max(0, min(y0, image_size))
x1 = max(0, min(x1, image_size))
y1 = max(0, min(y1, image_size))
patch = image[y0:y1, x0:x1]
return patch
def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):
# sketch_image: [0.0-stroke, 1.0-BG]
image_size = sketch_image.shape[0]
if np.sum(1.0 - sketch_image) == 0:
center = np.zeros((2), dtype=np.int32)
return center
else:
while True:
center = np.random.randint(0, image_size, size=(2)) # (2), in large size
patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
if np.sum(patch) != 0:
if not init_cursor_on_undrawn_pixel:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
else:
center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)
if np.sum(center_patch) != 0:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
init_cursor_batch_list = []
for cursor_i in range(init_cursor_num):
if random_pos:
init_cursor_batch = []
for i in range(len(sketch_data)):
sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG]
center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)
init_cursor_batch.append(center)
init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2)
else:
raise Exception('Not finished')
init_cursor_batch_list.append(init_cursor_batch)
if init_cursor_num == 1:
init_cursor_batch = init_cursor_batch_list[0]
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2)
else:
init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2)
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
np.float32) # (N, init_cursor_num, 1, 2)
return init_cursor_batch
def load_dataset_multi_object(dataset_base_dir, model_params):
train_stroke3_data = []
val_stroke3_data = []
if model_params.data_set == 'clean_line_drawings':
def load_qd_npz_data(npz_path):
data = np.load(npz_path, encoding='latin1', allow_pickle=True)
selected_strokes3 = data['stroke3'] # (N_sketches,), each with (N_points, 3)
selected_strokes3 = selected_strokes3.tolist()
return selected_strokes3
base_dir_clean = 'QuickDraw-clean'
cates = ['airplane', 'bus', 'car', 'sailboat', 'bird', 'cat', 'dog',
# 'rabbit',
'tree', 'flower',
# 'circle', 'line',
'zigzag'
]
for cate in cates:
train_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'train', cate + '.npz')
val_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'test', cate + '.npz')
print(train_cate_sketch_data_npz_path)
train_cate_stroke3_data = load_qd_npz_data(
train_cate_sketch_data_npz_path) # list of (N_sketches,), each with (N_points, 3)
val_cate_stroke3_data = load_qd_npz_data(val_cate_sketch_data_npz_path)
train_stroke3_data += train_cate_stroke3_data
val_stroke3_data += val_cate_stroke3_data
else:
raise Exception('Unknown data type:', model_params.data_set)
print('Loaded {}/{} from {}'.format(len(train_stroke3_data), len(val_stroke3_data), model_params.data_set))
print('model_params.max_seq_len %i.' % model_params.max_seq_len)
eval_sample_model_params = copy_hparams(model_params)
eval_sample_model_params.use_input_dropout = 0
eval_sample_model_params.use_recurrent_dropout = 0
eval_sample_model_params.use_output_dropout = 0
eval_sample_model_params.batch_size = 1 # only sample one at a time
eval_sample_model_params.model_mode = 'eval_sample'
train_set = GeneralMultiObjectDataLoader(train_stroke3_data,
model_params.batch_size, model_params.raster_size,
model_params.image_size_small, model_params.image_size_large,
model_params.bin_gt, is_train=True)
val_set = GeneralMultiObjectDataLoader(val_stroke3_data,
eval_sample_model_params.batch_size, eval_sample_model_params.raster_size,
eval_sample_model_params.image_size_small,
eval_sample_model_params.image_size_large,
eval_sample_model_params.bin_gt, is_train=False)
result = [train_set, val_set, model_params, eval_sample_model_params]
return result
class GeneralDataLoaderMultiObjectRough(object):
def __init__(self,
photo_data,
sketch_data,
texture_data,
shadow_data,
batch_size,
raster_size,
image_size_small,
image_size_large,
is_train):
self.batch_size = batch_size # minibatch size
self.raster_size = raster_size
self.image_size_small = image_size_small
self.image_size_large = image_size_large
self.is_train = is_train
assert photo_data is not None
assert len(photo_data) == len(sketch_data)
# self.num_batches = len(sketch_data) // self.batch_size
self.batch_idx = -1
print('batch_size', batch_size)
assert type(photo_data) is list
assert type(sketch_data) is list
assert type(texture_data) is list and len(texture_data) > 0
assert type(shadow_data) is list and len(shadow_data) > 0
self.photo_data = photo_data
self.sketch_data = sketch_data
self.texture_data = texture_data # list of (H, W, 3), [0, 255], uint8
self.shadow_data = shadow_data # list of (H, W), [0, 255], uint8
self.memory_photo_data_batch = []
self.memory_sketch_data_batch = []
def rough_augmentation(self, raw_photo, texture_prob=0.20, noise_prob=0.15, shadow_prob=0.20):
# raw_photo: (H, W), [0.0-stroke, 1.0-BG]
aug_photo_rgb = np.stack([raw_photo for _ in range(3)], axis=-1)
def texture_generation(texture_list, image_shape):
while True:
random_texture_id = random.randint(0, len(texture_list) - 1)
texture_large = texture_list[random_texture_id]
t_w, t_h = texture_large.shape[1], texture_large.shape[0]
i_w, i_h = image_shape[1], image_shape[0]
if t_h >= i_h and t_w >= i_w:
texture_large = np.copy(texture_large).astype(np.float32)
crop_y = random.randint(0, t_h - i_h)
crop_x = random.randint(0, t_w - i_w)
crop_texture = texture_large[crop_y: crop_y + i_h, crop_x: crop_x + i_w, :]
return crop_texture
def texture_change(rough_img_, all_textures):
# rough_img_: (H, W, 3), [0.0-stroke, 1.0-BG]
texture_image = texture_generation(all_textures, rough_img_.shape) # (h, w, 3)
texture_image /= 255.0
rand_b = np.random.uniform(1.0, 2.0, size=rough_img_.shape)
textured_img = rough_img_ * (texture_image / rand_b + (rand_b - 1.0) / rand_b) # [0.0, 1.0]
return textured_img
def noise_change(rough_img_, noise_scale=25):
# rough_img_: (H, W, 3), [0.0, 1.0]
rough_img_255 = rough_img_ * 255.0
rand_noise = np.random.uniform(-1.0, 1.0, size=rough_img_255.shape) * noise_scale
# rand_noise = np.random.normal(size=rough_img.shape) * noise_scale
noise_img = rough_img_255 + rand_noise
noise_img = np.clip(noise_img, 0.0, 255.0)
noise_img /= 255.0
return noise_img
def shadow_change(rough_img_, all_shadows):
# rough_img_: (H, W, 3), [0.0, 1.0]
rough_img_255 = rough_img_ * 255.0
shadow_i = random.randint(0, len(all_shadows) - 1)
shadow_full = all_shadows[shadow_i] # (H, W), [0, 255]
shadow_img_size = shadow_full.shape[0]
while True:
position = np.random.randint(-shadow_img_size // 2, shadow_img_size // 2, (2))
if abs(position[0]) > (shadow_img_size // 8) and abs(position[1]) > (shadow_img_size // 8):
break
position += (shadow_img_size // 2)
crop_up = shadow_img_size - position[0]
crop_left = shadow_img_size - position[1]
shadow_image_large = shadow_full[crop_up: crop_up + shadow_img_size, crop_left: crop_left + shadow_img_size]
shadow_bg = Image.fromarray(shadow_image_large, 'L')
shadow_bg = shadow_bg.resize(size=(rough_img_255.shape[1], rough_img_255.shape[0]), resample=Image.BILINEAR)
shadow_bg = np.array(shadow_bg, dtype=np.float32) / 255.0 # [0.0-shadow, 1.0-BG]
shadow_bg = np.stack([shadow_bg for _ in range(3)], axis=-1)
shadow_img = rough_img_255 * shadow_bg
shadow_img /= 255.0
return shadow_img
if random.random() <= texture_prob:
aug_photo_rgb = texture_change(aug_photo_rgb, self.texture_data) # (H, W, 3), [0.0, 1.0]
if random.random() <= noise_prob:
aug_photo_rgb = noise_change(aug_photo_rgb) # (H, W, 3), [0.0, 1.0]
if random.random() <= shadow_prob:
aug_photo_rgb = shadow_change(aug_photo_rgb, self.shadow_data) # (H, W, 3), [0.0, 1.0]
return aug_photo_rgb
def image_interpolation(self, photo, sketch, photo_prob):
interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)
interp_photo = np.clip(interp_photo, 0.0, 1.0)
return interp_photo
def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,
photo_prob=1.0, init_cursor_num=1):
if len(self.memory_sketch_data_batch) >= memory_idx + 1:
photo_data_batch = self.memory_photo_data_batch[memory_idx]
sketch_data_batch = self.memory_sketch_data_batch[memory_idx]
image_size_rand = sketch_data_batch.shape[1]
else:
if fixed_image_size == -1:
image_size_rand = random.randint(self.image_size_small, self.image_size_large)
else:
image_size_rand = fixed_image_size
# photo_prob = 0.0 if photo_prob_type == 'zero' else 1.0
photo_data_batch, sketch_data_batch = self.select_sketch(
image_size_rand) # both: (H, W), [0.0-stroke, 1.0-BG]
photo_data_batch = self.rough_augmentation(photo_data_batch) # (H, W, 3), [0.0-stroke, 1.0-BG]
self.memory_photo_data_batch.append(photo_data_batch)
self.memory_sketch_data_batch.append(sketch_data_batch)
if interpolate_type == 'prob':
if random.random() >= photo_prob:
photo_data_batch = np.stack([sketch_data_batch for _ in range(3)],
axis=-1) # (H, W, 3), [0.0-stroke, 1.0-BG]
elif interpolate_type == 'image':
photo_data_batch = self.image_interpolation(
photo_data_batch, np.stack([sketch_data_batch for _ in range(3)], axis=-1), photo_prob)
else:
raise Exception('Unknown interpolate_type', interpolate_type)
photo_data_batch = np.expand_dims(photo_data_batch, axis=0) # (1, image_size, image_size, 3)
sketch_data_batch = np.expand_dims(sketch_data_batch,
axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
return photo_data_batch, sketch_data_batch, \
self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand
def select_sketch(self, image_size_rand):
resolution_idx = image_size_rand - self.image_size_small
img_idx = random.randint(0, len(self.sketch_data[resolution_idx]) - 1)
assert img_idx != -1
selected_sketch = self.sketch_data[resolution_idx][img_idx] # [0-stroke, 255-BG], uint8
selected_photo = self.photo_data[resolution_idx][img_idx] # [0-stroke, 255-BG], uint8
rst_sketch_image = selected_sketch.astype(np.float32) / 255.0 # [0.0-stroke, 1.0-BG]
rst_photo_image = selected_photo.astype(np.float32) / 255.0 # [0.0-stroke, 1.0-BG]
return rst_photo_image, rst_sketch_image
def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):
photo_data_batch = []
sketch_data_batch = []
init_cursors_batch = []
image_size_batch = []
batch_size_per_loop = self.batch_size // loop_num
for loop_i in range(loop_num):
image_size_rand = random.randint(self.image_size_small, self.image_size_large)
photo_data_sub_batch = []
sketch_data_sub_batch = []
for img_i in range(batch_size_per_loop):
photo_patch, sketch_patch = self.select_sketch(image_size_rand) # both: (H, W), [0.0-stroke, 1.0-BG]
photo_patch = self.rough_augmentation(photo_patch) # (H, W, 3), [0.0-stroke, 1.0-BG]
if interpolate_type == 'prob':
if random.random() >= photo_prob:
photo_patch = np.stack([sketch_patch for _ in range(3)],
axis=-1) # (H, W, 3), [0.0-stroke, 1.0-BG]
elif interpolate_type == 'image':
photo_patch = self.image_interpolation(
photo_patch, np.stack([sketch_patch for _ in range(3)], axis=-1), photo_prob)
else:
raise Exception('Unknown interpolate_type', interpolate_type)
photo_data_sub_batch.append(photo_patch)
sketch_data_sub_batch.append(sketch_patch)
photo_data_sub_batch = np.stack(photo_data_sub_batch,
axis=0) # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]
sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG]
init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)
photo_data_batch.append(photo_data_sub_batch)
sketch_data_batch.append(sketch_data_sub_batch)
init_cursors_batch.append(init_cursors_sub_batch)
image_size_batch.append(image_size_rand)
return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch
def crop_patch(self, image, center, image_size, crop_size):
x0 = center[0] - crop_size // 2
x1 = x0 + crop_size
y0 = center[1] - crop_size // 2
y1 = y0 + crop_size
x0 = max(0, min(x0, image_size))
y0 = max(0, min(y0, image_size))
x1 = max(0, min(x1, image_size))
y1 = max(0, min(y1, image_size))
patch = image[y0:y1, x0:x1]
return patch
def gen_init_cursor_single(self, sketch_image):
# sketch_image: [0.0-stroke, 1.0-BG]
image_size = sketch_image.shape[0]
if np.sum(1.0 - sketch_image) == 0:
center = np.zeros((2), dtype=np.int32)
return center
else:
while True:
center = np.random.randint(0, image_size, size=(2)) # (2), in large size
patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
if np.sum(patch) != 0:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):
init_cursor_batch_list = []
for cursor_i in range(init_cursor_num):
if random_pos:
init_cursor_batch = []
for i in range(len(sketch_data)):
sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG]
center = self.gen_init_cursor_single(sketch_image)
init_cursor_batch.append(center)
init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2)
else:
raise Exception('Not finished')
init_cursor_batch_list.append(init_cursor_batch)
if init_cursor_num == 1:
init_cursor_batch = init_cursor_batch_list[0]
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2)
else:
init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2)
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
np.float32) # (N, init_cursor_num, 1, 2)
return init_cursor_batch
def load_dataset_multi_object_rough(dataset_base_dir, model_params):
train_photo_data = []
train_sketch_data = []
val_photo_data = []
val_sketch_data = []
texture_data = []
shadow_data = []
if model_params.data_set == 'rough_sketches':
base_dir_rough = 'QuickDraw-rough'
def load_sketch_data(mat_path):
sketch_data_mat = scipy.io.loadmat(mat_path)
sketch_data = sketch_data_mat['sketch_array']
sketch_data = np.array(sketch_data, dtype=np.uint8) # (N, resolution, resolution), [0-strokes, 255-BG]
return sketch_data
def load_photo_data(mat_path):
photo_data_mat = scipy.io.loadmat(mat_path)
photo_data = photo_data_mat['image_array']
photo_data = np.array(photo_data, dtype=np.uint8) # (N, resolution, resolution), [0-strokes, 255-BG]
return photo_data
def load_normal_data(img_path):
assert '.png' in img_path or '.jpg'
img = Image.open(img_path).convert('RGB')
img = np.array(img, dtype=np.uint8) # (H, W, 3), [0-stroke, 255-BG], uint8
return img
## Texture
texture_base = os.path.join(dataset_base_dir, base_dir_rough, 'texture')
all_texture = os.listdir(texture_base)
all_texture.sort()
for file_name in all_texture:
texture_path = os.path.join(texture_base, file_name)
texture_uint8 = load_normal_data(texture_path)
texture_data.append(texture_uint8)
## shadow
def process_angle(img, temp_size):
padded_img = img.copy()
padded_img[0, 0:temp_size] -= 1
padded_img[0, -(temp_size + 1):-1] -= 1
padded_img[-1, 0:temp_size] -= 1
padded_img[-1, -(temp_size + 1):-1] -= 1
padded_img[0:temp_size, 0] -= 1
padded_img[0:temp_size, -1] -= 1
padded_img[-(temp_size + 1):-1, 0] -= 1
padded_img[-(temp_size + 1):-1, -1] -= 1
return padded_img
def pad_img(ori_img, pad_value):
padded_img = np.pad(ori_img, 1, constant_values=pad_value)
img_h, img_w = padded_img.shape[0], padded_img.shape[1]
temp_size = img_h // 3
padded_img = process_angle(padded_img, temp_size)
temp_size = img_h // 9
padded_img = process_angle(padded_img, temp_size)
temp_size = img_h // 15
padded_img = process_angle(padded_img, temp_size)
temp_size = img_h // 21
padded_img = process_angle(padded_img, temp_size)
padded_img = np.clip(padded_img, 0, 255)
return padded_img
def shadow_generation(transparency, shadow_img_size=1024):
deepest_value = int(255 * transparency)
center_patch = np.zeros((shadow_img_size // 2, shadow_img_size // 2), dtype=np.uint8)
center_patch.fill(255)
pad_gap = shadow_img_size // 2
shadow_patch = center_patch.copy()
for i in range(pad_gap):
curr_pad_value = 255.0 - float(255.0 - deepest_value) / float(pad_gap) * (i + 1)
shadow_patch = pad_img(shadow_patch, pad_value=curr_pad_value)
for i in range(shadow_img_size // 4):
shadow_patch = pad_img(shadow_patch, pad_value=deepest_value)
assert shadow_patch.shape[0] == shadow_img_size * 2, shadow_patch.shape[0]
return shadow_patch
for transparency_ in range(90, 95 + 1):
transparency = transparency_ / 100.0
shadow_full = shadow_generation(transparency)
shadow_data.append(shadow_full)
splits = ['train', 'test']
resolutions = [model_params.image_size_small, model_params.image_size_large]
for resolution in range(resolutions[0], resolutions[1] + 1):
for split in splits:
sketch_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',
'sketch', split, 'res_' + str(resolution) + '.mat')
photo_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',
'photo', split, 'res_' + str(resolution) + '.mat')
sketch_data1_uint8 = load_sketch_data(
sketch_mat1_path) # (N, resolution, resolution), [0-strokes, 255-BG]
photo_data1_uint8 = load_photo_data(photo_mat1_path) # (N, resolution, resolution), [0-strokes, 255-BG]
sketch_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',
'sketch', split, 'res_' + str(resolution) + '.mat')
photo_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',
'photo', split, 'res_' + str(resolution) + '.mat')
sketch_data2_uint8 = load_sketch_data(
sketch_mat2_path) # (N, resolution, resolution), [0-strokes, 255-BG]
photo_data2_uint8 = load_photo_data(photo_mat2_path) # (N, resolution, resolution), [0-strokes, 255-BG]
sketch_data_uint8 = np.concatenate([sketch_data1_uint8, sketch_data2_uint8],
axis=0) # (N, resolution, resolution), [0-strokes, 255-BG]
photo_data_uint8 = np.concatenate([photo_data1_uint8, photo_data2_uint8],
axis=0) # (N, resolution, resolution), [0-strokes, 255-BG]
if split == 'train':
train_photo_data.append(photo_data_uint8)
train_sketch_data.append(sketch_data_uint8)
else:
val_photo_data.append(photo_data_uint8)
val_sketch_data.append(sketch_data_uint8)
assert len(train_sketch_data) == len(train_photo_data)
assert len(val_sketch_data) == len(val_photo_data)
else:
raise Exception('Unknown data type:', model_params.data_set)
print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))
print('model_params.max_seq_len %i.' % model_params.max_seq_len)
eval_sample_model_params = copy_hparams(model_params)
eval_sample_model_params.use_input_dropout = 0
eval_sample_model_params.use_recurrent_dropout = 0
eval_sample_model_params.use_output_dropout = 0
eval_sample_model_params.batch_size = 1 # only sample one at a time
eval_sample_model_params.model_mode = 'eval_sample'
train_set = GeneralDataLoaderMultiObjectRough(train_photo_data, train_sketch_data,
texture_data, shadow_data,
model_params.batch_size, model_params.raster_size,
model_params.image_size_small, model_params.image_size_large,
is_train=True)
val_set = GeneralDataLoaderMultiObjectRough(val_photo_data, val_sketch_data,
texture_data, shadow_data,
eval_sample_model_params.batch_size,
eval_sample_model_params.raster_size,
eval_sample_model_params.image_size_small,
eval_sample_model_params.image_size_large,
is_train=False)
result = [
train_set, val_set, model_params, eval_sample_model_params
]
return result
class GeneralDataLoaderNormalImageLinear(object):
def __init__(self,
photo_data,
sketch_data,
sketch_shape,
batch_size,
raster_size,
image_size_small,
image_size_large,
random_image_size,
flip_prob,
rotate_prob,
is_train):
self.batch_size = batch_size # minibatch size
self.raster_size = raster_size
self.image_size_small = image_size_small
self.image_size_large = image_size_large
self.random_image_size = random_image_size
self.is_train = is_train
assert photo_data is not None
assert len(photo_data) == len(sketch_data)
self.num_batches = len(sketch_data) // self.batch_size
self.batch_idx = -1
print('batch_size', batch_size, ', num_batches', self.num_batches)
self.flip_prob = flip_prob
self.rotate_prob = rotate_prob
assert type(photo_data) is list
assert type(sketch_data) is list
self.photo_data = photo_data
self.sketch_data = sketch_data
self.sketch_shape = sketch_shape
def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,
photo_prob=1.0,
init_cursor_num=1):
if self.random_image_size:
image_size_rand = fixed_image_size
else:
image_size_rand = self.image_size_large
photo_data_batch, sketch_data_batch = self.select_sketch_and_crop(
image_size_rand, data_idx=memory_idx, photo_prob=photo_prob,
interpolate_type=interpolate_type) # sketch_patch: [0.0-stroke, 1.0-BG]
photo_data_batch = np.expand_dims(photo_data_batch, axis=0) # (1, image_size, image_size, 3)
sketch_data_batch = np.expand_dims(sketch_data_batch,
axis=0) # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
image_size_rand = sketch_data_batch.shape[1]
return photo_data_batch, sketch_data_batch, \
self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand
def crop_and_augment(self, photo, sketch, shape, crop_size, rotate_angle, stroke_cover=0.01):
# img: [0-stroke, 255-BG], uint8
def angle_convert(angle):
return angle / 180.0 * math.pi
img_h, img_w = shape[0], shape[1]
if self.is_train:
crop_up = random.randint(0, img_h - crop_size)
crop_left = random.randint(0, img_w - crop_size)
else:
crop_up = (img_h - crop_size) // 2
crop_left = (img_w - crop_size) // 2
assert crop_up >= 0
assert crop_left >= 0
crop_box = (crop_left, crop_up, crop_left + crop_size, crop_up + crop_size)
rst_sketch_image = sketch.crop(crop_box)
rst_photo_image = photo.crop(crop_box)
if random.random() <= self.flip_prob and self.is_train:
rst_sketch_image = rst_sketch_image.transpose(Image.FLIP_LEFT_RIGHT)
rst_photo_image = rst_photo_image.transpose(Image.FLIP_LEFT_RIGHT)
if rotate_angle != 0 and self.is_train:
rst_sketch_image = rst_sketch_image.rotate(rotate_angle, resample=Image.BILINEAR)
rst_photo_image = rst_photo_image.rotate(rotate_angle, resample=Image.BILINEAR)
rst_sketch_image = np.array(rst_sketch_image, dtype=np.uint8)
rst_photo_image = np.array(rst_photo_image, dtype=np.uint8)
center = rst_photo_image.shape[0] // 2
new_dim = float(crop_size) / (
math.sin(angle_convert(abs(rotate_angle))) + math.cos(angle_convert(abs(rotate_angle))))
new_dim = int(round(new_dim))
start_pos = center - new_dim // 2
end_pos = start_pos + new_dim
rst_sketch_image = rst_sketch_image[start_pos:end_pos, start_pos:end_pos, :]
rst_photo_image = rst_photo_image[start_pos:end_pos, start_pos:end_pos, :]
rst_sketch_image = np.array(rst_sketch_image, dtype=np.float32) / 255.0 # [0.0-stroke, 1.0-BG]
rst_sketch_image = rst_sketch_image[:, :, 0]
rst_photo_image = np.array(rst_photo_image, dtype=np.float32) / 255.0 # [0.0-stroke, 1.0-BG]
percentage = np.mean(1.0 - rst_sketch_image)
valid = True
if percentage < stroke_cover:
valid = False
return rst_photo_image, rst_sketch_image, valid
def image_interpolation(self, photo, sketch, photo_prob):
interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)
interp_photo = np.clip(interp_photo, 0.0, 1.0)
return interp_photo
def select_sketch_and_crop(self, image_size_rand, interpolate_type, rotate_angle=0, photo_prob=1.0,
data_idx=-1, trial_times=10):
if self.is_train:
while True:
rand_img_idx = random.randint(0, len(self.sketch_data) - 1)
selected_sketch_shape = self.sketch_shape[rand_img_idx]
if selected_sketch_shape[0] >= image_size_rand and selected_sketch_shape[1] >= image_size_rand:
img_idx = rand_img_idx
break
else:
assert data_idx != -1
img_idx = data_idx
assert img_idx != -1
selected_sketch = self.sketch_data[img_idx]
selected_photo = self.photo_data[img_idx]
selected_shape = self.sketch_shape[img_idx]
assert interpolate_type in ['prob', 'image']
if interpolate_type == 'prob' and random.random() >= photo_prob:
selected_photo = self.sketch_data[img_idx]
for trial_i in range(trial_times):
cropped_photo, cropped_sketch, valid = \
self.crop_and_augment(selected_photo, selected_sketch, selected_shape, image_size_rand, rotate_angle)
# cropped_photo, cropped_sketch: [0.0-stroke, 1.0-BG]
if valid or trial_i == trial_times - 1:
if interpolate_type == 'image':
cropped_photo = self.image_interpolation(cropped_photo,
np.stack([cropped_sketch for _ in range(3)], axis=-1),
photo_prob)
return cropped_photo, cropped_sketch
def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):
photo_data_batch = []
sketch_data_batch = []
init_cursors_batch = []
image_size_batch = []
batch_size_per_loop = self.batch_size // loop_num
for loop_i in range(loop_num):
if self.random_image_size:
image_size_rand = random.randint(self.image_size_small, self.image_size_large)
else:
image_size_rand = self.image_size_large
rotate_angle = 0
if random.random() <= self.rotate_prob:
rotate_angle = random.randint(-45, 45)
photo_data_sub_batch = []
sketch_data_sub_batch = []
for img_i in range(batch_size_per_loop):
photo_patch, sketch_patch = \
self.select_sketch_and_crop(image_size_rand, rotate_angle=rotate_angle, photo_prob=photo_prob,
interpolate_type=interpolate_type) # sketch_patch: [0.0-stroke, 1.0-BG]
photo_data_sub_batch.append(photo_patch)
sketch_data_sub_batch.append(sketch_patch)
photo_data_sub_batch = np.stack(photo_data_sub_batch,
axis=0) # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]
sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
axis=0) # (N, image_size, image_size), [0.0-strokes, 1.0-BG]
init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)
photo_data_batch.append(photo_data_sub_batch)
sketch_data_batch.append(sketch_data_sub_batch)
init_cursors_batch.append(init_cursors_sub_batch)
image_size_rand = photo_data_sub_batch.shape[1]
image_size_batch.append(image_size_rand)
return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch
def crop_patch(self, image, center, image_size, crop_size):
x0 = center[0] - crop_size // 2
x1 = x0 + crop_size
y0 = center[1] - crop_size // 2
y1 = y0 + crop_size
x0 = max(0, min(x0, image_size))
y0 = max(0, min(y0, image_size))
x1 = max(0, min(x1, image_size))
y1 = max(0, min(y1, image_size))
patch = image[y0:y1, x0:x1]
return patch
def gen_init_cursor_single(self, sketch_image):
# sketch_image: [0.0-stroke, 1.0-BG]
image_size = sketch_image.shape[0]
if np.sum(1.0 - sketch_image) == 0:
center = np.zeros((2), dtype=np.int32)
return center
else:
while True:
center = np.random.randint(0, image_size, size=(2)) # (2), in large size
patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
if np.sum(patch) != 0:
return center.astype(np.float32) / float(image_size) # (2), in size [0.0, 1.0)
def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):
init_cursor_batch_list = []
for cursor_i in range(init_cursor_num):
if random_pos:
init_cursor_batch = []
for i in range(len(sketch_data)):
sketch_image = sketch_data[i].copy().astype(np.float32) # [0.0-stroke, 1.0-BG]
center = self.gen_init_cursor_single(sketch_image)
init_cursor_batch.append(center)
init_cursor_batch = np.stack(init_cursor_batch, axis=0) # (N, 2)
else:
raise Exception('Not finished')
init_cursor_batch_list.append(init_cursor_batch)
if init_cursor_num == 1:
init_cursor_batch = init_cursor_batch_list[0]
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32) # (N, 1, 2)
else:
init_cursor_batch = np.stack(init_cursor_batch_list, axis=1) # (N, init_cursor_num, 2)
init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
np.float32) # (N, init_cursor_num, 1, 2)
return init_cursor_batch
def load_dataset_normal_images(dataset_base_dir, model_params):
train_photo_data = []
train_sketch_data = []
train_data_shape = []
val_photo_data = []
val_sketch_data = []
val_data_shape = []
if model_params.data_set == 'faces':
random_training_image_size = False
flip_prob = -0.1
rotate_prob = -0.1
splits = ['train', 'val']
database = os.path.join(dataset_base_dir, 'CelebAMask-faces')
photo_base = os.path.join(database, 'CelebA-HQ-img256')
edge_base = os.path.join(database, 'CelebAMask-HQ-edge256')
train_split_txt_save_path = os.path.join(database, 'train.txt')
val_split_txt_save_path = os.path.join(database, 'val.txt')
celeba_train_txt = np.loadtxt(train_split_txt_save_path, dtype=str)
celeba_val_txt = np.loadtxt(val_split_txt_save_path, dtype=str)
splits_indices_map = {'train': celeba_train_txt, 'val': celeba_val_txt}
for split in splits:
split_indices = splits_indices_map[split]
for i in range(len(split_indices)):
file_idx = split_indices[i]
img_file_path = os.path.join(photo_base, str(file_idx) + '.jpg')
edge_img_path = os.path.join(edge_base, str(file_idx) + '.png')
img_data = Image.open(img_file_path).convert('RGB')
edge_data = Image.open(edge_img_path).convert('RGB')
if split == 'train':
train_photo_data.append(img_data)
train_sketch_data.append(edge_data)
train_data_shape.append((img_data.height, img_data.width))
else: # split == 'val'
val_photo_data.append(img_data)
val_sketch_data.append(edge_data)
val_data_shape.append((img_data.height, img_data.width))
assert len(train_sketch_data) == len(train_data_shape) == len(train_photo_data)
assert len(val_sketch_data) == len(val_data_shape) == len(val_photo_data)
else:
raise Exception('Unknown data type:', model_params.data_set)
print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))
print('model_params.max_seq_len %i.' % model_params.max_seq_len)
eval_sample_model_params = copy_hparams(model_params)
eval_sample_model_params.use_input_dropout = 0
eval_sample_model_params.use_recurrent_dropout = 0
eval_sample_model_params.use_output_dropout = 0
eval_sample_model_params.batch_size = 1 # only sample one at a time
eval_sample_model_params.model_mode = 'eval_sample'
train_set = GeneralDataLoaderNormalImageLinear(train_photo_data, train_sketch_data, train_data_shape,
model_params.batch_size, model_params.raster_size,
image_size_small=model_params.image_size_small,
image_size_large=model_params.image_size_large,
random_image_size=random_training_image_size,
flip_prob=flip_prob, rotate_prob=rotate_prob,
is_train=True)
val_set = GeneralDataLoaderNormalImageLinear(val_photo_data, val_sketch_data, val_data_shape,
eval_sample_model_params.batch_size,
eval_sample_model_params.raster_size,
image_size_small=eval_sample_model_params.image_size_small,
image_size_large=eval_sample_model_params.image_size_large,
random_image_size=random_training_image_size,
flip_prob=flip_prob, rotate_prob=rotate_prob,
is_train=False)
result = [
train_set, val_set, model_params, eval_sample_model_params
]
return result
def load_dataset_training(dataset_base_dir, model_params):
if model_params.data_set == 'clean_line_drawings':
return load_dataset_multi_object(dataset_base_dir, model_params)
elif model_params.data_set == 'rough_sketches':
return load_dataset_multi_object_rough(dataset_base_dir, model_params)
elif model_params.data_set == 'faces':
return load_dataset_normal_images(dataset_base_dir, model_params)
else:
raise Exception('Unknown data_set', model_params.data_set)
================================================
FILE: docs/assets/font.css
================================================
/* Homepage Font */
/* latin-ext */
@font-face {
font-family: 'Lato';
font-style: normal;
font-weight: 400;
src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');
unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
}
/* latin */
@font-face {
font-family: 'Lato';
font-style: normal;
font-weight: 400;
src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
}
/* latin-ext */
@font-face {
font-family: 'Lato';
font-style: normal;
font-weight: 700;
src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');
unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
}
/* latin */
@font-face {
font-family: 'Lato';
font-style: normal;
font-weight: 700;
src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
}
================================================
FILE: docs/assets/style.css
================================================
/* Body */
body {
background: #e3e5e8;
color: #ffffff;
font-family: 'Lato', Verdana, Helvetica, sans-serif;
font-weight: 300;
font-size: 14pt;
}
/* Hyperlinks */
a {text-decoration: none;}
a:link {color: #1772d0;}
a:visited {color: #1772d0;}
a:active {color: red;}
a:hover {color: #f09228;}
/* Pre-formatted Text */
pre {
margin: 5pt 0;
border: 0;
font-size: 12pt;
background: #fcfcfc;
}
/* Project Page Style */
/* Section */
.section {
width: 768pt;
min-height: 100pt;
margin: 15pt auto;
padding: 20pt 30pt;
border: 1pt hidden #000;
text-align: justify;
color: #000000;
background: #ffffff;
}
/* Header (Title and Logo) */
.section .header {
min-height: 80pt;
margin-top: 30pt;
}
.section .header .logo {
width: 80pt;
margin-left: 10pt;
float: left;
}
.section .header .logo img {
width: 80pt;
object-fit: cover;
}
.section .header .title {
margin: 0 120pt;
text-align: center;
font-size: 22pt;
}
/* Author */
.section .author {
margin: 5pt 0;
text-align: center;
font-size: 16pt;
}
/* Institution */
.section .institution {
margin: 5pt 0;
text-align: center;
font-size: 16pt;
}
/* Hyperlink (such as Paper and Code) */
.section .link {
margin: 5pt 0;
text-align: center;
font-size: 16pt;
}
/* Teaser */
.section .teaser {
margin: 20pt 0;
text-align: left;
}
.section .teaser img {
width: 95%;
}
/* Section Title */
.section .title {
text-align: center;
font-size: 22pt;
margin: 5pt 0 15pt 0; /* top right bottom left */
}
/* Section Body */
.section .body {
margin-bottom: 15pt;
text-align: justify;
font-size: 14pt;
}
/* BibTeX */
.section .bibtex {
margin: 5pt 0;
text-align: left;
font-size: 22pt;
}
/* Related Work */
.section .ref {
margin: 20pt 0 10pt 0; /* top right bottom left */
text-align: left;
font-size: 18pt;
font-weight: bold;
}
/* Citation */
.section .citation {
min-height: 60pt;
margin: 10pt 0;
}
.section .citation .image {
width: 120pt;
float: left;
}
.section .citation .image img {
max-height: 60pt;
width: 120pt;
object-fit: cover;
}
.section .citation .comment{
margin-left: 0pt;
text-align: left;
font-size: 14pt;
}
================================================
FILE: docs/index.html
================================================
General Virtual Sketching Framework for Vector Line Art
General Virtual Sketching Framework for Vector Line Art
1 Sun Yat-sen University,
2 Waseda University,
3 Huawei Technologies Canada
Given clean line drawings, rough sketches or photographs of arbitrary resolution as input, our framework generates the corresponding vector line drawings directly. As shown in (b), the framework models a virtual pen surrounded by a dynamic window (red boxes), which moves while drawing the strokes. It learns to move around by scaling the window and sliding to an undrawn area for restarting the drawing (bottom example; sliding trajectory in blue arrow). With our proposed stroke regularization mechanism, the framework is able to enlarge the window and draw long strokes for simplicity (top example).
Abstract
Vector line art plays an important role in graphic design, however, it is tedious to manually create.
We introduce a general framework to produce line drawings from a wide variety of images,
by learning a mapping from raster image space to vector image space.
Our approach is based on a recurrent neural network that draws the lines one by one.
A differentiable rasterization module allows for training with only supervised raster data.
We use a dynamic window around a virtual pen while drawing lines,
implemented with a proposed aligned cropping and differentiable pasting modules.
Furthermore, we develop a stroke regularization loss
that encourages the model to use fewer and longer strokes to simplify the resulting vector image.
Ablation studies and comparisons with existing methods corroborate the efficiency of our approach
which is able to generate visually better results in less computation time,
while generalizing better to a diversity of images and applications.
Method
Framework Overview
Our framework generates the parametrized strokes step by step in a recurrent manner.
It uses a dynamic window (dashed red boxes) around a virtual pen to draw the strokes,
and can both move and change the size of the window.
(a) Four main modules at each time step: aligned cropping, stroke generation, differentiable rendering and differentiable pasting.
(b) Architecture of the stroke generation module.
(c) Structural strokes predicted at each step;
movement only is illustrated by blue arrows during which no stroke is drawn on the canvas.
Overall Introduction
(Or watch on Bilibili )
👇
VIDEO
Results
Our framework is applicable to a diversity of image types, such as clean line drawing images, rough sketches and photographs.
Vectorization
Rough sketch simplification
Photograph to line drawing
More Results
(Or watch on Bilibili )
👇
VIDEO
Presentations
3-5 minute presentation
(Or watch on Bilibili )
👇
VIDEO
BibTeX
@article{mo2021virtualsketching,
title = {General Virtual Sketching Framework for Vector Line Art},
author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
year = {2021},
volume = {40},
number = {4},
pages = {51:1--51:14}
}
Related Work
================================================
FILE: hyper_parameters.py
================================================
import tensorflow as tf
#############################################
# Common parameters
#############################################
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'dataset_dir',
'datasets',
'The directory of sketch data of the dataset.')
tf.app.flags.DEFINE_string(
'log_root',
'outputs/log',
'Directory to store tensorboard.')
tf.app.flags.DEFINE_string(
'log_img_root',
'outputs/log_img',
'Directory to store intermediate output images.')
tf.app.flags.DEFINE_string(
'snapshot_root',
'outputs/snapshot',
'Directory to store model checkpoints.')
tf.app.flags.DEFINE_string(
'neural_renderer_path',
'outputs/snapshot/pretrain_neural_renderer/renderer_300000.tfmodel',
'Path to the neural renderer model.')
tf.app.flags.DEFINE_string(
'perceptual_model_root',
'outputs/snapshot/pretrain_perceptual_model',
'Directory to store perceptual model.')
tf.app.flags.DEFINE_string(
'data',
'',
'The dataset type.')
def get_default_hparams_clean():
"""Return default HParams for sketch-rnn."""
hparams = tf.contrib.training.HParams(
program_name='new_train_clean_line_drawings',
data_set='clean_line_drawings', # Our dataset.
input_channel=1,
num_steps=75040, # Total number of steps of training.
save_every=75000,
eval_every=5000,
max_seq_len=48,
batch_size=20,
gpus=[0, 1],
loop_per_gpu=1,
sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing']
stroke_num_loss_weight=0.02,
stroke_num_loss_weight_end=0.0,
increase_start_steps=25000,
decrease_stop_steps=40000,
perc_loss_layers=['ReLU1_2', 'ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
init_cursor_on_undrawn_pixel=False,
early_pen_loss_type='move', # ['head', 'tail', 'move']
early_pen_loss_weight=0.1,
early_pen_length=7,
min_width=0.01,
min_window_size=32,
max_scaling=2.0,
encode_cursor_type='value',
image_size_small=128,
image_size_large=278,
cropping_type='v3', # ['v2', 'v3']
pasting_type='v3', # ['v2', 'v3']
pasting_diff=True,
concat_win_size=True,
encoder_type='conv13_c3',
# ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
# ['conv13_c3_attn']
# ['combine33', 'combine43', 'combine53', 'combineFC']
vary_thickness=False,
outside_loss_weight=10.0,
win_size_outside_loss_weight=10.0,
resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
concat_cursor=True,
use_softargmax=True,
soft_beta=10, # value for the soft argmax
raster_loss_weight=1.0,
dec_rnn_size=256, # Size of decoder.
dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
# z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
bin_gt=True,
stop_accu_grad=True,
random_cursor=True,
cursor_type='next',
raster_size=128,
pix_drop_kp=1.0, # Dropout keep rate
add_coordconv=True,
position_format='abs',
raster_loss_base_type='perceptual', # [l1, mse, perceptual]
grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
learning_rate=0.0001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
decay_power=0.9,
min_learning_rate=0.000001, # Minimum learning rate.
use_recurrent_dropout=True, # Dropout with memory loss. Recommended
recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output dropout. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
model_mode='train' # ['train', 'eval', 'sample']
)
return hparams
def get_default_hparams_rough():
"""Return default HParams for sketch-rnn."""
hparams = tf.contrib.training.HParams(
program_name='new_train_rough_sketches',
data_set='rough_sketches', # ['rough_sketches', 'faces']
input_channel=3,
num_steps=90040, # Total number of steps of training.
save_every=90000,
eval_every=5000,
max_seq_len=48,
batch_size=20,
gpus=[0, 1],
loop_per_gpu=1,
sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing']
stroke_num_loss_weight=0.1,
stroke_num_loss_weight_end=0.0,
increase_start_steps=25000,
decrease_stop_steps=40000,
photo_prob_type='one', # ['increasing', 'zero', 'one']
photo_prob_start_step=35000,
perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
early_pen_loss_type='move', # ['head', 'tail', 'move']
early_pen_loss_weight=0.2,
early_pen_length=7,
min_width=0.01,
min_window_size=32,
max_scaling=2.0,
encode_cursor_type='value',
image_size_small=128,
image_size_large=278,
cropping_type='v3', # ['v2', 'v3']
pasting_type='v3', # ['v2', 'v3']
pasting_diff=True,
concat_win_size=True,
encoder_type='conv13_c3',
# ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
# ['conv13_c3_attn']
# ['combine33', 'combine43', 'combine53', 'combineFC']
outside_loss_weight=10.0,
win_size_outside_loss_weight=10.0,
resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
concat_cursor=True,
use_softargmax=True,
soft_beta=10, # value for the soft argmax
raster_loss_weight=1.0,
dec_rnn_size=256, # Size of decoder.
dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
# z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
bin_gt=True,
stop_accu_grad=True,
random_cursor=True,
cursor_type='next',
raster_size=128,
pix_drop_kp=1.0, # Dropout keep rate
add_coordconv=True,
position_format='abs',
raster_loss_base_type='perceptual', # [l1, mse, perceptual]
grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
learning_rate=0.0001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
decay_power=0.9,
min_learning_rate=0.000001, # Minimum learning rate.
use_recurrent_dropout=True, # Dropout with memory loss. Recommended
recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output dropout. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
model_mode='train' # ['train', 'eval', 'sample']
)
return hparams
def get_default_hparams_normal():
"""Return default HParams for sketch-rnn."""
hparams = tf.contrib.training.HParams(
program_name='new_train_faces',
data_set='faces', # ['rough_sketches', 'faces']
input_channel=3,
num_steps=90040, # Total number of steps of training.
save_every=90000,
eval_every=5000,
max_seq_len=48,
batch_size=20,
gpus=[0, 1],
loop_per_gpu=1,
sn_loss_type='fixed', # ['decreasing', 'fixed', 'increasing']
stroke_num_loss_weight=0.0,
stroke_num_loss_weight_end=0.0,
increase_start_steps=0,
decrease_stop_steps=40000,
photo_prob_type='interpolate', # ['increasing', 'zero', 'one', 'interpolate']
photo_prob_start_step=30000,
photo_prob_end_step=60000,
perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU4_2', 'ReLU5_1'],
perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
early_pen_loss_type='move', # ['head', 'tail', 'move']
early_pen_loss_weight=0.2,
early_pen_length=7,
min_width=0.01,
min_window_size=32,
max_scaling=2.0,
encode_cursor_type='value',
image_size_small=128,
image_size_large=256,
cropping_type='v3', # ['v2', 'v3']
pasting_type='v3', # ['v2', 'v3']
pasting_diff=True,
concat_win_size=True,
encoder_type='conv13_c3',
# ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
# ['conv13_c3_attn']
# ['combine33', 'combine43', 'combine53', 'combineFC']
outside_loss_weight=10.0,
win_size_outside_loss_weight=10.0,
resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
concat_cursor=True,
use_softargmax=True,
soft_beta=10, # value for the soft argmax
raster_loss_weight=1.0,
dec_rnn_size=256, # Size of decoder.
dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
# z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
bin_gt=True,
stop_accu_grad=True,
random_cursor=True,
cursor_type='next',
raster_size=128,
pix_drop_kp=1.0, # Dropout keep rate
add_coordconv=True,
position_format='abs',
raster_loss_base_type='perceptual', # [l1, mse, perceptual]
grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
learning_rate=0.0001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
decay_power=0.9,
min_learning_rate=0.000001, # Minimum learning rate.
use_recurrent_dropout=True, # Dropout with memory loss. Recommended
recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output dropout. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
model_mode='train' # ['train', 'eval', 'sample']
)
return hparams
================================================
FILE: launch_gui.bat
================================================
@echo OFF
REM === Cesta k instalaci Anacondy ===
set "CONDAPATH=C:\ProgramData\anaconda3"
REM === Název a cesta k prostředí ===
set "ENVNAME=virtual_sketching"
set "ENVPATH=%USERPROFILE%\.conda\envs\%ENVNAME%"
REM === Aktivace prostředí ===
call "%CONDAPATH%\Scripts\activate.bat" "%ENVPATH%"
REM === Spuštění GUI ===
python virtual_sketch_gui.py
REM === Pozastavení po ukončení ===
echo.
pause
REM === Deaktivace prostředí ===
call conda deactivate
================================================
FILE: model_common_test.py
================================================
import rnn
import tensorflow as tf
from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
generative_cnn_c3_encoder_deeper13_attn
class DiffPastingV3(object):
def __init__(self, raster_size):
self.patch_canvas = tf.placeholder(dtype=tf.float32,
shape=(None, None, 1)) # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
self.cursor_pos_a = tf.placeholder(dtype=tf.float32, shape=(2)) # (2), float32, in large size
self.image_size_a = tf.placeholder(dtype=tf.int32, shape=()) # ()
self.window_size_a = tf.placeholder(dtype=tf.float32, shape=()) # (), float32, with grad
self.raster_size_a = float(raster_size)
self.pasted_image = self.image_pasting_sampling_v3()
# (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
def image_pasting_sampling_v3(self):
padding_size = tf.cast(tf.ceil(self.window_size_a / 2.0), tf.int32)
x1y1_a = self.cursor_pos_a - self.window_size_a / 2.0 # (2), float32
x2y2_a = self.cursor_pos_a + self.window_size_a / 2.0 # (2), float32
x1y1_a_floor = tf.floor(x1y1_a) # (2)
x2y2_a_ceil = tf.ceil(x2y2_a) # (2)
cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2)
cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / self.window_size_a * self.raster_size_a # (2)
raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y)
image_size_b = self.raster_size_a
window_size_b = self.raster_size_a * (raster_size_b / self.window_size_a) # (x, y)
cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1)
y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
y2_b = y1_b + (window_size_b[1] - 1.)
x2_b = x1_b + (window_size_b[0] - 1.)
boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4)
boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a
box_ind_b = tf.ones((1), dtype=tf.int32) # (1)
box_ind_b = tf.cumsum(box_ind_b) - 1
patch_canvas = tf.expand_dims(self.patch_canvas,
axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4)
valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
crop_size=[raster_size_b[1], raster_size_b[0]])
valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1)
pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
pad_down = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
pad_right = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)
paddings = [[pad_up, pad_down],
[pad_left, pad_right],
[0, 0]]
pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
pasted_image = pad_img[padding_size: padding_size + self.image_size_a,
padding_size: padding_size + self.image_size_a, :]
# (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
return pasted_image
class VirtualSketchingModel(object):
def __init__(self, hps, gpu_mode=True, reuse=False):
"""Initializer for the model.
Args:
hps: a HParams object containing model hyperparameters
gpu_mode: a boolean that when True, uses GPU mode.
reuse: a boolean that when true, attemps to reuse variables.
"""
self.hps = hps
assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
# with tf.variable_scope('SCC', reuse=reuse):
if not gpu_mode:
with tf.device('/cpu:0'):
print('Model using cpu.')
self.build_model()
else:
print('-' * 100)
print('model_mode:', hps.model_mode)
print('Model using gpu.')
self.build_model()
def build_model(self):
"""Define model architecture."""
self.config_model()
initial_state = self.get_decoder_inputs()
self.initial_state = initial_state
## use pred as the prev points
other_params, pen_ras, final_state = self.get_points_and_raster_image(self.image_size)
# other_params: (N * max_seq_len, 6)
# pen_ras: (N * max_seq_len, 2), after softmax
self.other_params = other_params # (N * max_seq_len, 6)
self.pen_ras = pen_ras # (N * max_seq_len, 2), after softmax
self.final_state = final_state
if not self.hps.use_softargmax:
pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1)
else:
pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1)
pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7)
self.pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7)
# pred_params: (N, max_seq_len, 7)
def config_model(self):
if self.hps.model_mode == 'train':
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self.hps.dec_model == 'lstm':
dec_cell_fn = rnn.LSTMCell
elif self.hps.dec_model == 'layer_norm':
dec_cell_fn = rnn.LayerNormLSTMCell
elif self.hps.dec_model == 'hyper':
dec_cell_fn = rnn.HyperLSTMCell
else:
assert False, 'please choose a respectable cell'
use_recurrent_dropout = self.hps.use_recurrent_dropout
use_input_dropout = self.hps.use_input_dropout
use_output_dropout = self.hps.use_output_dropout
dec_cell = dec_cell_fn(
self.hps.dec_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
# dropout:
# print('Input dropout mode = %s.' % use_input_dropout)
# print('Output dropout mode = %s.' % use_output_dropout)
# print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
if use_input_dropout:
print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
dec_cell = tf.contrib.rnn.DropoutWrapper(
dec_cell, input_keep_prob=self.hps.input_dropout_prob)
if use_output_dropout:
print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
dec_cell = tf.contrib.rnn.DropoutWrapper(
dec_cell, output_keep_prob=self.hps.output_dropout_prob)
self.dec_cell = dec_cell
self.input_photo = tf.placeholder(dtype=tf.float32,
shape=[self.hps.batch_size, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG]
self.init_cursor = tf.placeholder(
dtype=tf.float32,
shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0)
self.init_width = tf.placeholder(
dtype=tf.float32,
shape=[self.hps.batch_size]) # (1), in [0.0, 1.0]
self.init_scaling = tf.placeholder(
dtype=tf.float32,
shape=[self.hps.batch_size]) # (N), in [0.0, 1.0]
self.init_window_size = tf.placeholder(
dtype=tf.float32,
shape=[self.hps.batch_size]) # (N)
self.image_size = tf.placeholder(dtype=tf.int32, shape=()) # ()
###########################
def normalize_image_m1to1(self, in_img_0to1):
norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
return norm_img_m1to1
def add_coords(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get N size
xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1)
xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
[batch_size_tensor, 1]) # e.g. (N, raster_size)
xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size)
xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size)
xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1)
yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size)
yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
[batch_size_tensor, 1]) # (N, raster_size)
yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1)
yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size)
yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1)
xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
# xx_channel = xx_channel * 2 - 1 # [-1, 1]
# yy_channel = yy_channel * 2 - 1
ret = tf.concat([
input_tensor,
xx_channel,
yy_channel,
], axis=-1) # e.g. (N, raster_size, raster_size, 4)
return ret
def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
image_size, window_size):
"""
:param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
:param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
:param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
:param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
:param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
:param window_size: (N, 1, 1), float, in large size
:return:
"""
if self.hps.resize_method == 'BILINEAR':
resize_method = tf.image.ResizeMethod.BILINEAR
elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
elif self.hps.resize_method == 'BICUBIC':
resize_method = tf.image.ResizeMethod.BICUBIC
elif self.hps.resize_method == 'AREA':
resize_method = tf.image.ResizeMethod.AREA
else:
raise Exception('unknown resize_method', self.hps.resize_method)
patch_photo = tf.stop_gradient(patch_photo)
patch_canvas = tf.stop_gradient(patch_canvas)
cursor_pos = tf.stop_gradient(cursor_pos)
window_size = tf.stop_gradient(window_size)
entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
(self.hps.raster_size, self.hps.raster_size),
method=resize_method))
entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
(self.hps.raster_size, self.hps.raster_size),
method=resize_method))
entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG]
entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG]
if self.hps.encode_cursor_type == 'value':
cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2)
cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
cursor_info = cursor_pos_norm
else:
raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)
batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
axis=-1) # [N, raster_size, raster_size, 6/10]
batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4]
batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
axis=-1) # [N, raster_size, raster_size, 4/6]
if self.hps.model_mode == 'train':
is_training = True
dropout_keep_prob = self.hps.pix_drop_kp
else:
is_training = False
dropout_keep_prob = 1.0
if self.hps.add_coordconv:
batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2)
batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2)
batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2)
if 'combine' in self.hps.encoder_type:
if self.hps.encoder_type == 'combine33':
image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combine43':
image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combine53':
image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combineFC':
image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 256)
else:
raise Exception('Unknown encoder_type', self.hps.encoder_type)
else:
with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
if self.hps.encoder_type == 'conv10':
image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_deep':
image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
elif self.hps.encoder_type == 'conv13':
image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_c3':
image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_deep_c3':
image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
elif self.hps.encoder_type == 'conv13_c3':
image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv13_c3_attn':
image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
else:
raise Exception('Unknown encoder_type', self.hps.encoder_type)
return image_embedding
def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])
pen_n_out = 2
params_n_out = 6
with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out)
with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out)
output = tf.concat([output_pen, output_params], axis=1) # (N, n_out)
return output, last_state
def get_mixture_coef(self, outputs):
z = outputs
z_pen_logits = z[:, 0:2] # (N, 2), pen states
z_other_params_logits = z[:, 2:] # (N, 6)
z_pen = tf.nn.softmax(z_pen_logits) # (N, 2)
if self.hps.position_format == 'abs':
x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2)
x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2)
widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1)
widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling]
# scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
# self.hps.min_scaling)
z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6)
else: # "rel"
raise Exception('Unknown position_format', self.hps.position_format)
r = [z_other_params, z_pen]
return r
###########################
def get_decoder_inputs(self):
initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
return initial_state
def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
output, last_state = tf.nn.dynamic_rnn(
dec_cell,
actual_input_x,
initial_state=initial_state,
time_major=False,
swap_memory=True,
dtype=tf.float32)
return output, last_state
###########################
def image_padding(self, ori_image, window_size, pad_value):
"""
Pad with (bg)
:param ori_image:
:return:
"""
paddings = [[0, 0],
[window_size // 2, window_size // 2],
[window_size // 2, window_size // 2],
[0, 0]]
pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k)
return pad_img
def image_cropping_fn(self, fn_inputs):
"""
crop the patch
:return:
"""
index_offset = self.hps.input_channel - 1
input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, -), [0.0-BG, 1.0-stroke]
cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0)
image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32
window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # ()
input_img_reshape = tf.expand_dims(input_image, axis=0)
pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)
cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # ()
y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # ()
patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4)
# resize to raster_size
patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
method=tf.image.ResizeMethod.AREA)
patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
# patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]
return patch_image_scaled
def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
"""
:param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
:param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
"""
input_img_ = input_img
window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad
cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2)
image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1)
image_size_ = tf.tile(image_size_, [self.hps.batch_size, image_size, image_size, 1])
window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1)
window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1)
fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
axis=-1) # (N, image_size, image_size, 2/4 + 4)
curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -)
return curr_patch_imgs
def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
"""
:param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
:param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
"""
window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad
cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1)
y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
y2 = y1 + (window_sizes_non_grad - 1.0)
x2 = x1 + (window_sizes_non_grad - 1.0)
boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4)
boxes = tf.squeeze(boxes, axis=1) # (N, 4)
boxes = boxes / tf.cast(image_size - 1, tf.float32)
box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N)
box_ind = tf.cast(box_ind, dtype=tf.int32)
box_ind = tf.cumsum(box_ind) - 1
curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
crop_size=[self.hps.raster_size, self.hps.raster_size])
# (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
return curr_patch_imgs
def get_points_and_raster_image(self, image_size):
## generate the other_params and pen_ras and raster image for raster loss
prev_state = self.initial_state # (N, dec_rnn_size * 3)
prev_width = self.init_width # (N)
prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=-1), axis=-1) # (N, 1, 1)
prev_scaling = self.init_scaling # (N)
prev_scaling = tf.reshape(prev_scaling, (-1, 1, 1)) # (N, 1, 1)
prev_window_size = self.init_window_size # (N)
prev_window_size = tf.reshape(prev_window_size, (-1, 1, 1)) # (N, 1, 1)
cursor_position_temp = self.init_cursor
self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0)
cursor_position_loop = self.cursor_position
other_params_list = []
pen_ras_list = []
curr_canvas_soft = tf.zeros_like(self.input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke]
#### sampling part - start ####
self.curr_canvas_hard = curr_canvas_hard
if self.hps.cropping_type == 'v3':
cropping_func = self.image_cropping_v3
# elif self.hps.cropping_type == 'v2':
# cropping_func = self.image_cropping
else:
raise Exception('Unknown cropping_type', self.hps.cropping_type)
for time_i in range(self.hps.max_seq_len):
cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0)
curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad
curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))
## patch-level encoding
# Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)
# input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
crop_inputs = tf.concat([1.0 - self.input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1+1)
cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
index_offset = self.hps.input_channel - 1
curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke]
curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
# (N, raster_size, raster_size, 1/3), [0.0-BG, 1.0-stroke]
curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG]
curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
# (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
# Normalizing image
curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG]
curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG]
## image-level encoding
combined_z = self.build_combined_encoder(
curr_patch_canvas_hard_non_grad,
curr_patch_inputs,
1.0 - curr_canvas_hard_non_grad,
self.input_photo,
cursor_position_non_grad,
image_size,
curr_window_size) # (N, z_size)
combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size)
curr_window_size_top_side_norm_non_grad = \
tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
curr_window_size_bottom_side_norm_non_grad = \
tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
if not self.hps.concat_win_size:
combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size)
else:
combined_z = tf.concat([tf.stop_gradient(prev_width),
curr_window_size_top_side_norm_non_grad,
curr_window_size_bottom_side_norm_non_grad,
combined_z],
2) # (N, 1, 2+z_size)
if self.hps.concat_cursor:
prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size)
else:
prev_input_x = combined_z # (N, 1, 2+z_size)
h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
# h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
[o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
# o_other_params: (N * 1, 6)
# o_pen_ras: (N * 1, 2), after softmax
o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6)
o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2)
other_params_list.append(o_other_params)
pen_ras_list.append(o_pen_ras_raw)
#### sampling part - end ####
prev_state = next_state
other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6)
pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2)
return other_params_, pen_ras_, prev_state
def differentiable_argmax(self, input_pen, soft_beta):
"""
Differentiable argmax trick.
:param input_pen: (N, n_class)
:return: pen_state: (N, 1)
"""
def sign_onehot(x):
"""
:param x: (N, n_class)
:return: (N, n_class)
"""
y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)
y = (y - 1) * (-1)
return y
def softargmax(x, beta=1e2):
"""
:param x: (N, n_class)
:param beta: 1e10 is the best. 1e2 is acceptable.
:return: (N)
"""
x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2)
return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1
## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.
# pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2)
# pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1)
pen_state = softargmax(input_pen, soft_beta)
pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1)
return pen_state
================================================
FILE: model_common_train.py
================================================
import rnn
import tensorflow as tf
from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
generative_cnn_c3_encoder_deeper13_attn
from rasterization_utils.NeuralRenderer import NeuralRasterizorStep
from vgg_utils.VGG16 import vgg_net_slim
class VirtualSketchingModel(object):
def __init__(self, hps, gpu_mode=True, reuse=False):
"""Initializer for the model.
Args:
hps: a HParams object containing model hyperparameters
gpu_mode: a boolean that when True, uses GPU mode.
reuse: a boolean that when true, attemps to reuse variables.
"""
self.hps = hps
assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
# with tf.variable_scope('SCC', reuse=reuse):
if not gpu_mode:
with tf.device('/cpu:0'):
print('Model using cpu.')
self.build_model()
else:
print('-' * 100)
print('model_mode:', hps.model_mode)
print('Model using gpu.')
self.build_model()
def build_model(self):
"""Define model architecture."""
self.config_model()
initial_state = self.get_decoder_inputs()
self.initial_state = initial_state
self.initial_state_list = tf.split(self.initial_state, self.total_loop, axis=0)
total_loss_list = []
ras_loss_list = []
perc_relu_raw_list = []
perc_relu_norm_list = []
sn_loss_list = []
cursor_outside_loss_list = []
win_size_outside_loss_list = []
early_state_loss_list = []
tower_grads = []
pred_raster_imgs_list = []
pred_raster_imgs_rgb_list = []
for t_i in range(self.total_loop):
gpu_idx = t_i // self.hps.loop_per_gpu
gpu_i = self.hps.gpus[gpu_idx]
print(self.hps.model_mode, 'model, gpu:', gpu_i, ', loop:', t_i % self.hps.loop_per_gpu)
with tf.device('/gpu:%d' % gpu_i):
with tf.name_scope('GPU_%d' % gpu_i) as scope:
if t_i > 0:
tf.get_variable_scope().reuse_variables()
else:
total_loss_list.clear()
ras_loss_list.clear()
perc_relu_raw_list.clear()
perc_relu_norm_list.clear()
sn_loss_list.clear()
cursor_outside_loss_list.clear()
win_size_outside_loss_list.clear()
early_state_loss_list.clear()
tower_grads.clear()
pred_raster_imgs_list.clear()
pred_raster_imgs_rgb_list.clear()
split_input_photo = self.input_photo_list[t_i]
split_image_size = self.image_size[t_i]
split_init_cursor = self.init_cursor_list[t_i]
split_initial_state = self.initial_state_list[t_i]
if self.hps.input_channel == 1:
split_target_sketch = split_input_photo
else:
split_target_sketch = self.target_sketch_list[t_i]
## use pred as the prev points
other_params, pen_ras, final_state, pred_raster_images, pred_raster_images_rgb, \
pos_before_max_min, win_size_before_max_min \
= self.get_points_and_raster_image(split_initial_state, split_init_cursor, split_input_photo,
split_image_size)
# other_params: (N * max_seq_len, 6)
# pen_ras: (N * max_seq_len, 2), after softmax
# pos_before_max_min: (N, max_seq_len, 2), in image_size
# win_size_before_max_min: (N, max_seq_len, 1), in image_size
pred_raster_imgs = 1.0 - pred_raster_images # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
pred_raster_imgs_rgb = 1.0 - pred_raster_images_rgb # (N, image_size, image_size, 3)
pred_raster_imgs_list.append(pred_raster_imgs)
pred_raster_imgs_rgb_list.append(pred_raster_imgs_rgb)
if not self.hps.use_softargmax:
pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1)
else:
pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1)
pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7)
pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7)
# pred_params: (N, max_seq_len, 7)
if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, \
early_pen_states_cost, \
perc_relu_loss_raw, perc_relu_loss_norm = \
self.build_losses(split_target_sketch, pred_raster_imgs, pred_params,
pos_before_max_min, win_size_before_max_min,
split_image_size)
# perc_relu_loss_raw, perc_relu_loss_norm: (n_layers)
ras_loss_list.append(raster_cost)
perc_relu_raw_list.append(perc_relu_loss_raw)
perc_relu_norm_list.append(perc_relu_loss_norm)
sn_loss_list.append(sn_cost)
cursor_outside_loss_list.append(cursor_outside_cost)
win_size_outside_loss_list.append(winsize_outside_cost)
early_state_loss_list.append(early_pen_states_cost)
if self.hps.model_mode == 'train':
total_cost_split, grads_and_vars_split = self.build_training_op_split(
raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost,
early_pen_states_cost)
total_loss_list.append(total_cost_split)
tower_grads.append(grads_and_vars_split)
self.raster_cost = tf.reduce_mean(tf.stack(ras_loss_list, axis=0))
self.perc_relu_losses_raw = tf.reduce_mean(tf.stack(perc_relu_raw_list, axis=0), axis=0) # (n_layers)
self.perc_relu_losses_norm = tf.reduce_mean(tf.stack(perc_relu_norm_list, axis=0), axis=0) # (n_layers)
self.stroke_num_cost = tf.reduce_mean(tf.stack(sn_loss_list, axis=0))
self.pos_outside_cost = tf.reduce_mean(tf.stack(cursor_outside_loss_list, axis=0))
self.win_size_outside_cost = tf.reduce_mean(tf.stack(win_size_outside_loss_list, axis=0))
self.early_pen_states_cost = tf.reduce_mean(tf.stack(early_state_loss_list, axis=0))
self.cost = tf.reduce_mean(tf.stack(total_loss_list, axis=0))
self.pred_raster_imgs = tf.concat(pred_raster_imgs_list, axis=0) # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
self.pred_raster_imgs_rgb = tf.concat(pred_raster_imgs_rgb_list, axis=0) # (N, image_size, image_size, 3)
if self.hps.model_mode == 'train':
self.build_training_op(tower_grads)
def config_model(self):
if self.hps.model_mode == 'train':
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self.hps.dec_model == 'lstm':
dec_cell_fn = rnn.LSTMCell
elif self.hps.dec_model == 'layer_norm':
dec_cell_fn = rnn.LayerNormLSTMCell
elif self.hps.dec_model == 'hyper':
dec_cell_fn = rnn.HyperLSTMCell
else:
assert False, 'please choose a respectable cell'
use_recurrent_dropout = self.hps.use_recurrent_dropout
use_input_dropout = self.hps.use_input_dropout
use_output_dropout = self.hps.use_output_dropout
dec_cell = dec_cell_fn(
self.hps.dec_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
# dropout:
# print('Input dropout mode = %s.' % use_input_dropout)
# print('Output dropout mode = %s.' % use_output_dropout)
# print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
if use_input_dropout:
print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
dec_cell = tf.contrib.rnn.DropoutWrapper(
dec_cell, input_keep_prob=self.hps.input_dropout_prob)
if use_output_dropout:
print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
dec_cell = tf.contrib.rnn.DropoutWrapper(
dec_cell, output_keep_prob=self.hps.output_dropout_prob)
self.dec_cell = dec_cell
self.total_loop = len(self.hps.gpus) * self.hps.loop_per_gpu
self.init_cursor = tf.placeholder(
dtype=tf.float32,
shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0)
self.init_width = tf.placeholder(
dtype=tf.float32,
shape=[1]) # (1), in [0.0, 1.0]
self.image_size = tf.placeholder(dtype=tf.int32, shape=(self.total_loop)) # ()
self.init_cursor_list = tf.split(self.init_cursor, self.total_loop, axis=0)
self.input_photo_list = []
for loop_i in range(self.total_loop):
input_photo_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG]
self.input_photo_list.append(input_photo_i)
if self.hps.input_channel == 3:
self.target_sketch_list = []
for loop_i in range(self.total_loop):
target_sketch_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1]) # [0.0-stroke, 1.0-BG]
self.target_sketch_list.append(target_sketch_i)
if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
self.stroke_num_loss_weight = tf.Variable(0.0, trainable=False)
self.early_pen_loss_start_idx = tf.Variable(0, dtype=tf.int32, trainable=False)
self.early_pen_loss_end_idx = tf.Variable(0, dtype=tf.int32, trainable=False)
if self.hps.model_mode == 'train':
self.perc_loss_mean_list = []
for loop_i in range(len(self.hps.perc_loss_layers)):
relu_loss_mean = tf.Variable(0.0, trainable=False)
self.perc_loss_mean_list.append(relu_loss_mean)
self.last_step_num = tf.Variable(0.0, trainable=False)
with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):
self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
self.optimizer = tf.train.AdamOptimizer(self.lr)
###########################
def normalize_image_m1to1(self, in_img_0to1):
norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
return norm_img_m1to1
def add_coords(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get N size
xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1)
xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
[batch_size_tensor, 1]) # e.g. (N, raster_size)
xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size)
xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size)
xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1)
yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size)
yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
[batch_size_tensor, 1]) # (N, raster_size)
yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1)
yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size)
yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1)
xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
# xx_channel = xx_channel * 2 - 1 # [-1, 1]
# yy_channel = yy_channel * 2 - 1
ret = tf.concat([
input_tensor,
xx_channel,
yy_channel,
], axis=-1) # e.g. (N, raster_size, raster_size, 4)
return ret
def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
image_size, window_size):
"""
:param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
:param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
:param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
:param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
:param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
:param window_size: (N, 1, 1), float, in large size
:return:
"""
if self.hps.resize_method == 'BILINEAR':
resize_method = tf.image.ResizeMethod.BILINEAR
elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
elif self.hps.resize_method == 'BICUBIC':
resize_method = tf.image.ResizeMethod.BICUBIC
elif self.hps.resize_method == 'AREA':
resize_method = tf.image.ResizeMethod.AREA
else:
raise Exception('unknown resize_method', self.hps.resize_method)
patch_photo = tf.stop_gradient(patch_photo)
patch_canvas = tf.stop_gradient(patch_canvas)
cursor_pos = tf.stop_gradient(cursor_pos)
window_size = tf.stop_gradient(window_size)
entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
(self.hps.raster_size, self.hps.raster_size),
method=resize_method))
entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
(self.hps.raster_size, self.hps.raster_size),
method=resize_method))
entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG]
entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG]
if self.hps.encode_cursor_type == 'value':
cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2)
cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
cursor_info = cursor_pos_norm
else:
raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)
batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
axis=-1) # [N, raster_size, raster_size, 6/10]
batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4]
batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
axis=-1) # [N, raster_size, raster_size, 4/6]
if self.hps.model_mode == 'train':
is_training = True
dropout_keep_prob = self.hps.pix_drop_kp
else:
is_training = False
dropout_keep_prob = 1.0
if self.hps.add_coordconv:
batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2)
batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2)
batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2)
if 'combine' in self.hps.encoder_type:
if self.hps.encoder_type == 'combine33':
image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combine43':
image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combine53':
image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'combineFC':
image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
is_training, dropout_keep_prob) # (N, 256)
else:
raise Exception('Unknown encoder_type', self.hps.encoder_type)
else:
with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
if self.hps.encoder_type == 'conv10':
image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_deep':
image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
elif self.hps.encoder_type == 'conv13':
image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_c3':
image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv10_deep_c3':
image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
elif self.hps.encoder_type == 'conv13_c3':
image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
elif self.hps.encoder_type == 'conv13_c3_attn':
image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
else:
raise Exception('Unknown encoder_type', self.hps.encoder_type)
return image_embedding
def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])
pen_n_out = 2
params_n_out = 6
with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out)
with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out)
output = tf.concat([output_pen, output_params], axis=1) # (N, n_out)
return output, last_state
def get_mixture_coef(self, outputs):
z = outputs
z_pen_logits = z[:, 0:2] # (N, 2), pen states
z_other_params_logits = z[:, 2:] # (N, 6)
z_pen = tf.nn.softmax(z_pen_logits) # (N, 2)
if self.hps.position_format == 'abs':
x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2)
x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2)
widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1)
widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling]
# scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
# self.hps.min_scaling)
z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6)
else: # "rel"
raise Exception('Unknown position_format', self.hps.position_format)
r = [z_other_params, z_pen]
return r
###########################
def get_decoder_inputs(self):
initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
return initial_state
def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
output, last_state = tf.nn.dynamic_rnn(
dec_cell,
actual_input_x,
initial_state=initial_state,
time_major=False,
swap_memory=True,
dtype=tf.float32)
return output, last_state
###########################
def image_padding(self, ori_image, window_size, pad_value):
"""
Pad with (bg)
:param ori_image:
:return:
"""
paddings = [[0, 0],
[window_size // 2, window_size // 2],
[window_size // 2, window_size // 2],
[0, 0]]
pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k)
return pad_img
def image_cropping_fn(self, fn_inputs):
"""
crop the patch
:return:
"""
index_offset = self.hps.input_channel - 1
input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, 2), [0.0-BG, 1.0-stroke]
cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0)
image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32
window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # ()
input_img_reshape = tf.expand_dims(input_image, axis=0)
pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)
cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # ()
y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # ()
patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4)
# resize to raster_size
patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
method=tf.image.ResizeMethod.AREA)
patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
# patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]
return patch_image_scaled
def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
"""
:param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
:param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
"""
input_img_ = input_img
window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad
cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2)
image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1)
image_size_ = tf.tile(image_size_, [self.hps.batch_size // self.total_loop, image_size, image_size, 1])
window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1)
window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1)
fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
axis=-1) # (N, image_size, image_size, 2/4 + 4)
curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -)
return curr_patch_imgs
def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
"""
:param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
:param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
"""
window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad
cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1)
y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
y2 = y1 + (window_sizes_non_grad - 1.0)
x2 = x1 + (window_sizes_non_grad - 1.0)
boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4)
boxes = tf.squeeze(boxes, axis=1) # (N, 4)
boxes = boxes / tf.cast(image_size - 1, tf.float32)
box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N)
box_ind = tf.cast(box_ind, dtype=tf.int32)
box_ind = tf.cumsum(box_ind) - 1
curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
crop_size=[self.hps.raster_size, self.hps.raster_size])
# (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
return curr_patch_imgs
def get_pixel_value(self, img, x, y):
"""
Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B, H', W')
- y: flattened tensor of shape (B, H', W')
Returns
-------
- output: tensor of shape (B, H', W', C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.stack([b, y, x], 3)
return tf.gather_nd(img, indices)
def image_pasting_nondiff_single(self, fn_inputs):
patch_image = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size
image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # ()
patch_image_scaled = tf.expand_dims(patch_image, axis=0) # (1, raster_size, raster_size, 1)
patch_image_scaled = tf.image.resize_images(patch_image_scaled, (window_size, window_size),
method=tf.image.ResizeMethod.BILINEAR)
patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
# patch_canvas_scaled: (window_size, window_size, 1)
cursor_pos = tf.cast(tf.round(cursor_pos), dtype=tf.int32) # (2)
cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]
pad_up = cursor_y
pad_down = image_size - cursor_y
pad_left = cursor_x
pad_right = image_size - cursor_x
paddings = [[pad_up, pad_down],
[pad_left, pad_right],
[0, 0]]
pad_img = tf.pad(patch_image_scaled, paddings=paddings, mode='CONSTANT',
constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
crop_start = window_size // 2
pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
return pasted_image
def image_pasting_diff_single(self, fn_inputs):
patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size
image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # ()
cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]
patch_canvas_scaled = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1)
patch_canvas_scaled = tf.image.resize_images(patch_canvas_scaled, (window_size, window_size),
method=tf.image.ResizeMethod.BILINEAR)
# patch_canvas_scaled: (1, window_size, window_size, 1)
valid_canvas = self.image_pasting_diff_batch(patch_canvas_scaled,
tf.expand_dims(tf.expand_dims(cursor_pos, axis=0), axis=0),
window_size)
valid_canvas = tf.squeeze(valid_canvas, axis=0)
# (window_size + 1, window_size + 1, 1)
pad_up = tf.cast(tf.floor(cursor_y), tf.int32)
pad_down = image_size - 1 - tf.cast(tf.floor(cursor_y), tf.int32)
pad_left = tf.cast(tf.floor(cursor_x), tf.int32)
pad_right = image_size - 1 - tf.cast(tf.floor(cursor_x), tf.int32)
paddings = [[pad_up, pad_down],
[pad_left, pad_right],
[0, 0]]
pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
crop_start = window_size // 2
pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
return pasted_image
def image_pasting_diff_single_v3(self, fn_inputs):
patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
cursor_pos_a = fn_inputs[0, 0, 1:3] # (2), float32, in large size
image_size_a = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
window_size_a = fn_inputs[0, 0, 4] # (), float32, with grad
raster_size_a = float(self.hps.raster_size)
padding_size = tf.cast(tf.ceil(window_size_a / 2.0), tf.int32)
x1y1_a = cursor_pos_a - window_size_a / 2.0 # (2), float32
x2y2_a = cursor_pos_a + window_size_a / 2.0 # (2), float32
x1y1_a_floor = tf.floor(x1y1_a) # (2)
x2y2_a_ceil = tf.ceil(x2y2_a) # (2)
cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2)
cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / window_size_a * raster_size_a # (2)
raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y)
image_size_b = raster_size_a
window_size_b = raster_size_a * (raster_size_b / window_size_a) # (x, y)
cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1)
y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
y2_b = y1_b + (window_size_b[1] - 1.)
x2_b = x1_b + (window_size_b[0] - 1.)
boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4)
boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a
box_ind_b = tf.ones((1), dtype=tf.int32) # (1)
box_ind_b = tf.cumsum(box_ind_b) - 1
patch_canvas = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4)
valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
crop_size=[raster_size_b[1], raster_size_b[0]])
valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1)
pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
pad_down = image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
pad_right = image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)
paddings = [[pad_up, pad_down],
[pad_left, pad_right],
[0, 0]]
pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
pasted_image = pad_img[padding_size: padding_size + image_size_a, padding_size: padding_size + image_size_a, :]
return pasted_image
def image_pasting_diff_batch(self, patch_image, cursor_position, window_size):
"""
:param patch_img: (N, window_size, window_size, 1), [0.0-BG, 1.0-stroke]
:param cursor_position: (N, 1, 2), in large size
:return:
"""
paddings1 = [[0, 0],
[1, 1],
[1, 1],
[0, 0]]
patch_image_pad1 = tf.pad(patch_image, paddings=paddings1, mode='CONSTANT',
constant_values=0.0) # (N, window_size+2, window_size+2, 1), [0.0-BG, 1.0-stroke]
cursor_x, cursor_y = cursor_position[:, :, 0:1], cursor_position[:, :, 1:2] # (N, 1, 1)
cursor_x_f, cursor_y_f = tf.floor(cursor_x), tf.floor(cursor_y)
patch_x, patch_y = 1.0 - (cursor_x - cursor_x_f), 1.0 - (cursor_y - cursor_y_f) # (N, 1, 1)
x_ones = tf.ones_like(patch_x, dtype=tf.float32) # (N, 1, 1)
x_ones = tf.tile(x_ones, [1, 1, window_size]) # (N, 1, window_size)
patch_x = tf.concat([patch_x, x_ones], axis=-1) # (N, 1, window_size + 1)
patch_x = tf.tile(patch_x, [1, window_size + 1, 1]) # (N, window_size + 1, window_size + 1)
patch_x = tf.cumsum(patch_x, axis=-1) # (N, window_size + 1, window_size + 1)
patch_x0 = tf.cast(tf.floor(patch_x), tf.int32) # (N, window_size + 1, window_size + 1)
patch_x1 = patch_x0 + 1 # (N, window_size + 1, window_size + 1)
y_ones = tf.ones_like(patch_y, dtype=tf.float32) # (N, 1, 1)
y_ones = tf.tile(y_ones, [1, window_size, 1]) # (N, window_size, 1)
patch_y = tf.concat([patch_y, y_ones], axis=1) # (N, window_size + 1, 1)
patch_y = tf.tile(patch_y, [1, 1, window_size + 1]) # (N, window_size + 1, window_size + 1)
patch_y = tf.cumsum(patch_y, axis=1) # (N, window_size + 1, window_size + 1)
patch_y0 = tf.cast(tf.floor(patch_y), tf.int32) # (N, window_size + 1, window_size + 1)
patch_y1 = patch_y0 + 1 # (N, window_size + 1, window_size + 1)
# get pixel value at corner coords
valid_canvas_patch_a = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y0)
valid_canvas_patch_b = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y1)
valid_canvas_patch_c = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y0)
valid_canvas_patch_d = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y1)
# (N, window_size + 1, window_size + 1, 1)
patch_x0 = tf.cast(patch_x0, tf.float32)
patch_x1 = tf.cast(patch_x1, tf.float32)
patch_y0 = tf.cast(patch_y0, tf.float32)
patch_y1 = tf.cast(patch_y1, tf.float32)
# calculate deltas
wa = (patch_x1 - patch_x) * (patch_y1 - patch_y)
wb = (patch_x1 - patch_x) * (patch_y - patch_y0)
wc = (patch_x - patch_x0) * (patch_y1 - patch_y)
wd = (patch_x - patch_x0) * (patch_y - patch_y0)
# (N, window_size + 1, window_size + 1)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# (N, window_size + 1, window_size + 1, 1)
# compute output
valid_canvas_patch_ = tf.add_n([wa * valid_canvas_patch_a,
wb * valid_canvas_patch_b,
wc * valid_canvas_patch_c,
wd * valid_canvas_patch_d]) # (N, window_size + 1, window_size + 1, 1)
return valid_canvas_patch_
def image_pasting(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
"""
paste the patch_img to padded size based on cursor_position
:param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
:param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
:return:
"""
cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size
window_sizes_r = tf.round(window_sizes) # (N, 1, 1), no grad
patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1)
cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
1]) # (N, raster_size, raster_size, 2)
image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1)
image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
self.hps.raster_size, 1])
window_sizes_tile = tf.reshape(window_sizes_r, (-1, 1, 1, 1)) # (N, 1, 1, 1)
window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])
pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
axis=-1) # (N, raster_size, raster_size, 5)
if is_differentiable:
curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single, pasting_inputs,
parallel_iterations=32) # (N, image_size, image_size, 1)
else:
curr_paste_imgs = tf.map_fn(self.image_pasting_nondiff_single, pasting_inputs,
parallel_iterations=32) # (N, image_size, image_size, 1)
curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size)
return curr_paste_imgs
def image_pasting_v3(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
"""
paste the patch_img to padded size based on cursor_position
:param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
:param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
:param window_sizes: (N, 1, 1), float32, with grad
:return:
"""
cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size
if is_differentiable:
patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1)
cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
1]) # (N, raster_size, raster_size, 2)
image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1)
image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
self.hps.raster_size, 1])
window_sizes_tile = tf.reshape(window_sizes, (-1, 1, 1, 1)) # (N, 1, 1, 1)
window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])
pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
axis=-1) # (N, raster_size, raster_size, 5)
curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single_v3, pasting_inputs,
parallel_iterations=32) # (N, image_size, image_size, 1)
else:
raise Exception('Unfinished...')
curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size)
return curr_paste_imgs
def get_points_and_raster_image(self, initial_state, init_cursor, input_photo, image_size):
## generate the other_params and pen_ras and raster image for raster loss
prev_state = initial_state # (N, dec_rnn_size * 3)
prev_width = self.init_width # (1)
prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=0), axis=0) # (1, 1, 1)
prev_width = tf.tile(prev_width, [self.hps.batch_size // self.total_loop, 1, 1]) # (N, 1, 1)
prev_scaling = tf.ones((self.hps.batch_size // self.total_loop, 1, 1)) # (N, 1, 1)
prev_window_size = tf.ones((self.hps.batch_size // self.total_loop, 1, 1),
dtype=tf.float32) * float(self.hps.raster_size) # (N, 1, 1)
cursor_position_temp = init_cursor
self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0)
cursor_position_loop = self.cursor_position
other_params_list = []
pen_ras_list = []
pos_before_max_min_list = []
win_size_before_max_min_list = []
curr_canvas_soft = tf.zeros_like(input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
curr_canvas_soft_rgb = tf.tile(tf.zeros_like(input_photo[:, :, :, 0:1]), [1, 1, 1, 3]) # (N, image_size, image_size, 3), [0.0-BG, 1.0-stroke]
curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke]
#### sampling part - start ####
self.curr_canvas_hard = curr_canvas_hard
rasterizor_st = NeuralRasterizorStep(
raster_size=self.hps.raster_size,
position_format=self.hps.position_format)
if self.hps.cropping_type == 'v3':
cropping_func = self.image_cropping_v3
# elif self.hps.cropping_type == 'v2':
# cropping_func = self.image_cropping
else:
raise Exception('Unknown cropping_type', self.hps.cropping_type)
if self.hps.pasting_type == 'v3':
pasting_func = self.image_pasting_v3
# elif self.hps.pasting_type == 'v2':
# pasting_func = self.image_pasting
else:
raise Exception('Unknown pasting_type', self.hps.pasting_type)
for time_i in range(self.hps.max_seq_len):
cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0)
curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad
curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))
## patch-level encoding
# Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)
# input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
crop_inputs = tf.concat([1.0 - input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1/3+1)
cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
index_offset = self.hps.input_channel - 1
curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke]
curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
# (N, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG]
curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
# (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
# Normalizing image
curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG]
curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG]
## image-level encoding
combined_z = self.build_combined_encoder(
curr_patch_canvas_hard_non_grad,
curr_patch_inputs,
1.0 - curr_canvas_hard_non_grad,
input_photo,
cursor_position_non_grad,
image_size,
curr_window_size) # (N, z_size)
combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size)
curr_window_size_top_side_norm_non_grad = \
tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
curr_window_size_bottom_side_norm_non_grad = \
tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
if not self.hps.concat_win_size:
combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size)
else:
combined_z = tf.concat([tf.stop_gradient(prev_width),
curr_window_size_top_side_norm_non_grad,
curr_window_size_bottom_side_norm_non_grad,
combined_z],
2) # (N, 1, 2+z_size)
if self.hps.concat_cursor:
prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size)
else:
prev_input_x = combined_z # (N, 1, 2+z_size)
h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
# h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
[o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
# o_other_params: (N * 1, 6)
# o_pen_ras: (N * 1, 2), after softmax
o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6)
o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2)
other_params_list.append(o_other_params)
pen_ras_list.append(o_pen_ras_raw)
#### sampling part - end ####
if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval' or self.hps.model_mode == 'eval_sample':
# use renderer here to convert the strokes to image
curr_other_params = tf.squeeze(o_other_params, axis=1) # (N, 6), (x1, y1)=[0.0, 1.0], (x2, y2)=[-1.0, 1.0]
x1y1, x2y2, width2, scaling = curr_other_params[:, 0:2], curr_other_params[:, 2:4],\
curr_other_params[:, 4:5], curr_other_params[:, 5:6]
x0y0 = tf.zeros_like(x2y2) # (N, 2), [-1.0, 1.0]
x0y0 = tf.div(tf.add(x0y0, 1.0), 2.0) # (N, 2), [0.0, 1.0]
x2y2 = tf.div(tf.add(x2y2, 1.0), 2.0) # (N, 2), [0.0, 1.0]
widths = tf.concat([tf.squeeze(prev_width, axis=1), width2], axis=1) # (N, 2)
curr_other_params = tf.concat([x0y0, x1y1, x2y2, widths], axis=-1) # (N, 8), (x0, y0)&(x2, y2)=[0.0, 1.0]
curr_stroke_image = rasterizor_st.raster_func_stroke_abs(curr_other_params)
# (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
curr_stroke_image_large = pasting_func(cursor_position_loop, curr_stroke_image,
image_size, curr_window_size,
is_differentiable=self.hps.pasting_diff)
# (N, image_size, image_size), [0.0-BG, 1.0-stroke]
## soft
if not self.hps.use_softargmax:
curr_state_soft = o_pen_ras[:, 1:2] # (N, 1)
else:
curr_state_soft = self.differentiable_argmax(o_pen_ras, self.hps.soft_beta) # (N, 1)
curr_state_soft = tf.expand_dims(curr_state_soft, axis=1) # (N, 1, 1)
filter_curr_stroke_image_soft = tf.multiply(tf.subtract(1.0, curr_state_soft), curr_stroke_image_large)
# (N, image_size, image_size), [0.0-BG, 1.0-stroke]
curr_canvas_soft = tf.add(curr_canvas_soft, filter_curr_stroke_image_soft) # [0.0-BG, 1.0-stroke]
## hard
curr_state_hard = tf.expand_dims(tf.cast(tf.argmax(o_pen_ras_raw, axis=-1), dtype=tf.float32),
axis=-1) # (N, 1, 1)
filter_curr_stroke_image_hard = tf.multiply(tf.subtract(1.0, curr_state_hard), curr_stroke_image_large)
# (N, image_size, image_size), [0.0-BG, 1.0-stroke]
self.curr_canvas_hard = tf.add(self.curr_canvas_hard, filter_curr_stroke_image_hard) # [0.0-BG, 1.0-stroke]
self.curr_canvas_hard = tf.clip_by_value(self.curr_canvas_hard, 0.0, 1.0) # [0.0-BG, 1.0-stroke]
next_width = o_other_params[:, :, 4:5]
next_scaling = o_other_params[:, :, 5:6]
next_window_size = tf.multiply(next_scaling, tf.stop_gradient(curr_window_size)) # float, with grad
window_size_before_max_min = next_window_size # (N, 1, 1), large-level
win_size_before_max_min_list.append(window_size_before_max_min)
next_window_size = tf.maximum(next_window_size, tf.cast(self.hps.min_window_size, tf.float32))
next_window_size = tf.minimum(next_window_size, tf.cast(image_size, tf.float32))
prev_state = next_state
prev_width = next_width * curr_window_size / next_window_size # (N, 1, 1)
prev_scaling = next_scaling # (N, 1, 1))
prev_window_size = curr_window_size
# update the cursor position
new_cursor_offsets = tf.multiply(o_other_params[:, :, 2:4],
tf.divide(curr_window_size, 2.0)) # (N, 1, 2), window-level
new_cursor_offset_next = new_cursor_offsets
new_cursor_offset_next = tf.concat([new_cursor_offset_next[:, :, 1:2], new_cursor_offset_next[:, :, 0:1]], axis=-1)
cursor_position_loop_large = tf.multiply(cursor_position_loop, tf.cast(image_size, tf.float32))
if self.hps.stop_accu_grad:
stroke_position_next = tf.stop_gradient(cursor_position_loop_large) + new_cursor_offset_next # (N, 1, 2), large-level
else:
stroke_position_next = cursor_position_loop_large + new_cursor_offset_next # (N, 1, 2), large-level
stroke_position_before_max_min = stroke_position_next # (N, 1, 2), large-level
pos_before_max_min_list.append(stroke_position_before_max_min)
if self.hps.cursor_type == 'next':
cursor_position_loop_large = stroke_position_next # (N, 1, 2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_position_loop_large = tf.maximum(cursor_position_loop_large, 0.0)
cursor_position_loop_large = tf.minimum(cursor_position_loop_large, tf.cast(image_size - 1, tf.float32))
cursor_position_loop = tf.div(cursor_position_loop_large, tf.cast(image_size, tf.float32))
curr_canvas_soft = tf.clip_by_value(curr_canvas_soft, 0.0, 1.0) # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6)
pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2)
pos_before_max_min_ = tf.concat(pos_before_max_min_list, axis=1) # (N, max_seq_len, 2)
win_size_before_max_min_ = tf.concat(win_size_before_max_min_list, axis=1) # (N, max_seq_len, 1)
return other_params_, pen_ras_, prev_state, curr_canvas_soft, curr_canvas_soft_rgb, \
pos_before_max_min_, win_size_before_max_min_
def differentiable_argmax(self, input_pen, soft_beta):
"""
Differentiable argmax trick.
:param input_pen: (N, n_class)
:return: pen_state: (N, 1)
"""
def sign_onehot(x):
"""
:param x: (N, n_class)
:return: (N, n_class)
"""
y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)
y = (y - 1) * (-1)
return y
def softargmax(x, beta=1e2):
"""
:param x: (N, n_class)
:param beta: 1e10 is the best. 1e2 is acceptable.
:return: (N)
"""
x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2)
return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1
## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.
# pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2)
# pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1)
pen_state = softargmax(input_pen, soft_beta)
pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1)
return pen_state
def build_losses(self, target_sketch, pred_raster_imgs, pred_params,
pos_before_max_min, win_size_before_max_min, image_size):
def get_raster_loss(pred_imgs, gt_imgs, loss_type):
perc_layer_losses_raw = []
perc_layer_losses_weighted = []
perc_layer_losses_norm = []
if loss_type == 'l1':
ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs, pred_imgs))) # ()
elif loss_type == 'l1_small':
gt_imgs_small = tf.image.resize_images(tf.expand_dims(gt_imgs, axis=3), (32, 32))
pred_imgs_small = tf.image.resize_images(tf.expand_dims(pred_imgs, axis=3), (32, 32))
ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs_small, pred_imgs_small))) # ()
elif loss_type == 'mse':
ras_cost = tf.reduce_mean(tf.pow(tf.subtract(gt_imgs, pred_imgs), 2)) # ()
elif loss_type == 'perceptual':
return_map_pred = vgg_net_slim(pred_imgs, image_size)
return_map_gt = vgg_net_slim(gt_imgs, image_size)
perc_loss_type = 'l1' # [l1, mse]
weighted_map = {'ReLU1_1': 100.0, 'ReLU1_2': 100.0,
'ReLU2_1': 100.0, 'ReLU2_2': 100.0,
'ReLU3_1': 10.0, 'ReLU3_2': 10.0, 'ReLU3_3': 10.0,
'ReLU4_1': 1.0, 'ReLU4_2': 1.0, 'ReLU4_3': 1.0,
'ReLU5_1': 1.0, 'ReLU5_2': 1.0, 'ReLU5_3': 1.0}
for perc_layer in self.hps.perc_loss_layers:
if perc_loss_type == 'l1':
perc_layer_loss = tf.reduce_mean(tf.abs(tf.subtract(return_map_pred[perc_layer],
return_map_gt[perc_layer]))) # ()
elif perc_loss_type == 'mse':
perc_layer_loss = tf.reduce_mean(tf.pow(tf.subtract(return_map_pred[perc_layer],
return_map_gt[perc_layer]), 2)) # ()
else:
raise NameError('Unknown perceptual loss type:', perc_loss_type)
perc_layer_losses_raw.append(perc_layer_loss)
assert perc_layer in weighted_map
perc_layer_losses_weighted.append(perc_layer_loss * weighted_map[perc_layer])
for loop_i in range(len(self.hps.perc_loss_layers)):
perc_relu_loss_raw = perc_layer_losses_raw[loop_i] # ()
if self.hps.model_mode == 'train':
curr_relu_mean = (self.perc_loss_mean_list[loop_i] * self.last_step_num + perc_relu_loss_raw) / (self.last_step_num + 1.0)
relu_cost_norm = perc_relu_loss_raw / curr_relu_mean
else:
relu_cost_norm = perc_relu_loss_raw
perc_layer_losses_norm.append(relu_cost_norm)
perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)
perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)
if self.hps.perc_loss_fuse_type == 'max':
ras_cost = tf.reduce_max(perc_layer_losses_norm)
elif self.hps.perc_loss_fuse_type == 'add':
ras_cost = tf.reduce_mean(perc_layer_losses_norm)
elif self.hps.perc_loss_fuse_type == 'raw_add':
ras_cost = tf.reduce_mean(perc_layer_losses_raw)
elif self.hps.perc_loss_fuse_type == 'weighted_sum':
ras_cost = tf.reduce_mean(perc_layer_losses_weighted)
else:
raise NameError('Unknown perc_loss_fuse_type:', self.hps.perc_loss_fuse_type)
elif loss_type == 'triplet':
raise Exception('Solution for triplet loss is coming soon.')
else:
raise NameError('Unknown loss type:', loss_type)
if loss_type != 'perceptual':
for perc_layer_i in self.hps.perc_loss_layers:
perc_layer_losses_raw.append(tf.constant(0.0))
perc_layer_losses_norm.append(tf.constant(0.0))
perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)
perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)
return ras_cost, perc_layer_losses_raw, perc_layer_losses_norm
gt_raster_images = tf.squeeze(target_sketch, axis=3) # (N, raster_h, raster_w), [0.0-stroke, 1.0-BG]
raster_cost, perc_relu_losses_raw, perc_relu_losses_norm = \
get_raster_loss(pred_raster_imgs, gt_raster_images, loss_type=self.hps.raster_loss_base_type)
def get_stroke_num_loss(input_strokes):
ending_state = input_strokes[:, :, 0] # (N, seq_len)
stroke_num_loss_pre = tf.reduce_mean(ending_state) # larger is better, [0.0, 1.0]
stroke_num_loss = 1.0 - stroke_num_loss_pre # lower is better, [0.0, 1.0]
return stroke_num_loss
stroke_num_cost = get_stroke_num_loss(pred_params) # lower is better
def get_pos_outside_loss(pos_before_max_min_):
pos_after_max_min = tf.maximum(pos_before_max_min_, 0.0)
pos_after_max_min = tf.minimum(pos_after_max_min, tf.cast(image_size - 1, tf.float32)) # (N, max_seq_len, 2)
pos_outside_loss = tf.reduce_mean(tf.abs(pos_before_max_min_ - pos_after_max_min))
return pos_outside_loss
pos_outside_cost = get_pos_outside_loss(pos_before_max_min) # lower is better
def get_win_size_outside_loss(win_size_before_max_min_, min_window_size):
win_size_outside_top_loss = tf.divide(
tf.maximum(win_size_before_max_min_ - tf.cast(image_size, tf.float32), 0.0),
tf.cast(image_size, tf.float32)) # (N, max_seq_len, 1)
win_size_outside_bottom_loss = tf.divide(
tf.maximum(tf.cast(min_window_size, tf.float32) - win_size_before_max_min_, 0.0),
tf.cast(min_window_size, tf.float32)) # (N, max_seq_len, 1)
win_size_outside_loss = tf.reduce_mean(win_size_outside_top_loss + win_size_outside_bottom_loss)
return win_size_outside_loss
win_size_outside_cost = get_win_size_outside_loss(win_size_before_max_min, self.hps.min_window_size) # lower is better
def get_early_pen_states_loss(input_strokes, curr_start, curr_end):
# input_strokes: (N, max_seq_len, 7)
pred_early_pen_states = input_strokes[:, curr_start:curr_end, 0] # (N, curr_early_len)
pred_early_pen_states_min = tf.reduce_min(pred_early_pen_states, axis=1) # (N), should not be 1
early_pen_states_loss = tf.reduce_mean(pred_early_pen_states_min) # lower is better
return early_pen_states_loss
early_pen_states_cost = get_early_pen_states_loss(pred_params,
self.early_pen_loss_start_idx, self.early_pen_loss_end_idx)
return raster_cost, stroke_num_cost, pos_outside_cost, win_size_outside_cost, \
early_pen_states_cost, \
perc_relu_losses_raw, perc_relu_losses_norm
def build_training_op_split(self, raster_cost, sn_cost, cursor_outside_cost, win_size_outside_cost,
early_pen_states_cost):
total_cost = self.hps.raster_loss_weight * raster_cost + \
self.hps.early_pen_loss_weight * early_pen_states_cost + \
self.stroke_num_loss_weight * sn_cost + \
self.hps.outside_loss_weight * cursor_outside_cost + \
self.hps.win_size_outside_loss_weight * win_size_outside_cost
tvars = [var for var in tf.trainable_variables()
if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]
gvs = self.optimizer.compute_gradients(total_cost, var_list=tvars)
return total_cost, gvs
def build_training_op(self, grad_list):
with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):
gvs = self.average_gradients(grad_list)
g = self.hps.grad_clip
for grad, var in gvs:
print('>>', var.op.name)
if grad is None:
print(' >> None value')
capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]
self.train_op = self.optimizer.apply_gradients(
capped_gvs, global_step=self.global_step, name='train_step')
def average_gradients(self, grads_list):
"""
Compute the average gradients.
:param grads_list: list(of length N_GPU) of list(grad, var)
:return:
"""
avg_grads = []
for grad_and_vars in zip(*grads_list):
grads = []
for g, _ in grad_and_vars:
expanded_g = tf.expand_dims(g, 0)
grads.append(expanded_g)
grad = tf.concat(grads, axis=0)
grad = tf.reduce_mean(grad, axis=0)
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
avg_grads.append(grad_and_var)
return avg_grads
================================================
FILE: rasterization_utils/NeuralRenderer.py
================================================
import tensorflow as tf
class RasterUnit(object):
def __init__(self,
raster_size,
input_params, # (N, 10)
reuse=False):
self.raster_size = raster_size
self.input_params = input_params
with tf.variable_scope("raster_unit", reuse=reuse):
self.build_unit()
def build_unit(self):
x = self.input_params # (N, 10)
x = self.fully_connected(x, 10, 512, scope='fc1') # (N, 512)
x = tf.nn.relu(x)
x = self.fully_connected(x, 512, 1024, scope='fc2') # (N, 1024)
x = tf.nn.relu(x)
x = self.fully_connected(x, 1024, 2048, scope='fc3') # (N, 2048)
x = tf.nn.relu(x)
x = self.fully_connected(x, 2048, 4096, scope='fc4') # (N, 4096)
x = tf.nn.relu(x)
x = tf.reshape(x, (-1, 16, 16, 16)) # (N, 16, 16, 16)
x = tf.transpose(x, (0, 2, 3, 1))
x = self.conv2d(x, 32, 3, 1, scope='conv1') # (N, 16, 16, 32)
x = tf.nn.relu(x)
x = self.conv2d(x, 32, 3, 1, scope='conv2') # (N, 16, 16, 32)
x = self.pixel_shuffle(x, upscale_factor=2) # (N, 32, 32, 8)
x = self.conv2d(x, 16, 3, 1, scope='conv3') # (N, 32, 32, 16)
x = tf.nn.relu(x)
x = self.conv2d(x, 16, 3, 1, scope='conv4') # (N, 32, 32, 16)
x = self.pixel_shuffle(x, upscale_factor=2) # (N, 64, 64, 4)
x = self.conv2d(x, 8, 3, 1, scope='conv5') # (N, 64, 64, 8)
x = tf.nn.relu(x)
x = self.conv2d(x, 4, 3, 1, scope='conv6') # (N, 64, 64, 4)
x = self.pixel_shuffle(x, upscale_factor=2) # (N, 128, 128, 1)
x = tf.sigmoid(x)
# (N, 128, 128), [0.0-stroke, 1.0-BG]
self.stroke_image = 1.0 - tf.reshape(x, (-1, self.raster_size, self.raster_size))
def conv2d(self, input_tensor, out_channels, kernel_size, stride, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
output_tensor = tf.layers.conv2d(input_tensor, out_channels, kernel_size=kernel_size,
strides=(stride, stride),
padding="same", kernel_initializer=tf.keras.initializers.he_normal())
return output_tensor
def fully_connected(self, input_tensor, in_dim, out_dim, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
weight = tf.get_variable("weight", [in_dim, out_dim], dtype=tf.float32,
initializer=tf.random_normal_initializer())
bias = tf.get_variable("bias", [out_dim], dtype=tf.float32,
initializer=tf.random_normal_initializer())
output_tensor = tf.matmul(input_tensor, weight) + bias
return output_tensor
def pixel_shuffle(self, input_tensor, upscale_factor):
params_shape = input_tensor.get_shape()
n, h, w, c = params_shape
input_tensor_proc = tf.reshape(input_tensor, (n, h, w, c // 4, 4))
input_tensor_proc = tf.transpose(input_tensor_proc, (0, 1, 2, 4, 3))
input_tensor_proc = tf.reshape(input_tensor_proc, (n, h, w, -1))
output_tensor = tf.depth_to_space(input_tensor_proc, block_size=upscale_factor)
return output_tensor
class NeuralRasterizor(object):
def __init__(self,
raster_size,
seq_len,
position_format='abs',
raster_padding=10,
strokes_format=3):
self.raster_size = raster_size
self.seq_len = seq_len
self.position_format = position_format
self.raster_padding = raster_padding
self.strokes_format = strokes_format
assert position_format in ['abs', 'rel']
def raster_func_abs(self, input_data, raster_seq_len=None):
"""
x and y in absolute position.
:param input_data: (N, seq_len, 10): [x0, y0, x1, y1, x2, y2, r0, r2, w0, w2]. All in [0.0, 1.0]
:return:
"""
seq_len = raster_seq_len if raster_seq_len is not None else self.seq_len
raster_params = tf.transpose(input_data, [1, 0, 2]) # (seq_len, N, 10)
seq_stroke_images = tf.map_fn(self.stroke_drawer_with_raster_unit, raster_params,
parallel_iterations=32) # (seq_len, N, raster_size, raster_size)
seq_stroke_images = tf.transpose(seq_stroke_images, [1, 2, 3, 0])
# (N, raster_size, raster_size, seq_len), [0.0-stroke, 1.0-BG]
filter_seq_stroke_images = 1.0 - seq_stroke_images
# (N, raster_size, raster_size, seq_len), [0.0-BG, 1.0-stroke]
# stacking
stroke_images_unclip = tf.reduce_sum(filter_seq_stroke_images, axis=-1) # (N, raster_size, raster_size)
stroke_images = tf.clip_by_value(stroke_images_unclip, 0.0, 1.0) # [0.0-BG, 1.0-stroke]
return stroke_images, stroke_images_unclip, seq_stroke_images
def stroke_drawer_with_raster_unit(self, params_batch):
"""
Convert two points into a raster stroke image with RasterUnit.
:param params_batch: (N, 10)
:return: (N, raster_size, raster_size)
"""
raster_unit = RasterUnit(
raster_size=self.raster_size,
input_params=params_batch,
reuse=tf.AUTO_REUSE
)
stroke_image = raster_unit.stroke_image # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]
return stroke_image
class NeuralRasterizorStep(object):
def __init__(self,
raster_size,
position_format='abs'):
self.raster_size = raster_size
self.position_format = position_format
assert position_format in ['abs', 'rel']
def raster_func_stroke_abs(self, input_data):
"""
x and y in absolute position.
:param input_data: (N, 8): [x0, y0, x1, y1, x2, y2, r0, r2]. All in [0.0, 1.0]
:return:
"""
w_in = tf.ones(shape=(input_data.shape[0], 2), dtype=tf.float32)
raster_params = tf.concat([input_data, w_in], axis=-1) # (N, 10)
stroke_image = self.stroke_drawer_with_raster_unit(raster_params) # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]
stroke_image = 1.0 - stroke_image # [0.0-BG, 1.0-stroke]
return stroke_image
def mask_ending_state(self, input_states):
"""
Mask the ending state to be 1
:param input_states: (N, seq_len, 1) in offset manner
:param seq_len:
:return:
"""
ending_state_accu = tf.cumsum(input_states, axis=1) # (N, seq_len, 1)
ending_state_clip = tf.clip_by_value(ending_state_accu, 0.0, 1.0) # (N, seq_len, 1)
return ending_state_clip
def stroke_drawer_with_raster_unit(self, params_batch):
"""
Convert two points into a raster stroke image with RasterUnit.
:param params_batch: (N, 10)
:return: (N, raster_size, raster_size)
"""
raster_unit = RasterUnit(
raster_size=self.raster_size,
input_params=params_batch,
reuse=tf.AUTO_REUSE
)
stroke_image = raster_unit.stroke_image # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]
return stroke_image
================================================
FILE: rasterization_utils/RealRenderer.py
================================================
import numpy as np
import gizeh
class GizehRasterizor(object):
def __init__(self):
self.name = 'GizehRasterizor'
def get_line_array_v2(self, image_size, seq_strokes, stroke_width, is_bin=True):
"""
:param p1: (x, y)
:param p2: (x, y)
:return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes
"""
surface = gizeh.Surface(width=image_size, height=image_size) # in pixels
shape_list = []
for seq_i in range(len(seq_strokes) - 1):
p1, p2 = seq_strokes[seq_i, :2], seq_strokes[seq_i + 1, :2]
pen_state = seq_strokes[seq_i, 2]
if pen_state == 0.0:
line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0))
shape_list.append(line)
group = gizeh.Group(shape_list)
group.draw(surface)
# Now export the surface
line_arr = surface.get_npimage()[:, :, 0] # returns a (width x height x 3) numpy array
if is_bin:
line_arr[line_arr <= 128] = 0
line_arr[line_arr != 0] = 1 # (image_size, image_size)
else:
line_arr = np.array(line_arr, dtype=np.float32) / 255.0
return line_arr
def get_line_array(self, p1, p2, image_size, stroke_width, is_bin=True):
"""
:param p1: (x, y)
:param p2: (x, y)
:return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes
"""
surface = gizeh.Surface(width=image_size, height=image_size) # in pixels
line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0))
line.draw(surface)
# Now export the surface
line_arr = surface.get_npimage()[:, :, 0] # returns a (width x height x 3) numpy array
if is_bin:
line_arr[line_arr <= 128] = 0
line_arr[line_arr != 0] = 1 # (image_size, image_size)
else:
line_arr = np.array(line_arr, dtype=np.float32) / 255.0
return line_arr
def load_sketch_images_on_the_fly_v2(self, image_size, norm_strokes3, stroke_width, is_bin=True):
"""
:param norm_strokes3: list (N_sketches,), each with (N_points, 3)
:return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes
"""
assert type(norm_strokes3) is list
sketch_imgs_list = []
for stroke_i in range(len(norm_strokes3)):
seq_strokes3 = norm_strokes3[stroke_i] # (N_points, 3)
sketch_img = self.get_line_array_v2(image_size, seq_strokes3, stroke_width=stroke_width, is_bin=is_bin)
sketch_img = np.clip(sketch_img, 0.0, 1.0) # (image_size, image_size), 0 for BG and 1 for strokes
sketch_imgs_list.append(sketch_img)
return sketch_imgs_list
def load_sketch_images_on_the_fly(self, image_size, norm_strokes3, stroke_width, is_bin=True):
"""
:param norm_strokes3: list (N_sketches,), each with (N_points, 3)
:return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes
"""
assert type(norm_strokes3) is list
sketch_imgs_list = []
for stroke_i in range(len(norm_strokes3)):
seq_strokes3 = norm_strokes3[stroke_i] # (N_points, 3)
seq_len = len(seq_strokes3)
stroke_imgs_list = []
for seq_i in range(seq_len - 1):
stroke_img = self.get_line_array(seq_strokes3[seq_i, :2], seq_strokes3[seq_i + 1, :2], image_size,
stroke_width=stroke_width, is_bin=is_bin)
pen_state = seq_strokes3[seq_i, 2]
stroke_img = stroke_img.astype(np.float32) * (1. - pen_state)
stroke_imgs_list.append(stroke_img)
stroke_imgs_list = np.stack(stroke_imgs_list,
axis=-1) # (image_size, image_size, seq_len-1), 0 for BG and 1 for strokes
stroke_imgs_list = np.sum(stroke_imgs_list, axis=-1)
stroke_imgs_list = np.clip(stroke_imgs_list, 0.0, 1.0) # (image_size, image_size), 0 for BG and 1 for strokes
sketch_imgs_list.append(stroke_imgs_list)
return sketch_imgs_list
def normalize_coordinate_np(self, sx, sy, image_size, raster_padding=10.0):
"""
Convert offset to normalized absolute points. The numpy version as in NeuralRasterizor.
:param sx: (N, seq_len)
:param sy: (N, seq_len)
:return:
"""
seq_len = sx.shape[1]
# transfer to abs points
abs_x = np.cumsum(sx, axis=1) # (N, seq_len)
abs_y = np.cumsum(sy, axis=1)
min_x = np.min(abs_x, axis=1, keepdims=True) # (N, 1)
max_x = np.max(abs_x, axis=1, keepdims=True)
min_y = np.min(abs_y, axis=1, keepdims=True)
max_y = np.max(abs_y, axis=1, keepdims=True)
# transform to positive coordinate
abs_x = np.subtract(abs_x, np.tile(min_x, [1, seq_len])) # (N, seq_len)
abs_y = np.subtract(abs_y, np.tile(min_y, [1, seq_len]))
# scaling to [0.0, raster_size - 2 * padding - 1]
bbox_w = np.squeeze(np.subtract(max_x, min_x), axis=-1) # (N)
bbox_h = np.squeeze(np.subtract(max_y, min_y), axis=-1)
unpad_raster_size = (image_size - 1.0) - 2.0 * raster_padding
scaling = np.divide(unpad_raster_size, np.maximum(bbox_w, bbox_h)) # (N)
scaling_tile = np.tile(np.expand_dims(scaling, axis=-1), [1, seq_len]) # (N, seq_len)
abs_x = np.multiply(abs_x, scaling_tile) # (N, seq_len)
abs_y = np.multiply(abs_y, scaling_tile)
# add padding
abs_x = np.add(abs_x, raster_padding) # (N, seq_len)
abs_y = np.add(abs_y, raster_padding)
# transform to the middle
trans_x = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_w, scaling)), 2.0) # (N)
trans_y = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_h, scaling)), 2.0)
trans_x = np.tile(np.expand_dims(trans_x, axis=-1), [1, seq_len]) # (N, seq_len)
trans_y = np.tile(np.expand_dims(trans_y, axis=-1), [1, seq_len]) # (N, seq_len)
abs_x = np.add(abs_x, trans_x) # (N, seq_len)
abs_y = np.add(abs_y, trans_y)
return abs_x, abs_y
def normalize_strokes_np(self, strokes_list, image_size):
"""
:param strokes_list: list (N_sketches,), each with (N_points, 3)
:return:
"""
assert type(strokes_list) is list
rst_list = []
for i in range(len(strokes_list)):
strokes_data = strokes_list[i] # (N_points, 3)
norm_x, norm_y = self.normalize_coordinate_np(np.expand_dims(strokes_data[:, 0], axis=0),
np.expand_dims(strokes_data[:, 1], axis=0),
image_size) # (1, N_points)
norm_strokes_data = np.stack([norm_x[0], norm_y[0], strokes_data[:, 2]], axis=-1) # (N_points, 3)
rst_list.append(norm_strokes_data)
return rst_list
def raster_func(self, input_data, image_size, stroke_width, is_bin=True, version='v2'):
"""
:param input_data: (N_sketches,), each with (N_points, 3)
:return: raster_image_array: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes
"""
norm_test_strokes3 = self.normalize_strokes_np(input_data, image_size)
if version == 'v1':
raster_image_array = self.load_sketch_images_on_the_fly(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin)
else:
raster_image_array = self.load_sketch_images_on_the_fly_v2(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin)
return raster_image_array
================================================
FILE: rnn.py
================================================
# Copyright 2019 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SketchRNN RNN definition."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def orthogonal(shape):
"""Orthogonal initilaizer."""
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v
return q.reshape(shape)
def orthogonal_initializer(scale=1.0):
"""Orthogonal initializer."""
def _initializer(shape, dtype=tf.float32,
partition_info=None): # pylint: disable=unused-argument
return tf.constant(orthogonal(shape) * scale, dtype)
return _initializer
def lstm_ortho_initializer(scale=1.0):
"""LSTM orthogonal initializer."""
def _initializer(shape, dtype=tf.float32,
partition_info=None): # pylint: disable=unused-argument
size_x = shape[0]
size_h = shape[1] // 4 # assumes lstm.
t = np.zeros(shape)
t[:, :size_h] = orthogonal([size_x, size_h]) * scale
t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale
t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale
t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale
return tf.constant(t, dtype)
return _initializer
class LSTMCell(tf.contrib.rnn.RNNCell):
"""Vanilla LSTM cell.
Uses ortho initializer, and also recurrent dropout without memory loss
(https://arxiv.org/abs/1603.05118)
"""
def __init__(self,
num_units,
forget_bias=1.0,
use_recurrent_dropout=False,
dropout_keep_prob=0.9):
self.num_units = num_units
self.forget_bias = forget_bias
self.use_recurrent_dropout = use_recurrent_dropout
self.dropout_keep_prob = dropout_keep_prob
@property
def state_size(self):
return 2 * self.num_units
@property
def output_size(self):
return self.num_units
def get_output(self, state):
unused_c, h = tf.split(state, 2, 1)
return h
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h = tf.split(state, 2, 1)
x_size = x.get_shape().as_list()[1]
w_init = None # uniform
h_init = lstm_ortho_initializer(1.0)
# Keep W_xh and W_hh separate here as well to use different init methods.
w_xh = tf.get_variable(
'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
w_hh = tf.get_variable(
'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
bias = tf.get_variable(
'bias', [4 * self.num_units],
initializer=tf.constant_initializer(0.0))
concat = tf.concat([x, h], 1)
w_full = tf.concat([w_xh, w_hh], 0)
hidden = tf.matmul(concat, w_full) + bias
i, j, f, o = tf.split(hidden, 4, 1)
if self.use_recurrent_dropout:
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
else:
g = tf.tanh(j)
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
new_h = tf.tanh(new_c) * tf.sigmoid(o)
return new_h, tf.concat([new_c, new_h], 1) # fuk tuples.
def layer_norm_all(h,
batch_size,
base,
num_units,
scope='layer_norm',
reuse=False,
gamma_start=1.0,
epsilon=1e-3,
use_bias=True):
"""Layer Norm (faster version, but not using defun)."""
# Performs layer norm on multiple base at once (ie, i, g, j, o for lstm)
# Reshapes h in to perform layer norm in parallel
h_reshape = tf.reshape(h, [batch_size, base, num_units])
mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)
var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)
epsilon = tf.constant(epsilon)
rstd = tf.rsqrt(var + epsilon)
h_reshape = (h_reshape - mean) * rstd
# reshape back to original
h = tf.reshape(h_reshape, [batch_size, base * num_units])
with tf.variable_scope(scope):
if reuse:
tf.get_variable_scope().reuse_variables()
gamma = tf.get_variable(
'ln_gamma', [4 * num_units],
initializer=tf.constant_initializer(gamma_start))
if use_bias:
beta = tf.get_variable(
'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0))
if use_bias:
return gamma * h + beta
return gamma * h
def layer_norm(x,
num_units,
scope='layer_norm',
reuse=False,
gamma_start=1.0,
epsilon=1e-3,
use_bias=True):
"""Calculate layer norm."""
axes = [1]
mean = tf.reduce_mean(x, axes, keep_dims=True)
x_shifted = x - mean
var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True)
inv_std = tf.rsqrt(var + epsilon)
with tf.variable_scope(scope):
if reuse:
tf.get_variable_scope().reuse_variables()
gamma = tf.get_variable(
'ln_gamma', [num_units],
initializer=tf.constant_initializer(gamma_start))
if use_bias:
beta = tf.get_variable(
'ln_beta', [num_units], initializer=tf.constant_initializer(0.0))
output = gamma * (x_shifted) * inv_std
if use_bias:
output += beta
return output
def raw_layer_norm(x, epsilon=1e-3):
axes = [1]
mean = tf.reduce_mean(x, axes, keep_dims=True)
std = tf.sqrt(
tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True) + epsilon)
output = (x - mean) / (std)
return output
def super_linear(x,
output_size,
scope=None,
reuse=False,
init_w='ortho',
weight_start=0.0,
use_bias=True,
bias_start=0.0,
input_size=None):
"""Performs linear operation. Uses ortho init defined earlier."""
shape = x.get_shape().as_list()
with tf.variable_scope(scope or 'linear'):
if reuse:
tf.get_variable_scope().reuse_variables()
w_init = None # uniform
if input_size is None:
x_size = shape[1]
else:
x_size = input_size
if init_w == 'zeros':
w_init = tf.constant_initializer(0.0)
elif init_w == 'constant':
w_init = tf.constant_initializer(weight_start)
elif init_w == 'gaussian':
w_init = tf.random_normal_initializer(stddev=weight_start)
elif init_w == 'ortho':
w_init = lstm_ortho_initializer(1.0)
w = tf.get_variable(
'super_linear_w', [x_size, output_size], tf.float32, initializer=w_init)
if use_bias:
b = tf.get_variable(
'super_linear_b', [output_size],
tf.float32,
initializer=tf.constant_initializer(bias_start))
return tf.matmul(x, w) + b
return tf.matmul(x, w)
class LayerNormLSTMCell(tf.contrib.rnn.RNNCell):
"""Layer-Norm, with Ortho Init. and Recurrent Dropout without Memory Loss.
https://arxiv.org/abs/1607.06450 - Layer Norm
https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss
"""
def __init__(self,
num_units,
forget_bias=1.0,
use_recurrent_dropout=False,
dropout_keep_prob=0.90):
"""Initialize the Layer Norm LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (default 1.0).
use_recurrent_dropout: Whether to use Recurrent Dropout (default False)
dropout_keep_prob: float, dropout keep probability (default 0.90)
"""
self.num_units = num_units
self.forget_bias = forget_bias
self.use_recurrent_dropout = use_recurrent_dropout
self.dropout_keep_prob = dropout_keep_prob
@property
def input_size(self):
return self.num_units
@property
def output_size(self):
return self.num_units
@property
def state_size(self):
return 2 * self.num_units
def get_output(self, state):
h, unused_c = tf.split(state, 2, 1)
return h
def __call__(self, x, state, timestep=0, scope=None):
with tf.variable_scope(scope or type(self).__name__):
h, c = tf.split(state, 2, 1)
h_size = self.num_units
x_size = x.get_shape().as_list()[1]
batch_size = x.get_shape().as_list()[0]
w_init = None # uniform
h_init = lstm_ortho_initializer(1.0)
w_xh = tf.get_variable(
'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
w_hh = tf.get_variable(
'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
concat = tf.concat([x, h], 1) # concat for speed.
w_full = tf.concat([w_xh, w_hh], 0)
concat = tf.matmul(concat, w_full) # + bias # live life without garbage.
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
concat = layer_norm_all(concat, batch_size, 4, h_size, 'ln_all')
i, j, f, o = tf.split(concat, 4, 1)
if self.use_recurrent_dropout:
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
else:
g = tf.tanh(j)
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
new_h = tf.tanh(layer_norm(new_c, h_size, 'ln_c')) * tf.sigmoid(o)
return new_h, tf.concat([new_h, new_c], 1)
class HyperLSTMCell(tf.contrib.rnn.RNNCell):
"""HyperLSTM with Ortho Init, Layer Norm, Recurrent Dropout, no Memory Loss.
https://arxiv.org/abs/1609.09106
http://blog.otoro.net/2016/09/28/hyper-networks/
"""
def __init__(self,
num_units,
forget_bias=1.0,
use_recurrent_dropout=False,
dropout_keep_prob=0.90,
use_layer_norm=True,
hyper_num_units=256,
hyper_embedding_size=32,
hyper_use_recurrent_dropout=False):
"""Initialize the Layer Norm HyperLSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (default 1.0).
use_recurrent_dropout: Whether to use Recurrent Dropout (default False)
dropout_keep_prob: float, dropout keep probability (default 0.90)
use_layer_norm: boolean. (default True)
Controls whether we use LayerNorm layers in main LSTM & HyperLSTM cell.
hyper_num_units: int, number of units in HyperLSTM cell.
(default is 128, recommend experimenting with 256 for larger tasks)
hyper_embedding_size: int, size of signals emitted from HyperLSTM cell.
(default is 16, recommend trying larger values for large datasets)
hyper_use_recurrent_dropout: boolean. (default False)
Controls whether HyperLSTM cell also uses recurrent dropout.
Recommend turning this on only if hyper_num_units becomes large (>= 512)
"""
self.num_units = num_units
self.forget_bias = forget_bias
self.use_recurrent_dropout = use_recurrent_dropout
self.dropout_keep_prob = dropout_keep_prob
self.use_layer_norm = use_layer_norm
self.hyper_num_units = hyper_num_units
self.hyper_embedding_size = hyper_embedding_size
self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout
self.total_num_units = self.num_units + self.hyper_num_units
if self.use_layer_norm:
cell_fn = LayerNormLSTMCell
else:
cell_fn = LSTMCell
self.hyper_cell = cell_fn(
hyper_num_units,
use_recurrent_dropout=hyper_use_recurrent_dropout,
dropout_keep_prob=dropout_keep_prob)
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self.num_units
@property
def state_size(self):
return 2 * self.total_num_units
def get_output(self, state):
total_h, unused_total_c = tf.split(state, 2, 1)
h = total_h[:, 0:self.num_units]
return h
def hyper_norm(self, layer, scope='hyper', use_bias=True):
num_units = self.num_units
embedding_size = self.hyper_embedding_size
# recurrent batch norm init trick (https://arxiv.org/abs/1603.09025).
init_gamma = 0.10 # cooijmans' da man.
with tf.variable_scope(scope):
zw = super_linear(
self.hyper_output,
embedding_size,
init_w='constant',
weight_start=0.00,
use_bias=True,
bias_start=1.0,
scope='zw')
alpha = super_linear(
zw,
num_units,
init_w='constant',
weight_start=init_gamma / embedding_size,
use_bias=False,
scope='alpha')
result = tf.multiply(alpha, layer)
if use_bias:
zb = super_linear(
self.hyper_output,
embedding_size,
init_w='gaussian',
weight_start=0.01,
use_bias=False,
bias_start=0.0,
scope='zb')
beta = super_linear(
zb,
num_units,
init_w='constant',
weight_start=0.00,
use_bias=False,
scope='beta')
result += beta
return result
def __call__(self, x, state, timestep=0, scope=None):
with tf.variable_scope(scope or type(self).__name__):
total_h, total_c = tf.split(state, 2, 1)
h = total_h[:, 0:self.num_units]
c = total_c[:, 0:self.num_units]
self.hyper_state = tf.concat(
[total_h[:, self.num_units:], total_c[:, self.num_units:]], 1)
batch_size = x.get_shape().as_list()[0]
x_size = x.get_shape().as_list()[1]
self._input_size = x_size
w_init = None # uniform
h_init = lstm_ortho_initializer(1.0)
w_xh = tf.get_variable(
'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
w_hh = tf.get_variable(
'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
bias = tf.get_variable(
'bias', [4 * self.num_units],
initializer=tf.constant_initializer(0.0))
# concatenate the input and hidden states for hyperlstm input
hyper_input = tf.concat([x, h], 1)
hyper_output, hyper_new_state = self.hyper_cell(hyper_input,
self.hyper_state)
self.hyper_output = hyper_output
self.hyper_state = hyper_new_state
xh = tf.matmul(x, w_xh)
hh = tf.matmul(h, w_hh)
# split Wxh contributions
ix, jx, fx, ox = tf.split(xh, 4, 1)
ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False)
jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False)
fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False)
ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False)
# split Whh contributions
ih, jh, fh, oh = tf.split(hh, 4, 1)
ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True)
jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True)
fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True)
oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True)
# split bias
ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted.
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i = ix + ih + ib
j = jx + jh + jb
f = fx + fh + fb
o = ox + oh + ob
if self.use_layer_norm:
concat = tf.concat([i, j, f, o], 1)
concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all')
i, j, f, o = tf.split(concat, 4, 1)
if self.use_recurrent_dropout:
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
else:
g = tf.tanh(j)
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o)
hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1)
new_total_h = tf.concat([new_h, hyper_h], 1)
new_total_c = tf.concat([new_c, hyper_c], 1)
new_total_state = tf.concat([new_total_h, new_total_c], 1)
return new_h, new_total_state
================================================
FILE: subnet_tf_utils.py
================================================
import tensorflow as tf
def get_initializer(init_method):
if init_method == 'xavier_normal':
initializer = tf.glorot_normal_initializer()
elif init_method == 'xavier_uniform':
initializer = tf.glorot_uniform_initializer()
elif init_method == 'he_normal':
initializer = tf.keras.initializers.he_normal()
elif init_method == 'he_uniform':
initializer = tf.keras.initializers.he_uniform()
elif init_method == 'lecun_normal':
initializer = tf.keras.initializers.lecun_normal()
elif init_method == 'lecun_uniform':
initializer = tf.keras.initializers.lecun_uniform()
else:
raise Exception('Unknown initializer:', init_method)
return initializer
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
with tf.variable_scope(name) as scope:
if alt_relu_impl:
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
else:
return tf.maximum(x, leak * x)
def batchnorm(input, name='batch_norm', init_method=None):
if init_method is not None:
initializer = get_initializer(init_method)
else:
initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)
with tf.variable_scope(name):
# this block looks like it has 3 inputs on the graph unless we do this
input = tf.identity(input)
channels = input.get_shape()[3]
offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
scale = tf.get_variable("scale", [channels], dtype=tf.float32,
initializer=initializer)
mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
variance_epsilon = 1e-5
normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
return normalized
def layernorm(input, name='layer_norm', init_method=None):
if init_method is not None:
initializer = get_initializer(init_method)
else:
initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)
with tf.variable_scope(name):
n_neurons = input.get_shape()[3]
offset = tf.get_variable("offset", [n_neurons], dtype=tf.float32, initializer=tf.zeros_initializer())
scale = tf.get_variable("scale", [n_neurons], dtype=tf.float32,
initializer=initializer)
offset = tf.reshape(offset, [1, 1, -1])
scale = tf.reshape(scale, [1, 1, -1])
mean, variance = tf.nn.moments(input, axes=[1, 2, 3], keep_dims=True)
variance_epsilon = 1e-5
normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
return normalized
def instance_norm(input, name="instance_norm", init_method=None):
if init_method is not None:
initializer = get_initializer(init_method)
else:
initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)
with tf.variable_scope(name):
depth = input.get_shape()[3]
scale = tf.get_variable("scale", [depth], initializer=initializer)
offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True)
epsilon = 1e-5
inv = tf.rsqrt(variance + epsilon)
normalized = (input - mean) * inv
return scale * normalized + offset
def linear1d(inputlin, inputdim, outputdim, name="linear1d", init_method=None):
if init_method is not None:
initializer = get_initializer(init_method)
else:
initializer = None
with tf.variable_scope(name) as scope:
weight = tf.get_variable("weight", [inputdim, outputdim], initializer=initializer)
bias = tf.get_variable("bias", [outputdim], dtype=tf.float32, initializer=tf.constant_initializer(0.0))
return tf.matmul(inputlin, weight) + bias
def general_conv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2,
stddev=0.02, padding="SAME", name="conv2d", do_norm=True, norm_type='instance_norm', do_relu=True,
relufactor=0, is_training=True, init_method=None):
if init_method is not None:
initializer = get_initializer(init_method)
else:
initializer = tf.truncated_normal_initializer(stddev=stddev)
with tf.variable_scope(name) as scope:
conv = tf.contrib.layers.conv2d(inputconv, output_dim, [filter_width, filter_height],
[stride_width, stride_height], padding, activation_fn=None,
weights_initializer=initializer,
biases_initializer=tf.constant_initializer(0.0))
if do_norm:
if norm_type == 'instance_norm':
conv = instance_norm(conv, init_method=init_method)
# conv = tf.contrib.layers.instance_norm(conv, epsilon=1e-05, center=True, scale=True,
# scope='instance_norm')
elif norm_type == 'batch_norm':
# conv = batchnorm(conv, init_method=init_method)
conv = tf.contrib.layers.batch_norm(conv, decay=0.9, is_training=is_training, updates_collections=None,
epsilon=1e-5, center=True, scale=True, scope="batch_norm")
elif norm_type == 'layer_norm':
# conv = layernorm(conv, init_method=init_method)
conv = tf.contrib.layers.layer_norm(conv, center=True, scale=True, scope='layer_norm')
if do_relu:
if relufactor == 0:
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
def generative_cnn_c3_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 512]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_combine33(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
tensor_x = tf.concat([local_x, global_x], axis=-1)
with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_combine43(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
tensor_x = tf.concat([local_x, global_x], axis=-1)
with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_combine53(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
tensor_x = tf.concat([local_x, global_x], axis=-1)
with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 1024 * 4 * 4))
tensor_x = linear1d(tensor_x, 1024 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_combineFC(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
local_x = tf.nn.dropout(local_x, drop_keep_prob)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
global_x = tf.nn.dropout(global_x, drop_keep_prob)
tensor_x_sp = None
tensor_x = tf.concat([local_x, global_x], axis=-1)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_combineFC_jointAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,
init_method=None, combine_manner='attn'):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
share_x = local_x
with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):
attn_x = general_conv2d(share_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
attn_map = tf.nn.sigmoid(attn_x) # (N, H/8, W/8, 1), [0.0, 1.0]
if combine_manner == 'attn':
attn_feat = attn_map * share_x + share_x
else:
raise Exception('Unknown combine_manner', combine_manner)
local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
local_x = tf.nn.dropout(local_x, drop_keep_prob)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
global_x = tf.nn.dropout(global_x, drop_keep_prob)
tensor_x_sp = None
tensor_x = tf.concat([local_x, global_x], axis=-1)
return tensor_x, tensor_x_sp, attn_map
def generative_cnn_c3_encoder_combineFC_sepAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,
init_method=None, combine_manner='attn'):
local_x = local_inputs
global_x = global_inputs
with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):
attn_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
attn_x = general_conv2d(attn_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
attn_x = general_conv2d(attn_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
attn_map = tf.nn.sigmoid(attn_x) # (N, H/8, W/8, 1), [0.0, 1.0]
with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
if combine_manner == 'attn':
attn_feat = attn_map * local_x + local_x
elif combine_manner == 'channel':
attn_feat = tf.concat([local_x, attn_map], axis=-1)
else:
raise Exception('Unknown combine_manner', combine_manner)
local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
local_x = tf.nn.dropout(local_x, drop_keep_prob)
with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
global_x = tf.nn.dropout(global_x, drop_keep_prob)
tensor_x_sp = None
tensor_x = tf.concat([local_x, global_x], axis=-1)
return tensor_x, tensor_x_sp, attn_map
def generative_cnn_c3_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_c3_encoder_deeper13_attn(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)
tensor_x = self_attention(tensor_x, 64)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_5_2", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, is_training=is_training, name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 512, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_5_2", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 512]
tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def generative_cnn_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
tensor_x = inputs
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, is_training=is_training,
name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_1_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, is_training=is_training,
name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_2_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, is_training=is_training,
name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_3_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_3_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,
name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_4_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_4_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,
name="CNN_ENC_5", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_5_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
name="CNN_ENC_5_3", init_method=init_method)
tensor_x_sp = tensor_x # [N, h, w, 256]
tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)
if is_training:
tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)
return tensor_x, tensor_x_sp
def max_pooling(x) :
return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME')
def hw_flatten(x) :
return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
def self_attention(x, in_channel, name='self_attention'):
with tf.variable_scope(name) as scope:
f = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
do_norm=False, do_relu=False, name='f_conv') # (N, h, w, c')
f = max_pooling(f) # (N, h', w', c')
g = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
do_norm=False, do_relu=False, name='g_conv') # (N, h, w, c')
h = general_conv2d(x, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
do_norm=False, do_relu=False, name='h_conv') # (N, h, w, c)
h = max_pooling(h) # (N, h', w', c)
# M = h * w, M' = h' * w'
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # (N, M, M')
beta = tf.nn.softmax(s) # attention map
o = tf.matmul(beta, hw_flatten(h)) # (N, M, c)
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = tf.reshape(o, shape=x.shape) # (N, h, w, c)
o = general_conv2d(o, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
do_norm=False, do_relu=False, name='attn_conv')
x = gamma * o + x
return x
def global_avg_pooling(x):
gap = tf.reduce_mean(x, axis=[1, 2])
return gap
def cnn_discriminator_wgan_gp(discrim_inputs, discrim_targets, init_method=None):
tensor_x = tf.concat([discrim_inputs, discrim_targets], axis=3) # (N, H, W, 3 + 1)
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
is_training=True, name="CNN_ENC_1", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
is_training=True, name="CNN_ENC_2", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=True, name="CNN_ENC_3", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
is_training=True, name="CNN_ENC_4", init_method=init_method)
tensor_x = general_conv2d(tensor_x, 1, filter_height=3, filter_width=3,
is_training=True, name="CNN_ENC_5", init_method=init_method)
# (N, H/32, W/32, 1)
d_out = global_avg_pooling(tensor_x) # (N, 1)
return d_out
================================================
FILE: test_photograph_to_line.py
================================================
import numpy as np
import os
import tensorflow as tf
from six.moves import range
from PIL import Image
import argparse
import hyper_parameters as hparams
from model_common_test import DiffPastingV3, VirtualSketchingModel
from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \
save_seq_data, image_pasting_v3_testing, draw_strokes
from dataset_utils import load_dataset_testing
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_len, state_dependent,
pasting_func):
"""Samples a sequence from a pre-trained model."""
select_times = 1
cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2)
curr_canvas = np.zeros(dtype=np.float32,
shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke]
initial_state = sess.run(model.initial_state)
prev_state = initial_state
prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)
prev_scaling = np.ones((select_times), dtype=np.float32) # (N)
prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N)
params_list = [[] for _ in range(select_times)]
state_raw_list = [[] for _ in range(select_times)]
state_soft_list = [[] for _ in range(select_times)]
window_size_list = [[] for _ in range(select_times)]
input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1))
for i in range(seq_len):
if not state_dependent and i % init_len == 0:
prev_state = initial_state
curr_window_size = prev_scaling * prev_window_size # (N)
curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)
curr_window_size = np.minimum(curr_window_size, image_size)
feed = {
model.initial_state: prev_state,
model.input_photo: input_photos_tiles,
model.curr_canvas_hard: curr_canvas.copy(),
model.cursor_position: cursor_pos,
model.image_size: image_size,
model.init_width: prev_width,
model.init_scaling: prev_scaling,
model.init_window_size: prev_window_size,
}
o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \
sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)
# o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)
# o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]
idx_eos_list = np.argmax(o_pen_list, axis=1) # (N)
for output_i in range(idx_eos_list.shape[0]):
idx_eos = idx_eos_list[output_i]
eos = [0, 0]
eos[idx_eos] = 1
other_params = o_other_params_list[output_i].tolist() # (6)
params_list[output_i].append([eos[1]] + other_params)
state_raw_list[output_i].append(o_pen_list[output_i][1])
state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])
window_size_list[output_i].append(curr_window_size[output_i])
# draw the stroke and add to the canvas
x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \
o_other_params_list[output_i, 4]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width[output_i], width2], axis=0) # (2)
o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8)
if idx_eos == 0:
f = o_other_params_proc + [1.0, 1.0]
pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0],
image_size,
curr_window_size[output_i],
pasting_func, sess) # [0.0-BG, 1.0-stroke]
curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke]
curr_canvas = np.clip(curr_canvas, 0.0, 1.0)
next_width = o_other_params_list[:, 4] # (N)
next_scaling = o_other_params_list[:, 5]
next_window_size = next_scaling * curr_window_size # (N)
next_window_size = np.maximum(next_window_size, model.hps.min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_state = next_state_list
prev_width = next_width * curr_window_size / next_window_size # (N,)
prev_scaling = next_scaling # (N)
prev_window_size = curr_window_size
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level
if model.hps.cursor_type == 'next':
cursor_pos_large = stroke_position_next # (N, 2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (N, 2), large-level
cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2)
cursor_pos = cursor_pos_large / float(image_size)
return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list
def main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name,
sampling_num,
draw_seq=False, draw_order=False,
state_dependent=True, longer_infer_len=-1):
model_params_default = hparams.get_default_hparams_normal()
model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)
[test_set, eval_hps_model, sample_hps_model] = \
load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)
test_image_raw_name = test_image_name[:test_image_name.find('.')]
model_dir = os.path.join(model_base_dir, model_name)
reset_graph()
sampling_model = VirtualSketchingModel(sample_hps_model)
# differentiable pasting graph
paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)
print('snapshot_step', snapshot_step)
sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)
os.makedirs(sampling_dir, exist_ok=True)
if longer_infer_len == -1:
tmp_max_len = eval_hps_model.max_seq_len
else:
tmp_max_len = longer_infer_len
for sampling_i in range(sampling_num):
input_photos, init_cursors, test_image_size = test_set.get_test_image()
# input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG]
# init_cursors: (N, 1, 2), in size [0.0, 1.0)
print()
print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)
print('Processing ...')
if init_cursors.ndim == 3:
init_cursors = np.expand_dims(init_cursors, axis=0)
input_photos = input_photos[0:1, :, :, :]
ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)
ori_img_png = Image.fromarray(ori_img, 'RGB')
ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')
# decoding for sampling
strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, window_size_out_list = sample(
sess, sampling_model, input_photos, init_cursors, test_image_size,
eval_hps_model.max_seq_len, tmp_max_len, state_dependent, paste_v3_func)
# pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke]
output_i = 0
strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0)
states_raw_out = states_raw_out_list[output_i]
states_soft_out = states_soft_out_list[output_i]
window_size_out = window_size_out_list[output_i]
round_new_lengths = [tmp_max_len]
multi_cursors = [init_cursors[0, output_i, 0, :]]
print('strokes_raw_out', strokes_raw_out.shape)
clean_states_soft_out = np.array(states_soft_out) # (N)
flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N)
drawing_len = len(flag_list) - np.sum(flag_list)
assert drawing_len >= 0
# print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
for i in range(strokes_raw_out.shape[0]):
flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]
win_size = window_size_out[i]
out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'
out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)
out_log = out_format % out_values
# print(out_log)
print('Saving results ...')
save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),
strokes_raw_out, init_cursors[0, output_i, 0, :],
test_image_size, tmp_max_len, eval_hps_model.min_width)
draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',
ori_img, test_image_size,
multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,
sample_hps_model.raster_size, sample_hps_model.min_window_size,
sess,
pasting_func=paste_v3_func,
save_seq=draw_seq, draw_order=draw_order)
def main(model_name, test_image_name, sampling_num):
test_dataset = 'faces'
test_image_base_dir = 'sample_inputs'
sampling_base_dir = 'outputs/sampling'
model_base_dir = 'outputs/snapshot'
state_dependent = False
longer_infer_len = 100
draw_seq = False
draw_color_order = True
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name, sampling_num,
draw_seq=draw_seq, draw_order=draw_color_order,
state_dependent=state_dependent, longer_infer_len=longer_infer_len)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', type=str, default='', help="The test image name.")
parser.add_argument('--model', '-m', type=str, default='pretrain_faces', help="The trained model.")
parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.")
args = parser.parse_args()
assert args.input != ''
assert args.sample > 0
main(args.model, args.input, args.sample)
================================================
FILE: test_rough_sketch_simplification.py
================================================
import numpy as np
import os
import tensorflow as tf
from six.moves import range
from PIL import Image
import argparse
import hyper_parameters as hparams
from model_common_test import DiffPastingV3, VirtualSketchingModel
from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \
save_seq_data, image_pasting_v3_testing, draw_strokes
from dataset_utils import load_dataset_testing
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def move_cursor_to_undrawn(current_pos_list, input_image_, patch_size,
move_min_dist, move_max_dist, trial_times=20):
"""
:param current_pos_list: (select_times, 1, 2), [0.0, 1.0)
:param input_image_: (1, image_size, image_size, 3), [0-stroke, 1-BG]
:return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0)
"""
def crop_patch(image, center, image_size, crop_size):
x0 = center[0] - crop_size // 2
x1 = x0 + crop_size
y0 = center[1] - crop_size // 2
y1 = y0 + crop_size
x0 = max(0, min(x0, image_size))
y0 = max(0, min(y0, image_size))
x1 = max(0, min(x1, image_size))
y1 = max(0, min(y1, image_size))
patch = image[y0:y1, x0:x1]
return patch
def isvalid_cursor(input_img, cursor, raster_size, image_size):
# input_img: (image_size, image_size, 3), [0.0-BG, 1.0-stroke]
cursor_large = cursor * float(image_size)
cursor_large = np.round(cursor_large).astype(np.int32)
input_crop_patch = crop_patch(input_img, cursor_large, image_size, raster_size)
if np.sum(input_crop_patch) > 0.0:
return True
else:
return False
def randomly_move_cursor(cursor_position, img_size, min_dist_p, max_dist_p):
# cursor_position: (2), [0.0, 1.0)
cursor_pos_large = cursor_position * img_size
min_dist = int(min_dist_p / 2.0 * img_size)
max_dist = int(max_dist_p / 2.0 * img_size)
rand_cursor_offset = np.random.randint(min_dist, max_dist, size=cursor_pos_large.shape)
rand_cursor_offset_sign = np.random.randint(0, 1 + 1, size=cursor_pos_large.shape)
rand_cursor_offset_sign[rand_cursor_offset_sign == 0] = -1
rand_cursor_offset = rand_cursor_offset * rand_cursor_offset_sign
new_cursor_pos_large = cursor_pos_large + rand_cursor_offset
new_cursor_pos_large = np.minimum(np.maximum(new_cursor_pos_large, 0), img_size - 1) # (2), large-level
new_cursor_pos = new_cursor_pos_large.astype(np.float32) / float(img_size)
return new_cursor_pos
input_image = 1.0 - input_image_[0] # (image_size, image_size, 3), [0-BG, 1-stroke]
img_size = input_image.shape[0]
new_cursor_pos = []
for cursor_i in range(current_pos_list.shape[0]):
curr_cursor = current_pos_list[cursor_i][0]
for trial_i in range(trial_times):
new_cursor = randomly_move_cursor(curr_cursor, img_size, move_min_dist, move_max_dist) # (2), [0.0, 1.0)
if isvalid_cursor(input_image, new_cursor, patch_size, img_size) or trial_i == trial_times - 1:
new_cursor_pos.append(new_cursor)
break
assert len(new_cursor_pos) == current_pos_list.shape[0]
new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1) # (select_times, 1, 2), [0.0, 1.0)
return new_cursor_pos
def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens,
state_dependent, pasting_func, round_stop_state_num,
min_dist_p, max_dist_p):
"""Samples a sequence from a pre-trained model."""
select_times = 1
curr_canvas = np.zeros(dtype=np.float32,
shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke]
initial_state = sess.run(model.initial_state)
params_list = [[] for _ in range(select_times)]
state_raw_list = [[] for _ in range(select_times)]
state_soft_list = [[] for _ in range(select_times)]
window_size_list = [[] for _ in range(select_times)]
round_cursor_list = []
round_length_real_list = []
input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1))
for cursor_i, seq_len in enumerate(seq_lens):
if cursor_i == 0:
cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2)
else:
cursor_pos = move_cursor_to_undrawn(cursor_pos, input_photos, model.hps.raster_size,
min_dist_p, max_dist_p) # (select_times, 1, 2)
round_cursor_list.append(cursor_pos)
prev_state = initial_state
prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)
prev_scaling = np.ones((select_times), dtype=np.float32) # (N)
prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N)
continuous_one_state_num = 0
for i in range(seq_len):
if not state_dependent and i % init_len == 0:
prev_state = initial_state
curr_window_size = prev_scaling * prev_window_size # (N)
curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)
curr_window_size = np.minimum(curr_window_size, image_size)
feed = {
model.initial_state: prev_state,
model.input_photo: input_photos_tiles,
model.curr_canvas_hard: curr_canvas.copy(),
model.cursor_position: cursor_pos,
model.image_size: image_size,
model.init_width: prev_width,
model.init_scaling: prev_scaling,
model.init_window_size: prev_window_size,
}
o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \
sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)
# o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)
# o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]
idx_eos_list = np.argmax(o_pen_list, axis=1) # (N)
output_i = 0
idx_eos = idx_eos_list[output_i]
eos = [0, 0]
eos[idx_eos] = 1
other_params = o_other_params_list[output_i].tolist() # (6)
params_list[output_i].append([eos[1]] + other_params)
state_raw_list[output_i].append(o_pen_list[output_i][1])
state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])
window_size_list[output_i].append(curr_window_size[output_i])
# draw the stroke and add to the canvas
x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \
o_other_params_list[output_i, 4]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width[output_i], width2], axis=0) # (2)
o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8)
if idx_eos == 0:
f = o_other_params_proc + [1.0, 1.0]
pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img,
cursor_pos[output_i, 0],
image_size,
curr_window_size[output_i],
pasting_func, sess) # [0.0-BG, 1.0-stroke]
curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke]
continuous_one_state_num = 0
else:
continuous_one_state_num += 1
curr_canvas = np.clip(curr_canvas, 0.0, 1.0)
next_width = o_other_params_list[:, 4] # (N)
next_scaling = o_other_params_list[:, 5]
next_window_size = next_scaling * curr_window_size # (N)
next_window_size = np.maximum(next_window_size, model.hps.min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_state = next_state_list
prev_width = next_width * curr_window_size / next_window_size # (N,)
prev_scaling = next_scaling # (N)
prev_window_size = curr_window_size
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = o_other_params_list[:, 2:4] * (
np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]],
axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level
if model.hps.cursor_type == 'next':
cursor_pos_large = stroke_position_next # (N, 2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0),
float(image_size - 1)) # (N, 2), large-level
cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2)
cursor_pos = cursor_pos_large / float(image_size)
if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1:
round_length_real_list.append(i + 1)
break
return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \
round_cursor_list, round_length_real_list
def main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name,
sampling_num,
min_dist_p, max_dist_p,
longer_infer_lens, round_stop_state_num,
draw_seq=False, draw_order=False,
state_dependent=True):
model_params_default = hparams.get_default_hparams_rough()
model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)
[test_set, eval_hps_model, sample_hps_model] = \
load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)
test_image_raw_name = test_image_name[:test_image_name.find('.')]
model_dir = os.path.join(model_base_dir, model_name)
reset_graph()
sampling_model = VirtualSketchingModel(sample_hps_model)
# differentiable pasting graph
paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)
print('snapshot_step', snapshot_step)
sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)
os.makedirs(sampling_dir, exist_ok=True)
for sampling_i in range(sampling_num):
input_photos, init_cursors, test_image_size = test_set.get_test_image()
# input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG]
# init_cursors: (N, 1, 2), in size [0.0, 1.0)
print()
print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)
print('Processing ...')
if init_cursors.ndim == 3:
init_cursors = np.expand_dims(init_cursors, axis=0)
input_photos = input_photos[0:1, :, :, :]
ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)
ori_img_png = Image.fromarray(ori_img, 'RGB')
ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')
# decoding for sampling
strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \
window_size_out_list, round_new_cursors, round_new_lengths = sample(
sess, sampling_model, input_photos, init_cursors, test_image_size,
eval_hps_model.max_seq_len, longer_infer_lens, state_dependent, paste_v3_func,
round_stop_state_num, min_dist_p, max_dist_p)
# pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke]
print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths)
output_i = 0
strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0)
states_raw_out = states_raw_out_list[output_i]
states_soft_out = states_soft_out_list[output_i]
window_size_out = window_size_out_list[output_i]
multi_cursors = [init_cursors[0, output_i, 0]]
for c_i in range(len(round_new_cursors)):
best_cursor = round_new_cursors[c_i][output_i, 0] # (2)
multi_cursors.append(best_cursor)
assert len(multi_cursors) == len(round_new_lengths)
print('strokes_raw_out', strokes_raw_out.shape)
clean_states_soft_out = np.array(states_soft_out) # (N)
flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N)
drawing_len = len(flag_list) - np.sum(flag_list)
assert drawing_len >= 0
# print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
for i in range(strokes_raw_out.shape[0]):
flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]
win_size = window_size_out[i]
out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'
out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)
out_log = out_format % out_values
# print(out_log)
print('Saving results ...')
save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),
strokes_raw_out, multi_cursors,
test_image_size, round_new_lengths, eval_hps_model.min_width)
draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',
ori_img, test_image_size,
multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,
sample_hps_model.raster_size, sample_hps_model.min_window_size,
sess,
pasting_func=paste_v3_func,
save_seq=draw_seq, draw_order=draw_order)
def main(model_name, test_image_name, sampling_num):
test_dataset = 'rough_sketches'
test_image_base_dir = 'sample_inputs'
sampling_base_dir = 'outputs/sampling'
model_base_dir = 'outputs/snapshot'
state_dependent = False
longer_infer_lens = [128 for _ in range(10)]
round_stop_state_num = 12
min_dist_p = 0.3
max_dist_p = 0.9
draw_seq = False
draw_color_order = True
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name, sampling_num,
min_dist_p=min_dist_p, max_dist_p=max_dist_p,
draw_seq=draw_seq, draw_order=draw_color_order,
state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,
round_stop_state_num=round_stop_state_num)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', type=str, default='', help="The test image name.")
parser.add_argument('--model', '-m', type=str, default='pretrain_rough_sketches', help="The trained model.")
parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.")
args = parser.parse_args()
assert args.input != ''
assert args.sample > 0
main(args.model, args.input, args.sample)
================================================
FILE: test_vectorization.py
================================================
import numpy as np
import random
import os
import tensorflow as tf
from six.moves import range
from PIL import Image
import time
import argparse
import hyper_parameters as hparams
from model_common_test import DiffPastingV3, VirtualSketchingModel
from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \
save_seq_data, image_pasting_v3_testing, draw_strokes
from dataset_utils import load_dataset_testing
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def move_cursor_to_undrawn(current_canvas_list, input_image_, last_min_acc_list, grid_patch_size=128,
stroke_acc_threshold=0.95, stroke_num_threshold=5, continuous_min_acc_threshold=2):
"""
:param current_canvas_list: (select_times, image_size, image_size), [0.0-BG, 1.0-stroke]
:param input_image_: (1, image_size, image_size), [0-stroke, 1-BG]
:return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0)
"""
def split_images(in_img, image_size, grid_size):
if image_size % grid_size == 0:
paddings_ = 0
else:
paddings_ = grid_size - image_size % grid_size
paddings = [[0, paddings_],
[0, paddings_]]
image_pad = np.pad(in_img, paddings, mode='constant', constant_values=0.0) # (H_p, W_p), [0.0-BG, 1.0-stroke]
assert image_pad.shape[0] % grid_size == 0
split_num = image_pad.shape[0] // grid_size
images_h = np.hsplit(image_pad, split_num)
image_patches = []
for image_h in images_h:
images_v = np.vsplit(image_h, split_num)
image_patches += images_v
image_patches = np.array(image_patches, dtype=np.float32)
return image_patches, split_num
def line_drawing_rounding(line_drawing):
line_drawing_r = np.copy(line_drawing) # [0.0-BG, 1.0-stroke]
line_drawing_r[line_drawing_r != 0.0] = 1.0
return line_drawing_r
def cal_undrawn_pixels(in_canvas, in_sketch):
in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32) # (N, H, W), [0.0-BG, 1.0-stroke]
in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32)
intersection = np.bitwise_and(in_canvas_round, in_sketch_round)
intersection_sum = np.sum(intersection, axis=(1, 2))
gt_sum = np.sum(in_sketch_round, axis=(1, 2)) # (N)
undrawn_num = gt_sum - intersection_sum
return undrawn_num
def cal_stroke_acc(in_canvas, in_sketch):
in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32) # (N, H, W), [0.0-BG, 1.0-stroke]
in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32)
intersection = np.bitwise_and(in_canvas_round, in_sketch_round)
intersection_sum = np.sum(intersection, axis=(1, 2)).astype(np.float32)
gt_sum = np.sum(in_sketch_round, axis=(1, 2)).astype(np.float32) # (N)
undrawn_num = gt_sum - intersection_sum # (N)
stroke_acc = intersection_sum / gt_sum # (N)
stroke_acc[gt_sum == 0.0] = 1.0
stroke_acc[undrawn_num <= stroke_num_threshold] = 1.0
return stroke_acc
def get_cursor(patch_idx, img_size, grid_size, split_num):
y_pos = patch_idx % split_num
x_pos = patch_idx // split_num
y_top = y_pos * grid_size + grid_size // 4
y_bottom = y_top + grid_size // 2
x_left = x_pos * grid_size + grid_size // 4
x_right = x_left + grid_size // 2
cursor_y = random.randint(y_top, y_bottom)
cursor_x = random.randint(x_left, x_right)
cursor_y = max(0, min(cursor_y, img_size - 1))
cursor_x = max(0, min(cursor_x, img_size - 1)) # (2), in large size
center = np.array([cursor_x, cursor_y], dtype=np.float32)
return center / float(img_size) # (2), in size [0.0, 1.0)
input_image = 1.0 - input_image_[0] # (image_size, image_size), [0-BG, 1-stroke]
img_size = input_image.shape[0]
input_image_patches, split_number = split_images(input_image, img_size, grid_patch_size) # (N, grid_size, grid_size)
new_cursor_pos = []
last_min_acc_list_new = [item for item in last_min_acc_list]
for canvas_i in range(current_canvas_list.shape[0]):
curr_canvas = current_canvas_list[canvas_i] # (image_size, image_size), [0.0-BG, 1.0-stroke]
curr_canvas_patches, _ = split_images(curr_canvas, img_size, grid_patch_size) # (N, grid_size, grid_size)
# 1. detect ending flag by stroke accuracy
stroke_accuracy = cal_stroke_acc(curr_canvas_patches, input_image_patches)
min_acc_idx = np.argmin(stroke_accuracy)
min_acc= stroke_accuracy[min_acc_idx]
# print('min_acc_idx', min_acc_idx, ' | ', 'min_acc', min_acc)
if min_acc >= stroke_acc_threshold: # end of drawing
return None, None
# 2. detect undrawn pixels
undrawn_pixel_num = cal_undrawn_pixels(curr_canvas_patches, input_image_patches)
# undrawn_pixel_num_dis = np.reshape(undrawn_pixel_num, (split_number, split_number)).T
# print('undrawn_pixel_num_dis')
# print(undrawn_pixel_num_dis)
max_undrawn_idx = np.argmax(undrawn_pixel_num)
# max_undrawn = undrawn_pixel_num[max_undrawn_idx]
# print('max_undrawn_idx', max_undrawn_idx, ' | ', 'max_undrawn', max_undrawn)
# 3. select a random position
last_min_acc_idx, last_min_acc_times = last_min_acc_list[canvas_i]
if last_min_acc_times >= continuous_min_acc_threshold and last_min_acc_idx == min_acc_idx:
selected_patch_idx = last_min_acc_idx
new_min_acc_times = 1
else:
selected_patch_idx = max_undrawn_idx
if min_acc_idx == last_min_acc_idx:
new_min_acc_times = last_min_acc_times + 1
else:
new_min_acc_times = 1
new_min_acc_idx = min_acc_idx
last_min_acc_list_new[canvas_i] = (new_min_acc_idx, new_min_acc_times)
# print('selected_patch_idx', selected_patch_idx)
# 4. get cursor according to the selected_patch_idx
rand_cursor = get_cursor(selected_patch_idx, img_size, grid_patch_size, split_number) # (2), in size [0.0, 1.0)
new_cursor_pos.append(rand_cursor)
assert len(new_cursor_pos) == current_canvas_list.shape[0]
new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1) # (select_times, 1, 2), [0.0, 1.0)
return new_cursor_pos, last_min_acc_list_new
def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens, state_dependent,
pasting_func, round_stop_state_num, stroke_acc_threshold):
"""Samples a sequence from a pre-trained model."""
select_times = 1
curr_canvas = np.zeros(dtype=np.float32,
shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke]
initial_state = sess.run(model.initial_state)
prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)
params_list = [[] for _ in range(select_times)]
state_raw_list = [[] for _ in range(select_times)]
state_soft_list = [[] for _ in range(select_times)]
window_size_list = [[] for _ in range(select_times)]
last_min_stroke_acc_list = [(-1, 0) for _ in range(select_times)]
round_cursor_list = []
round_length_real_list = []
input_photos_tiles = np.tile(input_photos, (select_times, 1, 1))
for cursor_i, seq_len in enumerate(seq_lens):
# print('\n')
# print('@@ Round', cursor_i + 1)
if cursor_i == 0:
cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2)
else:
cursor_pos, last_min_stroke_acc_list_updated = \
move_cursor_to_undrawn(curr_canvas, input_photos, last_min_stroke_acc_list,
grid_patch_size=model.hps.raster_size,
stroke_acc_threshold=stroke_acc_threshold) # (select_times, 1, 2)
if cursor_pos is not None:
round_cursor_list.append(cursor_pos)
last_min_stroke_acc_list = last_min_stroke_acc_list_updated
else:
break
prev_state = initial_state
if not model.hps.init_cursor_on_undrawn_pixel:
prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)
prev_scaling = np.ones((select_times), dtype=np.float32) # (N)
prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N)
continuous_one_state_num = 0
for i in range(seq_len):
if not state_dependent and i % init_len == 0:
prev_state = initial_state
curr_window_size = prev_scaling * prev_window_size # (N)
curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)
curr_window_size = np.minimum(curr_window_size, image_size)
feed = {
model.initial_state: prev_state,
model.input_photo: np.expand_dims(input_photos_tiles, axis=-1),
model.curr_canvas_hard: curr_canvas.copy(),
model.cursor_position: cursor_pos,
model.image_size: image_size,
model.init_width: prev_width,
model.init_scaling: prev_scaling,
model.init_window_size: prev_window_size,
}
o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \
sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)
# o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)
# o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]
idx_eos_list = np.argmax(o_pen_list, axis=1) # (N)
output_i = 0
idx_eos = idx_eos_list[output_i]
eos = [0, 0]
eos[idx_eos] = 1
other_params = o_other_params_list[output_i].tolist() # (6)
params_list[output_i].append([eos[1]] + other_params)
state_raw_list[output_i].append(o_pen_list[output_i][1])
state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])
window_size_list[output_i].append(curr_window_size[output_i])
# draw the stroke and add to the canvas
x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \
o_other_params_list[output_i, 4]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width[output_i], width2], axis=0) # (2)
o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8)
if idx_eos == 0:
f = o_other_params_proc + [1.0, 1.0]
pred_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0],
image_size,
curr_window_size[output_i],
pasting_func, sess) # [0.0-BG, 1.0-stroke]
curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke]
continuous_one_state_num = 0
else:
continuous_one_state_num += 1
curr_canvas = np.clip(curr_canvas, 0.0, 1.0)
next_width = o_other_params_list[:, 4] # (N)
next_scaling = o_other_params_list[:, 5]
next_window_size = next_scaling * curr_window_size # (N)
next_window_size = np.maximum(next_window_size, model.hps.min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_state = next_state_list
prev_width = next_width * curr_window_size / next_window_size # (N,)
prev_scaling = next_scaling # (N)
prev_window_size = curr_window_size
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level
if model.hps.cursor_type == 'next':
cursor_pos_large = stroke_position_next # (N, 2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (N, 2), large-level
cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2)
cursor_pos = cursor_pos_large / float(image_size)
if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1:
round_length_real_list.append(i + 1)
break
return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \
round_cursor_list, round_length_real_list
def main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name,
sampling_num,
longer_infer_lens,
round_stop_state_num, stroke_acc_threshold,
draw_seq=False, draw_order=False,
state_dependent=True):
model_params_default = hparams.get_default_hparams_clean()
model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)
[test_set, eval_hps_model, sample_hps_model] \
= load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)
test_image_raw_name = test_image_name[:test_image_name.find('.')]
model_dir = os.path.join(model_base_dir, model_name)
reset_graph()
sampling_model = VirtualSketchingModel(sample_hps_model)
# differentiable pasting graph
paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)
print('snapshot_step', snapshot_step)
sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)
os.makedirs(sampling_dir, exist_ok=True)
stroke_number_list = []
compute_time_list = []
for sampling_i in range(sampling_num):
start_time_point = time.time()
input_photos, init_cursors, test_image_size = test_set.get_test_image()
# input_photos: (1, image_size, image_size), [0-stroke, 1-BG]
# init_cursors: (1, 1, 2), in size [0.0, 1.0)
print()
print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)
print('Processing ...')
if init_cursors.ndim == 3:
init_cursors = np.expand_dims(init_cursors, axis=0)
input_photos = input_photos[0:1, :, :]
ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)
ori_img = np.stack([ori_img for _ in range(3)], axis=2)
ori_img_png = Image.fromarray(ori_img, 'RGB')
ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')
data_loading_time_point = time.time()
# decoding for sampling
strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \
window_size_out_list, round_new_cursors, round_new_lengths = sample(
sess, sampling_model, input_photos, init_cursors, test_image_size,
eval_hps_model.max_seq_len, longer_infer_lens, state_dependent,
paste_v3_func, round_stop_state_num, stroke_acc_threshold)
# pred_imgs_out: [0.0-BG, 1.0-stroke]
print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths)
sampling_time_point = time.time()
data_loading_time = data_loading_time_point - start_time_point
sampling_time_total = sampling_time_point - start_time_point
sampling_time_wo_data_loading = sampling_time_point - data_loading_time_point
compute_time_list.append(sampling_time_total)
# print(' >>> data_loading_time', data_loading_time)
print(' >>> sampling_time_total', sampling_time_total)
# print(' >>> sampling_time_wo_data_loading', sampling_time_wo_data_loading)
best_result_idx = 0
strokes_raw_out = np.stack(strokes_raw_out_list[best_result_idx], axis=0)
states_raw_out = states_raw_out_list[best_result_idx]
states_soft_out = states_soft_out_list[best_result_idx]
window_size_out = window_size_out_list[best_result_idx]
multi_cursors = [init_cursors[0, best_result_idx, 0]]
for c_i in range(len(round_new_cursors)):
best_cursor = round_new_cursors[c_i][best_result_idx, 0] # (2)
multi_cursors.append(best_cursor)
assert len(multi_cursors) == len(round_new_lengths)
print('strokes_raw_out', strokes_raw_out.shape)
stroke_number_list.append(strokes_raw_out.shape[0])
clean_states_soft_out = np.array(states_soft_out) # (N)
flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N)
drawing_len = len(flag_list) - np.sum(flag_list)
assert drawing_len >= 0
# print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
for i in range(strokes_raw_out.shape[0]):
flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]
win_size = window_size_out[i]
out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'
out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)
out_log = out_format % out_values
# print(out_log)
print('Saving results ...')
save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),
strokes_raw_out, multi_cursors,
test_image_size, round_new_lengths, eval_hps_model.min_width)
draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',
ori_img, test_image_size,
multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,
sample_hps_model.raster_size, sample_hps_model.min_window_size,
sess,
pasting_func=paste_v3_func,
save_seq=draw_seq, draw_order=draw_order)
average_stroke_number = np.mean(stroke_number_list)
average_compute_time = np.mean(compute_time_list)
print()
print('@@@ Total summary:')
print(' >>> average_stroke_number', average_stroke_number)
print(' >>> average_compute_time', average_compute_time)
def main(model_name, test_image_name, sampling_num):
test_dataset = 'clean_line_drawings'
test_image_base_dir = 'sample_inputs'
sampling_base_dir = 'outputs/sampling'
model_base_dir = 'outputs/snapshot'
state_dependent = False
longer_infer_lens = [500 for _ in range(10)]
round_stop_state_num = 12
stroke_acc_threshold = 0.95
draw_seq = False
draw_color_order = True
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
main_testing(test_image_base_dir, test_dataset, test_image_name,
sampling_base_dir, model_base_dir, model_name, sampling_num,
draw_seq=draw_seq, draw_order=draw_color_order,
state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,
round_stop_state_num=round_stop_state_num, stroke_acc_threshold=stroke_acc_threshold)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', type=str, default='', help="The test image name.")
parser.add_argument('--model', '-m', type=str, default='pretrain_clean_line_drawings', help="The trained model.")
parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.")
args = parser.parse_args()
assert args.input != ''
assert args.sample > 0
main(args.model, args.input, args.sample)
================================================
FILE: tools/gif_making.py
================================================
import os
import sys
import argparse
import numpy as np
from PIL import Image
import tensorflow as tf
sys.path.append('./')
from utils import draw, image_pasting_v3_testing
from model_common_test import DiffPastingV3
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def add_scaling_visualization(canvas_images, cursor, window_size, image_size):
"""
:param canvas_images: (N, H, W, 3)
:param cursor:
:param window_size:
:param image_size:
:return:
"""
cursor_pos = cursor * float(image_size)
cursor_x, cursor_y = int(round(cursor_pos[0])), int(round(cursor_pos[1])) # in large size
vis_color = [255, 0, 0]
cursor_width = 3
box_width = 2
canvas_imgs = 255 - np.round(canvas_images * 255.0).astype(np.uint8)
# add cursor visualization
canvas_imgs[:, cursor_y - cursor_width: cursor_y + cursor_width, cursor_x - cursor_width: cursor_x + cursor_width, :] = vis_color
# add box visualization
up = max(0, cursor_y - window_size // 2)
down = min(image_size, cursor_y + window_size // 2)
left = max(0, cursor_x - window_size // 2)
right = min(image_size, cursor_x + window_size // 2)
# up = cursor_y - window_size // 2
# down = cursor_y + window_size // 2
# left = cursor_x - window_size // 2
# right = cursor_x + window_size // 2
if up > 0:
canvas_imgs[:, up: up + box_width, left: right, :] = vis_color
if down < image_size:
canvas_imgs[:, down - box_width: down, left: right, :] = vis_color
if left > 0:
canvas_imgs[:, up: down, left: left + box_width, :] = vis_color
if right < image_size:
canvas_imgs[:, up: down, right - box_width: right, :] = vis_color
return canvas_imgs
def make_gif(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,
save_base,
cursor_type='next', min_window_size=32, raster_size=128, add_box=True):
"""
:param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2
:return:
"""
canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
gif_frames = []
cursor_idx = 0
if init_cursor.ndim == 1:
init_cursor = [init_cursor]
for round_idx in range(len(infer_lengths)):
print('Making progress', round_idx + 1, '/', len(infer_lengths))
round_length = infer_lengths[round_idx]
cursor_pos = init_cursor[cursor_idx] # (2)
cursor_idx += 1
prev_width = init_width
prev_scaling = 1.0
prev_window_size = float(raster_size) # (1)
for round_inner_i in range(round_length):
stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
curr_window_size_raw = prev_scaling * prev_window_size
curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)
curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)
curr_window_size = int(round(curr_window_size_raw)) # ()
pen_state = data[stroke_idx, 0]
stroke_params = data[stroke_idx, 1:] # (8)
x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width, width2], axis=0) # (2)
stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8)
next_width = stroke_params[4]
next_scaling = stroke_params[5]
next_window_size = next_scaling * curr_window_size_raw
next_window_size = np.maximum(next_window_size, min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_width = next_width * curr_window_size_raw / next_window_size
prev_scaling = next_scaling
prev_window_size = curr_window_size_raw
f = stroke_params_proc.tolist() # (8)
f += [1.0, 1.0]
gt_stroke_img = draw(f) # (H, W), [0.0-stroke, 1.0-BG]
gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,
image_size,
curr_window_size_raw,
pasting_func, sess) # [0.0-BG, 1.0-stroke]
if pen_state == 0:
canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke]
canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)
if add_box:
vis_inputs = np.expand_dims(canvas_rgb, axis=0)
vis_outputs = add_scaling_visualization(vis_inputs, cursor_pos, curr_window_size, image_size)
canvas_vis = vis_outputs[0]
else:
canvas_vis = canvas_rgb
canvas_vis_png = Image.fromarray(canvas_vis, 'RGB')
gif_frames.append(canvas_vis_png)
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
if cursor_type == 'next':
cursor_pos_large = stroke_position_next # (2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
cursor_pos = cursor_pos_large / float(image_size)
print('Saving to GIF ...')
save_path = os.path.join(save_base, 'dynamic.gif')
first_frame = gif_frames[0]
first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01)
def gif_making(npz_path):
assert npz_path != ''
min_window_size = 32
raster_size = 128
split_idx = npz_path.rfind('/')
if split_idx == -1:
file_base = './'
file_name = npz_path[:-4]
else:
file_base = npz_path[:npz_path.rfind('/')]
file_name = npz_path[npz_path.rfind('/') + 1: -4]
gif_base = os.path.join(file_base, file_name)
os.makedirs(gif_base, exist_ok=True)
# differentiable pasting graph
paste_v3_func = DiffPastingV3(raster_size)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
data = np.load(npz_path, encoding='latin1', allow_pickle=True)
strokes_data = data['strokes_data']
init_cursors = data['init_cursors']
image_size = data['image_size']
round_length = data['round_length']
init_width = data['init_width']
if round_length.ndim == 0:
round_lengths = [round_length]
else:
round_lengths = round_length
# print('round_lengths', round_lengths)
make_gif(sess, paste_v3_func,
strokes_data, init_cursors, image_size, round_lengths, init_width,
gif_base,
min_window_size=min_window_size, raster_size=raster_size)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--file', '-f', type=str, default='', help="define a npz path")
args = parser.parse_args()
gif_making(args.file)
================================================
FILE: tools/svg_conversion.py
================================================
import os
import argparse
import numpy as np
from xml.dom import minidom
def write_svg_1(path_list, img_size, save_path):
''' A long curve consisting of several strokes forms a path. '''
impl = minidom.getDOMImplementation()
doc = impl.createDocument(None, None, None)
rootElement = doc.createElement('svg')
rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
rootElement.setAttribute('height', str(img_size))
rootElement.setAttribute('width', str(img_size))
path_num = len(path_list)
for path_i in range(path_num):
path_items = path_list[path_i]
assert len(path_items) > 0
if len(path_items) == 1:
continue
childElement = doc.createElement('path')
childElement.setAttribute('id', 'curve_' + str(path_i))
childElement.setAttribute('stroke', '#000000')
childElement.setAttribute('stroke-width', '3.5')
childElement.setAttribute('stroke-linejoin', 'round')
childElement.setAttribute('stroke-linecap', 'round')
childElement.setAttribute('fill', 'none')
command_str = ''
for stroke_i, stroke_item in enumerate(path_items):
if stroke_i == 0:
command_str += 'M '
stroke_position = stroke_item[0]
command_str += str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '
else:
command_str += 'Q '
ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2]
ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1]
ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0]
command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \
str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '
last_position = stroke_position
childElement.setAttribute('d', command_str)
rootElement.appendChild(childElement)
doc.appendChild(rootElement)
f = open(save_path, 'w')
doc.writexml(f, addindent=' ', newl='\n')
f.close()
def write_svg_2(path_list, img_size, save_path):
''' A single stroke forms a path. '''
impl = minidom.getDOMImplementation()
doc = impl.createDocument(None, None, None)
rootElement = doc.createElement('svg')
rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
rootElement.setAttribute('height', str(img_size))
rootElement.setAttribute('width', str(img_size))
path_num = len(path_list)
for path_i in range(path_num):
path_items = path_list[path_i]
assert len(path_items) > 0
if len(path_items) == 1:
continue
for stroke_i, stroke_item in enumerate(path_items):
if stroke_i == 0:
last_position = stroke_item[0]
else:
childElement = doc.createElement('path')
childElement.setAttribute('id', 'curve_' + str(path_i))
childElement.setAttribute('stroke', '#000000')
childElement.setAttribute('stroke-linejoin', 'round')
childElement.setAttribute('stroke-linecap', 'round')
childElement.setAttribute('fill', 'none')
command_str = 'M ' + str(last_position[0]) + ', ' + str(last_position[1]) + ' '
command_str += 'Q '
ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2]
ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1]
ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0]
command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \
str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '
last_position = stroke_position
childElement.setAttribute('d', command_str)
childElement.setAttribute('stroke-width', str(stroke_width * img_size / 1.66))
rootElement.appendChild(childElement)
doc.appendChild(rootElement)
f = open(save_path, 'w')
doc.writexml(f, addindent=' ', newl='\n')
f.close()
def convert_strokes_to_svg(data, init_cursor, image_size, infer_lengths, init_width, save_path, svg_type,
cursor_type='next', min_window_size=32, raster_size=128):
"""
:param data: (N_strokes, 7): flag, x_c, y_c, dx, dy, r, ds
:return:
"""
cursor_idx = 0
absolute_strokes = []
absolute_strokes_path = []
if init_cursor.ndim == 1:
init_cursor = [init_cursor]
for round_idx in range(len(infer_lengths)):
round_length = infer_lengths[round_idx]
cursor_pos = init_cursor[cursor_idx] # (2)
cursor_idx += 1
cursor_pos_large = cursor_pos * float(image_size)
if len(absolute_strokes_path) > 0:
absolute_strokes.append(absolute_strokes_path)
absolute_strokes_path = [[cursor_pos_large]]
prev_width = init_width
prev_scaling = 1.0
prev_window_size = float(raster_size) # (1)
for round_inner_i in range(round_length):
stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
curr_window_size_raw = prev_scaling * prev_window_size
curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)
curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)
# curr_window_size = int(round(curr_window_size_raw)) # ()
stroke_params = data[stroke_idx, 1:] # (6)
pen_state = data[stroke_idx, 0]
next_width = stroke_params[4]
next_scaling = stroke_params[5]
next_width_abs = next_width * curr_window_size_raw / float(image_size)
prev_scaling = next_scaling
prev_window_size = curr_window_size_raw
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
if pen_state == 0:
absolute_strokes_path.append([stroke_params[0:2], stroke_position_next, next_width_abs])
else:
absolute_strokes.append(absolute_strokes_path)
absolute_strokes_path = [[stroke_position_next]]
if cursor_type == 'next':
cursor_pos_large = stroke_position_next # (2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
cursor_pos = cursor_pos_large / float(image_size)
absolute_strokes.append(absolute_strokes_path)
if svg_type == 'cluster':
write_svg_1(absolute_strokes, image_size, save_path)
elif svg_type == 'single':
write_svg_2(absolute_strokes, image_size, save_path)
else:
raise Exception('Unknown svg_type', svg_type)
def data_convert_to_absolute(npz_path, svg_type):
assert npz_path != ''
assert svg_type in ['single', 'cluster']
min_window_size = 32
raster_size = 128
split_idx = npz_path.rfind('/')
if split_idx == -1:
file_base = './'
file_name = npz_path[:-4]
else:
file_base = npz_path[:npz_path.rfind('/')]
file_name = npz_path[npz_path.rfind('/') + 1: -4]
svg_data_base = os.path.join(file_base, file_name)
os.makedirs(svg_data_base, exist_ok=True)
data = np.load(npz_path, encoding='latin1', allow_pickle=True)
strokes_data = data['strokes_data']
init_cursors = data['init_cursors']
image_size = data['image_size']
round_length = data['round_length']
init_width = data['init_width']
if round_length.ndim == 0:
round_lengths = [round_length]
else:
round_lengths = round_length
save_path = os.path.join(svg_data_base, str(svg_type) + '.svg')
convert_strokes_to_svg(strokes_data, init_cursors, image_size, round_lengths, init_width,
min_window_size=min_window_size, raster_size=raster_size, save_path=save_path,
svg_type=svg_type)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--file', '-f', type=str, default='', help="define a npz path")
parser.add_argument('--svg_type', '-st', type=str, choices=['single', 'cluster'], default='single',
help="svg type")
args = parser.parse_args()
data_convert_to_absolute(args.file, args.svg_type)
================================================
FILE: tools/visualize_drawing.py
================================================
import os
import sys
import argparse
import numpy as np
from PIL import Image
import tensorflow as tf
sys.path.append('./')
from utils import get_colors, draw, image_pasting_v3_testing
from model_common_test import DiffPastingV3
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def display_strokes_final(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,
save_base,
cursor_type='next', min_window_size=32, raster_size=128):
"""
:param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2
:return:
"""
canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
drawn_region = np.zeros_like(canvas)
overlap_region = np.zeros_like(canvas)
canvas_color_with_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)
canvas_color_wo_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)
canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)
cursor_idx = 0
if init_cursor.ndim == 1:
init_cursor = [init_cursor]
stroke_count = len(data)
color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255]
color_idx = 0
valid_stroke_count = stroke_count - np.sum(data[:, 0]).astype(np.int32) + len(init_cursor)
valid_color_rgb_set = get_colors(valid_stroke_count) # list of (3,) in [0, 255]
valid_color_idx = -1
# print('Drawn stroke number', valid_stroke_count)
# print(' flag x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
for round_idx in range(len(infer_lengths)):
round_length = infer_lengths[round_idx]
cursor_pos = init_cursor[cursor_idx] # (2)
cursor_idx += 1
prev_width = init_width
prev_scaling = 1.0
prev_window_size = float(raster_size) # (1)
for round_inner_i in range(round_length):
stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
curr_window_size_raw = prev_scaling * prev_window_size
curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)
curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)
pen_state = data[stroke_idx, 0]
stroke_params = data[stroke_idx, 1:] # (8)
x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width, width2], axis=0) # (2)
stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8)
next_width = stroke_params[4]
next_scaling = stroke_params[5]
next_window_size = next_scaling * curr_window_size_raw
next_window_size = np.maximum(next_window_size, min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_width = next_width * curr_window_size_raw / next_window_size
prev_scaling = next_scaling
prev_window_size = curr_window_size_raw
f = stroke_params_proc.tolist() # (8)
f += [1.0, 1.0]
gt_stroke_img = draw(f) # (H, W), [0.0-stroke, 1.0-BG]
gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,
image_size,
curr_window_size_raw,
pasting_func, sess) # [0.0-BG, 1.0-stroke]
is_overlap = False
if pen_state == 0:
canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke]
curr_drawn_stroke_region = np.zeros_like(gt_stroke_img_large)
curr_drawn_stroke_region[gt_stroke_img_large > 0.5] = 1
intersection = drawn_region * curr_drawn_stroke_region
# regard stroke with >50% overlap area as overlaped stroke
if np.sum(intersection) / np.sum(curr_drawn_stroke_region) > 0.5:
# enlarge the stroke a bit for better visualization
overlap_region[gt_stroke_img_large > 0] += 1
is_overlap = True
drawn_region[gt_stroke_img_large > 0.5] = 1
color_rgb = color_rgb_set[color_idx] # (3) in [0, 255]
color_idx += 1
color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)
color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)
canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + color_stroke # (H, W, 3)
if pen_state == 0:
valid_color_idx += 1
if pen_state == 0:
valid_color_rgb = valid_color_rgb_set[valid_color_idx] # (3) in [0, 255]
# valid_color_idx += 1
valid_color_rgb = np.reshape(valid_color_rgb, (1, 1, 3)).astype(np.float32)
valid_color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - valid_color_rgb / 255.0)
canvas_color_with_overlap = canvas_color_with_overlap * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + valid_color_stroke # (H, W, 3)
if not is_overlap:
canvas_color_wo_overlap = canvas_color_wo_overlap * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + valid_color_stroke # (H, W, 3)
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
if cursor_type == 'next':
cursor_pos_large = stroke_position_next # (2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
cursor_pos = cursor_pos_large / float(image_size)
canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)
canvas_black = 255 - np.round(canvas_rgb * 255.0).astype(np.uint8)
canvas_color_with_overlap = 255 - np.round(canvas_color_with_overlap * 255.0).astype(np.uint8)
canvas_color_wo_overlap = 255 - np.round(canvas_color_wo_overlap * 255.0).astype(np.uint8)
canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)
canvas_black_png = Image.fromarray(canvas_black, 'RGB')
canvas_black_save_path = os.path.join(save_base, 'output_rendered.png')
canvas_black_png.save(canvas_black_save_path, 'PNG')
canvas_color_png = Image.fromarray(canvas_color_with_overlap, 'RGB')
canvas_color_save_path = os.path.join(save_base, 'output_order_with_overlap.png')
canvas_color_png.save(canvas_color_save_path, 'PNG')
canvas_color_wo_png = Image.fromarray(canvas_color_wo_overlap, 'RGB')
canvas_color_wo_save_path = os.path.join(save_base, 'output_order_wo_overlap.png')
canvas_color_wo_png.save(canvas_color_wo_save_path, 'PNG')
canvas_color_m_png = Image.fromarray(canvas_color_with_moving, 'RGB')
canvas_color_m_save_path = os.path.join(save_base, 'output_order_with_moving.png')
canvas_color_m_png.save(canvas_color_m_save_path, 'PNG')
def visualize_drawing(npz_path):
assert npz_path != ''
min_window_size = 32
raster_size = 128
split_idx = npz_path.rfind('/')
if split_idx == -1:
file_base = './'
file_name = npz_path[:-4]
else:
file_base = npz_path[:npz_path.rfind('/')]
file_name = npz_path[npz_path.rfind('/') + 1: -4]
regenerate_base = os.path.join(file_base, file_name)
os.makedirs(regenerate_base, exist_ok=True)
# differentiable pasting graph
paste_v3_func = DiffPastingV3(raster_size)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
data = np.load(npz_path, encoding='latin1', allow_pickle=True)
strokes_data = data['strokes_data']
init_cursors = data['init_cursors']
image_size = data['image_size']
round_length = data['round_length']
init_width = data['init_width']
if round_length.ndim == 0:
round_lengths = [round_length]
else:
round_lengths = round_length
# print('round_lengths', round_lengths)
print('Processing ...')
display_strokes_final(sess, paste_v3_func,
strokes_data, init_cursors, image_size, round_lengths, init_width,
regenerate_base,
min_window_size=min_window_size, raster_size=raster_size)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--file', '-f', type=str, default='', help="define a npz path")
args = parser.parse_args()
visualize_drawing(args.file)
================================================
FILE: train_rough_photograph.py
================================================
import json
import os
import time
import numpy as np
import six
import tensorflow as tf
from PIL import Image
import argparse
import model_common_train as sketch_image_model
from hyper_parameters import FLAGS, get_default_hparams_rough, get_default_hparams_normal
from utils import create_summary, save_model, reset_graph, load_checkpoint
from dataset_utils import load_dataset_training
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
tf.logging.set_verbosity(tf.logging.INFO)
def should_save_log_img(step_):
if step_ % 500 == 0:
return True
else:
return False
def save_log_images(sess, model, data_set, save_root, step_num, curr_photo_prob, interpolate_type, save_num=10):
res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1)
log_img_resolutions = []
for ii in range(save_num - 1):
log_img_resolutions.append(model.hps.image_size_small + ii * res_gap)
log_img_resolutions.append(model.hps.image_size_large)
for res_i in range(len(log_img_resolutions)):
resolution = log_img_resolutions[res_i]
sub_save_root = os.path.join(save_root, 'res_' + str(resolution))
os.makedirs(sub_save_root, exist_ok=True)
input_photos, target_sketches, init_cursors, image_size_rand = \
data_set.get_batch_from_memory(memory_idx=res_i,
fixed_image_size=resolution,
random_cursor=model.hps.random_cursor,
photo_prob=curr_photo_prob,
interpolate_type=interpolate_type)
# input_photos: (N, image_size, image_size, 3), [0-stroke, 1-BG]
# target_sketches: (N, image_size, image_size), [0-stroke, 1-BG]
# init_cursors: (N, 1, 2), in size [0.0, 1.0)
input_photo_val = input_photos
init_cursor_input = [init_cursors for _ in range(model.total_loop)]
init_cursor_input = np.concatenate(init_cursor_input, axis=0)
image_size_input = [image_size_rand for _ in range(model.total_loop)]
image_size_input = np.stack(image_size_input, axis=0)
feed = {
model.init_cursor: init_cursor_input,
model.image_size: image_size_input,
model.init_width: [model.hps.min_width],
}
for loop_i in range(model.total_loop):
feed[model.input_photo_list[loop_i]] = input_photo_val
raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb],
feed) # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8)
input_photo = (np.array(input_photo_val[0, :, :, :]) * 255.0).astype(np.uint8)
target_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8)
raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8)
pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png')
input_save_path = os.path.join(sub_save_root, 'input.png')
target_save_path = os.path.join(sub_save_root, 'gt.png')
pred_rgb_save_root = os.path.join(sub_save_root, 'rgb')
os.makedirs(pred_rgb_save_root, exist_ok=True)
pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png')
raster_images_pred = Image.fromarray(raster_images_pred, 'L')
raster_images_pred.save(pred_save_path, 'PNG')
input_photo = Image.fromarray(input_photo, 'RGB')
input_photo.save(input_save_path, 'PNG')
target_sketch = Image.fromarray(target_sketch, 'L')
target_sketch.save(target_save_path, 'PNG')
raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB')
raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG')
def train(sess, train_model, eval_sample_model, train_set, valid_set, sub_log_root, sub_snapshot_root, sub_log_img_root):
# Setup summary writer.
summary_writer = tf.summary.FileWriter(sub_log_root)
print('-' * 100)
# Calculate trainable params.
t_vars = tf.trainable_variables()
count_t_vars = 0
for var in t_vars:
num_param = np.prod(var.get_shape().as_list())
count_t_vars += num_param
print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param))
print('Total trainable variables %i.' % count_t_vars)
print('-' * 100)
# main train loop
hps = train_model.hps
start = time.time()
# create saver
snapshot_save_vars = [var for var in tf.global_variables()
if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]
saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20)
start_step = 1
print('start_step', start_step)
mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))]
for _ in range(start_step, hps.num_steps + 1):
step = sess.run(train_model.global_step) # start from 0
count_step = min(step, hps.num_steps)
curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) *
(1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate)
if hps.sn_loss_type == 'decreasing':
assert hps.decrease_stop_steps <= hps.num_steps
assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight
curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps)
curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k
curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end)
elif hps.sn_loss_type == 'fixed':
curr_stroke_num_loss_weight = hps.stroke_num_loss_weight
elif hps.sn_loss_type == 'increasing':
curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps)
curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k
else:
raise Exception('Unknown sn_loss_type', hps.sn_loss_type)
if hps.early_pen_loss_type == 'head':
curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length
curr_early_pen_loss_start = 1
curr_early_pen_loss_end = curr_early_pen_loss_len
elif hps.early_pen_loss_type == 'tail':
curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2
curr_early_pen_loss_end = hps.max_seq_len
curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len
elif hps.early_pen_loss_type == 'move':
curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2
curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len
curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2
else:
raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type)
curr_early_pen_loss_start = int(round(curr_early_pen_loss_start))
curr_early_pen_loss_end = int(round(curr_early_pen_loss_end))
if hps.photo_prob_type == 'increasing' or hps.photo_prob_type == 'interpolate':
assert hps.photo_prob_end_step >= hps.photo_prob_start_step
curr_photo_prob_k = 1.0 / float(hps.photo_prob_end_step - hps.photo_prob_start_step)
curr_photo_prob = (count_step - hps.photo_prob_start_step) * curr_photo_prob_k
curr_photo_prob = max(0.0, curr_photo_prob)
curr_photo_prob = min(1.0, curr_photo_prob)
interpolate_type = 'prob' if hps.photo_prob_type == 'increasing' else 'image'
elif hps.photo_prob_type == 'zero':
curr_photo_prob = 0.0
interpolate_type = 'prob'
elif hps.photo_prob_type == 'one':
curr_photo_prob = 1.0
interpolate_type = 'prob'
else:
raise Exception('Unknown photo_prob_type', hps.photo_prob_type)
input_photos, target_sketches, init_cursors, image_sizes = \
train_set.get_batch_multi_res(loop_num=train_model.total_loop,
random_cursor=hps.random_cursor,
photo_prob=curr_photo_prob,
interpolate_type=interpolate_type)
# input_photos: list of (N, image_size, image_size, 3), [0-stroke, 1-BG]
# target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG]
# init_cursors: list of (N, 1, 2), in size [0.0, 1.0)
init_cursors_input = np.concatenate(init_cursors, axis=0)
image_size_input = np.stack(image_sizes, axis=0)
feed = {
train_model.init_cursor: init_cursors_input,
train_model.image_size: image_size_input,
train_model.init_width: [hps.min_width],
train_model.lr: curr_learning_rate,
train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight,
train_model.early_pen_loss_start_idx: curr_early_pen_loss_start,
train_model.early_pen_loss_end_idx: curr_early_pen_loss_end,
train_model.last_step_num: float(step),
}
for layer_i in range(len(hps.perc_loss_layers)):
feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i]
for loop_i in range(train_model.total_loop):
input_photo_val = input_photos[loop_i]
target_sketch_val = target_sketches[loop_i]
feed[train_model.input_photo_list[loop_i]] = input_photo_val
feed[train_model.target_sketch_list[loop_i]] = np.expand_dims(target_sketch_val, axis=-1)
(train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm,
stroke_num_cost, early_pen_states_cost,
pos_outside_cost, win_size_outside_cost,
train_step) = sess.run([
train_model.cost, train_model.raster_cost,
train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm,
train_model.stroke_num_cost,
train_model.early_pen_states_cost,
train_model.pos_outside_cost, train_model.win_size_outside_cost,
train_model.global_step
], feed)
## update mean_raster_loss
for layer_i in range(len(hps.perc_loss_layers)):
perc_relu_cost_raw = perc_relu_costs_raw[layer_i]
mean_perc_relu_loss = mean_perc_relu_losses[layer_i]
mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1)
mean_perc_relu_losses[layer_i] = mean_perc_relu_loss
_ = sess.run(train_model.train_op, feed)
if step % 20 == 0 and step > 0:
end = time.time()
time_taken = end - start
train_summary_map = {
'Train_Cost': train_cost,
'Train_raster_Cost': raster_cost,
'Train_stroke_num_Cost': stroke_num_cost,
'Train_early_pen_states_cost': early_pen_states_cost,
'Train_pos_outside_Cost': pos_outside_cost,
'Train_win_size_outside_Cost': win_size_outside_cost,
'Learning_Rate': curr_learning_rate,
'Time_Taken_Train': time_taken
}
for layer_i in range(len(hps.perc_loss_layers)):
layer_name = hps.perc_loss_layers[layer_i]
train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i]
create_summary(summary_writer, train_summary_map, train_step)
output_format = ('step: %d, lr: %.6f, '
'snw: %.3f, '
'cost: %.4f, '
'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, '
'pos_outside: %.4f, win_outside: %.4f, '
'train_time_taken: %.1f')
output_values = (step, curr_learning_rate,
curr_stroke_num_loss_weight,
train_cost,
raster_cost, stroke_num_cost, early_pen_states_cost,
pos_outside_cost, win_size_outside_cost,
time_taken)
output_log = output_format % output_values
# print(output_log)
tf.logging.info(output_log)
start = time.time()
if should_save_log_img(step) and step > 0:
save_log_images(sess, eval_sample_model, valid_set, sub_log_img_root, step, curr_photo_prob, interpolate_type)
if step % hps.save_every == 0 and step > 0:
save_model(sess, saver, sub_snapshot_root, step)
def trainer(model_params):
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
print('Hyperparams:')
for key, val in six.iteritems(model_params.values()):
print('%s = %s' % (key, str(val)))
print('Loading data files.')
print('-' * 100)
datasets = load_dataset_training(FLAGS.dataset_dir, model_params)
sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name)
sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name)
sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name)
train_set = datasets[0]
valid_set = datasets[1]
train_model_params = datasets[2]
eval_sample_model_params = datasets[3]
eval_sample_model_params.loop_per_gpu = 1
eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu
reset_graph()
train_model = sketch_image_model.VirtualSketchingModel(train_model_params)
eval_sample_model = sketch_image_model.VirtualSketchingModel(eval_sample_model_params, reuse=True)
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True)
if train_model_params.raster_loss_base_type == 'perceptual':
load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True)
# Write config file to json file.
os.makedirs(sub_log_root, exist_ok=True)
os.makedirs(sub_log_img_root, exist_ok=True)
os.makedirs(sub_snapshot_root, exist_ok=True)
with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f:
json.dump(train_model_params.values(), f, indent=True)
train(sess, train_model, eval_sample_model, train_set, valid_set,
sub_log_root, sub_snapshot_root, sub_log_img_root)
def main(dataset_type):
if dataset_type == 'rough':
model_params = get_default_hparams_rough()
elif dataset_type == 'face':
model_params = get_default_hparams_normal()
else:
raise Exception('Unknown dataset_type:', dataset_type)
trainer(model_params)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', '-d', type=str, default='rough', choices=['rough', 'face'], help="The dataset type.")
args = parser.parse_args()
main(args.data)
================================================
FILE: train_vectorization.py
================================================
import json
import os
import time
import numpy as np
import six
import tensorflow as tf
from PIL import Image
import model_common_train as sketch_vector_model
from hyper_parameters import FLAGS, get_default_hparams_clean
from utils import create_summary, save_model, reset_graph, load_checkpoint
from dataset_utils import load_dataset_training
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
tf.logging.set_verbosity(tf.logging.INFO)
def should_save_log_img(step_):
if step_ % 500 == 0:
return True
else:
return False
def save_log_images(sess, model, data_set, save_root, step_num, save_num=10):
res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1)
log_img_resolutions = []
for ii in range(save_num - 1):
log_img_resolutions.append(model.hps.image_size_small + ii * res_gap)
log_img_resolutions.append(model.hps.image_size_large)
for res_i in range(len(log_img_resolutions)):
resolution = log_img_resolutions[res_i]
sub_save_root = os.path.join(save_root, 'res_' + str(resolution))
os.makedirs(sub_save_root, exist_ok=True)
input_photos, target_sketches, init_cursors, image_size_rand = \
data_set.get_batch_from_memory(memory_idx=res_i, vary_thickness=model.hps.vary_thickness,
fixed_image_size=resolution,
random_cursor=model.hps.random_cursor,
init_cursor_on_undrawn_pixel=model.hps.init_cursor_on_undrawn_pixel)
# input_photos: (N, image_size, image_size), [0-stroke, 1-BG]
# target_sketches: (N, image_size, image_size), [0-stroke, 1-BG]
# init_cursors: (N, 1, 2), in size [0.0, 1.0)
if input_photos is not None:
input_photo_val = np.expand_dims(input_photos, axis=-1)
else:
input_photo_val = np.expand_dims(target_sketches, axis=-1)
init_cursor_input = [init_cursors for _ in range(model.total_loop)]
init_cursor_input = np.concatenate(init_cursor_input, axis=0)
image_size_input = [image_size_rand for _ in range(model.total_loop)]
image_size_input = np.stack(image_size_input, axis=0)
feed = {
model.init_cursor: init_cursor_input,
model.image_size: image_size_input,
model.init_width: [model.hps.min_width],
}
for loop_i in range(model.total_loop):
feed[model.input_photo_list[loop_i]] = input_photo_val
raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb],
feed) # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8)
input_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8)
raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8)
pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png')
target_save_path = os.path.join(sub_save_root, 'gt.png')
pred_rgb_save_root = os.path.join(sub_save_root, 'rgb')
os.makedirs(pred_rgb_save_root, exist_ok=True)
pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png')
raster_images_pred = Image.fromarray(raster_images_pred, 'L')
raster_images_pred.save(pred_save_path, 'PNG')
input_sketch = Image.fromarray(input_sketch, 'L')
input_sketch.save(target_save_path, 'PNG')
raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB')
raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG')
def train(sess, train_model, eval_sample_model, train_set, val_set, sub_log_root, sub_snapshot_root, sub_log_img_root):
# Setup summary writer.
summary_writer = tf.summary.FileWriter(sub_log_root)
print('-' * 100)
# Calculate trainable params.
t_vars = tf.trainable_variables()
count_t_vars = 0
for var in t_vars:
num_param = np.prod(var.get_shape().as_list())
count_t_vars += num_param
print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param))
print('Total trainable variables %i.' % count_t_vars)
print('-' * 100)
# main train loop
hps = train_model.hps
start = time.time()
# create saver
snapshot_save_vars = [var for var in tf.global_variables()
if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]
saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20)
start_step = 1
print('start_step', start_step)
mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))]
for _ in range(start_step, hps.num_steps + 1):
step = sess.run(train_model.global_step) # start from 0
count_step = min(step, hps.num_steps)
curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) *
(1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate)
if hps.sn_loss_type == 'decreasing':
assert hps.decrease_stop_steps <= hps.num_steps
assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight
curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps)
curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k
curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end)
elif hps.sn_loss_type == 'fixed':
curr_stroke_num_loss_weight = hps.stroke_num_loss_weight
elif hps.sn_loss_type == 'increasing':
curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps)
curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k
else:
raise Exception('Unknown sn_loss_type', hps.sn_loss_type)
if hps.early_pen_loss_type == 'head':
curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length
curr_early_pen_loss_start = 1
curr_early_pen_loss_end = curr_early_pen_loss_len
elif hps.early_pen_loss_type == 'tail':
curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2
curr_early_pen_loss_end = hps.max_seq_len
curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len
elif hps.early_pen_loss_type == 'move':
curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)
curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2
curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len
curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2
else:
raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type)
curr_early_pen_loss_start = int(round(curr_early_pen_loss_start))
curr_early_pen_loss_end = int(round(curr_early_pen_loss_end))
input_photos, target_sketches, init_cursors, image_sizes = \
train_set.get_batch_multi_res(loop_num=train_model.total_loop, vary_thickness=hps.vary_thickness,
random_cursor=hps.random_cursor,
init_cursor_on_undrawn_pixel=hps.init_cursor_on_undrawn_pixel)
# input_photos: list of (N, image_size, image_size), [0-stroke, 1-BG]
# target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG]
# init_cursors: list of (N, 1, 2), in size [0.0, 1.0)
init_cursors_input = np.concatenate(init_cursors, axis=0)
image_size_input = np.stack(image_sizes, axis=0)
feed = {
train_model.init_cursor: init_cursors_input,
train_model.image_size: image_size_input,
train_model.init_width: [hps.min_width],
train_model.lr: curr_learning_rate,
train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight,
train_model.early_pen_loss_start_idx: curr_early_pen_loss_start,
train_model.early_pen_loss_end_idx: curr_early_pen_loss_end,
train_model.last_step_num: float(step),
}
for layer_i in range(len(hps.perc_loss_layers)):
feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i]
for loop_i in range(train_model.total_loop):
if input_photos is not None:
input_photo_val = np.expand_dims(input_photos[loop_i], axis=-1)
else:
input_photo_val = np.expand_dims(target_sketches[loop_i], axis=-1)
feed[train_model.input_photo_list[loop_i]] = input_photo_val
(train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm,
stroke_num_cost, early_pen_states_cost,
pos_outside_cost, win_size_outside_cost,
train_step) = sess.run([
train_model.cost, train_model.raster_cost,
train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm,
train_model.stroke_num_cost,
train_model.early_pen_states_cost,
train_model.pos_outside_cost, train_model.win_size_outside_cost,
train_model.global_step
], feed)
## update mean_raster_loss
for layer_i in range(len(hps.perc_loss_layers)):
perc_relu_cost_raw = perc_relu_costs_raw[layer_i]
mean_perc_relu_loss = mean_perc_relu_losses[layer_i]
mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1)
mean_perc_relu_losses[layer_i] = mean_perc_relu_loss
_ = sess.run(train_model.train_op, feed)
if step % 20 == 0 and step > 0:
end = time.time()
time_taken = end - start
train_summary_map = {
'Train_Cost': train_cost,
'Train_raster_Cost': raster_cost,
'Train_stroke_num_Cost': stroke_num_cost,
'Train_early_pen_states_cost': early_pen_states_cost,
'Train_pos_outside_Cost': pos_outside_cost,
'Train_win_size_outside_Cost': win_size_outside_cost,
'Learning_Rate': curr_learning_rate,
'Time_Taken_Train': time_taken
}
for layer_i in range(len(hps.perc_loss_layers)):
layer_name = hps.perc_loss_layers[layer_i]
train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i]
create_summary(summary_writer, train_summary_map, train_step)
output_format = ('step: %d, lr: %.6f, '
'snw: %.3f, '
'cost: %.4f, '
'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, '
'pos_outside: %.4f, win_outside: %.4f, '
'train_time_taken: %.1f')
output_values = (step, curr_learning_rate,
curr_stroke_num_loss_weight,
train_cost,
raster_cost, stroke_num_cost, early_pen_states_cost,
pos_outside_cost, win_size_outside_cost,
time_taken)
output_log = output_format % output_values
# print(output_log)
tf.logging.info(output_log)
start = time.time()
if should_save_log_img(step) and step > 0:
save_log_images(sess, eval_sample_model, val_set, sub_log_img_root, step)
if step % hps.save_every == 0 and step > 0:
save_model(sess, saver, sub_snapshot_root, step)
def trainer(model_params):
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
print('Hyperparams:')
for key, val in six.iteritems(model_params.values()):
print('%s = %s' % (key, str(val)))
print('Loading data files.')
print('-' * 100)
datasets = load_dataset_training(FLAGS.dataset_dir, model_params)
sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name)
sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name)
sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name)
train_set = datasets[0]
val_set = datasets[1]
train_model_params = datasets[2]
eval_sample_model_params = datasets[3]
eval_sample_model_params.loop_per_gpu = 1
eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu
reset_graph()
train_model = sketch_vector_model.VirtualSketchingModel(train_model_params)
eval_sample_model = sketch_vector_model.VirtualSketchingModel(eval_sample_model_params, reuse=True)
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tfconfig)
sess.run(tf.global_variables_initializer())
load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True)
if train_model_params.raster_loss_base_type == 'perceptual':
load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True)
# Write config file to json file.
os.makedirs(sub_log_root, exist_ok=True)
os.makedirs(sub_log_img_root, exist_ok=True)
os.makedirs(sub_snapshot_root, exist_ok=True)
with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f:
json.dump(train_model_params.values(), f, indent=True)
train(sess, train_model, eval_sample_model, train_set, val_set,
sub_log_root, sub_snapshot_root, sub_log_img_root)
def main():
model_params = get_default_hparams_clean()
trainer(model_params)
if __name__ == '__main__':
main()
================================================
FILE: utils.py
================================================
import os
import cv2
import json
import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
#############################################
# Tensorflow utils
#############################################
def reset_graph():
"""Closes the current default session and resets the graph."""
sess = tf.get_default_session()
if sess:
sess.close()
tf.reset_default_graph()
def load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_only=False, gen_model_pretrain=False,
train_entire=False):
if ras_only:
load_var = {var.op.name: var for var in tf.global_variables() if 'raster_unit' in var.op.name}
elif perceptual_only:
load_var = {var.op.name: var for var in tf.global_variables() if 'VGG16' in var.op.name}
elif train_entire:
load_var = {var.op.name: var for var in tf.global_variables()
if 'discriminator' not in var.op.name
and 'raster_unit' not in var.op.name
and 'VGG16' not in var.op.name
and 'beta1' not in var.op.name
and 'beta2' not in var.op.name
and 'global_step' not in var.op.name
and 'Entire' not in var.op.name
}
else:
if gen_model_pretrain:
load_var = {var.op.name: var for var in tf.global_variables()
if 'discriminator' not in var.op.name
and 'raster_unit' not in var.op.name
and 'VGG16' not in var.op.name
and 'beta1' not in var.op.name
and 'beta2' not in var.op.name
# and 'global_step' not in var.op.name
}
else:
load_var = tf.global_variables()
restorer = tf.train.Saver(load_var)
if not ras_only:
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
model_checkpoint_path = ckpt.model_checkpoint_path
else:
model_checkpoint_path = checkpoint_path
print('Loading model %s' % model_checkpoint_path)
restorer.restore(sess, model_checkpoint_path)
snapshot_step = model_checkpoint_path[model_checkpoint_path.rfind('-') + 1:]
return snapshot_step
def create_summary(summary_writer, summ_map, step):
for summ_key in summ_map:
summ_value = summ_map[summ_key]
summ = tf.summary.Summary()
summ.value.add(tag=summ_key, simple_value=float(summ_value))
summary_writer.add_summary(summ, step)
summary_writer.flush()
def save_model(sess, saver, model_save_path, global_step):
checkpoint_path = os.path.join(model_save_path, 'p2s')
print('saving model %s.' % checkpoint_path)
print('global_step %i.' % global_step)
saver.save(sess, checkpoint_path, global_step=global_step)
#############################################
# Utils for basic image processing
#############################################
def normal(x, width):
return (int)(x * (width - 1) + 0.5)
def draw(f, width=128):
x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
x1 = x0 + (x2 - x0) * x1
y1 = y0 + (y2 - y0) * y1
x0 = normal(x0, width * 2)
x1 = normal(x1, width * 2)
x2 = normal(x2, width * 2)
y0 = normal(y0, width * 2)
y1 = normal(y1, width * 2)
y2 = normal(y2, width * 2)
z0 = (int)(1 + z0 * width // 2)
z2 = (int)(1 + z2 * width // 2)
canvas = np.zeros([width * 2, width * 2]).astype('float32')
tmp = 1. / 100
for i in range(100):
t = i * tmp
x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
z = (int)((1-t) * z0 + t * z2)
w = (1-t) * w0 + t * w2
cv2.circle(canvas, (y, x), z, w, -1)
return 1 - cv2.resize(canvas, dsize=(width, width))
def rgb_trans(split_num, break_values):
slice_per_split = split_num // 8
break_values_head, break_values_tail = break_values[:-1], break_values[1:]
results = []
for split_i in range(8):
break_value_head = break_values_head[split_i]
break_value_tail = break_values_tail[split_i]
slice_gap = float(break_value_tail - break_value_head) / float(slice_per_split)
for slice_i in range(slice_per_split):
slice_val = break_value_head + slice_gap * slice_i
slice_val = int(round(slice_val))
results.append(slice_val)
return results
def get_colors(color_num):
split_num = (color_num // 8 + 1) * 8
r_break_values = [0, 0, 0, 0, 128, 255, 255, 255, 128]
g_break_values = [0, 0, 128, 255, 255, 255, 128, 0, 0]
b_break_values = [128, 255, 255, 255, 128, 0, 0, 0, 0]
r_rst_list = rgb_trans(split_num, r_break_values)
g_rst_list = rgb_trans(split_num, g_break_values)
b_rst_list = rgb_trans(split_num, b_break_values)
assert len(r_rst_list) == len(g_rst_list)
assert len(b_rst_list) == len(g_rst_list)
rgb_color_list = [(r_rst_list[i], g_rst_list[i], b_rst_list[i]) for i in range(len(r_rst_list))]
return rgb_color_list
#############################################
# Utils for testing
#############################################
def save_seq_data(save_root, save_filename, strokes_data, init_cursors, image_size, round_length, init_width):
seq_save_root = os.path.join(save_root, 'seq_data')
os.makedirs(seq_save_root, exist_ok=True)
save_npz_path = os.path.join(seq_save_root, save_filename + '.npz')
np.savez(save_npz_path, strokes_data=strokes_data, init_cursors=init_cursors,
image_size=image_size, round_length=round_length, init_width=init_width)
def image_pasting_v3_testing(patch_image, cursor, image_size, window_size_f, pasting_func, sess):
"""
:param patch_image: (raster_size, raster_size), [0.0-BG, 1.0-stroke]
:param cursor: (2), in size [0.0, 1.0)
:param window_size_f: (), float32, [0.0, image_size)
:return: (image_size, image_size), [0.0-BG, 1.0-stroke]
"""
cursor_pos = cursor * float(image_size)
pasted_image = sess.run(pasting_func.pasted_image,
feed_dict={pasting_func.patch_canvas: np.expand_dims(patch_image, axis=-1),
pasting_func.cursor_pos_a: cursor_pos,
pasting_func.image_size_a: image_size,
pasting_func.window_size_a: window_size_f})
# (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
pasted_image = pasted_image[:, :, 0]
return pasted_image
def draw_strokes(data, save_root, save_filename, input_img, image_size, init_cursor, infer_lengths, init_width,
cursor_type, raster_size, min_window_size,
sess,
pasting_func=None,
save_seq=False, draw_order=False):
"""
:param data: (N_strokes, 9): flag, x1, y1, x2, y2, r2, s2
:return:
"""
canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
canvas_color = np.zeros((image_size, image_size, 3), dtype=np.float32)
canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)
frames = []
cursor_idx = 0
stroke_count = len(data)
color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255]
color_idx = 0
for round_idx in range(len(infer_lengths)):
round_length = infer_lengths[round_idx]
cursor_pos = init_cursor[cursor_idx] # (2)
cursor_idx += 1
prev_width = init_width
prev_scaling = 1.0
prev_window_size = raster_size # (1)
for round_inner_i in range(round_length):
stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
curr_window_size = prev_scaling * prev_window_size
curr_window_size = np.maximum(curr_window_size, min_window_size)
curr_window_size = np.minimum(curr_window_size, image_size)
pen_state = data[stroke_idx, 0]
stroke_params = data[stroke_idx, 1:] # (8)
x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width, width2], axis=0) # (2)
stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8)
next_width = stroke_params[4]
next_scaling = stroke_params[5]
next_window_size = next_scaling * curr_window_size
next_window_size = np.maximum(next_window_size, min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_width = next_width * curr_window_size / next_window_size
prev_scaling = next_scaling
prev_window_size = curr_window_size
f = stroke_params_proc.tolist() # (8)
f += [1.0, 1.0]
gt_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size,
curr_window_size,
pasting_func, sess) # [0.0-BG, 1.0-stroke]
if pen_state == 0:
canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke]
if draw_order:
color_rgb = color_rgb_set[color_idx] # (3) in [0, 255]
color_idx += 1
color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)
color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)
canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + color_stroke # (H, W, 3)
if pen_state == 0:
canvas_color = canvas_color * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + color_stroke # (H, W, 3)
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = stroke_params[2:4] * (curr_window_size / 2.0) # (1, 6), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
if cursor_type == 'next':
cursor_pos_large = stroke_position_next # (2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
cursor_pos = cursor_pos_large / float(image_size)
frames.append(canvas.copy())
canvas = np.clip(canvas, 0.0, 1.0)
canvas = np.round((1.0 - canvas) * 255.0).astype(np.uint8) # [0-stroke, 255-BG]
os.makedirs(save_root, exist_ok=True)
save_path = os.path.join(save_root, save_filename)
canvas_img = Image.fromarray(canvas, 'L')
canvas_img.save(save_path, 'PNG')
if save_seq:
seq_save_root = os.path.join(save_root, 'seq', save_filename[:-4])
os.makedirs(seq_save_root, exist_ok=True)
for len_i in range(len(frames)):
frame = frames[len_i]
frame = np.round((1.0 - frame) * 255.0).astype(np.uint8)
save_path = os.path.join(seq_save_root, str(len_i) + '.png')
frame_img = Image.fromarray(frame, 'L')
frame_img.save(save_path, 'PNG')
if draw_order:
order_save_root = os.path.join(save_root, 'order')
order_comp_save_root = os.path.join(save_root, 'order-compare')
os.makedirs(order_save_root, exist_ok=True)
os.makedirs(order_comp_save_root, exist_ok=True)
canvas_color = 255 - np.round(canvas_color * 255.0).astype(np.uint8)
canvas_color_img = Image.fromarray(canvas_color, 'RGB')
save_path = os.path.join(order_save_root, save_filename)
canvas_color_img.save(save_path, 'PNG')
canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)
# comparsions
rows = 2
cols = 3
plt.figure(figsize=(5 * cols, 5 * rows))
plt.subplot(rows, cols, 1)
plt.title('Input', fontsize=12)
# plt.axis('off')
input_rgb = input_img
plt.imshow(input_rgb)
# plt.subplot(rows, cols, 2)
# plt.title('GT', fontsize=12)
# # plt.axis('off')
# gt_rgb = np.stack([gt_img for _ in range(3)], axis=2)
# plt.imshow(gt_rgb)
plt.subplot(rows, cols, 2)
plt.title('Sketch', fontsize=12)
# plt.axis('off')
canvas_rgb = np.stack([canvas for _ in range(3)], axis=2)
plt.imshow(canvas_rgb)
plt.subplot(rows, cols, 4)
plt.title('Sketch Order', fontsize=12)
# plt.axis('off')
plt.imshow(canvas_color)
plt.subplot(rows, cols, 5)
plt.title('Sketch Order with moving', fontsize=12)
# plt.axis('off')
plt.imshow(canvas_color_with_moving)
plt.subplot(rows, cols, 6)
plt.title('Order', fontsize=12)
plt.axis('off')
img_h = 5
img_w = 10
color_array = np.zeros([len(color_rgb_set) * img_h, img_w, 3], dtype=np.uint8)
for i in range(len(color_rgb_set)):
color_array[i * img_h: i * img_h + img_h, :, :] = color_rgb_set[i]
plt.imshow(color_array)
comp_save_path = os.path.join(order_comp_save_root, save_filename)
plt.savefig(comp_save_path)
plt.close()
# plt.show()
def update_hyperparams(model_params, model_base_dir, model_name, infer_dataset):
with tf.gfile.Open(os.path.join(model_base_dir, model_name, 'model_config.json'), 'r') as f:
data = json.load(f)
ignored_keys = ['image_size_small', 'image_size_large', 'z_size', 'raster_perc_loss_layer', 'raster_loss_wk',
'decreasing_sn', 'raster_loss_weight']
for name in model_params._hparam_types.keys():
if name not in data and name not in ignored_keys:
raise Exception(name, 'not in model_config.json')
assert data['resize_method'] == 'AREA'
data['data_set'] = infer_dataset
fix_list = ['use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
pop_keys = ['gpus', 'image_size', 'resolution_type', 'loop_per_gpu', 'stroke_num_loss_weight_end',
'perc_loss_fuse_type',
'early_pen_length', 'early_pen_loss_type', 'early_pen_loss_weight',
'increase_start_steps', 'perc_loss_layers', 'sn_loss_type', 'photo_prob_end_step',
'sup_weight', 'gan_weight', 'base_raster_loss_base_type']
for pop_key in pop_keys:
if pop_key in data.keys():
data.pop(pop_key)
model_params.parse_json(json.dumps(data))
return model_params
================================================
FILE: vgg_utils/VGG16.py
================================================
import tensorflow as tf
def vgg_net(x, n_classes, img_size, reuse, is_train=True, dropout_rate=0.5):
# Define a scope for reusing the variables
with tf.variable_scope('VGG16', reuse=reuse):
x = tf.reshape(x, [-1, img_size, img_size, 1])
x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#1', x.shape)
x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#2', x.shape)
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#3', x.shape)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#4', x.shape)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#5', x.shape)
x_shape = x.get_shape().as_list()
nodes = x_shape[1] * x_shape[2] * x_shape[3]
x = tf.reshape(x, [-1, nodes])
x = tf.layers.dense(x, 4096, activation=tf.nn.relu)
if is_train:
x = tf.layers.dropout(x, dropout_rate)
x = tf.layers.dense(x, 4096, activation=tf.nn.relu)
if is_train:
x = tf.layers.dropout(x, dropout_rate)
out = tf.layers.dense(x, n_classes)
print(out)
return out
def vgg_net_slim(x, img_size):
return_map = {}
# Define a scope for reusing the variables
with tf.variable_scope('VGG16', reuse=tf.AUTO_REUSE):
x = tf.reshape(x, [-1, img_size, img_size, 1])
x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU1_1'] = x
x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU1_2'] = x
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#1', x.shape) #1 (?, 64, 64, 64)
x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU2_1'] = x
x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU2_2'] = x
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#2', x.shape) #2 (?, 32, 32, 128)
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU3_1'] = x
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU3_2'] = x
x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU3_3'] = x
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#3', x.shape) #3 (?, 16, 16, 256)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU4_1'] = x
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU4_2'] = x
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU4_3'] = x
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#4', x.shape) #4 (?, 8, 8, 512)
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU5_1'] = x
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU5_2'] = x
x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,
padding='SAME', activation=tf.nn.relu)
return_map['ReLU5_3'] = x
x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)
print('#5', x.shape) #5 (?, 4, 4, 512)
return return_map
================================================
FILE: virtual_sketch_gui.py
================================================
import tkinter as tk
from tkinter import filedialog, messagebox
import subprocess
import os
import threading
import glob
import sys
import shutil
# ==== Nastavení cesty ke skriptům ====
MODEL_OPTIONS = {
"Rough → Clean Sketch": "test_rough_sketch_simplification.py",
"Photo → Line Drawing": "test_photograph_to_line.py",
"Clean Sketch → Vector": "test_vectorization.py",
}
SVG_CONVERTER = os.path.join("tools", "svg_conversion.py")
class VirtualSketchApp:
def __init__(self, root):
self.root = root
self.root.title("Virtual Sketching GUI")
self.input_file = None
self.model_script = tk.StringVar(value=list(MODEL_OPTIONS.values())[0])
self.build_ui()
def build_ui(self):
tk.Label(self.root, text="1. Vyber vstupní obrázek:").pack(anchor="w")
tk.Button(self.root, text="Vybrat obrázek", command=self.choose_file).pack(fill="x")
self.file_label = tk.Label(self.root, text="Žádný soubor nevybrán", fg="gray")
self.file_label.pack(anchor="w")
tk.Label(self.root, text="2. Zvol model:").pack(anchor="w")
for name, script in MODEL_OPTIONS.items():
tk.Radiobutton(self.root, text=name, variable=self.model_script, value=script).pack(anchor="w")
tk.Button(self.root, text="3. Spustit zpracování", command=self.run_processing).pack(pady=10, fill="x")
self.status = tk.Label(self.root, text="Připraven", fg="green")
self.status.pack(anchor="w")
def choose_file(self):
path = filedialog.askopenfilename(filetypes=[
("Obrázkové soubory", "*.png *.jpg *.jpeg *.bmp *.gif *.tif *.tiff"),
("PNG", "*.png"),
("JPEG", "*.jpg;*.jpeg"),
("BMP", "*.bmp"),
("GIF", "*.gif"),
("TIFF", "*.tif;*.tiff")
])
if path:
self.input_file = path
self.file_label.config(text=os.path.basename(path), fg="black")
def run_processing(self):
if not self.input_file:
messagebox.showerror("Chyba", "Nejprve vyber obrázek.")
return
script = self.model_script.get()
cmd = [sys.executable, script, "--input", self.input_file]
def task():
self.status.config(text="Zpracovávám...", fg="blue")
try:
subprocess.run(cmd, check=True)
self.status.config(text="✅ Hotovo", fg="green")
self.move_outputs_to_sketches()
self.run_svg_conversion()
except subprocess.CalledProcessError:
self.status.config(text="❌ Chyba při běhu skriptu", fg="red")
threading.Thread(target=task).start()
def move_outputs_to_sketches(self):
if not self.input_file:
return
input_dir = os.path.dirname(self.input_file)
input_base = os.path.splitext(os.path.basename(self.input_file))[0]
sketches_dir = os.path.join(input_dir, "sketches")
os.makedirs(sketches_dir, exist_ok=True)
for ext in ["_0.npz", "_0_pred.png", "_input.png", "_0.svg"]:
candidate = os.path.join(input_dir, f"{input_base}{ext}")
if os.path.isfile(candidate):
shutil.move(candidate, os.path.join(sketches_dir, os.path.basename(candidate)))
def run_svg_conversion(self):
if not self.input_file:
return
input_dir = os.path.dirname(self.input_file)
input_base = os.path.splitext(os.path.basename(self.input_file))[0]
npz_file = os.path.join(input_dir, f"{input_base}_0.npz")
sketches_dir = os.path.join(input_dir, "sketches")
npz_file_in_sketches = os.path.join(sketches_dir, f"{input_base}_0.npz")
if not os.path.isfile(npz_file_in_sketches):
print("⚠️ .npz soubor nebyl nalezen pro SVG konverzi.")
return
cmd = [sys.executable, SVG_CONVERTER, "--file", npz_file_in_sketches, "--svg_type", "single"]
try:
subprocess.run(cmd, check=True)
svg_path = os.path.join(sketches_dir, f"{input_base}_0.svg")
if os.path.isfile(svg_path):
print(f"✅ SVG vytvořeno: {svg_path}")
except subprocess.CalledProcessError:
print("⚠️ Chyba při SVG konverzi")
if __name__ == "__main__":
root = tk.Tk()
app = VirtualSketchApp(root)
root.mainloop()