Full Code of MarkMoHR/virtual_sketching for AI

main 229fefae08e4 cached
28 files
431.9 KB
115.0k tokens
226 symbols
1 requests
Download .txt
Showing preview only (449K chars total). Download the full file or copy to clipboard to get everything.
Repository: MarkMoHR/virtual_sketching
Branch: main
Commit: 229fefae08e4
Files: 28
Total size: 431.9 KB

Directory structure:
gitextract_xyepypzb/

├── .gitignore
├── LICENSE
├── README.md
├── README_CN.md
├── WINDOWS_INSTALL_GUIDE.md
├── dataset_utils.py
├── docs/
│   ├── assets/
│   │   ├── font.css
│   │   └── style.css
│   └── index.html
├── hyper_parameters.py
├── launch_gui.bat
├── model_common_test.py
├── model_common_train.py
├── rasterization_utils/
│   ├── NeuralRenderer.py
│   └── RealRenderer.py
├── rnn.py
├── subnet_tf_utils.py
├── test_photograph_to_line.py
├── test_rough_sketch_simplification.py
├── test_vectorization.py
├── tools/
│   ├── gif_making.py
│   ├── svg_conversion.py
│   └── visualize_drawing.py
├── train_rough_photograph.py
├── train_vectorization.py
├── utils.py
├── vgg_utils/
│   └── VGG16.py
└── virtual_sketch_gui.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea
.idea/
data/
datas/
dataset/
datasets/
model/
models/
testData/
output/
outputs/

*.csv

# temporary files
*.txt~
*.pyc
.DS_Store
.gitignore~

*.h5

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021

[[Paper]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[Project Page]](https://markmohr.github.io/virtual_sketching/) | [[中文Readme]](/README_CN.md) | [[中文论文介绍]](https://blog.csdn.net/qq_33000225/article/details/118883153)

This code is used for **line drawing vectorization**, **rough sketch simplification** and **photograph to vector line drawing**.

<img src='docs/figures/muten.png' height=300><img src='docs/figures/muten-black-full-simplest.gif' height=300>

<img src='docs/figures/rocket.png' height=150><img src='docs/figures/rocket-blue-simplest.gif' height=150>&nbsp;&nbsp;&nbsp;&nbsp;<img src='docs/figures/1390.png' height=150><img src='docs/figures/face-blue-1390-simplest.gif' height=150>

## Outline
- [Dependencies](#dependencies)
- [Testing with Trained Weights](#testing-with-trained-weights)
- [Training](#training)
- [Citation](#citation)
- [Projects Using this Model/Method](#projects-using-this-modelmethod)
- [Blogs Mentioning this Paper](#blogs-mentioning-this-paper)
- [For Windows users](#-windows-users)

## Dependencies
 - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= version <=1.15.0)
 - [opencv](https://opencv.org/) == 3.4.2
 - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0
 - [scipy](https://www.scipy.org/) == 1.5.2
 - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11

## Testing with Trained Weights
### Model Preparation

Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): 
  - `pretrain_clean_line_drawings` (105 MB): for vectorization
  - `pretrain_rough_sketches` (105 MB): for rough sketch simplification
  - `pretrain_faces` (105 MB): for photograph to line drawing

Then, place them in this file structure:
```
outputs/
    snapshot/
        pretrain_clean_line_drawings/
        pretrain_rough_sketches/
        pretrain_faces/
```

### Usage
Choose the image in the `sample_inputs/` directory, and run one of the following commands for each task. The results will be under `outputs/sampling/`.

``` python
python3 test_vectorization.py --input muten.png

python3 test_rough_sketch_simplification.py --input rocket.png

python3 test_photograph_to_line.py --input 1390.png
```

**Note!!!** Our approach starts drawing from a randomly selected initial position, so it outputs different results in every testing trial (some might be fine and some might not be good enough). It is recommended to do several trials to select the visually best result. The number of outputs can be defined by the `--sample` argument:

``` python
python3 test_vectorization.py --input muten.png --sample 10

python3 test_rough_sketch_simplification.py --input rocket.png --sample 10

python3 test_photograph_to_line.py --input 1390.png --sample 10
```

**Reproducing Paper Figures:** our results (download from [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)) are selected by doing a certain number of trials. Apparently, it is required to use the same initial drawing positions to reproduce our results.

### Additional Tools

#### a) Visualization

Our vector output is stored in a `npz` package. Run the following command to obtain the rendered output and the drawing order. Results will be under the same directory of the `npz` file.
``` python
python3 tools/visualize_drawing.py --file path/to/the/result.npz 
```

#### b) GIF Making

To see the dynamic drawing procedure, run the following command to obtain the `gif`. Result will be under the same directory of the `npz` file.
``` python
python3 tools/gif_making.py --file path/to/the/result.npz 
```


#### c) Conversion to SVG

Our vector output in a `npz` package is stored as Eq.(1) in the main paper. Run the following command to convert it to the `svg` format. Result will be under the same directory of the `npz` file.

``` python
python3 tools/svg_conversion.py --file path/to/the/result.npz 
```
  - The conversion is implemented in two modes (by setting the `--svg_type` argument):
    - `single` (default): each stroke (a single segment) forms a path in the SVG file
    - `cluster`: each continuous curve (with multiple strokes) forms a path in the SVG file

**Important Notes**

In SVG format, all the segments on a path share the same *stroke-width*. While in our stroke design,  strokes on a common curve have different widths. Inside a stroke (a single segment), the thickness also changes linearly from an endpoint to another. 
Therefore, neither of the two conversion methods above generate visually the same results as the ones in our paper.
*(Please mention this issue in your paper if you do qualitative comparisons with our results in SVG format.)*


<br>

## Training

### Preparations

Download the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): 
  - `pretrain_neural_renderer` (40 MB): the pre-trained neural renderer
  - `pretrain_perceptual_model` (691 MB): the pre-trained perceptual model for raster loss

Download the datasets [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): 
  - `QuickDraw-clean` (14 MB): for clean line drawing vectorization. Taken from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) dataset.
  - `QuickDraw-rough` (361 MB): for rough sketch simplification. Synthesized by the pencil drawing generation method from [Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation).
  - `CelebAMask-faces` (370 MB): for photograph to line drawing. Processed from the [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset.

Then, place them in this file structure:
```
datasets/
    QuickDraw-clean/
    QuickDraw-rough/
    CelebAMask-faces/
outputs/
    snapshot/
        pretrain_neural_renderer/
        pretrain_perceptual_model/
```

### Running

It is recommended to train with multi-GPU. We train each task with 2 GPUs (each with 11 GB).

``` python
python3 train_vectorization.py

python3 train_rough_photograph.py --data rough

python3 train_rough_photograph.py --data face
```

<br>

## Citation

If you use the code and models please cite:

```
@article{mo2021virtualsketching,
  title   = {General Virtual Sketching Framework for Vector Line Art},
  author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
  journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
  year    = {2021},
  volume  = {40},
  number  = {4},
  pages   = {51:1--51:14}
}
```

<br>

## Projects Using this Model/Method

| **[Painterly style transfer](https://github.com/xch-liu/Painterly-Style-Transfer) (TVCG 2023)**  | **[Robot calligraphy](https://github.com/LoYuXr/CalliRewrite) (ICRA 2024)** | 
|:-------------:|:-------------------:|
| <img src="docs/figures/applications/Painterly-Style-Transfer.png" style="height: 170px"> | <img src="docs/figures/applications/robot-calligraphy.png" style="height: 170px"> |
| **[Geometrized cartoon line inbetweening](https://github.com/lisiyao21/AnimeInbet) (ICCV 2023)**  | **[Stroke correspondence and inbetweening](https://github.com/MarkMoHR/JoSTC) (TOG 2024)** | 
| <img src="docs/figures/applications/Geometrized-Cartoon-Line-Inbetweening.png" style="height: 170px"> | <img src="docs/figures/applications/Vector-Line-Inbetweening2.png" style="height: 170px"><img src="docs/figures/applications/Vector-Line-Inbetweening-dynamic1.gif" style="height: 170px"> |
| **[Modelling complex vector drawings](https://github.com/Co-do/Stroke-Cloud) (ICLR 2024)**  | **[Sketch-to-Image Generation](https://github.com/BlockDetail/Block-and-Detail) (UIST 2024)** | 
| <img src="docs/figures/applications/complex-vector-drawings.png" style="height: 170px"> | <img src="docs/figures/applications/sketch-to-image.png" style="height: 170px"> |

## Blogs Mentioning this Paper

- [The state of AI for hand-drawn animation inbetweening](https://yosefk.com/blog/the-state-of-ai-for-hand-drawn-animation-inbetweening.html)

## 🪟 Windows users

See [WINDOWS_INSTALL_GUIDE.md](WINDOWS_INSTALL_GUIDE.md) for a complete Windows installation guide and GUI usage.



================================================
FILE: README_CN.md
================================================
# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021

[[论文]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[项目主页]](https://markmohr.github.io/virtual_sketching/)

这份代码能用于实现:**线稿矢量化**、**粗糙草图简化**和**自然图像到矢量草图转换**。

<img src='docs/figures/muten.png' height=300><img src='docs/figures/muten-black-full-simplest.gif' height=300>

<img src='docs/figures/rocket.png' height=150><img src='docs/figures/rocket-blue-simplest.gif' height=150>&nbsp;&nbsp;&nbsp;&nbsp;<img src='docs/figures/1390.png' height=150><img src='docs/figures/face-blue-1390-simplest.gif' height=150>

## 目录
- [环境依赖](#环境依赖)
- [使用预训练模型测试](#使用预训练模型测试)
- [重新训练](#重新训练)
- [引用](#引用)

## 环境依赖
 - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= 版本 <=1.15.0)
 - [opencv](https://opencv.org/) == 3.4.2
 - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0
 - [scipy](https://www.scipy.org/) == 1.5.2
 - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11

## 使用预训练模型测试
### 模型下载与准备

在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型:
  - `pretrain_clean_line_drawings` (105 MB): 用于线稿矢量化
  - `pretrain_rough_sketches` (105 MB): 用于粗糙草图简化
  - `pretrain_faces` (105 MB): 用于自然图像到矢量草图转换

然后,按照如下结构放置模型:
```
outputs/
    snapshot/
        pretrain_clean_line_drawings/
        pretrain_rough_sketches/
        pretrain_faces/
```

### 测试方法
在`sample_inputs/`文件夹下选择图像,然后根据任务类型运行下面其中一个命令。生成结果会在`outputs/sampling/`目录下看到。

``` python
python3 test_vectorization.py --input muten.png

python3 test_rough_sketch_simplification.py --input rocket.png

python3 test_photograph_to_line.py --input 1390.png
```

**注意!!!** 我们的方法从一个随机挑选的初始位置启动绘制,所以每跑一次测试理论上都会得到一个不同的结果(有可能效果不错,但也可能效果不是很好)。因此,建议做多几次测试来挑选看上去最好的结果。也可以通过设置 `--sample`参数来定义跑一次测试代码同时输出(不同结果)的数量:

``` python
python3 test_vectorization.py --input muten.png --sample 10

python3 test_rough_sketch_simplification.py --input rocket.png --sample 10

python3 test_photograph_to_line.py --input 1390.png --sample 10
```

**如何复现论文展示的结果?** 可以从[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载论文展示的结果。这些是我们通过若干次测试得到不同输出后挑选的最好的结果。显然,若要复现这些结果,需要使用相同的初始位置启动绘制。

### 其他工具

#### a) 可视化

我们的矢量输出均使用`npz` 文件包存储。运行以下的命令可以得到渲染后的结果以及绘制顺序。可以在`npz` 文件包相同的目录下找到这些可视化结果。
``` python
python3 tools/visualize_drawing.py --file path/to/the/result.npz 
```

#### b) GIF制作

若要看到动态的绘制过程,可以运行以下命令来得到 `gif`。结果在`npz` 文件包相同的目录下。
``` python
python3 tools/gif_making.py --file path/to/the/result.npz 
```


#### c) 转化为SVG

`npz` 文件包中的矢量结果均按照论文里面的公式(1)格式存储。可以运行以下命令行,来将其转化为 `svg` 文件格式。结果在`npz` 文件包相同的目录下。

``` python
python3 tools/svg_conversion.py --file path/to/the/result.npz 
```
  - 转化过程以两种模式实现(设置`--svg_type`参数):
    - `single` (默认模式): 每个笔划(一根单独的曲线)构成SVG文件中的一个path路径
    - `cluster`: 每个连续曲线(多个笔划)构成SVG文件中的一个path路径

**重要注意事项**

在SVG文件格式中,一个path上的所有线段均只有同一个线宽(*stroke-width*)。然而在我们论文里面,定义一个连续曲线上所有的笔划可以有不同的线宽。同时,对于一个单独的笔划(贝塞尔曲线),定义其线宽从一个端点到另一个端点线性递增或者递减。

因此,上述两个转化方法得到的SVG结果理论上都无法保证跟论文里面的结果在视觉上完全一致。(*假如你在论文里面使用这里转化后的SVG结果进行视觉上的对比,请提及此问题。*)


<br>

## 重新训练

### 训练准备

在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型:
  - `pretrain_neural_renderer` (40 MB): 预训练好的神经网络渲染器
  - `pretrain_perceptual_model` (691 MB): 预训练好的perceptual model,用于算 raster loss

在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载训练数据集:
  - `QuickDraw-clean` (14 MB): 用于线稿矢量化。来自 [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset)数据集。
  - `QuickDraw-rough` (361 MB): 用于粗糙草图简化。利用[Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation)里面的铅笔画图像生成方法合成。
  - `CelebAMask-faces` (370 MB): 用于自然图像到矢量草图转换。使用[CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)数据集进行处理后得到。

然后,按照如下结构放置数据集:
```
datasets/
    QuickDraw-clean/
    QuickDraw-rough/
    CelebAMask-faces/
outputs/
    snapshot/
        pretrain_neural_renderer/
        pretrain_perceptual_model/
```

### 训练方法

建议使用多GPU进行训练。每个任务,我们均使用2个GPU(每个11 GB)来训练。

``` python
python3 train_vectorization.py

python3 train_rough_photograph.py --data rough

python3 train_rough_photograph.py --data face
```

<br>

## 引用

若使用此代码和模型,请引用本工作,谢谢!

```
@article{mo2021virtualsketching,
  title   = {General Virtual Sketching Framework for Vector Line Art},
  author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
  journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
  year    = {2021},
  volume  = {40},
  number  = {4},
  pages   = {51:1--51:14}
}
```



================================================
FILE: WINDOWS_INSTALL_GUIDE.md
================================================
# 🪟 Windows Installation Guide for Virtual Sketching

This guide provides step-by-step instructions to set up and run the [Virtual Sketching](https://github.com/MarkMoHR/virtual_sketching) project on Windows using Anaconda and Python 3.6.

## ✅ Requirements

- Windows 10 or newer
- Anaconda installed
- Git installed (optional, but recommended)

---

## 📦 Step 1: Create and Activate Conda Environment

```bash
conda create -n virtual_sketching python=3.6 -y
conda activate virtual_sketching
```

## 📂 Step 2: Clone the Repository

```bash
cd D:\
git clone https://github.com/MarkMoHR/virtual_sketching.git
cd virtual_sketching
```

(If you plan to contribute, consider forking the repo and using your own URL.)

---

## 🔧 Step 3: Install Required Packages

### From `conda`:
```bash
conda install opencv=3.4.2 pillow=6.2.0 scipy=1.5.2 -y
conda install -c conda-forge pycairo gtk3 cffi -y
```

### Then remove default TensorFlow (if installed via conda):
```bash
conda remove tensorflow
```

### Install required packages via `pip`:
```bash
pip install tensorflow==1.15.0
pip install numpy gizeh cairocffi matplotlib svgwrite
```

> ⚠️ Do not upgrade `pillow`, `scipy`, or `tensorflow` — newer versions are incompatible.

---

## 🛠️ Step 4: Fix Backend Compatibility

In `utils.py`, near the top, ensure the following:

```python
from PIL import Image
import matplotlib
matplotlib.use('TkAgg')  # force compatible backend
import matplotlib.pyplot as plt
```

This ensures proper rendering with tkinter on Windows.

---

## 🚀 Step 5: Run a Demo

From the project directory:

```bash
python test_vectorization.py --input sample_inputs\muten.png --sample 5
```

> ⚠️ This script only generates `.npz` and `.png` files. To convert to `.svg`, see next section.

---

## 🖼️ Optional: Use GUI Tool (Windows Only)

### Step 1: Launch GUI with provided batch file

Use the `runme.bat` file to activate the conda environment and launch the Python GUI:

```bat
rem runme.bat
set CONDAPATH=C:\ProgramData\anaconda3
set ENVNAME=virtual_sketching
call %CONDAPATH%\Scripts\activate.bat %ENVNAME%
python virtual_sketch_gui.py
pause
```

### Step 2: Select an input image and model

The GUI allows you to:
- Choose input image (PNG, JPEG, BMP, etc.)
- Select one of the three models
- Automatically runs processing and converts results to SVG
- Saves all outputs into a `sketches/` subfolder of the input image directory

---

## 🧠 Known Compatibility Notes

- Python 3.6 and TensorFlow 1.15 are required (due to use of `tensorflow.contrib`)
- Windows support requires manual setup of Gizeh and Cairo backends
- GPU usage optional — TensorFlow 1.15 requires CUDA 10.0 and cuDNN 7

---

## 🧩 Troubleshooting Tips

- ❌ `No module named '_cffi_backend'` → Run: `conda install -c conda-forge cffi`
- ❌ `ImportError: cannot import name 'draw_svg_from_npz'` → Use `svg_conversion.py` instead
- ❌ `.svg` looks wrong → Make sure you're using the official `svg_conversion.py` from `tools/`
- ❌ Missing `.svg`? → Use `virtual_sketch_gui.py` or run `tools/svg_conversion.py` manually on `.npz`

---

## 🤝 Contributing

Feel free to open issues or pull requests if you encounter bugs or want to improve the Windows support!

---

## 📁 Folder Structure Suggestion

```
virtual_sketching/
├── sample_inputs/
├── tools/
├── outputs/
├── virtual_sketch_gui.py
├── runme.bat
├── README.md
└── WINDOWS_INSTALL_GUIDE.md   ← You are here
```

---

Made with ❤️ by the community to help Windows users get started!



================================================
FILE: dataset_utils.py
================================================
import os
import math
import random
import scipy.io
import numpy as np
import tensorflow as tf
from PIL import Image

from rasterization_utils.RealRenderer import GizehRasterizor as RealRenderer


def copy_hparams(hparams):
    """Return a copy of an HParams instance."""
    return tf.contrib.training.HParams(**hparams.values())


class GeneralRawDataLoader(object):
    def __init__(self,
                 image_path,
                 raster_size,
                 test_dataset):
        self.image_path = image_path
        self.raster_size = raster_size
        self.test_dataset = test_dataset

    def get_test_image(self, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
        input_image_data, image_size_test = self.gen_input_images(self.image_path)
        input_image_data = np.array(input_image_data,
                                    dtype=np.float32)  # (1, image_size, image_size, (3)), [0.0-strokes, 1.0-BG]

        return input_image_data, \
               self.gen_init_cursors(input_image_data, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \
               image_size_test

    def gen_input_images(self, image_path):
        img = Image.open(image_path).convert('RGB')
        height, width = img.height, img.width
        max_dim = max(height, width)

        img = np.array(img, dtype=np.uint8)

        if height != width:
            # Padding to a square image
            if self.test_dataset == 'clean_line_drawings':
                pad_value = [255, 255, 255]
            elif self.test_dataset == 'faces':
                pad_value = [0, 0, 0]
            else:
                # TODO: find better padding pixel value
                pad_value = img[height - 10, width - 10]

            img_r, img_g, img_b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
            pad_width = max_dim - width
            pad_height = max_dim - height

            pad_img_r = np.pad(img_r, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[0])
            pad_img_g = np.pad(img_g, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[1])
            pad_img_b = np.pad(img_b, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[2])
            image_array = np.stack([pad_img_r, pad_img_g, pad_img_b], axis=-1)
        else:
            image_array = img

        if self.test_dataset == 'faces' and max_dim != 256:
            image_array_resize = Image.fromarray(image_array, 'RGB')
            image_array_resize = image_array_resize.resize(size=(256, 256), resample=Image.BILINEAR)
            image_array = np.array(image_array_resize, dtype=np.uint8)

        assert image_array.shape[0] == image_array.shape[1]
        img_size = image_array.shape[0]
        image_array = image_array.astype(np.float32)
        if self.test_dataset == 'clean_line_drawings':
            image_array = image_array[:, :, 0] / 255.0  # [0.0-stroke, 1.0-BG]
        else:
            image_array = image_array / 255.0  # [0.0-stroke, 1.0-BG]
        image_array = np.expand_dims(image_array, axis=0)
        return image_array, img_size

    def crop_patch(self, image, center, image_size, crop_size):
        x0 = center[0] - crop_size // 2
        x1 = x0 + crop_size
        y0 = center[1] - crop_size // 2
        y1 = y0 + crop_size
        x0 = max(0, min(x0, image_size))
        y0 = max(0, min(y0, image_size))
        x1 = max(0, min(x1, image_size))
        y1 = max(0, min(y1, image_size))
        patch = image[y0:y1, x0:x1]
        return patch

    def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):
        # sketch_image: [0.0-stroke, 1.0-BG]
        image_size = sketch_image.shape[0]
        if np.sum(1.0 - sketch_image) == 0:
            center = np.zeros((2), dtype=np.int32)
            return center
        else:
            while True:
                center = np.random.randint(0, image_size, size=(2))  # (2), in large size
                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
                if np.sum(patch) != 0:
                    if not init_cursor_on_undrawn_pixel:
                        return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)
                    else:
                        center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)
                        if np.sum(center_patch) != 0:
                            return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)

    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
        init_cursor_batch_list = []
        for cursor_i in range(init_cursor_num):
            if random_pos:
                init_cursor_batch = []
                for i in range(len(sketch_data)):
                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]
                    center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)
                    init_cursor_batch.append(center)

                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)
            else:
                raise Exception('Not finished')
            init_cursor_batch_list.append(init_cursor_batch)

        if init_cursor_num == 1:
            init_cursor_batch = init_cursor_batch_list[0]
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)
        else:
            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
                np.float32)  # (N, init_cursor_num, 1, 2)

        return init_cursor_batch


def load_dataset_testing(test_data_base_dir, test_dataset, test_img_name, model_params):
    assert test_dataset in ['clean_line_drawings', 'rough_sketches', 'faces']
    img_path = os.path.join(test_data_base_dir, test_dataset, test_img_name)
    print('Loaded {} from {}'.format(img_path, test_dataset))

    eval_model_params = copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.batch_size = 1
    eval_model_params.model_mode = 'sample'

    sample_model_params = copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    test_set = GeneralRawDataLoader(img_path, eval_model_params.raster_size, test_dataset=test_dataset)

    result = [test_set, eval_model_params, sample_model_params]
    return result


class GeneralMultiObjectDataLoader(object):
    def __init__(self,
                 stroke3_data,
                 batch_size,
                 raster_size,
                 image_size_small,
                 image_size_large,
                 is_bin,
                 is_train):
        self.batch_size = batch_size  # minibatch size
        self.raster_size = raster_size
        self.image_size_small = image_size_small
        self.image_size_large = image_size_large
        self.is_bin = is_bin
        self.is_train = is_train

        self.num_batches = len(stroke3_data) // self.batch_size
        self.batch_idx = -1
        print('batch_size', batch_size, ', num_batches', self.num_batches)

        self.rasterizor = RealRenderer()
        self.memory_sketch_data_batch = []

        assert type(stroke3_data) is list
        self.preprocess_rand_data(stroke3_data)

    def preprocess_rand_data(self, stroke3):
        if self.is_train:
            random.shuffle(stroke3)
        self.stroke3_data = stroke3

    def cal_dist(self, posA, posB):
        return np.sqrt(np.sum(np.power(posA - posB, 2)))

    def invalid_position(self, pos, obj_size, pos_list, size_list):
        if len(pos_list) == 0:
            return False

        pos_a = pos
        size_a = obj_size
        for i in range(len(pos_list)):
            pos_b = pos_list[i]
            size_b = size_list[i]

            if self.cal_dist(pos_a, pos_b) < ((size_a + size_b) // 4):
                return True

        return False

    def get_object_info(self, image_size, vary_thickness=True, try_total_times=3):
        if image_size <= 172:
            obj_num = 1
            obj_thickness_list = [3]
        elif image_size <= 225:
            obj_num = random.randint(1, 2)
            obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))
        elif image_size <= 278:
            obj_num = 2
            obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))
        elif image_size <= 331:
            obj_num = random.randint(2, 3)
            while True:
                obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))
                if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:
                    break
        elif image_size <= 384:
            obj_num = 3
            while True:
                obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))
                if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:
                    break
        else:
            raise Exception('Invalid image_size', image_size)

        if not vary_thickness:
            num_item = len(obj_thickness_list)
            obj_thickness_list = [3 for _ in range(num_item)]

        obj_pos_list = []
        obj_size_list = []
        if obj_num == 1:
            obj_size_list.append(image_size)
            center = (image_size // 2, image_size // 2)
            obj_pos_list.append(center)
        else:
            for obj_i in range(obj_num):
                for try_i in range(try_total_times):
                    obj_size = random.randint(128, image_size * 3 // 4)
                    obj_center = np.random.randint(obj_size // 3, image_size - (obj_size // 3) + 1, size=(2))

                    if not self.invalid_position(obj_center, obj_size, obj_pos_list,
                                                 obj_size_list) or try_i == try_total_times - 1:
                        obj_pos_list.append(obj_center)
                        obj_size_list.append(obj_size)
                        break

        assert len(obj_size_list) == len(obj_pos_list) == len(obj_thickness_list) == obj_num
        return obj_num, obj_size_list, obj_pos_list, obj_thickness_list

    def object_pasting(self, obj_img, canvas_img, center):
        c_y, c_x = center[0], center[1]
        obj_size = obj_img.shape[0]
        canvas_size = canvas_img.shape[0]
        box_left = max(0, c_x - obj_size // 2)
        box_right = min(canvas_size, c_x + obj_size // 2)
        box_up = max(0, c_y - obj_size // 2)
        box_bottom = min(canvas_size, c_y + obj_size // 2)

        box_canvas = canvas_img[box_up: box_bottom, box_left: box_right]

        obj_box_up = box_up - (c_y - obj_size // 2)
        obj_box_left = box_left - (c_x - obj_size // 2)
        box_obj = obj_img[obj_box_up: obj_box_up + (box_bottom - box_up),
                  obj_box_left: obj_box_left + (box_right - box_left)]

        box_canvas += box_obj

        rst_canvas = np.copy(canvas_img)
        rst_canvas[box_up: box_bottom, box_left: box_right] = box_canvas
        rst_canvas = np.clip(rst_canvas, 0.0, 1.0)

        return rst_canvas

    def get_multi_object_image(self, img_size, vary_thickness):
        object_num, object_size_list, object_pos_list, object_thickness_list = self.get_object_info(
            img_size, vary_thickness=vary_thickness)

        canvas = np.zeros(shape=(img_size, img_size), dtype=np.float32)

        for obj_i in range(object_num):
            rand_idx = np.random.randint(0, len(self.stroke3_data))
            rand_stroke3 = self.stroke3_data[rand_idx]  # (N_points, 3)

            object_size = object_size_list[obj_i]
            object_enter = object_pos_list[obj_i]
            object_thickness = object_thickness_list[obj_i]

            stroke_image = self.gen_stroke_images([rand_stroke3], object_size, object_thickness)
            stroke_image = 1.0 - stroke_image[0]  # (image_size, image_size), [0.0-BG, 1.0-strokes]

            canvas = self.object_pasting(stroke_image, canvas, object_enter)  # [0.0-BG, 1.0-strokes]

        canvas = 1.0 - canvas  # [0.0-strokes, 1.0-BG]
        return canvas

    def get_batch_from_memory(self, memory_idx, vary_thickness, fixed_image_size=-1, random_cursor=True,
                              init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
        if len(self.memory_sketch_data_batch) >= memory_idx + 1:
            sketch_data_batch = self.memory_sketch_data_batch[memory_idx]
            sketch_data_batch = np.expand_dims(sketch_data_batch,
                                               axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
            image_size_rand = sketch_data_batch.shape[1]
        else:
            if fixed_image_size == -1:
                image_size_rand = random.randint(self.image_size_small, self.image_size_large)
            else:
                image_size_rand = fixed_image_size

            multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness)  # [0.0-strokes, 1.0-BG]
            self.memory_sketch_data_batch.append(multi_obj_image)
            sketch_data_batch = np.expand_dims(multi_obj_image,
                                               axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]

        return None, sketch_data_batch, \
               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \
               image_size_rand

    def get_batch_multi_res(self, loop_num, vary_thickness, random_cursor=True,
                            init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
        sketch_data_batch = []
        init_cursors_batch = []
        image_size_batch = []
        batch_size_per_loop = self.batch_size // loop_num
        for loop_i in range(loop_num):
            image_size_rand = random.randint(self.image_size_small, self.image_size_large)
            sketch_data_sub_batch = []
            for batch_i in range(batch_size_per_loop):
                multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness)  # [0.0-strokes, 1.0-BG]
                sketch_data_sub_batch.append(multi_obj_image)
            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]

            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor,
                                                           init_cursor_on_undrawn_pixel, init_cursor_num)
            sketch_data_batch.append(sketch_data_sub_batch)
            init_cursors_batch.append(init_cursors_sub_batch)
            image_size_batch.append(image_size_rand)

        return None, \
               sketch_data_batch, \
               init_cursors_batch, \
               image_size_batch

    def gen_stroke_images(self, stroke3_list, image_size, stroke_width):
        """
        :param stroke3_list: list of (batch_size,), each with (N_points, 3)
        :param image_size:
        :return:
        """
        gt_image_array = self.rasterizor.raster_func(stroke3_list, image_size, stroke_width=stroke_width,
                                                     is_bin=self.is_bin, version='v2')
        gt_image_array = np.stack(gt_image_array, axis=0)
        gt_image_array = 1.0 - gt_image_array  # (batch_size, image_size, image_size), [0.0-strokes, 1.0-BG]
        return gt_image_array

    def crop_patch(self, image, center, image_size, crop_size):
        x0 = center[0] - crop_size // 2
        x1 = x0 + crop_size
        y0 = center[1] - crop_size // 2
        y1 = y0 + crop_size
        x0 = max(0, min(x0, image_size))
        y0 = max(0, min(y0, image_size))
        x1 = max(0, min(x1, image_size))
        y1 = max(0, min(y1, image_size))
        patch = image[y0:y1, x0:x1]
        return patch

    def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):
        # sketch_image: [0.0-stroke, 1.0-BG]
        image_size = sketch_image.shape[0]
        if np.sum(1.0 - sketch_image) == 0:
            center = np.zeros((2), dtype=np.int32)
            return center
        else:
            while True:
                center = np.random.randint(0, image_size, size=(2))  # (2), in large size
                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
                if np.sum(patch) != 0:
                    if not init_cursor_on_undrawn_pixel:
                        return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)
                    else:
                        center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)
                        if np.sum(center_patch) != 0:
                            return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)

    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):
        init_cursor_batch_list = []
        for cursor_i in range(init_cursor_num):
            if random_pos:
                init_cursor_batch = []
                for i in range(len(sketch_data)):
                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]
                    center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)
                    init_cursor_batch.append(center)

                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)
            else:
                raise Exception('Not finished')
            init_cursor_batch_list.append(init_cursor_batch)

        if init_cursor_num == 1:
            init_cursor_batch = init_cursor_batch_list[0]
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)
        else:
            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
                np.float32)  # (N, init_cursor_num, 1, 2)

        return init_cursor_batch


def load_dataset_multi_object(dataset_base_dir, model_params):
    train_stroke3_data = []
    val_stroke3_data = []

    if model_params.data_set == 'clean_line_drawings':
        def load_qd_npz_data(npz_path):
            data = np.load(npz_path, encoding='latin1', allow_pickle=True)
            selected_strokes3 = data['stroke3']  # (N_sketches,), each with (N_points, 3)
            selected_strokes3 = selected_strokes3.tolist()
            return selected_strokes3

        base_dir_clean = 'QuickDraw-clean'
        cates = ['airplane', 'bus', 'car', 'sailboat', 'bird', 'cat', 'dog',
                 # 'rabbit',
                 'tree', 'flower',
                 # 'circle', 'line',
                 'zigzag'
                 ]

        for cate in cates:
            train_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'train', cate + '.npz')
            val_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'test', cate + '.npz')
            print(train_cate_sketch_data_npz_path)

            train_cate_stroke3_data = load_qd_npz_data(
                train_cate_sketch_data_npz_path)  # list of (N_sketches,), each with (N_points, 3)
            val_cate_stroke3_data = load_qd_npz_data(val_cate_sketch_data_npz_path)
            train_stroke3_data += train_cate_stroke3_data
            val_stroke3_data += val_cate_stroke3_data
    else:
        raise Exception('Unknown data type:', model_params.data_set)

    print('Loaded {}/{} from {}'.format(len(train_stroke3_data), len(val_stroke3_data), model_params.data_set))
    print('model_params.max_seq_len %i.' % model_params.max_seq_len)

    eval_sample_model_params = copy_hparams(model_params)
    eval_sample_model_params.use_input_dropout = 0
    eval_sample_model_params.use_recurrent_dropout = 0
    eval_sample_model_params.use_output_dropout = 0
    eval_sample_model_params.batch_size = 1  # only sample one at a time
    eval_sample_model_params.model_mode = 'eval_sample'

    train_set = GeneralMultiObjectDataLoader(train_stroke3_data,
                                             model_params.batch_size, model_params.raster_size,
                                             model_params.image_size_small, model_params.image_size_large,
                                             model_params.bin_gt, is_train=True)
    val_set = GeneralMultiObjectDataLoader(val_stroke3_data,
                                           eval_sample_model_params.batch_size, eval_sample_model_params.raster_size,
                                           eval_sample_model_params.image_size_small,
                                           eval_sample_model_params.image_size_large,
                                           eval_sample_model_params.bin_gt, is_train=False)

    result = [train_set, val_set, model_params, eval_sample_model_params]
    return result


class GeneralDataLoaderMultiObjectRough(object):
    def __init__(self,
                 photo_data,
                 sketch_data,
                 texture_data,
                 shadow_data,
                 batch_size,
                 raster_size,
                 image_size_small,
                 image_size_large,
                 is_train):
        self.batch_size = batch_size  # minibatch size
        self.raster_size = raster_size
        self.image_size_small = image_size_small
        self.image_size_large = image_size_large
        self.is_train = is_train

        assert photo_data is not None
        assert len(photo_data) == len(sketch_data)
        # self.num_batches = len(sketch_data) // self.batch_size
        self.batch_idx = -1
        print('batch_size', batch_size)

        assert type(photo_data) is list
        assert type(sketch_data) is list
        assert type(texture_data) is list and len(texture_data) > 0
        assert type(shadow_data) is list and len(shadow_data) > 0
        self.photo_data = photo_data
        self.sketch_data = sketch_data
        self.texture_data = texture_data  # list of (H, W, 3), [0, 255], uint8
        self.shadow_data = shadow_data  # list of (H, W), [0, 255], uint8

        self.memory_photo_data_batch = []
        self.memory_sketch_data_batch = []

    def rough_augmentation(self, raw_photo, texture_prob=0.20, noise_prob=0.15, shadow_prob=0.20):
        # raw_photo: (H, W), [0.0-stroke, 1.0-BG]
        aug_photo_rgb = np.stack([raw_photo for _ in range(3)], axis=-1)

        def texture_generation(texture_list, image_shape):
            while True:
                random_texture_id = random.randint(0, len(texture_list) - 1)
                texture_large = texture_list[random_texture_id]
                t_w, t_h = texture_large.shape[1], texture_large.shape[0]
                i_w, i_h = image_shape[1], image_shape[0]

                if t_h >= i_h and t_w >= i_w:
                    texture_large = np.copy(texture_large).astype(np.float32)
                    crop_y = random.randint(0, t_h - i_h)
                    crop_x = random.randint(0, t_w - i_w)
                    crop_texture = texture_large[crop_y: crop_y + i_h, crop_x: crop_x + i_w, :]
                    return crop_texture

        def texture_change(rough_img_, all_textures):
            # rough_img_: (H, W, 3), [0.0-stroke, 1.0-BG]

            texture_image = texture_generation(all_textures, rough_img_.shape)  # (h, w, 3)
            texture_image /= 255.0

            rand_b = np.random.uniform(1.0, 2.0, size=rough_img_.shape)
            textured_img = rough_img_ * (texture_image / rand_b + (rand_b - 1.0) / rand_b)  # [0.0, 1.0]
            return textured_img

        def noise_change(rough_img_, noise_scale=25):
            # rough_img_: (H, W, 3), [0.0, 1.0]
            rough_img_255 = rough_img_ * 255.0

            rand_noise = np.random.uniform(-1.0, 1.0, size=rough_img_255.shape) * noise_scale
            # rand_noise = np.random.normal(size=rough_img.shape) * noise_scale
            noise_img = rough_img_255 + rand_noise
            noise_img = np.clip(noise_img, 0.0, 255.0)
            noise_img /= 255.0
            return noise_img

        def shadow_change(rough_img_, all_shadows):
            # rough_img_: (H, W, 3), [0.0, 1.0]
            rough_img_255 = rough_img_ * 255.0

            shadow_i = random.randint(0, len(all_shadows) - 1)
            shadow_full = all_shadows[shadow_i]  # (H, W), [0, 255]
            shadow_img_size = shadow_full.shape[0]

            while True:
                position = np.random.randint(-shadow_img_size // 2, shadow_img_size // 2, (2))
                if abs(position[0]) > (shadow_img_size // 8) and abs(position[1]) > (shadow_img_size // 8):
                    break
            position += (shadow_img_size // 2)

            crop_up = shadow_img_size - position[0]
            crop_left = shadow_img_size - position[1]

            shadow_image_large = shadow_full[crop_up: crop_up + shadow_img_size, crop_left: crop_left + shadow_img_size]
            shadow_bg = Image.fromarray(shadow_image_large, 'L')
            shadow_bg = shadow_bg.resize(size=(rough_img_255.shape[1], rough_img_255.shape[0]), resample=Image.BILINEAR)
            shadow_bg = np.array(shadow_bg, dtype=np.float32) / 255.0  # [0.0-shadow, 1.0-BG]
            shadow_bg = np.stack([shadow_bg for _ in range(3)], axis=-1)

            shadow_img = rough_img_255 * shadow_bg
            shadow_img /= 255.0
            return shadow_img

        if random.random() <= texture_prob:
            aug_photo_rgb = texture_change(aug_photo_rgb, self.texture_data)  # (H, W, 3), [0.0, 1.0]
        if random.random() <= noise_prob:
            aug_photo_rgb = noise_change(aug_photo_rgb)  # (H, W, 3), [0.0, 1.0]
        if random.random() <= shadow_prob:
            aug_photo_rgb = shadow_change(aug_photo_rgb, self.shadow_data)  # (H, W, 3), [0.0, 1.0]

        return aug_photo_rgb

    def image_interpolation(self, photo, sketch, photo_prob):
        interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)
        interp_photo = np.clip(interp_photo, 0.0, 1.0)
        return interp_photo

    def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,
                              photo_prob=1.0, init_cursor_num=1):
        if len(self.memory_sketch_data_batch) >= memory_idx + 1:
            photo_data_batch = self.memory_photo_data_batch[memory_idx]
            sketch_data_batch = self.memory_sketch_data_batch[memory_idx]
            image_size_rand = sketch_data_batch.shape[1]
        else:
            if fixed_image_size == -1:
                image_size_rand = random.randint(self.image_size_small, self.image_size_large)
            else:
                image_size_rand = fixed_image_size

            # photo_prob = 0.0 if photo_prob_type == 'zero' else 1.0
            photo_data_batch, sketch_data_batch = self.select_sketch(
                image_size_rand)  # both: (H, W), [0.0-stroke, 1.0-BG]
            photo_data_batch = self.rough_augmentation(photo_data_batch)  # (H, W, 3), [0.0-stroke, 1.0-BG]

            self.memory_photo_data_batch.append(photo_data_batch)
            self.memory_sketch_data_batch.append(sketch_data_batch)

        if interpolate_type == 'prob':
            if random.random() >= photo_prob:
                photo_data_batch = np.stack([sketch_data_batch for _ in range(3)],
                                            axis=-1)  # (H, W, 3), [0.0-stroke, 1.0-BG]
        elif interpolate_type == 'image':
            photo_data_batch = self.image_interpolation(
                photo_data_batch, np.stack([sketch_data_batch for _ in range(3)], axis=-1), photo_prob)
        else:
            raise Exception('Unknown interpolate_type', interpolate_type)

        photo_data_batch = np.expand_dims(photo_data_batch, axis=0)  # (1, image_size, image_size, 3)
        sketch_data_batch = np.expand_dims(sketch_data_batch,
                                           axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]

        return photo_data_batch, sketch_data_batch, \
               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand

    def select_sketch(self, image_size_rand):
        resolution_idx = image_size_rand - self.image_size_small
        img_idx = random.randint(0, len(self.sketch_data[resolution_idx]) - 1)
        assert img_idx != -1

        selected_sketch = self.sketch_data[resolution_idx][img_idx]  # [0-stroke, 255-BG], uint8
        selected_photo = self.photo_data[resolution_idx][img_idx]  # [0-stroke, 255-BG], uint8

        rst_sketch_image = selected_sketch.astype(np.float32) / 255.0  # [0.0-stroke, 1.0-BG]
        rst_photo_image = selected_photo.astype(np.float32) / 255.0  # [0.0-stroke, 1.0-BG]

        return rst_photo_image, rst_sketch_image

    def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):
        photo_data_batch = []
        sketch_data_batch = []
        init_cursors_batch = []
        image_size_batch = []
        batch_size_per_loop = self.batch_size // loop_num
        for loop_i in range(loop_num):
            image_size_rand = random.randint(self.image_size_small, self.image_size_large)

            photo_data_sub_batch = []
            sketch_data_sub_batch = []
            for img_i in range(batch_size_per_loop):
                photo_patch, sketch_patch = self.select_sketch(image_size_rand)  # both: (H, W), [0.0-stroke, 1.0-BG]
                photo_patch = self.rough_augmentation(photo_patch)  # (H, W, 3), [0.0-stroke, 1.0-BG]

                if interpolate_type == 'prob':
                    if random.random() >= photo_prob:
                        photo_patch = np.stack([sketch_patch for _ in range(3)],
                                               axis=-1)  # (H, W, 3), [0.0-stroke, 1.0-BG]
                elif interpolate_type == 'image':
                    photo_patch = self.image_interpolation(
                        photo_patch, np.stack([sketch_patch for _ in range(3)], axis=-1), photo_prob)
                else:
                    raise Exception('Unknown interpolate_type', interpolate_type)

                photo_data_sub_batch.append(photo_patch)
                sketch_data_sub_batch.append(sketch_patch)

            photo_data_sub_batch = np.stack(photo_data_sub_batch,
                                            axis=0)  # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]
            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]
            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)
            photo_data_batch.append(photo_data_sub_batch)
            sketch_data_batch.append(sketch_data_sub_batch)
            init_cursors_batch.append(init_cursors_sub_batch)
            image_size_batch.append(image_size_rand)

        return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch

    def crop_patch(self, image, center, image_size, crop_size):
        x0 = center[0] - crop_size // 2
        x1 = x0 + crop_size
        y0 = center[1] - crop_size // 2
        y1 = y0 + crop_size
        x0 = max(0, min(x0, image_size))
        y0 = max(0, min(y0, image_size))
        x1 = max(0, min(x1, image_size))
        y1 = max(0, min(y1, image_size))
        patch = image[y0:y1, x0:x1]
        return patch

    def gen_init_cursor_single(self, sketch_image):
        # sketch_image: [0.0-stroke, 1.0-BG]
        image_size = sketch_image.shape[0]
        if np.sum(1.0 - sketch_image) == 0:
            center = np.zeros((2), dtype=np.int32)
            return center
        else:
            while True:
                center = np.random.randint(0, image_size, size=(2))  # (2), in large size
                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
                if np.sum(patch) != 0:
                    return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)

    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):
        init_cursor_batch_list = []
        for cursor_i in range(init_cursor_num):
            if random_pos:
                init_cursor_batch = []
                for i in range(len(sketch_data)):
                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]
                    center = self.gen_init_cursor_single(sketch_image)
                    init_cursor_batch.append(center)

                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)
            else:
                raise Exception('Not finished')
            init_cursor_batch_list.append(init_cursor_batch)

        if init_cursor_num == 1:
            init_cursor_batch = init_cursor_batch_list[0]
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)
        else:
            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
                np.float32)  # (N, init_cursor_num, 1, 2)

        return init_cursor_batch


def load_dataset_multi_object_rough(dataset_base_dir, model_params):
    train_photo_data = []
    train_sketch_data = []
    val_photo_data = []
    val_sketch_data = []
    texture_data = []
    shadow_data = []

    if model_params.data_set == 'rough_sketches':
        base_dir_rough = 'QuickDraw-rough'

        def load_sketch_data(mat_path):
            sketch_data_mat = scipy.io.loadmat(mat_path)
            sketch_data = sketch_data_mat['sketch_array']
            sketch_data = np.array(sketch_data, dtype=np.uint8)  # (N, resolution, resolution), [0-strokes, 255-BG]
            return sketch_data

        def load_photo_data(mat_path):
            photo_data_mat = scipy.io.loadmat(mat_path)
            photo_data = photo_data_mat['image_array']
            photo_data = np.array(photo_data, dtype=np.uint8)  # (N, resolution, resolution), [0-strokes, 255-BG]
            return photo_data

        def load_normal_data(img_path):
            assert '.png' in img_path or '.jpg'
            img = Image.open(img_path).convert('RGB')
            img = np.array(img, dtype=np.uint8)  # (H, W, 3), [0-stroke, 255-BG], uint8
            return img

        ## Texture
        texture_base = os.path.join(dataset_base_dir, base_dir_rough, 'texture')
        all_texture = os.listdir(texture_base)
        all_texture.sort()

        for file_name in all_texture:
            texture_path = os.path.join(texture_base, file_name)
            texture_uint8 = load_normal_data(texture_path)
            texture_data.append(texture_uint8)

        ## shadow
        def process_angle(img, temp_size):
            padded_img = img.copy()
            padded_img[0, 0:temp_size] -= 1
            padded_img[0, -(temp_size + 1):-1] -= 1
            padded_img[-1, 0:temp_size] -= 1
            padded_img[-1, -(temp_size + 1):-1] -= 1

            padded_img[0:temp_size, 0] -= 1
            padded_img[0:temp_size, -1] -= 1
            padded_img[-(temp_size + 1):-1, 0] -= 1
            padded_img[-(temp_size + 1):-1, -1] -= 1
            return padded_img

        def pad_img(ori_img, pad_value):
            padded_img = np.pad(ori_img, 1, constant_values=pad_value)
            img_h, img_w = padded_img.shape[0], padded_img.shape[1]

            temp_size = img_h // 3
            padded_img = process_angle(padded_img, temp_size)

            temp_size = img_h // 9
            padded_img = process_angle(padded_img, temp_size)

            temp_size = img_h // 15
            padded_img = process_angle(padded_img, temp_size)

            temp_size = img_h // 21
            padded_img = process_angle(padded_img, temp_size)

            padded_img = np.clip(padded_img, 0, 255)

            return padded_img

        def shadow_generation(transparency, shadow_img_size=1024):
            deepest_value = int(255 * transparency)

            center_patch = np.zeros((shadow_img_size // 2, shadow_img_size // 2), dtype=np.uint8)
            center_patch.fill(255)

            pad_gap = shadow_img_size // 2
            shadow_patch = center_patch.copy()
            for i in range(pad_gap):
                curr_pad_value = 255.0 - float(255.0 - deepest_value) / float(pad_gap) * (i + 1)
                shadow_patch = pad_img(shadow_patch, pad_value=curr_pad_value)

            for i in range(shadow_img_size // 4):
                shadow_patch = pad_img(shadow_patch, pad_value=deepest_value)

            assert shadow_patch.shape[0] == shadow_img_size * 2, shadow_patch.shape[0]
            return shadow_patch

        for transparency_ in range(90, 95 + 1):
            transparency = transparency_ / 100.0
            shadow_full = shadow_generation(transparency)
            shadow_data.append(shadow_full)

        splits = ['train', 'test']

        resolutions = [model_params.image_size_small, model_params.image_size_large]

        for resolution in range(resolutions[0], resolutions[1] + 1):
            for split in splits:
                sketch_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',
                                                'sketch', split, 'res_' + str(resolution) + '.mat')
                photo_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',
                                               'photo', split, 'res_' + str(resolution) + '.mat')
                sketch_data1_uint8 = load_sketch_data(
                    sketch_mat1_path)  # (N, resolution, resolution), [0-strokes, 255-BG]
                photo_data1_uint8 = load_photo_data(photo_mat1_path)  # (N, resolution, resolution), [0-strokes, 255-BG]

                sketch_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',
                                                'sketch', split, 'res_' + str(resolution) + '.mat')
                photo_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',
                                               'photo', split, 'res_' + str(resolution) + '.mat')
                sketch_data2_uint8 = load_sketch_data(
                    sketch_mat2_path)  # (N, resolution, resolution), [0-strokes, 255-BG]
                photo_data2_uint8 = load_photo_data(photo_mat2_path)  # (N, resolution, resolution), [0-strokes, 255-BG]

                sketch_data_uint8 = np.concatenate([sketch_data1_uint8, sketch_data2_uint8],
                                                   axis=0)  # (N, resolution, resolution), [0-strokes, 255-BG]
                photo_data_uint8 = np.concatenate([photo_data1_uint8, photo_data2_uint8],
                                                  axis=0)  # (N, resolution, resolution), [0-strokes, 255-BG]

                if split == 'train':
                    train_photo_data.append(photo_data_uint8)
                    train_sketch_data.append(sketch_data_uint8)
                else:
                    val_photo_data.append(photo_data_uint8)
                    val_sketch_data.append(sketch_data_uint8)

        assert len(train_sketch_data) == len(train_photo_data)
        assert len(val_sketch_data) == len(val_photo_data)
    else:
        raise Exception('Unknown data type:', model_params.data_set)

    print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))
    print('model_params.max_seq_len %i.' % model_params.max_seq_len)

    eval_sample_model_params = copy_hparams(model_params)
    eval_sample_model_params.use_input_dropout = 0
    eval_sample_model_params.use_recurrent_dropout = 0
    eval_sample_model_params.use_output_dropout = 0
    eval_sample_model_params.batch_size = 1  # only sample one at a time
    eval_sample_model_params.model_mode = 'eval_sample'

    train_set = GeneralDataLoaderMultiObjectRough(train_photo_data, train_sketch_data,
                                                  texture_data, shadow_data,
                                                  model_params.batch_size, model_params.raster_size,
                                                  model_params.image_size_small, model_params.image_size_large,
                                                  is_train=True)
    val_set = GeneralDataLoaderMultiObjectRough(val_photo_data, val_sketch_data,
                                                texture_data, shadow_data,
                                                eval_sample_model_params.batch_size,
                                                eval_sample_model_params.raster_size,
                                                eval_sample_model_params.image_size_small,
                                                eval_sample_model_params.image_size_large,
                                                is_train=False)

    result = [
        train_set, val_set, model_params, eval_sample_model_params
    ]
    return result


class GeneralDataLoaderNormalImageLinear(object):
    def __init__(self,
                 photo_data,
                 sketch_data,
                 sketch_shape,
                 batch_size,
                 raster_size,
                 image_size_small,
                 image_size_large,
                 random_image_size,
                 flip_prob,
                 rotate_prob,
                 is_train):
        self.batch_size = batch_size  # minibatch size
        self.raster_size = raster_size
        self.image_size_small = image_size_small
        self.image_size_large = image_size_large
        self.random_image_size = random_image_size
        self.is_train = is_train

        assert photo_data is not None
        assert len(photo_data) == len(sketch_data)
        self.num_batches = len(sketch_data) // self.batch_size
        self.batch_idx = -1
        print('batch_size', batch_size, ', num_batches', self.num_batches)

        self.flip_prob = flip_prob
        self.rotate_prob = rotate_prob

        assert type(photo_data) is list
        assert type(sketch_data) is list
        self.photo_data = photo_data
        self.sketch_data = sketch_data
        self.sketch_shape = sketch_shape

    def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,
                              photo_prob=1.0,
                              init_cursor_num=1):
        if self.random_image_size:
            image_size_rand = fixed_image_size
        else:
            image_size_rand = self.image_size_large

        photo_data_batch, sketch_data_batch = self.select_sketch_and_crop(
            image_size_rand, data_idx=memory_idx, photo_prob=photo_prob,
            interpolate_type=interpolate_type)  # sketch_patch: [0.0-stroke, 1.0-BG]

        photo_data_batch = np.expand_dims(photo_data_batch, axis=0)  # (1, image_size, image_size, 3)
        sketch_data_batch = np.expand_dims(sketch_data_batch,
                                           axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]
        image_size_rand = sketch_data_batch.shape[1]

        return photo_data_batch, sketch_data_batch, \
               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand

    def crop_and_augment(self, photo, sketch, shape, crop_size, rotate_angle, stroke_cover=0.01):
        # img: [0-stroke, 255-BG], uint8

        def angle_convert(angle):
            return angle / 180.0 * math.pi

        img_h, img_w = shape[0], shape[1]

        if self.is_train:
            crop_up = random.randint(0, img_h - crop_size)
            crop_left = random.randint(0, img_w - crop_size)
        else:
            crop_up = (img_h - crop_size) // 2
            crop_left = (img_w - crop_size) // 2

        assert crop_up >= 0
        assert crop_left >= 0

        crop_box = (crop_left, crop_up, crop_left + crop_size, crop_up + crop_size)
        rst_sketch_image = sketch.crop(crop_box)
        rst_photo_image = photo.crop(crop_box)

        if random.random() <= self.flip_prob and self.is_train:
            rst_sketch_image = rst_sketch_image.transpose(Image.FLIP_LEFT_RIGHT)
            rst_photo_image = rst_photo_image.transpose(Image.FLIP_LEFT_RIGHT)

        if rotate_angle != 0 and self.is_train:
            rst_sketch_image = rst_sketch_image.rotate(rotate_angle, resample=Image.BILINEAR)
            rst_photo_image = rst_photo_image.rotate(rotate_angle, resample=Image.BILINEAR)
            rst_sketch_image = np.array(rst_sketch_image, dtype=np.uint8)
            rst_photo_image = np.array(rst_photo_image, dtype=np.uint8)

            center = rst_photo_image.shape[0] // 2

            new_dim = float(crop_size) / (
                        math.sin(angle_convert(abs(rotate_angle))) + math.cos(angle_convert(abs(rotate_angle))))
            new_dim = int(round(new_dim))

            start_pos = center - new_dim // 2
            end_pos = start_pos + new_dim
            rst_sketch_image = rst_sketch_image[start_pos:end_pos, start_pos:end_pos, :]
            rst_photo_image = rst_photo_image[start_pos:end_pos, start_pos:end_pos, :]

        rst_sketch_image = np.array(rst_sketch_image, dtype=np.float32) / 255.0  # [0.0-stroke, 1.0-BG]
        rst_sketch_image = rst_sketch_image[:, :, 0]
        rst_photo_image = np.array(rst_photo_image, dtype=np.float32) / 255.0  # [0.0-stroke, 1.0-BG]

        percentage = np.mean(1.0 - rst_sketch_image)
        valid = True
        if percentage < stroke_cover:
            valid = False

        return rst_photo_image, rst_sketch_image, valid

    def image_interpolation(self, photo, sketch, photo_prob):
        interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)
        interp_photo = np.clip(interp_photo, 0.0, 1.0)
        return interp_photo

    def select_sketch_and_crop(self, image_size_rand, interpolate_type, rotate_angle=0, photo_prob=1.0,
                               data_idx=-1, trial_times=10):
        if self.is_train:
            while True:
                rand_img_idx = random.randint(0, len(self.sketch_data) - 1)
                selected_sketch_shape = self.sketch_shape[rand_img_idx]
                if selected_sketch_shape[0] >= image_size_rand and selected_sketch_shape[1] >= image_size_rand:
                    img_idx = rand_img_idx
                    break
        else:
            assert data_idx != -1
            img_idx = data_idx

        assert img_idx != -1
        selected_sketch = self.sketch_data[img_idx]
        selected_photo = self.photo_data[img_idx]
        selected_shape = self.sketch_shape[img_idx]

        assert interpolate_type in ['prob', 'image']

        if interpolate_type == 'prob' and random.random() >= photo_prob:
            selected_photo = self.sketch_data[img_idx]

        for trial_i in range(trial_times):
            cropped_photo, cropped_sketch, valid = \
                self.crop_and_augment(selected_photo, selected_sketch, selected_shape, image_size_rand, rotate_angle)
            # cropped_photo, cropped_sketch: [0.0-stroke, 1.0-BG]

            if valid or trial_i == trial_times - 1:
                if interpolate_type == 'image':
                    cropped_photo = self.image_interpolation(cropped_photo,
                                                             np.stack([cropped_sketch for _ in range(3)], axis=-1),
                                                             photo_prob)

                return cropped_photo, cropped_sketch

    def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):
        photo_data_batch = []
        sketch_data_batch = []
        init_cursors_batch = []
        image_size_batch = []
        batch_size_per_loop = self.batch_size // loop_num
        for loop_i in range(loop_num):
            if self.random_image_size:
                image_size_rand = random.randint(self.image_size_small, self.image_size_large)
            else:
                image_size_rand = self.image_size_large

            rotate_angle = 0
            if random.random() <= self.rotate_prob:
                rotate_angle = random.randint(-45, 45)

            photo_data_sub_batch = []
            sketch_data_sub_batch = []
            for img_i in range(batch_size_per_loop):
                photo_patch, sketch_patch = \
                    self.select_sketch_and_crop(image_size_rand, rotate_angle=rotate_angle, photo_prob=photo_prob,
                                                interpolate_type=interpolate_type)  # sketch_patch: [0.0-stroke, 1.0-BG]
                photo_data_sub_batch.append(photo_patch)
                sketch_data_sub_batch.append(sketch_patch)

            photo_data_sub_batch = np.stack(photo_data_sub_batch,
                                            axis=0)  # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]
            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,
                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]
            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)

            photo_data_batch.append(photo_data_sub_batch)
            sketch_data_batch.append(sketch_data_sub_batch)
            init_cursors_batch.append(init_cursors_sub_batch)

            image_size_rand = photo_data_sub_batch.shape[1]
            image_size_batch.append(image_size_rand)

        return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch

    def crop_patch(self, image, center, image_size, crop_size):
        x0 = center[0] - crop_size // 2
        x1 = x0 + crop_size
        y0 = center[1] - crop_size // 2
        y1 = y0 + crop_size
        x0 = max(0, min(x0, image_size))
        y0 = max(0, min(y0, image_size))
        x1 = max(0, min(x1, image_size))
        y1 = max(0, min(y1, image_size))
        patch = image[y0:y1, x0:x1]
        return patch

    def gen_init_cursor_single(self, sketch_image):
        # sketch_image: [0.0-stroke, 1.0-BG]
        image_size = sketch_image.shape[0]
        if np.sum(1.0 - sketch_image) == 0:
            center = np.zeros((2), dtype=np.int32)
            return center
        else:
            while True:
                center = np.random.randint(0, image_size, size=(2))  # (2), in large size
                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)
                if np.sum(patch) != 0:
                    return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)

    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):
        init_cursor_batch_list = []
        for cursor_i in range(init_cursor_num):
            if random_pos:
                init_cursor_batch = []
                for i in range(len(sketch_data)):
                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]
                    center = self.gen_init_cursor_single(sketch_image)
                    init_cursor_batch.append(center)

                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)
            else:
                raise Exception('Not finished')
            init_cursor_batch_list.append(init_cursor_batch)

        if init_cursor_num == 1:
            init_cursor_batch = init_cursor_batch_list[0]
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)
        else:
            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)
            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(
                np.float32)  # (N, init_cursor_num, 1, 2)

        return init_cursor_batch


def load_dataset_normal_images(dataset_base_dir, model_params):
    train_photo_data = []
    train_sketch_data = []
    train_data_shape = []
    val_photo_data = []
    val_sketch_data = []
    val_data_shape = []

    if model_params.data_set == 'faces':
        random_training_image_size = False
        flip_prob = -0.1
        rotate_prob = -0.1

        splits = ['train', 'val']

        database = os.path.join(dataset_base_dir, 'CelebAMask-faces')
        photo_base = os.path.join(database, 'CelebA-HQ-img256')
        edge_base = os.path.join(database, 'CelebAMask-HQ-edge256')

        train_split_txt_save_path = os.path.join(database, 'train.txt')
        val_split_txt_save_path = os.path.join(database, 'val.txt')
        celeba_train_txt = np.loadtxt(train_split_txt_save_path, dtype=str)
        celeba_val_txt = np.loadtxt(val_split_txt_save_path, dtype=str)
        splits_indices_map = {'train': celeba_train_txt, 'val': celeba_val_txt}

        for split in splits:
            split_indices = splits_indices_map[split]

            for i in range(len(split_indices)):
                file_idx = split_indices[i]
                img_file_path = os.path.join(photo_base, str(file_idx) + '.jpg')
                edge_img_path = os.path.join(edge_base, str(file_idx) + '.png')

                img_data = Image.open(img_file_path).convert('RGB')
                edge_data = Image.open(edge_img_path).convert('RGB')

                if split == 'train':
                    train_photo_data.append(img_data)
                    train_sketch_data.append(edge_data)
                    train_data_shape.append((img_data.height, img_data.width))
                else:  # split == 'val'
                    val_photo_data.append(img_data)
                    val_sketch_data.append(edge_data)
                    val_data_shape.append((img_data.height, img_data.width))

        assert len(train_sketch_data) == len(train_data_shape) == len(train_photo_data)
        assert len(val_sketch_data) == len(val_data_shape) == len(val_photo_data)
    else:
        raise Exception('Unknown data type:', model_params.data_set)

    print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))
    print('model_params.max_seq_len %i.' % model_params.max_seq_len)

    eval_sample_model_params = copy_hparams(model_params)
    eval_sample_model_params.use_input_dropout = 0
    eval_sample_model_params.use_recurrent_dropout = 0
    eval_sample_model_params.use_output_dropout = 0
    eval_sample_model_params.batch_size = 1  # only sample one at a time
    eval_sample_model_params.model_mode = 'eval_sample'

    train_set = GeneralDataLoaderNormalImageLinear(train_photo_data, train_sketch_data, train_data_shape,
                                                   model_params.batch_size, model_params.raster_size,
                                                   image_size_small=model_params.image_size_small,
                                                   image_size_large=model_params.image_size_large,
                                                   random_image_size=random_training_image_size,
                                                   flip_prob=flip_prob, rotate_prob=rotate_prob,
                                                   is_train=True)
    val_set = GeneralDataLoaderNormalImageLinear(val_photo_data, val_sketch_data, val_data_shape,
                                                 eval_sample_model_params.batch_size,
                                                 eval_sample_model_params.raster_size,
                                                 image_size_small=eval_sample_model_params.image_size_small,
                                                 image_size_large=eval_sample_model_params.image_size_large,
                                                 random_image_size=random_training_image_size,
                                                 flip_prob=flip_prob, rotate_prob=rotate_prob,
                                                 is_train=False)

    result = [
        train_set, val_set, model_params, eval_sample_model_params
    ]
    return result


def load_dataset_training(dataset_base_dir, model_params):
    if model_params.data_set == 'clean_line_drawings':
        return load_dataset_multi_object(dataset_base_dir, model_params)
    elif model_params.data_set == 'rough_sketches':
        return load_dataset_multi_object_rough(dataset_base_dir, model_params)
    elif model_params.data_set == 'faces':
        return load_dataset_normal_images(dataset_base_dir, model_params)
    else:
        raise Exception('Unknown data_set', model_params.data_set)


================================================
FILE: docs/assets/font.css
================================================
/* Homepage Font */

/* latin-ext */
@font-face {
  font-family: 'Lato';
  font-style: normal;
  font-weight: 400;
  src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');
  unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
}

/* latin */
@font-face {
  font-family: 'Lato';
  font-style: normal;
  font-weight: 400;
  src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');
  unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
}

/* latin-ext */
@font-face {
  font-family: 'Lato';
  font-style: normal;
  font-weight: 700;
  src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');
  unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
}

/* latin */
@font-face {
  font-family: 'Lato';
  font-style: normal;
  font-weight: 700;
  src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');
  unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
}


================================================
FILE: docs/assets/style.css
================================================
/* Body */
body {
  background: #e3e5e8;
  color: #ffffff;
  font-family: 'Lato', Verdana, Helvetica, sans-serif;
  font-weight: 300;
  font-size: 14pt;
}

/* Hyperlinks */
a {text-decoration: none;}
a:link {color: #1772d0;}
a:visited {color: #1772d0;}
a:active {color: red;}
a:hover {color: #f09228;}

/* Pre-formatted Text */
pre {
  margin: 5pt 0;
  border: 0;
  font-size: 12pt;
  background: #fcfcfc;
}

/* Project Page Style */
/* Section */
.section {
  width: 768pt;
  min-height: 100pt;
  margin: 15pt auto;
  padding: 20pt 30pt;
  border: 1pt hidden #000;
  text-align: justify;
  color: #000000;
  background: #ffffff;
}

/* Header (Title and Logo) */
.section .header {
  min-height: 80pt;
  margin-top: 30pt;
}
.section .header .logo {
  width: 80pt;
  margin-left: 10pt;
  float: left;
}
.section .header .logo img {
  width: 80pt;
  object-fit: cover;
}
.section .header .title {
  margin: 0 120pt;
  text-align: center;
  font-size: 22pt;
}

/* Author */
.section .author {
  margin: 5pt 0;
  text-align: center;
  font-size: 16pt;
}

/* Institution */
.section .institution {
  margin: 5pt 0;
  text-align: center;
  font-size: 16pt;
}

/* Hyperlink (such as Paper and Code) */
.section .link {
  margin: 5pt 0;
  text-align: center;
  font-size: 16pt;
}

/* Teaser */
.section .teaser {
  margin: 20pt 0;
  text-align: left;
}
.section .teaser img {
  width: 95%;
}

/* Section Title */
.section .title {
  text-align: center;
  font-size: 22pt;
  margin: 5pt 0 15pt 0;  /* top right bottom left */
}

/* Section Body */
.section .body {
  margin-bottom: 15pt;
  text-align: justify;
  font-size: 14pt;
}

/* BibTeX */
.section .bibtex {
  margin: 5pt 0;
  text-align: left;
  font-size: 22pt;
}

/* Related Work */
.section .ref {
  margin: 20pt 0 10pt 0;  /* top right bottom left */
  text-align: left;
  font-size: 18pt;
  font-weight: bold;
}

/* Citation */
.section .citation {
  min-height: 60pt;
  margin: 10pt 0;
}
.section .citation .image {
  width: 120pt;
  float: left;
}
.section .citation .image img {
  max-height: 60pt;
  width: 120pt;
  object-fit: cover;
}
.section .citation .comment{
  margin-left: 0pt;
  text-align: left;
  font-size: 14pt;
}


================================================
FILE: docs/index.html
================================================
<!doctype html>
<html lang="en">


<!-- === Header Starts === -->
<head>
  <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">

  <title>General Virtual Sketching Framework for Vector Line Art</title>

  <link href="./assets/bootstrap.min.css" rel="stylesheet">
  <link href="./assets/font.css" rel="stylesheet" type="text/css">
  <link href="./assets/style.css" rel="stylesheet" type="text/css">
</head>
<!-- === Header Ends === -->


<body>


<!-- === Home Section Starts === -->
<div class="section">
  <!-- === Title Starts === -->
    <div class="title">
      <b>General Virtual Sketching Framework for Vector Line Art</b>
    </div>
  <!-- === Title Ends === -->
  <div class="author">
    <a href="http://mo-haoran.com/" target="_blank">Haoran Mo</a><sup>1</sup>,&nbsp;
    <a href="https://esslab.jp/~ess/en/" target="_blank">Edgar Simo-Serra</a><sup>2</sup>,&nbsp;
    <a href="http://cse.sysu.edu.cn/content/2537" target="_blank">Chengying Gao</a><sup>*1</sup>,&nbsp;
    <a href="https://changqingzou.weebly.com/" target="_blank">Changqing Zou</a><sup>3</sup>,&nbsp;
    <a href="http://cse.sysu.edu.cn/content/2523" target="_blank">Ruomei Wang</a><sup>1</sup>
  </div>
  <div class="institution">
    <sup>1</sup>Sun Yat-sen University,&nbsp;
    <sup>2</sup>Waseda University,&nbsp;
    <br>
    <sup>3</sup>Huawei Technologies Canada
  </div>
  <br>
  <div class="institution">
    Accepted by <a href="https://s2021.siggraph.org/" target="_blank">ACM SIGGRAPH 2021</a>
  </div>
  <div class="link">
    <a href="https://esslab.jp/publications/HaoranSIGRAPH2021.pdf" target="_blank">[Paper]</a>&nbsp;
    <a href="https://github.com/MarkMoHR/virtual_sketching" target="_blank">[Code]</a>
  </div>
  <div class="teaser">
    <img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/teaser6.png" style="width: 100%;">
    <br>
    <br>
    <font size="3">
      Given clean line drawings, rough sketches or photographs of arbitrary resolution as input, our framework generates the corresponding vector line drawings directly. As shown in (b), the framework models a virtual pen surrounded by a dynamic window (red boxes), which moves while drawing the strokes. It learns to move around by scaling the window and sliding to an undrawn area for restarting the drawing (bottom example; sliding trajectory in blue arrow). With our proposed stroke regularization mechanism, the framework is able to enlarge the window and draw long strokes for simplicity (top example).
    </font>
  </div>
</div>
<!-- === Home Section Ends === -->


<!-- === Overview Section Starts === -->
<div class="section">
  <div class="title">Abstract</div>
  <div class="body">
    Vector line art plays an important role in graphic design, however, it is tedious to manually create.
    We introduce a general framework to produce line drawings from a wide variety of images,
    by learning a mapping from raster image space to vector image space.
    Our approach is based on a recurrent neural network that draws the lines one by one.
    A differentiable rasterization module allows for training with only supervised raster data.
    We use a dynamic window around a virtual pen while drawing lines,
    implemented with a proposed aligned cropping and differentiable pasting modules.
    Furthermore, we develop a stroke regularization loss
    that encourages the model to use fewer and longer strokes to simplify the resulting vector image.
    Ablation studies and comparisons with existing methods corroborate the efficiency of our approach
    which is able to generate visually better results in less computation time,
    while generalizing better to a diversity of images and applications.
  </div>
  <div class="link">
    <a href="https://esslab.jp/publications/HaoranSIGRAPH2021.pdf" target="_blank">[Paper]</a>&nbsp; &nbsp;
    <a href="https://dl.acm.org/doi/abs/10.1145/3450626.3459833" target="_blank">[Paper (ACM)]</a>&nbsp; &nbsp;
    <a href="https://markmohr.github.io/files/SIG2021/SketchVectorization_SIG2021_supplemental.pdf" target="_blank">[Supplementary]</a>&nbsp; &nbsp;
	  <a href="https://github.com/MarkMoHR/virtual_sketching" target="_blank">[Code]</a>&nbsp; &nbsp;
    <a href="https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing" target="_blank">[All Precomputed Results]</a>
	  <!-- <a href="" target="_blank">[Presentation (TBD)]</a>&nbsp; &nbsp; -->
  </div>
</div>
<!-- === Overview Section Ends === -->


<!-- === Result Section Starts === -->
<div class="section">
  <div class="title">Method</div>
  <br>
  <div class="body">
    <p style="text-align:center; font-size:23px; font-weight:bold">Framework Overview<p>
    <img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/framework6.png" width="100%">
    <br>
    <br>
    <font size="4">
      Our framework generates the parametrized strokes step by step in a recurrent manner.
      It uses a dynamic window (dashed red boxes) around a virtual pen to draw the strokes,
      and can both move and change the size of the window.
      (a) Four main modules at each time step: aligned cropping, stroke generation, differentiable rendering and differentiable pasting.
      (b) Architecture of the stroke generation module.
      (c) Structural strokes predicted at each step;
      movement only is illustrated by blue arrows during which no stroke is drawn on the canvas.
    </font>
    <br>
    <br>

    <p style="text-align:center; font-size:23px; font-weight:bold">
      Overall Introduction
    <p>
    <p style="text-align:center; font-size:20px">
      (Or watch on <a href="https://www.bilibili.com/video/BV1gM4y1V7i7/" target="_blank">Bilibili</a>)
      <br>
      👇
    <p>
    <!-- Adjust the frame size based on the demo (EVERY project differs). -->
    <div style="position: relative; padding-top: 50%; text-align: center;">
      <iframe src="https://www.youtube.com/embed/gXk3TMceByY" frameborder=0
              style="position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;"
              allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
              allowfullscreen></iframe>
    </div>

  </div>
</div>
<!-- === Result Section Ends === -->

<!-- === Result Section Starts === -->
<div class="section">
  <div class="title">Results</div>
  <div class="body">
    Our framework is applicable to a diversity of image types, such as clean line drawing images, rough sketches and photographs.

    <p style="margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold">Vectorization<p>
    <table width="100%" style="margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;">
      <tr>
        <td width="45%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/clean/muten.png" width="100%"></td>
        <td width="10%"></td>
        <td width="45%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/clean/muten-black-full-simplest.gif" width="100%"></td>
      </tr>
    </table>
    <br>

    <p style="margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold">Rough sketch simplification<p>
    <table width="100%" style="margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;">
      <tr>
        <td width="26%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/rocket.png" width="100%"></td>
        <td width="26%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/rocket-blue-simplest.gif" width="100%"></td>
        <td width="4%"></td>
        <td width="14%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/penguin.png" width="100%"></td>
        <td width="14%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/penguin-blue-simplest.gif" width="100%"></td>
      </tr>
    </table>
    <br>

    <p style="margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold">Photograph to line drawing<p>
    <table width="100%" style="margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;">
      <tr>
        <td width="23%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/1390_input.png" width="100%"></td>
        <td width="23%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/face-blue-1390-simplest.gif" width="100%"></td>
        <td width="8%"></td>
        <td width="23%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/1190_input.png" width="100%"></td>
        <td width="23%"><img src="https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/face-blue-1190-simplest.gif" width="100%"></td>
      </tr>
    </table>
    <br>

    <p style="margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold">
      More Results
    <p>
    <p style="text-align:center; font-size:20px">
      (Or watch on <a href="https://www.bilibili.com/video/BV1pv411N7Yx/" target="_blank">Bilibili</a>)
      <br>
      👇
    <p>
    <!-- Adjust the frame size based on the demo (EVERY project differs). -->
    <div style="position: relative; padding-top: 50%; text-align: center;">
      <iframe src="https://www.youtube.com/embed/Pr6mK9ddXkQ" frameborder=0
              style="position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;"
              allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
              allowfullscreen></iframe>
    </div>
    <br>

    <div class="link">
      <a href="https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing" target="_blank">
      [Download Our Precomputed Output Results (7MB)]</a>
    </div>

  </div>
</div>
<!-- === Result Section Ends === -->


<!-- === Result Section Starts === -->
<div class="section">
  <div class="title">Presentations</div>
  <div class="body">

    <p style="margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold">
      3-5 minute presentation
    <p>
    <p style="text-align:center; font-size:20px">
      (Or watch on <a href="https://www.bilibili.com/video/BV1S3411q7VX/" target="_blank">Bilibili</a>)
      <br>
      👇
    <p>
    <!-- Adjust the frame size based on the demo (EVERY project differs). -->
    <div style="position: relative; padding-top: 50%; text-align: center;">
      <iframe src="https://www.youtube.com/embed/BSJN1ixacts" frameborder=0
              style="position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;"
              allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
              allowfullscreen></iframe>
    </div>
    <br>

    <div class="link">
      👉 15-20 minute presentation:
      <a href="https://youtu.be/D_U4e1qh5qc" target="_blank">[YouTube]</a>
      <a href="https://www.bilibili.com/video/BV1uU4y1E7Wg/" target="_blank">[Bilibili]</a>
    </div>

    <div class="link">
      👉 30-second fast forward:
      <a href="https://youtu.be/d0EbSU_EeFg" target="_blank">[YouTube]</a>
      <a href="https://www.bilibili.com/video/BV1vq4y1M7j1/" target="_blank">[Bilibili]</a>
    </div>

  </div>
</div>
<!-- === Result Section Ends === -->


<!-- === Reference Section Starts === -->
<div class="section">
  <div class="bibtex">BibTeX</div>
<pre>
@article{mo2021virtualsketching,
    title   = {General Virtual Sketching Framework for Vector Line Art},
    author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
    journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
    year    = {2021},
    volume  = {40},
    number  = {4},
    pages   = {51:1--51:14}
}
</pre>

  <br>
  <div class="bibtex">Related Work</div>
  <div class="citation">
    <div class="comment">
      Jean-Dominique Favreau, Florent Lafarge and Adrien Bousseau.
      <strong>Fidelity vs. Simplicity: a Global Approach to Line Drawing Vectorization</strong>. SIGGRAPH 2016.
      [<a href="https://www-sop.inria.fr/reves/Basilic/2016/FLB16/fidelity_simplicity.pdf">Paper</a>]
      [<a href="https://www-sop.inria.fr/reves/Basilic/2016/FLB16/">Webpage</a>]
      <br><br>
    </div>

    <div class="comment">
      Mikhail Bessmeltsev and Justin Solomon. 
      <strong>Vectorization of Line Drawings via PolyVector Fields</strong>. SIGGRAPH 2019. 
      [<a href="https://arxiv.org/abs/1801.01922">Paper</a>]
      [<a href="https://github.com/bmpix/PolyVectorization">Code</a>]
      <br><br>
    </div>

    <div class="comment">
      Edgar Simo-Serra, Satoshi Iizuka and Hiroshi Ishikawa. 
      <strong>Mastering Sketching: Adversarial Augmentation for Structured Prediction</strong>. SIGGRAPH 2018. 
      [<a href="https://esslab.jp/~ess/publications/SimoSerraTOG2018.pdf">Paper</a>]
      [<a href="https://esslab.jp/~ess/en/research/sketch_master/">Webpage</a>]
      [<a href="https://github.com/bobbens/sketch_simplification">Code</a>]
      <br><br>
    </div>

    <div class="comment">
      Zhewei Huang, Wen Heng and Shuchang Zhou. 
      <strong>Learning to Paint With Model-based Deep Reinforcement Learning</strong>. ICCV 2019. 
      [<a href="https://openaccess.thecvf.com/content_ICCV_2019/papers/Huang_Learning_to_Paint_With_Model-Based_Deep_Reinforcement_Learning_ICCV_2019_paper.pdf">Paper</a>]
      [<a href="https://github.com/megvii-research/ICCV2019-LearningToPaint">Code</a>]
      <br><br>
    </div>
  </div>
</div>
<!-- === Reference Section Ends === -->


</body>
</html>


================================================
FILE: hyper_parameters.py
================================================
import tensorflow as tf


#############################################
# Common parameters
#############################################

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string(
    'dataset_dir',
    'datasets',
    'The directory of sketch data of the dataset.')
tf.app.flags.DEFINE_string(
    'log_root',
    'outputs/log',
    'Directory to store tensorboard.')
tf.app.flags.DEFINE_string(
    'log_img_root',
    'outputs/log_img',
    'Directory to store intermediate output images.')
tf.app.flags.DEFINE_string(
    'snapshot_root',
    'outputs/snapshot',
    'Directory to store model checkpoints.')
tf.app.flags.DEFINE_string(
    'neural_renderer_path',
    'outputs/snapshot/pretrain_neural_renderer/renderer_300000.tfmodel',
    'Path to the neural renderer model.')
tf.app.flags.DEFINE_string(
    'perceptual_model_root',
    'outputs/snapshot/pretrain_perceptual_model',
    'Directory to store perceptual model.')
tf.app.flags.DEFINE_string(
    'data',
    '',
    'The dataset type.')


def get_default_hparams_clean():
    """Return default HParams for sketch-rnn."""
    hparams = tf.contrib.training.HParams(
        program_name='new_train_clean_line_drawings',
        data_set='clean_line_drawings',  # Our dataset.

        input_channel=1,

        num_steps=75040,  # Total number of steps of training.
        save_every=75000,
        eval_every=5000,

        max_seq_len=48,
        batch_size=20,
        gpus=[0, 1],
        loop_per_gpu=1,

        sn_loss_type='increasing',  # ['decreasing', 'fixed', 'increasing']
        stroke_num_loss_weight=0.02,
        stroke_num_loss_weight_end=0.0,
        increase_start_steps=25000,
        decrease_stop_steps=40000,

        perc_loss_layers=['ReLU1_2', 'ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']

        init_cursor_on_undrawn_pixel=False,

        early_pen_loss_type='move',  # ['head', 'tail', 'move']
        early_pen_loss_weight=0.1,
        early_pen_length=7,

        min_width=0.01,
        min_window_size=32,
        max_scaling=2.0,

        encode_cursor_type='value',

        image_size_small=128,
        image_size_large=278,

        cropping_type='v3',  # ['v2', 'v3']
        pasting_type='v3',  # ['v2', 'v3']
        pasting_diff=True,

        concat_win_size=True,

        encoder_type='conv13_c3',
        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
        # ['conv13_c3_attn']
        # ['combine33', 'combine43', 'combine53', 'combineFC']
        vary_thickness=False,

        outside_loss_weight=10.0,
        win_size_outside_loss_weight=10.0,

        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']

        concat_cursor=True,

        use_softargmax=True,
        soft_beta=10,  # value for the soft argmax

        raster_loss_weight=1.0,

        dec_rnn_size=256,  # Size of decoder.
        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.
        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.
        bin_gt=True,

        stop_accu_grad=True,

        random_cursor=True,
        cursor_type='next',

        raster_size=128,

        pix_drop_kp=1.0,  # Dropout keep rate
        add_coordconv=True,
        position_format='abs',
        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]

        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.

        learning_rate=0.0001,  # Learning rate.
        decay_rate=0.9999,  # Learning rate decay per minibatch.
        decay_power=0.9,
        min_learning_rate=0.000001,  # Minimum learning rate.

        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended
        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.
        use_input_dropout=False,  # Input dropout. Recommend leaving False.
        input_dropout_prob=0.90,  # Probability of input dropout keep.
        use_output_dropout=False,  # Output dropout. Recommend leaving False.
        output_dropout_prob=0.90,  # Probability of output dropout keep.

        model_mode='train'  # ['train', 'eval', 'sample']
    )
    return hparams


def get_default_hparams_rough():
    """Return default HParams for sketch-rnn."""
    hparams = tf.contrib.training.HParams(
        program_name='new_train_rough_sketches',
        data_set='rough_sketches',  # ['rough_sketches', 'faces']

        input_channel=3,

        num_steps=90040,  # Total number of steps of training.
        save_every=90000,
        eval_every=5000,

        max_seq_len=48,
        batch_size=20,
        gpus=[0, 1],
        loop_per_gpu=1,

        sn_loss_type='increasing',  # ['decreasing', 'fixed', 'increasing']
        stroke_num_loss_weight=0.1,
        stroke_num_loss_weight_end=0.0,
        increase_start_steps=25000,
        decrease_stop_steps=40000,

        photo_prob_type='one',  # ['increasing', 'zero', 'one']
        photo_prob_start_step=35000,

        perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']

        early_pen_loss_type='move',  # ['head', 'tail', 'move']
        early_pen_loss_weight=0.2,
        early_pen_length=7,

        min_width=0.01,
        min_window_size=32,
        max_scaling=2.0,

        encode_cursor_type='value',

        image_size_small=128,
        image_size_large=278,

        cropping_type='v3',  # ['v2', 'v3']
        pasting_type='v3',  # ['v2', 'v3']
        pasting_diff=True,

        concat_win_size=True,

        encoder_type='conv13_c3',
        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
        # ['conv13_c3_attn']
        # ['combine33', 'combine43', 'combine53', 'combineFC']

        outside_loss_weight=10.0,
        win_size_outside_loss_weight=10.0,

        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']

        concat_cursor=True,

        use_softargmax=True,
        soft_beta=10,  # value for the soft argmax

        raster_loss_weight=1.0,

        dec_rnn_size=256,  # Size of decoder.
        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.
        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.
        bin_gt=True,

        stop_accu_grad=True,

        random_cursor=True,
        cursor_type='next',

        raster_size=128,

        pix_drop_kp=1.0,  # Dropout keep rate
        add_coordconv=True,
        position_format='abs',
        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]

        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.

        learning_rate=0.0001,  # Learning rate.
        decay_rate=0.9999,  # Learning rate decay per minibatch.
        decay_power=0.9,
        min_learning_rate=0.000001,  # Minimum learning rate.

        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended
        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.
        use_input_dropout=False,  # Input dropout. Recommend leaving False.
        input_dropout_prob=0.90,  # Probability of input dropout keep.
        use_output_dropout=False,  # Output dropout. Recommend leaving False.
        output_dropout_prob=0.90,  # Probability of output dropout keep.

        model_mode='train'  # ['train', 'eval', 'sample']
    )
    return hparams


def get_default_hparams_normal():
    """Return default HParams for sketch-rnn."""
    hparams = tf.contrib.training.HParams(
        program_name='new_train_faces',
        data_set='faces',  # ['rough_sketches', 'faces']

        input_channel=3,

        num_steps=90040,  # Total number of steps of training.
        save_every=90000,
        eval_every=5000,

        max_seq_len=48,
        batch_size=20,
        gpus=[0, 1],
        loop_per_gpu=1,

        sn_loss_type='fixed',  # ['decreasing', 'fixed', 'increasing']
        stroke_num_loss_weight=0.0,
        stroke_num_loss_weight_end=0.0,
        increase_start_steps=0,
        decrease_stop_steps=40000,

        photo_prob_type='interpolate',  # ['increasing', 'zero', 'one', 'interpolate']
        photo_prob_start_step=30000,
        photo_prob_end_step=60000,

        perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU4_2', 'ReLU5_1'],
        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']

        early_pen_loss_type='move',  # ['head', 'tail', 'move']
        early_pen_loss_weight=0.2,
        early_pen_length=7,

        min_width=0.01,
        min_window_size=32,
        max_scaling=2.0,

        encode_cursor_type='value',

        image_size_small=128,
        image_size_large=256,

        cropping_type='v3',  # ['v2', 'v3']
        pasting_type='v3',  # ['v2', 'v3']
        pasting_diff=True,

        concat_win_size=True,

        encoder_type='conv13_c3',
        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
        # ['conv13_c3_attn']
        # ['combine33', 'combine43', 'combine53', 'combineFC']

        outside_loss_weight=10.0,
        win_size_outside_loss_weight=10.0,

        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']

        concat_cursor=True,

        use_softargmax=True,
        soft_beta=10,  # value for the soft argmax

        raster_loss_weight=1.0,

        dec_rnn_size=256,  # Size of decoder.
        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.
        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.
        bin_gt=True,

        stop_accu_grad=True,

        random_cursor=True,
        cursor_type='next',

        raster_size=128,

        pix_drop_kp=1.0,  # Dropout keep rate
        add_coordconv=True,
        position_format='abs',
        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]

        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.

        learning_rate=0.0001,  # Learning rate.
        decay_rate=0.9999,  # Learning rate decay per minibatch.
        decay_power=0.9,
        min_learning_rate=0.000001,  # Minimum learning rate.

        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended
        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.
        use_input_dropout=False,  # Input dropout. Recommend leaving False.
        input_dropout_prob=0.90,  # Probability of input dropout keep.
        use_output_dropout=False,  # Output dropout. Recommend leaving False.
        output_dropout_prob=0.90,  # Probability of output dropout keep.

        model_mode='train'  # ['train', 'eval', 'sample']
    )
    return hparams


================================================
FILE: launch_gui.bat
================================================
@echo OFF

REM === Cesta k instalaci Anacondy ===
set "CONDAPATH=C:\ProgramData\anaconda3"

REM === Název a cesta k prostředí ===
set "ENVNAME=virtual_sketching"
set "ENVPATH=%USERPROFILE%\.conda\envs\%ENVNAME%"

REM === Aktivace prostředí ===
call "%CONDAPATH%\Scripts\activate.bat" "%ENVPATH%"

REM === Spuštění GUI ===
python virtual_sketch_gui.py

REM === Pozastavení po ukončení ===
echo.
pause

REM === Deaktivace prostředí ===
call conda deactivate


================================================
FILE: model_common_test.py
================================================
import rnn
import tensorflow as tf

from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
    generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
    generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
    generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
    generative_cnn_c3_encoder_deeper13_attn


class DiffPastingV3(object):
    def __init__(self, raster_size):
        self.patch_canvas = tf.placeholder(dtype=tf.float32,
                                           shape=(None, None, 1))  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        self.cursor_pos_a = tf.placeholder(dtype=tf.float32, shape=(2))  # (2), float32, in large size
        self.image_size_a = tf.placeholder(dtype=tf.int32, shape=())  # ()
        self.window_size_a = tf.placeholder(dtype=tf.float32, shape=())  # (), float32, with grad
        self.raster_size_a = float(raster_size)

        self.pasted_image = self.image_pasting_sampling_v3()
        # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]

    def image_pasting_sampling_v3(self):
        padding_size = tf.cast(tf.ceil(self.window_size_a / 2.0), tf.int32)

        x1y1_a = self.cursor_pos_a - self.window_size_a / 2.0  # (2), float32
        x2y2_a = self.cursor_pos_a + self.window_size_a / 2.0  # (2), float32

        x1y1_a_floor = tf.floor(x1y1_a)  # (2)
        x2y2_a_ceil = tf.ceil(x2y2_a)  # (2)

        cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0  # (2)
        cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / self.window_size_a * self.raster_size_a  # (2)
        raster_size_b = (x2y2_a_ceil - x1y1_a_floor)  # (x, y)
        image_size_b = self.raster_size_a
        window_size_b = self.raster_size_a * (raster_size_b / self.window_size_a)  # (x, y)

        cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1)  # (1)

        y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
        x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
        y2_b = y1_b + (window_size_b[1] - 1.)
        x2_b = x1_b + (window_size_b[0] - 1.)
        boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1)  # (4)
        boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32)  # with grad to window_size_a

        box_ind_b = tf.ones((1), dtype=tf.int32)  # (1)
        box_ind_b = tf.cumsum(box_ind_b) - 1

        patch_canvas = tf.expand_dims(self.patch_canvas,
                                      axis=0)  # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        boxes_b = tf.expand_dims(boxes_b, axis=0)  # (1, 4)

        valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
                                                crop_size=[raster_size_b[1], raster_size_b[0]])
        valid_canvas = valid_canvas[0]  # (raster_size_b, raster_size_b, 1)

        pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
        pad_down = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
        pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
        pad_right = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)

        paddings = [[pad_up, pad_down],
                    [pad_left, pad_right],
                    [0, 0]]
        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]

        pasted_image = pad_img[padding_size: padding_size + self.image_size_a,
                       padding_size: padding_size + self.image_size_a, :]
        # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
        return pasted_image


class VirtualSketchingModel(object):
    def __init__(self, hps, gpu_mode=True, reuse=False):
        """Initializer for the model.

    Args:
       hps: a HParams object containing model hyperparameters
       gpu_mode: a boolean that when True, uses GPU mode.
       reuse: a boolean that when true, attemps to reuse variables.
    """
        self.hps = hps
        assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
        # with tf.variable_scope('SCC', reuse=reuse):
        if not gpu_mode:
            with tf.device('/cpu:0'):
                print('Model using cpu.')
                self.build_model()
        else:
            print('-' * 100)
            print('model_mode:', hps.model_mode)
            print('Model using gpu.')
            self.build_model()

    def build_model(self):
        """Define model architecture."""
        self.config_model()

        initial_state = self.get_decoder_inputs()
        self.initial_state = initial_state

        ## use pred as the prev points
        other_params, pen_ras, final_state = self.get_points_and_raster_image(self.image_size)
        # other_params: (N * max_seq_len, 6)
        # pen_ras: (N * max_seq_len, 2), after softmax

        self.other_params = other_params  # (N * max_seq_len, 6)
        self.pen_ras = pen_ras  # (N * max_seq_len, 2), after softmax
        self.final_state = final_state

        if not self.hps.use_softargmax:
            pen_state_soft = pen_ras[:, 1:2]  # (N * max_seq_len, 1)
        else:
            pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta)  # (N * max_seq_len, 1)

        pred_params = tf.concat([pen_state_soft, other_params], axis=1)  # (N * max_seq_len, 7)
        self.pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7])  # (N, max_seq_len, 7)
        # pred_params: (N, max_seq_len, 7)

    def config_model(self):
        if self.hps.model_mode == 'train':
            self.global_step = tf.Variable(0, name='global_step', trainable=False)

        if self.hps.dec_model == 'lstm':
            dec_cell_fn = rnn.LSTMCell
        elif self.hps.dec_model == 'layer_norm':
            dec_cell_fn = rnn.LayerNormLSTMCell
        elif self.hps.dec_model == 'hyper':
            dec_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        dec_cell = dec_cell_fn(
            self.hps.dec_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        # print('Input dropout mode = %s.' % use_input_dropout)
        # print('Output dropout mode = %s.' % use_output_dropout)
        # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
        if use_input_dropout:
            print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
            dec_cell = tf.contrib.rnn.DropoutWrapper(
                dec_cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
            dec_cell = tf.contrib.rnn.DropoutWrapper(
                dec_cell, output_keep_prob=self.hps.output_dropout_prob)
        self.dec_cell = dec_cell

        self.input_photo = tf.placeholder(dtype=tf.float32,
                                          shape=[self.hps.batch_size, None, None, self.hps.input_channel])  # [0.0-stroke, 1.0-BG]
        self.init_cursor = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, 1, 2])  # (N, 1, 2), in size [0.0, 1.0)
        self.init_width = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size])  # (1), in [0.0, 1.0]
        self.init_scaling = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size])  # (N), in [0.0, 1.0]
        self.init_window_size = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size])  # (N)
        self.image_size = tf.placeholder(dtype=tf.int32, shape=())  # ()

    ###########################

    def normalize_image_m1to1(self, in_img_0to1):
        norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
        norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
        return norm_img_m1to1

    def add_coords(self, input_tensor):
        batch_size_tensor = tf.shape(input_tensor)[0]  # get N size

        xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)
        xx_ones = tf.expand_dims(xx_ones, -1)  # e.g. (N, raster_size, 1)
        xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
                           [batch_size_tensor, 1])  # e.g. (N, raster_size)
        xx_range = tf.expand_dims(xx_range, 1)  # e.g. (N, 1, raster_size)

        xx_channel = tf.matmul(xx_ones, xx_range)  # e.g. (N, raster_size, raster_size)
        xx_channel = tf.expand_dims(xx_channel, -1)  # e.g. (N, raster_size, raster_size, 1)

        yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)
        yy_ones = tf.expand_dims(yy_ones, 1)  # e.g. (N, 1, raster_size)
        yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
                           [batch_size_tensor, 1])  # (N, raster_size)
        yy_range = tf.expand_dims(yy_range, -1)  # e.g. (N, raster_size, 1)

        yy_channel = tf.matmul(yy_range, yy_ones)  # e.g. (N, raster_size, raster_size)
        yy_channel = tf.expand_dims(yy_channel, -1)  # e.g. (N, raster_size, raster_size, 1)

        xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
        yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
        # xx_channel = xx_channel * 2 - 1  # [-1, 1]
        # yy_channel = yy_channel * 2 - 1

        ret = tf.concat([
            input_tensor,
            xx_channel,
            yy_channel,
        ], axis=-1)  # e.g. (N, raster_size, raster_size, 4)

        return ret

    def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
                               image_size, window_size):
        """
        :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
        :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
        :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
        :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
        :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
        :param window_size: (N, 1, 1), float, in large size
        :return:
        """
        if self.hps.resize_method == 'BILINEAR':
            resize_method = tf.image.ResizeMethod.BILINEAR
        elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
            resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
        elif self.hps.resize_method == 'BICUBIC':
            resize_method = tf.image.ResizeMethod.BICUBIC
        elif self.hps.resize_method == 'AREA':
            resize_method = tf.image.ResizeMethod.AREA
        else:
            raise Exception('unknown resize_method', self.hps.resize_method)

        patch_photo = tf.stop_gradient(patch_photo)
        patch_canvas = tf.stop_gradient(patch_canvas)
        cursor_pos = tf.stop_gradient(cursor_pos)
        window_size = tf.stop_gradient(window_size)

        entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
                                                                      (self.hps.raster_size, self.hps.raster_size),
                                                                      method=resize_method))
        entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
                                                                      (self.hps.raster_size, self.hps.raster_size),
                                                                      method=resize_method))
        entire_photo_small = self.normalize_image_m1to1(entire_photo_small)  # [-1.0-stroke, 1.0-BG]
        entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small)  # [-1.0-stroke, 1.0-BG]

        if self.hps.encode_cursor_type == 'value':
            cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1)  # (N, 1, 1, 2)
            cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
            cursor_info = cursor_pos_norm
        else:
            raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)

        batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
                                axis=-1)  # [N, raster_size, raster_size, 6/10]
        batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1)  # [N, raster_size, raster_size, 2/4]
        batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
                                       axis=-1)  # [N, raster_size, raster_size, 4/6]

        if self.hps.model_mode == 'train':
            is_training = True
            dropout_keep_prob = self.hps.pix_drop_kp
        else:
            is_training = False
            dropout_keep_prob = 1.0

        if self.hps.add_coordconv:
            batch_input_combined = self.add_coords(batch_input_combined)  # (N, in_H, in_W, in_dim + 2)
            batch_input_local = self.add_coords(batch_input_local)  # (N, in_H, in_W, in_dim + 2)
            batch_input_global = self.add_coords(batch_input_global)  # (N, in_H, in_W, in_dim + 2)

        if 'combine' in self.hps.encoder_type:
            if self.hps.encoder_type == 'combine33':
                image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combine43':
                image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combine53':
                image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combineFC':
                image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 256)
            else:
                raise Exception('Unknown encoder_type', self.hps.encoder_type)
        else:
            with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
                if self.hps.encoder_type == 'conv10':
                    image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_deep':
                    image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)
                elif self.hps.encoder_type == 'conv13':
                    image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_c3':
                    image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_deep_c3':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)
                elif self.hps.encoder_type == 'conv13_c3':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv13_c3_attn':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                else:
                    raise Exception('Unknown encoder_type', self.hps.encoder_type)
        return image_embedding

    def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
        rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
        rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])

        pen_n_out = 2
        params_n_out = 6

        with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
            output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
            output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
            output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen)  # (N, pen_n_out)

        with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
            output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
            output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
            output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params)  # (N, params_n_out)

        output = tf.concat([output_pen, output_params], axis=1)  # (N, n_out)

        return output, last_state

    def get_mixture_coef(self, outputs):
        z = outputs
        z_pen_logits = z[:, 0:2]  # (N, 2), pen states
        z_other_params_logits = z[:, 2:]  # (N, 6)

        z_pen = tf.nn.softmax(z_pen_logits)  # (N, 2)
        if self.hps.position_format == 'abs':
            x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2])  # (N, 2)
            x2y2 = tf.tanh(z_other_params_logits[:, 2:4])  # (N, 2)
            widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5])  # (N, 1)
            widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
            scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling  # (N, 1), [0.0, max_scaling]
            # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
            #                  self.hps.min_scaling)
            z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1)  # (N, 6)
        else:  # "rel"
            raise Exception('Unknown position_format', self.hps.position_format)

        r = [z_other_params, z_pen]
        return r

    ###########################

    def get_decoder_inputs(self):
        initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
        return initial_state

    def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
        with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
            output, last_state = tf.nn.dynamic_rnn(
                dec_cell,
                actual_input_x,
                initial_state=initial_state,
                time_major=False,
                swap_memory=True,
                dtype=tf.float32)
        return output, last_state

    ###########################

    def image_padding(self, ori_image, window_size, pad_value):
        """
        Pad with (bg)
        :param ori_image:
        :return:
        """
        paddings = [[0, 0],
                    [window_size // 2, window_size // 2],
                    [window_size // 2, window_size // 2],
                    [0, 0]]
        pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value)  # (N, H_p, W_p, k)
        return pad_img

    def image_cropping_fn(self, fn_inputs):
        """
        crop the patch
        :return:
        """
        index_offset = self.hps.input_channel - 1
        input_image = fn_inputs[:, :, 0:2 + index_offset]  # (image_size, image_size, -), [0.0-BG, 1.0-stroke]
        cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset]  # (2), in [0.0, 1.0)
        image_size = fn_inputs[0, 0, 4 + index_offset]  # (), float32
        window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32)  # ()

        input_img_reshape = tf.expand_dims(input_image, axis=0)
        pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)

        cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
        x0, x1 = cursor_pos[0], cursor_pos[0] + window_size  # ()
        y0, y1 = cursor_pos[1], cursor_pos[1] + window_size  # ()
        patch_image = pad_img[:, y0:y1, x0:x1, :]  # (1, window_size, window_size, 2/4)

        # resize to raster_size
        patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
                                                    method=tf.image.ResizeMethod.AREA)
        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
        # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]

        return patch_image_scaled

    def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
        """
        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
        :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        """
        input_img_ = input_img
        window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes))  # (N, 1, 1), no grad

        cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)
        cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 2)

        image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (1, 1, 1, 1)
        image_size_ = tf.tile(image_size_, [self.hps.batch_size, image_size, image_size, 1])

        window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1))  # (N, 1, 1, 1)
        window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 1)

        fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
                              axis=-1)  # (N, image_size, image_size, 2/4 + 4)
        curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32)  # (N, raster_size, raster_size, -)
        return curr_patch_imgs

    def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
        """
        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
        :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        """
        window_sizes_non_grad = tf.stop_gradient(window_sizes)  # (N, 1, 1), no grad

        cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
        cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1)  # (N, 1, 1)

        y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
        x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
        y2 = y1 + (window_sizes_non_grad - 1.0)
        x2 = x1 + (window_sizes_non_grad - 1.0)
        boxes = tf.concat([y1, x1, y2, x2], axis=-1)  # (N, 1, 4)
        boxes = tf.squeeze(boxes, axis=1)  # (N, 4)
        boxes = boxes / tf.cast(image_size - 1, tf.float32)

        box_ind = tf.ones_like(cursor_x)[:, 0, 0]  # (N)
        box_ind = tf.cast(box_ind, dtype=tf.int32)
        box_ind = tf.cumsum(box_ind) - 1

        curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
                                                   crop_size=[self.hps.raster_size, self.hps.raster_size])
        #  (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
        return curr_patch_imgs

    def get_points_and_raster_image(self, image_size):
        ## generate the other_params and pen_ras and raster image for raster loss
        prev_state = self.initial_state  # (N, dec_rnn_size * 3)

        prev_width = self.init_width  # (N)
        prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=-1), axis=-1)  # (N, 1, 1)

        prev_scaling = self.init_scaling  # (N)
        prev_scaling = tf.reshape(prev_scaling, (-1, 1, 1))  # (N, 1, 1)

        prev_window_size = self.init_window_size  # (N)
        prev_window_size = tf.reshape(prev_window_size, (-1, 1, 1))  # (N, 1, 1)

        cursor_position_temp = self.init_cursor
        self.cursor_position = cursor_position_temp  # (N, 1, 2), in size [0.0, 1.0)
        cursor_position_loop = self.cursor_position

        other_params_list = []
        pen_ras_list = []

        curr_canvas_soft = tf.zeros_like(self.input_photo[:, :, :, 0])  # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
        curr_canvas_hard = tf.zeros_like(curr_canvas_soft)  # [0.0-BG, 1.0-stroke]

        #### sampling part - start ####
        self.curr_canvas_hard = curr_canvas_hard

        if self.hps.cropping_type == 'v3':
            cropping_func = self.image_cropping_v3
        # elif self.hps.cropping_type == 'v2':
        #     cropping_func = self.image_cropping
        else:
            raise Exception('Unknown cropping_type', self.hps.cropping_type)

        for time_i in range(self.hps.max_seq_len):
            cursor_position_non_grad = tf.stop_gradient(cursor_position_loop)  # (N, 1, 2), in size [0.0, 1.0)

            curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size))  # float, with grad
            curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
            curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))

            ## patch-level encoding
            # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
            curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
            curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)

            # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
            crop_inputs = tf.concat([1.0 - self.input_photo, curr_canvas_hard_non_grad], axis=-1)  # (N, H_p, W_p, 1+1)

            cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
            index_offset = self.hps.input_channel - 1
            curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset]  # [0.0-BG, 1.0-stroke]
            curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
            # (N, raster_size, raster_size, 1/3), [0.0-BG, 1.0-stroke]

            curr_patch_inputs = 1.0 - curr_patch_inputs  # [0.0-stroke, 1.0-BG]
            curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
            # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]

            # Normalizing image
            curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad  # [0.0-stroke, 1.0-BG]
            curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad)  # [-1.0-stroke, 1.0-BG]

            ## image-level encoding
            combined_z = self.build_combined_encoder(
                curr_patch_canvas_hard_non_grad,
                curr_patch_inputs,
                1.0 - curr_canvas_hard_non_grad,
                self.input_photo,
                cursor_position_non_grad,
                image_size,
                curr_window_size)  # (N, z_size)
            combined_z = tf.expand_dims(combined_z, axis=1)  # (N, 1, z_size)

            curr_window_size_top_side_norm_non_grad = \
                tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
            curr_window_size_bottom_side_norm_non_grad = \
                tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
            if not self.hps.concat_win_size:
                combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2)  # (N, 1, 2+z_size)
            else:
                combined_z = tf.concat([tf.stop_gradient(prev_width),
                                        curr_window_size_top_side_norm_non_grad,
                                        curr_window_size_bottom_side_norm_non_grad,
                                        combined_z],
                                       2)  # (N, 1, 2+z_size)

            if self.hps.concat_cursor:
                prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2)  # (N, 1, 2+2+z_size)
            else:
                prev_input_x = combined_z  # (N, 1, 2+z_size)

            h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
            # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
            [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
            # o_other_params: (N * 1, 6)
            # o_pen_ras: (N * 1, 2), after softmax

            o_other_params = tf.reshape(o_other_params, [-1, 1, 6])  # (N, 1, 6)
            o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2])  # (N, 1, 2)

            other_params_list.append(o_other_params)
            pen_ras_list.append(o_pen_ras_raw)

            #### sampling part - end ####

            prev_state = next_state

        other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6])  # (N * max_seq_len, 6)
        pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2])  # (N * max_seq_len, 2)

        return other_params_, pen_ras_, prev_state

    def differentiable_argmax(self, input_pen, soft_beta):
        """
        Differentiable argmax trick.
        :param input_pen: (N, n_class)
        :return: pen_state: (N, 1)
        """
        def sign_onehot(x):
            """
            :param x: (N, n_class)
            :return:  (N, n_class)
            """
            y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)
            y = (y - 1) * (-1)
            return y

        def softargmax(x, beta=1e2):
            """
            :param x: (N, n_class)
            :param beta: 1e10 is the best. 1e2 is acceptable.
            :return:  (N)
            """
            x_range = tf.cumsum(tf.ones_like(x), axis=1)  # (N, 2)
            return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1

        ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.
        # pen_onehot = sign_onehot(input_pen)  # one-hot form, (N * max_seq_len, 2)
        # pen_state = pen_onehot[:, 1:2]  # (N * max_seq_len, 1)
        pen_state = softargmax(input_pen, soft_beta)
        pen_state = tf.expand_dims(pen_state, axis=1)  # (N * max_seq_len, 1)
        return pen_state


================================================
FILE: model_common_train.py
================================================
import rnn
import tensorflow as tf

from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
    generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
    generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
    generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
    generative_cnn_c3_encoder_deeper13_attn
from rasterization_utils.NeuralRenderer import NeuralRasterizorStep
from vgg_utils.VGG16 import vgg_net_slim


class VirtualSketchingModel(object):
    def __init__(self, hps, gpu_mode=True, reuse=False):
        """Initializer for the model.

    Args:
       hps: a HParams object containing model hyperparameters
       gpu_mode: a boolean that when True, uses GPU mode.
       reuse: a boolean that when true, attemps to reuse variables.
    """
        self.hps = hps
        assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
        # with tf.variable_scope('SCC', reuse=reuse):
        if not gpu_mode:
            with tf.device('/cpu:0'):
                print('Model using cpu.')
                self.build_model()
        else:
            print('-' * 100)
            print('model_mode:', hps.model_mode)
            print('Model using gpu.')
            self.build_model()

    def build_model(self):
        """Define model architecture."""
        self.config_model()

        initial_state = self.get_decoder_inputs()
        self.initial_state = initial_state
        self.initial_state_list = tf.split(self.initial_state, self.total_loop, axis=0)

        total_loss_list = []
        ras_loss_list = []
        perc_relu_raw_list = []
        perc_relu_norm_list = []
        sn_loss_list = []
        cursor_outside_loss_list = []
        win_size_outside_loss_list = []
        early_state_loss_list = []

        tower_grads = []

        pred_raster_imgs_list = []
        pred_raster_imgs_rgb_list = []

        for t_i in range(self.total_loop):
            gpu_idx = t_i // self.hps.loop_per_gpu
            gpu_i = self.hps.gpus[gpu_idx]
            print(self.hps.model_mode, 'model, gpu:', gpu_i, ', loop:', t_i % self.hps.loop_per_gpu)
            with tf.device('/gpu:%d' % gpu_i):
                with tf.name_scope('GPU_%d' % gpu_i) as scope:
                    if t_i > 0:
                        tf.get_variable_scope().reuse_variables()
                    else:
                        total_loss_list.clear()
                        ras_loss_list.clear()
                        perc_relu_raw_list.clear()
                        perc_relu_norm_list.clear()
                        sn_loss_list.clear()
                        cursor_outside_loss_list.clear()
                        win_size_outside_loss_list.clear()
                        early_state_loss_list.clear()
                        tower_grads.clear()
                        pred_raster_imgs_list.clear()
                        pred_raster_imgs_rgb_list.clear()

                    split_input_photo = self.input_photo_list[t_i]
                    split_image_size = self.image_size[t_i]
                    split_init_cursor = self.init_cursor_list[t_i]
                    split_initial_state = self.initial_state_list[t_i]
                    if self.hps.input_channel == 1:
                        split_target_sketch = split_input_photo
                    else:
                        split_target_sketch = self.target_sketch_list[t_i]

                    ## use pred as the prev points
                    other_params, pen_ras, final_state, pred_raster_images, pred_raster_images_rgb, \
                    pos_before_max_min, win_size_before_max_min \
                        = self.get_points_and_raster_image(split_initial_state, split_init_cursor, split_input_photo,
                                                           split_image_size)
                    # other_params: (N * max_seq_len, 6)
                    # pen_ras: (N * max_seq_len, 2), after softmax
                    # pos_before_max_min: (N, max_seq_len, 2), in image_size
                    # win_size_before_max_min: (N, max_seq_len, 1), in image_size

                    pred_raster_imgs = 1.0 - pred_raster_images  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
                    pred_raster_imgs_rgb = 1.0 - pred_raster_images_rgb  # (N, image_size, image_size, 3)
                    pred_raster_imgs_list.append(pred_raster_imgs)
                    pred_raster_imgs_rgb_list.append(pred_raster_imgs_rgb)

                    if not self.hps.use_softargmax:
                        pen_state_soft = pen_ras[:, 1:2]  # (N * max_seq_len, 1)
                    else:
                        pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta)  # (N * max_seq_len, 1)

                    pred_params = tf.concat([pen_state_soft, other_params], axis=1)  # (N * max_seq_len, 7)
                    pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7])  # (N, max_seq_len, 7)
                    # pred_params: (N, max_seq_len, 7)

                    if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
                        raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, \
                        early_pen_states_cost, \
                        perc_relu_loss_raw, perc_relu_loss_norm = \
                            self.build_losses(split_target_sketch, pred_raster_imgs, pred_params,
                                              pos_before_max_min, win_size_before_max_min,
                                              split_image_size)
                        # perc_relu_loss_raw, perc_relu_loss_norm: (n_layers)

                        ras_loss_list.append(raster_cost)
                        perc_relu_raw_list.append(perc_relu_loss_raw)
                        perc_relu_norm_list.append(perc_relu_loss_norm)
                        sn_loss_list.append(sn_cost)
                        cursor_outside_loss_list.append(cursor_outside_cost)
                        win_size_outside_loss_list.append(winsize_outside_cost)
                        early_state_loss_list.append(early_pen_states_cost)

                        if self.hps.model_mode == 'train':
                            total_cost_split, grads_and_vars_split = self.build_training_op_split(
                                raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost,
                                early_pen_states_cost)
                            total_loss_list.append(total_cost_split)
                            tower_grads.append(grads_and_vars_split)

        self.raster_cost = tf.reduce_mean(tf.stack(ras_loss_list, axis=0))
        self.perc_relu_losses_raw = tf.reduce_mean(tf.stack(perc_relu_raw_list, axis=0), axis=0)  # (n_layers)
        self.perc_relu_losses_norm = tf.reduce_mean(tf.stack(perc_relu_norm_list, axis=0), axis=0)  # (n_layers)
        self.stroke_num_cost = tf.reduce_mean(tf.stack(sn_loss_list, axis=0))
        self.pos_outside_cost = tf.reduce_mean(tf.stack(cursor_outside_loss_list, axis=0))
        self.win_size_outside_cost = tf.reduce_mean(tf.stack(win_size_outside_loss_list, axis=0))
        self.early_pen_states_cost = tf.reduce_mean(tf.stack(early_state_loss_list, axis=0))
        self.cost = tf.reduce_mean(tf.stack(total_loss_list, axis=0))

        self.pred_raster_imgs = tf.concat(pred_raster_imgs_list, axis=0)  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
        self.pred_raster_imgs_rgb = tf.concat(pred_raster_imgs_rgb_list, axis=0)  # (N, image_size, image_size, 3)

        if self.hps.model_mode == 'train':
            self.build_training_op(tower_grads)

    def config_model(self):
        if self.hps.model_mode == 'train':
            self.global_step = tf.Variable(0, name='global_step', trainable=False)

        if self.hps.dec_model == 'lstm':
            dec_cell_fn = rnn.LSTMCell
        elif self.hps.dec_model == 'layer_norm':
            dec_cell_fn = rnn.LayerNormLSTMCell
        elif self.hps.dec_model == 'hyper':
            dec_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        dec_cell = dec_cell_fn(
            self.hps.dec_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        # print('Input dropout mode = %s.' % use_input_dropout)
        # print('Output dropout mode = %s.' % use_output_dropout)
        # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
        if use_input_dropout:
            print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
            dec_cell = tf.contrib.rnn.DropoutWrapper(
                dec_cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
            dec_cell = tf.contrib.rnn.DropoutWrapper(
                dec_cell, output_keep_prob=self.hps.output_dropout_prob)
        self.dec_cell = dec_cell

        self.total_loop = len(self.hps.gpus) * self.hps.loop_per_gpu

        self.init_cursor = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, 1, 2])  # (N, 1, 2), in size [0.0, 1.0)
        self.init_width = tf.placeholder(
            dtype=tf.float32,
            shape=[1])  # (1), in [0.0, 1.0]
        self.image_size = tf.placeholder(dtype=tf.int32, shape=(self.total_loop))  # ()

        self.init_cursor_list = tf.split(self.init_cursor, self.total_loop, axis=0)
        self.input_photo_list = []
        for loop_i in range(self.total_loop):
            input_photo_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, self.hps.input_channel])  # [0.0-stroke, 1.0-BG]
            self.input_photo_list.append(input_photo_i)

        if self.hps.input_channel == 3:
            self.target_sketch_list = []
            for loop_i in range(self.total_loop):
                target_sketch_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1])  # [0.0-stroke, 1.0-BG]
                self.target_sketch_list.append(target_sketch_i)

        if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
            self.stroke_num_loss_weight = tf.Variable(0.0, trainable=False)
            self.early_pen_loss_start_idx = tf.Variable(0, dtype=tf.int32, trainable=False)
            self.early_pen_loss_end_idx = tf.Variable(0, dtype=tf.int32, trainable=False)

        if self.hps.model_mode == 'train':
            self.perc_loss_mean_list = []
            for loop_i in range(len(self.hps.perc_loss_layers)):
                relu_loss_mean = tf.Variable(0.0, trainable=False)
                self.perc_loss_mean_list.append(relu_loss_mean)
            self.last_step_num = tf.Variable(0.0, trainable=False)

            with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):
                self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(self.lr)

    ###########################

    def normalize_image_m1to1(self, in_img_0to1):
        norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
        norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
        return norm_img_m1to1

    def add_coords(self, input_tensor):
        batch_size_tensor = tf.shape(input_tensor)[0]  # get N size

        xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)
        xx_ones = tf.expand_dims(xx_ones, -1)  # e.g. (N, raster_size, 1)
        xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
                           [batch_size_tensor, 1])  # e.g. (N, raster_size)
        xx_range = tf.expand_dims(xx_range, 1)  # e.g. (N, 1, raster_size)

        xx_channel = tf.matmul(xx_ones, xx_range)  # e.g. (N, raster_size, raster_size)
        xx_channel = tf.expand_dims(xx_channel, -1)  # e.g. (N, raster_size, raster_size, 1)

        yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)
        yy_ones = tf.expand_dims(yy_ones, 1)  # e.g. (N, 1, raster_size)
        yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
                           [batch_size_tensor, 1])  # (N, raster_size)
        yy_range = tf.expand_dims(yy_range, -1)  # e.g. (N, raster_size, 1)

        yy_channel = tf.matmul(yy_range, yy_ones)  # e.g. (N, raster_size, raster_size)
        yy_channel = tf.expand_dims(yy_channel, -1)  # e.g. (N, raster_size, raster_size, 1)

        xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
        yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
        # xx_channel = xx_channel * 2 - 1  # [-1, 1]
        # yy_channel = yy_channel * 2 - 1

        ret = tf.concat([
            input_tensor,
            xx_channel,
            yy_channel,
        ], axis=-1)  # e.g. (N, raster_size, raster_size, 4)

        return ret

    def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
                               image_size, window_size):
        """
        :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
        :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
        :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
        :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
        :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
        :param window_size: (N, 1, 1), float, in large size
        :return:
        """
        if self.hps.resize_method == 'BILINEAR':
            resize_method = tf.image.ResizeMethod.BILINEAR
        elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
            resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
        elif self.hps.resize_method == 'BICUBIC':
            resize_method = tf.image.ResizeMethod.BICUBIC
        elif self.hps.resize_method == 'AREA':
            resize_method = tf.image.ResizeMethod.AREA
        else:
            raise Exception('unknown resize_method', self.hps.resize_method)

        patch_photo = tf.stop_gradient(patch_photo)
        patch_canvas = tf.stop_gradient(patch_canvas)
        cursor_pos = tf.stop_gradient(cursor_pos)
        window_size = tf.stop_gradient(window_size)

        entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
                                                                      (self.hps.raster_size, self.hps.raster_size),
                                                                      method=resize_method))
        entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
                                                                      (self.hps.raster_size, self.hps.raster_size),
                                                                      method=resize_method))
        entire_photo_small = self.normalize_image_m1to1(entire_photo_small)  # [-1.0-stroke, 1.0-BG]
        entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small)  # [-1.0-stroke, 1.0-BG]

        if self.hps.encode_cursor_type == 'value':
            cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1)  # (N, 1, 1, 2)
            cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
            cursor_info = cursor_pos_norm
        else:
            raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)

        batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
                                axis=-1)  # [N, raster_size, raster_size, 6/10]
        batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1)  # [N, raster_size, raster_size, 2/4]
        batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
                                       axis=-1)  # [N, raster_size, raster_size, 4/6]

        if self.hps.model_mode == 'train':
            is_training = True
            dropout_keep_prob = self.hps.pix_drop_kp
        else:
            is_training = False
            dropout_keep_prob = 1.0

        if self.hps.add_coordconv:
            batch_input_combined = self.add_coords(batch_input_combined)  # (N, in_H, in_W, in_dim + 2)
            batch_input_local = self.add_coords(batch_input_local)  # (N, in_H, in_W, in_dim + 2)
            batch_input_global = self.add_coords(batch_input_global)  # (N, in_H, in_W, in_dim + 2)

        if 'combine' in self.hps.encoder_type:
            if self.hps.encoder_type == 'combine33':
                image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combine43':
                image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combine53':
                image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 128)
            elif self.hps.encoder_type == 'combineFC':
                image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
                                                                         is_training, dropout_keep_prob)  # (N, 256)
            else:
                raise Exception('Unknown encoder_type', self.hps.encoder_type)
        else:
            with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
                if self.hps.encoder_type == 'conv10':
                    image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_deep':
                    image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)
                elif self.hps.encoder_type == 'conv13':
                    image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_c3':
                    image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv10_deep_c3':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)
                elif self.hps.encoder_type == 'conv13_c3':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                elif self.hps.encoder_type == 'conv13_c3_attn':
                    image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)
                else:
                    raise Exception('Unknown encoder_type', self.hps.encoder_type)
        return image_embedding

    def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
        rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
        rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])

        pen_n_out = 2
        params_n_out = 6

        with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
            output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
            output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
            output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen)  # (N, pen_n_out)

        with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
            output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
            output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
            output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params)  # (N, params_n_out)

        output = tf.concat([output_pen, output_params], axis=1)  # (N, n_out)

        return output, last_state

    def get_mixture_coef(self, outputs):
        z = outputs
        z_pen_logits = z[:, 0:2]  # (N, 2), pen states
        z_other_params_logits = z[:, 2:]  # (N, 6)

        z_pen = tf.nn.softmax(z_pen_logits)  # (N, 2)
        if self.hps.position_format == 'abs':
            x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2])  # (N, 2)
            x2y2 = tf.tanh(z_other_params_logits[:, 2:4])  # (N, 2)
            widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5])  # (N, 1)
            widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
            scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling  # (N, 1), [0.0, max_scaling]
            # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
            #                  self.hps.min_scaling)
            z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1)  # (N, 6)
        else:  # "rel"
            raise Exception('Unknown position_format', self.hps.position_format)

        r = [z_other_params, z_pen]
        return r

    ###########################

    def get_decoder_inputs(self):
        initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
        return initial_state

    def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
        with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
            output, last_state = tf.nn.dynamic_rnn(
                dec_cell,
                actual_input_x,
                initial_state=initial_state,
                time_major=False,
                swap_memory=True,
                dtype=tf.float32)
        return output, last_state

    ###########################

    def image_padding(self, ori_image, window_size, pad_value):
        """
        Pad with (bg)
        :param ori_image:
        :return:
        """
        paddings = [[0, 0],
                    [window_size // 2, window_size // 2],
                    [window_size // 2, window_size // 2],
                    [0, 0]]
        pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value)  # (N, H_p, W_p, k)
        return pad_img

    def image_cropping_fn(self, fn_inputs):
        """
        crop the patch
        :return:
        """
        index_offset = self.hps.input_channel - 1
        input_image = fn_inputs[:, :, 0:2 + index_offset]  # (image_size, image_size, 2), [0.0-BG, 1.0-stroke]
        cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset]  # (2), in [0.0, 1.0)
        image_size = fn_inputs[0, 0, 4 + index_offset]  # (), float32
        window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32)  # ()

        input_img_reshape = tf.expand_dims(input_image, axis=0)
        pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)

        cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
        x0, x1 = cursor_pos[0], cursor_pos[0] + window_size  # ()
        y0, y1 = cursor_pos[1], cursor_pos[1] + window_size  # ()
        patch_image = pad_img[:, y0:y1, x0:x1, :]  # (1, window_size, window_size, 2/4)

        # resize to raster_size
        patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
                                                    method=tf.image.ResizeMethod.AREA)
        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
        # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]

        return patch_image_scaled

    def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
        """
        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
        :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        """
        input_img_ = input_img
        window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes))  # (N, 1, 1), no grad

        cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)
        cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 2)

        image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (1, 1, 1, 1)
        image_size_ = tf.tile(image_size_, [self.hps.batch_size // self.total_loop, image_size, image_size, 1])

        window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1))  # (N, 1, 1, 1)
        window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 1)

        fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
                              axis=-1)  # (N, image_size, image_size, 2/4 + 4)
        curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32)  # (N, raster_size, raster_size, -)
        return curr_patch_imgs

    def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
        """
        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
        :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        """
        window_sizes_non_grad = tf.stop_gradient(window_sizes)  # (N, 1, 1), no grad

        cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
        cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1)  # (N, 1, 1)

        y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
        x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
        y2 = y1 + (window_sizes_non_grad - 1.0)
        x2 = x1 + (window_sizes_non_grad - 1.0)
        boxes = tf.concat([y1, x1, y2, x2], axis=-1)  # (N, 1, 4)
        boxes = tf.squeeze(boxes, axis=1)  # (N, 4)
        boxes = boxes / tf.cast(image_size - 1, tf.float32)

        box_ind = tf.ones_like(cursor_x)[:, 0, 0]  # (N)
        box_ind = tf.cast(box_ind, dtype=tf.int32)
        box_ind = tf.cumsum(box_ind) - 1

        curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
                                                   crop_size=[self.hps.raster_size, self.hps.raster_size])
        #  (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
        return curr_patch_imgs

    def get_pixel_value(self, img, x, y):
        """
        Utility function to get pixel value for coordinate vectors x and y from a  4D tensor image.

        Input
        -----
        - img: tensor of shape (B, H, W, C)
        - x: flattened tensor of shape (B, H', W')
        - y: flattened tensor of shape (B, H', W')

        Returns
        -------
        - output: tensor of shape (B, H', W', C)
        """
        shape = tf.shape(x)
        batch_size = shape[0]
        height = shape[1]
        width = shape[2]

        batch_idx = tf.range(0, batch_size)
        batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
        b = tf.tile(batch_idx, (1, height, width))

        indices = tf.stack([b, y, x], 3)

        return tf.gather_nd(img, indices)

    def image_pasting_nondiff_single(self, fn_inputs):
        patch_image = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        cursor_pos = fn_inputs[0, 0, 1:3]  # (2), in large size
        image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()
        window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32)  # ()

        patch_image_scaled = tf.expand_dims(patch_image, axis=0)  # (1, raster_size, raster_size, 1)
        patch_image_scaled = tf.image.resize_images(patch_image_scaled, (window_size, window_size),
                                                    method=tf.image.ResizeMethod.BILINEAR)
        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
        # patch_canvas_scaled: (window_size, window_size, 1)

        cursor_pos = tf.cast(tf.round(cursor_pos), dtype=tf.int32)  # (2)
        cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]

        pad_up = cursor_y
        pad_down = image_size - cursor_y
        pad_left = cursor_x
        pad_right = image_size - cursor_x

        paddings = [[pad_up, pad_down],
                    [pad_left, pad_right],
                    [0, 0]]
        pad_img = tf.pad(patch_image_scaled, paddings=paddings, mode='CONSTANT',
                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]

        crop_start = window_size // 2
        pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
        return pasted_image

    def image_pasting_diff_single(self, fn_inputs):
        patch_canvas = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        cursor_pos = fn_inputs[0, 0, 1:3]  # (2), in large size
        image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()
        window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32)  # ()
        cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]

        patch_canvas_scaled = tf.expand_dims(patch_canvas, axis=0)  # (1, raster_size, raster_size, 1)
        patch_canvas_scaled = tf.image.resize_images(patch_canvas_scaled, (window_size, window_size),
                                                     method=tf.image.ResizeMethod.BILINEAR)
        # patch_canvas_scaled: (1, window_size, window_size, 1)

        valid_canvas = self.image_pasting_diff_batch(patch_canvas_scaled,
                                                     tf.expand_dims(tf.expand_dims(cursor_pos, axis=0), axis=0),
                                                     window_size)
        valid_canvas = tf.squeeze(valid_canvas, axis=0)
        # (window_size + 1, window_size + 1, 1)

        pad_up = tf.cast(tf.floor(cursor_y), tf.int32)
        pad_down = image_size - 1 - tf.cast(tf.floor(cursor_y), tf.int32)
        pad_left = tf.cast(tf.floor(cursor_x), tf.int32)
        pad_right = image_size - 1 - tf.cast(tf.floor(cursor_x), tf.int32)

        paddings = [[pad_up, pad_down],
                    [pad_left, pad_right],
                    [0, 0]]
        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]

        crop_start = window_size // 2
        pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
        return pasted_image

    def image_pasting_diff_single_v3(self, fn_inputs):
        patch_canvas = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        cursor_pos_a = fn_inputs[0, 0, 1:3]  # (2), float32, in large size
        image_size_a = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()
        window_size_a = fn_inputs[0, 0, 4]  # (), float32, with grad
        raster_size_a = float(self.hps.raster_size)

        padding_size = tf.cast(tf.ceil(window_size_a / 2.0), tf.int32)

        x1y1_a = cursor_pos_a - window_size_a / 2.0  # (2), float32
        x2y2_a = cursor_pos_a + window_size_a / 2.0  # (2), float32

        x1y1_a_floor = tf.floor(x1y1_a)  # (2)
        x2y2_a_ceil = tf.ceil(x2y2_a)  # (2)

        cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0  # (2)
        cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / window_size_a * raster_size_a  # (2)
        raster_size_b = (x2y2_a_ceil - x1y1_a_floor)  # (x, y)
        image_size_b = raster_size_a
        window_size_b = raster_size_a * (raster_size_b / window_size_a)  # (x, y)

        cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1)  # (1)

        y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
        x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
        y2_b = y1_b + (window_size_b[1] - 1.)
        x2_b = x1_b + (window_size_b[0] - 1.)
        boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1)  # (4)
        boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32)  # with grad to window_size_a

        box_ind_b = tf.ones((1), dtype=tf.int32)  # (1)
        box_ind_b = tf.cumsum(box_ind_b) - 1

        patch_canvas = tf.expand_dims(patch_canvas, axis=0)  # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
        boxes_b = tf.expand_dims(boxes_b, axis=0)  # (1, 4)

        valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
                                                crop_size=[raster_size_b[1], raster_size_b[0]])
        valid_canvas = valid_canvas[0]  # (raster_size_b, raster_size_b, 1)

        pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
        pad_down = image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
        pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
        pad_right = image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)

        paddings = [[pad_up, pad_down],
                    [pad_left, pad_right],
                    [0, 0]]
        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]

        pasted_image = pad_img[padding_size: padding_size + image_size_a, padding_size: padding_size + image_size_a, :]
        return pasted_image

    def image_pasting_diff_batch(self, patch_image, cursor_position, window_size):
        """
        :param patch_img: (N, window_size, window_size, 1), [0.0-BG, 1.0-stroke]
        :param cursor_position: (N, 1, 2), in large size
        :return:
        """
        paddings1 = [[0, 0],
                     [1, 1],
                     [1, 1],
                     [0, 0]]
        patch_image_pad1 = tf.pad(patch_image, paddings=paddings1, mode='CONSTANT',
                                  constant_values=0.0)  # (N, window_size+2, window_size+2, 1), [0.0-BG, 1.0-stroke]

        cursor_x, cursor_y = cursor_position[:, :, 0:1], cursor_position[:, :, 1:2]  # (N, 1, 1)
        cursor_x_f, cursor_y_f = tf.floor(cursor_x), tf.floor(cursor_y)
        patch_x, patch_y = 1.0 - (cursor_x - cursor_x_f), 1.0 - (cursor_y - cursor_y_f)  # (N, 1, 1)

        x_ones = tf.ones_like(patch_x, dtype=tf.float32)  # (N, 1, 1)
        x_ones = tf.tile(x_ones, [1, 1, window_size])  # (N, 1, window_size)
        patch_x = tf.concat([patch_x, x_ones], axis=-1)  # (N, 1, window_size + 1)
        patch_x = tf.tile(patch_x, [1, window_size + 1, 1])  # (N, window_size + 1, window_size + 1)
        patch_x = tf.cumsum(patch_x, axis=-1)  # (N, window_size + 1, window_size + 1)
        patch_x0 = tf.cast(tf.floor(patch_x), tf.int32)  # (N, window_size + 1, window_size + 1)
        patch_x1 = patch_x0 + 1  # (N, window_size + 1, window_size + 1)

        y_ones = tf.ones_like(patch_y, dtype=tf.float32)  # (N, 1, 1)
        y_ones = tf.tile(y_ones, [1, window_size, 1])  # (N, window_size, 1)
        patch_y = tf.concat([patch_y, y_ones], axis=1)  # (N, window_size + 1, 1)
        patch_y = tf.tile(patch_y, [1, 1, window_size + 1])  # (N, window_size + 1, window_size + 1)
        patch_y = tf.cumsum(patch_y, axis=1)  # (N, window_size + 1, window_size + 1)
        patch_y0 = tf.cast(tf.floor(patch_y), tf.int32)  # (N, window_size + 1, window_size + 1)
        patch_y1 = patch_y0 + 1  # (N, window_size + 1, window_size + 1)

        # get pixel value at corner coords
        valid_canvas_patch_a = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y0)
        valid_canvas_patch_b = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y1)
        valid_canvas_patch_c = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y0)
        valid_canvas_patch_d = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y1)
        # (N, window_size + 1, window_size + 1, 1)

        patch_x0 = tf.cast(patch_x0, tf.float32)
        patch_x1 = tf.cast(patch_x1, tf.float32)
        patch_y0 = tf.cast(patch_y0, tf.float32)
        patch_y1 = tf.cast(patch_y1, tf.float32)

        # calculate deltas
        wa = (patch_x1 - patch_x) * (patch_y1 - patch_y)
        wb = (patch_x1 - patch_x) * (patch_y - patch_y0)
        wc = (patch_x - patch_x0) * (patch_y1 - patch_y)
        wd = (patch_x - patch_x0) * (patch_y - patch_y0)
        # (N, window_size + 1, window_size + 1)

        # add dimension for addition
        wa = tf.expand_dims(wa, axis=3)
        wb = tf.expand_dims(wb, axis=3)
        wc = tf.expand_dims(wc, axis=3)
        wd = tf.expand_dims(wd, axis=3)
        # (N, window_size + 1, window_size + 1, 1)

        # compute output
        valid_canvas_patch_ = tf.add_n([wa * valid_canvas_patch_a,
                                        wb * valid_canvas_patch_b,
                                        wc * valid_canvas_patch_c,
                                        wd * valid_canvas_patch_d])  # (N, window_size + 1, window_size + 1, 1)
        return valid_canvas_patch_

    def image_pasting(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
        """
        paste the patch_img to padded size based on cursor_position
        :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
        :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        :return:
        """
        cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32))  # in large size
        window_sizes_r = tf.round(window_sizes)  # (N, 1, 1), no grad

        patch_img_ = tf.expand_dims(patch_img, axis=-1)  # (N, raster_size, raster_size, 1)
        cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)
        cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
                                                              1])  # (N, raster_size, raster_size, 2)
        image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (N, 1, 1, 1)
        image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
                                                    self.hps.raster_size, 1])
        window_sizes_tile = tf.reshape(window_sizes_r, (-1, 1, 1, 1))  # (N, 1, 1, 1)
        window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])

        pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
                                   axis=-1)  # (N, raster_size, raster_size, 5)

        if is_differentiable:
            curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single, pasting_inputs,
                                        parallel_iterations=32)  # (N, image_size, image_size, 1)
        else:
            curr_paste_imgs = tf.map_fn(self.image_pasting_nondiff_single, pasting_inputs,
                                        parallel_iterations=32)  # (N, image_size, image_size, 1)
        curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1)  # (N, image_size, image_size)
        return curr_paste_imgs

    def image_pasting_v3(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
        """
        paste the patch_img to padded size based on cursor_position
        :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
        :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
        :param window_sizes: (N, 1, 1), float32, with grad
        :return:
        """
        cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32))  # in large size

        if is_differentiable:
            patch_img_ = tf.expand_dims(patch_img, axis=-1)  # (N, raster_size, raster_size, 1)
            cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)
            cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
                                           1])  # (N, raster_size, raster_size, 2)
            image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (N, 1, 1, 1)
            image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
                                      self.hps.raster_size, 1])
            window_sizes_tile = tf.reshape(window_sizes, (-1, 1, 1, 1))  # (N, 1, 1, 1)
            window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])

            pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
                                       axis=-1)  # (N, raster_size, raster_size, 5)
            curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single_v3, pasting_inputs,
                                        parallel_iterations=32)  # (N, image_size, image_size, 1)
        else:
            raise Exception('Unfinished...')
        curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1)  # (N, image_size, image_size)
        return curr_paste_imgs

    def get_points_and_raster_image(self, initial_state, init_cursor, input_photo, image_size):
        ## generate the other_params and pen_ras and raster image for raster loss
        prev_state = initial_state  # (N, dec_rnn_size * 3)

        prev_width = self.init_width  # (1)
        prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=0), axis=0)  # (1, 1, 1)
        prev_width = tf.tile(prev_width, [self.hps.batch_size // self.total_loop, 1, 1])  # (N, 1, 1)

        prev_scaling = tf.ones((self.hps.batch_size // self.total_loop, 1, 1))  # (N, 1, 1)
        prev_window_size = tf.ones((self.hps.batch_size // self.total_loop, 1, 1),
                                   dtype=tf.float32) * float(self.hps.raster_size)  # (N, 1, 1)

        cursor_position_temp = init_cursor
        self.cursor_position = cursor_position_temp  # (N, 1, 2), in size [0.0, 1.0)
        cursor_position_loop = self.cursor_position

        other_params_list = []
        pen_ras_list = []

        pos_before_max_min_list = []
        win_size_before_max_min_list = []

        curr_canvas_soft = tf.zeros_like(input_photo[:, :, :, 0])  # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
        curr_canvas_soft_rgb = tf.tile(tf.zeros_like(input_photo[:, :, :, 0:1]), [1, 1, 1, 3])  # (N, image_size, image_size, 3), [0.0-BG, 1.0-stroke]
        curr_canvas_hard = tf.zeros_like(curr_canvas_soft)  # [0.0-BG, 1.0-stroke]

        #### sampling part - start ####
        self.curr_canvas_hard = curr_canvas_hard

        rasterizor_st = NeuralRasterizorStep(
            raster_size=self.hps.raster_size,
            position_format=self.hps.position_format)

        if self.hps.cropping_type == 'v3':
            cropping_func = self.image_cropping_v3
        # elif self.hps.cropping_type == 'v2':
        #     cropping_func = self.image_cropping
        else:
            raise Exception('Unknown cropping_type', self.hps.cropping_type)

        if self.hps.pasting_type == 'v3':
            pasting_func = self.image_pasting_v3
        # elif self.hps.pasting_type == 'v2':
        #     pasting_func = self.image_pasting
        else:
            raise Exception('Unknown pasting_type', self.hps.pasting_type)

        for time_i in range(self.hps.max_seq_len):
            cursor_position_non_grad = tf.stop_gradient(cursor_position_loop)  # (N, 1, 2), in size [0.0, 1.0)

            curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size))  # float, with grad
            curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
            curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))

            ## patch-level encoding
            # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
            curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
            curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)

            # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
            crop_inputs = tf.concat([1.0 - input_photo, curr_canvas_hard_non_grad], axis=-1)  # (N, H_p, W_p, 1/3+1)

            cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
            index_offset = self.hps.input_channel - 1
            curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset]  # [0.0-BG, 1.0-stroke]
            curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
            # (N, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]

            curr_patch_inputs = 1.0 - curr_patch_inputs  # [0.0-stroke, 1.0-BG]
            curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
            # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]

            # Normalizing image
            curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad  # [0.0-stroke, 1.0-BG]
            curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad)  # [-1.0-stroke, 1.0-BG]

            ## image-level encoding
            combined_z = self.build_combined_encoder(
                curr_patch_canvas_hard_non_grad,
                curr_patch_inputs,
                1.0 - curr_canvas_hard_non_grad,
                input_photo,
                cursor_position_non_grad,
                image_size,
                curr_window_size)  # (N, z_size)
            combined_z = tf.expand_dims(combined_z, axis=1)  # (N, 1, z_size)

            curr_window_size_top_side_norm_non_grad = \
                tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
            curr_window_size_bottom_side_norm_non_grad = \
                tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
            if not self.hps.concat_win_size:
                combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2)  # (N, 1, 2+z_size)
            else:
                combined_z = tf.concat([tf.stop_gradient(prev_width),
                                        curr_window_size_top_side_norm_non_grad,
                                        curr_window_size_bottom_side_norm_non_grad,
                                        combined_z],
                                       2)  # (N, 1, 2+z_size)

            if self.hps.concat_cursor:
                prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2)  # (N, 1, 2+2+z_size)
            else:
                prev_input_x = combined_z  # (N, 1, 2+z_size)

            h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
            # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
            [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
            # o_other_params: (N * 1, 6)
            # o_pen_ras: (N * 1, 2), after softmax

            o_other_params = tf.reshape(o_other_params, [-1, 1, 6])  # (N, 1, 6)
            o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2])  # (N, 1, 2)

            other_params_list.append(o_other_params)
            pen_ras_list.append(o_pen_ras_raw)

            #### sampling part - end ####

            if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval' or self.hps.model_mode == 'eval_sample':
                # use renderer here to convert the strokes to image
                curr_other_params = tf.squeeze(o_other_params, axis=1)  # (N, 6), (x1, y1)=[0.0, 1.0], (x2, y2)=[-1.0, 1.0]
                x1y1, x2y2, width2, scaling = curr_other_params[:, 0:2], curr_other_params[:, 2:4],\
                                              curr_other_params[:, 4:5], curr_other_params[:, 5:6]
                x0y0 = tf.zeros_like(x2y2)  # (N, 2), [-1.0, 1.0]
                x0y0 = tf.div(tf.add(x0y0, 1.0), 2.0)  # (N, 2), [0.0, 1.0]
                x2y2 = tf.div(tf.add(x2y2, 1.0), 2.0)  # (N, 2), [0.0, 1.0]
                widths = tf.concat([tf.squeeze(prev_width, axis=1), width2], axis=1)  # (N, 2)
                curr_other_params = tf.concat([x0y0, x1y1, x2y2, widths], axis=-1)  # (N, 8), (x0, y0)&(x2, y2)=[0.0, 1.0]
                curr_stroke_image = rasterizor_st.raster_func_stroke_abs(curr_other_params)
                # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]

                curr_stroke_image_large = pasting_func(cursor_position_loop, curr_stroke_image,
                                                             image_size, curr_window_size,
                                                             is_differentiable=self.hps.pasting_diff)
                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]

                ## soft
                if not self.hps.use_softargmax:
                    curr_state_soft = o_pen_ras[:, 1:2]  # (N, 1)
                else:
                    curr_state_soft = self.differentiable_argmax(o_pen_ras, self.hps.soft_beta)  # (N, 1)

                curr_state_soft = tf.expand_dims(curr_state_soft, axis=1)  # (N, 1, 1)

                filter_curr_stroke_image_soft = tf.multiply(tf.subtract(1.0, curr_state_soft), curr_stroke_image_large)
                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
                curr_canvas_soft = tf.add(curr_canvas_soft, filter_curr_stroke_image_soft)  # [0.0-BG, 1.0-stroke]

                ## hard
                curr_state_hard = tf.expand_dims(tf.cast(tf.argmax(o_pen_ras_raw, axis=-1), dtype=tf.float32),
                                                     axis=-1)  # (N, 1, 1)
                filter_curr_stroke_image_hard = tf.multiply(tf.subtract(1.0, curr_state_hard), curr_stroke_image_large)
                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
                self.curr_canvas_hard = tf.add(self.curr_canvas_hard, filter_curr_stroke_image_hard)  # [0.0-BG, 1.0-stroke]
                self.curr_canvas_hard = tf.clip_by_value(self.curr_canvas_hard, 0.0, 1.0)  # [0.0-BG, 1.0-stroke]

            next_width = o_other_params[:, :, 4:5]
            next_scaling = o_other_params[:, :, 5:6]
            next_window_size = tf.multiply(next_scaling, tf.stop_gradient(curr_window_size))  # float, with grad
            window_size_before_max_min = next_window_size  # (N, 1, 1), large-level
            win_size_before_max_min_list.append(window_size_before_max_min)
            next_window_size = tf.maximum(next_window_size, tf.cast(self.hps.min_window_size, tf.float32))
            next_window_size = tf.minimum(next_window_size, tf.cast(image_size, tf.float32))

            prev_state = next_state
            prev_width = next_width * curr_window_size / next_window_size  # (N, 1, 1)
            prev_scaling = next_scaling  # (N, 1, 1))
            prev_window_size = curr_window_size

            # update the cursor position
            new_cursor_offsets = tf.multiply(o_other_params[:, :, 2:4],
                                             tf.divide(curr_window_size, 2.0))  # (N, 1, 2), window-level
            new_cursor_offset_next = new_cursor_offsets
            new_cursor_offset_next = tf.concat([new_cursor_offset_next[:, :, 1:2], new_cursor_offset_next[:, :, 0:1]], axis=-1)

            cursor_position_loop_large = tf.multiply(cursor_position_loop, tf.cast(image_size, tf.float32))

            if self.hps.stop_accu_grad:
                stroke_position_next = tf.stop_gradient(cursor_position_loop_large) + new_cursor_offset_next  # (N, 1, 2), large-level
            else:
                stroke_position_next = cursor_position_loop_large + new_cursor_of
Download .txt
gitextract_xyepypzb/

├── .gitignore
├── LICENSE
├── README.md
├── README_CN.md
├── WINDOWS_INSTALL_GUIDE.md
├── dataset_utils.py
├── docs/
│   ├── assets/
│   │   ├── font.css
│   │   └── style.css
│   └── index.html
├── hyper_parameters.py
├── launch_gui.bat
├── model_common_test.py
├── model_common_train.py
├── rasterization_utils/
│   ├── NeuralRenderer.py
│   └── RealRenderer.py
├── rnn.py
├── subnet_tf_utils.py
├── test_photograph_to_line.py
├── test_rough_sketch_simplification.py
├── test_vectorization.py
├── tools/
│   ├── gif_making.py
│   ├── svg_conversion.py
│   └── visualize_drawing.py
├── train_rough_photograph.py
├── train_vectorization.py
├── utils.py
├── vgg_utils/
│   └── VGG16.py
└── virtual_sketch_gui.py
Download .txt
SYMBOL INDEX (226 symbols across 19 files)

FILE: dataset_utils.py
  function copy_hparams (line 12) | def copy_hparams(hparams):
  class GeneralRawDataLoader (line 17) | class GeneralRawDataLoader(object):
    method __init__ (line 18) | def __init__(self,
    method get_test_image (line 26) | def get_test_image(self, random_cursor=True, init_cursor_on_undrawn_pi...
    method gen_input_images (line 35) | def gen_input_images(self, image_path):
    method crop_patch (line 78) | def crop_patch(self, image, center, image_size, crop_size):
    method gen_init_cursor_single (line 90) | def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_...
    method gen_init_cursors (line 108) | def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_o...
  function load_dataset_testing (line 134) | def load_dataset_testing(test_data_base_dir, test_dataset, test_img_name...
  class GeneralMultiObjectDataLoader (line 156) | class GeneralMultiObjectDataLoader(object):
    method __init__ (line 157) | def __init__(self,
    method preprocess_rand_data (line 182) | def preprocess_rand_data(self, stroke3):
    method cal_dist (line 187) | def cal_dist(self, posA, posB):
    method invalid_position (line 190) | def invalid_position(self, pos, obj_size, pos_list, size_list):
    method get_object_info (line 205) | def get_object_info(self, image_size, vary_thickness=True, try_total_t...
    method object_pasting (line 255) | def object_pasting(self, obj_img, canvas_img, center):
    method get_multi_object_image (line 279) | def get_multi_object_image(self, img_size, vary_thickness):
    method get_batch_from_memory (line 301) | def get_batch_from_memory(self, memory_idx, vary_thickness, fixed_imag...
    method get_batch_multi_res (line 323) | def get_batch_multi_res(self, loop_num, vary_thickness, random_cursor=...
    method gen_stroke_images (line 349) | def gen_stroke_images(self, stroke3_list, image_size, stroke_width):
    method crop_patch (line 361) | def crop_patch(self, image, center, image_size, crop_size):
    method gen_init_cursor_single (line 373) | def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_...
    method gen_init_cursors (line 391) | def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_o...
  function load_dataset_multi_object (line 417) | def load_dataset_multi_object(dataset_base_dir, model_params):
  class GeneralDataLoaderMultiObjectRough (line 473) | class GeneralDataLoaderMultiObjectRough(object):
    method __init__ (line 474) | def __init__(self,
    method rough_augmentation (line 508) | def rough_augmentation(self, raw_photo, texture_prob=0.20, noise_prob=...
    method image_interpolation (line 583) | def image_interpolation(self, photo, sketch, photo_prob):
    method get_batch_from_memory (line 588) | def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_im...
    method select_sketch (line 625) | def select_sketch(self, image_size_rand):
    method get_batch_multi_res (line 638) | def get_batch_multi_res(self, loop_num, interpolate_type, random_curso...
    method crop_patch (line 678) | def crop_patch(self, image, center, image_size, crop_size):
    method gen_init_cursor_single (line 690) | def gen_init_cursor_single(self, sketch_image):
    method gen_init_cursors (line 703) | def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_n...
  function load_dataset_multi_object_rough (line 729) | def load_dataset_multi_object_rough(dataset_base_dir, model_params):
  class GeneralDataLoaderNormalImageLinear (line 893) | class GeneralDataLoaderNormalImageLinear(object):
    method __init__ (line 894) | def __init__(self,
    method get_batch_from_memory (line 928) | def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_im...
    method crop_and_augment (line 948) | def crop_and_augment(self, photo, sketch, shape, crop_size, rotate_ang...
    method image_interpolation (line 1002) | def image_interpolation(self, photo, sketch, photo_prob):
    method select_sketch_and_crop (line 1007) | def select_sketch_and_crop(self, image_size_rand, interpolate_type, ro...
    method get_batch_multi_res (line 1043) | def get_batch_multi_res(self, loop_num, interpolate_type, random_curso...
    method crop_patch (line 1083) | def crop_patch(self, image, center, image_size, crop_size):
    method gen_init_cursor_single (line 1095) | def gen_init_cursor_single(self, sketch_image):
    method gen_init_cursors (line 1108) | def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_n...
  function load_dataset_normal_images (line 1134) | def load_dataset_normal_images(dataset_base_dir, model_params):
  function load_dataset_training (line 1216) | def load_dataset_training(dataset_base_dir, model_params):

FILE: hyper_parameters.py
  function get_default_hparams_clean (line 40) | def get_default_hparams_clean():
  function get_default_hparams_rough (line 141) | def get_default_hparams_rough():
  function get_default_hparams_normal (line 242) | def get_default_hparams_normal():

FILE: model_common_test.py
  class DiffPastingV3 (line 11) | class DiffPastingV3(object):
    method __init__ (line 12) | def __init__(self, raster_size):
    method image_pasting_sampling_v3 (line 23) | def image_pasting_sampling_v3(self):
  class VirtualSketchingModel (line 75) | class VirtualSketchingModel(object):
    method __init__ (line 76) | def __init__(self, hps, gpu_mode=True, reuse=False):
    method build_model (line 97) | def build_model(self):
    method config_model (line 122) | def config_model(self):
    method normalize_image_m1to1 (line 176) | def normalize_image_m1to1(self, in_img_0to1):
    method add_coords (line 181) | def add_coords(self, input_tensor):
    method build_combined_encoder (line 215) | def build_combined_encoder(self, patch_canvas, patch_photo, entire_can...
    method build_seq_decoder (line 311) | def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
    method get_mixture_coef (line 332) | def get_mixture_coef(self, outputs):
    method get_decoder_inputs (line 355) | def get_decoder_inputs(self):
    method rnn_decoder (line 359) | def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
    method image_padding (line 372) | def image_padding(self, ori_image, window_size, pad_value):
    method image_cropping_fn (line 385) | def image_cropping_fn(self, fn_inputs):
    method image_cropping (line 412) | def image_cropping(self, cursor_position, input_img, image_size, windo...
    method image_cropping_v3 (line 435) | def image_cropping_v3(self, cursor_position, input_img, image_size, wi...
    method get_points_and_raster_image (line 463) | def get_points_and_raster_image(self, image_size):
    method differentiable_argmax (line 575) | def differentiable_argmax(self, input_pen, soft_beta):

FILE: model_common_train.py
  class VirtualSketchingModel (line 13) | class VirtualSketchingModel(object):
    method __init__ (line 14) | def __init__(self, hps, gpu_mode=True, reuse=False):
    method build_model (line 35) | def build_model(self):
    method config_model (line 150) | def config_model(self):
    method normalize_image_m1to1 (line 226) | def normalize_image_m1to1(self, in_img_0to1):
    method add_coords (line 231) | def add_coords(self, input_tensor):
    method build_combined_encoder (line 265) | def build_combined_encoder(self, patch_canvas, patch_photo, entire_can...
    method build_seq_decoder (line 361) | def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
    method get_mixture_coef (line 382) | def get_mixture_coef(self, outputs):
    method get_decoder_inputs (line 405) | def get_decoder_inputs(self):
    method rnn_decoder (line 409) | def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
    method image_padding (line 422) | def image_padding(self, ori_image, window_size, pad_value):
    method image_cropping_fn (line 435) | def image_cropping_fn(self, fn_inputs):
    method image_cropping (line 462) | def image_cropping(self, cursor_position, input_img, image_size, windo...
    method image_cropping_v3 (line 485) | def image_cropping_v3(self, cursor_position, input_img, image_size, wi...
    method get_pixel_value (line 513) | def get_pixel_value(self, img, x, y):
    method image_pasting_nondiff_single (line 540) | def image_pasting_nondiff_single(self, fn_inputs):
    method image_pasting_diff_single (line 570) | def image_pasting_diff_single(self, fn_inputs):
    method image_pasting_diff_single_v3 (line 603) | def image_pasting_diff_single_v3(self, fn_inputs):
    method image_pasting_diff_batch (line 657) | def image_pasting_diff_batch(self, patch_image, cursor_position, windo...
    method image_pasting (line 723) | def image_pasting(self, cursor_position_norm, patch_img, image_size, w...
    method image_pasting_v3 (line 756) | def image_pasting_v3(self, cursor_position_norm, patch_img, image_size...
    method get_points_and_raster_image (line 786) | def get_points_and_raster_image(self, initial_state, init_cursor, inpu...
    method differentiable_argmax (line 991) | def differentiable_argmax(self, input_pen, soft_beta):
    method build_losses (line 1022) | def build_losses(self, target_sketch, pred_raster_imgs, pred_params,
    method build_training_op_split (line 1146) | def build_training_op_split(self, raster_cost, sn_cost, cursor_outside...
    method build_training_op (line 1159) | def build_training_op(self, grad_list):
    method average_gradients (line 1174) | def average_gradients(self, grads_list):

FILE: rasterization_utils/NeuralRenderer.py
  class RasterUnit (line 4) | class RasterUnit(object):
    method __init__ (line 5) | def __init__(self,
    method build_unit (line 15) | def build_unit(self):
    method conv2d (line 47) | def conv2d(self, input_tensor, out_channels, kernel_size, stride, scop...
    method fully_connected (line 54) | def fully_connected(self, input_tensor, in_dim, out_dim, scope, reuse=...
    method pixel_shuffle (line 63) | def pixel_shuffle(self, input_tensor, upscale_factor):
  class NeuralRasterizor (line 73) | class NeuralRasterizor(object):
    method __init__ (line 74) | def __init__(self,
    method raster_func_abs (line 88) | def raster_func_abs(self, input_data, raster_seq_len=None):
    method stroke_drawer_with_raster_unit (line 111) | def stroke_drawer_with_raster_unit(self, params_batch):
  class NeuralRasterizorStep (line 126) | class NeuralRasterizorStep(object):
    method __init__ (line 127) | def __init__(self,
    method raster_func_stroke_abs (line 135) | def raster_func_stroke_abs(self, input_data):
    method mask_ending_state (line 148) | def mask_ending_state(self, input_states):
    method stroke_drawer_with_raster_unit (line 159) | def stroke_drawer_with_raster_unit(self, params_batch):

FILE: rasterization_utils/RealRenderer.py
  class GizehRasterizor (line 5) | class GizehRasterizor(object):
    method __init__ (line 6) | def __init__(self):
    method get_line_array_v2 (line 9) | def get_line_array_v2(self, image_size, seq_strokes, stroke_width, is_...
    method get_line_array (line 39) | def get_line_array(self, p1, p2, image_size, stroke_width, is_bin=True):
    method load_sketch_images_on_the_fly_v2 (line 60) | def load_sketch_images_on_the_fly_v2(self, image_size, norm_strokes3, ...
    method load_sketch_images_on_the_fly (line 75) | def load_sketch_images_on_the_fly(self, image_size, norm_strokes3, str...
    method normalize_coordinate_np (line 102) | def normalize_coordinate_np(self, sx, sy, image_size, raster_padding=1...
    method normalize_strokes_np (line 148) | def normalize_strokes_np(self, strokes_list, image_size):
    method raster_func (line 166) | def raster_func(self, input_data, image_size, stroke_width, is_bin=Tru...

FILE: rnn.py
  function orthogonal (line 25) | def orthogonal(shape):
  function orthogonal_initializer (line 34) | def orthogonal_initializer(scale=1.0):
  function lstm_ortho_initializer (line 44) | def lstm_ortho_initializer(scale=1.0):
  class LSTMCell (line 61) | class LSTMCell(tf.contrib.rnn.RNNCell):
    method __init__ (line 68) | def __init__(self,
    method state_size (line 79) | def state_size(self):
    method output_size (line 83) | def output_size(self):
    method get_output (line 86) | def get_output(self, state):
    method __call__ (line 90) | def __call__(self, x, state, scope=None):
  function layer_norm_all (line 126) | def layer_norm_all(h,
  function layer_norm (line 160) | def layer_norm(x,
  function raw_layer_norm (line 188) | def raw_layer_norm(x, epsilon=1e-3):
  function super_linear (line 197) | def super_linear(x,
  class LayerNormLSTMCell (line 237) | class LayerNormLSTMCell(tf.contrib.rnn.RNNCell):
    method __init__ (line 244) | def __init__(self,
    method input_size (line 263) | def input_size(self):
    method output_size (line 267) | def output_size(self):
    method state_size (line 271) | def state_size(self):
    method get_output (line 274) | def get_output(self, state):
    method __call__ (line 278) | def __call__(self, x, state, timestep=0, scope=None):
  class HyperLSTMCell (line 314) | class HyperLSTMCell(tf.contrib.rnn.RNNCell):
    method __init__ (line 321) | def __init__(self,
    method input_size (line 368) | def input_size(self):
    method output_size (line 372) | def output_size(self):
    method state_size (line 376) | def state_size(self):
    method get_output (line 379) | def get_output(self, state):
    method hyper_norm (line 384) | def hyper_norm(self, layer, scope='hyper', use_bias=True):
    method __call__ (line 425) | def __call__(self, x, state, timestep=0, scope=None):

FILE: subnet_tf_utils.py
  function get_initializer (line 4) | def get_initializer(init_method):
  function lrelu (line 22) | def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
  function batchnorm (line 32) | def batchnorm(input, name='batch_norm', init_method=None):
  function layernorm (line 52) | def layernorm(input, name='layer_norm', init_method=None):
  function instance_norm (line 71) | def instance_norm(input, name="instance_norm", init_method=None):
  function linear1d (line 88) | def linear1d(inputlin, inputdim, outputdim, name="linear1d", init_method...
  function general_conv2d (line 100) | def general_conv2d(inputconv, output_dim=64, filter_height=4, filter_wid...
  function generative_cnn_c3_encoder (line 135) | def generative_cnn_c3_encoder(inputs, is_training=True, drop_keep_prob=0...
  function generative_cnn_c3_encoder_deeper (line 172) | def generative_cnn_c3_encoder_deeper(inputs, is_training=True, drop_keep...
  function generative_cnn_c3_encoder_combine33 (line 209) | def generative_cnn_c3_encoder_combine33(local_inputs, global_inputs, is_...
  function generative_cnn_c3_encoder_combine43 (line 276) | def generative_cnn_c3_encoder_combine43(local_inputs, global_inputs, is_...
  function generative_cnn_c3_encoder_combine53 (line 350) | def generative_cnn_c3_encoder_combine53(local_inputs, global_inputs, is_...
  function generative_cnn_c3_encoder_combineFC (line 431) | def generative_cnn_c3_encoder_combineFC(local_inputs, global_inputs, is_...
  function generative_cnn_c3_encoder_combineFC_jointAttn (line 516) | def generative_cnn_c3_encoder_combineFC_jointAttn(local_inputs, global_i...
  function generative_cnn_c3_encoder_combineFC_sepAttn (line 618) | def generative_cnn_c3_encoder_combineFC_sepAttn(local_inputs, global_inp...
  function generative_cnn_c3_encoder_deeper13 (line 730) | def generative_cnn_c3_encoder_deeper13(inputs, is_training=True, drop_ke...
  function generative_cnn_c3_encoder_deeper13_attn (line 775) | def generative_cnn_c3_encoder_deeper13_attn(inputs, is_training=True, dr...
  function generative_cnn_encoder (line 822) | def generative_cnn_encoder(inputs, is_training=True, drop_keep_prob=0.5,...
  function generative_cnn_encoder_deeper (line 854) | def generative_cnn_encoder_deeper(inputs, is_training=True, drop_keep_pr...
  function generative_cnn_encoder_deeper13 (line 886) | def generative_cnn_encoder_deeper13(inputs, is_training=True, drop_keep_...
  function max_pooling (line 931) | def max_pooling(x) :
  function hw_flatten (line 935) | def hw_flatten(x) :
  function self_attention (line 939) | def self_attention(x, in_channel, name='self_attention'):
  function global_avg_pooling (line 966) | def global_avg_pooling(x):
  function cnn_discriminator_wgan_gp (line 971) | def cnn_discriminator_wgan_gp(discrim_inputs, discrim_targets, init_meth...

FILE: test_photograph_to_line.py
  function sample (line 17) | def sample(sess, model, input_photos, init_cursor, image_size, init_len,...
  function main_testing (line 128) | def main_testing(test_image_base_dir, test_dataset, test_image_name,
  function main (line 228) | def main(model_name, test_image_name, sampling_num):

FILE: test_rough_sketch_simplification.py
  function move_cursor_to_undrawn (line 17) | def move_cursor_to_undrawn(current_pos_list, input_image_, patch_size,
  function sample (line 81) | def sample(sess, model, input_photos, init_cursor, image_size, init_len,...
  function main_testing (line 220) | def main_testing(test_image_base_dir, test_dataset, test_image_name,
  function main (line 324) | def main(model_name, test_image_name, sampling_num):

FILE: test_vectorization.py
  function move_cursor_to_undrawn (line 19) | def move_cursor_to_undrawn(current_canvas_list, input_image_, last_min_a...
  function sample (line 154) | def sample(sess, model, input_photos, init_cursor, image_size, init_len,...
  function main_testing (line 300) | def main_testing(test_image_base_dir, test_dataset, test_image_name,
  function main (line 428) | def main(model_name, test_image_name, sampling_num):

FILE: tools/gif_making.py
  function add_scaling_visualization (line 15) | def add_scaling_visualization(canvas_images, cursor, window_size, image_...
  function make_gif (line 56) | def make_gif(sess, pasting_func, data, init_cursor, image_size, infer_le...
  function gif_making (line 159) | def gif_making(npz_path):

FILE: tools/svg_conversion.py
  function write_svg_1 (line 7) | def write_svg_1(path_list, img_size, save_path):
  function write_svg_2 (line 62) | def write_svg_2(path_list, img_size, save_path):
  function convert_strokes_to_svg (line 116) | def convert_strokes_to_svg(data, init_cursor, image_size, infer_lengths,...
  function data_convert_to_absolute (line 198) | def data_convert_to_absolute(npz_path, svg_type):

FILE: tools/visualize_drawing.py
  function display_strokes_final (line 15) | def display_strokes_final(sess, pasting_func, data, init_cursor, image_s...
  function visualize_drawing (line 172) | def visualize_drawing(npz_path):

FILE: train_rough_photograph.py
  function should_save_log_img (line 20) | def should_save_log_img(step_):
  function save_log_images (line 27) | def save_log_images(sess, model, data_set, save_root, step_num, curr_pho...
  function train (line 90) | def train(sess, train_model, eval_sample_model, train_set, valid_set, su...
  function trainer (line 279) | def trainer(model_params):
  function main (line 326) | def main(dataset_type):

FILE: train_vectorization.py
  function should_save_log_img (line 19) | def should_save_log_img(step_):
  function save_log_images (line 26) | def save_log_images(sess, model, data_set, save_root, step_num, save_num...
  function train (line 87) | def train(sess, train_model, eval_sample_model, train_set, val_set, sub_...
  function trainer (line 260) | def trainer(model_params):
  function main (line 307) | def main():

FILE: utils.py
  function reset_graph (line 14) | def reset_graph():
  function load_checkpoint (line 22) | def load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_on...
  function create_summary (line 64) | def create_summary(summary_writer, summ_map, step):
  function save_model (line 73) | def save_model(sess, saver, model_save_path, global_step):
  function normal (line 85) | def normal(x, width):
  function draw (line 89) | def draw(f, width=128):
  function rgb_trans (line 113) | def rgb_trans(split_num, break_values):
  function get_colors (line 132) | def get_colors(color_num):
  function save_seq_data (line 154) | def save_seq_data(save_root, save_filename, strokes_data, init_cursors, ...
  function image_pasting_v3_testing (line 162) | def image_pasting_v3_testing(patch_image, cursor, image_size, window_siz...
  function draw_strokes (line 180) | def draw_strokes(data, save_root, save_filename, input_img, image_size, ...
  function update_hyperparams (line 363) | def update_hyperparams(model_params, model_base_dir, model_name, infer_d...

FILE: vgg_utils/VGG16.py
  function vgg_net (line 4) | def vgg_net(x, n_classes, img_size, reuse, is_train=True, dropout_rate=0...
  function vgg_net_slim (line 68) | def vgg_net_slim(x, img_size):

FILE: virtual_sketch_gui.py
  class VirtualSketchApp (line 19) | class VirtualSketchApp:
    method __init__ (line 20) | def __init__(self, root):
    method build_ui (line 28) | def build_ui(self):
    method choose_file (line 44) | def choose_file(self):
    method run_processing (line 57) | def run_processing(self):
    method move_outputs_to_sketches (line 77) | def move_outputs_to_sketches(self):
    method run_svg_conversion (line 91) | def run_svg_conversion(self):
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (457K chars).
[
  {
    "path": ".gitignore",
    "chars": 153,
    "preview": ".idea\n.idea/\ndata/\ndatas/\ndataset/\ndatasets/\nmodel/\nmodels/\ntestData/\noutput/\noutputs/\n\n*.csv\n\n# temporary files\n*.txt~\n"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 8265,
    "preview": "# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021\n\n[[Paper]](https://esslab.jp/publications/Haor"
  },
  {
    "path": "README_CN.md",
    "chars": 4752,
    "preview": "# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021\r\n\r\n[[论文]](https://esslab.jp/publications/Haora"
  },
  {
    "path": "WINDOWS_INSTALL_GUIDE.md",
    "chars": 3478,
    "preview": "# 🪟 Windows Installation Guide for Virtual Sketching\n\nThis guide provides step-by-step instructions to set up and run th"
  },
  {
    "path": "dataset_utils.py",
    "chars": 57414,
    "preview": "import os\nimport math\nimport random\nimport scipy.io\nimport numpy as np\nimport tensorflow as tf\nfrom PIL import Image\n\nfr"
  },
  {
    "path": "docs/assets/font.css",
    "chars": 1532,
    "preview": "/* Homepage Font */\n\n/* latin-ext */\n@font-face {\n  font-family: 'Lato';\n  font-style: normal;\n  font-weight: 400;\n  src"
  },
  {
    "path": "docs/assets/style.css",
    "chars": 2185,
    "preview": "/* Body */\nbody {\n  background: #e3e5e8;\n  color: #ffffff;\n  font-family: 'Lato', Verdana, Helvetica, sans-serif;\n  font"
  },
  {
    "path": "docs/index.html",
    "chars": 14021,
    "preview": "<!doctype html>\r\n<html lang=\"en\">\r\n\r\n\r\n<!-- === Header Starts === -->\r\n<head>\r\n  <meta http-equiv=\"Content-Type\" content"
  },
  {
    "path": "hyper_parameters.py",
    "chars": 10723,
    "preview": "import tensorflow as tf\n\n\n#############################################\n# Common parameters\n############################"
  },
  {
    "path": "launch_gui.bat",
    "chars": 456,
    "preview": "@echo OFF\n\nREM === Cesta k instalaci Anacondy ===\nset \"CONDAPATH=C:\\ProgramData\\anaconda3\"\n\nREM === Název a cesta k pros"
  },
  {
    "path": "model_common_test.py",
    "chars": 31154,
    "preview": "import rnn\nimport tensorflow as tf\n\nfrom subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, g"
  },
  {
    "path": "model_common_train.py",
    "chars": 64864,
    "preview": "import rnn\nimport tensorflow as tf\n\nfrom subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, g"
  },
  {
    "path": "rasterization_utils/NeuralRenderer.py",
    "chars": 7285,
    "preview": "import tensorflow as tf\n\n\nclass RasterUnit(object):\n    def __init__(self,\n                 raster_size,\n               "
  },
  {
    "path": "rasterization_utils/RealRenderer.py",
    "chars": 7945,
    "preview": "import numpy as np\nimport gizeh\n\n\nclass GizehRasterizor(object):\n    def __init__(self):\n        self.name = 'GizehRaste"
  },
  {
    "path": "rnn.py",
    "chars": 18018,
    "preview": "# Copyright 2019 The Magenta Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
  },
  {
    "path": "subnet_tf_utils.py",
    "chars": 65450,
    "preview": "import tensorflow as tf\n\n\ndef get_initializer(init_method):\n    if init_method == 'xavier_normal':\n        initializer ="
  },
  {
    "path": "test_photograph_to_line.py",
    "chars": 11899,
    "preview": "import numpy as np\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimport argparse\n\n"
  },
  {
    "path": "test_rough_sketch_simplification.py",
    "chars": 16557,
    "preview": "import numpy as np\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimport argparse\n\n"
  },
  {
    "path": "test_vectorization.py",
    "chars": 21177,
    "preview": "import numpy as np\nimport random\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimp"
  },
  {
    "path": "tools/gif_making.py",
    "chars": 7842,
    "preview": "import os\nimport sys\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\n\nsys.path.append('"
  },
  {
    "path": "tools/svg_conversion.py",
    "chars": 9249,
    "preview": "import os\nimport argparse\nimport numpy as np\nfrom xml.dom import minidom\n\n\ndef write_svg_1(path_list, img_size, save_pat"
  },
  {
    "path": "tools/visualize_drawing.py",
    "chars": 9862,
    "preview": "import os\nimport sys\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\n\nsys.path.append('"
  },
  {
    "path": "train_rough_photograph.py",
    "chars": 15825,
    "preview": "import json\nimport os\nimport time\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom PIL import Image\nimport arg"
  },
  {
    "path": "train_vectorization.py",
    "chars": 14317,
    "preview": "import json\nimport os\nimport time\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom PIL import Image\n\nimport mo"
  },
  {
    "path": "utils.py",
    "chars": 15630,
    "preview": "import os\nimport cv2\nimport json\nimport numpy as np\nimport tensorflow as tf\nfrom PIL import Image\nimport matplotlib.pypl"
  },
  {
    "path": "vgg_utils/VGG16.py",
    "chars": 6486,
    "preview": "import tensorflow as tf\n\n\ndef vgg_net(x, n_classes, img_size, reuse, is_train=True, dropout_rate=0.5):\n    # Define a sc"
  },
  {
    "path": "virtual_sketch_gui.py",
    "chars": 4388,
    "preview": "import tkinter as tk\nfrom tkinter import filedialog, messagebox\nimport subprocess\nimport os\nimport threading\nimport glob"
  }
]

About this extraction

This page contains the full source code of the MarkMoHR/virtual_sketching GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (431.9 KB), approximately 115.0k tokens, and a symbol index with 226 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!