[
  {
    "path": ".gitignore",
    "content": ".idea\n.idea/\ndata/\ndatas/\ndataset/\ndatasets/\nmodel/\nmodels/\ntestData/\noutput/\noutputs/\n\n*.csv\n\n# temporary files\n*.txt~\n*.pyc\n.DS_Store\n.gitignore~\n\n*.h5"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021\n\n[[Paper]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[Project Page]](https://markmohr.github.io/virtual_sketching/) | [[中文Readme]](/README_CN.md) | [[中文论文介绍]](https://blog.csdn.net/qq_33000225/article/details/118883153)\n\nThis code is used for **line drawing vectorization**, **rough sketch simplification** and **photograph to vector line drawing**.\n\n<img src='docs/figures/muten.png' height=300><img src='docs/figures/muten-black-full-simplest.gif' height=300>\n\n<img src='docs/figures/rocket.png' height=150><img src='docs/figures/rocket-blue-simplest.gif' height=150>&nbsp;&nbsp;&nbsp;&nbsp;<img src='docs/figures/1390.png' height=150><img src='docs/figures/face-blue-1390-simplest.gif' height=150>\n\n## Outline\n- [Dependencies](#dependencies)\n- [Testing with Trained Weights](#testing-with-trained-weights)\n- [Training](#training)\n- [Citation](#citation)\n- [Projects Using this Model/Method](#projects-using-this-modelmethod)\n- [Blogs Mentioning this Paper](#blogs-mentioning-this-paper)\n- [For Windows users](#-windows-users)\n\n## Dependencies\n - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= version <=1.15.0)\n - [opencv](https://opencv.org/) == 3.4.2\n - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0\n - [scipy](https://www.scipy.org/) == 1.5.2\n - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11\n\n## Testing with Trained Weights\n### Model Preparation\n\nDownload the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): \n  - `pretrain_clean_line_drawings` (105 MB): for vectorization\n  - `pretrain_rough_sketches` (105 MB): for rough sketch simplification\n  - `pretrain_faces` (105 MB): for photograph to line drawing\n\nThen, place them in this file structure:\n```\noutputs/\n    snapshot/\n        pretrain_clean_line_drawings/\n        pretrain_rough_sketches/\n        pretrain_faces/\n```\n\n### Usage\nChoose the image in the `sample_inputs/` directory, and run one of the following commands for each task. The results will be under `outputs/sampling/`.\n\n``` python\npython3 test_vectorization.py --input muten.png\n\npython3 test_rough_sketch_simplification.py --input rocket.png\n\npython3 test_photograph_to_line.py --input 1390.png\n```\n\n**Note!!!** Our approach starts drawing from a randomly selected initial position, so it outputs different results in every testing trial (some might be fine and some might not be good enough). It is recommended to do several trials to select the visually best result. The number of outputs can be defined by the `--sample` argument:\n\n``` python\npython3 test_vectorization.py --input muten.png --sample 10\n\npython3 test_rough_sketch_simplification.py --input rocket.png --sample 10\n\npython3 test_photograph_to_line.py --input 1390.png --sample 10\n```\n\n**Reproducing Paper Figures:** our results (download from [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)) are selected by doing a certain number of trials. Apparently, it is required to use the same initial drawing positions to reproduce our results.\n\n### Additional Tools\n\n#### a) Visualization\n\nOur vector output is stored in a `npz` package. Run the following command to obtain the rendered output and the drawing order. Results will be under the same directory of the `npz` file.\n``` python\npython3 tools/visualize_drawing.py --file path/to/the/result.npz \n```\n\n#### b) GIF Making\n\nTo see the dynamic drawing procedure, run the following command to obtain the `gif`. Result will be under the same directory of the `npz` file.\n``` python\npython3 tools/gif_making.py --file path/to/the/result.npz \n```\n\n\n#### c) Conversion to SVG\n\nOur vector output in a `npz` package is stored as Eq.(1) in the main paper. Run the following command to convert it to the `svg` format. Result will be under the same directory of the `npz` file.\n\n``` python\npython3 tools/svg_conversion.py --file path/to/the/result.npz \n```\n  - The conversion is implemented in two modes (by setting the `--svg_type` argument):\n    - `single` (default): each stroke (a single segment) forms a path in the SVG file\n    - `cluster`: each continuous curve (with multiple strokes) forms a path in the SVG file\n\n**Important Notes**\n\nIn SVG format, all the segments on a path share the same *stroke-width*. While in our stroke design,  strokes on a common curve have different widths. Inside a stroke (a single segment), the thickness also changes linearly from an endpoint to another. \nTherefore, neither of the two conversion methods above generate visually the same results as the ones in our paper.\n*(Please mention this issue in your paper if you do qualitative comparisons with our results in SVG format.)*\n\n\n<br>\n\n## Training\n\n### Preparations\n\nDownload the models [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): \n  - `pretrain_neural_renderer` (40 MB): the pre-trained neural renderer\n  - `pretrain_perceptual_model` (691 MB): the pre-trained perceptual model for raster loss\n\nDownload the datasets [here](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing): \n  - `QuickDraw-clean` (14 MB): for clean line drawing vectorization. Taken from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) dataset.\n  - `QuickDraw-rough` (361 MB): for rough sketch simplification. Synthesized by the pencil drawing generation method from [Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation).\n  - `CelebAMask-faces` (370 MB): for photograph to line drawing. Processed from the [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset.\n\nThen, place them in this file structure:\n```\ndatasets/\n    QuickDraw-clean/\n    QuickDraw-rough/\n    CelebAMask-faces/\noutputs/\n    snapshot/\n        pretrain_neural_renderer/\n        pretrain_perceptual_model/\n```\n\n### Running\n\nIt is recommended to train with multi-GPU. We train each task with 2 GPUs (each with 11 GB).\n\n``` python\npython3 train_vectorization.py\n\npython3 train_rough_photograph.py --data rough\n\npython3 train_rough_photograph.py --data face\n```\n\n<br>\n\n## Citation\n\nIf you use the code and models please cite:\n\n```\n@article{mo2021virtualsketching,\n  title   = {General Virtual Sketching Framework for Vector Line Art},\n  author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},\n  journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},\n  year    = {2021},\n  volume  = {40},\n  number  = {4},\n  pages   = {51:1--51:14}\n}\n```\n\n<br>\n\n## Projects Using this Model/Method\n\n| **[Painterly style transfer](https://github.com/xch-liu/Painterly-Style-Transfer) (TVCG 2023)**  | **[Robot calligraphy](https://github.com/LoYuXr/CalliRewrite) (ICRA 2024)** | \n|:-------------:|:-------------------:|\n| <img src=\"docs/figures/applications/Painterly-Style-Transfer.png\" style=\"height: 170px\"> | <img src=\"docs/figures/applications/robot-calligraphy.png\" style=\"height: 170px\"> |\n| **[Geometrized cartoon line inbetweening](https://github.com/lisiyao21/AnimeInbet) (ICCV 2023)**  | **[Stroke correspondence and inbetweening](https://github.com/MarkMoHR/JoSTC) (TOG 2024)** | \n| <img src=\"docs/figures/applications/Geometrized-Cartoon-Line-Inbetweening.png\" style=\"height: 170px\"> | <img src=\"docs/figures/applications/Vector-Line-Inbetweening2.png\" style=\"height: 170px\"><img src=\"docs/figures/applications/Vector-Line-Inbetweening-dynamic1.gif\" style=\"height: 170px\"> |\n| **[Modelling complex vector drawings](https://github.com/Co-do/Stroke-Cloud) (ICLR 2024)**  | **[Sketch-to-Image Generation](https://github.com/BlockDetail/Block-and-Detail) (UIST 2024)** | \n| <img src=\"docs/figures/applications/complex-vector-drawings.png\" style=\"height: 170px\"> | <img src=\"docs/figures/applications/sketch-to-image.png\" style=\"height: 170px\"> |\n\n## Blogs Mentioning this Paper\n\n- [The state of AI for hand-drawn animation inbetweening](https://yosefk.com/blog/the-state-of-ai-for-hand-drawn-animation-inbetweening.html)\n\n## 🪟 Windows users\n\nSee [WINDOWS_INSTALL_GUIDE.md](WINDOWS_INSTALL_GUIDE.md) for a complete Windows installation guide and GUI usage.\n\n"
  },
  {
    "path": "README_CN.md",
    "content": "# General Virtual Sketching Framework for Vector Line Art - SIGGRAPH 2021\r\n\r\n[[论文]](https://esslab.jp/publications/HaoranSIGRAPH2021.pdf) | [[项目主页]](https://markmohr.github.io/virtual_sketching/)\r\n\r\n这份代码能用于实现：**线稿矢量化**、**粗糙草图简化**和**自然图像到矢量草图转换**。\r\n\r\n<img src='docs/figures/muten.png' height=300><img src='docs/figures/muten-black-full-simplest.gif' height=300>\r\n\r\n<img src='docs/figures/rocket.png' height=150><img src='docs/figures/rocket-blue-simplest.gif' height=150>&nbsp;&nbsp;&nbsp;&nbsp;<img src='docs/figures/1390.png' height=150><img src='docs/figures/face-blue-1390-simplest.gif' height=150>\r\n\r\n## 目录\r\n- [环境依赖](#环境依赖)\r\n- [使用预训练模型测试](#使用预训练模型测试)\r\n- [重新训练](#重新训练)\r\n- [引用](#引用)\r\n\r\n## 环境依赖\r\n - [Tensorflow](https://www.tensorflow.org/) (1.12.0 <= 版本 <=1.15.0)\r\n - [opencv](https://opencv.org/) == 3.4.2\r\n - [pillow](https://pillow.readthedocs.io/en/latest/index.html) == 6.2.0\r\n - [scipy](https://www.scipy.org/) == 1.5.2\r\n - [gizeh](https://github.com/Zulko/gizeh) == 0.1.11\r\n\r\n## 使用预训练模型测试\r\n### 模型下载与准备\r\n\r\n在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型：\r\n  - `pretrain_clean_line_drawings` (105 MB): 用于线稿矢量化\r\n  - `pretrain_rough_sketches` (105 MB): 用于粗糙草图简化\r\n  - `pretrain_faces` (105 MB): 用于自然图像到矢量草图转换\r\n\r\n然后，按照如下结构放置模型：\r\n```\r\noutputs/\r\n    snapshot/\r\n        pretrain_clean_line_drawings/\r\n        pretrain_rough_sketches/\r\n        pretrain_faces/\r\n```\r\n\r\n### 测试方法\r\n在`sample_inputs/`文件夹下选择图像，然后根据任务类型运行下面其中一个命令。生成结果会在`outputs/sampling/`目录下看到。\r\n\r\n``` python\r\npython3 test_vectorization.py --input muten.png\r\n\r\npython3 test_rough_sketch_simplification.py --input rocket.png\r\n\r\npython3 test_photograph_to_line.py --input 1390.png\r\n```\r\n\r\n**注意!!!** 我们的方法从一个随机挑选的初始位置启动绘制，所以每跑一次测试理论上都会得到一个不同的结果（有可能效果不错，但也可能效果不是很好）。因此，建议做多几次测试来挑选看上去最好的结果。也可以通过设置 `--sample`参数来定义跑一次测试代码同时输出（不同结果）的数量：\r\n\r\n``` python\r\npython3 test_vectorization.py --input muten.png --sample 10\r\n\r\npython3 test_rough_sketch_simplification.py --input rocket.png --sample 10\r\n\r\npython3 test_photograph_to_line.py --input 1390.png --sample 10\r\n```\r\n\r\n**如何复现论文展示的结果？** 可以从[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载论文展示的结果。这些是我们通过若干次测试得到不同输出后挑选的最好的结果。显然，若要复现这些结果，需要使用相同的初始位置启动绘制。\r\n\r\n### 其他工具\r\n\r\n#### a) 可视化\r\n\r\n我们的矢量输出均使用`npz` 文件包存储。运行以下的命令可以得到渲染后的结果以及绘制顺序。可以在`npz` 文件包相同的目录下找到这些可视化结果。\r\n``` python\r\npython3 tools/visualize_drawing.py --file path/to/the/result.npz \r\n```\r\n\r\n#### b) GIF制作\r\n\r\n若要看到动态的绘制过程，可以运行以下命令来得到 `gif`。结果在`npz` 文件包相同的目录下。\r\n``` python\r\npython3 tools/gif_making.py --file path/to/the/result.npz \r\n```\r\n\r\n\r\n#### c) 转化为SVG\r\n\r\n`npz` 文件包中的矢量结果均按照论文里面的公式(1)格式存储。可以运行以下命令行，来将其转化为 `svg` 文件格式。结果在`npz` 文件包相同的目录下。\r\n\r\n``` python\r\npython3 tools/svg_conversion.py --file path/to/the/result.npz \r\n```\r\n  - 转化过程以两种模式实现（设置`--svg_type`参数）：\r\n    - `single` (默认模式): 每个笔划（一根单独的曲线）构成SVG文件中的一个path路径\r\n    - `cluster`: 每个连续曲线（多个笔划）构成SVG文件中的一个path路径\r\n\r\n**重要注意事项**\r\n\r\n在SVG文件格式中，一个path上的所有线段均只有同一个线宽（*stroke-width*）。然而在我们论文里面，定义一个连续曲线上所有的笔划可以有不同的线宽。同时，对于一个单独的笔划（贝塞尔曲线），定义其线宽从一个端点到另一个端点线性递增或者递减。\r\n\r\n因此，上述两个转化方法得到的SVG结果理论上都无法保证跟论文里面的结果在视觉上完全一致。（*假如你在论文里面使用这里转化后的SVG结果进行视觉上的对比，请提及此问题。*）\r\n\r\n\r\n<br>\r\n\r\n## 重新训练\r\n\r\n### 训练准备\r\n\r\n在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载模型：\r\n  - `pretrain_neural_renderer` (40 MB): 预训练好的神经网络渲染器\r\n  - `pretrain_perceptual_model` (691 MB): 预训练好的perceptual model，用于算 raster loss\r\n\r\n在[这里](https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing)下载训练数据集：\r\n  - `QuickDraw-clean` (14 MB): 用于线稿矢量化。来自 [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset)数据集。\r\n  - `QuickDraw-rough` (361 MB): 用于粗糙草图简化。利用[Sketch Simplification](https://github.com/bobbens/sketch_simplification#pencil-drawing-generation)里面的铅笔画图像生成方法合成。\r\n  - `CelebAMask-faces` (370 MB): 用于自然图像到矢量草图转换。使用[CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)数据集进行处理后得到。\r\n\r\n然后，按照如下结构放置数据集：\r\n```\r\ndatasets/\r\n    QuickDraw-clean/\r\n    QuickDraw-rough/\r\n    CelebAMask-faces/\r\noutputs/\r\n    snapshot/\r\n        pretrain_neural_renderer/\r\n        pretrain_perceptual_model/\r\n```\r\n\r\n### 训练方法\r\n\r\n建议使用多GPU进行训练。每个任务，我们均使用2个GPU（每个11 GB）来训练。\r\n\r\n``` python\r\npython3 train_vectorization.py\r\n\r\npython3 train_rough_photograph.py --data rough\r\n\r\npython3 train_rough_photograph.py --data face\r\n```\r\n\r\n<br>\r\n\r\n## 引用\r\n\r\n若使用此代码和模型，请引用本工作，谢谢！\r\n\r\n```\r\n@article{mo2021virtualsketching,\r\n  title   = {General Virtual Sketching Framework for Vector Line Art},\r\n  author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},\r\n  journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},\r\n  year    = {2021},\r\n  volume  = {40},\r\n  number  = {4},\r\n  pages   = {51:1--51:14}\r\n}\r\n```\r\n\r\n"
  },
  {
    "path": "WINDOWS_INSTALL_GUIDE.md",
    "content": "# 🪟 Windows Installation Guide for Virtual Sketching\n\nThis guide provides step-by-step instructions to set up and run the [Virtual Sketching](https://github.com/MarkMoHR/virtual_sketching) project on Windows using Anaconda and Python 3.6.\n\n## ✅ Requirements\n\n- Windows 10 or newer\n- Anaconda installed\n- Git installed (optional, but recommended)\n\n---\n\n## 📦 Step 1: Create and Activate Conda Environment\n\n```bash\nconda create -n virtual_sketching python=3.6 -y\nconda activate virtual_sketching\n```\n\n## 📂 Step 2: Clone the Repository\n\n```bash\ncd D:\\\ngit clone https://github.com/MarkMoHR/virtual_sketching.git\ncd virtual_sketching\n```\n\n(If you plan to contribute, consider forking the repo and using your own URL.)\n\n---\n\n## 🔧 Step 3: Install Required Packages\n\n### From `conda`:\n```bash\nconda install opencv=3.4.2 pillow=6.2.0 scipy=1.5.2 -y\nconda install -c conda-forge pycairo gtk3 cffi -y\n```\n\n### Then remove default TensorFlow (if installed via conda):\n```bash\nconda remove tensorflow\n```\n\n### Install required packages via `pip`:\n```bash\npip install tensorflow==1.15.0\npip install numpy gizeh cairocffi matplotlib svgwrite\n```\n\n> ⚠️ Do not upgrade `pillow`, `scipy`, or `tensorflow` — newer versions are incompatible.\n\n---\n\n## 🛠️ Step 4: Fix Backend Compatibility\n\nIn `utils.py`, near the top, ensure the following:\n\n```python\nfrom PIL import Image\nimport matplotlib\nmatplotlib.use('TkAgg')  # force compatible backend\nimport matplotlib.pyplot as plt\n```\n\nThis ensures proper rendering with tkinter on Windows.\n\n---\n\n## 🚀 Step 5: Run a Demo\n\nFrom the project directory:\n\n```bash\npython test_vectorization.py --input sample_inputs\\muten.png --sample 5\n```\n\n> ⚠️ This script only generates `.npz` and `.png` files. To convert to `.svg`, see next section.\n\n---\n\n## 🖼️ Optional: Use GUI Tool (Windows Only)\n\n### Step 1: Launch GUI with provided batch file\n\nUse the `runme.bat` file to activate the conda environment and launch the Python GUI:\n\n```bat\nrem runme.bat\nset CONDAPATH=C:\\ProgramData\\anaconda3\nset ENVNAME=virtual_sketching\ncall %CONDAPATH%\\Scripts\\activate.bat %ENVNAME%\npython virtual_sketch_gui.py\npause\n```\n\n### Step 2: Select an input image and model\n\nThe GUI allows you to:\n- Choose input image (PNG, JPEG, BMP, etc.)\n- Select one of the three models\n- Automatically runs processing and converts results to SVG\n- Saves all outputs into a `sketches/` subfolder of the input image directory\n\n---\n\n## 🧠 Known Compatibility Notes\n\n- Python 3.6 and TensorFlow 1.15 are required (due to use of `tensorflow.contrib`)\n- Windows support requires manual setup of Gizeh and Cairo backends\n- GPU usage optional — TensorFlow 1.15 requires CUDA 10.0 and cuDNN 7\n\n---\n\n## 🧩 Troubleshooting Tips\n\n- ❌ `No module named '_cffi_backend'` → Run: `conda install -c conda-forge cffi`\n- ❌ `ImportError: cannot import name 'draw_svg_from_npz'` → Use `svg_conversion.py` instead\n- ❌ `.svg` looks wrong → Make sure you're using the official `svg_conversion.py` from `tools/`\n- ❌ Missing `.svg`? → Use `virtual_sketch_gui.py` or run `tools/svg_conversion.py` manually on `.npz`\n\n---\n\n## 🤝 Contributing\n\nFeel free to open issues or pull requests if you encounter bugs or want to improve the Windows support!\n\n---\n\n## 📁 Folder Structure Suggestion\n\n```\nvirtual_sketching/\n├── sample_inputs/\n├── tools/\n├── outputs/\n├── virtual_sketch_gui.py\n├── runme.bat\n├── README.md\n└── WINDOWS_INSTALL_GUIDE.md   ← You are here\n```\n\n---\n\nMade with ❤️ by the community to help Windows users get started!\n\n"
  },
  {
    "path": "dataset_utils.py",
    "content": "import os\nimport math\nimport random\nimport scipy.io\nimport numpy as np\nimport tensorflow as tf\nfrom PIL import Image\n\nfrom rasterization_utils.RealRenderer import GizehRasterizor as RealRenderer\n\n\ndef copy_hparams(hparams):\n    \"\"\"Return a copy of an HParams instance.\"\"\"\n    return tf.contrib.training.HParams(**hparams.values())\n\n\nclass GeneralRawDataLoader(object):\n    def __init__(self,\n                 image_path,\n                 raster_size,\n                 test_dataset):\n        self.image_path = image_path\n        self.raster_size = raster_size\n        self.test_dataset = test_dataset\n\n    def get_test_image(self, random_cursor=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):\n        input_image_data, image_size_test = self.gen_input_images(self.image_path)\n        input_image_data = np.array(input_image_data,\n                                    dtype=np.float32)  # (1, image_size, image_size, (3)), [0.0-strokes, 1.0-BG]\n\n        return input_image_data, \\\n               self.gen_init_cursors(input_image_data, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \\\n               image_size_test\n\n    def gen_input_images(self, image_path):\n        img = Image.open(image_path).convert('RGB')\n        height, width = img.height, img.width\n        max_dim = max(height, width)\n\n        img = np.array(img, dtype=np.uint8)\n\n        if height != width:\n            # Padding to a square image\n            if self.test_dataset == 'clean_line_drawings':\n                pad_value = [255, 255, 255]\n            elif self.test_dataset == 'faces':\n                pad_value = [0, 0, 0]\n            else:\n                # TODO: find better padding pixel value\n                pad_value = img[height - 10, width - 10]\n\n            img_r, img_g, img_b = img[:, :, 0], img[:, :, 1], img[:, :, 2]\n            pad_width = max_dim - width\n            pad_height = max_dim - height\n\n            pad_img_r = np.pad(img_r, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[0])\n            pad_img_g = np.pad(img_g, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[1])\n            pad_img_b = np.pad(img_b, ((0, pad_height), (0, pad_width)), 'constant', constant_values=pad_value[2])\n            image_array = np.stack([pad_img_r, pad_img_g, pad_img_b], axis=-1)\n        else:\n            image_array = img\n\n        if self.test_dataset == 'faces' and max_dim != 256:\n            image_array_resize = Image.fromarray(image_array, 'RGB')\n            image_array_resize = image_array_resize.resize(size=(256, 256), resample=Image.BILINEAR)\n            image_array = np.array(image_array_resize, dtype=np.uint8)\n\n        assert image_array.shape[0] == image_array.shape[1]\n        img_size = image_array.shape[0]\n        image_array = image_array.astype(np.float32)\n        if self.test_dataset == 'clean_line_drawings':\n            image_array = image_array[:, :, 0] / 255.0  # [0.0-stroke, 1.0-BG]\n        else:\n            image_array = image_array / 255.0  # [0.0-stroke, 1.0-BG]\n        image_array = np.expand_dims(image_array, axis=0)\n        return image_array, img_size\n\n    def crop_patch(self, image, center, image_size, crop_size):\n        x0 = center[0] - crop_size // 2\n        x1 = x0 + crop_size\n        y0 = center[1] - crop_size // 2\n        y1 = y0 + crop_size\n        x0 = max(0, min(x0, image_size))\n        y0 = max(0, min(y0, image_size))\n        x1 = max(0, min(x1, image_size))\n        y1 = max(0, min(y1, image_size))\n        patch = image[y0:y1, x0:x1]\n        return patch\n\n    def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):\n        # sketch_image: [0.0-stroke, 1.0-BG]\n        image_size = sketch_image.shape[0]\n        if np.sum(1.0 - sketch_image) == 0:\n            center = np.zeros((2), dtype=np.int32)\n            return center\n        else:\n            while True:\n                center = np.random.randint(0, image_size, size=(2))  # (2), in large size\n                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)\n                if np.sum(patch) != 0:\n                    if not init_cursor_on_undrawn_pixel:\n                        return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n                    else:\n                        center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)\n                        if np.sum(center_patch) != 0:\n                            return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n\n    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):\n        init_cursor_batch_list = []\n        for cursor_i in range(init_cursor_num):\n            if random_pos:\n                init_cursor_batch = []\n                for i in range(len(sketch_data)):\n                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]\n                    center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)\n                    init_cursor_batch.append(center)\n\n                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)\n            else:\n                raise Exception('Not finished')\n            init_cursor_batch_list.append(init_cursor_batch)\n\n        if init_cursor_num == 1:\n            init_cursor_batch = init_cursor_batch_list[0]\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)\n        else:\n            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(\n                np.float32)  # (N, init_cursor_num, 1, 2)\n\n        return init_cursor_batch\n\n\ndef load_dataset_testing(test_data_base_dir, test_dataset, test_img_name, model_params):\n    assert test_dataset in ['clean_line_drawings', 'rough_sketches', 'faces']\n    img_path = os.path.join(test_data_base_dir, test_dataset, test_img_name)\n    print('Loaded {} from {}'.format(img_path, test_dataset))\n\n    eval_model_params = copy_hparams(model_params)\n    eval_model_params.use_input_dropout = 0\n    eval_model_params.use_recurrent_dropout = 0\n    eval_model_params.use_output_dropout = 0\n    eval_model_params.batch_size = 1\n    eval_model_params.model_mode = 'sample'\n\n    sample_model_params = copy_hparams(eval_model_params)\n    sample_model_params.batch_size = 1  # only sample one at a time\n    sample_model_params.max_seq_len = 1  # sample one point at a time\n\n    test_set = GeneralRawDataLoader(img_path, eval_model_params.raster_size, test_dataset=test_dataset)\n\n    result = [test_set, eval_model_params, sample_model_params]\n    return result\n\n\nclass GeneralMultiObjectDataLoader(object):\n    def __init__(self,\n                 stroke3_data,\n                 batch_size,\n                 raster_size,\n                 image_size_small,\n                 image_size_large,\n                 is_bin,\n                 is_train):\n        self.batch_size = batch_size  # minibatch size\n        self.raster_size = raster_size\n        self.image_size_small = image_size_small\n        self.image_size_large = image_size_large\n        self.is_bin = is_bin\n        self.is_train = is_train\n\n        self.num_batches = len(stroke3_data) // self.batch_size\n        self.batch_idx = -1\n        print('batch_size', batch_size, ', num_batches', self.num_batches)\n\n        self.rasterizor = RealRenderer()\n        self.memory_sketch_data_batch = []\n\n        assert type(stroke3_data) is list\n        self.preprocess_rand_data(stroke3_data)\n\n    def preprocess_rand_data(self, stroke3):\n        if self.is_train:\n            random.shuffle(stroke3)\n        self.stroke3_data = stroke3\n\n    def cal_dist(self, posA, posB):\n        return np.sqrt(np.sum(np.power(posA - posB, 2)))\n\n    def invalid_position(self, pos, obj_size, pos_list, size_list):\n        if len(pos_list) == 0:\n            return False\n\n        pos_a = pos\n        size_a = obj_size\n        for i in range(len(pos_list)):\n            pos_b = pos_list[i]\n            size_b = size_list[i]\n\n            if self.cal_dist(pos_a, pos_b) < ((size_a + size_b) // 4):\n                return True\n\n        return False\n\n    def get_object_info(self, image_size, vary_thickness=True, try_total_times=3):\n        if image_size <= 172:\n            obj_num = 1\n            obj_thickness_list = [3]\n        elif image_size <= 225:\n            obj_num = random.randint(1, 2)\n            obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))\n        elif image_size <= 278:\n            obj_num = 2\n            obj_thickness_list = np.random.randint(3, 4 + 1, size=(obj_num))\n        elif image_size <= 331:\n            obj_num = random.randint(2, 3)\n            while True:\n                obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))\n                if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:\n                    break\n        elif image_size <= 384:\n            obj_num = 3\n            while True:\n                obj_thickness_list = np.random.randint(3, 5 + 1, size=(obj_num))\n                if np.sum(obj_thickness_list) / obj_num != 5 and np.sum(obj_thickness_list) < 13:\n                    break\n        else:\n            raise Exception('Invalid image_size', image_size)\n\n        if not vary_thickness:\n            num_item = len(obj_thickness_list)\n            obj_thickness_list = [3 for _ in range(num_item)]\n\n        obj_pos_list = []\n        obj_size_list = []\n        if obj_num == 1:\n            obj_size_list.append(image_size)\n            center = (image_size // 2, image_size // 2)\n            obj_pos_list.append(center)\n        else:\n            for obj_i in range(obj_num):\n                for try_i in range(try_total_times):\n                    obj_size = random.randint(128, image_size * 3 // 4)\n                    obj_center = np.random.randint(obj_size // 3, image_size - (obj_size // 3) + 1, size=(2))\n\n                    if not self.invalid_position(obj_center, obj_size, obj_pos_list,\n                                                 obj_size_list) or try_i == try_total_times - 1:\n                        obj_pos_list.append(obj_center)\n                        obj_size_list.append(obj_size)\n                        break\n\n        assert len(obj_size_list) == len(obj_pos_list) == len(obj_thickness_list) == obj_num\n        return obj_num, obj_size_list, obj_pos_list, obj_thickness_list\n\n    def object_pasting(self, obj_img, canvas_img, center):\n        c_y, c_x = center[0], center[1]\n        obj_size = obj_img.shape[0]\n        canvas_size = canvas_img.shape[0]\n        box_left = max(0, c_x - obj_size // 2)\n        box_right = min(canvas_size, c_x + obj_size // 2)\n        box_up = max(0, c_y - obj_size // 2)\n        box_bottom = min(canvas_size, c_y + obj_size // 2)\n\n        box_canvas = canvas_img[box_up: box_bottom, box_left: box_right]\n\n        obj_box_up = box_up - (c_y - obj_size // 2)\n        obj_box_left = box_left - (c_x - obj_size // 2)\n        box_obj = obj_img[obj_box_up: obj_box_up + (box_bottom - box_up),\n                  obj_box_left: obj_box_left + (box_right - box_left)]\n\n        box_canvas += box_obj\n\n        rst_canvas = np.copy(canvas_img)\n        rst_canvas[box_up: box_bottom, box_left: box_right] = box_canvas\n        rst_canvas = np.clip(rst_canvas, 0.0, 1.0)\n\n        return rst_canvas\n\n    def get_multi_object_image(self, img_size, vary_thickness):\n        object_num, object_size_list, object_pos_list, object_thickness_list = self.get_object_info(\n            img_size, vary_thickness=vary_thickness)\n\n        canvas = np.zeros(shape=(img_size, img_size), dtype=np.float32)\n\n        for obj_i in range(object_num):\n            rand_idx = np.random.randint(0, len(self.stroke3_data))\n            rand_stroke3 = self.stroke3_data[rand_idx]  # (N_points, 3)\n\n            object_size = object_size_list[obj_i]\n            object_enter = object_pos_list[obj_i]\n            object_thickness = object_thickness_list[obj_i]\n\n            stroke_image = self.gen_stroke_images([rand_stroke3], object_size, object_thickness)\n            stroke_image = 1.0 - stroke_image[0]  # (image_size, image_size), [0.0-BG, 1.0-strokes]\n\n            canvas = self.object_pasting(stroke_image, canvas, object_enter)  # [0.0-BG, 1.0-strokes]\n\n        canvas = 1.0 - canvas  # [0.0-strokes, 1.0-BG]\n        return canvas\n\n    def get_batch_from_memory(self, memory_idx, vary_thickness, fixed_image_size=-1, random_cursor=True,\n                              init_cursor_on_undrawn_pixel=False, init_cursor_num=1):\n        if len(self.memory_sketch_data_batch) >= memory_idx + 1:\n            sketch_data_batch = self.memory_sketch_data_batch[memory_idx]\n            sketch_data_batch = np.expand_dims(sketch_data_batch,\n                                               axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]\n            image_size_rand = sketch_data_batch.shape[1]\n        else:\n            if fixed_image_size == -1:\n                image_size_rand = random.randint(self.image_size_small, self.image_size_large)\n            else:\n                image_size_rand = fixed_image_size\n\n            multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness)  # [0.0-strokes, 1.0-BG]\n            self.memory_sketch_data_batch.append(multi_obj_image)\n            sketch_data_batch = np.expand_dims(multi_obj_image,\n                                               axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]\n\n        return None, sketch_data_batch, \\\n               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_on_undrawn_pixel, init_cursor_num), \\\n               image_size_rand\n\n    def get_batch_multi_res(self, loop_num, vary_thickness, random_cursor=True,\n                            init_cursor_on_undrawn_pixel=False, init_cursor_num=1):\n        sketch_data_batch = []\n        init_cursors_batch = []\n        image_size_batch = []\n        batch_size_per_loop = self.batch_size // loop_num\n        for loop_i in range(loop_num):\n            image_size_rand = random.randint(self.image_size_small, self.image_size_large)\n            sketch_data_sub_batch = []\n            for batch_i in range(batch_size_per_loop):\n                multi_obj_image = self.get_multi_object_image(image_size_rand, vary_thickness)  # [0.0-strokes, 1.0-BG]\n                sketch_data_sub_batch.append(multi_obj_image)\n            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,\n                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]\n\n            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor,\n                                                           init_cursor_on_undrawn_pixel, init_cursor_num)\n            sketch_data_batch.append(sketch_data_sub_batch)\n            init_cursors_batch.append(init_cursors_sub_batch)\n            image_size_batch.append(image_size_rand)\n\n        return None, \\\n               sketch_data_batch, \\\n               init_cursors_batch, \\\n               image_size_batch\n\n    def gen_stroke_images(self, stroke3_list, image_size, stroke_width):\n        \"\"\"\n        :param stroke3_list: list of (batch_size,), each with (N_points, 3)\n        :param image_size:\n        :return:\n        \"\"\"\n        gt_image_array = self.rasterizor.raster_func(stroke3_list, image_size, stroke_width=stroke_width,\n                                                     is_bin=self.is_bin, version='v2')\n        gt_image_array = np.stack(gt_image_array, axis=0)\n        gt_image_array = 1.0 - gt_image_array  # (batch_size, image_size, image_size), [0.0-strokes, 1.0-BG]\n        return gt_image_array\n\n    def crop_patch(self, image, center, image_size, crop_size):\n        x0 = center[0] - crop_size // 2\n        x1 = x0 + crop_size\n        y0 = center[1] - crop_size // 2\n        y1 = y0 + crop_size\n        x0 = max(0, min(x0, image_size))\n        y0 = max(0, min(y0, image_size))\n        x1 = max(0, min(x1, image_size))\n        y1 = max(0, min(y1, image_size))\n        patch = image[y0:y1, x0:x1]\n        return patch\n\n    def gen_init_cursor_single(self, sketch_image, init_cursor_on_undrawn_pixel, misalign_size=3):\n        # sketch_image: [0.0-stroke, 1.0-BG]\n        image_size = sketch_image.shape[0]\n        if np.sum(1.0 - sketch_image) == 0:\n            center = np.zeros((2), dtype=np.int32)\n            return center\n        else:\n            while True:\n                center = np.random.randint(0, image_size, size=(2))  # (2), in large size\n                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)\n                if np.sum(patch) != 0:\n                    if not init_cursor_on_undrawn_pixel:\n                        return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n                    else:\n                        center_patch = 1.0 - self.crop_patch(sketch_image, center, image_size, misalign_size)\n                        if np.sum(center_patch) != 0:\n                            return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n\n    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_on_undrawn_pixel=False, init_cursor_num=1):\n        init_cursor_batch_list = []\n        for cursor_i in range(init_cursor_num):\n            if random_pos:\n                init_cursor_batch = []\n                for i in range(len(sketch_data)):\n                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]\n                    center = self.gen_init_cursor_single(sketch_image, init_cursor_on_undrawn_pixel)\n                    init_cursor_batch.append(center)\n\n                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)\n            else:\n                raise Exception('Not finished')\n            init_cursor_batch_list.append(init_cursor_batch)\n\n        if init_cursor_num == 1:\n            init_cursor_batch = init_cursor_batch_list[0]\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)\n        else:\n            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(\n                np.float32)  # (N, init_cursor_num, 1, 2)\n\n        return init_cursor_batch\n\n\ndef load_dataset_multi_object(dataset_base_dir, model_params):\n    train_stroke3_data = []\n    val_stroke3_data = []\n\n    if model_params.data_set == 'clean_line_drawings':\n        def load_qd_npz_data(npz_path):\n            data = np.load(npz_path, encoding='latin1', allow_pickle=True)\n            selected_strokes3 = data['stroke3']  # (N_sketches,), each with (N_points, 3)\n            selected_strokes3 = selected_strokes3.tolist()\n            return selected_strokes3\n\n        base_dir_clean = 'QuickDraw-clean'\n        cates = ['airplane', 'bus', 'car', 'sailboat', 'bird', 'cat', 'dog',\n                 # 'rabbit',\n                 'tree', 'flower',\n                 # 'circle', 'line',\n                 'zigzag'\n                 ]\n\n        for cate in cates:\n            train_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'train', cate + '.npz')\n            val_cate_sketch_data_npz_path = os.path.join(dataset_base_dir, base_dir_clean, 'test', cate + '.npz')\n            print(train_cate_sketch_data_npz_path)\n\n            train_cate_stroke3_data = load_qd_npz_data(\n                train_cate_sketch_data_npz_path)  # list of (N_sketches,), each with (N_points, 3)\n            val_cate_stroke3_data = load_qd_npz_data(val_cate_sketch_data_npz_path)\n            train_stroke3_data += train_cate_stroke3_data\n            val_stroke3_data += val_cate_stroke3_data\n    else:\n        raise Exception('Unknown data type:', model_params.data_set)\n\n    print('Loaded {}/{} from {}'.format(len(train_stroke3_data), len(val_stroke3_data), model_params.data_set))\n    print('model_params.max_seq_len %i.' % model_params.max_seq_len)\n\n    eval_sample_model_params = copy_hparams(model_params)\n    eval_sample_model_params.use_input_dropout = 0\n    eval_sample_model_params.use_recurrent_dropout = 0\n    eval_sample_model_params.use_output_dropout = 0\n    eval_sample_model_params.batch_size = 1  # only sample one at a time\n    eval_sample_model_params.model_mode = 'eval_sample'\n\n    train_set = GeneralMultiObjectDataLoader(train_stroke3_data,\n                                             model_params.batch_size, model_params.raster_size,\n                                             model_params.image_size_small, model_params.image_size_large,\n                                             model_params.bin_gt, is_train=True)\n    val_set = GeneralMultiObjectDataLoader(val_stroke3_data,\n                                           eval_sample_model_params.batch_size, eval_sample_model_params.raster_size,\n                                           eval_sample_model_params.image_size_small,\n                                           eval_sample_model_params.image_size_large,\n                                           eval_sample_model_params.bin_gt, is_train=False)\n\n    result = [train_set, val_set, model_params, eval_sample_model_params]\n    return result\n\n\nclass GeneralDataLoaderMultiObjectRough(object):\n    def __init__(self,\n                 photo_data,\n                 sketch_data,\n                 texture_data,\n                 shadow_data,\n                 batch_size,\n                 raster_size,\n                 image_size_small,\n                 image_size_large,\n                 is_train):\n        self.batch_size = batch_size  # minibatch size\n        self.raster_size = raster_size\n        self.image_size_small = image_size_small\n        self.image_size_large = image_size_large\n        self.is_train = is_train\n\n        assert photo_data is not None\n        assert len(photo_data) == len(sketch_data)\n        # self.num_batches = len(sketch_data) // self.batch_size\n        self.batch_idx = -1\n        print('batch_size', batch_size)\n\n        assert type(photo_data) is list\n        assert type(sketch_data) is list\n        assert type(texture_data) is list and len(texture_data) > 0\n        assert type(shadow_data) is list and len(shadow_data) > 0\n        self.photo_data = photo_data\n        self.sketch_data = sketch_data\n        self.texture_data = texture_data  # list of (H, W, 3), [0, 255], uint8\n        self.shadow_data = shadow_data  # list of (H, W), [0, 255], uint8\n\n        self.memory_photo_data_batch = []\n        self.memory_sketch_data_batch = []\n\n    def rough_augmentation(self, raw_photo, texture_prob=0.20, noise_prob=0.15, shadow_prob=0.20):\n        # raw_photo: (H, W), [0.0-stroke, 1.0-BG]\n        aug_photo_rgb = np.stack([raw_photo for _ in range(3)], axis=-1)\n\n        def texture_generation(texture_list, image_shape):\n            while True:\n                random_texture_id = random.randint(0, len(texture_list) - 1)\n                texture_large = texture_list[random_texture_id]\n                t_w, t_h = texture_large.shape[1], texture_large.shape[0]\n                i_w, i_h = image_shape[1], image_shape[0]\n\n                if t_h >= i_h and t_w >= i_w:\n                    texture_large = np.copy(texture_large).astype(np.float32)\n                    crop_y = random.randint(0, t_h - i_h)\n                    crop_x = random.randint(0, t_w - i_w)\n                    crop_texture = texture_large[crop_y: crop_y + i_h, crop_x: crop_x + i_w, :]\n                    return crop_texture\n\n        def texture_change(rough_img_, all_textures):\n            # rough_img_: (H, W, 3), [0.0-stroke, 1.0-BG]\n\n            texture_image = texture_generation(all_textures, rough_img_.shape)  # (h, w, 3)\n            texture_image /= 255.0\n\n            rand_b = np.random.uniform(1.0, 2.0, size=rough_img_.shape)\n            textured_img = rough_img_ * (texture_image / rand_b + (rand_b - 1.0) / rand_b)  # [0.0, 1.0]\n            return textured_img\n\n        def noise_change(rough_img_, noise_scale=25):\n            # rough_img_: (H, W, 3), [0.0, 1.0]\n            rough_img_255 = rough_img_ * 255.0\n\n            rand_noise = np.random.uniform(-1.0, 1.0, size=rough_img_255.shape) * noise_scale\n            # rand_noise = np.random.normal(size=rough_img.shape) * noise_scale\n            noise_img = rough_img_255 + rand_noise\n            noise_img = np.clip(noise_img, 0.0, 255.0)\n            noise_img /= 255.0\n            return noise_img\n\n        def shadow_change(rough_img_, all_shadows):\n            # rough_img_: (H, W, 3), [0.0, 1.0]\n            rough_img_255 = rough_img_ * 255.0\n\n            shadow_i = random.randint(0, len(all_shadows) - 1)\n            shadow_full = all_shadows[shadow_i]  # (H, W), [0, 255]\n            shadow_img_size = shadow_full.shape[0]\n\n            while True:\n                position = np.random.randint(-shadow_img_size // 2, shadow_img_size // 2, (2))\n                if abs(position[0]) > (shadow_img_size // 8) and abs(position[1]) > (shadow_img_size // 8):\n                    break\n            position += (shadow_img_size // 2)\n\n            crop_up = shadow_img_size - position[0]\n            crop_left = shadow_img_size - position[1]\n\n            shadow_image_large = shadow_full[crop_up: crop_up + shadow_img_size, crop_left: crop_left + shadow_img_size]\n            shadow_bg = Image.fromarray(shadow_image_large, 'L')\n            shadow_bg = shadow_bg.resize(size=(rough_img_255.shape[1], rough_img_255.shape[0]), resample=Image.BILINEAR)\n            shadow_bg = np.array(shadow_bg, dtype=np.float32) / 255.0  # [0.0-shadow, 1.0-BG]\n            shadow_bg = np.stack([shadow_bg for _ in range(3)], axis=-1)\n\n            shadow_img = rough_img_255 * shadow_bg\n            shadow_img /= 255.0\n            return shadow_img\n\n        if random.random() <= texture_prob:\n            aug_photo_rgb = texture_change(aug_photo_rgb, self.texture_data)  # (H, W, 3), [0.0, 1.0]\n        if random.random() <= noise_prob:\n            aug_photo_rgb = noise_change(aug_photo_rgb)  # (H, W, 3), [0.0, 1.0]\n        if random.random() <= shadow_prob:\n            aug_photo_rgb = shadow_change(aug_photo_rgb, self.shadow_data)  # (H, W, 3), [0.0, 1.0]\n\n        return aug_photo_rgb\n\n    def image_interpolation(self, photo, sketch, photo_prob):\n        interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)\n        interp_photo = np.clip(interp_photo, 0.0, 1.0)\n        return interp_photo\n\n    def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,\n                              photo_prob=1.0, init_cursor_num=1):\n        if len(self.memory_sketch_data_batch) >= memory_idx + 1:\n            photo_data_batch = self.memory_photo_data_batch[memory_idx]\n            sketch_data_batch = self.memory_sketch_data_batch[memory_idx]\n            image_size_rand = sketch_data_batch.shape[1]\n        else:\n            if fixed_image_size == -1:\n                image_size_rand = random.randint(self.image_size_small, self.image_size_large)\n            else:\n                image_size_rand = fixed_image_size\n\n            # photo_prob = 0.0 if photo_prob_type == 'zero' else 1.0\n            photo_data_batch, sketch_data_batch = self.select_sketch(\n                image_size_rand)  # both: (H, W), [0.0-stroke, 1.0-BG]\n            photo_data_batch = self.rough_augmentation(photo_data_batch)  # (H, W, 3), [0.0-stroke, 1.0-BG]\n\n            self.memory_photo_data_batch.append(photo_data_batch)\n            self.memory_sketch_data_batch.append(sketch_data_batch)\n\n        if interpolate_type == 'prob':\n            if random.random() >= photo_prob:\n                photo_data_batch = np.stack([sketch_data_batch for _ in range(3)],\n                                            axis=-1)  # (H, W, 3), [0.0-stroke, 1.0-BG]\n        elif interpolate_type == 'image':\n            photo_data_batch = self.image_interpolation(\n                photo_data_batch, np.stack([sketch_data_batch for _ in range(3)], axis=-1), photo_prob)\n        else:\n            raise Exception('Unknown interpolate_type', interpolate_type)\n\n        photo_data_batch = np.expand_dims(photo_data_batch, axis=0)  # (1, image_size, image_size, 3)\n        sketch_data_batch = np.expand_dims(sketch_data_batch,\n                                           axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]\n\n        return photo_data_batch, sketch_data_batch, \\\n               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand\n\n    def select_sketch(self, image_size_rand):\n        resolution_idx = image_size_rand - self.image_size_small\n        img_idx = random.randint(0, len(self.sketch_data[resolution_idx]) - 1)\n        assert img_idx != -1\n\n        selected_sketch = self.sketch_data[resolution_idx][img_idx]  # [0-stroke, 255-BG], uint8\n        selected_photo = self.photo_data[resolution_idx][img_idx]  # [0-stroke, 255-BG], uint8\n\n        rst_sketch_image = selected_sketch.astype(np.float32) / 255.0  # [0.0-stroke, 1.0-BG]\n        rst_photo_image = selected_photo.astype(np.float32) / 255.0  # [0.0-stroke, 1.0-BG]\n\n        return rst_photo_image, rst_sketch_image\n\n    def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):\n        photo_data_batch = []\n        sketch_data_batch = []\n        init_cursors_batch = []\n        image_size_batch = []\n        batch_size_per_loop = self.batch_size // loop_num\n        for loop_i in range(loop_num):\n            image_size_rand = random.randint(self.image_size_small, self.image_size_large)\n\n            photo_data_sub_batch = []\n            sketch_data_sub_batch = []\n            for img_i in range(batch_size_per_loop):\n                photo_patch, sketch_patch = self.select_sketch(image_size_rand)  # both: (H, W), [0.0-stroke, 1.0-BG]\n                photo_patch = self.rough_augmentation(photo_patch)  # (H, W, 3), [0.0-stroke, 1.0-BG]\n\n                if interpolate_type == 'prob':\n                    if random.random() >= photo_prob:\n                        photo_patch = np.stack([sketch_patch for _ in range(3)],\n                                               axis=-1)  # (H, W, 3), [0.0-stroke, 1.0-BG]\n                elif interpolate_type == 'image':\n                    photo_patch = self.image_interpolation(\n                        photo_patch, np.stack([sketch_patch for _ in range(3)], axis=-1), photo_prob)\n                else:\n                    raise Exception('Unknown interpolate_type', interpolate_type)\n\n                photo_data_sub_batch.append(photo_patch)\n                sketch_data_sub_batch.append(sketch_patch)\n\n            photo_data_sub_batch = np.stack(photo_data_sub_batch,\n                                            axis=0)  # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]\n            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,\n                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]\n            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)\n            photo_data_batch.append(photo_data_sub_batch)\n            sketch_data_batch.append(sketch_data_sub_batch)\n            init_cursors_batch.append(init_cursors_sub_batch)\n            image_size_batch.append(image_size_rand)\n\n        return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch\n\n    def crop_patch(self, image, center, image_size, crop_size):\n        x0 = center[0] - crop_size // 2\n        x1 = x0 + crop_size\n        y0 = center[1] - crop_size // 2\n        y1 = y0 + crop_size\n        x0 = max(0, min(x0, image_size))\n        y0 = max(0, min(y0, image_size))\n        x1 = max(0, min(x1, image_size))\n        y1 = max(0, min(y1, image_size))\n        patch = image[y0:y1, x0:x1]\n        return patch\n\n    def gen_init_cursor_single(self, sketch_image):\n        # sketch_image: [0.0-stroke, 1.0-BG]\n        image_size = sketch_image.shape[0]\n        if np.sum(1.0 - sketch_image) == 0:\n            center = np.zeros((2), dtype=np.int32)\n            return center\n        else:\n            while True:\n                center = np.random.randint(0, image_size, size=(2))  # (2), in large size\n                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)\n                if np.sum(patch) != 0:\n                    return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n\n    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):\n        init_cursor_batch_list = []\n        for cursor_i in range(init_cursor_num):\n            if random_pos:\n                init_cursor_batch = []\n                for i in range(len(sketch_data)):\n                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]\n                    center = self.gen_init_cursor_single(sketch_image)\n                    init_cursor_batch.append(center)\n\n                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)\n            else:\n                raise Exception('Not finished')\n            init_cursor_batch_list.append(init_cursor_batch)\n\n        if init_cursor_num == 1:\n            init_cursor_batch = init_cursor_batch_list[0]\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)\n        else:\n            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(\n                np.float32)  # (N, init_cursor_num, 1, 2)\n\n        return init_cursor_batch\n\n\ndef load_dataset_multi_object_rough(dataset_base_dir, model_params):\n    train_photo_data = []\n    train_sketch_data = []\n    val_photo_data = []\n    val_sketch_data = []\n    texture_data = []\n    shadow_data = []\n\n    if model_params.data_set == 'rough_sketches':\n        base_dir_rough = 'QuickDraw-rough'\n\n        def load_sketch_data(mat_path):\n            sketch_data_mat = scipy.io.loadmat(mat_path)\n            sketch_data = sketch_data_mat['sketch_array']\n            sketch_data = np.array(sketch_data, dtype=np.uint8)  # (N, resolution, resolution), [0-strokes, 255-BG]\n            return sketch_data\n\n        def load_photo_data(mat_path):\n            photo_data_mat = scipy.io.loadmat(mat_path)\n            photo_data = photo_data_mat['image_array']\n            photo_data = np.array(photo_data, dtype=np.uint8)  # (N, resolution, resolution), [0-strokes, 255-BG]\n            return photo_data\n\n        def load_normal_data(img_path):\n            assert '.png' in img_path or '.jpg'\n            img = Image.open(img_path).convert('RGB')\n            img = np.array(img, dtype=np.uint8)  # (H, W, 3), [0-stroke, 255-BG], uint8\n            return img\n\n        ## Texture\n        texture_base = os.path.join(dataset_base_dir, base_dir_rough, 'texture')\n        all_texture = os.listdir(texture_base)\n        all_texture.sort()\n\n        for file_name in all_texture:\n            texture_path = os.path.join(texture_base, file_name)\n            texture_uint8 = load_normal_data(texture_path)\n            texture_data.append(texture_uint8)\n\n        ## shadow\n        def process_angle(img, temp_size):\n            padded_img = img.copy()\n            padded_img[0, 0:temp_size] -= 1\n            padded_img[0, -(temp_size + 1):-1] -= 1\n            padded_img[-1, 0:temp_size] -= 1\n            padded_img[-1, -(temp_size + 1):-1] -= 1\n\n            padded_img[0:temp_size, 0] -= 1\n            padded_img[0:temp_size, -1] -= 1\n            padded_img[-(temp_size + 1):-1, 0] -= 1\n            padded_img[-(temp_size + 1):-1, -1] -= 1\n            return padded_img\n\n        def pad_img(ori_img, pad_value):\n            padded_img = np.pad(ori_img, 1, constant_values=pad_value)\n            img_h, img_w = padded_img.shape[0], padded_img.shape[1]\n\n            temp_size = img_h // 3\n            padded_img = process_angle(padded_img, temp_size)\n\n            temp_size = img_h // 9\n            padded_img = process_angle(padded_img, temp_size)\n\n            temp_size = img_h // 15\n            padded_img = process_angle(padded_img, temp_size)\n\n            temp_size = img_h // 21\n            padded_img = process_angle(padded_img, temp_size)\n\n            padded_img = np.clip(padded_img, 0, 255)\n\n            return padded_img\n\n        def shadow_generation(transparency, shadow_img_size=1024):\n            deepest_value = int(255 * transparency)\n\n            center_patch = np.zeros((shadow_img_size // 2, shadow_img_size // 2), dtype=np.uint8)\n            center_patch.fill(255)\n\n            pad_gap = shadow_img_size // 2\n            shadow_patch = center_patch.copy()\n            for i in range(pad_gap):\n                curr_pad_value = 255.0 - float(255.0 - deepest_value) / float(pad_gap) * (i + 1)\n                shadow_patch = pad_img(shadow_patch, pad_value=curr_pad_value)\n\n            for i in range(shadow_img_size // 4):\n                shadow_patch = pad_img(shadow_patch, pad_value=deepest_value)\n\n            assert shadow_patch.shape[0] == shadow_img_size * 2, shadow_patch.shape[0]\n            return shadow_patch\n\n        for transparency_ in range(90, 95 + 1):\n            transparency = transparency_ / 100.0\n            shadow_full = shadow_generation(transparency)\n            shadow_data.append(shadow_full)\n\n        splits = ['train', 'test']\n\n        resolutions = [model_params.image_size_small, model_params.image_size_large]\n\n        for resolution in range(resolutions[0], resolutions[1] + 1):\n            for split in splits:\n                sketch_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',\n                                                'sketch', split, 'res_' + str(resolution) + '.mat')\n                photo_mat1_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil1',\n                                               'photo', split, 'res_' + str(resolution) + '.mat')\n                sketch_data1_uint8 = load_sketch_data(\n                    sketch_mat1_path)  # (N, resolution, resolution), [0-strokes, 255-BG]\n                photo_data1_uint8 = load_photo_data(photo_mat1_path)  # (N, resolution, resolution), [0-strokes, 255-BG]\n\n                sketch_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',\n                                                'sketch', split, 'res_' + str(resolution) + '.mat')\n                photo_mat2_path = os.path.join(dataset_base_dir, base_dir_rough, 'model_pencil2',\n                                               'photo', split, 'res_' + str(resolution) + '.mat')\n                sketch_data2_uint8 = load_sketch_data(\n                    sketch_mat2_path)  # (N, resolution, resolution), [0-strokes, 255-BG]\n                photo_data2_uint8 = load_photo_data(photo_mat2_path)  # (N, resolution, resolution), [0-strokes, 255-BG]\n\n                sketch_data_uint8 = np.concatenate([sketch_data1_uint8, sketch_data2_uint8],\n                                                   axis=0)  # (N, resolution, resolution), [0-strokes, 255-BG]\n                photo_data_uint8 = np.concatenate([photo_data1_uint8, photo_data2_uint8],\n                                                  axis=0)  # (N, resolution, resolution), [0-strokes, 255-BG]\n\n                if split == 'train':\n                    train_photo_data.append(photo_data_uint8)\n                    train_sketch_data.append(sketch_data_uint8)\n                else:\n                    val_photo_data.append(photo_data_uint8)\n                    val_sketch_data.append(sketch_data_uint8)\n\n        assert len(train_sketch_data) == len(train_photo_data)\n        assert len(val_sketch_data) == len(val_photo_data)\n    else:\n        raise Exception('Unknown data type:', model_params.data_set)\n\n    print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))\n    print('model_params.max_seq_len %i.' % model_params.max_seq_len)\n\n    eval_sample_model_params = copy_hparams(model_params)\n    eval_sample_model_params.use_input_dropout = 0\n    eval_sample_model_params.use_recurrent_dropout = 0\n    eval_sample_model_params.use_output_dropout = 0\n    eval_sample_model_params.batch_size = 1  # only sample one at a time\n    eval_sample_model_params.model_mode = 'eval_sample'\n\n    train_set = GeneralDataLoaderMultiObjectRough(train_photo_data, train_sketch_data,\n                                                  texture_data, shadow_data,\n                                                  model_params.batch_size, model_params.raster_size,\n                                                  model_params.image_size_small, model_params.image_size_large,\n                                                  is_train=True)\n    val_set = GeneralDataLoaderMultiObjectRough(val_photo_data, val_sketch_data,\n                                                texture_data, shadow_data,\n                                                eval_sample_model_params.batch_size,\n                                                eval_sample_model_params.raster_size,\n                                                eval_sample_model_params.image_size_small,\n                                                eval_sample_model_params.image_size_large,\n                                                is_train=False)\n\n    result = [\n        train_set, val_set, model_params, eval_sample_model_params\n    ]\n    return result\n\n\nclass GeneralDataLoaderNormalImageLinear(object):\n    def __init__(self,\n                 photo_data,\n                 sketch_data,\n                 sketch_shape,\n                 batch_size,\n                 raster_size,\n                 image_size_small,\n                 image_size_large,\n                 random_image_size,\n                 flip_prob,\n                 rotate_prob,\n                 is_train):\n        self.batch_size = batch_size  # minibatch size\n        self.raster_size = raster_size\n        self.image_size_small = image_size_small\n        self.image_size_large = image_size_large\n        self.random_image_size = random_image_size\n        self.is_train = is_train\n\n        assert photo_data is not None\n        assert len(photo_data) == len(sketch_data)\n        self.num_batches = len(sketch_data) // self.batch_size\n        self.batch_idx = -1\n        print('batch_size', batch_size, ', num_batches', self.num_batches)\n\n        self.flip_prob = flip_prob\n        self.rotate_prob = rotate_prob\n\n        assert type(photo_data) is list\n        assert type(sketch_data) is list\n        self.photo_data = photo_data\n        self.sketch_data = sketch_data\n        self.sketch_shape = sketch_shape\n\n    def get_batch_from_memory(self, memory_idx, interpolate_type, fixed_image_size=-1, random_cursor=True,\n                              photo_prob=1.0,\n                              init_cursor_num=1):\n        if self.random_image_size:\n            image_size_rand = fixed_image_size\n        else:\n            image_size_rand = self.image_size_large\n\n        photo_data_batch, sketch_data_batch = self.select_sketch_and_crop(\n            image_size_rand, data_idx=memory_idx, photo_prob=photo_prob,\n            interpolate_type=interpolate_type)  # sketch_patch: [0.0-stroke, 1.0-BG]\n\n        photo_data_batch = np.expand_dims(photo_data_batch, axis=0)  # (1, image_size, image_size, 3)\n        sketch_data_batch = np.expand_dims(sketch_data_batch,\n                                           axis=0)  # (1, image_size, image_size), [0.0-strokes, 1.0-BG]\n        image_size_rand = sketch_data_batch.shape[1]\n\n        return photo_data_batch, sketch_data_batch, \\\n               self.gen_init_cursors(sketch_data_batch, random_cursor, init_cursor_num), image_size_rand\n\n    def crop_and_augment(self, photo, sketch, shape, crop_size, rotate_angle, stroke_cover=0.01):\n        # img: [0-stroke, 255-BG], uint8\n\n        def angle_convert(angle):\n            return angle / 180.0 * math.pi\n\n        img_h, img_w = shape[0], shape[1]\n\n        if self.is_train:\n            crop_up = random.randint(0, img_h - crop_size)\n            crop_left = random.randint(0, img_w - crop_size)\n        else:\n            crop_up = (img_h - crop_size) // 2\n            crop_left = (img_w - crop_size) // 2\n\n        assert crop_up >= 0\n        assert crop_left >= 0\n\n        crop_box = (crop_left, crop_up, crop_left + crop_size, crop_up + crop_size)\n        rst_sketch_image = sketch.crop(crop_box)\n        rst_photo_image = photo.crop(crop_box)\n\n        if random.random() <= self.flip_prob and self.is_train:\n            rst_sketch_image = rst_sketch_image.transpose(Image.FLIP_LEFT_RIGHT)\n            rst_photo_image = rst_photo_image.transpose(Image.FLIP_LEFT_RIGHT)\n\n        if rotate_angle != 0 and self.is_train:\n            rst_sketch_image = rst_sketch_image.rotate(rotate_angle, resample=Image.BILINEAR)\n            rst_photo_image = rst_photo_image.rotate(rotate_angle, resample=Image.BILINEAR)\n            rst_sketch_image = np.array(rst_sketch_image, dtype=np.uint8)\n            rst_photo_image = np.array(rst_photo_image, dtype=np.uint8)\n\n            center = rst_photo_image.shape[0] // 2\n\n            new_dim = float(crop_size) / (\n                        math.sin(angle_convert(abs(rotate_angle))) + math.cos(angle_convert(abs(rotate_angle))))\n            new_dim = int(round(new_dim))\n\n            start_pos = center - new_dim // 2\n            end_pos = start_pos + new_dim\n            rst_sketch_image = rst_sketch_image[start_pos:end_pos, start_pos:end_pos, :]\n            rst_photo_image = rst_photo_image[start_pos:end_pos, start_pos:end_pos, :]\n\n        rst_sketch_image = np.array(rst_sketch_image, dtype=np.float32) / 255.0  # [0.0-stroke, 1.0-BG]\n        rst_sketch_image = rst_sketch_image[:, :, 0]\n        rst_photo_image = np.array(rst_photo_image, dtype=np.float32) / 255.0  # [0.0-stroke, 1.0-BG]\n\n        percentage = np.mean(1.0 - rst_sketch_image)\n        valid = True\n        if percentage < stroke_cover:\n            valid = False\n\n        return rst_photo_image, rst_sketch_image, valid\n\n    def image_interpolation(self, photo, sketch, photo_prob):\n        interp_photo = photo * photo_prob + sketch * (1.0 - photo_prob)\n        interp_photo = np.clip(interp_photo, 0.0, 1.0)\n        return interp_photo\n\n    def select_sketch_and_crop(self, image_size_rand, interpolate_type, rotate_angle=0, photo_prob=1.0,\n                               data_idx=-1, trial_times=10):\n        if self.is_train:\n            while True:\n                rand_img_idx = random.randint(0, len(self.sketch_data) - 1)\n                selected_sketch_shape = self.sketch_shape[rand_img_idx]\n                if selected_sketch_shape[0] >= image_size_rand and selected_sketch_shape[1] >= image_size_rand:\n                    img_idx = rand_img_idx\n                    break\n        else:\n            assert data_idx != -1\n            img_idx = data_idx\n\n        assert img_idx != -1\n        selected_sketch = self.sketch_data[img_idx]\n        selected_photo = self.photo_data[img_idx]\n        selected_shape = self.sketch_shape[img_idx]\n\n        assert interpolate_type in ['prob', 'image']\n\n        if interpolate_type == 'prob' and random.random() >= photo_prob:\n            selected_photo = self.sketch_data[img_idx]\n\n        for trial_i in range(trial_times):\n            cropped_photo, cropped_sketch, valid = \\\n                self.crop_and_augment(selected_photo, selected_sketch, selected_shape, image_size_rand, rotate_angle)\n            # cropped_photo, cropped_sketch: [0.0-stroke, 1.0-BG]\n\n            if valid or trial_i == trial_times - 1:\n                if interpolate_type == 'image':\n                    cropped_photo = self.image_interpolation(cropped_photo,\n                                                             np.stack([cropped_sketch for _ in range(3)], axis=-1),\n                                                             photo_prob)\n\n                return cropped_photo, cropped_sketch\n\n    def get_batch_multi_res(self, loop_num, interpolate_type, random_cursor=True, init_cursor_num=1, photo_prob=1.0):\n        photo_data_batch = []\n        sketch_data_batch = []\n        init_cursors_batch = []\n        image_size_batch = []\n        batch_size_per_loop = self.batch_size // loop_num\n        for loop_i in range(loop_num):\n            if self.random_image_size:\n                image_size_rand = random.randint(self.image_size_small, self.image_size_large)\n            else:\n                image_size_rand = self.image_size_large\n\n            rotate_angle = 0\n            if random.random() <= self.rotate_prob:\n                rotate_angle = random.randint(-45, 45)\n\n            photo_data_sub_batch = []\n            sketch_data_sub_batch = []\n            for img_i in range(batch_size_per_loop):\n                photo_patch, sketch_patch = \\\n                    self.select_sketch_and_crop(image_size_rand, rotate_angle=rotate_angle, photo_prob=photo_prob,\n                                                interpolate_type=interpolate_type)  # sketch_patch: [0.0-stroke, 1.0-BG]\n                photo_data_sub_batch.append(photo_patch)\n                sketch_data_sub_batch.append(sketch_patch)\n\n            photo_data_sub_batch = np.stack(photo_data_sub_batch,\n                                            axis=0)  # (N, image_size, image_size, 3), [0.0-strokes, 1.0-BG]\n            sketch_data_sub_batch = np.stack(sketch_data_sub_batch,\n                                             axis=0)  # (N, image_size, image_size), [0.0-strokes, 1.0-BG]\n            init_cursors_sub_batch = self.gen_init_cursors(sketch_data_sub_batch, random_cursor, init_cursor_num)\n\n            photo_data_batch.append(photo_data_sub_batch)\n            sketch_data_batch.append(sketch_data_sub_batch)\n            init_cursors_batch.append(init_cursors_sub_batch)\n\n            image_size_rand = photo_data_sub_batch.shape[1]\n            image_size_batch.append(image_size_rand)\n\n        return photo_data_batch, sketch_data_batch, init_cursors_batch, image_size_batch\n\n    def crop_patch(self, image, center, image_size, crop_size):\n        x0 = center[0] - crop_size // 2\n        x1 = x0 + crop_size\n        y0 = center[1] - crop_size // 2\n        y1 = y0 + crop_size\n        x0 = max(0, min(x0, image_size))\n        y0 = max(0, min(y0, image_size))\n        x1 = max(0, min(x1, image_size))\n        y1 = max(0, min(y1, image_size))\n        patch = image[y0:y1, x0:x1]\n        return patch\n\n    def gen_init_cursor_single(self, sketch_image):\n        # sketch_image: [0.0-stroke, 1.0-BG]\n        image_size = sketch_image.shape[0]\n        if np.sum(1.0 - sketch_image) == 0:\n            center = np.zeros((2), dtype=np.int32)\n            return center\n        else:\n            while True:\n                center = np.random.randint(0, image_size, size=(2))  # (2), in large size\n                patch = 1.0 - self.crop_patch(sketch_image, center, image_size, self.raster_size)\n                if np.sum(patch) != 0:\n                    return center.astype(np.float32) / float(image_size)  # (2), in size [0.0, 1.0)\n\n    def gen_init_cursors(self, sketch_data, random_pos=True, init_cursor_num=1):\n        init_cursor_batch_list = []\n        for cursor_i in range(init_cursor_num):\n            if random_pos:\n                init_cursor_batch = []\n                for i in range(len(sketch_data)):\n                    sketch_image = sketch_data[i].copy().astype(np.float32)  # [0.0-stroke, 1.0-BG]\n                    center = self.gen_init_cursor_single(sketch_image)\n                    init_cursor_batch.append(center)\n\n                init_cursor_batch = np.stack(init_cursor_batch, axis=0)  # (N, 2)\n            else:\n                raise Exception('Not finished')\n            init_cursor_batch_list.append(init_cursor_batch)\n\n        if init_cursor_num == 1:\n            init_cursor_batch = init_cursor_batch_list[0]\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=1).astype(np.float32)  # (N, 1, 2)\n        else:\n            init_cursor_batch = np.stack(init_cursor_batch_list, axis=1)  # (N, init_cursor_num, 2)\n            init_cursor_batch = np.expand_dims(init_cursor_batch, axis=2).astype(\n                np.float32)  # (N, init_cursor_num, 1, 2)\n\n        return init_cursor_batch\n\n\ndef load_dataset_normal_images(dataset_base_dir, model_params):\n    train_photo_data = []\n    train_sketch_data = []\n    train_data_shape = []\n    val_photo_data = []\n    val_sketch_data = []\n    val_data_shape = []\n\n    if model_params.data_set == 'faces':\n        random_training_image_size = False\n        flip_prob = -0.1\n        rotate_prob = -0.1\n\n        splits = ['train', 'val']\n\n        database = os.path.join(dataset_base_dir, 'CelebAMask-faces')\n        photo_base = os.path.join(database, 'CelebA-HQ-img256')\n        edge_base = os.path.join(database, 'CelebAMask-HQ-edge256')\n\n        train_split_txt_save_path = os.path.join(database, 'train.txt')\n        val_split_txt_save_path = os.path.join(database, 'val.txt')\n        celeba_train_txt = np.loadtxt(train_split_txt_save_path, dtype=str)\n        celeba_val_txt = np.loadtxt(val_split_txt_save_path, dtype=str)\n        splits_indices_map = {'train': celeba_train_txt, 'val': celeba_val_txt}\n\n        for split in splits:\n            split_indices = splits_indices_map[split]\n\n            for i in range(len(split_indices)):\n                file_idx = split_indices[i]\n                img_file_path = os.path.join(photo_base, str(file_idx) + '.jpg')\n                edge_img_path = os.path.join(edge_base, str(file_idx) + '.png')\n\n                img_data = Image.open(img_file_path).convert('RGB')\n                edge_data = Image.open(edge_img_path).convert('RGB')\n\n                if split == 'train':\n                    train_photo_data.append(img_data)\n                    train_sketch_data.append(edge_data)\n                    train_data_shape.append((img_data.height, img_data.width))\n                else:  # split == 'val'\n                    val_photo_data.append(img_data)\n                    val_sketch_data.append(edge_data)\n                    val_data_shape.append((img_data.height, img_data.width))\n\n        assert len(train_sketch_data) == len(train_data_shape) == len(train_photo_data)\n        assert len(val_sketch_data) == len(val_data_shape) == len(val_photo_data)\n    else:\n        raise Exception('Unknown data type:', model_params.data_set)\n\n    print('Loaded {}/{} from {}'.format(len(train_sketch_data), len(val_sketch_data), model_params.data_set))\n    print('model_params.max_seq_len %i.' % model_params.max_seq_len)\n\n    eval_sample_model_params = copy_hparams(model_params)\n    eval_sample_model_params.use_input_dropout = 0\n    eval_sample_model_params.use_recurrent_dropout = 0\n    eval_sample_model_params.use_output_dropout = 0\n    eval_sample_model_params.batch_size = 1  # only sample one at a time\n    eval_sample_model_params.model_mode = 'eval_sample'\n\n    train_set = GeneralDataLoaderNormalImageLinear(train_photo_data, train_sketch_data, train_data_shape,\n                                                   model_params.batch_size, model_params.raster_size,\n                                                   image_size_small=model_params.image_size_small,\n                                                   image_size_large=model_params.image_size_large,\n                                                   random_image_size=random_training_image_size,\n                                                   flip_prob=flip_prob, rotate_prob=rotate_prob,\n                                                   is_train=True)\n    val_set = GeneralDataLoaderNormalImageLinear(val_photo_data, val_sketch_data, val_data_shape,\n                                                 eval_sample_model_params.batch_size,\n                                                 eval_sample_model_params.raster_size,\n                                                 image_size_small=eval_sample_model_params.image_size_small,\n                                                 image_size_large=eval_sample_model_params.image_size_large,\n                                                 random_image_size=random_training_image_size,\n                                                 flip_prob=flip_prob, rotate_prob=rotate_prob,\n                                                 is_train=False)\n\n    result = [\n        train_set, val_set, model_params, eval_sample_model_params\n    ]\n    return result\n\n\ndef load_dataset_training(dataset_base_dir, model_params):\n    if model_params.data_set == 'clean_line_drawings':\n        return load_dataset_multi_object(dataset_base_dir, model_params)\n    elif model_params.data_set == 'rough_sketches':\n        return load_dataset_multi_object_rough(dataset_base_dir, model_params)\n    elif model_params.data_set == 'faces':\n        return load_dataset_normal_images(dataset_base_dir, model_params)\n    else:\n        raise Exception('Unknown data_set', model_params.data_set)\n"
  },
  {
    "path": "docs/assets/font.css",
    "content": "/* Homepage Font */\n\n/* latin-ext */\n@font-face {\n  font-family: 'Lato';\n  font-style: normal;\n  font-weight: 400;\n  src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');\n  unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;\n}\n\n/* latin */\n@font-face {\n  font-family: 'Lato';\n  font-style: normal;\n  font-weight: 400;\n  src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');\n  unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;\n}\n\n/* latin-ext */\n@font-face {\n  font-family: 'Lato';\n  font-style: normal;\n  font-weight: 700;\n  src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');\n  unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;\n}\n\n/* latin */\n@font-face {\n  font-family: 'Lato';\n  font-style: normal;\n  font-weight: 700;\n  src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');\n  unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;\n}\n"
  },
  {
    "path": "docs/assets/style.css",
    "content": "/* Body */\nbody {\n  background: #e3e5e8;\n  color: #ffffff;\n  font-family: 'Lato', Verdana, Helvetica, sans-serif;\n  font-weight: 300;\n  font-size: 14pt;\n}\n\n/* Hyperlinks */\na {text-decoration: none;}\na:link {color: #1772d0;}\na:visited {color: #1772d0;}\na:active {color: red;}\na:hover {color: #f09228;}\n\n/* Pre-formatted Text */\npre {\n  margin: 5pt 0;\n  border: 0;\n  font-size: 12pt;\n  background: #fcfcfc;\n}\n\n/* Project Page Style */\n/* Section */\n.section {\n  width: 768pt;\n  min-height: 100pt;\n  margin: 15pt auto;\n  padding: 20pt 30pt;\n  border: 1pt hidden #000;\n  text-align: justify;\n  color: #000000;\n  background: #ffffff;\n}\n\n/* Header (Title and Logo) */\n.section .header {\n  min-height: 80pt;\n  margin-top: 30pt;\n}\n.section .header .logo {\n  width: 80pt;\n  margin-left: 10pt;\n  float: left;\n}\n.section .header .logo img {\n  width: 80pt;\n  object-fit: cover;\n}\n.section .header .title {\n  margin: 0 120pt;\n  text-align: center;\n  font-size: 22pt;\n}\n\n/* Author */\n.section .author {\n  margin: 5pt 0;\n  text-align: center;\n  font-size: 16pt;\n}\n\n/* Institution */\n.section .institution {\n  margin: 5pt 0;\n  text-align: center;\n  font-size: 16pt;\n}\n\n/* Hyperlink (such as Paper and Code) */\n.section .link {\n  margin: 5pt 0;\n  text-align: center;\n  font-size: 16pt;\n}\n\n/* Teaser */\n.section .teaser {\n  margin: 20pt 0;\n  text-align: left;\n}\n.section .teaser img {\n  width: 95%;\n}\n\n/* Section Title */\n.section .title {\n  text-align: center;\n  font-size: 22pt;\n  margin: 5pt 0 15pt 0;  /* top right bottom left */\n}\n\n/* Section Body */\n.section .body {\n  margin-bottom: 15pt;\n  text-align: justify;\n  font-size: 14pt;\n}\n\n/* BibTeX */\n.section .bibtex {\n  margin: 5pt 0;\n  text-align: left;\n  font-size: 22pt;\n}\n\n/* Related Work */\n.section .ref {\n  margin: 20pt 0 10pt 0;  /* top right bottom left */\n  text-align: left;\n  font-size: 18pt;\n  font-weight: bold;\n}\n\n/* Citation */\n.section .citation {\n  min-height: 60pt;\n  margin: 10pt 0;\n}\n.section .citation .image {\n  width: 120pt;\n  float: left;\n}\n.section .citation .image img {\n  max-height: 60pt;\n  width: 120pt;\n  object-fit: cover;\n}\n.section .citation .comment{\n  margin-left: 0pt;\n  text-align: left;\n  font-size: 14pt;\n}\n"
  },
  {
    "path": "docs/index.html",
    "content": "<!doctype html>\r\n<html lang=\"en\">\r\n\r\n\r\n<!-- === Header Starts === -->\r\n<head>\r\n  <meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\">\r\n\r\n  <title>General Virtual Sketching Framework for Vector Line Art</title>\r\n\r\n  <link href=\"./assets/bootstrap.min.css\" rel=\"stylesheet\">\r\n  <link href=\"./assets/font.css\" rel=\"stylesheet\" type=\"text/css\">\r\n  <link href=\"./assets/style.css\" rel=\"stylesheet\" type=\"text/css\">\r\n</head>\r\n<!-- === Header Ends === -->\r\n\r\n\r\n<body>\r\n\r\n\r\n<!-- === Home Section Starts === -->\r\n<div class=\"section\">\r\n  <!-- === Title Starts === -->\r\n    <div class=\"title\">\r\n      <b>General Virtual Sketching Framework for Vector Line Art</b>\r\n    </div>\r\n  <!-- === Title Ends === -->\r\n  <div class=\"author\">\r\n    <a href=\"http://mo-haoran.com/\" target=\"_blank\">Haoran Mo</a><sup>1</sup>,&nbsp;\r\n    <a href=\"https://esslab.jp/~ess/en/\" target=\"_blank\">Edgar Simo-Serra</a><sup>2</sup>,&nbsp;\r\n    <a href=\"http://cse.sysu.edu.cn/content/2537\" target=\"_blank\">Chengying Gao</a><sup>*1</sup>,&nbsp;\r\n    <a href=\"https://changqingzou.weebly.com/\" target=\"_blank\">Changqing Zou</a><sup>3</sup>,&nbsp;\r\n    <a href=\"http://cse.sysu.edu.cn/content/2523\" target=\"_blank\">Ruomei Wang</a><sup>1</sup>\r\n  </div>\r\n  <div class=\"institution\">\r\n    <sup>1</sup>Sun Yat-sen University,&nbsp;\r\n    <sup>2</sup>Waseda University,&nbsp;\r\n    <br>\r\n    <sup>3</sup>Huawei Technologies Canada\r\n  </div>\r\n  <br>\r\n  <div class=\"institution\">\r\n    Accepted by <a href=\"https://s2021.siggraph.org/\" target=\"_blank\">ACM SIGGRAPH 2021</a>\r\n  </div>\r\n  <div class=\"link\">\r\n    <a href=\"https://esslab.jp/publications/HaoranSIGRAPH2021.pdf\" target=\"_blank\">[Paper]</a>&nbsp;\r\n    <a href=\"https://github.com/MarkMoHR/virtual_sketching\" target=\"_blank\">[Code]</a>\r\n  </div>\r\n  <div class=\"teaser\">\r\n    <img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/teaser6.png\" style=\"width: 100%;\">\r\n    <br>\r\n    <br>\r\n    <font size=\"3\">\r\n      Given clean line drawings, rough sketches or photographs of arbitrary resolution as input, our framework generates the corresponding vector line drawings directly. As shown in (b), the framework models a virtual pen surrounded by a dynamic window (red boxes), which moves while drawing the strokes. It learns to move around by scaling the window and sliding to an undrawn area for restarting the drawing (bottom example; sliding trajectory in blue arrow). With our proposed stroke regularization mechanism, the framework is able to enlarge the window and draw long strokes for simplicity (top example).\r\n    </font>\r\n  </div>\r\n</div>\r\n<!-- === Home Section Ends === -->\r\n\r\n\r\n<!-- === Overview Section Starts === -->\r\n<div class=\"section\">\r\n  <div class=\"title\">Abstract</div>\r\n  <div class=\"body\">\r\n    Vector line art plays an important role in graphic design, however, it is tedious to manually create.\r\n    We introduce a general framework to produce line drawings from a wide variety of images,\r\n    by learning a mapping from raster image space to vector image space.\r\n    Our approach is based on a recurrent neural network that draws the lines one by one.\r\n    A differentiable rasterization module allows for training with only supervised raster data.\r\n    We use a dynamic window around a virtual pen while drawing lines,\r\n    implemented with a proposed aligned cropping and differentiable pasting modules.\r\n    Furthermore, we develop a stroke regularization loss\r\n    that encourages the model to use fewer and longer strokes to simplify the resulting vector image.\r\n    Ablation studies and comparisons with existing methods corroborate the efficiency of our approach\r\n    which is able to generate visually better results in less computation time,\r\n    while generalizing better to a diversity of images and applications.\r\n  </div>\r\n  <div class=\"link\">\r\n    <a href=\"https://esslab.jp/publications/HaoranSIGRAPH2021.pdf\" target=\"_blank\">[Paper]</a>&nbsp; &nbsp;\r\n    <a href=\"https://dl.acm.org/doi/abs/10.1145/3450626.3459833\" target=\"_blank\">[Paper (ACM)]</a>&nbsp; &nbsp;\r\n    <a href=\"https://markmohr.github.io/files/SIG2021/SketchVectorization_SIG2021_supplemental.pdf\" target=\"_blank\">[Supplementary]</a>&nbsp; &nbsp;\r\n\t  <a href=\"https://github.com/MarkMoHR/virtual_sketching\" target=\"_blank\">[Code]</a>&nbsp; &nbsp;\r\n    <a href=\"https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing\" target=\"_blank\">[All Precomputed Results]</a>\r\n\t  <!-- <a href=\"\" target=\"_blank\">[Presentation (TBD)]</a>&nbsp; &nbsp; -->\r\n  </div>\r\n</div>\r\n<!-- === Overview Section Ends === -->\r\n\r\n\r\n<!-- === Result Section Starts === -->\r\n<div class=\"section\">\r\n  <div class=\"title\">Method</div>\r\n  <br>\r\n  <div class=\"body\">\r\n    <p style=\"text-align:center; font-size:23px; font-weight:bold\">Framework Overview<p>\r\n    <img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/framework6.png\" width=\"100%\">\r\n    <br>\r\n    <br>\r\n    <font size=\"4\">\r\n      Our framework generates the parametrized strokes step by step in a recurrent manner.\r\n      It uses a dynamic window (dashed red boxes) around a virtual pen to draw the strokes,\r\n      and can both move and change the size of the window.\r\n      (a) Four main modules at each time step: aligned cropping, stroke generation, differentiable rendering and differentiable pasting.\r\n      (b) Architecture of the stroke generation module.\r\n      (c) Structural strokes predicted at each step;\r\n      movement only is illustrated by blue arrows during which no stroke is drawn on the canvas.\r\n    </font>\r\n    <br>\r\n    <br>\r\n\r\n    <p style=\"text-align:center; font-size:23px; font-weight:bold\">\r\n      Overall Introduction\r\n    <p>\r\n    <p style=\"text-align:center; font-size:20px\">\r\n      (Or watch on <a href=\"https://www.bilibili.com/video/BV1gM4y1V7i7/\" target=\"_blank\">Bilibili</a>)\r\n      <br>\r\n      👇\r\n    <p>\r\n    <!-- Adjust the frame size based on the demo (EVERY project differs). -->\r\n    <div style=\"position: relative; padding-top: 50%; text-align: center;\">\r\n      <iframe src=\"https://www.youtube.com/embed/gXk3TMceByY\" frameborder=0\r\n              style=\"position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;\"\r\n              allow=\"accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture\"\r\n              allowfullscreen></iframe>\r\n    </div>\r\n\r\n  </div>\r\n</div>\r\n<!-- === Result Section Ends === -->\r\n\r\n<!-- === Result Section Starts === -->\r\n<div class=\"section\">\r\n  <div class=\"title\">Results</div>\r\n  <div class=\"body\">\r\n    Our framework is applicable to a diversity of image types, such as clean line drawing images, rough sketches and photographs.\r\n\r\n    <p style=\"margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold\">Vectorization<p>\r\n    <table width=\"100%\" style=\"margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;\">\r\n      <tr>\r\n        <td width=\"45%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/clean/muten.png\" width=\"100%\"></td>\r\n        <td width=\"10%\"></td>\r\n        <td width=\"45%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/clean/muten-black-full-simplest.gif\" width=\"100%\"></td>\r\n      </tr>\r\n    </table>\r\n    <br>\r\n\r\n    <p style=\"margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold\">Rough sketch simplification<p>\r\n    <table width=\"100%\" style=\"margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;\">\r\n      <tr>\r\n        <td width=\"26%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/rocket.png\" width=\"100%\"></td>\r\n        <td width=\"26%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/rocket-blue-simplest.gif\" width=\"100%\"></td>\r\n        <td width=\"4%\"></td>\r\n        <td width=\"14%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/penguin.png\" width=\"100%\"></td>\r\n        <td width=\"14%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/rough/penguin-blue-simplest.gif\" width=\"100%\"></td>\r\n      </tr>\r\n    </table>\r\n    <br>\r\n\r\n    <p style=\"margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold\">Photograph to line drawing<p>\r\n    <table width=\"100%\" style=\"margin: 0pt auto; text-align: center; border-collapse: separate; border-spacing: 5pt;\">\r\n      <tr>\r\n        <td width=\"23%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/1390_input.png\" width=\"100%\"></td>\r\n        <td width=\"23%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/face-blue-1390-simplest.gif\" width=\"100%\"></td>\r\n        <td width=\"8%\"></td>\r\n        <td width=\"23%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/1190_input.png\" width=\"100%\"></td>\r\n        <td width=\"23%\"><img src=\"https://cdn.jsdelivr.net/gh/mark-cdn/CDN-for-works@1.4/files/SIG21/gifs/face/face-blue-1190-simplest.gif\" width=\"100%\"></td>\r\n      </tr>\r\n    </table>\r\n    <br>\r\n\r\n    <p style=\"margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold\">\r\n      More Results\r\n    <p>\r\n    <p style=\"text-align:center; font-size:20px\">\r\n      (Or watch on <a href=\"https://www.bilibili.com/video/BV1pv411N7Yx/\" target=\"_blank\">Bilibili</a>)\r\n      <br>\r\n      👇\r\n    <p>\r\n    <!-- Adjust the frame size based on the demo (EVERY project differs). -->\r\n    <div style=\"position: relative; padding-top: 50%; text-align: center;\">\r\n      <iframe src=\"https://www.youtube.com/embed/Pr6mK9ddXkQ\" frameborder=0\r\n              style=\"position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;\"\r\n              allow=\"accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture\"\r\n              allowfullscreen></iframe>\r\n    </div>\r\n    <br>\r\n\r\n    <div class=\"link\">\r\n      <a href=\"https://drive.google.com/drive/folders/1-hi2cl8joZ6oMOp4yvk_hObJGAK6ELHB?usp=sharing\" target=\"_blank\">\r\n      [Download Our Precomputed Output Results (7MB)]</a>\r\n    </div>\r\n\r\n  </div>\r\n</div>\r\n<!-- === Result Section Ends === -->\r\n\r\n\r\n<!-- === Result Section Starts === -->\r\n<div class=\"section\">\r\n  <div class=\"title\">Presentations</div>\r\n  <div class=\"body\">\r\n\r\n    <p style=\"margin-top: 10pt; text-align:center; font-size:23px; font-weight:bold\">\r\n      3-5 minute presentation\r\n    <p>\r\n    <p style=\"text-align:center; font-size:20px\">\r\n      (Or watch on <a href=\"https://www.bilibili.com/video/BV1S3411q7VX/\" target=\"_blank\">Bilibili</a>)\r\n      <br>\r\n      👇\r\n    <p>\r\n    <!-- Adjust the frame size based on the demo (EVERY project differs). -->\r\n    <div style=\"position: relative; padding-top: 50%; text-align: center;\">\r\n      <iframe src=\"https://www.youtube.com/embed/BSJN1ixacts\" frameborder=0\r\n              style=\"position: absolute; top: 1%; left: 5%; width: 90%; height: 100%;\"\r\n              allow=\"accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture\"\r\n              allowfullscreen></iframe>\r\n    </div>\r\n    <br>\r\n\r\n    <div class=\"link\">\r\n      👉 15-20 minute presentation:\r\n      <a href=\"https://youtu.be/D_U4e1qh5qc\" target=\"_blank\">[YouTube]</a>\r\n      <a href=\"https://www.bilibili.com/video/BV1uU4y1E7Wg/\" target=\"_blank\">[Bilibili]</a>\r\n    </div>\r\n\r\n    <div class=\"link\">\r\n      👉 30-second fast forward:\r\n      <a href=\"https://youtu.be/d0EbSU_EeFg\" target=\"_blank\">[YouTube]</a>\r\n      <a href=\"https://www.bilibili.com/video/BV1vq4y1M7j1/\" target=\"_blank\">[Bilibili]</a>\r\n    </div>\r\n\r\n  </div>\r\n</div>\r\n<!-- === Result Section Ends === -->\r\n\r\n\r\n<!-- === Reference Section Starts === -->\r\n<div class=\"section\">\r\n  <div class=\"bibtex\">BibTeX</div>\r\n<pre>\r\n@article{mo2021virtualsketching,\r\n    title   = {General Virtual Sketching Framework for Vector Line Art},\r\n    author  = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},\r\n    journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},\r\n    year    = {2021},\r\n    volume  = {40},\r\n    number  = {4},\r\n    pages   = {51:1--51:14}\r\n}\r\n</pre>\r\n\r\n  <br>\r\n  <div class=\"bibtex\">Related Work</div>\r\n  <div class=\"citation\">\r\n    <div class=\"comment\">\r\n      Jean-Dominique Favreau, Florent Lafarge and Adrien Bousseau.\r\n      <strong>Fidelity vs. Simplicity: a Global Approach to Line Drawing Vectorization</strong>. SIGGRAPH 2016.\r\n      [<a href=\"https://www-sop.inria.fr/reves/Basilic/2016/FLB16/fidelity_simplicity.pdf\">Paper</a>]\r\n      [<a href=\"https://www-sop.inria.fr/reves/Basilic/2016/FLB16/\">Webpage</a>]\r\n      <br><br>\r\n    </div>\r\n\r\n    <div class=\"comment\">\r\n      Mikhail Bessmeltsev and Justin Solomon. \r\n      <strong>Vectorization of Line Drawings via PolyVector Fields</strong>. SIGGRAPH 2019. \r\n      [<a href=\"https://arxiv.org/abs/1801.01922\">Paper</a>]\r\n      [<a href=\"https://github.com/bmpix/PolyVectorization\">Code</a>]\r\n      <br><br>\r\n    </div>\r\n\r\n    <div class=\"comment\">\r\n      Edgar Simo-Serra, Satoshi Iizuka and Hiroshi Ishikawa. \r\n      <strong>Mastering Sketching: Adversarial Augmentation for Structured Prediction</strong>. SIGGRAPH 2018. \r\n      [<a href=\"https://esslab.jp/~ess/publications/SimoSerraTOG2018.pdf\">Paper</a>]\r\n      [<a href=\"https://esslab.jp/~ess/en/research/sketch_master/\">Webpage</a>]\r\n      [<a href=\"https://github.com/bobbens/sketch_simplification\">Code</a>]\r\n      <br><br>\r\n    </div>\r\n\r\n    <div class=\"comment\">\r\n      Zhewei Huang, Wen Heng and Shuchang Zhou. \r\n      <strong>Learning to Paint With Model-based Deep Reinforcement Learning</strong>. ICCV 2019. \r\n      [<a href=\"https://openaccess.thecvf.com/content_ICCV_2019/papers/Huang_Learning_to_Paint_With_Model-Based_Deep_Reinforcement_Learning_ICCV_2019_paper.pdf\">Paper</a>]\r\n      [<a href=\"https://github.com/megvii-research/ICCV2019-LearningToPaint\">Code</a>]\r\n      <br><br>\r\n    </div>\r\n  </div>\r\n</div>\r\n<!-- === Reference Section Ends === -->\r\n\r\n\r\n</body>\r\n</html>\r\n"
  },
  {
    "path": "hyper_parameters.py",
    "content": "import tensorflow as tf\n\n\n#############################################\n# Common parameters\n#############################################\n\nFLAGS = tf.app.flags.FLAGS\n\ntf.app.flags.DEFINE_string(\n    'dataset_dir',\n    'datasets',\n    'The directory of sketch data of the dataset.')\ntf.app.flags.DEFINE_string(\n    'log_root',\n    'outputs/log',\n    'Directory to store tensorboard.')\ntf.app.flags.DEFINE_string(\n    'log_img_root',\n    'outputs/log_img',\n    'Directory to store intermediate output images.')\ntf.app.flags.DEFINE_string(\n    'snapshot_root',\n    'outputs/snapshot',\n    'Directory to store model checkpoints.')\ntf.app.flags.DEFINE_string(\n    'neural_renderer_path',\n    'outputs/snapshot/pretrain_neural_renderer/renderer_300000.tfmodel',\n    'Path to the neural renderer model.')\ntf.app.flags.DEFINE_string(\n    'perceptual_model_root',\n    'outputs/snapshot/pretrain_perceptual_model',\n    'Directory to store perceptual model.')\ntf.app.flags.DEFINE_string(\n    'data',\n    '',\n    'The dataset type.')\n\n\ndef get_default_hparams_clean():\n    \"\"\"Return default HParams for sketch-rnn.\"\"\"\n    hparams = tf.contrib.training.HParams(\n        program_name='new_train_clean_line_drawings',\n        data_set='clean_line_drawings',  # Our dataset.\n\n        input_channel=1,\n\n        num_steps=75040,  # Total number of steps of training.\n        save_every=75000,\n        eval_every=5000,\n\n        max_seq_len=48,\n        batch_size=20,\n        gpus=[0, 1],\n        loop_per_gpu=1,\n\n        sn_loss_type='increasing',  # ['decreasing', 'fixed', 'increasing']\n        stroke_num_loss_weight=0.02,\n        stroke_num_loss_weight_end=0.0,\n        increase_start_steps=25000,\n        decrease_stop_steps=40000,\n\n        perc_loss_layers=['ReLU1_2', 'ReLU2_2', 'ReLU3_3', 'ReLU5_1'],\n        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']\n\n        init_cursor_on_undrawn_pixel=False,\n\n        early_pen_loss_type='move',  # ['head', 'tail', 'move']\n        early_pen_loss_weight=0.1,\n        early_pen_length=7,\n\n        min_width=0.01,\n        min_window_size=32,\n        max_scaling=2.0,\n\n        encode_cursor_type='value',\n\n        image_size_small=128,\n        image_size_large=278,\n\n        cropping_type='v3',  # ['v2', 'v3']\n        pasting_type='v3',  # ['v2', 'v3']\n        pasting_diff=True,\n\n        concat_win_size=True,\n\n        encoder_type='conv13_c3',\n        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']\n        # ['conv13_c3_attn']\n        # ['combine33', 'combine43', 'combine53', 'combineFC']\n        vary_thickness=False,\n\n        outside_loss_weight=10.0,\n        win_size_outside_loss_weight=10.0,\n\n        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']\n\n        concat_cursor=True,\n\n        use_softargmax=True,\n        soft_beta=10,  # value for the soft argmax\n\n        raster_loss_weight=1.0,\n\n        dec_rnn_size=256,  # Size of decoder.\n        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.\n        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.\n        bin_gt=True,\n\n        stop_accu_grad=True,\n\n        random_cursor=True,\n        cursor_type='next',\n\n        raster_size=128,\n\n        pix_drop_kp=1.0,  # Dropout keep rate\n        add_coordconv=True,\n        position_format='abs',\n        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]\n\n        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.\n\n        learning_rate=0.0001,  # Learning rate.\n        decay_rate=0.9999,  # Learning rate decay per minibatch.\n        decay_power=0.9,\n        min_learning_rate=0.000001,  # Minimum learning rate.\n\n        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended\n        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.\n        use_input_dropout=False,  # Input dropout. Recommend leaving False.\n        input_dropout_prob=0.90,  # Probability of input dropout keep.\n        use_output_dropout=False,  # Output dropout. Recommend leaving False.\n        output_dropout_prob=0.90,  # Probability of output dropout keep.\n\n        model_mode='train'  # ['train', 'eval', 'sample']\n    )\n    return hparams\n\n\ndef get_default_hparams_rough():\n    \"\"\"Return default HParams for sketch-rnn.\"\"\"\n    hparams = tf.contrib.training.HParams(\n        program_name='new_train_rough_sketches',\n        data_set='rough_sketches',  # ['rough_sketches', 'faces']\n\n        input_channel=3,\n\n        num_steps=90040,  # Total number of steps of training.\n        save_every=90000,\n        eval_every=5000,\n\n        max_seq_len=48,\n        batch_size=20,\n        gpus=[0, 1],\n        loop_per_gpu=1,\n\n        sn_loss_type='increasing',  # ['decreasing', 'fixed', 'increasing']\n        stroke_num_loss_weight=0.1,\n        stroke_num_loss_weight_end=0.0,\n        increase_start_steps=25000,\n        decrease_stop_steps=40000,\n\n        photo_prob_type='one',  # ['increasing', 'zero', 'one']\n        photo_prob_start_step=35000,\n\n        perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU5_1'],\n        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']\n\n        early_pen_loss_type='move',  # ['head', 'tail', 'move']\n        early_pen_loss_weight=0.2,\n        early_pen_length=7,\n\n        min_width=0.01,\n        min_window_size=32,\n        max_scaling=2.0,\n\n        encode_cursor_type='value',\n\n        image_size_small=128,\n        image_size_large=278,\n\n        cropping_type='v3',  # ['v2', 'v3']\n        pasting_type='v3',  # ['v2', 'v3']\n        pasting_diff=True,\n\n        concat_win_size=True,\n\n        encoder_type='conv13_c3',\n        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']\n        # ['conv13_c3_attn']\n        # ['combine33', 'combine43', 'combine53', 'combineFC']\n\n        outside_loss_weight=10.0,\n        win_size_outside_loss_weight=10.0,\n\n        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']\n\n        concat_cursor=True,\n\n        use_softargmax=True,\n        soft_beta=10,  # value for the soft argmax\n\n        raster_loss_weight=1.0,\n\n        dec_rnn_size=256,  # Size of decoder.\n        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.\n        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.\n        bin_gt=True,\n\n        stop_accu_grad=True,\n\n        random_cursor=True,\n        cursor_type='next',\n\n        raster_size=128,\n\n        pix_drop_kp=1.0,  # Dropout keep rate\n        add_coordconv=True,\n        position_format='abs',\n        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]\n\n        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.\n\n        learning_rate=0.0001,  # Learning rate.\n        decay_rate=0.9999,  # Learning rate decay per minibatch.\n        decay_power=0.9,\n        min_learning_rate=0.000001,  # Minimum learning rate.\n\n        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended\n        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.\n        use_input_dropout=False,  # Input dropout. Recommend leaving False.\n        input_dropout_prob=0.90,  # Probability of input dropout keep.\n        use_output_dropout=False,  # Output dropout. Recommend leaving False.\n        output_dropout_prob=0.90,  # Probability of output dropout keep.\n\n        model_mode='train'  # ['train', 'eval', 'sample']\n    )\n    return hparams\n\n\ndef get_default_hparams_normal():\n    \"\"\"Return default HParams for sketch-rnn.\"\"\"\n    hparams = tf.contrib.training.HParams(\n        program_name='new_train_faces',\n        data_set='faces',  # ['rough_sketches', 'faces']\n\n        input_channel=3,\n\n        num_steps=90040,  # Total number of steps of training.\n        save_every=90000,\n        eval_every=5000,\n\n        max_seq_len=48,\n        batch_size=20,\n        gpus=[0, 1],\n        loop_per_gpu=1,\n\n        sn_loss_type='fixed',  # ['decreasing', 'fixed', 'increasing']\n        stroke_num_loss_weight=0.0,\n        stroke_num_loss_weight_end=0.0,\n        increase_start_steps=0,\n        decrease_stop_steps=40000,\n\n        photo_prob_type='interpolate',  # ['increasing', 'zero', 'one', 'interpolate']\n        photo_prob_start_step=30000,\n        photo_prob_end_step=60000,\n\n        perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU4_2', 'ReLU5_1'],\n        perc_loss_fuse_type='add',  # ['max', 'add', 'raw_add', 'weighted_sum']\n\n        early_pen_loss_type='move',  # ['head', 'tail', 'move']\n        early_pen_loss_weight=0.2,\n        early_pen_length=7,\n\n        min_width=0.01,\n        min_window_size=32,\n        max_scaling=2.0,\n\n        encode_cursor_type='value',\n\n        image_size_small=128,\n        image_size_large=256,\n\n        cropping_type='v3',  # ['v2', 'v3']\n        pasting_type='v3',  # ['v2', 'v3']\n        pasting_diff=True,\n\n        concat_win_size=True,\n\n        encoder_type='conv13_c3',\n        # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']\n        # ['conv13_c3_attn']\n        # ['combine33', 'combine43', 'combine53', 'combineFC']\n\n        outside_loss_weight=10.0,\n        win_size_outside_loss_weight=10.0,\n\n        resize_method='AREA',  # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']\n\n        concat_cursor=True,\n\n        use_softargmax=True,\n        soft_beta=10,  # value for the soft argmax\n\n        raster_loss_weight=1.0,\n\n        dec_rnn_size=256,  # Size of decoder.\n        dec_model='hyper',  # Decoder: lstm, layer_norm or hyper.\n        # z_size=128,  # Size of latent vector z. Recommend 32, 64 or 128.\n        bin_gt=True,\n\n        stop_accu_grad=True,\n\n        random_cursor=True,\n        cursor_type='next',\n\n        raster_size=128,\n\n        pix_drop_kp=1.0,  # Dropout keep rate\n        add_coordconv=True,\n        position_format='abs',\n        raster_loss_base_type='perceptual',  # [l1, mse, perceptual]\n\n        grad_clip=1.0,  # Gradient clipping. Recommend leaving at 1.0.\n\n        learning_rate=0.0001,  # Learning rate.\n        decay_rate=0.9999,  # Learning rate decay per minibatch.\n        decay_power=0.9,\n        min_learning_rate=0.000001,  # Minimum learning rate.\n\n        use_recurrent_dropout=True,  # Dropout with memory loss. Recommended\n        recurrent_dropout_prob=0.90,  # Probability of recurrent dropout keep.\n        use_input_dropout=False,  # Input dropout. Recommend leaving False.\n        input_dropout_prob=0.90,  # Probability of input dropout keep.\n        use_output_dropout=False,  # Output dropout. Recommend leaving False.\n        output_dropout_prob=0.90,  # Probability of output dropout keep.\n\n        model_mode='train'  # ['train', 'eval', 'sample']\n    )\n    return hparams\n"
  },
  {
    "path": "launch_gui.bat",
    "content": "@echo OFF\n\nREM === Cesta k instalaci Anacondy ===\nset \"CONDAPATH=C:\\ProgramData\\anaconda3\"\n\nREM === Název a cesta k prostředí ===\nset \"ENVNAME=virtual_sketching\"\nset \"ENVPATH=%USERPROFILE%\\.conda\\envs\\%ENVNAME%\"\n\nREM === Aktivace prostředí ===\ncall \"%CONDAPATH%\\Scripts\\activate.bat\" \"%ENVPATH%\"\n\nREM === Spuštění GUI ===\npython virtual_sketch_gui.py\n\nREM === Pozastavení po ukončení ===\necho.\npause\n\nREM === Deaktivace prostředí ===\ncall conda deactivate\n"
  },
  {
    "path": "model_common_test.py",
    "content": "import rnn\nimport tensorflow as tf\n\nfrom subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \\\n    generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \\\n    generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \\\n    generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \\\n    generative_cnn_c3_encoder_deeper13_attn\n\n\nclass DiffPastingV3(object):\n    def __init__(self, raster_size):\n        self.patch_canvas = tf.placeholder(dtype=tf.float32,\n                                           shape=(None, None, 1))  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        self.cursor_pos_a = tf.placeholder(dtype=tf.float32, shape=(2))  # (2), float32, in large size\n        self.image_size_a = tf.placeholder(dtype=tf.int32, shape=())  # ()\n        self.window_size_a = tf.placeholder(dtype=tf.float32, shape=())  # (), float32, with grad\n        self.raster_size_a = float(raster_size)\n\n        self.pasted_image = self.image_pasting_sampling_v3()\n        # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]\n\n    def image_pasting_sampling_v3(self):\n        padding_size = tf.cast(tf.ceil(self.window_size_a / 2.0), tf.int32)\n\n        x1y1_a = self.cursor_pos_a - self.window_size_a / 2.0  # (2), float32\n        x2y2_a = self.cursor_pos_a + self.window_size_a / 2.0  # (2), float32\n\n        x1y1_a_floor = tf.floor(x1y1_a)  # (2)\n        x2y2_a_ceil = tf.ceil(x2y2_a)  # (2)\n\n        cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0  # (2)\n        cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / self.window_size_a * self.raster_size_a  # (2)\n        raster_size_b = (x2y2_a_ceil - x1y1_a_floor)  # (x, y)\n        image_size_b = self.raster_size_a\n        window_size_b = self.raster_size_a * (raster_size_b / self.window_size_a)  # (x, y)\n\n        cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1)  # (1)\n\n        y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.\n        x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.\n        y2_b = y1_b + (window_size_b[1] - 1.)\n        x2_b = x1_b + (window_size_b[0] - 1.)\n        boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1)  # (4)\n        boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32)  # with grad to window_size_a\n\n        box_ind_b = tf.ones((1), dtype=tf.int32)  # (1)\n        box_ind_b = tf.cumsum(box_ind_b) - 1\n\n        patch_canvas = tf.expand_dims(self.patch_canvas,\n                                      axis=0)  # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        boxes_b = tf.expand_dims(boxes_b, axis=0)  # (1, 4)\n\n        valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,\n                                                crop_size=[raster_size_b[1], raster_size_b[0]])\n        valid_canvas = valid_canvas[0]  # (raster_size_b, raster_size_b, 1)\n\n        pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size\n        pad_down = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)\n        pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size\n        pad_right = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)\n\n        paddings = [[pad_up, pad_down],\n                    [pad_left, pad_right],\n                    [0, 0]]\n        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',\n                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]\n\n        pasted_image = pad_img[padding_size: padding_size + self.image_size_a,\n                       padding_size: padding_size + self.image_size_a, :]\n        # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]\n        return pasted_image\n\n\nclass VirtualSketchingModel(object):\n    def __init__(self, hps, gpu_mode=True, reuse=False):\n        \"\"\"Initializer for the model.\n\n    Args:\n       hps: a HParams object containing model hyperparameters\n       gpu_mode: a boolean that when True, uses GPU mode.\n       reuse: a boolean that when true, attemps to reuse variables.\n    \"\"\"\n        self.hps = hps\n        assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']\n        # with tf.variable_scope('SCC', reuse=reuse):\n        if not gpu_mode:\n            with tf.device('/cpu:0'):\n                print('Model using cpu.')\n                self.build_model()\n        else:\n            print('-' * 100)\n            print('model_mode:', hps.model_mode)\n            print('Model using gpu.')\n            self.build_model()\n\n    def build_model(self):\n        \"\"\"Define model architecture.\"\"\"\n        self.config_model()\n\n        initial_state = self.get_decoder_inputs()\n        self.initial_state = initial_state\n\n        ## use pred as the prev points\n        other_params, pen_ras, final_state = self.get_points_and_raster_image(self.image_size)\n        # other_params: (N * max_seq_len, 6)\n        # pen_ras: (N * max_seq_len, 2), after softmax\n\n        self.other_params = other_params  # (N * max_seq_len, 6)\n        self.pen_ras = pen_ras  # (N * max_seq_len, 2), after softmax\n        self.final_state = final_state\n\n        if not self.hps.use_softargmax:\n            pen_state_soft = pen_ras[:, 1:2]  # (N * max_seq_len, 1)\n        else:\n            pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta)  # (N * max_seq_len, 1)\n\n        pred_params = tf.concat([pen_state_soft, other_params], axis=1)  # (N * max_seq_len, 7)\n        self.pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7])  # (N, max_seq_len, 7)\n        # pred_params: (N, max_seq_len, 7)\n\n    def config_model(self):\n        if self.hps.model_mode == 'train':\n            self.global_step = tf.Variable(0, name='global_step', trainable=False)\n\n        if self.hps.dec_model == 'lstm':\n            dec_cell_fn = rnn.LSTMCell\n        elif self.hps.dec_model == 'layer_norm':\n            dec_cell_fn = rnn.LayerNormLSTMCell\n        elif self.hps.dec_model == 'hyper':\n            dec_cell_fn = rnn.HyperLSTMCell\n        else:\n            assert False, 'please choose a respectable cell'\n\n        use_recurrent_dropout = self.hps.use_recurrent_dropout\n        use_input_dropout = self.hps.use_input_dropout\n        use_output_dropout = self.hps.use_output_dropout\n\n        dec_cell = dec_cell_fn(\n            self.hps.dec_rnn_size,\n            use_recurrent_dropout=use_recurrent_dropout,\n            dropout_keep_prob=self.hps.recurrent_dropout_prob)\n\n        # dropout:\n        # print('Input dropout mode = %s.' % use_input_dropout)\n        # print('Output dropout mode = %s.' % use_output_dropout)\n        # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)\n        if use_input_dropout:\n            print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)\n            dec_cell = tf.contrib.rnn.DropoutWrapper(\n                dec_cell, input_keep_prob=self.hps.input_dropout_prob)\n        if use_output_dropout:\n            print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)\n            dec_cell = tf.contrib.rnn.DropoutWrapper(\n                dec_cell, output_keep_prob=self.hps.output_dropout_prob)\n        self.dec_cell = dec_cell\n\n        self.input_photo = tf.placeholder(dtype=tf.float32,\n                                          shape=[self.hps.batch_size, None, None, self.hps.input_channel])  # [0.0-stroke, 1.0-BG]\n        self.init_cursor = tf.placeholder(\n            dtype=tf.float32,\n            shape=[self.hps.batch_size, 1, 2])  # (N, 1, 2), in size [0.0, 1.0)\n        self.init_width = tf.placeholder(\n            dtype=tf.float32,\n            shape=[self.hps.batch_size])  # (1), in [0.0, 1.0]\n        self.init_scaling = tf.placeholder(\n            dtype=tf.float32,\n            shape=[self.hps.batch_size])  # (N), in [0.0, 1.0]\n        self.init_window_size = tf.placeholder(\n            dtype=tf.float32,\n            shape=[self.hps.batch_size])  # (N)\n        self.image_size = tf.placeholder(dtype=tf.int32, shape=())  # ()\n\n    ###########################\n\n    def normalize_image_m1to1(self, in_img_0to1):\n        norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)\n        norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)\n        return norm_img_m1to1\n\n    def add_coords(self, input_tensor):\n        batch_size_tensor = tf.shape(input_tensor)[0]  # get N size\n\n        xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)\n        xx_ones = tf.expand_dims(xx_ones, -1)  # e.g. (N, raster_size, 1)\n        xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),\n                           [batch_size_tensor, 1])  # e.g. (N, raster_size)\n        xx_range = tf.expand_dims(xx_range, 1)  # e.g. (N, 1, raster_size)\n\n        xx_channel = tf.matmul(xx_ones, xx_range)  # e.g. (N, raster_size, raster_size)\n        xx_channel = tf.expand_dims(xx_channel, -1)  # e.g. (N, raster_size, raster_size, 1)\n\n        yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)\n        yy_ones = tf.expand_dims(yy_ones, 1)  # e.g. (N, 1, raster_size)\n        yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),\n                           [batch_size_tensor, 1])  # (N, raster_size)\n        yy_range = tf.expand_dims(yy_range, -1)  # e.g. (N, raster_size, 1)\n\n        yy_channel = tf.matmul(yy_range, yy_ones)  # e.g. (N, raster_size, raster_size)\n        yy_channel = tf.expand_dims(yy_channel, -1)  # e.g. (N, raster_size, raster_size, 1)\n\n        xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)\n        yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)\n        # xx_channel = xx_channel * 2 - 1  # [-1, 1]\n        # yy_channel = yy_channel * 2 - 1\n\n        ret = tf.concat([\n            input_tensor,\n            xx_channel,\n            yy_channel,\n        ], axis=-1)  # e.g. (N, raster_size, raster_size, 4)\n\n        return ret\n\n    def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,\n                               image_size, window_size):\n        \"\"\"\n        :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]\n        :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]\n        :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]\n        :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]\n        :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)\n        :param window_size: (N, 1, 1), float, in large size\n        :return:\n        \"\"\"\n        if self.hps.resize_method == 'BILINEAR':\n            resize_method = tf.image.ResizeMethod.BILINEAR\n        elif self.hps.resize_method == 'NEAREST_NEIGHBOR':\n            resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR\n        elif self.hps.resize_method == 'BICUBIC':\n            resize_method = tf.image.ResizeMethod.BICUBIC\n        elif self.hps.resize_method == 'AREA':\n            resize_method = tf.image.ResizeMethod.AREA\n        else:\n            raise Exception('unknown resize_method', self.hps.resize_method)\n\n        patch_photo = tf.stop_gradient(patch_photo)\n        patch_canvas = tf.stop_gradient(patch_canvas)\n        cursor_pos = tf.stop_gradient(cursor_pos)\n        window_size = tf.stop_gradient(window_size)\n\n        entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,\n                                                                      (self.hps.raster_size, self.hps.raster_size),\n                                                                      method=resize_method))\n        entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,\n                                                                      (self.hps.raster_size, self.hps.raster_size),\n                                                                      method=resize_method))\n        entire_photo_small = self.normalize_image_m1to1(entire_photo_small)  # [-1.0-stroke, 1.0-BG]\n        entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small)  # [-1.0-stroke, 1.0-BG]\n\n        if self.hps.encode_cursor_type == 'value':\n            cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1)  # (N, 1, 1, 2)\n            cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])\n            cursor_info = cursor_pos_norm\n        else:\n            raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)\n\n        batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],\n                                axis=-1)  # [N, raster_size, raster_size, 6/10]\n        batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1)  # [N, raster_size, raster_size, 2/4]\n        batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],\n                                       axis=-1)  # [N, raster_size, raster_size, 4/6]\n\n        if self.hps.model_mode == 'train':\n            is_training = True\n            dropout_keep_prob = self.hps.pix_drop_kp\n        else:\n            is_training = False\n            dropout_keep_prob = 1.0\n\n        if self.hps.add_coordconv:\n            batch_input_combined = self.add_coords(batch_input_combined)  # (N, in_H, in_W, in_dim + 2)\n            batch_input_local = self.add_coords(batch_input_local)  # (N, in_H, in_W, in_dim + 2)\n            batch_input_global = self.add_coords(batch_input_global)  # (N, in_H, in_W, in_dim + 2)\n\n        if 'combine' in self.hps.encoder_type:\n            if self.hps.encoder_type == 'combine33':\n                image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combine43':\n                image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combine53':\n                image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combineFC':\n                image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 256)\n            else:\n                raise Exception('Unknown encoder_type', self.hps.encoder_type)\n        else:\n            with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):\n                if self.hps.encoder_type == 'conv10':\n                    image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_deep':\n                    image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)\n                elif self.hps.encoder_type == 'conv13':\n                    image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_deep_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)\n                elif self.hps.encoder_type == 'conv13_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv13_c3_attn':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                else:\n                    raise Exception('Unknown encoder_type', self.hps.encoder_type)\n        return image_embedding\n\n    def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):\n        rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)\n        rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])\n\n        pen_n_out = 2\n        params_n_out = 6\n\n        with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):\n            output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])\n            output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))\n            output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen)  # (N, pen_n_out)\n\n        with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):\n            output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])\n            output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))\n            output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params)  # (N, params_n_out)\n\n        output = tf.concat([output_pen, output_params], axis=1)  # (N, n_out)\n\n        return output, last_state\n\n    def get_mixture_coef(self, outputs):\n        z = outputs\n        z_pen_logits = z[:, 0:2]  # (N, 2), pen states\n        z_other_params_logits = z[:, 2:]  # (N, 6)\n\n        z_pen = tf.nn.softmax(z_pen_logits)  # (N, 2)\n        if self.hps.position_format == 'abs':\n            x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2])  # (N, 2)\n            x2y2 = tf.tanh(z_other_params_logits[:, 2:4])  # (N, 2)\n            widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5])  # (N, 1)\n            widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)\n            scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling  # (N, 1), [0.0, max_scaling]\n            # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),\n            #                  self.hps.min_scaling)\n            z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1)  # (N, 6)\n        else:  # \"rel\"\n            raise Exception('Unknown position_format', self.hps.position_format)\n\n        r = [z_other_params, z_pen]\n        return r\n\n    ###########################\n\n    def get_decoder_inputs(self):\n        initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)\n        return initial_state\n\n    def rnn_decoder(self, dec_cell, initial_state, actual_input_x):\n        with tf.variable_scope(\"RNN_DEC\", reuse=tf.AUTO_REUSE):\n            output, last_state = tf.nn.dynamic_rnn(\n                dec_cell,\n                actual_input_x,\n                initial_state=initial_state,\n                time_major=False,\n                swap_memory=True,\n                dtype=tf.float32)\n        return output, last_state\n\n    ###########################\n\n    def image_padding(self, ori_image, window_size, pad_value):\n        \"\"\"\n        Pad with (bg)\n        :param ori_image:\n        :return:\n        \"\"\"\n        paddings = [[0, 0],\n                    [window_size // 2, window_size // 2],\n                    [window_size // 2, window_size // 2],\n                    [0, 0]]\n        pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value)  # (N, H_p, W_p, k)\n        return pad_img\n\n    def image_cropping_fn(self, fn_inputs):\n        \"\"\"\n        crop the patch\n        :return:\n        \"\"\"\n        index_offset = self.hps.input_channel - 1\n        input_image = fn_inputs[:, :, 0:2 + index_offset]  # (image_size, image_size, -), [0.0-BG, 1.0-stroke]\n        cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset]  # (2), in [0.0, 1.0)\n        image_size = fn_inputs[0, 0, 4 + index_offset]  # (), float32\n        window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32)  # ()\n\n        input_img_reshape = tf.expand_dims(input_image, axis=0)\n        pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)\n\n        cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)\n        x0, x1 = cursor_pos[0], cursor_pos[0] + window_size  # ()\n        y0, y1 = cursor_pos[1], cursor_pos[1] + window_size  # ()\n        patch_image = pad_img[:, y0:y1, x0:x1, :]  # (1, window_size, window_size, 2/4)\n\n        # resize to raster_size\n        patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),\n                                                    method=tf.image.ResizeMethod.AREA)\n        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)\n        # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]\n\n        return patch_image_scaled\n\n    def image_cropping(self, cursor_position, input_img, image_size, window_sizes):\n        \"\"\"\n        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        \"\"\"\n        input_img_ = input_img\n        window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes))  # (N, 1, 1), no grad\n\n        cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)\n        cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 2)\n\n        image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (1, 1, 1, 1)\n        image_size_ = tf.tile(image_size_, [self.hps.batch_size, image_size, image_size, 1])\n\n        window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1))  # (N, 1, 1, 1)\n        window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 1)\n\n        fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],\n                              axis=-1)  # (N, image_size, image_size, 2/4 + 4)\n        curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32)  # (N, raster_size, raster_size, -)\n        return curr_patch_imgs\n\n    def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):\n        \"\"\"\n        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        \"\"\"\n        window_sizes_non_grad = tf.stop_gradient(window_sizes)  # (N, 1, 1), no grad\n\n        cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))\n        cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1)  # (N, 1, 1)\n\n        y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2\n        x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2\n        y2 = y1 + (window_sizes_non_grad - 1.0)\n        x2 = x1 + (window_sizes_non_grad - 1.0)\n        boxes = tf.concat([y1, x1, y2, x2], axis=-1)  # (N, 1, 4)\n        boxes = tf.squeeze(boxes, axis=1)  # (N, 4)\n        boxes = boxes / tf.cast(image_size - 1, tf.float32)\n\n        box_ind = tf.ones_like(cursor_x)[:, 0, 0]  # (N)\n        box_ind = tf.cast(box_ind, dtype=tf.int32)\n        box_ind = tf.cumsum(box_ind) - 1\n\n        curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,\n                                                   crop_size=[self.hps.raster_size, self.hps.raster_size])\n        #  (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]\n        return curr_patch_imgs\n\n    def get_points_and_raster_image(self, image_size):\n        ## generate the other_params and pen_ras and raster image for raster loss\n        prev_state = self.initial_state  # (N, dec_rnn_size * 3)\n\n        prev_width = self.init_width  # (N)\n        prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=-1), axis=-1)  # (N, 1, 1)\n\n        prev_scaling = self.init_scaling  # (N)\n        prev_scaling = tf.reshape(prev_scaling, (-1, 1, 1))  # (N, 1, 1)\n\n        prev_window_size = self.init_window_size  # (N)\n        prev_window_size = tf.reshape(prev_window_size, (-1, 1, 1))  # (N, 1, 1)\n\n        cursor_position_temp = self.init_cursor\n        self.cursor_position = cursor_position_temp  # (N, 1, 2), in size [0.0, 1.0)\n        cursor_position_loop = self.cursor_position\n\n        other_params_list = []\n        pen_ras_list = []\n\n        curr_canvas_soft = tf.zeros_like(self.input_photo[:, :, :, 0])  # (N, image_size, image_size), [0.0-BG, 1.0-stroke]\n        curr_canvas_hard = tf.zeros_like(curr_canvas_soft)  # [0.0-BG, 1.0-stroke]\n\n        #### sampling part - start ####\n        self.curr_canvas_hard = curr_canvas_hard\n\n        if self.hps.cropping_type == 'v3':\n            cropping_func = self.image_cropping_v3\n        # elif self.hps.cropping_type == 'v2':\n        #     cropping_func = self.image_cropping\n        else:\n            raise Exception('Unknown cropping_type', self.hps.cropping_type)\n\n        for time_i in range(self.hps.max_seq_len):\n            cursor_position_non_grad = tf.stop_gradient(cursor_position_loop)  # (N, 1, 2), in size [0.0, 1.0)\n\n            curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size))  # float, with grad\n            curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))\n            curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))\n\n            ## patch-level encoding\n            # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.\n            curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)\n            curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)\n\n            # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]\n            crop_inputs = tf.concat([1.0 - self.input_photo, curr_canvas_hard_non_grad], axis=-1)  # (N, H_p, W_p, 1+1)\n\n            cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)\n            index_offset = self.hps.input_channel - 1\n            curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset]  # [0.0-BG, 1.0-stroke]\n            curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]\n            # (N, raster_size, raster_size, 1/3), [0.0-BG, 1.0-stroke]\n\n            curr_patch_inputs = 1.0 - curr_patch_inputs  # [0.0-stroke, 1.0-BG]\n            curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)\n            # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]\n\n            # Normalizing image\n            curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad  # [0.0-stroke, 1.0-BG]\n            curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad)  # [-1.0-stroke, 1.0-BG]\n\n            ## image-level encoding\n            combined_z = self.build_combined_encoder(\n                curr_patch_canvas_hard_non_grad,\n                curr_patch_inputs,\n                1.0 - curr_canvas_hard_non_grad,\n                self.input_photo,\n                cursor_position_non_grad,\n                image_size,\n                curr_window_size)  # (N, z_size)\n            combined_z = tf.expand_dims(combined_z, axis=1)  # (N, 1, z_size)\n\n            curr_window_size_top_side_norm_non_grad = \\\n                tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))\n            curr_window_size_bottom_side_norm_non_grad = \\\n                tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))\n            if not self.hps.concat_win_size:\n                combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2)  # (N, 1, 2+z_size)\n            else:\n                combined_z = tf.concat([tf.stop_gradient(prev_width),\n                                        curr_window_size_top_side_norm_non_grad,\n                                        curr_window_size_bottom_side_norm_non_grad,\n                                        combined_z],\n                                       2)  # (N, 1, 2+z_size)\n\n            if self.hps.concat_cursor:\n                prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2)  # (N, 1, 2+2+z_size)\n            else:\n                prev_input_x = combined_z  # (N, 1, 2+z_size)\n\n            h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)\n            # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)\n            [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)\n            # o_other_params: (N * 1, 6)\n            # o_pen_ras: (N * 1, 2), after softmax\n\n            o_other_params = tf.reshape(o_other_params, [-1, 1, 6])  # (N, 1, 6)\n            o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2])  # (N, 1, 2)\n\n            other_params_list.append(o_other_params)\n            pen_ras_list.append(o_pen_ras_raw)\n\n            #### sampling part - end ####\n\n            prev_state = next_state\n\n        other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6])  # (N * max_seq_len, 6)\n        pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2])  # (N * max_seq_len, 2)\n\n        return other_params_, pen_ras_, prev_state\n\n    def differentiable_argmax(self, input_pen, soft_beta):\n        \"\"\"\n        Differentiable argmax trick.\n        :param input_pen: (N, n_class)\n        :return: pen_state: (N, 1)\n        \"\"\"\n        def sign_onehot(x):\n            \"\"\"\n            :param x: (N, n_class)\n            :return:  (N, n_class)\n            \"\"\"\n            y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)\n            y = (y - 1) * (-1)\n            return y\n\n        def softargmax(x, beta=1e2):\n            \"\"\"\n            :param x: (N, n_class)\n            :param beta: 1e10 is the best. 1e2 is acceptable.\n            :return:  (N)\n            \"\"\"\n            x_range = tf.cumsum(tf.ones_like(x), axis=1)  # (N, 2)\n            return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1\n\n        ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.\n        # pen_onehot = sign_onehot(input_pen)  # one-hot form, (N * max_seq_len, 2)\n        # pen_state = pen_onehot[:, 1:2]  # (N * max_seq_len, 1)\n        pen_state = softargmax(input_pen, soft_beta)\n        pen_state = tf.expand_dims(pen_state, axis=1)  # (N * max_seq_len, 1)\n        return pen_state\n"
  },
  {
    "path": "model_common_train.py",
    "content": "import rnn\nimport tensorflow as tf\n\nfrom subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \\\n    generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \\\n    generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \\\n    generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \\\n    generative_cnn_c3_encoder_deeper13_attn\nfrom rasterization_utils.NeuralRenderer import NeuralRasterizorStep\nfrom vgg_utils.VGG16 import vgg_net_slim\n\n\nclass VirtualSketchingModel(object):\n    def __init__(self, hps, gpu_mode=True, reuse=False):\n        \"\"\"Initializer for the model.\n\n    Args:\n       hps: a HParams object containing model hyperparameters\n       gpu_mode: a boolean that when True, uses GPU mode.\n       reuse: a boolean that when true, attemps to reuse variables.\n    \"\"\"\n        self.hps = hps\n        assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']\n        # with tf.variable_scope('SCC', reuse=reuse):\n        if not gpu_mode:\n            with tf.device('/cpu:0'):\n                print('Model using cpu.')\n                self.build_model()\n        else:\n            print('-' * 100)\n            print('model_mode:', hps.model_mode)\n            print('Model using gpu.')\n            self.build_model()\n\n    def build_model(self):\n        \"\"\"Define model architecture.\"\"\"\n        self.config_model()\n\n        initial_state = self.get_decoder_inputs()\n        self.initial_state = initial_state\n        self.initial_state_list = tf.split(self.initial_state, self.total_loop, axis=0)\n\n        total_loss_list = []\n        ras_loss_list = []\n        perc_relu_raw_list = []\n        perc_relu_norm_list = []\n        sn_loss_list = []\n        cursor_outside_loss_list = []\n        win_size_outside_loss_list = []\n        early_state_loss_list = []\n\n        tower_grads = []\n\n        pred_raster_imgs_list = []\n        pred_raster_imgs_rgb_list = []\n\n        for t_i in range(self.total_loop):\n            gpu_idx = t_i // self.hps.loop_per_gpu\n            gpu_i = self.hps.gpus[gpu_idx]\n            print(self.hps.model_mode, 'model, gpu:', gpu_i, ', loop:', t_i % self.hps.loop_per_gpu)\n            with tf.device('/gpu:%d' % gpu_i):\n                with tf.name_scope('GPU_%d' % gpu_i) as scope:\n                    if t_i > 0:\n                        tf.get_variable_scope().reuse_variables()\n                    else:\n                        total_loss_list.clear()\n                        ras_loss_list.clear()\n                        perc_relu_raw_list.clear()\n                        perc_relu_norm_list.clear()\n                        sn_loss_list.clear()\n                        cursor_outside_loss_list.clear()\n                        win_size_outside_loss_list.clear()\n                        early_state_loss_list.clear()\n                        tower_grads.clear()\n                        pred_raster_imgs_list.clear()\n                        pred_raster_imgs_rgb_list.clear()\n\n                    split_input_photo = self.input_photo_list[t_i]\n                    split_image_size = self.image_size[t_i]\n                    split_init_cursor = self.init_cursor_list[t_i]\n                    split_initial_state = self.initial_state_list[t_i]\n                    if self.hps.input_channel == 1:\n                        split_target_sketch = split_input_photo\n                    else:\n                        split_target_sketch = self.target_sketch_list[t_i]\n\n                    ## use pred as the prev points\n                    other_params, pen_ras, final_state, pred_raster_images, pred_raster_images_rgb, \\\n                    pos_before_max_min, win_size_before_max_min \\\n                        = self.get_points_and_raster_image(split_initial_state, split_init_cursor, split_input_photo,\n                                                           split_image_size)\n                    # other_params: (N * max_seq_len, 6)\n                    # pen_ras: (N * max_seq_len, 2), after softmax\n                    # pos_before_max_min: (N, max_seq_len, 2), in image_size\n                    # win_size_before_max_min: (N, max_seq_len, 1), in image_size\n\n                    pred_raster_imgs = 1.0 - pred_raster_images  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]\n                    pred_raster_imgs_rgb = 1.0 - pred_raster_images_rgb  # (N, image_size, image_size, 3)\n                    pred_raster_imgs_list.append(pred_raster_imgs)\n                    pred_raster_imgs_rgb_list.append(pred_raster_imgs_rgb)\n\n                    if not self.hps.use_softargmax:\n                        pen_state_soft = pen_ras[:, 1:2]  # (N * max_seq_len, 1)\n                    else:\n                        pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta)  # (N * max_seq_len, 1)\n\n                    pred_params = tf.concat([pen_state_soft, other_params], axis=1)  # (N * max_seq_len, 7)\n                    pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7])  # (N, max_seq_len, 7)\n                    # pred_params: (N, max_seq_len, 7)\n\n                    if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':\n                        raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, \\\n                        early_pen_states_cost, \\\n                        perc_relu_loss_raw, perc_relu_loss_norm = \\\n                            self.build_losses(split_target_sketch, pred_raster_imgs, pred_params,\n                                              pos_before_max_min, win_size_before_max_min,\n                                              split_image_size)\n                        # perc_relu_loss_raw, perc_relu_loss_norm: (n_layers)\n\n                        ras_loss_list.append(raster_cost)\n                        perc_relu_raw_list.append(perc_relu_loss_raw)\n                        perc_relu_norm_list.append(perc_relu_loss_norm)\n                        sn_loss_list.append(sn_cost)\n                        cursor_outside_loss_list.append(cursor_outside_cost)\n                        win_size_outside_loss_list.append(winsize_outside_cost)\n                        early_state_loss_list.append(early_pen_states_cost)\n\n                        if self.hps.model_mode == 'train':\n                            total_cost_split, grads_and_vars_split = self.build_training_op_split(\n                                raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost,\n                                early_pen_states_cost)\n                            total_loss_list.append(total_cost_split)\n                            tower_grads.append(grads_and_vars_split)\n\n        self.raster_cost = tf.reduce_mean(tf.stack(ras_loss_list, axis=0))\n        self.perc_relu_losses_raw = tf.reduce_mean(tf.stack(perc_relu_raw_list, axis=0), axis=0)  # (n_layers)\n        self.perc_relu_losses_norm = tf.reduce_mean(tf.stack(perc_relu_norm_list, axis=0), axis=0)  # (n_layers)\n        self.stroke_num_cost = tf.reduce_mean(tf.stack(sn_loss_list, axis=0))\n        self.pos_outside_cost = tf.reduce_mean(tf.stack(cursor_outside_loss_list, axis=0))\n        self.win_size_outside_cost = tf.reduce_mean(tf.stack(win_size_outside_loss_list, axis=0))\n        self.early_pen_states_cost = tf.reduce_mean(tf.stack(early_state_loss_list, axis=0))\n        self.cost = tf.reduce_mean(tf.stack(total_loss_list, axis=0))\n\n        self.pred_raster_imgs = tf.concat(pred_raster_imgs_list, axis=0)  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]\n        self.pred_raster_imgs_rgb = tf.concat(pred_raster_imgs_rgb_list, axis=0)  # (N, image_size, image_size, 3)\n\n        if self.hps.model_mode == 'train':\n            self.build_training_op(tower_grads)\n\n    def config_model(self):\n        if self.hps.model_mode == 'train':\n            self.global_step = tf.Variable(0, name='global_step', trainable=False)\n\n        if self.hps.dec_model == 'lstm':\n            dec_cell_fn = rnn.LSTMCell\n        elif self.hps.dec_model == 'layer_norm':\n            dec_cell_fn = rnn.LayerNormLSTMCell\n        elif self.hps.dec_model == 'hyper':\n            dec_cell_fn = rnn.HyperLSTMCell\n        else:\n            assert False, 'please choose a respectable cell'\n\n        use_recurrent_dropout = self.hps.use_recurrent_dropout\n        use_input_dropout = self.hps.use_input_dropout\n        use_output_dropout = self.hps.use_output_dropout\n\n        dec_cell = dec_cell_fn(\n            self.hps.dec_rnn_size,\n            use_recurrent_dropout=use_recurrent_dropout,\n            dropout_keep_prob=self.hps.recurrent_dropout_prob)\n\n        # dropout:\n        # print('Input dropout mode = %s.' % use_input_dropout)\n        # print('Output dropout mode = %s.' % use_output_dropout)\n        # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)\n        if use_input_dropout:\n            print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)\n            dec_cell = tf.contrib.rnn.DropoutWrapper(\n                dec_cell, input_keep_prob=self.hps.input_dropout_prob)\n        if use_output_dropout:\n            print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)\n            dec_cell = tf.contrib.rnn.DropoutWrapper(\n                dec_cell, output_keep_prob=self.hps.output_dropout_prob)\n        self.dec_cell = dec_cell\n\n        self.total_loop = len(self.hps.gpus) * self.hps.loop_per_gpu\n\n        self.init_cursor = tf.placeholder(\n            dtype=tf.float32,\n            shape=[self.hps.batch_size, 1, 2])  # (N, 1, 2), in size [0.0, 1.0)\n        self.init_width = tf.placeholder(\n            dtype=tf.float32,\n            shape=[1])  # (1), in [0.0, 1.0]\n        self.image_size = tf.placeholder(dtype=tf.int32, shape=(self.total_loop))  # ()\n\n        self.init_cursor_list = tf.split(self.init_cursor, self.total_loop, axis=0)\n        self.input_photo_list = []\n        for loop_i in range(self.total_loop):\n            input_photo_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, self.hps.input_channel])  # [0.0-stroke, 1.0-BG]\n            self.input_photo_list.append(input_photo_i)\n\n        if self.hps.input_channel == 3:\n            self.target_sketch_list = []\n            for loop_i in range(self.total_loop):\n                target_sketch_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1])  # [0.0-stroke, 1.0-BG]\n                self.target_sketch_list.append(target_sketch_i)\n\n        if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':\n            self.stroke_num_loss_weight = tf.Variable(0.0, trainable=False)\n            self.early_pen_loss_start_idx = tf.Variable(0, dtype=tf.int32, trainable=False)\n            self.early_pen_loss_end_idx = tf.Variable(0, dtype=tf.int32, trainable=False)\n\n        if self.hps.model_mode == 'train':\n            self.perc_loss_mean_list = []\n            for loop_i in range(len(self.hps.perc_loss_layers)):\n                relu_loss_mean = tf.Variable(0.0, trainable=False)\n                self.perc_loss_mean_list.append(relu_loss_mean)\n            self.last_step_num = tf.Variable(0.0, trainable=False)\n\n            with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):\n                self.lr = tf.Variable(self.hps.learning_rate, trainable=False)\n                self.optimizer = tf.train.AdamOptimizer(self.lr)\n\n    ###########################\n\n    def normalize_image_m1to1(self, in_img_0to1):\n        norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)\n        norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)\n        return norm_img_m1to1\n\n    def add_coords(self, input_tensor):\n        batch_size_tensor = tf.shape(input_tensor)[0]  # get N size\n\n        xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)\n        xx_ones = tf.expand_dims(xx_ones, -1)  # e.g. (N, raster_size, 1)\n        xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),\n                           [batch_size_tensor, 1])  # e.g. (N, raster_size)\n        xx_range = tf.expand_dims(xx_range, 1)  # e.g. (N, 1, raster_size)\n\n        xx_channel = tf.matmul(xx_ones, xx_range)  # e.g. (N, raster_size, raster_size)\n        xx_channel = tf.expand_dims(xx_channel, -1)  # e.g. (N, raster_size, raster_size, 1)\n\n        yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32)  # e.g. (N, raster_size)\n        yy_ones = tf.expand_dims(yy_ones, 1)  # e.g. (N, 1, raster_size)\n        yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),\n                           [batch_size_tensor, 1])  # (N, raster_size)\n        yy_range = tf.expand_dims(yy_range, -1)  # e.g. (N, raster_size, 1)\n\n        yy_channel = tf.matmul(yy_range, yy_ones)  # e.g. (N, raster_size, raster_size)\n        yy_channel = tf.expand_dims(yy_channel, -1)  # e.g. (N, raster_size, raster_size, 1)\n\n        xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)\n        yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)\n        # xx_channel = xx_channel * 2 - 1  # [-1, 1]\n        # yy_channel = yy_channel * 2 - 1\n\n        ret = tf.concat([\n            input_tensor,\n            xx_channel,\n            yy_channel,\n        ], axis=-1)  # e.g. (N, raster_size, raster_size, 4)\n\n        return ret\n\n    def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,\n                               image_size, window_size):\n        \"\"\"\n        :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]\n        :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]\n        :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]\n        :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]\n        :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)\n        :param window_size: (N, 1, 1), float, in large size\n        :return:\n        \"\"\"\n        if self.hps.resize_method == 'BILINEAR':\n            resize_method = tf.image.ResizeMethod.BILINEAR\n        elif self.hps.resize_method == 'NEAREST_NEIGHBOR':\n            resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR\n        elif self.hps.resize_method == 'BICUBIC':\n            resize_method = tf.image.ResizeMethod.BICUBIC\n        elif self.hps.resize_method == 'AREA':\n            resize_method = tf.image.ResizeMethod.AREA\n        else:\n            raise Exception('unknown resize_method', self.hps.resize_method)\n\n        patch_photo = tf.stop_gradient(patch_photo)\n        patch_canvas = tf.stop_gradient(patch_canvas)\n        cursor_pos = tf.stop_gradient(cursor_pos)\n        window_size = tf.stop_gradient(window_size)\n\n        entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,\n                                                                      (self.hps.raster_size, self.hps.raster_size),\n                                                                      method=resize_method))\n        entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,\n                                                                      (self.hps.raster_size, self.hps.raster_size),\n                                                                      method=resize_method))\n        entire_photo_small = self.normalize_image_m1to1(entire_photo_small)  # [-1.0-stroke, 1.0-BG]\n        entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small)  # [-1.0-stroke, 1.0-BG]\n\n        if self.hps.encode_cursor_type == 'value':\n            cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1)  # (N, 1, 1, 2)\n            cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])\n            cursor_info = cursor_pos_norm\n        else:\n            raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)\n\n        batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],\n                                axis=-1)  # [N, raster_size, raster_size, 6/10]\n        batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1)  # [N, raster_size, raster_size, 2/4]\n        batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],\n                                       axis=-1)  # [N, raster_size, raster_size, 4/6]\n\n        if self.hps.model_mode == 'train':\n            is_training = True\n            dropout_keep_prob = self.hps.pix_drop_kp\n        else:\n            is_training = False\n            dropout_keep_prob = 1.0\n\n        if self.hps.add_coordconv:\n            batch_input_combined = self.add_coords(batch_input_combined)  # (N, in_H, in_W, in_dim + 2)\n            batch_input_local = self.add_coords(batch_input_local)  # (N, in_H, in_W, in_dim + 2)\n            batch_input_global = self.add_coords(batch_input_global)  # (N, in_H, in_W, in_dim + 2)\n\n        if 'combine' in self.hps.encoder_type:\n            if self.hps.encoder_type == 'combine33':\n                image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combine43':\n                image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combine53':\n                image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 128)\n            elif self.hps.encoder_type == 'combineFC':\n                image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,\n                                                                         is_training, dropout_keep_prob)  # (N, 256)\n            else:\n                raise Exception('Unknown encoder_type', self.hps.encoder_type)\n        else:\n            with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):\n                if self.hps.encoder_type == 'conv10':\n                    image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_deep':\n                    image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)\n                elif self.hps.encoder_type == 'conv13':\n                    image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv10_deep_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob)  # (N, 512)\n                elif self.hps.encoder_type == 'conv13_c3':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                elif self.hps.encoder_type == 'conv13_c3_attn':\n                    image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob)  # (N, 128)\n                else:\n                    raise Exception('Unknown encoder_type', self.hps.encoder_type)\n        return image_embedding\n\n    def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):\n        rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)\n        rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])\n\n        pen_n_out = 2\n        params_n_out = 6\n\n        with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):\n            output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])\n            output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))\n            output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen)  # (N, pen_n_out)\n\n        with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):\n            output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])\n            output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))\n            output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params)  # (N, params_n_out)\n\n        output = tf.concat([output_pen, output_params], axis=1)  # (N, n_out)\n\n        return output, last_state\n\n    def get_mixture_coef(self, outputs):\n        z = outputs\n        z_pen_logits = z[:, 0:2]  # (N, 2), pen states\n        z_other_params_logits = z[:, 2:]  # (N, 6)\n\n        z_pen = tf.nn.softmax(z_pen_logits)  # (N, 2)\n        if self.hps.position_format == 'abs':\n            x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2])  # (N, 2)\n            x2y2 = tf.tanh(z_other_params_logits[:, 2:4])  # (N, 2)\n            widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5])  # (N, 1)\n            widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)\n            scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling  # (N, 1), [0.0, max_scaling]\n            # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),\n            #                  self.hps.min_scaling)\n            z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1)  # (N, 6)\n        else:  # \"rel\"\n            raise Exception('Unknown position_format', self.hps.position_format)\n\n        r = [z_other_params, z_pen]\n        return r\n\n    ###########################\n\n    def get_decoder_inputs(self):\n        initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)\n        return initial_state\n\n    def rnn_decoder(self, dec_cell, initial_state, actual_input_x):\n        with tf.variable_scope(\"RNN_DEC\", reuse=tf.AUTO_REUSE):\n            output, last_state = tf.nn.dynamic_rnn(\n                dec_cell,\n                actual_input_x,\n                initial_state=initial_state,\n                time_major=False,\n                swap_memory=True,\n                dtype=tf.float32)\n        return output, last_state\n\n    ###########################\n\n    def image_padding(self, ori_image, window_size, pad_value):\n        \"\"\"\n        Pad with (bg)\n        :param ori_image:\n        :return:\n        \"\"\"\n        paddings = [[0, 0],\n                    [window_size // 2, window_size // 2],\n                    [window_size // 2, window_size // 2],\n                    [0, 0]]\n        pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value)  # (N, H_p, W_p, k)\n        return pad_img\n\n    def image_cropping_fn(self, fn_inputs):\n        \"\"\"\n        crop the patch\n        :return:\n        \"\"\"\n        index_offset = self.hps.input_channel - 1\n        input_image = fn_inputs[:, :, 0:2 + index_offset]  # (image_size, image_size, 2), [0.0-BG, 1.0-stroke]\n        cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset]  # (2), in [0.0, 1.0)\n        image_size = fn_inputs[0, 0, 4 + index_offset]  # (), float32\n        window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32)  # ()\n\n        input_img_reshape = tf.expand_dims(input_image, axis=0)\n        pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)\n\n        cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)\n        x0, x1 = cursor_pos[0], cursor_pos[0] + window_size  # ()\n        y0, y1 = cursor_pos[1], cursor_pos[1] + window_size  # ()\n        patch_image = pad_img[:, y0:y1, x0:x1, :]  # (1, window_size, window_size, 2/4)\n\n        # resize to raster_size\n        patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),\n                                                    method=tf.image.ResizeMethod.AREA)\n        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)\n        # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]\n\n        return patch_image_scaled\n\n    def image_cropping(self, cursor_position, input_img, image_size, window_sizes):\n        \"\"\"\n        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        \"\"\"\n        input_img_ = input_img\n        window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes))  # (N, 1, 1), no grad\n\n        cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)\n        cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 2)\n\n        image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (1, 1, 1, 1)\n        image_size_ = tf.tile(image_size_, [self.hps.batch_size // self.total_loop, image_size, image_size, 1])\n\n        window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1))  # (N, 1, 1, 1)\n        window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1])  # (N, image_size, image_size, 1)\n\n        fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],\n                              axis=-1)  # (N, image_size, image_size, 2/4 + 4)\n        curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32)  # (N, raster_size, raster_size, -)\n        return curr_patch_imgs\n\n    def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):\n        \"\"\"\n        :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        \"\"\"\n        window_sizes_non_grad = tf.stop_gradient(window_sizes)  # (N, 1, 1), no grad\n\n        cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))\n        cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1)  # (N, 1, 1)\n\n        y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2\n        x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2\n        y2 = y1 + (window_sizes_non_grad - 1.0)\n        x2 = x1 + (window_sizes_non_grad - 1.0)\n        boxes = tf.concat([y1, x1, y2, x2], axis=-1)  # (N, 1, 4)\n        boxes = tf.squeeze(boxes, axis=1)  # (N, 4)\n        boxes = boxes / tf.cast(image_size - 1, tf.float32)\n\n        box_ind = tf.ones_like(cursor_x)[:, 0, 0]  # (N)\n        box_ind = tf.cast(box_ind, dtype=tf.int32)\n        box_ind = tf.cumsum(box_ind) - 1\n\n        curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,\n                                                   crop_size=[self.hps.raster_size, self.hps.raster_size])\n        #  (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]\n        return curr_patch_imgs\n\n    def get_pixel_value(self, img, x, y):\n        \"\"\"\n        Utility function to get pixel value for coordinate vectors x and y from a  4D tensor image.\n\n        Input\n        -----\n        - img: tensor of shape (B, H, W, C)\n        - x: flattened tensor of shape (B, H', W')\n        - y: flattened tensor of shape (B, H', W')\n\n        Returns\n        -------\n        - output: tensor of shape (B, H', W', C)\n        \"\"\"\n        shape = tf.shape(x)\n        batch_size = shape[0]\n        height = shape[1]\n        width = shape[2]\n\n        batch_idx = tf.range(0, batch_size)\n        batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))\n        b = tf.tile(batch_idx, (1, height, width))\n\n        indices = tf.stack([b, y, x], 3)\n\n        return tf.gather_nd(img, indices)\n\n    def image_pasting_nondiff_single(self, fn_inputs):\n        patch_image = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        cursor_pos = fn_inputs[0, 0, 1:3]  # (2), in large size\n        image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()\n        window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32)  # ()\n\n        patch_image_scaled = tf.expand_dims(patch_image, axis=0)  # (1, raster_size, raster_size, 1)\n        patch_image_scaled = tf.image.resize_images(patch_image_scaled, (window_size, window_size),\n                                                    method=tf.image.ResizeMethod.BILINEAR)\n        patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)\n        # patch_canvas_scaled: (window_size, window_size, 1)\n\n        cursor_pos = tf.cast(tf.round(cursor_pos), dtype=tf.int32)  # (2)\n        cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]\n\n        pad_up = cursor_y\n        pad_down = image_size - cursor_y\n        pad_left = cursor_x\n        pad_right = image_size - cursor_x\n\n        paddings = [[pad_up, pad_down],\n                    [pad_left, pad_right],\n                    [0, 0]]\n        pad_img = tf.pad(patch_image_scaled, paddings=paddings, mode='CONSTANT',\n                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]\n\n        crop_start = window_size // 2\n        pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]\n        return pasted_image\n\n    def image_pasting_diff_single(self, fn_inputs):\n        patch_canvas = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        cursor_pos = fn_inputs[0, 0, 1:3]  # (2), in large size\n        image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()\n        window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32)  # ()\n        cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]\n\n        patch_canvas_scaled = tf.expand_dims(patch_canvas, axis=0)  # (1, raster_size, raster_size, 1)\n        patch_canvas_scaled = tf.image.resize_images(patch_canvas_scaled, (window_size, window_size),\n                                                     method=tf.image.ResizeMethod.BILINEAR)\n        # patch_canvas_scaled: (1, window_size, window_size, 1)\n\n        valid_canvas = self.image_pasting_diff_batch(patch_canvas_scaled,\n                                                     tf.expand_dims(tf.expand_dims(cursor_pos, axis=0), axis=0),\n                                                     window_size)\n        valid_canvas = tf.squeeze(valid_canvas, axis=0)\n        # (window_size + 1, window_size + 1, 1)\n\n        pad_up = tf.cast(tf.floor(cursor_y), tf.int32)\n        pad_down = image_size - 1 - tf.cast(tf.floor(cursor_y), tf.int32)\n        pad_left = tf.cast(tf.floor(cursor_x), tf.int32)\n        pad_right = image_size - 1 - tf.cast(tf.floor(cursor_x), tf.int32)\n\n        paddings = [[pad_up, pad_down],\n                    [pad_left, pad_right],\n                    [0, 0]]\n        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',\n                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]\n\n        crop_start = window_size // 2\n        pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]\n        return pasted_image\n\n    def image_pasting_diff_single_v3(self, fn_inputs):\n        patch_canvas = fn_inputs[:, :, 0:1]  # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        cursor_pos_a = fn_inputs[0, 0, 1:3]  # (2), float32, in large size\n        image_size_a = tf.cast(fn_inputs[0, 0, 3], tf.int32)  # ()\n        window_size_a = fn_inputs[0, 0, 4]  # (), float32, with grad\n        raster_size_a = float(self.hps.raster_size)\n\n        padding_size = tf.cast(tf.ceil(window_size_a / 2.0), tf.int32)\n\n        x1y1_a = cursor_pos_a - window_size_a / 2.0  # (2), float32\n        x2y2_a = cursor_pos_a + window_size_a / 2.0  # (2), float32\n\n        x1y1_a_floor = tf.floor(x1y1_a)  # (2)\n        x2y2_a_ceil = tf.ceil(x2y2_a)  # (2)\n\n        cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0  # (2)\n        cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / window_size_a * raster_size_a  # (2)\n        raster_size_b = (x2y2_a_ceil - x1y1_a_floor)  # (x, y)\n        image_size_b = raster_size_a\n        window_size_b = raster_size_a * (raster_size_b / window_size_a)  # (x, y)\n\n        cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1)  # (1)\n\n        y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.\n        x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.\n        y2_b = y1_b + (window_size_b[1] - 1.)\n        x2_b = x1_b + (window_size_b[0] - 1.)\n        boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1)  # (4)\n        boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32)  # with grad to window_size_a\n\n        box_ind_b = tf.ones((1), dtype=tf.int32)  # (1)\n        box_ind_b = tf.cumsum(box_ind_b) - 1\n\n        patch_canvas = tf.expand_dims(patch_canvas, axis=0)  # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n        boxes_b = tf.expand_dims(boxes_b, axis=0)  # (1, 4)\n\n        valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,\n                                                crop_size=[raster_size_b[1], raster_size_b[0]])\n        valid_canvas = valid_canvas[0]  # (raster_size_b, raster_size_b, 1)\n\n        pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size\n        pad_down = image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)\n        pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size\n        pad_right = image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)\n\n        paddings = [[pad_up, pad_down],\n                    [pad_left, pad_right],\n                    [0, 0]]\n        pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',\n                         constant_values=0.0)  # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]\n\n        pasted_image = pad_img[padding_size: padding_size + image_size_a, padding_size: padding_size + image_size_a, :]\n        return pasted_image\n\n    def image_pasting_diff_batch(self, patch_image, cursor_position, window_size):\n        \"\"\"\n        :param patch_img: (N, window_size, window_size, 1), [0.0-BG, 1.0-stroke]\n        :param cursor_position: (N, 1, 2), in large size\n        :return:\n        \"\"\"\n        paddings1 = [[0, 0],\n                     [1, 1],\n                     [1, 1],\n                     [0, 0]]\n        patch_image_pad1 = tf.pad(patch_image, paddings=paddings1, mode='CONSTANT',\n                                  constant_values=0.0)  # (N, window_size+2, window_size+2, 1), [0.0-BG, 1.0-stroke]\n\n        cursor_x, cursor_y = cursor_position[:, :, 0:1], cursor_position[:, :, 1:2]  # (N, 1, 1)\n        cursor_x_f, cursor_y_f = tf.floor(cursor_x), tf.floor(cursor_y)\n        patch_x, patch_y = 1.0 - (cursor_x - cursor_x_f), 1.0 - (cursor_y - cursor_y_f)  # (N, 1, 1)\n\n        x_ones = tf.ones_like(patch_x, dtype=tf.float32)  # (N, 1, 1)\n        x_ones = tf.tile(x_ones, [1, 1, window_size])  # (N, 1, window_size)\n        patch_x = tf.concat([patch_x, x_ones], axis=-1)  # (N, 1, window_size + 1)\n        patch_x = tf.tile(patch_x, [1, window_size + 1, 1])  # (N, window_size + 1, window_size + 1)\n        patch_x = tf.cumsum(patch_x, axis=-1)  # (N, window_size + 1, window_size + 1)\n        patch_x0 = tf.cast(tf.floor(patch_x), tf.int32)  # (N, window_size + 1, window_size + 1)\n        patch_x1 = patch_x0 + 1  # (N, window_size + 1, window_size + 1)\n\n        y_ones = tf.ones_like(patch_y, dtype=tf.float32)  # (N, 1, 1)\n        y_ones = tf.tile(y_ones, [1, window_size, 1])  # (N, window_size, 1)\n        patch_y = tf.concat([patch_y, y_ones], axis=1)  # (N, window_size + 1, 1)\n        patch_y = tf.tile(patch_y, [1, 1, window_size + 1])  # (N, window_size + 1, window_size + 1)\n        patch_y = tf.cumsum(patch_y, axis=1)  # (N, window_size + 1, window_size + 1)\n        patch_y0 = tf.cast(tf.floor(patch_y), tf.int32)  # (N, window_size + 1, window_size + 1)\n        patch_y1 = patch_y0 + 1  # (N, window_size + 1, window_size + 1)\n\n        # get pixel value at corner coords\n        valid_canvas_patch_a = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y0)\n        valid_canvas_patch_b = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y1)\n        valid_canvas_patch_c = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y0)\n        valid_canvas_patch_d = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y1)\n        # (N, window_size + 1, window_size + 1, 1)\n\n        patch_x0 = tf.cast(patch_x0, tf.float32)\n        patch_x1 = tf.cast(patch_x1, tf.float32)\n        patch_y0 = tf.cast(patch_y0, tf.float32)\n        patch_y1 = tf.cast(patch_y1, tf.float32)\n\n        # calculate deltas\n        wa = (patch_x1 - patch_x) * (patch_y1 - patch_y)\n        wb = (patch_x1 - patch_x) * (patch_y - patch_y0)\n        wc = (patch_x - patch_x0) * (patch_y1 - patch_y)\n        wd = (patch_x - patch_x0) * (patch_y - patch_y0)\n        # (N, window_size + 1, window_size + 1)\n\n        # add dimension for addition\n        wa = tf.expand_dims(wa, axis=3)\n        wb = tf.expand_dims(wb, axis=3)\n        wc = tf.expand_dims(wc, axis=3)\n        wd = tf.expand_dims(wd, axis=3)\n        # (N, window_size + 1, window_size + 1, 1)\n\n        # compute output\n        valid_canvas_patch_ = tf.add_n([wa * valid_canvas_patch_a,\n                                        wb * valid_canvas_patch_b,\n                                        wc * valid_canvas_patch_c,\n                                        wd * valid_canvas_patch_d])  # (N, window_size + 1, window_size + 1, 1)\n        return valid_canvas_patch_\n\n    def image_pasting(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):\n        \"\"\"\n        paste the patch_img to padded size based on cursor_position\n        :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        :return:\n        \"\"\"\n        cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32))  # in large size\n        window_sizes_r = tf.round(window_sizes)  # (N, 1, 1), no grad\n\n        patch_img_ = tf.expand_dims(patch_img, axis=-1)  # (N, raster_size, raster_size, 1)\n        cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)\n        cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,\n                                                              1])  # (N, raster_size, raster_size, 2)\n        image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (N, 1, 1, 1)\n        image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,\n                                                    self.hps.raster_size, 1])\n        window_sizes_tile = tf.reshape(window_sizes_r, (-1, 1, 1, 1))  # (N, 1, 1, 1)\n        window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])\n\n        pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],\n                                   axis=-1)  # (N, raster_size, raster_size, 5)\n\n        if is_differentiable:\n            curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single, pasting_inputs,\n                                        parallel_iterations=32)  # (N, image_size, image_size, 1)\n        else:\n            curr_paste_imgs = tf.map_fn(self.image_pasting_nondiff_single, pasting_inputs,\n                                        parallel_iterations=32)  # (N, image_size, image_size, 1)\n        curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1)  # (N, image_size, image_size)\n        return curr_paste_imgs\n\n    def image_pasting_v3(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):\n        \"\"\"\n        paste the patch_img to padded size based on cursor_position\n        :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)\n        :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]\n        :param window_sizes: (N, 1, 1), float32, with grad\n        :return:\n        \"\"\"\n        cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32))  # in large size\n\n        if is_differentiable:\n            patch_img_ = tf.expand_dims(patch_img, axis=-1)  # (N, raster_size, raster_size, 1)\n            cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2))  # (N, 1, 1, 2)\n            cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,\n                                           1])  # (N, raster_size, raster_size, 2)\n            image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1))  # (N, 1, 1, 1)\n            image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,\n                                      self.hps.raster_size, 1])\n            window_sizes_tile = tf.reshape(window_sizes, (-1, 1, 1, 1))  # (N, 1, 1, 1)\n            window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])\n\n            pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],\n                                       axis=-1)  # (N, raster_size, raster_size, 5)\n            curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single_v3, pasting_inputs,\n                                        parallel_iterations=32)  # (N, image_size, image_size, 1)\n        else:\n            raise Exception('Unfinished...')\n        curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1)  # (N, image_size, image_size)\n        return curr_paste_imgs\n\n    def get_points_and_raster_image(self, initial_state, init_cursor, input_photo, image_size):\n        ## generate the other_params and pen_ras and raster image for raster loss\n        prev_state = initial_state  # (N, dec_rnn_size * 3)\n\n        prev_width = self.init_width  # (1)\n        prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=0), axis=0)  # (1, 1, 1)\n        prev_width = tf.tile(prev_width, [self.hps.batch_size // self.total_loop, 1, 1])  # (N, 1, 1)\n\n        prev_scaling = tf.ones((self.hps.batch_size // self.total_loop, 1, 1))  # (N, 1, 1)\n        prev_window_size = tf.ones((self.hps.batch_size // self.total_loop, 1, 1),\n                                   dtype=tf.float32) * float(self.hps.raster_size)  # (N, 1, 1)\n\n        cursor_position_temp = init_cursor\n        self.cursor_position = cursor_position_temp  # (N, 1, 2), in size [0.0, 1.0)\n        cursor_position_loop = self.cursor_position\n\n        other_params_list = []\n        pen_ras_list = []\n\n        pos_before_max_min_list = []\n        win_size_before_max_min_list = []\n\n        curr_canvas_soft = tf.zeros_like(input_photo[:, :, :, 0])  # (N, image_size, image_size), [0.0-BG, 1.0-stroke]\n        curr_canvas_soft_rgb = tf.tile(tf.zeros_like(input_photo[:, :, :, 0:1]), [1, 1, 1, 3])  # (N, image_size, image_size, 3), [0.0-BG, 1.0-stroke]\n        curr_canvas_hard = tf.zeros_like(curr_canvas_soft)  # [0.0-BG, 1.0-stroke]\n\n        #### sampling part - start ####\n        self.curr_canvas_hard = curr_canvas_hard\n\n        rasterizor_st = NeuralRasterizorStep(\n            raster_size=self.hps.raster_size,\n            position_format=self.hps.position_format)\n\n        if self.hps.cropping_type == 'v3':\n            cropping_func = self.image_cropping_v3\n        # elif self.hps.cropping_type == 'v2':\n        #     cropping_func = self.image_cropping\n        else:\n            raise Exception('Unknown cropping_type', self.hps.cropping_type)\n\n        if self.hps.pasting_type == 'v3':\n            pasting_func = self.image_pasting_v3\n        # elif self.hps.pasting_type == 'v2':\n        #     pasting_func = self.image_pasting\n        else:\n            raise Exception('Unknown pasting_type', self.hps.pasting_type)\n\n        for time_i in range(self.hps.max_seq_len):\n            cursor_position_non_grad = tf.stop_gradient(cursor_position_loop)  # (N, 1, 2), in size [0.0, 1.0)\n\n            curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size))  # float, with grad\n            curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))\n            curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))\n\n            ## patch-level encoding\n            # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.\n            curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)\n            curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)\n\n            # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]\n            crop_inputs = tf.concat([1.0 - input_photo, curr_canvas_hard_non_grad], axis=-1)  # (N, H_p, W_p, 1/3+1)\n\n            cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)\n            index_offset = self.hps.input_channel - 1\n            curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset]  # [0.0-BG, 1.0-stroke]\n            curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]\n            # (N, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]\n\n            curr_patch_inputs = 1.0 - curr_patch_inputs  # [0.0-stroke, 1.0-BG]\n            curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)\n            # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]\n\n            # Normalizing image\n            curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad  # [0.0-stroke, 1.0-BG]\n            curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad)  # [-1.0-stroke, 1.0-BG]\n\n            ## image-level encoding\n            combined_z = self.build_combined_encoder(\n                curr_patch_canvas_hard_non_grad,\n                curr_patch_inputs,\n                1.0 - curr_canvas_hard_non_grad,\n                input_photo,\n                cursor_position_non_grad,\n                image_size,\n                curr_window_size)  # (N, z_size)\n            combined_z = tf.expand_dims(combined_z, axis=1)  # (N, 1, z_size)\n\n            curr_window_size_top_side_norm_non_grad = \\\n                tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))\n            curr_window_size_bottom_side_norm_non_grad = \\\n                tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))\n            if not self.hps.concat_win_size:\n                combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2)  # (N, 1, 2+z_size)\n            else:\n                combined_z = tf.concat([tf.stop_gradient(prev_width),\n                                        curr_window_size_top_side_norm_non_grad,\n                                        curr_window_size_bottom_side_norm_non_grad,\n                                        combined_z],\n                                       2)  # (N, 1, 2+z_size)\n\n            if self.hps.concat_cursor:\n                prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2)  # (N, 1, 2+2+z_size)\n            else:\n                prev_input_x = combined_z  # (N, 1, 2+z_size)\n\n            h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)\n            # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)\n            [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)\n            # o_other_params: (N * 1, 6)\n            # o_pen_ras: (N * 1, 2), after softmax\n\n            o_other_params = tf.reshape(o_other_params, [-1, 1, 6])  # (N, 1, 6)\n            o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2])  # (N, 1, 2)\n\n            other_params_list.append(o_other_params)\n            pen_ras_list.append(o_pen_ras_raw)\n\n            #### sampling part - end ####\n\n            if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval' or self.hps.model_mode == 'eval_sample':\n                # use renderer here to convert the strokes to image\n                curr_other_params = tf.squeeze(o_other_params, axis=1)  # (N, 6), (x1, y1)=[0.0, 1.0], (x2, y2)=[-1.0, 1.0]\n                x1y1, x2y2, width2, scaling = curr_other_params[:, 0:2], curr_other_params[:, 2:4],\\\n                                              curr_other_params[:, 4:5], curr_other_params[:, 5:6]\n                x0y0 = tf.zeros_like(x2y2)  # (N, 2), [-1.0, 1.0]\n                x0y0 = tf.div(tf.add(x0y0, 1.0), 2.0)  # (N, 2), [0.0, 1.0]\n                x2y2 = tf.div(tf.add(x2y2, 1.0), 2.0)  # (N, 2), [0.0, 1.0]\n                widths = tf.concat([tf.squeeze(prev_width, axis=1), width2], axis=1)  # (N, 2)\n                curr_other_params = tf.concat([x0y0, x1y1, x2y2, widths], axis=-1)  # (N, 8), (x0, y0)&(x2, y2)=[0.0, 1.0]\n                curr_stroke_image = rasterizor_st.raster_func_stroke_abs(curr_other_params)\n                # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]\n\n                curr_stroke_image_large = pasting_func(cursor_position_loop, curr_stroke_image,\n                                                             image_size, curr_window_size,\n                                                             is_differentiable=self.hps.pasting_diff)\n                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]\n\n                ## soft\n                if not self.hps.use_softargmax:\n                    curr_state_soft = o_pen_ras[:, 1:2]  # (N, 1)\n                else:\n                    curr_state_soft = self.differentiable_argmax(o_pen_ras, self.hps.soft_beta)  # (N, 1)\n\n                curr_state_soft = tf.expand_dims(curr_state_soft, axis=1)  # (N, 1, 1)\n\n                filter_curr_stroke_image_soft = tf.multiply(tf.subtract(1.0, curr_state_soft), curr_stroke_image_large)\n                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]\n                curr_canvas_soft = tf.add(curr_canvas_soft, filter_curr_stroke_image_soft)  # [0.0-BG, 1.0-stroke]\n\n                ## hard\n                curr_state_hard = tf.expand_dims(tf.cast(tf.argmax(o_pen_ras_raw, axis=-1), dtype=tf.float32),\n                                                     axis=-1)  # (N, 1, 1)\n                filter_curr_stroke_image_hard = tf.multiply(tf.subtract(1.0, curr_state_hard), curr_stroke_image_large)\n                # (N, image_size, image_size), [0.0-BG, 1.0-stroke]\n                self.curr_canvas_hard = tf.add(self.curr_canvas_hard, filter_curr_stroke_image_hard)  # [0.0-BG, 1.0-stroke]\n                self.curr_canvas_hard = tf.clip_by_value(self.curr_canvas_hard, 0.0, 1.0)  # [0.0-BG, 1.0-stroke]\n\n            next_width = o_other_params[:, :, 4:5]\n            next_scaling = o_other_params[:, :, 5:6]\n            next_window_size = tf.multiply(next_scaling, tf.stop_gradient(curr_window_size))  # float, with grad\n            window_size_before_max_min = next_window_size  # (N, 1, 1), large-level\n            win_size_before_max_min_list.append(window_size_before_max_min)\n            next_window_size = tf.maximum(next_window_size, tf.cast(self.hps.min_window_size, tf.float32))\n            next_window_size = tf.minimum(next_window_size, tf.cast(image_size, tf.float32))\n\n            prev_state = next_state\n            prev_width = next_width * curr_window_size / next_window_size  # (N, 1, 1)\n            prev_scaling = next_scaling  # (N, 1, 1))\n            prev_window_size = curr_window_size\n\n            # update the cursor position\n            new_cursor_offsets = tf.multiply(o_other_params[:, :, 2:4],\n                                             tf.divide(curr_window_size, 2.0))  # (N, 1, 2), window-level\n            new_cursor_offset_next = new_cursor_offsets\n            new_cursor_offset_next = tf.concat([new_cursor_offset_next[:, :, 1:2], new_cursor_offset_next[:, :, 0:1]], axis=-1)\n\n            cursor_position_loop_large = tf.multiply(cursor_position_loop, tf.cast(image_size, tf.float32))\n\n            if self.hps.stop_accu_grad:\n                stroke_position_next = tf.stop_gradient(cursor_position_loop_large) + new_cursor_offset_next  # (N, 1, 2), large-level\n            else:\n                stroke_position_next = cursor_position_loop_large + new_cursor_offset_next  # (N, 1, 2), large-level\n\n            stroke_position_before_max_min = stroke_position_next  # (N, 1, 2), large-level\n            pos_before_max_min_list.append(stroke_position_before_max_min)\n\n            if self.hps.cursor_type == 'next':\n                cursor_position_loop_large = stroke_position_next  # (N, 1, 2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_position_loop_large = tf.maximum(cursor_position_loop_large, 0.0)\n            cursor_position_loop_large = tf.minimum(cursor_position_loop_large, tf.cast(image_size - 1, tf.float32))\n            cursor_position_loop = tf.div(cursor_position_loop_large, tf.cast(image_size, tf.float32))\n\n        curr_canvas_soft = tf.clip_by_value(curr_canvas_soft, 0.0, 1.0)  # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]\n\n        other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6])  # (N * max_seq_len, 6)\n        pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2])  # (N * max_seq_len, 2)\n        pos_before_max_min_ = tf.concat(pos_before_max_min_list, axis=1)  # (N, max_seq_len, 2)\n        win_size_before_max_min_ = tf.concat(win_size_before_max_min_list, axis=1)  # (N, max_seq_len, 1)\n\n        return other_params_, pen_ras_, prev_state, curr_canvas_soft, curr_canvas_soft_rgb, \\\n               pos_before_max_min_, win_size_before_max_min_\n\n    def differentiable_argmax(self, input_pen, soft_beta):\n        \"\"\"\n        Differentiable argmax trick.\n        :param input_pen: (N, n_class)\n        :return: pen_state: (N, 1)\n        \"\"\"\n        def sign_onehot(x):\n            \"\"\"\n            :param x: (N, n_class)\n            :return:  (N, n_class)\n            \"\"\"\n            y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)\n            y = (y - 1) * (-1)\n            return y\n\n        def softargmax(x, beta=1e2):\n            \"\"\"\n            :param x: (N, n_class)\n            :param beta: 1e10 is the best. 1e2 is acceptable.\n            :return:  (N)\n            \"\"\"\n            x_range = tf.cumsum(tf.ones_like(x), axis=1)  # (N, 2)\n            return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1\n\n        ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.\n        # pen_onehot = sign_onehot(input_pen)  # one-hot form, (N * max_seq_len, 2)\n        # pen_state = pen_onehot[:, 1:2]  # (N * max_seq_len, 1)\n        pen_state = softargmax(input_pen, soft_beta)\n        pen_state = tf.expand_dims(pen_state, axis=1)  # (N * max_seq_len, 1)\n        return pen_state\n\n    def build_losses(self, target_sketch, pred_raster_imgs, pred_params,\n                     pos_before_max_min, win_size_before_max_min, image_size):\n        def get_raster_loss(pred_imgs, gt_imgs, loss_type):\n            perc_layer_losses_raw = []\n            perc_layer_losses_weighted = []\n            perc_layer_losses_norm = []\n\n            if loss_type == 'l1':\n                ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs, pred_imgs)))  # ()\n            elif loss_type == 'l1_small':\n                gt_imgs_small = tf.image.resize_images(tf.expand_dims(gt_imgs, axis=3), (32, 32))\n                pred_imgs_small = tf.image.resize_images(tf.expand_dims(pred_imgs, axis=3), (32, 32))\n                ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs_small, pred_imgs_small)))  # ()\n            elif loss_type == 'mse':\n                ras_cost = tf.reduce_mean(tf.pow(tf.subtract(gt_imgs, pred_imgs), 2))  # ()\n            elif loss_type == 'perceptual':\n                return_map_pred = vgg_net_slim(pred_imgs, image_size)\n                return_map_gt = vgg_net_slim(gt_imgs, image_size)\n                perc_loss_type = 'l1'  # [l1, mse]\n                weighted_map = {'ReLU1_1': 100.0, 'ReLU1_2': 100.0,\n                                'ReLU2_1': 100.0, 'ReLU2_2': 100.0,\n                                'ReLU3_1': 10.0, 'ReLU3_2': 10.0, 'ReLU3_3': 10.0,\n                                'ReLU4_1': 1.0, 'ReLU4_2': 1.0, 'ReLU4_3': 1.0,\n                                'ReLU5_1': 1.0, 'ReLU5_2': 1.0, 'ReLU5_3': 1.0}\n\n                for perc_layer in self.hps.perc_loss_layers:\n                    if perc_loss_type == 'l1':\n                        perc_layer_loss = tf.reduce_mean(tf.abs(tf.subtract(return_map_pred[perc_layer],\n                                                                            return_map_gt[perc_layer])))  # ()\n                    elif perc_loss_type == 'mse':\n                        perc_layer_loss = tf.reduce_mean(tf.pow(tf.subtract(return_map_pred[perc_layer],\n                                                                            return_map_gt[perc_layer]), 2))  # ()\n                    else:\n                        raise NameError('Unknown perceptual loss type:', perc_loss_type)\n                    perc_layer_losses_raw.append(perc_layer_loss)\n\n                    assert perc_layer in weighted_map\n                    perc_layer_losses_weighted.append(perc_layer_loss * weighted_map[perc_layer])\n\n                for loop_i in range(len(self.hps.perc_loss_layers)):\n                    perc_relu_loss_raw = perc_layer_losses_raw[loop_i]  # ()\n\n                    if self.hps.model_mode == 'train':\n                        curr_relu_mean = (self.perc_loss_mean_list[loop_i] * self.last_step_num + perc_relu_loss_raw) / (self.last_step_num + 1.0)\n                        relu_cost_norm = perc_relu_loss_raw / curr_relu_mean\n                    else:\n                        relu_cost_norm = perc_relu_loss_raw\n                    perc_layer_losses_norm.append(relu_cost_norm)\n\n                perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)\n                perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)\n\n                if self.hps.perc_loss_fuse_type == 'max':\n                    ras_cost = tf.reduce_max(perc_layer_losses_norm)\n                elif self.hps.perc_loss_fuse_type == 'add':\n                    ras_cost = tf.reduce_mean(perc_layer_losses_norm)\n                elif self.hps.perc_loss_fuse_type == 'raw_add':\n                    ras_cost = tf.reduce_mean(perc_layer_losses_raw)\n                elif self.hps.perc_loss_fuse_type == 'weighted_sum':\n                    ras_cost = tf.reduce_mean(perc_layer_losses_weighted)\n                else:\n                    raise NameError('Unknown perc_loss_fuse_type:', self.hps.perc_loss_fuse_type)\n\n            elif loss_type == 'triplet':\n                raise Exception('Solution for triplet loss is coming soon.')\n            else:\n                raise NameError('Unknown loss type:', loss_type)\n\n            if loss_type != 'perceptual':\n                for perc_layer_i in self.hps.perc_loss_layers:\n                    perc_layer_losses_raw.append(tf.constant(0.0))\n                    perc_layer_losses_norm.append(tf.constant(0.0))\n\n                perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)\n                perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)\n\n            return ras_cost, perc_layer_losses_raw, perc_layer_losses_norm\n\n        gt_raster_images = tf.squeeze(target_sketch, axis=3)  # (N, raster_h, raster_w), [0.0-stroke, 1.0-BG]\n        raster_cost, perc_relu_losses_raw, perc_relu_losses_norm = \\\n            get_raster_loss(pred_raster_imgs, gt_raster_images, loss_type=self.hps.raster_loss_base_type)\n\n        def get_stroke_num_loss(input_strokes):\n            ending_state = input_strokes[:, :, 0]  # (N, seq_len)\n            stroke_num_loss_pre = tf.reduce_mean(ending_state)  # larger is better, [0.0, 1.0]\n            stroke_num_loss = 1.0 - stroke_num_loss_pre  # lower is better, [0.0, 1.0]\n            return stroke_num_loss\n\n        stroke_num_cost = get_stroke_num_loss(pred_params)  # lower is better\n\n        def get_pos_outside_loss(pos_before_max_min_):\n            pos_after_max_min = tf.maximum(pos_before_max_min_, 0.0)\n            pos_after_max_min = tf.minimum(pos_after_max_min, tf.cast(image_size - 1, tf.float32))  # (N, max_seq_len, 2)\n            pos_outside_loss = tf.reduce_mean(tf.abs(pos_before_max_min_ - pos_after_max_min))\n            return pos_outside_loss\n\n        pos_outside_cost = get_pos_outside_loss(pos_before_max_min)  # lower is better\n\n        def get_win_size_outside_loss(win_size_before_max_min_, min_window_size):\n            win_size_outside_top_loss = tf.divide(\n                tf.maximum(win_size_before_max_min_ - tf.cast(image_size, tf.float32), 0.0),\n                tf.cast(image_size, tf.float32))  # (N, max_seq_len, 1)\n            win_size_outside_bottom_loss = tf.divide(\n                tf.maximum(tf.cast(min_window_size, tf.float32) - win_size_before_max_min_, 0.0),\n                tf.cast(min_window_size, tf.float32))  # (N, max_seq_len, 1)\n            win_size_outside_loss = tf.reduce_mean(win_size_outside_top_loss + win_size_outside_bottom_loss)\n            return win_size_outside_loss\n\n        win_size_outside_cost = get_win_size_outside_loss(win_size_before_max_min, self.hps.min_window_size)  # lower is better\n\n        def get_early_pen_states_loss(input_strokes, curr_start, curr_end):\n            # input_strokes: (N, max_seq_len, 7)\n            pred_early_pen_states = input_strokes[:, curr_start:curr_end, 0]  # (N, curr_early_len)\n            pred_early_pen_states_min = tf.reduce_min(pred_early_pen_states, axis=1)  # (N), should not be 1\n            early_pen_states_loss = tf.reduce_mean(pred_early_pen_states_min)  # lower is better\n            return early_pen_states_loss\n\n        early_pen_states_cost = get_early_pen_states_loss(pred_params,\n                                                          self.early_pen_loss_start_idx, self.early_pen_loss_end_idx)\n\n        return raster_cost, stroke_num_cost, pos_outside_cost, win_size_outside_cost, \\\n               early_pen_states_cost, \\\n               perc_relu_losses_raw, perc_relu_losses_norm\n\n    def build_training_op_split(self, raster_cost, sn_cost, cursor_outside_cost, win_size_outside_cost,\n                                early_pen_states_cost):\n        total_cost = self.hps.raster_loss_weight * raster_cost + \\\n                self.hps.early_pen_loss_weight * early_pen_states_cost + \\\n                self.stroke_num_loss_weight * sn_cost + \\\n                self.hps.outside_loss_weight * cursor_outside_cost + \\\n                self.hps.win_size_outside_loss_weight * win_size_outside_cost\n\n        tvars = [var for var in tf.trainable_variables()\n                 if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]\n        gvs = self.optimizer.compute_gradients(total_cost, var_list=tvars)\n        return total_cost, gvs\n\n    def build_training_op(self, grad_list):\n        with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):\n            gvs = self.average_gradients(grad_list)\n            g = self.hps.grad_clip\n\n            for grad, var in gvs:\n                print('>>', var.op.name)\n                if grad is None:\n                    print('  >> None value')\n\n            capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]\n\n            self.train_op = self.optimizer.apply_gradients(\n                capped_gvs, global_step=self.global_step, name='train_step')\n\n    def average_gradients(self, grads_list):\n        \"\"\"\n        Compute the average gradients.\n        :param grads_list: list(of length N_GPU) of list(grad, var)\n        :return:\n        \"\"\"\n        avg_grads = []\n        for grad_and_vars in zip(*grads_list):\n            grads = []\n            for g, _ in grad_and_vars:\n                expanded_g = tf.expand_dims(g, 0)\n                grads.append(expanded_g)\n            grad = tf.concat(grads, axis=0)\n            grad = tf.reduce_mean(grad, axis=0)\n\n            v = grad_and_vars[0][1]\n            grad_and_var = (grad, v)\n            avg_grads.append(grad_and_var)\n\n        return avg_grads"
  },
  {
    "path": "rasterization_utils/NeuralRenderer.py",
    "content": "import tensorflow as tf\n\n\nclass RasterUnit(object):\n    def __init__(self,\n                 raster_size,\n                 input_params,  # (N, 10)\n                 reuse=False):\n        self.raster_size = raster_size\n        self.input_params = input_params\n\n        with tf.variable_scope(\"raster_unit\", reuse=reuse):\n            self.build_unit()\n\n    def build_unit(self):\n        x = self.input_params  # (N, 10)\n        x = self.fully_connected(x, 10, 512, scope='fc1')  # (N, 512)\n        x = tf.nn.relu(x)\n        x = self.fully_connected(x, 512, 1024, scope='fc2')  # (N, 1024)\n        x = tf.nn.relu(x)\n        x = self.fully_connected(x, 1024, 2048, scope='fc3')  # (N, 2048)\n        x = tf.nn.relu(x)\n        x = self.fully_connected(x, 2048, 4096, scope='fc4')  # (N, 4096)\n        x = tf.nn.relu(x)\n        x = tf.reshape(x, (-1, 16, 16, 16))  # (N, 16, 16, 16)\n        x = tf.transpose(x, (0, 2, 3, 1))\n\n        x = self.conv2d(x, 32, 3, 1, scope='conv1')  # (N, 16, 16, 32)\n        x = tf.nn.relu(x)\n        x = self.conv2d(x, 32, 3, 1, scope='conv2')  # (N, 16, 16, 32)\n        x = self.pixel_shuffle(x, upscale_factor=2)  # (N, 32, 32, 8)\n\n        x = self.conv2d(x, 16, 3, 1, scope='conv3')  # (N, 32, 32, 16)\n        x = tf.nn.relu(x)\n        x = self.conv2d(x, 16, 3, 1, scope='conv4')  # (N, 32, 32, 16)\n        x = self.pixel_shuffle(x, upscale_factor=2)  # (N, 64, 64, 4)\n\n        x = self.conv2d(x, 8, 3, 1, scope='conv5')  # (N, 64, 64, 8)\n        x = tf.nn.relu(x)\n        x = self.conv2d(x, 4, 3, 1, scope='conv6')  # (N, 64, 64, 4)\n        x = self.pixel_shuffle(x, upscale_factor=2)  # (N, 128, 128, 1)\n        x = tf.sigmoid(x)\n\n        # (N, 128, 128), [0.0-stroke, 1.0-BG]\n        self.stroke_image = 1.0 - tf.reshape(x, (-1, self.raster_size, self.raster_size))\n\n    def conv2d(self, input_tensor, out_channels, kernel_size, stride, scope, reuse=False):\n        with tf.variable_scope(scope, reuse=reuse):\n            output_tensor = tf.layers.conv2d(input_tensor, out_channels, kernel_size=kernel_size,\n                                             strides=(stride, stride),\n                                             padding=\"same\", kernel_initializer=tf.keras.initializers.he_normal())\n            return output_tensor\n\n    def fully_connected(self, input_tensor, in_dim, out_dim, scope, reuse=False):\n        with tf.variable_scope(scope, reuse=reuse):\n            weight = tf.get_variable(\"weight\", [in_dim, out_dim], dtype=tf.float32,\n                                     initializer=tf.random_normal_initializer())\n            bias = tf.get_variable(\"bias\", [out_dim], dtype=tf.float32,\n                                   initializer=tf.random_normal_initializer())\n            output_tensor = tf.matmul(input_tensor, weight) + bias\n            return output_tensor\n\n    def pixel_shuffle(self, input_tensor, upscale_factor):\n        params_shape = input_tensor.get_shape()\n        n, h, w, c = params_shape\n        input_tensor_proc = tf.reshape(input_tensor, (n, h, w, c // 4, 4))\n        input_tensor_proc = tf.transpose(input_tensor_proc, (0, 1, 2, 4, 3))\n        input_tensor_proc = tf.reshape(input_tensor_proc, (n, h, w, -1))\n        output_tensor = tf.depth_to_space(input_tensor_proc, block_size=upscale_factor)\n        return output_tensor\n\n\nclass NeuralRasterizor(object):\n    def __init__(self,\n                 raster_size,\n                 seq_len,\n                 position_format='abs',\n                 raster_padding=10,\n                 strokes_format=3):\n        self.raster_size = raster_size\n        self.seq_len = seq_len\n        self.position_format = position_format\n        self.raster_padding = raster_padding\n        self.strokes_format = strokes_format\n\n        assert position_format in ['abs', 'rel']\n\n    def raster_func_abs(self, input_data, raster_seq_len=None):\n        \"\"\"\n        x and y in absolute position.\n        :param input_data: (N, seq_len, 10): [x0, y0, x1, y1, x2, y2, r0, r2, w0, w2]. All in [0.0, 1.0]\n        :return:\n        \"\"\"\n        seq_len = raster_seq_len if raster_seq_len is not None else self.seq_len\n\n        raster_params = tf.transpose(input_data, [1, 0, 2])  # (seq_len, N, 10)\n\n        seq_stroke_images = tf.map_fn(self.stroke_drawer_with_raster_unit, raster_params,\n                                      parallel_iterations=32)  # (seq_len, N, raster_size, raster_size)\n        seq_stroke_images = tf.transpose(seq_stroke_images, [1, 2, 3, 0])\n        # (N, raster_size, raster_size, seq_len), [0.0-stroke, 1.0-BG]\n\n        filter_seq_stroke_images = 1.0 - seq_stroke_images\n        # (N, raster_size, raster_size, seq_len), [0.0-BG, 1.0-stroke]\n\n        # stacking\n        stroke_images_unclip = tf.reduce_sum(filter_seq_stroke_images, axis=-1)  # (N, raster_size, raster_size)\n        stroke_images = tf.clip_by_value(stroke_images_unclip, 0.0, 1.0)  # [0.0-BG, 1.0-stroke]\n        return stroke_images, stroke_images_unclip, seq_stroke_images\n\n    def stroke_drawer_with_raster_unit(self, params_batch):\n        \"\"\"\n        Convert two points into a raster stroke image with RasterUnit.\n        :param params_batch: (N, 10)\n        :return: (N, raster_size, raster_size)\n        \"\"\"\n        raster_unit = RasterUnit(\n            raster_size=self.raster_size,\n            input_params=params_batch,\n            reuse=tf.AUTO_REUSE\n        )\n        stroke_image = raster_unit.stroke_image  # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]\n        return stroke_image\n\n\nclass NeuralRasterizorStep(object):\n    def __init__(self,\n                 raster_size,\n                 position_format='abs'):\n        self.raster_size = raster_size\n        self.position_format = position_format\n\n        assert position_format in ['abs', 'rel']\n\n    def raster_func_stroke_abs(self, input_data):\n        \"\"\"\n        x and y in absolute position.\n        :param input_data: (N, 8): [x0, y0, x1, y1, x2, y2, r0, r2]. All in [0.0, 1.0]\n        :return:\n        \"\"\"\n        w_in = tf.ones(shape=(input_data.shape[0], 2), dtype=tf.float32)\n        raster_params = tf.concat([input_data, w_in], axis=-1)  # (N, 10)\n        stroke_image = self.stroke_drawer_with_raster_unit(raster_params)  # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]\n        stroke_image = 1.0 - stroke_image  # [0.0-BG, 1.0-stroke]\n\n        return stroke_image\n\n    def mask_ending_state(self, input_states):\n        \"\"\"\n        Mask the ending state to be 1\n        :param input_states: (N, seq_len, 1) in offset manner\n        :param seq_len:\n        :return:\n        \"\"\"\n        ending_state_accu = tf.cumsum(input_states, axis=1)  # (N, seq_len, 1)\n        ending_state_clip = tf.clip_by_value(ending_state_accu, 0.0, 1.0)  # (N, seq_len, 1)\n        return ending_state_clip\n\n    def stroke_drawer_with_raster_unit(self, params_batch):\n        \"\"\"\n        Convert two points into a raster stroke image with RasterUnit.\n        :param params_batch: (N, 10)\n        :return: (N, raster_size, raster_size)\n        \"\"\"\n        raster_unit = RasterUnit(\n            raster_size=self.raster_size,\n            input_params=params_batch,\n            reuse=tf.AUTO_REUSE\n        )\n        stroke_image = raster_unit.stroke_image  # (N, raster_size, raster_size), [0.0-stroke, 1.0-BG]\n        return stroke_image\n"
  },
  {
    "path": "rasterization_utils/RealRenderer.py",
    "content": "import numpy as np\nimport gizeh\n\n\nclass GizehRasterizor(object):\n    def __init__(self):\n        self.name = 'GizehRasterizor'\n\n    def get_line_array_v2(self, image_size, seq_strokes, stroke_width, is_bin=True):\n        \"\"\"\n        :param p1: (x, y)\n        :param p2: (x, y)\n        :return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes\n        \"\"\"\n        surface = gizeh.Surface(width=image_size, height=image_size)  # in pixels\n        shape_list = []\n        for seq_i in range(len(seq_strokes) - 1):\n            p1, p2 = seq_strokes[seq_i, :2], seq_strokes[seq_i + 1, :2]\n            pen_state = seq_strokes[seq_i, 2]\n\n            if pen_state == 0.0:\n                line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0))\n                shape_list.append(line)\n\n        group = gizeh.Group(shape_list)\n        group.draw(surface)\n\n        # Now export the surface\n        line_arr = surface.get_npimage()[:, :, 0]  # returns a (width x height x 3) numpy array\n\n        if is_bin:\n            line_arr[line_arr <= 128] = 0\n            line_arr[line_arr != 0] = 1  # (image_size, image_size)\n        else:\n            line_arr = np.array(line_arr, dtype=np.float32) / 255.0\n\n        return line_arr\n\n    def get_line_array(self, p1, p2, image_size, stroke_width, is_bin=True):\n        \"\"\"\n        :param p1: (x, y)\n        :param p2: (x, y)\n        :return: line_arr: (image_size, image_size), {0, 1}, 0 for BG and 1 for strokes\n        \"\"\"\n        surface = gizeh.Surface(width=image_size, height=image_size)  # in pixels\n        line = gizeh.polyline(points=[p1, p2], stroke_width=stroke_width, stroke=(1, 1, 1), fill=(0, 0, 0))\n        line.draw(surface)\n\n        # Now export the surface\n        line_arr = surface.get_npimage()[:, :, 0]  # returns a (width x height x 3) numpy array\n\n        if is_bin:\n            line_arr[line_arr <= 128] = 0\n            line_arr[line_arr != 0] = 1  # (image_size, image_size)\n        else:\n            line_arr = np.array(line_arr, dtype=np.float32) / 255.0\n\n        return line_arr\n\n    def load_sketch_images_on_the_fly_v2(self, image_size, norm_strokes3, stroke_width, is_bin=True):\n        \"\"\"\n        :param norm_strokes3: list (N_sketches,), each with (N_points, 3)\n        :return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes\n        \"\"\"\n        assert type(norm_strokes3) is list\n        sketch_imgs_list = []\n        for stroke_i in range(len(norm_strokes3)):\n            seq_strokes3 = norm_strokes3[stroke_i]  # (N_points, 3)\n            sketch_img = self.get_line_array_v2(image_size, seq_strokes3, stroke_width=stroke_width, is_bin=is_bin)\n            sketch_img = np.clip(sketch_img, 0.0, 1.0)  # (image_size, image_size), 0 for BG and 1 for strokes\n            sketch_imgs_list.append(sketch_img)\n\n        return sketch_imgs_list\n\n    def load_sketch_images_on_the_fly(self, image_size, norm_strokes3, stroke_width, is_bin=True):\n        \"\"\"\n        :param norm_strokes3: list (N_sketches,), each with (N_points, 3)\n        :return: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes\n        \"\"\"\n        assert type(norm_strokes3) is list\n        sketch_imgs_list = []\n        for stroke_i in range(len(norm_strokes3)):\n            seq_strokes3 = norm_strokes3[stroke_i]  # (N_points, 3)\n            seq_len = len(seq_strokes3)\n            stroke_imgs_list = []\n\n            for seq_i in range(seq_len - 1):\n                stroke_img = self.get_line_array(seq_strokes3[seq_i, :2], seq_strokes3[seq_i + 1, :2], image_size,\n                                                 stroke_width=stroke_width, is_bin=is_bin)\n                pen_state = seq_strokes3[seq_i, 2]\n                stroke_img = stroke_img.astype(np.float32) * (1. - pen_state)\n                stroke_imgs_list.append(stroke_img)\n\n            stroke_imgs_list = np.stack(stroke_imgs_list,\n                                        axis=-1)  # (image_size, image_size, seq_len-1), 0 for BG and 1 for strokes\n            stroke_imgs_list = np.sum(stroke_imgs_list, axis=-1)\n            stroke_imgs_list = np.clip(stroke_imgs_list, 0.0, 1.0)  # (image_size, image_size), 0 for BG and 1 for strokes\n            sketch_imgs_list.append(stroke_imgs_list)\n\n        return sketch_imgs_list\n\n    def normalize_coordinate_np(self, sx, sy, image_size, raster_padding=10.0):\n        \"\"\"\n        Convert offset to normalized absolute points. The numpy version as in NeuralRasterizor.\n        :param sx: (N, seq_len)\n        :param sy: (N, seq_len)\n        :return:\n        \"\"\"\n        seq_len = sx.shape[1]\n\n        # transfer to abs points\n        abs_x = np.cumsum(sx, axis=1)  # (N, seq_len)\n        abs_y = np.cumsum(sy, axis=1)\n\n        min_x = np.min(abs_x, axis=1, keepdims=True)  # (N, 1)\n        max_x = np.max(abs_x, axis=1, keepdims=True)\n        min_y = np.min(abs_y, axis=1, keepdims=True)\n        max_y = np.max(abs_y, axis=1, keepdims=True)\n\n        # transform to positive coordinate\n        abs_x = np.subtract(abs_x, np.tile(min_x, [1, seq_len]))  # (N, seq_len)\n        abs_y = np.subtract(abs_y, np.tile(min_y, [1, seq_len]))\n\n        # scaling to [0.0, raster_size - 2 * padding - 1]\n        bbox_w = np.squeeze(np.subtract(max_x, min_x), axis=-1)  # (N)\n        bbox_h = np.squeeze(np.subtract(max_y, min_y), axis=-1)\n\n        unpad_raster_size = (image_size - 1.0) - 2.0 * raster_padding\n        scaling = np.divide(unpad_raster_size, np.maximum(bbox_w, bbox_h))  # (N)\n        scaling_tile = np.tile(np.expand_dims(scaling, axis=-1), [1, seq_len])  # (N, seq_len)\n        abs_x = np.multiply(abs_x, scaling_tile)  # (N, seq_len)\n        abs_y = np.multiply(abs_y, scaling_tile)\n\n        # add padding\n        abs_x = np.add(abs_x, raster_padding)  # (N, seq_len)\n        abs_y = np.add(abs_y, raster_padding)\n\n        # transform to the middle\n        trans_x = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_w, scaling)), 2.0)  # (N)\n        trans_y = np.divide(np.subtract(unpad_raster_size, np.multiply(bbox_h, scaling)), 2.0)\n        trans_x = np.tile(np.expand_dims(trans_x, axis=-1), [1, seq_len])  # (N, seq_len)\n        trans_y = np.tile(np.expand_dims(trans_y, axis=-1), [1, seq_len])  # (N, seq_len)\n        abs_x = np.add(abs_x, trans_x)  # (N, seq_len)\n        abs_y = np.add(abs_y, trans_y)\n\n        return abs_x, abs_y\n\n    def normalize_strokes_np(self, strokes_list, image_size):\n        \"\"\"\n\n        :param strokes_list: list (N_sketches,), each with (N_points, 3)\n        :return:\n        \"\"\"\n        assert type(strokes_list) is list\n\n        rst_list = []\n        for i in range(len(strokes_list)):\n            strokes_data = strokes_list[i]  # (N_points, 3)\n            norm_x, norm_y = self.normalize_coordinate_np(np.expand_dims(strokes_data[:, 0], axis=0),\n                                                          np.expand_dims(strokes_data[:, 1], axis=0),\n                                                          image_size)  # (1, N_points)\n            norm_strokes_data = np.stack([norm_x[0], norm_y[0], strokes_data[:, 2]], axis=-1)  # (N_points, 3)\n            rst_list.append(norm_strokes_data)\n        return rst_list\n\n    def raster_func(self, input_data, image_size, stroke_width, is_bin=True, version='v2'):\n        \"\"\"\n        :param input_data: (N_sketches,), each with (N_points, 3)\n        :return: raster_image_array: list (N_sketches,), each with (raster_size, raster_size), 0-BG and 1-strokes\n        \"\"\"\n        norm_test_strokes3 = self.normalize_strokes_np(input_data, image_size)\n        if version == 'v1':\n            raster_image_array = self.load_sketch_images_on_the_fly(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin)\n        else:\n            raster_image_array = self.load_sketch_images_on_the_fly_v2(image_size, norm_test_strokes3, stroke_width, is_bin=is_bin)\n\n        return raster_image_array\n"
  },
  {
    "path": "rnn.py",
    "content": "# Copyright 2019 The Magenta Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"SketchRNN RNN definition.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\n\n\ndef orthogonal(shape):\n    \"\"\"Orthogonal initilaizer.\"\"\"\n    flat_shape = (shape[0], np.prod(shape[1:]))\n    a = np.random.normal(0.0, 1.0, flat_shape)\n    u, _, v = np.linalg.svd(a, full_matrices=False)\n    q = u if u.shape == flat_shape else v\n    return q.reshape(shape)\n\n\ndef orthogonal_initializer(scale=1.0):\n    \"\"\"Orthogonal initializer.\"\"\"\n\n    def _initializer(shape, dtype=tf.float32,\n                     partition_info=None):  # pylint: disable=unused-argument\n        return tf.constant(orthogonal(shape) * scale, dtype)\n\n    return _initializer\n\n\ndef lstm_ortho_initializer(scale=1.0):\n    \"\"\"LSTM orthogonal initializer.\"\"\"\n\n    def _initializer(shape, dtype=tf.float32,\n                     partition_info=None):  # pylint: disable=unused-argument\n        size_x = shape[0]\n        size_h = shape[1] // 4  # assumes lstm.\n        t = np.zeros(shape)\n        t[:, :size_h] = orthogonal([size_x, size_h]) * scale\n        t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale\n        t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale\n        t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale\n        return tf.constant(t, dtype)\n\n    return _initializer\n\n\nclass LSTMCell(tf.contrib.rnn.RNNCell):\n    \"\"\"Vanilla LSTM cell.\n\n  Uses ortho initializer, and also recurrent dropout without memory loss\n  (https://arxiv.org/abs/1603.05118)\n  \"\"\"\n\n    def __init__(self,\n                 num_units,\n                 forget_bias=1.0,\n                 use_recurrent_dropout=False,\n                 dropout_keep_prob=0.9):\n        self.num_units = num_units\n        self.forget_bias = forget_bias\n        self.use_recurrent_dropout = use_recurrent_dropout\n        self.dropout_keep_prob = dropout_keep_prob\n\n    @property\n    def state_size(self):\n        return 2 * self.num_units\n\n    @property\n    def output_size(self):\n        return self.num_units\n\n    def get_output(self, state):\n        unused_c, h = tf.split(state, 2, 1)\n        return h\n\n    def __call__(self, x, state, scope=None):\n        with tf.variable_scope(scope or type(self).__name__):\n            c, h = tf.split(state, 2, 1)\n\n            x_size = x.get_shape().as_list()[1]\n\n            w_init = None  # uniform\n\n            h_init = lstm_ortho_initializer(1.0)\n\n            # Keep W_xh and W_hh separate here as well to use different init methods.\n            w_xh = tf.get_variable(\n                'W_xh', [x_size, 4 * self.num_units], initializer=w_init)\n            w_hh = tf.get_variable(\n                'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)\n            bias = tf.get_variable(\n                'bias', [4 * self.num_units],\n                initializer=tf.constant_initializer(0.0))\n\n            concat = tf.concat([x, h], 1)\n            w_full = tf.concat([w_xh, w_hh], 0)\n            hidden = tf.matmul(concat, w_full) + bias\n\n            i, j, f, o = tf.split(hidden, 4, 1)\n\n            if self.use_recurrent_dropout:\n                g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)\n            else:\n                g = tf.tanh(j)\n\n            new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g\n            new_h = tf.tanh(new_c) * tf.sigmoid(o)\n\n            return new_h, tf.concat([new_c, new_h], 1)  # fuk tuples.\n\n\ndef layer_norm_all(h,\n                   batch_size,\n                   base,\n                   num_units,\n                   scope='layer_norm',\n                   reuse=False,\n                   gamma_start=1.0,\n                   epsilon=1e-3,\n                   use_bias=True):\n    \"\"\"Layer Norm (faster version, but not using defun).\"\"\"\n    # Performs layer norm on multiple base at once (ie, i, g, j, o for lstm)\n    # Reshapes h in to perform layer norm in parallel\n    h_reshape = tf.reshape(h, [batch_size, base, num_units])\n    mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)\n    var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)\n    epsilon = tf.constant(epsilon)\n    rstd = tf.rsqrt(var + epsilon)\n    h_reshape = (h_reshape - mean) * rstd\n    # reshape back to original\n    h = tf.reshape(h_reshape, [batch_size, base * num_units])\n    with tf.variable_scope(scope):\n        if reuse:\n            tf.get_variable_scope().reuse_variables()\n        gamma = tf.get_variable(\n            'ln_gamma', [4 * num_units],\n            initializer=tf.constant_initializer(gamma_start))\n        if use_bias:\n            beta = tf.get_variable(\n                'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0))\n    if use_bias:\n        return gamma * h + beta\n    return gamma * h\n\n\ndef layer_norm(x,\n               num_units,\n               scope='layer_norm',\n               reuse=False,\n               gamma_start=1.0,\n               epsilon=1e-3,\n               use_bias=True):\n    \"\"\"Calculate layer norm.\"\"\"\n    axes = [1]\n    mean = tf.reduce_mean(x, axes, keep_dims=True)\n    x_shifted = x - mean\n    var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True)\n    inv_std = tf.rsqrt(var + epsilon)\n    with tf.variable_scope(scope):\n        if reuse:\n            tf.get_variable_scope().reuse_variables()\n        gamma = tf.get_variable(\n            'ln_gamma', [num_units],\n            initializer=tf.constant_initializer(gamma_start))\n        if use_bias:\n            beta = tf.get_variable(\n                'ln_beta', [num_units], initializer=tf.constant_initializer(0.0))\n    output = gamma * (x_shifted) * inv_std\n    if use_bias:\n        output += beta\n    return output\n\n\ndef raw_layer_norm(x, epsilon=1e-3):\n    axes = [1]\n    mean = tf.reduce_mean(x, axes, keep_dims=True)\n    std = tf.sqrt(\n        tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True) + epsilon)\n    output = (x - mean) / (std)\n    return output\n\n\ndef super_linear(x,\n                 output_size,\n                 scope=None,\n                 reuse=False,\n                 init_w='ortho',\n                 weight_start=0.0,\n                 use_bias=True,\n                 bias_start=0.0,\n                 input_size=None):\n    \"\"\"Performs linear operation. Uses ortho init defined earlier.\"\"\"\n    shape = x.get_shape().as_list()\n    with tf.variable_scope(scope or 'linear'):\n        if reuse:\n            tf.get_variable_scope().reuse_variables()\n\n        w_init = None  # uniform\n        if input_size is None:\n            x_size = shape[1]\n        else:\n            x_size = input_size\n        if init_w == 'zeros':\n            w_init = tf.constant_initializer(0.0)\n        elif init_w == 'constant':\n            w_init = tf.constant_initializer(weight_start)\n        elif init_w == 'gaussian':\n            w_init = tf.random_normal_initializer(stddev=weight_start)\n        elif init_w == 'ortho':\n            w_init = lstm_ortho_initializer(1.0)\n\n        w = tf.get_variable(\n            'super_linear_w', [x_size, output_size], tf.float32, initializer=w_init)\n        if use_bias:\n            b = tf.get_variable(\n                'super_linear_b', [output_size],\n                tf.float32,\n                initializer=tf.constant_initializer(bias_start))\n            return tf.matmul(x, w) + b\n        return tf.matmul(x, w)\n\n\nclass LayerNormLSTMCell(tf.contrib.rnn.RNNCell):\n    \"\"\"Layer-Norm, with Ortho Init. and Recurrent Dropout without Memory Loss.\n\n  https://arxiv.org/abs/1607.06450 - Layer Norm\n  https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss\n  \"\"\"\n\n    def __init__(self,\n                 num_units,\n                 forget_bias=1.0,\n                 use_recurrent_dropout=False,\n                 dropout_keep_prob=0.90):\n        \"\"\"Initialize the Layer Norm LSTM cell.\n\n    Args:\n      num_units: int, The number of units in the LSTM cell.\n      forget_bias: float, The bias added to forget gates (default 1.0).\n      use_recurrent_dropout: Whether to use Recurrent Dropout (default False)\n      dropout_keep_prob: float, dropout keep probability (default 0.90)\n    \"\"\"\n        self.num_units = num_units\n        self.forget_bias = forget_bias\n        self.use_recurrent_dropout = use_recurrent_dropout\n        self.dropout_keep_prob = dropout_keep_prob\n\n    @property\n    def input_size(self):\n        return self.num_units\n\n    @property\n    def output_size(self):\n        return self.num_units\n\n    @property\n    def state_size(self):\n        return 2 * self.num_units\n\n    def get_output(self, state):\n        h, unused_c = tf.split(state, 2, 1)\n        return h\n\n    def __call__(self, x, state, timestep=0, scope=None):\n        with tf.variable_scope(scope or type(self).__name__):\n            h, c = tf.split(state, 2, 1)\n\n            h_size = self.num_units\n            x_size = x.get_shape().as_list()[1]\n            batch_size = x.get_shape().as_list()[0]\n\n            w_init = None  # uniform\n\n            h_init = lstm_ortho_initializer(1.0)\n\n            w_xh = tf.get_variable(\n                'W_xh', [x_size, 4 * self.num_units], initializer=w_init)\n            w_hh = tf.get_variable(\n                'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)\n\n            concat = tf.concat([x, h], 1)  # concat for speed.\n            w_full = tf.concat([w_xh, w_hh], 0)\n            concat = tf.matmul(concat, w_full)  # + bias # live life without garbage.\n\n            # i = input_gate, j = new_input, f = forget_gate, o = output_gate\n            concat = layer_norm_all(concat, batch_size, 4, h_size, 'ln_all')\n            i, j, f, o = tf.split(concat, 4, 1)\n\n            if self.use_recurrent_dropout:\n                g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)\n            else:\n                g = tf.tanh(j)\n\n            new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g\n            new_h = tf.tanh(layer_norm(new_c, h_size, 'ln_c')) * tf.sigmoid(o)\n\n        return new_h, tf.concat([new_h, new_c], 1)\n\n\nclass HyperLSTMCell(tf.contrib.rnn.RNNCell):\n    \"\"\"HyperLSTM with Ortho Init, Layer Norm, Recurrent Dropout, no Memory Loss.\n\n  https://arxiv.org/abs/1609.09106\n  http://blog.otoro.net/2016/09/28/hyper-networks/\n  \"\"\"\n\n    def __init__(self,\n                 num_units,\n                 forget_bias=1.0,\n                 use_recurrent_dropout=False,\n                 dropout_keep_prob=0.90,\n                 use_layer_norm=True,\n                 hyper_num_units=256,\n                 hyper_embedding_size=32,\n                 hyper_use_recurrent_dropout=False):\n        \"\"\"Initialize the Layer Norm HyperLSTM cell.\n\n    Args:\n      num_units: int, The number of units in the LSTM cell.\n      forget_bias: float, The bias added to forget gates (default 1.0).\n      use_recurrent_dropout: Whether to use Recurrent Dropout (default False)\n      dropout_keep_prob: float, dropout keep probability (default 0.90)\n      use_layer_norm: boolean. (default True)\n        Controls whether we use LayerNorm layers in main LSTM & HyperLSTM cell.\n      hyper_num_units: int, number of units in HyperLSTM cell.\n        (default is 128, recommend experimenting with 256 for larger tasks)\n      hyper_embedding_size: int, size of signals emitted from HyperLSTM cell.\n        (default is 16, recommend trying larger values for large datasets)\n      hyper_use_recurrent_dropout: boolean. (default False)\n        Controls whether HyperLSTM cell also uses recurrent dropout.\n        Recommend turning this on only if hyper_num_units becomes large (>= 512)\n    \"\"\"\n        self.num_units = num_units\n        self.forget_bias = forget_bias\n        self.use_recurrent_dropout = use_recurrent_dropout\n        self.dropout_keep_prob = dropout_keep_prob\n        self.use_layer_norm = use_layer_norm\n        self.hyper_num_units = hyper_num_units\n        self.hyper_embedding_size = hyper_embedding_size\n        self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout\n\n        self.total_num_units = self.num_units + self.hyper_num_units\n\n        if self.use_layer_norm:\n            cell_fn = LayerNormLSTMCell\n        else:\n            cell_fn = LSTMCell\n        self.hyper_cell = cell_fn(\n            hyper_num_units,\n            use_recurrent_dropout=hyper_use_recurrent_dropout,\n            dropout_keep_prob=dropout_keep_prob)\n\n    @property\n    def input_size(self):\n        return self._input_size\n\n    @property\n    def output_size(self):\n        return self.num_units\n\n    @property\n    def state_size(self):\n        return 2 * self.total_num_units\n\n    def get_output(self, state):\n        total_h, unused_total_c = tf.split(state, 2, 1)\n        h = total_h[:, 0:self.num_units]\n        return h\n\n    def hyper_norm(self, layer, scope='hyper', use_bias=True):\n        num_units = self.num_units\n        embedding_size = self.hyper_embedding_size\n        # recurrent batch norm init trick (https://arxiv.org/abs/1603.09025).\n        init_gamma = 0.10  # cooijmans' da man.\n        with tf.variable_scope(scope):\n            zw = super_linear(\n                self.hyper_output,\n                embedding_size,\n                init_w='constant',\n                weight_start=0.00,\n                use_bias=True,\n                bias_start=1.0,\n                scope='zw')\n            alpha = super_linear(\n                zw,\n                num_units,\n                init_w='constant',\n                weight_start=init_gamma / embedding_size,\n                use_bias=False,\n                scope='alpha')\n            result = tf.multiply(alpha, layer)\n            if use_bias:\n                zb = super_linear(\n                    self.hyper_output,\n                    embedding_size,\n                    init_w='gaussian',\n                    weight_start=0.01,\n                    use_bias=False,\n                    bias_start=0.0,\n                    scope='zb')\n                beta = super_linear(\n                    zb,\n                    num_units,\n                    init_w='constant',\n                    weight_start=0.00,\n                    use_bias=False,\n                    scope='beta')\n                result += beta\n        return result\n\n    def __call__(self, x, state, timestep=0, scope=None):\n        with tf.variable_scope(scope or type(self).__name__):\n            total_h, total_c = tf.split(state, 2, 1)\n            h = total_h[:, 0:self.num_units]\n            c = total_c[:, 0:self.num_units]\n            self.hyper_state = tf.concat(\n                [total_h[:, self.num_units:], total_c[:, self.num_units:]], 1)\n\n            batch_size = x.get_shape().as_list()[0]\n            x_size = x.get_shape().as_list()[1]\n            self._input_size = x_size\n\n            w_init = None  # uniform\n\n            h_init = lstm_ortho_initializer(1.0)\n\n            w_xh = tf.get_variable(\n                'W_xh', [x_size, 4 * self.num_units], initializer=w_init)\n            w_hh = tf.get_variable(\n                'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)\n            bias = tf.get_variable(\n                'bias', [4 * self.num_units],\n                initializer=tf.constant_initializer(0.0))\n\n            # concatenate the input and hidden states for hyperlstm input\n            hyper_input = tf.concat([x, h], 1)\n            hyper_output, hyper_new_state = self.hyper_cell(hyper_input,\n                                                            self.hyper_state)\n            self.hyper_output = hyper_output\n            self.hyper_state = hyper_new_state\n\n            xh = tf.matmul(x, w_xh)\n            hh = tf.matmul(h, w_hh)\n\n            # split Wxh contributions\n            ix, jx, fx, ox = tf.split(xh, 4, 1)\n            ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False)\n            jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False)\n            fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False)\n            ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False)\n\n            # split Whh contributions\n            ih, jh, fh, oh = tf.split(hh, 4, 1)\n            ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True)\n            jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True)\n            fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True)\n            oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True)\n\n            # split bias\n            ib, jb, fb, ob = tf.split(bias, 4, 0)  # bias is to be broadcasted.\n\n            # i = input_gate, j = new_input, f = forget_gate, o = output_gate\n            i = ix + ih + ib\n            j = jx + jh + jb\n            f = fx + fh + fb\n            o = ox + oh + ob\n\n            if self.use_layer_norm:\n                concat = tf.concat([i, j, f, o], 1)\n                concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all')\n                i, j, f, o = tf.split(concat, 4, 1)\n\n            if self.use_recurrent_dropout:\n                g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)\n            else:\n                g = tf.tanh(j)\n\n            new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g\n            new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o)\n\n            hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1)\n            new_total_h = tf.concat([new_h, hyper_h], 1)\n            new_total_c = tf.concat([new_c, hyper_c], 1)\n            new_total_state = tf.concat([new_total_h, new_total_c], 1)\n        return new_h, new_total_state\n"
  },
  {
    "path": "subnet_tf_utils.py",
    "content": "import tensorflow as tf\n\n\ndef get_initializer(init_method):\n    if init_method == 'xavier_normal':\n        initializer = tf.glorot_normal_initializer()\n    elif init_method == 'xavier_uniform':\n        initializer = tf.glorot_uniform_initializer()\n    elif init_method == 'he_normal':\n        initializer = tf.keras.initializers.he_normal()\n    elif init_method == 'he_uniform':\n        initializer = tf.keras.initializers.he_uniform()\n    elif init_method == 'lecun_normal':\n        initializer = tf.keras.initializers.lecun_normal()\n    elif init_method == 'lecun_uniform':\n        initializer = tf.keras.initializers.lecun_uniform()\n    else:\n        raise Exception('Unknown initializer:', init_method)\n    return initializer\n\n\ndef lrelu(x, leak=0.2, name=\"lrelu\", alt_relu_impl=False):\n    with tf.variable_scope(name) as scope:\n        if alt_relu_impl:\n            f1 = 0.5 * (1 + leak)\n            f2 = 0.5 * (1 - leak)\n            return f1 * x + f2 * abs(x)\n        else:\n            return tf.maximum(x, leak * x)\n\n\ndef batchnorm(input, name='batch_norm', init_method=None):\n    if init_method is not None:\n        initializer = get_initializer(init_method)\n    else:\n        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)\n\n    with tf.variable_scope(name):\n        # this block looks like it has 3 inputs on the graph unless we do this\n        input = tf.identity(input)\n\n        channels = input.get_shape()[3]\n        offset = tf.get_variable(\"offset\", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())\n        scale = tf.get_variable(\"scale\", [channels], dtype=tf.float32,\n                                initializer=initializer)\n        mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)\n        variance_epsilon = 1e-5\n        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)\n        return normalized\n\n\ndef layernorm(input, name='layer_norm', init_method=None):\n    if init_method is not None:\n        initializer = get_initializer(init_method)\n    else:\n        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)\n\n    with tf.variable_scope(name):\n        n_neurons = input.get_shape()[3]\n        offset = tf.get_variable(\"offset\", [n_neurons], dtype=tf.float32, initializer=tf.zeros_initializer())\n        scale = tf.get_variable(\"scale\", [n_neurons], dtype=tf.float32,\n                                initializer=initializer)\n        offset = tf.reshape(offset, [1, 1, -1])\n        scale = tf.reshape(scale, [1, 1, -1])\n        mean, variance = tf.nn.moments(input, axes=[1, 2, 3], keep_dims=True)\n        variance_epsilon = 1e-5\n        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)\n        return normalized\n\n\ndef instance_norm(input, name=\"instance_norm\", init_method=None):\n    if init_method is not None:\n        initializer = get_initializer(init_method)\n    else:\n        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)\n\n    with tf.variable_scope(name):\n        depth = input.get_shape()[3]\n        scale = tf.get_variable(\"scale\", [depth], initializer=initializer)\n        offset = tf.get_variable(\"offset\", [depth], initializer=tf.constant_initializer(0.0))\n        mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True)\n        epsilon = 1e-5\n        inv = tf.rsqrt(variance + epsilon)\n        normalized = (input - mean) * inv\n        return scale * normalized + offset\n\n\ndef linear1d(inputlin, inputdim, outputdim, name=\"linear1d\", init_method=None):\n    if init_method is not None:\n        initializer = get_initializer(init_method)\n    else:\n        initializer = None\n\n    with tf.variable_scope(name) as scope:\n        weight = tf.get_variable(\"weight\", [inputdim, outputdim], initializer=initializer)\n        bias = tf.get_variable(\"bias\", [outputdim], dtype=tf.float32, initializer=tf.constant_initializer(0.0))\n        return tf.matmul(inputlin, weight) + bias\n\n\ndef general_conv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2,\n                   stddev=0.02, padding=\"SAME\", name=\"conv2d\", do_norm=True, norm_type='instance_norm', do_relu=True,\n                   relufactor=0, is_training=True, init_method=None):\n    if init_method is not None:\n        initializer = get_initializer(init_method)\n    else:\n        initializer = tf.truncated_normal_initializer(stddev=stddev)\n\n    with tf.variable_scope(name) as scope:\n        conv = tf.contrib.layers.conv2d(inputconv, output_dim, [filter_width, filter_height],\n                                        [stride_width, stride_height], padding, activation_fn=None,\n                                        weights_initializer=initializer,\n                                        biases_initializer=tf.constant_initializer(0.0))\n        if do_norm:\n            if norm_type == 'instance_norm':\n                conv = instance_norm(conv, init_method=init_method)\n                # conv = tf.contrib.layers.instance_norm(conv, epsilon=1e-05, center=True, scale=True,\n                #                                        scope='instance_norm')\n            elif norm_type == 'batch_norm':\n                # conv = batchnorm(conv, init_method=init_method)\n                conv = tf.contrib.layers.batch_norm(conv, decay=0.9, is_training=is_training, updates_collections=None,\n                                                    epsilon=1e-5, center=True, scale=True, scope=\"batch_norm\")\n            elif norm_type == 'layer_norm':\n                # conv = layernorm(conv, init_method=init_method)\n                conv = tf.contrib.layers.layer_norm(conv, center=True, scale=True, scope='layer_norm')\n\n        if do_relu:\n            if relufactor == 0:\n                conv = tf.nn.relu(conv, \"relu\")\n            else:\n                conv = lrelu(conv, relufactor, \"lrelu\")\n\n        return conv\n\n\ndef generative_cnn_c3_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 512]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_combine33(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n\n    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_combine43(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n\n    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_combine53(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n\n    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n        tensor_x = tf.reshape(tensor_x, (-1, 1024 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 1024 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_combineFC(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))\n        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            local_x = tf.nn.dropout(local_x, drop_keep_prob)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))\n        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            global_x = tf.nn.dropout(global_x, drop_keep_prob)\n\n    tensor_x_sp = None\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n    return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_combineFC_jointAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,\n                                                  init_method=None, combine_manner='attn'):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        share_x = local_x\n\n        with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):\n            attn_x = general_conv2d(share_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                    is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n            attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                                    is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n            attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                                    is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n            attn_map = tf.nn.sigmoid(attn_x)  # (N, H/8, W/8, 1), [0.0, 1.0]\n\n        if combine_manner == 'attn':\n            attn_feat = attn_map * share_x + share_x\n        else:\n            raise Exception('Unknown combine_manner', combine_manner)\n\n        local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))\n        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            local_x = tf.nn.dropout(local_x, drop_keep_prob)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))\n        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            global_x = tf.nn.dropout(global_x, drop_keep_prob)\n\n    tensor_x_sp = None\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n    return tensor_x, tensor_x_sp, attn_map\n\n\ndef generative_cnn_c3_encoder_combineFC_sepAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,\n                                                  init_method=None, combine_manner='attn'):\n    local_x = local_inputs\n    global_x = global_inputs\n\n    with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):\n        attn_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        attn_x = general_conv2d(attn_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3,\n                                is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        attn_x = general_conv2d(attn_x, 128, filter_height=3, filter_width=3,\n                                is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                                is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                                is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n        attn_map = tf.nn.sigmoid(attn_x)  # (N, H/8, W/8, 1), [0.0, 1.0]\n\n    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        if combine_manner == 'attn':\n            attn_feat = attn_map * local_x + local_x\n        elif combine_manner == 'channel':\n            attn_feat = tf.concat([local_x, attn_map], axis=-1)\n        else:\n            raise Exception('Unknown combine_manner', combine_manner)\n\n        local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,\n                                 is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                 is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))\n        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            local_x = tf.nn.dropout(local_x, drop_keep_prob)\n\n    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n\n        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))\n        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            global_x = tf.nn.dropout(global_x, drop_keep_prob)\n\n    tensor_x_sp = None\n    tensor_x = tf.concat([local_x, global_x], axis=-1)\n    return tensor_x, tensor_x_sp, attn_map\n\n\ndef generative_cnn_c3_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_c3_encoder_deeper13_attn(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = self_attention(tensor_x, 64)\n\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,\n                                  is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,\n                                  is_training=is_training, name=\"CNN_ENC_5_3\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_1_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_3_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_1_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_3_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, is_training=is_training, name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 512, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 512]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef generative_cnn_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):\n    tensor_x = inputs\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training,\n                                  name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_1_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training,\n                                  name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_2_2\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training,\n                                  name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_3_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_3_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,\n                                  name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_4_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_4_3\", init_method=init_method)\n\n        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,\n                                  name=\"CNN_ENC_5\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_5_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,\n                                  name=\"CNN_ENC_5_3\", init_method=init_method)\n        tensor_x_sp = tensor_x  # [N, h, w, 256]\n\n        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))\n        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)\n\n        if is_training:\n            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)\n\n        return tensor_x, tensor_x_sp\n\n\ndef max_pooling(x) :\n    return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME')\n\n\ndef hw_flatten(x) :\n    return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])\n\n\ndef self_attention(x, in_channel, name='self_attention'):\n    with tf.variable_scope(name) as scope:\n        f = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                           do_norm=False, do_relu=False, name='f_conv')  # (N, h, w, c')\n        f = max_pooling(f)  # (N, h', w', c')\n        g = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                           do_norm=False, do_relu=False, name='g_conv')  # (N, h, w, c')\n        h = general_conv2d(x, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                           do_norm=False, do_relu=False, name='h_conv')  # (N, h, w, c)\n        h = max_pooling(h)  # (N, h', w', c)\n\n        # M = h * w, M' = h' * w'\n        s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)  # (N, M, M')\n        beta = tf.nn.softmax(s)  # attention map\n\n        o = tf.matmul(beta, hw_flatten(h))  # (N, M, c)\n        gamma = tf.get_variable(\"gamma\", [1], initializer=tf.constant_initializer(0.0))\n\n        o = tf.reshape(o, shape=x.shape)  # (N, h, w, c)\n        o = general_conv2d(o, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,\n                           do_norm=False, do_relu=False, name='attn_conv')\n\n        x = gamma * o + x\n\n    return x\n\n\ndef global_avg_pooling(x):\n    gap = tf.reduce_mean(x, axis=[1, 2])\n    return gap\n\n\ndef cnn_discriminator_wgan_gp(discrim_inputs, discrim_targets, init_method=None):\n    tensor_x = tf.concat([discrim_inputs, discrim_targets], axis=3)  # (N, H, W, 3 + 1)\n\n    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:\n        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,\n                                  is_training=True, name=\"CNN_ENC_1\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,\n                                  is_training=True, name=\"CNN_ENC_2\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=True, name=\"CNN_ENC_3\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,\n                                  is_training=True, name=\"CNN_ENC_4\", init_method=init_method)\n        tensor_x = general_conv2d(tensor_x, 1, filter_height=3, filter_width=3,\n                                  is_training=True, name=\"CNN_ENC_5\", init_method=init_method)\n        # (N, H/32, W/32, 1)\n\n        d_out = global_avg_pooling(tensor_x)  # (N, 1)\n\n    return d_out\n"
  },
  {
    "path": "test_photograph_to_line.py",
    "content": "import numpy as np\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimport argparse\n\nimport hyper_parameters as hparams\nfrom model_common_test import DiffPastingV3, VirtualSketchingModel\nfrom utils import reset_graph, load_checkpoint, update_hyperparams, draw, \\\n    save_seq_data, image_pasting_v3_testing, draw_strokes\nfrom dataset_utils import load_dataset_testing\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0'\n\n\ndef sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_len, state_dependent,\n           pasting_func):\n    \"\"\"Samples a sequence from a pre-trained model.\"\"\"\n    select_times = 1\n    cursor_pos = np.squeeze(init_cursor, axis=0)  # (select_times, 1, 2)\n    curr_canvas = np.zeros(dtype=np.float32,\n                           shape=(select_times, image_size, image_size))  # [0.0-BG, 1.0-stroke]\n\n    initial_state = sess.run(model.initial_state)\n    prev_state = initial_state\n    prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)\n    prev_scaling = np.ones((select_times), dtype=np.float32)  # (N)\n    prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size  # (N)\n\n    params_list = [[] for _ in range(select_times)]\n    state_raw_list = [[] for _ in range(select_times)]\n    state_soft_list = [[] for _ in range(select_times)]\n    window_size_list = [[] for _ in range(select_times)]\n\n    input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1))\n\n    for i in range(seq_len):\n        if not state_dependent and i % init_len == 0:\n            prev_state = initial_state\n\n        curr_window_size = prev_scaling * prev_window_size  # (N)\n        curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)\n        curr_window_size = np.minimum(curr_window_size, image_size)\n\n        feed = {\n            model.initial_state: prev_state,\n            model.input_photo: input_photos_tiles,\n            model.curr_canvas_hard: curr_canvas.copy(),\n            model.cursor_position: cursor_pos,\n            model.image_size: image_size,\n            model.init_width: prev_width,\n            model.init_scaling: prev_scaling,\n            model.init_window_size: prev_window_size,\n        }\n\n        o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \\\n            sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)\n        # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)\n        # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]\n\n        idx_eos_list = np.argmax(o_pen_list, axis=1)  # (N)\n\n        for output_i in range(idx_eos_list.shape[0]):\n            idx_eos = idx_eos_list[output_i]\n\n            eos = [0, 0]\n            eos[idx_eos] = 1\n\n            other_params = o_other_params_list[output_i].tolist()  # (6)\n            params_list[output_i].append([eos[1]] + other_params)\n            state_raw_list[output_i].append(o_pen_list[output_i][1])\n            state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])\n            window_size_list[output_i].append(curr_window_size[output_i])\n\n            # draw the stroke and add to the canvas\n            x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \\\n                                 o_other_params_list[output_i, 4]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width[output_i], width2], axis=0)  # (2)\n            o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist()  # (8)\n\n            if idx_eos == 0:\n                f = o_other_params_proc + [1.0, 1.0]\n                pred_stroke_img = draw(f)  # (raster_size, raster_size), [0.0-stroke, 1.0-BG]\n                pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0],\n                                                                  image_size,\n                                                                  curr_window_size[output_i],\n                                                                  pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n                curr_canvas[output_i] += pred_stroke_img_large  # [0.0-BG, 1.0-stroke]\n        curr_canvas = np.clip(curr_canvas, 0.0, 1.0)\n\n        next_width = o_other_params_list[:, 4]  # (N)\n        next_scaling = o_other_params_list[:, 5]\n        next_window_size = next_scaling * curr_window_size  # (N)\n        next_window_size = np.maximum(next_window_size, model.hps.min_window_size)\n        next_window_size = np.minimum(next_window_size, image_size)\n\n        prev_state = next_state_list\n        prev_width = next_width * curr_window_size / next_window_size  # (N,)\n        prev_scaling = next_scaling  # (N)\n        prev_window_size = curr_window_size\n\n        # update cursor_pos based on hps.cursor_type\n        new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0)  # (N, 2), patch-level\n        new_cursor_offset_next = new_cursor_offsets\n\n        # important!!!\n        new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1)\n\n        cursor_pos_large = cursor_pos * float(image_size)\n        stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next  # (N, 2), large-level\n\n        if model.hps.cursor_type == 'next':\n            cursor_pos_large = stroke_position_next  # (N, 2), large-level\n        else:\n            raise Exception('Unknown cursor_type')\n\n        cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (N, 2), large-level\n        cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1)  # (N, 1, 2)\n        cursor_pos = cursor_pos_large / float(image_size)\n\n    return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list\n\n\ndef main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name,\n                 sampling_num,\n                 draw_seq=False, draw_order=False,\n                 state_dependent=True, longer_infer_len=-1):\n    model_params_default = hparams.get_default_hparams_normal()\n    model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)\n\n    [test_set, eval_hps_model, sample_hps_model] = \\\n        load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)\n\n    test_image_raw_name = test_image_name[:test_image_name.find('.')]\n    model_dir = os.path.join(model_base_dir, model_name)\n\n    reset_graph()\n    sampling_model = VirtualSketchingModel(sample_hps_model)\n\n    # differentiable pasting graph\n    paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)\n\n    tfconfig = tf.ConfigProto()\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    # loads the weights from checkpoint into our model\n    snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)\n    print('snapshot_step', snapshot_step)\n    sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)\n    os.makedirs(sampling_dir, exist_ok=True)\n\n    if longer_infer_len == -1:\n        tmp_max_len = eval_hps_model.max_seq_len\n    else:\n        tmp_max_len = longer_infer_len\n\n    for sampling_i in range(sampling_num):\n        input_photos, init_cursors, test_image_size = test_set.get_test_image()\n        # input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG]\n        # init_cursors: (N, 1, 2), in size [0.0, 1.0)\n\n        print()\n        print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)\n        print('Processing ...')\n\n        if init_cursors.ndim == 3:\n            init_cursors = np.expand_dims(init_cursors, axis=0)\n\n        input_photos = input_photos[0:1, :, :, :]\n\n        ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)\n        ori_img_png = Image.fromarray(ori_img, 'RGB')\n        ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')\n\n        # decoding for sampling\n        strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, window_size_out_list = sample(\n            sess, sampling_model, input_photos, init_cursors, test_image_size,\n            eval_hps_model.max_seq_len, tmp_max_len, state_dependent, paste_v3_func)\n        # pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke]\n\n        output_i = 0\n        strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0)\n        states_raw_out = states_raw_out_list[output_i]\n        states_soft_out = states_soft_out_list[output_i]\n        window_size_out = window_size_out_list[output_i]\n\n        round_new_lengths = [tmp_max_len]\n        multi_cursors = [init_cursors[0, output_i, 0, :]]\n\n        print('strokes_raw_out', strokes_raw_out.shape)\n\n        clean_states_soft_out = np.array(states_soft_out)  # (N)\n\n        flag_list = strokes_raw_out[:, 0].astype(np.int32)  # (N)\n        drawing_len = len(flag_list) - np.sum(flag_list)\n        assert drawing_len >= 0\n\n        # print('    flag  raw\\t soft\\t x1\\t\\t y1\\t\\t x2\\t\\t y2\\t\\t r2\\t\\t s2')\n        for i in range(strokes_raw_out.shape[0]):\n            flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]\n            win_size = window_size_out[i]\n            out_format = '#%d: %d  | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'\n            out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)\n            out_log = out_format % out_values\n            # print(out_log)\n\n        print('Saving results ...')\n        save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),\n                      strokes_raw_out, init_cursors[0, output_i, 0, :],\n                      test_image_size, tmp_max_len, eval_hps_model.min_width)\n\n        draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',\n                     ori_img, test_image_size,\n                     multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,\n                     sample_hps_model.raster_size, sample_hps_model.min_window_size,\n                     sess,\n                     pasting_func=paste_v3_func,\n                     save_seq=draw_seq, draw_order=draw_order)\n\n\ndef main(model_name, test_image_name, sampling_num):\n    test_dataset = 'faces'\n    test_image_base_dir = 'sample_inputs'\n\n    sampling_base_dir = 'outputs/sampling'\n    model_base_dir = 'outputs/snapshot'\n\n    state_dependent = False\n    longer_infer_len = 100\n\n    draw_seq = False\n    draw_color_order = True\n\n    # set numpy output to something sensible\n    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)\n\n    main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name, sampling_num,\n                 draw_seq=draw_seq, draw_order=draw_color_order,\n                 state_dependent=state_dependent, longer_infer_len=longer_infer_len)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', '-i', type=str, default='', help=\"The test image name.\")\n    parser.add_argument('--model', '-m', type=str, default='pretrain_faces', help=\"The trained model.\")\n    parser.add_argument('--sample', '-s', type=int, default=1, help=\"The number of outputs.\")\n    args = parser.parse_args()\n\n    assert args.input != ''\n    assert args.sample > 0\n\n    main(args.model, args.input, args.sample)\n"
  },
  {
    "path": "test_rough_sketch_simplification.py",
    "content": "import numpy as np\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimport argparse\n\nimport hyper_parameters as hparams\nfrom model_common_test import DiffPastingV3, VirtualSketchingModel\nfrom utils import reset_graph, load_checkpoint, update_hyperparams, draw, \\\n    save_seq_data, image_pasting_v3_testing, draw_strokes\nfrom dataset_utils import load_dataset_testing\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0'\n\n\ndef move_cursor_to_undrawn(current_pos_list, input_image_, patch_size,\n                           move_min_dist, move_max_dist, trial_times=20):\n    \"\"\"\n    :param current_pos_list: (select_times, 1, 2), [0.0, 1.0)\n    :param input_image_: (1, image_size, image_size, 3), [0-stroke, 1-BG]\n    :return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0)\n    \"\"\"\n\n    def crop_patch(image, center, image_size, crop_size):\n        x0 = center[0] - crop_size // 2\n        x1 = x0 + crop_size\n        y0 = center[1] - crop_size // 2\n        y1 = y0 + crop_size\n        x0 = max(0, min(x0, image_size))\n        y0 = max(0, min(y0, image_size))\n        x1 = max(0, min(x1, image_size))\n        y1 = max(0, min(y1, image_size))\n        patch = image[y0:y1, x0:x1]\n        return patch\n\n    def isvalid_cursor(input_img, cursor, raster_size, image_size):\n        # input_img: (image_size, image_size, 3), [0.0-BG, 1.0-stroke]\n        cursor_large = cursor * float(image_size)\n        cursor_large = np.round(cursor_large).astype(np.int32)\n        input_crop_patch = crop_patch(input_img, cursor_large, image_size, raster_size)\n        if np.sum(input_crop_patch) > 0.0:\n            return True\n        else:\n            return False\n\n    def randomly_move_cursor(cursor_position, img_size, min_dist_p, max_dist_p):\n        # cursor_position: (2), [0.0, 1.0)\n        cursor_pos_large = cursor_position * img_size\n        min_dist = int(min_dist_p / 2.0 * img_size)\n        max_dist = int(max_dist_p / 2.0 * img_size)\n        rand_cursor_offset = np.random.randint(min_dist, max_dist, size=cursor_pos_large.shape)\n        rand_cursor_offset_sign = np.random.randint(0, 1 + 1, size=cursor_pos_large.shape)\n        rand_cursor_offset_sign[rand_cursor_offset_sign == 0] = -1\n        rand_cursor_offset = rand_cursor_offset * rand_cursor_offset_sign\n\n        new_cursor_pos_large = cursor_pos_large + rand_cursor_offset\n        new_cursor_pos_large = np.minimum(np.maximum(new_cursor_pos_large, 0), img_size - 1)  # (2), large-level\n        new_cursor_pos = new_cursor_pos_large.astype(np.float32) / float(img_size)\n        return new_cursor_pos\n\n    input_image = 1.0 - input_image_[0]  # (image_size, image_size, 3), [0-BG, 1-stroke]\n    img_size = input_image.shape[0]\n\n    new_cursor_pos = []\n    for cursor_i in range(current_pos_list.shape[0]):\n        curr_cursor = current_pos_list[cursor_i][0]\n\n        for trial_i in range(trial_times):\n            new_cursor = randomly_move_cursor(curr_cursor, img_size, move_min_dist, move_max_dist)  # (2), [0.0, 1.0)\n\n            if isvalid_cursor(input_image, new_cursor, patch_size, img_size) or trial_i == trial_times - 1:\n                new_cursor_pos.append(new_cursor)\n                break\n\n    assert len(new_cursor_pos) == current_pos_list.shape[0]\n    new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1)  # (select_times, 1, 2), [0.0, 1.0)\n    return new_cursor_pos\n\n\ndef sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens,\n           state_dependent, pasting_func, round_stop_state_num,\n           min_dist_p, max_dist_p):\n    \"\"\"Samples a sequence from a pre-trained model.\"\"\"\n    select_times = 1\n    curr_canvas = np.zeros(dtype=np.float32,\n                           shape=(select_times, image_size, image_size))  # [0.0-BG, 1.0-stroke]\n\n    initial_state = sess.run(model.initial_state)\n\n    params_list = [[] for _ in range(select_times)]\n    state_raw_list = [[] for _ in range(select_times)]\n    state_soft_list = [[] for _ in range(select_times)]\n    window_size_list = [[] for _ in range(select_times)]\n\n    round_cursor_list = []\n    round_length_real_list = []\n\n    input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1))\n\n    for cursor_i, seq_len in enumerate(seq_lens):\n        if cursor_i == 0:\n            cursor_pos = np.squeeze(init_cursor, axis=0)  # (select_times, 1, 2)\n        else:\n            cursor_pos = move_cursor_to_undrawn(cursor_pos, input_photos, model.hps.raster_size,\n                                                min_dist_p, max_dist_p)  # (select_times, 1, 2)\n            round_cursor_list.append(cursor_pos)\n\n        prev_state = initial_state\n        prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)\n        prev_scaling = np.ones((select_times), dtype=np.float32)  # (N)\n        prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size  # (N)\n\n        continuous_one_state_num = 0\n\n        for i in range(seq_len):\n            if not state_dependent and i % init_len == 0:\n                prev_state = initial_state\n\n            curr_window_size = prev_scaling * prev_window_size  # (N)\n            curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)\n            curr_window_size = np.minimum(curr_window_size, image_size)\n\n            feed = {\n                model.initial_state: prev_state,\n                model.input_photo: input_photos_tiles,\n                model.curr_canvas_hard: curr_canvas.copy(),\n                model.cursor_position: cursor_pos,\n                model.image_size: image_size,\n                model.init_width: prev_width,\n                model.init_scaling: prev_scaling,\n                model.init_window_size: prev_window_size,\n            }\n\n            o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \\\n                sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)\n            # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)\n            # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]\n\n            idx_eos_list = np.argmax(o_pen_list, axis=1)  # (N)\n\n            output_i = 0\n            idx_eos = idx_eos_list[output_i]\n\n            eos = [0, 0]\n            eos[idx_eos] = 1\n\n            other_params = o_other_params_list[output_i].tolist()  # (6)\n            params_list[output_i].append([eos[1]] + other_params)\n            state_raw_list[output_i].append(o_pen_list[output_i][1])\n            state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])\n            window_size_list[output_i].append(curr_window_size[output_i])\n\n            # draw the stroke and add to the canvas\n            x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \\\n                                 o_other_params_list[output_i, 4]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width[output_i], width2], axis=0)  # (2)\n            o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist()  # (8)\n\n            if idx_eos == 0:\n                f = o_other_params_proc + [1.0, 1.0]\n                pred_stroke_img = draw(f)  # (raster_size, raster_size), [0.0-stroke, 1.0-BG]\n                pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img,\n                                                                  cursor_pos[output_i, 0],\n                                                                  image_size,\n                                                                  curr_window_size[output_i],\n                                                                  pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n                curr_canvas[output_i] += pred_stroke_img_large  # [0.0-BG, 1.0-stroke]\n\n                continuous_one_state_num = 0\n            else:\n                continuous_one_state_num += 1\n\n            curr_canvas = np.clip(curr_canvas, 0.0, 1.0)\n\n            next_width = o_other_params_list[:, 4]  # (N)\n            next_scaling = o_other_params_list[:, 5]\n            next_window_size = next_scaling * curr_window_size  # (N)\n            next_window_size = np.maximum(next_window_size, model.hps.min_window_size)\n            next_window_size = np.minimum(next_window_size, image_size)\n\n            prev_state = next_state_list\n            prev_width = next_width * curr_window_size / next_window_size  # (N,)\n            prev_scaling = next_scaling  # (N)\n            prev_window_size = curr_window_size\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = o_other_params_list[:, 2:4] * (\n                        np.expand_dims(curr_window_size, axis=-1) / 2.0)  # (N, 2), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]],\n                                                    axis=-1)\n\n            cursor_pos_large = cursor_pos * float(image_size)\n            stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next  # (N, 2), large-level\n\n            if model.hps.cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (N, 2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0),\n                                          float(image_size - 1))  # (N, 2), large-level\n            cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1)  # (N, 1, 2)\n            cursor_pos = cursor_pos_large / float(image_size)\n\n            if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1:\n                round_length_real_list.append(i + 1)\n                break\n\n    return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \\\n           round_cursor_list, round_length_real_list\n\n\ndef main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name,\n                 sampling_num,\n                 min_dist_p, max_dist_p,\n                 longer_infer_lens, round_stop_state_num,\n                 draw_seq=False, draw_order=False,\n                 state_dependent=True):\n    model_params_default = hparams.get_default_hparams_rough()\n    model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)\n\n    [test_set, eval_hps_model, sample_hps_model] = \\\n        load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)\n\n    test_image_raw_name = test_image_name[:test_image_name.find('.')]\n    model_dir = os.path.join(model_base_dir, model_name)\n\n    reset_graph()\n    sampling_model = VirtualSketchingModel(sample_hps_model)\n\n    # differentiable pasting graph\n    paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)\n\n    tfconfig = tf.ConfigProto()\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    # loads the weights from checkpoint into our model\n    snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)\n    print('snapshot_step', snapshot_step)\n    sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)\n    os.makedirs(sampling_dir, exist_ok=True)\n\n    for sampling_i in range(sampling_num):\n        input_photos, init_cursors, test_image_size = test_set.get_test_image()\n        # input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG]\n        # init_cursors: (N, 1, 2), in size [0.0, 1.0)\n\n        print()\n        print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)\n        print('Processing ...')\n\n        if init_cursors.ndim == 3:\n            init_cursors = np.expand_dims(init_cursors, axis=0)\n\n        input_photos = input_photos[0:1, :, :, :]\n\n        ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)\n        ori_img_png = Image.fromarray(ori_img, 'RGB')\n        ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')\n\n        # decoding for sampling\n        strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \\\n        window_size_out_list, round_new_cursors, round_new_lengths = sample(\n            sess, sampling_model, input_photos, init_cursors, test_image_size,\n            eval_hps_model.max_seq_len, longer_infer_lens, state_dependent, paste_v3_func,\n            round_stop_state_num, min_dist_p, max_dist_p)\n        # pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke]\n\n        print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths)\n\n        output_i = 0\n        strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0)\n        states_raw_out = states_raw_out_list[output_i]\n        states_soft_out = states_soft_out_list[output_i]\n        window_size_out = window_size_out_list[output_i]\n\n        multi_cursors = [init_cursors[0, output_i, 0]]\n        for c_i in range(len(round_new_cursors)):\n            best_cursor = round_new_cursors[c_i][output_i, 0]  # (2)\n            multi_cursors.append(best_cursor)\n        assert len(multi_cursors) == len(round_new_lengths)\n\n        print('strokes_raw_out', strokes_raw_out.shape)\n\n        clean_states_soft_out = np.array(states_soft_out)  # (N)\n\n        flag_list = strokes_raw_out[:, 0].astype(np.int32)  # (N)\n        drawing_len = len(flag_list) - np.sum(flag_list)\n        assert drawing_len >= 0\n\n        # print('    flag  raw\\t soft\\t x1\\t\\t y1\\t\\t x2\\t\\t y2\\t\\t r2\\t\\t s2')\n        for i in range(strokes_raw_out.shape[0]):\n            flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]\n            win_size = window_size_out[i]\n            out_format = '#%d: %d  | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'\n            out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)\n            out_log = out_format % out_values\n            # print(out_log)\n\n        print('Saving results ...')\n        save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),\n                      strokes_raw_out, multi_cursors,\n                      test_image_size, round_new_lengths, eval_hps_model.min_width)\n\n        draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',\n                     ori_img, test_image_size,\n                     multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,\n                     sample_hps_model.raster_size, sample_hps_model.min_window_size,\n                     sess,\n                     pasting_func=paste_v3_func,\n                     save_seq=draw_seq, draw_order=draw_order)\n\n\ndef main(model_name, test_image_name, sampling_num):\n    test_dataset = 'rough_sketches'\n    test_image_base_dir = 'sample_inputs'\n\n    sampling_base_dir = 'outputs/sampling'\n    model_base_dir = 'outputs/snapshot'\n\n    state_dependent = False\n    longer_infer_lens = [128 for _ in range(10)]\n    round_stop_state_num = 12\n    min_dist_p = 0.3\n    max_dist_p = 0.9\n\n    draw_seq = False\n    draw_color_order = True\n\n    # set numpy output to something sensible\n    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)\n\n    main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name, sampling_num,\n                 min_dist_p=min_dist_p, max_dist_p=max_dist_p,\n                 draw_seq=draw_seq, draw_order=draw_color_order,\n                 state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,\n                 round_stop_state_num=round_stop_state_num)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', '-i', type=str, default='', help=\"The test image name.\")\n    parser.add_argument('--model', '-m', type=str, default='pretrain_rough_sketches', help=\"The trained model.\")\n    parser.add_argument('--sample', '-s', type=int, default=1, help=\"The number of outputs.\")\n    args = parser.parse_args()\n\n    assert args.input != ''\n    assert args.sample > 0\n\n    main(args.model, args.input, args.sample)\n"
  },
  {
    "path": "test_vectorization.py",
    "content": "import numpy as np\nimport random\nimport os\nimport tensorflow as tf\nfrom six.moves import range\nfrom PIL import Image\nimport time\nimport argparse\n\nimport hyper_parameters as hparams\nfrom model_common_test import DiffPastingV3, VirtualSketchingModel\nfrom utils import reset_graph, load_checkpoint, update_hyperparams, draw, \\\n    save_seq_data, image_pasting_v3_testing, draw_strokes\nfrom dataset_utils import load_dataset_testing\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0'\n\n\ndef move_cursor_to_undrawn(current_canvas_list, input_image_, last_min_acc_list, grid_patch_size=128,\n                           stroke_acc_threshold=0.95, stroke_num_threshold=5, continuous_min_acc_threshold=2):\n    \"\"\"\n    :param current_canvas_list: (select_times, image_size, image_size), [0.0-BG, 1.0-stroke]\n    :param input_image_: (1, image_size, image_size), [0-stroke, 1-BG]\n    :return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0)\n    \"\"\"\n    def split_images(in_img, image_size, grid_size):\n        if image_size % grid_size == 0:\n            paddings_ = 0\n        else:\n            paddings_ = grid_size - image_size % grid_size\n        paddings = [[0, paddings_],\n                    [0, paddings_]]\n        image_pad = np.pad(in_img, paddings, mode='constant', constant_values=0.0)  # (H_p, W_p), [0.0-BG, 1.0-stroke]\n\n        assert image_pad.shape[0] % grid_size == 0\n\n        split_num = image_pad.shape[0] // grid_size\n\n        images_h = np.hsplit(image_pad, split_num)\n        image_patches = []\n        for image_h in images_h:\n            images_v = np.vsplit(image_h, split_num)\n            image_patches += images_v\n        image_patches = np.array(image_patches, dtype=np.float32)\n        return image_patches, split_num\n\n    def line_drawing_rounding(line_drawing):\n        line_drawing_r = np.copy(line_drawing)  # [0.0-BG, 1.0-stroke]\n        line_drawing_r[line_drawing_r != 0.0] = 1.0\n        return line_drawing_r\n\n    def cal_undrawn_pixels(in_canvas, in_sketch):\n        in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32)  # (N, H, W), [0.0-BG, 1.0-stroke]\n        in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32)\n\n        intersection = np.bitwise_and(in_canvas_round, in_sketch_round)\n\n        intersection_sum = np.sum(intersection, axis=(1, 2))\n        gt_sum = np.sum(in_sketch_round, axis=(1, 2))  # (N)\n\n        undrawn_num = gt_sum - intersection_sum\n        return undrawn_num\n\n    def cal_stroke_acc(in_canvas, in_sketch):\n        in_canvas_round = line_drawing_rounding(in_canvas).astype(np.int32)  # (N, H, W), [0.0-BG, 1.0-stroke]\n        in_sketch_round = line_drawing_rounding(in_sketch).astype(np.int32)\n\n        intersection = np.bitwise_and(in_canvas_round, in_sketch_round)\n\n        intersection_sum = np.sum(intersection, axis=(1, 2)).astype(np.float32)\n        gt_sum = np.sum(in_sketch_round, axis=(1, 2)).astype(np.float32)  # (N)\n        undrawn_num = gt_sum - intersection_sum  # (N)\n\n        stroke_acc = intersection_sum / gt_sum  # (N)\n        stroke_acc[gt_sum == 0.0] = 1.0\n        stroke_acc[undrawn_num <= stroke_num_threshold] = 1.0\n        return stroke_acc\n\n    def get_cursor(patch_idx, img_size, grid_size, split_num):\n        y_pos = patch_idx % split_num\n        x_pos = patch_idx // split_num\n\n        y_top = y_pos * grid_size + grid_size // 4\n        y_bottom = y_top + grid_size // 2\n        x_left = x_pos * grid_size + grid_size // 4\n        x_right = x_left + grid_size // 2\n\n        cursor_y = random.randint(y_top, y_bottom)\n        cursor_x = random.randint(x_left, x_right)\n\n        cursor_y = max(0, min(cursor_y, img_size - 1))\n        cursor_x = max(0, min(cursor_x, img_size - 1))  # (2), in large size\n        center = np.array([cursor_x, cursor_y], dtype=np.float32)\n\n        return center / float(img_size)  # (2), in size [0.0, 1.0)\n\n    input_image = 1.0 - input_image_[0]  # (image_size, image_size), [0-BG, 1-stroke]\n    img_size = input_image.shape[0]\n\n    input_image_patches, split_number = split_images(input_image, img_size, grid_patch_size)  # (N, grid_size, grid_size)\n\n    new_cursor_pos = []\n    last_min_acc_list_new = [item for item in last_min_acc_list]\n    for canvas_i in range(current_canvas_list.shape[0]):\n        curr_canvas = current_canvas_list[canvas_i]  # (image_size, image_size), [0.0-BG, 1.0-stroke]\n\n        curr_canvas_patches, _ = split_images(curr_canvas, img_size, grid_patch_size)  # (N, grid_size, grid_size)\n\n        # 1. detect ending flag by stroke accuracy\n        stroke_accuracy = cal_stroke_acc(curr_canvas_patches, input_image_patches)\n        min_acc_idx = np.argmin(stroke_accuracy)\n        min_acc= stroke_accuracy[min_acc_idx]\n        # print('min_acc_idx', min_acc_idx, ' | ', 'min_acc', min_acc)\n\n        if min_acc >= stroke_acc_threshold:  # end of drawing\n            return None, None\n\n        # 2. detect undrawn pixels\n        undrawn_pixel_num = cal_undrawn_pixels(curr_canvas_patches, input_image_patches)\n        # undrawn_pixel_num_dis = np.reshape(undrawn_pixel_num, (split_number, split_number)).T\n        # print('undrawn_pixel_num_dis')\n        # print(undrawn_pixel_num_dis)\n\n        max_undrawn_idx = np.argmax(undrawn_pixel_num)\n        # max_undrawn = undrawn_pixel_num[max_undrawn_idx]\n        # print('max_undrawn_idx', max_undrawn_idx, ' | ', 'max_undrawn', max_undrawn)\n\n        # 3. select a random position\n        last_min_acc_idx, last_min_acc_times = last_min_acc_list[canvas_i]\n        if last_min_acc_times >= continuous_min_acc_threshold and last_min_acc_idx == min_acc_idx:\n            selected_patch_idx = last_min_acc_idx\n            new_min_acc_times = 1\n        else:\n            selected_patch_idx = max_undrawn_idx\n\n            if min_acc_idx == last_min_acc_idx:\n                new_min_acc_times = last_min_acc_times + 1\n            else:\n                new_min_acc_times = 1\n\n        new_min_acc_idx = min_acc_idx\n        last_min_acc_list_new[canvas_i] = (new_min_acc_idx, new_min_acc_times)\n        # print('selected_patch_idx', selected_patch_idx)\n\n        # 4. get cursor according to the selected_patch_idx\n        rand_cursor = get_cursor(selected_patch_idx, img_size, grid_patch_size, split_number)  # (2), in size [0.0, 1.0)\n        new_cursor_pos.append(rand_cursor)\n\n    assert len(new_cursor_pos) == current_canvas_list.shape[0]\n    new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1)  # (select_times, 1, 2), [0.0, 1.0)\n    return new_cursor_pos, last_min_acc_list_new\n\n\ndef sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens, state_dependent,\n           pasting_func, round_stop_state_num, stroke_acc_threshold):\n    \"\"\"Samples a sequence from a pre-trained model.\"\"\"\n    select_times = 1\n    curr_canvas = np.zeros(dtype=np.float32,\n                           shape=(select_times, image_size, image_size))  # [0.0-BG, 1.0-stroke]\n\n    initial_state = sess.run(model.initial_state)\n    prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)\n\n    params_list = [[] for _ in range(select_times)]\n    state_raw_list = [[] for _ in range(select_times)]\n    state_soft_list = [[] for _ in range(select_times)]\n    window_size_list = [[] for _ in range(select_times)]\n\n    last_min_stroke_acc_list = [(-1, 0) for _ in range(select_times)]\n\n    round_cursor_list = []\n    round_length_real_list = []\n\n    input_photos_tiles = np.tile(input_photos, (select_times, 1, 1))\n\n    for cursor_i, seq_len in enumerate(seq_lens):\n        # print('\\n')\n        # print('@@ Round', cursor_i + 1)\n        if cursor_i == 0:\n            cursor_pos = np.squeeze(init_cursor, axis=0)  # (select_times, 1, 2)\n        else:\n            cursor_pos, last_min_stroke_acc_list_updated = \\\n                move_cursor_to_undrawn(curr_canvas, input_photos, last_min_stroke_acc_list,\n                                       grid_patch_size=model.hps.raster_size,\n                                       stroke_acc_threshold=stroke_acc_threshold)  # (select_times, 1, 2)\n            if cursor_pos is not None:\n                round_cursor_list.append(cursor_pos)\n                last_min_stroke_acc_list = last_min_stroke_acc_list_updated\n            else:\n                break\n\n        prev_state = initial_state\n        if not model.hps.init_cursor_on_undrawn_pixel:\n            prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)\n        prev_scaling = np.ones((select_times), dtype=np.float32)  # (N)\n        prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size  # (N)\n\n        continuous_one_state_num = 0\n\n        for i in range(seq_len):\n            if not state_dependent and i % init_len == 0:\n                prev_state = initial_state\n\n            curr_window_size = prev_scaling * prev_window_size  # (N)\n            curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)\n            curr_window_size = np.minimum(curr_window_size, image_size)\n\n            feed = {\n                model.initial_state: prev_state,\n                model.input_photo: np.expand_dims(input_photos_tiles, axis=-1),\n                model.curr_canvas_hard: curr_canvas.copy(),\n                model.cursor_position: cursor_pos,\n                model.image_size: image_size,\n                model.init_width: prev_width,\n                model.init_scaling: prev_scaling,\n                model.init_window_size: prev_window_size,\n            }\n\n            o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \\\n                sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)\n            # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)\n            # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]\n\n            idx_eos_list = np.argmax(o_pen_list, axis=1)  # (N)\n\n            output_i = 0\n            idx_eos = idx_eos_list[output_i]\n\n            eos = [0, 0]\n            eos[idx_eos] = 1\n\n            other_params = o_other_params_list[output_i].tolist()  # (6)\n            params_list[output_i].append([eos[1]] + other_params)\n            state_raw_list[output_i].append(o_pen_list[output_i][1])\n            state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])\n            window_size_list[output_i].append(curr_window_size[output_i])\n\n            # draw the stroke and add to the canvas\n            x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \\\n                                 o_other_params_list[output_i, 4]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width[output_i], width2], axis=0)  # (2)\n            o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist()  # (8)\n\n            if idx_eos == 0:\n                f = o_other_params_proc + [1.0, 1.0]\n                pred_stroke_img = draw(f)  # (raster_size, raster_size), [0.0-stroke, 1.0-BG]\n                pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img, cursor_pos[output_i, 0],\n                                                                  image_size,\n                                                                  curr_window_size[output_i],\n                                                                  pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n                curr_canvas[output_i] += pred_stroke_img_large  # [0.0-BG, 1.0-stroke]\n\n                continuous_one_state_num = 0\n            else:\n                continuous_one_state_num += 1\n\n            curr_canvas = np.clip(curr_canvas, 0.0, 1.0)\n\n            next_width = o_other_params_list[:, 4]  # (N)\n            next_scaling = o_other_params_list[:, 5]\n            next_window_size = next_scaling * curr_window_size  # (N)\n            next_window_size = np.maximum(next_window_size, model.hps.min_window_size)\n            next_window_size = np.minimum(next_window_size, image_size)\n\n            prev_state = next_state_list\n            prev_width = next_width * curr_window_size / next_window_size  # (N,)\n            prev_scaling = next_scaling  # (N)\n            prev_window_size = curr_window_size\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = o_other_params_list[:, 2:4] * (np.expand_dims(curr_window_size, axis=-1) / 2.0)  # (N, 2), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]], axis=-1)\n\n            cursor_pos_large = cursor_pos * float(image_size)\n            stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next  # (N, 2), large-level\n\n            if model.hps.cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (N, 2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (N, 2), large-level\n            cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1)  # (N, 1, 2)\n            cursor_pos = cursor_pos_large / float(image_size)\n\n            if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1:\n                round_length_real_list.append(i + 1)\n                break\n\n    return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \\\n           round_cursor_list, round_length_real_list\n\n\ndef main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name,\n                 sampling_num,\n                 longer_infer_lens,\n                 round_stop_state_num, stroke_acc_threshold,\n                 draw_seq=False, draw_order=False,\n                 state_dependent=True):\n    model_params_default = hparams.get_default_hparams_clean()\n    model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)\n\n    [test_set, eval_hps_model, sample_hps_model] \\\n        = load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)\n\n    test_image_raw_name = test_image_name[:test_image_name.find('.')]\n    model_dir = os.path.join(model_base_dir, model_name)\n\n    reset_graph()\n    sampling_model = VirtualSketchingModel(sample_hps_model)\n\n    # differentiable pasting graph\n    paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)\n\n    tfconfig = tf.ConfigProto()\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    # loads the weights from checkpoint into our model\n    snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)\n    print('snapshot_step', snapshot_step)\n    sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)\n    os.makedirs(sampling_dir, exist_ok=True)\n\n    stroke_number_list = []\n    compute_time_list = []\n\n    for sampling_i in range(sampling_num):\n        start_time_point = time.time()\n        input_photos, init_cursors, test_image_size = test_set.get_test_image()\n        # input_photos: (1, image_size, image_size), [0-stroke, 1-BG]\n        # init_cursors: (1, 1, 2), in size [0.0, 1.0)\n\n        print()\n        print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)\n        print('Processing ...')\n\n        if init_cursors.ndim == 3:\n            init_cursors = np.expand_dims(init_cursors, axis=0)\n\n        input_photos = input_photos[0:1, :, :]\n        ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)\n        ori_img = np.stack([ori_img for _ in range(3)], axis=2)\n        ori_img_png = Image.fromarray(ori_img, 'RGB')\n        ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')\n\n        data_loading_time_point = time.time()\n\n        # decoding for sampling\n        strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \\\n        window_size_out_list, round_new_cursors, round_new_lengths = sample(\n            sess, sampling_model, input_photos, init_cursors, test_image_size,\n            eval_hps_model.max_seq_len, longer_infer_lens, state_dependent,\n            paste_v3_func, round_stop_state_num, stroke_acc_threshold)\n        # pred_imgs_out: [0.0-BG, 1.0-stroke]\n\n        print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths)\n\n        sampling_time_point = time.time()\n\n        data_loading_time = data_loading_time_point - start_time_point\n        sampling_time_total = sampling_time_point - start_time_point\n        sampling_time_wo_data_loading = sampling_time_point - data_loading_time_point\n        compute_time_list.append(sampling_time_total)\n        # print('  >>> data_loading_time', data_loading_time)\n        print('  >>> sampling_time_total', sampling_time_total)\n        # print('  >>> sampling_time_wo_data_loading', sampling_time_wo_data_loading)\n\n        best_result_idx = 0\n        strokes_raw_out = np.stack(strokes_raw_out_list[best_result_idx], axis=0)\n        states_raw_out = states_raw_out_list[best_result_idx]\n        states_soft_out = states_soft_out_list[best_result_idx]\n        window_size_out = window_size_out_list[best_result_idx]\n\n        multi_cursors = [init_cursors[0, best_result_idx, 0]]\n        for c_i in range(len(round_new_cursors)):\n            best_cursor = round_new_cursors[c_i][best_result_idx, 0]  # (2)\n            multi_cursors.append(best_cursor)\n        assert len(multi_cursors) == len(round_new_lengths)\n\n        print('strokes_raw_out', strokes_raw_out.shape)\n        stroke_number_list.append(strokes_raw_out.shape[0])\n\n        clean_states_soft_out = np.array(states_soft_out)  # (N)\n\n        flag_list = strokes_raw_out[:, 0].astype(np.int32)  # (N)\n        drawing_len = len(flag_list) - np.sum(flag_list)\n        assert drawing_len >= 0\n\n        # print('    flag  raw\\t soft\\t x1\\t\\t y1\\t\\t x2\\t\\t y2\\t\\t r2\\t\\t s2')\n        for i in range(strokes_raw_out.shape[0]):\n            flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]\n            win_size = window_size_out[i]\n            out_format = '#%d: %d  | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'\n            out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)\n            out_log = out_format % out_values\n            # print(out_log)\n\n        print('Saving results ...')\n        save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),\n                      strokes_raw_out, multi_cursors,\n                      test_image_size, round_new_lengths, eval_hps_model.min_width)\n\n        draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png', \n                     ori_img, test_image_size,\n                     multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,\n                     sample_hps_model.raster_size, sample_hps_model.min_window_size,\n                     sess,\n                     pasting_func=paste_v3_func,\n                     save_seq=draw_seq, draw_order=draw_order)\n\n    average_stroke_number = np.mean(stroke_number_list)\n    average_compute_time = np.mean(compute_time_list)\n    print()\n    print('@@@ Total summary:')\n    print('  >>> average_stroke_number', average_stroke_number)\n    print('  >>> average_compute_time', average_compute_time)\n\n\ndef main(model_name, test_image_name, sampling_num):\n    test_dataset = 'clean_line_drawings'\n    test_image_base_dir = 'sample_inputs'\n\n    sampling_base_dir = 'outputs/sampling'\n    model_base_dir = 'outputs/snapshot'\n\n    state_dependent = False\n    longer_infer_lens = [500 for _ in range(10)]\n    round_stop_state_num = 12\n    stroke_acc_threshold = 0.95\n\n    draw_seq = False\n    draw_color_order = True\n\n    # set numpy output to something sensible\n    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)\n\n    main_testing(test_image_base_dir, test_dataset, test_image_name,\n                 sampling_base_dir, model_base_dir, model_name, sampling_num,\n                 draw_seq=draw_seq, draw_order=draw_color_order,\n                 state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,\n                 round_stop_state_num=round_stop_state_num, stroke_acc_threshold=stroke_acc_threshold)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', '-i', type=str, default='', help=\"The test image name.\")\n    parser.add_argument('--model', '-m', type=str, default='pretrain_clean_line_drawings', help=\"The trained model.\")\n    parser.add_argument('--sample', '-s', type=int, default=1, help=\"The number of outputs.\")\n    args = parser.parse_args()\n\n    assert args.input != ''\n    assert args.sample > 0\n\n    main(args.model, args.input, args.sample)\n"
  },
  {
    "path": "tools/gif_making.py",
    "content": "import os\nimport sys\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\n\nsys.path.append('./')\nfrom utils import draw, image_pasting_v3_testing\nfrom model_common_test import DiffPastingV3\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0'\n\n\ndef add_scaling_visualization(canvas_images, cursor, window_size, image_size):\n    \"\"\"\n    :param canvas_images: (N, H, W, 3)\n    :param cursor:\n    :param window_size:\n    :param image_size:\n    :return:\n    \"\"\"\n    cursor_pos = cursor * float(image_size)\n    cursor_x, cursor_y = int(round(cursor_pos[0])), int(round(cursor_pos[1]))  # in large size\n\n    vis_color = [255, 0, 0]\n    cursor_width = 3\n    box_width = 2\n\n    canvas_imgs = 255 - np.round(canvas_images * 255.0).astype(np.uint8)\n\n    # add cursor visualization\n    canvas_imgs[:, cursor_y - cursor_width: cursor_y + cursor_width, cursor_x - cursor_width: cursor_x + cursor_width, :] = vis_color\n\n    # add box visualization\n    up = max(0, cursor_y - window_size // 2)\n    down = min(image_size, cursor_y + window_size // 2)\n    left = max(0, cursor_x - window_size // 2)\n    right = min(image_size, cursor_x + window_size // 2)\n    # up = cursor_y - window_size // 2\n    # down = cursor_y + window_size // 2\n    # left = cursor_x - window_size // 2\n    # right = cursor_x + window_size // 2\n\n    if up > 0:\n        canvas_imgs[:, up: up + box_width, left: right, :] = vis_color\n    if down < image_size:\n        canvas_imgs[:, down - box_width: down, left: right, :] = vis_color\n    if left > 0:\n        canvas_imgs[:, up: down, left: left + box_width, :] = vis_color\n    if right < image_size:\n        canvas_imgs[:, up: down, right - box_width: right, :] = vis_color\n    return canvas_imgs\n\n\ndef make_gif(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,\n             save_base,\n             cursor_type='next', min_window_size=32, raster_size=128, add_box=True):\n    \"\"\"\n    :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2\n    :return:\n    \"\"\"\n    canvas = np.zeros((image_size, image_size), dtype=np.float32)  # [0.0-BG, 1.0-stroke]\n    gif_frames = []\n\n    cursor_idx = 0\n\n    if init_cursor.ndim == 1:\n        init_cursor = [init_cursor]\n\n    for round_idx in range(len(infer_lengths)):\n        print('Making progress', round_idx + 1, '/', len(infer_lengths))\n        round_length = infer_lengths[round_idx]\n\n        cursor_pos = init_cursor[cursor_idx]  # (2)\n        cursor_idx += 1\n\n        prev_width = init_width\n        prev_scaling = 1.0\n        prev_window_size = float(raster_size)  # (1)\n\n        for round_inner_i in range(round_length):\n            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i\n\n            curr_window_size_raw = prev_scaling * prev_window_size\n            curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)\n            curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)\n            curr_window_size = int(round(curr_window_size_raw))  # ()\n\n            pen_state = data[stroke_idx, 0]\n            stroke_params = data[stroke_idx, 1:]  # (8)\n\n            x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width, width2], axis=0)  # (2)\n            stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1)  # (8)\n\n            next_width = stroke_params[4]\n            next_scaling = stroke_params[5]\n            next_window_size = next_scaling * curr_window_size_raw\n            next_window_size = np.maximum(next_window_size, min_window_size)\n            next_window_size = np.minimum(next_window_size, image_size)\n\n            prev_width = next_width * curr_window_size_raw / next_window_size\n            prev_scaling = next_scaling\n            prev_window_size = curr_window_size_raw\n\n            f = stroke_params_proc.tolist()  # (8)\n            f += [1.0, 1.0]\n            gt_stroke_img = draw(f)  # (H, W), [0.0-stroke, 1.0-BG]\n\n            gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,\n                                                           image_size,\n                                                           curr_window_size_raw,\n                                                           pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n\n            if pen_state == 0:\n                canvas += gt_stroke_img_large  # [0.0-BG, 1.0-stroke]\n\n            canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)\n\n            if add_box:\n                vis_inputs = np.expand_dims(canvas_rgb, axis=0)\n                vis_outputs = add_scaling_visualization(vis_inputs, cursor_pos, curr_window_size, image_size)\n                canvas_vis = vis_outputs[0]\n            else:\n                canvas_vis = canvas_rgb\n\n            canvas_vis_png = Image.fromarray(canvas_vis, 'RGB')\n            gif_frames.append(canvas_vis_png)\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0)  # (1, 6), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)\n\n            cursor_pos_large = cursor_pos * float(image_size)\n\n            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level\n\n            if cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level\n            cursor_pos = cursor_pos_large / float(image_size)\n\n    print('Saving to GIF ...')\n    save_path = os.path.join(save_base, 'dynamic.gif')\n    first_frame = gif_frames[0]\n    first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01)\n\n\ndef gif_making(npz_path):\n    assert npz_path != ''\n\n    min_window_size = 32\n    raster_size = 128\n\n    split_idx = npz_path.rfind('/')\n    if split_idx == -1:\n        file_base = './'\n        file_name = npz_path[:-4]\n    else:\n        file_base = npz_path[:npz_path.rfind('/')]\n        file_name = npz_path[npz_path.rfind('/') + 1: -4]\n\n    gif_base = os.path.join(file_base, file_name)\n    os.makedirs(gif_base, exist_ok=True)\n\n    # differentiable pasting graph\n    paste_v3_func = DiffPastingV3(raster_size)\n\n    tfconfig = tf.ConfigProto()\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    data = np.load(npz_path, encoding='latin1', allow_pickle=True)\n    strokes_data = data['strokes_data']\n    init_cursors = data['init_cursors']\n    image_size = data['image_size']\n    round_length = data['round_length']\n    init_width = data['init_width']\n\n    if round_length.ndim == 0:\n        round_lengths = [round_length]\n    else:\n        round_lengths = round_length\n\n    # print('round_lengths', round_lengths)\n\n    make_gif(sess, paste_v3_func,\n             strokes_data, init_cursors, image_size, round_lengths, init_width,\n             gif_base,\n             min_window_size=min_window_size, raster_size=raster_size)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--file', '-f', type=str, default='', help=\"define a npz path\")\n    args = parser.parse_args()\n\n    gif_making(args.file)\n"
  },
  {
    "path": "tools/svg_conversion.py",
    "content": "import os\nimport argparse\nimport numpy as np\nfrom xml.dom import minidom\n\n\ndef write_svg_1(path_list, img_size, save_path):\n    ''' A long curve consisting of several strokes forms a path. '''\n    impl = minidom.getDOMImplementation()\n\n    doc = impl.createDocument(None, None, None)\n\n    rootElement = doc.createElement('svg')\n    rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg')\n    rootElement.setAttribute('height', str(img_size))\n    rootElement.setAttribute('width', str(img_size))\n\n    path_num = len(path_list)\n    for path_i in range(path_num):\n        path_items = path_list[path_i]\n\n        assert len(path_items) > 0\n        if len(path_items) == 1:\n            continue\n\n        childElement = doc.createElement('path')\n        childElement.setAttribute('id', 'curve_' + str(path_i))\n        childElement.setAttribute('stroke', '#000000')\n        childElement.setAttribute('stroke-width', '3.5')\n        childElement.setAttribute('stroke-linejoin', 'round')\n        childElement.setAttribute('stroke-linecap', 'round')\n        childElement.setAttribute('fill', 'none')\n\n        command_str = ''\n        for stroke_i, stroke_item in enumerate(path_items):\n            if stroke_i == 0:\n                command_str += 'M '\n                stroke_position = stroke_item[0]\n                command_str += str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '\n            else:\n                command_str += 'Q '\n                ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2]\n\n                ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1]\n                ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0]\n\n                command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \\\n                               str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '\n\n            last_position = stroke_position\n\n        childElement.setAttribute('d', command_str)\n        rootElement.appendChild(childElement)\n\n    doc.appendChild(rootElement)\n\n    f = open(save_path, 'w')\n    doc.writexml(f, addindent='  ', newl='\\n')\n    f.close()\n\n\ndef write_svg_2(path_list, img_size, save_path):\n    ''' A single stroke forms a path. '''\n    impl = minidom.getDOMImplementation()\n\n    doc = impl.createDocument(None, None, None)\n\n    rootElement = doc.createElement('svg')\n    rootElement.setAttribute('xmlns', 'http://www.w3.org/2000/svg')\n    rootElement.setAttribute('height', str(img_size))\n    rootElement.setAttribute('width', str(img_size))\n\n    path_num = len(path_list)\n    for path_i in range(path_num):\n        path_items = path_list[path_i]\n\n        assert len(path_items) > 0\n        if len(path_items) == 1:\n            continue\n\n        for stroke_i, stroke_item in enumerate(path_items):\n            if stroke_i == 0:\n                last_position = stroke_item[0]\n            else:\n                childElement = doc.createElement('path')\n                childElement.setAttribute('id', 'curve_' + str(path_i))\n                childElement.setAttribute('stroke', '#000000')\n                childElement.setAttribute('stroke-linejoin', 'round')\n                childElement.setAttribute('stroke-linecap', 'round')\n                childElement.setAttribute('fill', 'none')\n\n                command_str = 'M ' + str(last_position[0]) + ', ' + str(last_position[1]) + ' '\n                command_str += 'Q '\n\n                ctrl_position, stroke_position, stroke_width = stroke_item[0], stroke_item[1], stroke_item[2]\n\n                ctrl_position_0 = last_position[0] + (stroke_position[0] - last_position[0]) * ctrl_position[1]\n                ctrl_position_1 = last_position[1] + (stroke_position[1] - last_position[1]) * ctrl_position[0]\n\n                command_str += str(ctrl_position_0) + ', ' + str(ctrl_position_1) + ', ' + \\\n                               str(stroke_position[0]) + ', ' + str(stroke_position[1]) + ' '\n\n                last_position = stroke_position\n\n                childElement.setAttribute('d', command_str)\n                childElement.setAttribute('stroke-width', str(stroke_width * img_size / 1.66))\n                rootElement.appendChild(childElement)\n\n    doc.appendChild(rootElement)\n\n    f = open(save_path, 'w')\n    doc.writexml(f, addindent='  ', newl='\\n')\n    f.close()\n\n\ndef convert_strokes_to_svg(data, init_cursor, image_size, infer_lengths, init_width, save_path, svg_type,\n                           cursor_type='next', min_window_size=32, raster_size=128):\n    \"\"\"\n    :param data: (N_strokes, 7): flag, x_c, y_c, dx, dy, r, ds\n    :return:\n    \"\"\"\n    cursor_idx = 0\n\n    absolute_strokes = []\n    absolute_strokes_path = []\n\n    if init_cursor.ndim == 1:\n        init_cursor = [init_cursor]\n\n    for round_idx in range(len(infer_lengths)):\n        round_length = infer_lengths[round_idx]\n\n        cursor_pos = init_cursor[cursor_idx]  # (2)\n        cursor_idx += 1\n\n        cursor_pos_large = cursor_pos * float(image_size)\n\n        if len(absolute_strokes_path) > 0:\n            absolute_strokes.append(absolute_strokes_path)\n        absolute_strokes_path = [[cursor_pos_large]]\n\n        prev_width = init_width\n        prev_scaling = 1.0\n        prev_window_size = float(raster_size)  # (1)\n\n        for round_inner_i in range(round_length):\n            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i\n\n            curr_window_size_raw = prev_scaling * prev_window_size\n            curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)\n            curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)\n            # curr_window_size = int(round(curr_window_size_raw))  # ()\n\n            stroke_params = data[stroke_idx, 1:]  # (6)\n            pen_state = data[stroke_idx, 0]\n\n            next_width = stroke_params[4]\n            next_scaling = stroke_params[5]\n\n            next_width_abs = next_width * curr_window_size_raw / float(image_size)\n\n            prev_scaling = next_scaling\n            prev_window_size = curr_window_size_raw\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0)  # (1, 6), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)\n            cursor_pos_large = cursor_pos * float(image_size)\n            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level\n\n            if pen_state == 0:\n                absolute_strokes_path.append([stroke_params[0:2], stroke_position_next, next_width_abs])\n            else:\n                absolute_strokes.append(absolute_strokes_path)\n                absolute_strokes_path = [[stroke_position_next]]\n\n            if cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level\n            cursor_pos = cursor_pos_large / float(image_size)\n\n    absolute_strokes.append(absolute_strokes_path)\n\n    if svg_type == 'cluster':\n        write_svg_1(absolute_strokes, image_size, save_path)\n    elif svg_type == 'single':\n        write_svg_2(absolute_strokes, image_size, save_path)\n    else:\n        raise Exception('Unknown svg_type', svg_type)\n\n\ndef data_convert_to_absolute(npz_path, svg_type):\n    assert npz_path != ''\n    assert svg_type in ['single', 'cluster']\n\n    min_window_size = 32\n    raster_size = 128\n\n    split_idx = npz_path.rfind('/')\n    if split_idx == -1:\n        file_base = './'\n        file_name = npz_path[:-4]\n    else:\n        file_base = npz_path[:npz_path.rfind('/')]\n        file_name = npz_path[npz_path.rfind('/') + 1: -4]\n\n    svg_data_base = os.path.join(file_base, file_name)\n    os.makedirs(svg_data_base, exist_ok=True)\n\n    data = np.load(npz_path, encoding='latin1', allow_pickle=True)\n    strokes_data = data['strokes_data']\n    init_cursors = data['init_cursors']\n    image_size = data['image_size']\n    round_length = data['round_length']\n    init_width = data['init_width']\n\n    if round_length.ndim == 0:\n        round_lengths = [round_length]\n    else:\n        round_lengths = round_length\n\n    save_path = os.path.join(svg_data_base, str(svg_type) + '.svg')\n\n    convert_strokes_to_svg(strokes_data, init_cursors, image_size, round_lengths, init_width,\n                           min_window_size=min_window_size, raster_size=raster_size, save_path=save_path,\n                           svg_type=svg_type)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--file', '-f', type=str, default='', help=\"define a npz path\")\n    parser.add_argument('--svg_type', '-st', type=str, choices=['single', 'cluster'], default='single',\n                        help=\"svg type\")\n    args = parser.parse_args()\n\n    data_convert_to_absolute(args.file, args.svg_type)\n"
  },
  {
    "path": "tools/visualize_drawing.py",
    "content": "import os\nimport sys\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport tensorflow as tf\n\nsys.path.append('./')\nfrom utils import get_colors, draw, image_pasting_v3_testing\nfrom model_common_test import DiffPastingV3\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0'\n\n\ndef display_strokes_final(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,\n                          save_base,\n                          cursor_type='next', min_window_size=32, raster_size=128):\n    \"\"\"\n    :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2\n    :return:\n    \"\"\"\n    canvas = np.zeros((image_size, image_size), dtype=np.float32)  # [0.0-BG, 1.0-stroke]\n    drawn_region = np.zeros_like(canvas)\n    overlap_region = np.zeros_like(canvas)\n    canvas_color_with_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)\n    canvas_color_wo_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)\n    canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)\n\n    cursor_idx = 0\n\n    if init_cursor.ndim == 1:\n        init_cursor = [init_cursor]\n\n    stroke_count = len(data)\n    color_rgb_set = get_colors(stroke_count)  # list of (3,) in [0, 255]\n    color_idx = 0\n\n    valid_stroke_count = stroke_count - np.sum(data[:, 0]).astype(np.int32) + len(init_cursor)\n    valid_color_rgb_set = get_colors(valid_stroke_count)  # list of (3,) in [0, 255]\n    valid_color_idx = -1\n\n    # print('Drawn stroke number', valid_stroke_count)\n    # print('    flag  x1\\t\\t y1\\t\\t x2\\t\\t y2\\t\\t r2\\t\\t s2')\n\n    for round_idx in range(len(infer_lengths)):\n        round_length = infer_lengths[round_idx]\n\n        cursor_pos = init_cursor[cursor_idx]  # (2)\n        cursor_idx += 1\n\n        prev_width = init_width\n        prev_scaling = 1.0\n        prev_window_size = float(raster_size)  # (1)\n\n        for round_inner_i in range(round_length):\n            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i\n\n            curr_window_size_raw = prev_scaling * prev_window_size\n            curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)\n            curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)\n\n            pen_state = data[stroke_idx, 0]\n            stroke_params = data[stroke_idx, 1:]  # (8)\n\n            x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width, width2], axis=0)  # (2)\n            stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1)  # (8)\n\n            next_width = stroke_params[4]\n            next_scaling = stroke_params[5]\n            next_window_size = next_scaling * curr_window_size_raw\n            next_window_size = np.maximum(next_window_size, min_window_size)\n            next_window_size = np.minimum(next_window_size, image_size)\n\n            prev_width = next_width * curr_window_size_raw / next_window_size\n            prev_scaling = next_scaling\n            prev_window_size = curr_window_size_raw\n\n            f = stroke_params_proc.tolist()  # (8)\n            f += [1.0, 1.0]\n            gt_stroke_img = draw(f)  # (H, W), [0.0-stroke, 1.0-BG]\n\n            gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,\n                                                           image_size,\n                                                           curr_window_size_raw,\n                                                           pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n\n            is_overlap = False\n\n            if pen_state == 0:\n                canvas += gt_stroke_img_large  # [0.0-BG, 1.0-stroke]\n\n                curr_drawn_stroke_region = np.zeros_like(gt_stroke_img_large)\n                curr_drawn_stroke_region[gt_stroke_img_large > 0.5] = 1\n                intersection = drawn_region * curr_drawn_stroke_region\n                # regard stroke with >50% overlap area as overlaped stroke\n                if np.sum(intersection) / np.sum(curr_drawn_stroke_region) > 0.5:\n                    # enlarge the stroke a bit for better visualization\n                    overlap_region[gt_stroke_img_large > 0] += 1\n                    is_overlap = True\n\n                drawn_region[gt_stroke_img_large > 0.5] = 1\n\n            color_rgb = color_rgb_set[color_idx]  # (3) in [0, 255]\n            color_idx += 1\n\n            color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)\n            color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)\n            canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),\n                                                                                 axis=-1) + color_stroke  # (H, W, 3)\n\n            if pen_state == 0:\n                valid_color_idx += 1\n\n            if pen_state == 0:\n                valid_color_rgb = valid_color_rgb_set[valid_color_idx]  # (3) in [0, 255]\n                # valid_color_idx += 1\n\n                valid_color_rgb = np.reshape(valid_color_rgb, (1, 1, 3)).astype(np.float32)\n                valid_color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - valid_color_rgb / 255.0)\n                canvas_color_with_overlap = canvas_color_with_overlap * np.expand_dims((1.0 - gt_stroke_img_large),\n                                                                                       axis=-1) + valid_color_stroke  # (H, W, 3)\n                if not is_overlap:\n                    canvas_color_wo_overlap = canvas_color_wo_overlap * np.expand_dims((1.0 - gt_stroke_img_large),\n                                                                                       axis=-1) + valid_color_stroke  # (H, W, 3)\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0)  # (1, 6), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)\n\n            cursor_pos_large = cursor_pos * float(image_size)\n\n            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level\n\n            if cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level\n            cursor_pos = cursor_pos_large / float(image_size)\n\n    canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)\n    canvas_black = 255 - np.round(canvas_rgb * 255.0).astype(np.uint8)\n    canvas_color_with_overlap = 255 - np.round(canvas_color_with_overlap * 255.0).astype(np.uint8)\n    canvas_color_wo_overlap = 255 - np.round(canvas_color_wo_overlap * 255.0).astype(np.uint8)\n    canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)\n\n    canvas_black_png = Image.fromarray(canvas_black, 'RGB')\n    canvas_black_save_path = os.path.join(save_base, 'output_rendered.png')\n    canvas_black_png.save(canvas_black_save_path, 'PNG')\n\n    canvas_color_png = Image.fromarray(canvas_color_with_overlap, 'RGB')\n    canvas_color_save_path = os.path.join(save_base, 'output_order_with_overlap.png')\n    canvas_color_png.save(canvas_color_save_path, 'PNG')\n\n    canvas_color_wo_png = Image.fromarray(canvas_color_wo_overlap, 'RGB')\n    canvas_color_wo_save_path = os.path.join(save_base, 'output_order_wo_overlap.png')\n    canvas_color_wo_png.save(canvas_color_wo_save_path, 'PNG')\n\n    canvas_color_m_png = Image.fromarray(canvas_color_with_moving, 'RGB')\n    canvas_color_m_save_path = os.path.join(save_base, 'output_order_with_moving.png')\n    canvas_color_m_png.save(canvas_color_m_save_path, 'PNG')\n\n\ndef visualize_drawing(npz_path):\n    assert npz_path != ''\n\n    min_window_size = 32\n    raster_size = 128\n\n    split_idx = npz_path.rfind('/')\n    if split_idx == -1:\n        file_base = './'\n        file_name = npz_path[:-4]\n    else:\n        file_base = npz_path[:npz_path.rfind('/')]\n        file_name = npz_path[npz_path.rfind('/') + 1: -4]\n\n    regenerate_base = os.path.join(file_base, file_name)\n    os.makedirs(regenerate_base, exist_ok=True)\n\n    # differentiable pasting graph\n    paste_v3_func = DiffPastingV3(raster_size)\n\n    tfconfig = tf.ConfigProto()\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    data = np.load(npz_path, encoding='latin1', allow_pickle=True)\n    strokes_data = data['strokes_data']\n    init_cursors = data['init_cursors']\n    image_size = data['image_size']\n    round_length = data['round_length']\n    init_width = data['init_width']\n\n    if round_length.ndim == 0:\n        round_lengths = [round_length]\n    else:\n        round_lengths = round_length\n\n    # print('round_lengths', round_lengths)\n\n    print('Processing ...')\n    display_strokes_final(sess, paste_v3_func,\n                          strokes_data, init_cursors, image_size, round_lengths, init_width,\n                          regenerate_base,\n                          min_window_size=min_window_size, raster_size=raster_size)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--file', '-f', type=str, default='', help=\"define a npz path\")\n    args = parser.parse_args()\n\n    visualize_drawing(args.file)\n"
  },
  {
    "path": "train_rough_photograph.py",
    "content": "import json\nimport os\nimport time\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom PIL import Image\nimport argparse\n\nimport model_common_train as sketch_image_model\nfrom hyper_parameters import FLAGS, get_default_hparams_rough, get_default_hparams_normal\nfrom utils import create_summary, save_model, reset_graph, load_checkpoint\nfrom dataset_utils import load_dataset_training\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'\n\ntf.logging.set_verbosity(tf.logging.INFO)\n\n\ndef should_save_log_img(step_):\n    if step_ % 500 == 0:\n        return True\n    else:\n        return False\n\n\ndef save_log_images(sess, model, data_set, save_root, step_num, curr_photo_prob, interpolate_type, save_num=10):\n    res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1)\n    log_img_resolutions = []\n    for ii in range(save_num - 1):\n        log_img_resolutions.append(model.hps.image_size_small + ii * res_gap)\n    log_img_resolutions.append(model.hps.image_size_large)\n\n    for res_i in range(len(log_img_resolutions)):\n        resolution = log_img_resolutions[res_i]\n\n        sub_save_root = os.path.join(save_root, 'res_' + str(resolution))\n        os.makedirs(sub_save_root, exist_ok=True)\n\n        input_photos, target_sketches, init_cursors, image_size_rand = \\\n            data_set.get_batch_from_memory(memory_idx=res_i,\n                                           fixed_image_size=resolution,\n                                           random_cursor=model.hps.random_cursor,\n                                           photo_prob=curr_photo_prob,\n                                           interpolate_type=interpolate_type)\n        # input_photos: (N, image_size, image_size, 3), [0-stroke, 1-BG]\n        # target_sketches: (N, image_size, image_size), [0-stroke, 1-BG]\n        # init_cursors: (N, 1, 2), in size [0.0, 1.0)\n\n        input_photo_val = input_photos\n\n        init_cursor_input = [init_cursors for _ in range(model.total_loop)]\n        init_cursor_input = np.concatenate(init_cursor_input, axis=0)\n        image_size_input = [image_size_rand for _ in range(model.total_loop)]\n        image_size_input = np.stack(image_size_input, axis=0)\n\n        feed = {\n            model.init_cursor: init_cursor_input,\n            model.image_size: image_size_input,\n            model.init_width: [model.hps.min_width],\n        }\n        for loop_i in range(model.total_loop):\n            feed[model.input_photo_list[loop_i]] = input_photo_val\n\n        raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb],\n                                                              feed)  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]\n        raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8)\n        input_photo = (np.array(input_photo_val[0, :, :, :]) * 255.0).astype(np.uint8)\n        target_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8)\n        raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8)\n\n        pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png')\n        input_save_path = os.path.join(sub_save_root, 'input.png')\n        target_save_path = os.path.join(sub_save_root, 'gt.png')\n\n        pred_rgb_save_root = os.path.join(sub_save_root, 'rgb')\n        os.makedirs(pred_rgb_save_root, exist_ok=True)\n        pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png')\n\n        raster_images_pred = Image.fromarray(raster_images_pred, 'L')\n        raster_images_pred.save(pred_save_path, 'PNG')\n        input_photo = Image.fromarray(input_photo, 'RGB')\n        input_photo.save(input_save_path, 'PNG')\n        target_sketch = Image.fromarray(target_sketch, 'L')\n        target_sketch.save(target_save_path, 'PNG')\n        raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB')\n        raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG')\n\n\ndef train(sess, train_model, eval_sample_model, train_set, valid_set, sub_log_root, sub_snapshot_root, sub_log_img_root):\n    # Setup summary writer.\n    summary_writer = tf.summary.FileWriter(sub_log_root)\n\n    print('-' * 100)\n\n    # Calculate trainable params.\n    t_vars = tf.trainable_variables()\n    count_t_vars = 0\n    for var in t_vars:\n        num_param = np.prod(var.get_shape().as_list())\n        count_t_vars += num_param\n        print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param))\n    print('Total trainable variables %i.' % count_t_vars)\n    print('-' * 100)\n\n    # main train loop\n\n    hps = train_model.hps\n    start = time.time()\n\n    # create saver\n    snapshot_save_vars = [var for var in tf.global_variables()\n                          if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]\n    saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20)\n\n    start_step = 1\n    print('start_step', start_step)\n\n    mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))]\n\n    for _ in range(start_step, hps.num_steps + 1):\n        step = sess.run(train_model.global_step)  # start from 0\n\n        count_step = min(step, hps.num_steps)\n        curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) *\n                              (1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate)\n\n        if hps.sn_loss_type == 'decreasing':\n            assert hps.decrease_stop_steps <= hps.num_steps\n            assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight\n            curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps)\n            curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k\n            curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end)\n        elif hps.sn_loss_type == 'fixed':\n            curr_stroke_num_loss_weight = hps.stroke_num_loss_weight\n        elif hps.sn_loss_type == 'increasing':\n            curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps)\n            curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k\n        else:\n            raise Exception('Unknown sn_loss_type', hps.sn_loss_type)\n\n        if hps.early_pen_loss_type == 'head':\n            curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length\n\n            curr_early_pen_loss_start = 1\n            curr_early_pen_loss_end = curr_early_pen_loss_len\n        elif hps.early_pen_loss_type == 'tail':\n            curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2\n\n            curr_early_pen_loss_end = hps.max_seq_len\n            curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len\n        elif hps.early_pen_loss_type == 'move':\n            curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2\n\n            curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len\n            curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2\n        else:\n            raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type)\n        curr_early_pen_loss_start = int(round(curr_early_pen_loss_start))\n        curr_early_pen_loss_end = int(round(curr_early_pen_loss_end))\n\n        if hps.photo_prob_type == 'increasing' or hps.photo_prob_type == 'interpolate':\n            assert hps.photo_prob_end_step >= hps.photo_prob_start_step\n            curr_photo_prob_k = 1.0 / float(hps.photo_prob_end_step - hps.photo_prob_start_step)\n            curr_photo_prob = (count_step - hps.photo_prob_start_step) * curr_photo_prob_k\n            curr_photo_prob = max(0.0, curr_photo_prob)\n            curr_photo_prob = min(1.0, curr_photo_prob)\n            interpolate_type = 'prob' if hps.photo_prob_type == 'increasing' else 'image'\n        elif hps.photo_prob_type == 'zero':\n            curr_photo_prob = 0.0\n            interpolate_type = 'prob'\n        elif hps.photo_prob_type == 'one':\n            curr_photo_prob = 1.0\n            interpolate_type = 'prob'\n        else:\n            raise Exception('Unknown photo_prob_type', hps.photo_prob_type)\n\n        input_photos, target_sketches, init_cursors, image_sizes = \\\n            train_set.get_batch_multi_res(loop_num=train_model.total_loop,\n                                          random_cursor=hps.random_cursor,\n                                          photo_prob=curr_photo_prob,\n                                          interpolate_type=interpolate_type)\n        # input_photos: list of (N, image_size, image_size, 3), [0-stroke, 1-BG]\n        # target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG]\n        # init_cursors: list of (N, 1, 2), in size [0.0, 1.0)\n\n        init_cursors_input = np.concatenate(init_cursors, axis=0)\n        image_size_input = np.stack(image_sizes, axis=0)\n\n        feed = {\n            train_model.init_cursor: init_cursors_input,\n            train_model.image_size: image_size_input,\n            train_model.init_width: [hps.min_width],\n\n            train_model.lr: curr_learning_rate,\n            train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight,\n            train_model.early_pen_loss_start_idx: curr_early_pen_loss_start,\n            train_model.early_pen_loss_end_idx: curr_early_pen_loss_end,\n\n            train_model.last_step_num: float(step),\n        }\n        for layer_i in range(len(hps.perc_loss_layers)):\n            feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i]\n\n        for loop_i in range(train_model.total_loop):\n            input_photo_val = input_photos[loop_i]\n            target_sketch_val = target_sketches[loop_i]\n            feed[train_model.input_photo_list[loop_i]] = input_photo_val\n            feed[train_model.target_sketch_list[loop_i]] = np.expand_dims(target_sketch_val, axis=-1)\n\n        (train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm,\n         stroke_num_cost, early_pen_states_cost,\n         pos_outside_cost, win_size_outside_cost,\n         train_step) = sess.run([\n            train_model.cost, train_model.raster_cost,\n            train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm,\n            train_model.stroke_num_cost,\n            train_model.early_pen_states_cost,\n            train_model.pos_outside_cost, train_model.win_size_outside_cost,\n            train_model.global_step\n         ], feed)\n\n        ## update mean_raster_loss\n        for layer_i in range(len(hps.perc_loss_layers)):\n            perc_relu_cost_raw = perc_relu_costs_raw[layer_i]\n            mean_perc_relu_loss = mean_perc_relu_losses[layer_i]\n            mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1)\n            mean_perc_relu_losses[layer_i] = mean_perc_relu_loss\n\n        _ = sess.run(train_model.train_op, feed)\n\n        if step % 20 == 0 and step > 0:\n            end = time.time()\n            time_taken = end - start\n\n            train_summary_map = {\n                'Train_Cost': train_cost,\n                'Train_raster_Cost': raster_cost,\n                'Train_stroke_num_Cost': stroke_num_cost,\n                'Train_early_pen_states_cost': early_pen_states_cost,\n                'Train_pos_outside_Cost': pos_outside_cost,\n                'Train_win_size_outside_Cost': win_size_outside_cost,\n                'Learning_Rate': curr_learning_rate,\n                'Time_Taken_Train': time_taken\n            }\n            for layer_i in range(len(hps.perc_loss_layers)):\n                layer_name = hps.perc_loss_layers[layer_i]\n                train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i]\n\n            create_summary(summary_writer, train_summary_map, train_step)\n\n            output_format = ('step: %d, lr: %.6f, '\n                             'snw: %.3f, '\n                             'cost: %.4f, '\n                             'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, '\n                             'pos_outside: %.4f, win_outside: %.4f, '\n                             'train_time_taken: %.1f')\n            output_values = (step, curr_learning_rate,\n                             curr_stroke_num_loss_weight,\n                             train_cost,\n                             raster_cost, stroke_num_cost, early_pen_states_cost,\n                             pos_outside_cost, win_size_outside_cost,\n                             time_taken)\n            output_log = output_format % output_values\n            # print(output_log)\n            tf.logging.info(output_log)\n            start = time.time()\n\n        if should_save_log_img(step) and step > 0:\n            save_log_images(sess, eval_sample_model, valid_set, sub_log_img_root, step, curr_photo_prob, interpolate_type)\n\n        if step % hps.save_every == 0 and step > 0:\n            save_model(sess, saver, sub_snapshot_root, step)\n\n\ndef trainer(model_params):\n    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)\n\n    print('Hyperparams:')\n    for key, val in six.iteritems(model_params.values()):\n        print('%s = %s' % (key, str(val)))\n    print('Loading data files.')\n    print('-' * 100)\n\n    datasets = load_dataset_training(FLAGS.dataset_dir, model_params)\n\n    sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name)\n    sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name)\n    sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name)\n\n    train_set = datasets[0]\n    valid_set = datasets[1]\n    train_model_params = datasets[2]\n    eval_sample_model_params = datasets[3]\n\n    eval_sample_model_params.loop_per_gpu = 1\n    eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu\n\n    reset_graph()\n    train_model = sketch_image_model.VirtualSketchingModel(train_model_params)\n    eval_sample_model = sketch_image_model.VirtualSketchingModel(eval_sample_model_params, reuse=True)\n\n    tfconfig = tf.ConfigProto(allow_soft_placement=True)\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True)\n    if train_model_params.raster_loss_base_type == 'perceptual':\n        load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True)\n\n    # Write config file to json file.\n    os.makedirs(sub_log_root, exist_ok=True)\n    os.makedirs(sub_log_img_root, exist_ok=True)\n    os.makedirs(sub_snapshot_root, exist_ok=True)\n    with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f:\n        json.dump(train_model_params.values(), f, indent=True)\n\n    train(sess, train_model, eval_sample_model, train_set, valid_set,\n          sub_log_root, sub_snapshot_root, sub_log_img_root)\n\n\ndef main(dataset_type):\n    if dataset_type == 'rough':\n        model_params = get_default_hparams_rough()\n    elif dataset_type == 'face':\n        model_params = get_default_hparams_normal()\n    else:\n        raise Exception('Unknown dataset_type:', dataset_type)\n\n    trainer(model_params)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data', '-d', type=str, default='rough', choices=['rough', 'face'], help=\"The dataset type.\")\n    args = parser.parse_args()\n\n    main(args.data)\n"
  },
  {
    "path": "train_vectorization.py",
    "content": "import json\nimport os\nimport time\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom PIL import Image\n\nimport model_common_train as sketch_vector_model\nfrom hyper_parameters import FLAGS, get_default_hparams_clean\nfrom utils import create_summary, save_model, reset_graph, load_checkpoint\nfrom dataset_utils import load_dataset_training\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'\n\ntf.logging.set_verbosity(tf.logging.INFO)\n\n\ndef should_save_log_img(step_):\n    if step_ % 500 == 0:\n        return True\n    else:\n        return False\n\n\ndef save_log_images(sess, model, data_set, save_root, step_num, save_num=10):\n    res_gap = (model.hps.image_size_large - model.hps.image_size_small) // (save_num - 1)\n    log_img_resolutions = []\n    for ii in range(save_num - 1):\n        log_img_resolutions.append(model.hps.image_size_small + ii * res_gap)\n    log_img_resolutions.append(model.hps.image_size_large)\n\n    for res_i in range(len(log_img_resolutions)):\n        resolution = log_img_resolutions[res_i]\n\n        sub_save_root = os.path.join(save_root, 'res_' + str(resolution))\n        os.makedirs(sub_save_root, exist_ok=True)\n\n        input_photos, target_sketches, init_cursors, image_size_rand = \\\n            data_set.get_batch_from_memory(memory_idx=res_i, vary_thickness=model.hps.vary_thickness,\n                                           fixed_image_size=resolution,\n                                           random_cursor=model.hps.random_cursor,\n                                           init_cursor_on_undrawn_pixel=model.hps.init_cursor_on_undrawn_pixel)\n        # input_photos: (N, image_size, image_size), [0-stroke, 1-BG]\n        # target_sketches: (N, image_size, image_size), [0-stroke, 1-BG]\n        # init_cursors: (N, 1, 2), in size [0.0, 1.0)\n\n        if input_photos is not None:\n            input_photo_val = np.expand_dims(input_photos, axis=-1)\n        else:\n            input_photo_val = np.expand_dims(target_sketches, axis=-1)\n\n        init_cursor_input = [init_cursors for _ in range(model.total_loop)]\n        init_cursor_input = np.concatenate(init_cursor_input, axis=0)\n        image_size_input = [image_size_rand for _ in range(model.total_loop)]\n        image_size_input = np.stack(image_size_input, axis=0)\n\n        feed = {\n            model.init_cursor: init_cursor_input,\n            model.image_size: image_size_input,\n            model.init_width: [model.hps.min_width],\n        }\n        for loop_i in range(model.total_loop):\n            feed[model.input_photo_list[loop_i]] = input_photo_val\n\n        raster_images_pred, raster_images_pred_rgb = sess.run([model.pred_raster_imgs, model.pred_raster_imgs_rgb],\n                                                              feed)  # (N, image_size, image_size), [0.0-stroke, 1.0-BG]\n        raster_images_pred = (np.array(raster_images_pred[0]) * 255.0).astype(np.uint8)\n        input_sketch = (np.array(target_sketches[0]) * 255.0).astype(np.uint8)\n        raster_images_pred_rgb = (np.array(raster_images_pred_rgb[0]) * 255.0).astype(np.uint8)\n\n        pred_save_path = os.path.join(sub_save_root, str(step_num) + '.png')\n        target_save_path = os.path.join(sub_save_root, 'gt.png')\n\n        pred_rgb_save_root = os.path.join(sub_save_root, 'rgb')\n        os.makedirs(pred_rgb_save_root, exist_ok=True)\n        pred_rgb_save_path = os.path.join(pred_rgb_save_root, str(step_num) + '.png')\n\n        raster_images_pred = Image.fromarray(raster_images_pred, 'L')\n        raster_images_pred.save(pred_save_path, 'PNG')\n        input_sketch = Image.fromarray(input_sketch, 'L')\n        input_sketch.save(target_save_path, 'PNG')\n        raster_images_pred_rgb = Image.fromarray(raster_images_pred_rgb, 'RGB')\n        raster_images_pred_rgb.save(pred_rgb_save_path, 'PNG')\n\n\ndef train(sess, train_model, eval_sample_model, train_set, val_set, sub_log_root, sub_snapshot_root, sub_log_img_root):\n    # Setup summary writer.\n    summary_writer = tf.summary.FileWriter(sub_log_root)\n\n    print('-' * 100)\n\n    # Calculate trainable params.\n    t_vars = tf.trainable_variables()\n    count_t_vars = 0\n    for var in t_vars:\n        num_param = np.prod(var.get_shape().as_list())\n        count_t_vars += num_param\n        print('%s | shape: %s | num_param: %i' % (var.name, str(var.get_shape()), num_param))\n    print('Total trainable variables %i.' % count_t_vars)\n    print('-' * 100)\n\n    # main train loop\n\n    hps = train_model.hps\n    start = time.time()\n\n    # create saver\n    snapshot_save_vars = [var for var in tf.global_variables()\n                          if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]\n    saver = tf.train.Saver(var_list=snapshot_save_vars, max_to_keep=20)\n\n    start_step = 1\n    print('start_step', start_step)\n\n    mean_perc_relu_losses = [0.0 for _ in range(len(hps.perc_loss_layers))]\n\n    for _ in range(start_step, hps.num_steps + 1):\n        step = sess.run(train_model.global_step)  # start from 0\n\n        count_step = min(step, hps.num_steps)\n        curr_learning_rate = ((hps.learning_rate - hps.min_learning_rate) *\n                              (1 - count_step / hps.num_steps) ** hps.decay_power + hps.min_learning_rate)\n\n        if hps.sn_loss_type == 'decreasing':\n            assert hps.decrease_stop_steps <= hps.num_steps\n            assert hps.stroke_num_loss_weight_end <= hps.stroke_num_loss_weight\n            curr_sn_k = (hps.stroke_num_loss_weight - hps.stroke_num_loss_weight_end) / float(hps.decrease_stop_steps)\n            curr_stroke_num_loss_weight = hps.stroke_num_loss_weight - count_step * curr_sn_k\n            curr_stroke_num_loss_weight = max(curr_stroke_num_loss_weight, hps.stroke_num_loss_weight_end)\n        elif hps.sn_loss_type == 'fixed':\n            curr_stroke_num_loss_weight = hps.stroke_num_loss_weight\n        elif hps.sn_loss_type == 'increasing':\n            curr_sn_k = hps.stroke_num_loss_weight / float(hps.num_steps - hps.increase_start_steps)\n            curr_stroke_num_loss_weight = max(count_step - hps.increase_start_steps, 0) * curr_sn_k\n        else:\n            raise Exception('Unknown sn_loss_type', hps.sn_loss_type)\n\n        if hps.early_pen_loss_type == 'head':\n            curr_early_pen_k = (hps.max_seq_len - hps.early_pen_length) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.early_pen_length\n\n            curr_early_pen_loss_start = 1\n            curr_early_pen_loss_end = curr_early_pen_loss_len\n        elif hps.early_pen_loss_type == 'tail':\n            curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2\n\n            curr_early_pen_loss_end = hps.max_seq_len\n            curr_early_pen_loss_start = curr_early_pen_loss_end - curr_early_pen_loss_len\n        elif hps.early_pen_loss_type == 'move':\n            curr_early_pen_k = (hps.max_seq_len // 2 - 1) / float(hps.num_steps)\n            curr_early_pen_loss_len = count_step * curr_early_pen_k + hps.max_seq_len // 2\n\n            curr_early_pen_loss_start = hps.max_seq_len - curr_early_pen_loss_len\n            curr_early_pen_loss_end = curr_early_pen_loss_start + hps.max_seq_len // 2\n        else:\n            raise Exception('Unknown early_pen_loss_type', hps.early_pen_loss_type)\n        curr_early_pen_loss_start = int(round(curr_early_pen_loss_start))\n        curr_early_pen_loss_end = int(round(curr_early_pen_loss_end))\n\n        input_photos, target_sketches, init_cursors, image_sizes = \\\n            train_set.get_batch_multi_res(loop_num=train_model.total_loop, vary_thickness=hps.vary_thickness,\n                                          random_cursor=hps.random_cursor,\n                                          init_cursor_on_undrawn_pixel=hps.init_cursor_on_undrawn_pixel)\n        # input_photos: list of (N, image_size, image_size), [0-stroke, 1-BG]\n        # target_sketches: list of (N, image_size, image_size), [0-stroke, 1-BG]\n        # init_cursors: list of (N, 1, 2), in size [0.0, 1.0)\n\n        init_cursors_input = np.concatenate(init_cursors, axis=0)\n        image_size_input = np.stack(image_sizes, axis=0)\n\n        feed = {\n            train_model.init_cursor: init_cursors_input,\n            train_model.image_size: image_size_input,\n            train_model.init_width: [hps.min_width],\n\n            train_model.lr: curr_learning_rate,\n            train_model.stroke_num_loss_weight: curr_stroke_num_loss_weight,\n            train_model.early_pen_loss_start_idx: curr_early_pen_loss_start,\n            train_model.early_pen_loss_end_idx: curr_early_pen_loss_end,\n\n            train_model.last_step_num: float(step),\n        }\n        for layer_i in range(len(hps.perc_loss_layers)):\n            feed[train_model.perc_loss_mean_list[layer_i]] = mean_perc_relu_losses[layer_i]\n\n        for loop_i in range(train_model.total_loop):\n            if input_photos is not None:\n                input_photo_val = np.expand_dims(input_photos[loop_i], axis=-1)\n            else:\n                input_photo_val = np.expand_dims(target_sketches[loop_i], axis=-1)\n            feed[train_model.input_photo_list[loop_i]] = input_photo_val\n\n        (train_cost, raster_cost, perc_relu_costs_raw, perc_relu_costs_norm,\n         stroke_num_cost, early_pen_states_cost,\n         pos_outside_cost, win_size_outside_cost,\n         train_step) = sess.run([\n            train_model.cost, train_model.raster_cost,\n            train_model.perc_relu_losses_raw, train_model.perc_relu_losses_norm,\n            train_model.stroke_num_cost,\n            train_model.early_pen_states_cost,\n            train_model.pos_outside_cost, train_model.win_size_outside_cost,\n            train_model.global_step\n         ], feed)\n\n        ## update mean_raster_loss\n        for layer_i in range(len(hps.perc_loss_layers)):\n            perc_relu_cost_raw = perc_relu_costs_raw[layer_i]\n            mean_perc_relu_loss = mean_perc_relu_losses[layer_i]\n            mean_perc_relu_loss = (mean_perc_relu_loss * step + perc_relu_cost_raw) / float(step + 1)\n            mean_perc_relu_losses[layer_i] = mean_perc_relu_loss\n\n        _ = sess.run(train_model.train_op, feed)\n\n        if step % 20 == 0 and step > 0:\n            end = time.time()\n            time_taken = end - start\n\n            train_summary_map = {\n                'Train_Cost': train_cost,\n                'Train_raster_Cost': raster_cost,\n                'Train_stroke_num_Cost': stroke_num_cost,\n                'Train_early_pen_states_cost': early_pen_states_cost,\n                'Train_pos_outside_Cost': pos_outside_cost,\n                'Train_win_size_outside_Cost': win_size_outside_cost,\n                'Learning_Rate': curr_learning_rate,\n                'Time_Taken_Train': time_taken\n            }\n            for layer_i in range(len(hps.perc_loss_layers)):\n                layer_name = hps.perc_loss_layers[layer_i]\n                train_summary_map['Train_raster_Cost_' + layer_name] = perc_relu_costs_raw[layer_i]\n\n            create_summary(summary_writer, train_summary_map, train_step)\n\n            output_format = ('step: %d, lr: %.6f, '\n                             'snw: %.3f, '\n                             'cost: %.4f, '\n                             'ras: %.4f, stroke_num: %.4f, early_pen: %.4f, '\n                             'pos_outside: %.4f, win_outside: %.4f, '\n                             'train_time_taken: %.1f')\n            output_values = (step, curr_learning_rate,\n                             curr_stroke_num_loss_weight,\n                             train_cost,\n                             raster_cost, stroke_num_cost, early_pen_states_cost,\n                             pos_outside_cost, win_size_outside_cost,\n                             time_taken)\n            output_log = output_format % output_values\n            # print(output_log)\n            tf.logging.info(output_log)\n            start = time.time()\n\n        if should_save_log_img(step) and step > 0:\n            save_log_images(sess, eval_sample_model, val_set, sub_log_img_root, step)\n\n        if step % hps.save_every == 0 and step > 0:\n            save_model(sess, saver, sub_snapshot_root, step)\n\n\ndef trainer(model_params):\n    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)\n\n    print('Hyperparams:')\n    for key, val in six.iteritems(model_params.values()):\n        print('%s = %s' % (key, str(val)))\n    print('Loading data files.')\n    print('-' * 100)\n\n    datasets = load_dataset_training(FLAGS.dataset_dir, model_params)\n\n    sub_snapshot_root = os.path.join(FLAGS.snapshot_root, model_params.program_name)\n    sub_log_root = os.path.join(FLAGS.log_root, model_params.program_name)\n    sub_log_img_root = os.path.join(FLAGS.log_img_root, model_params.program_name)\n\n    train_set = datasets[0]\n    val_set = datasets[1]\n    train_model_params = datasets[2]\n    eval_sample_model_params = datasets[3]\n\n    eval_sample_model_params.loop_per_gpu = 1\n    eval_sample_model_params.batch_size = len(eval_sample_model_params.gpus) * eval_sample_model_params.loop_per_gpu\n\n    reset_graph()\n    train_model = sketch_vector_model.VirtualSketchingModel(train_model_params)\n    eval_sample_model = sketch_vector_model.VirtualSketchingModel(eval_sample_model_params, reuse=True)\n\n    tfconfig = tf.ConfigProto(allow_soft_placement=True)\n    tfconfig.gpu_options.allow_growth = True\n    sess = tf.InteractiveSession(config=tfconfig)\n    sess.run(tf.global_variables_initializer())\n\n    load_checkpoint(sess, FLAGS.neural_renderer_path, ras_only=True)\n    if train_model_params.raster_loss_base_type == 'perceptual':\n        load_checkpoint(sess, FLAGS.perceptual_model_root, perceptual_only=True)\n\n    # Write config file to json file.\n    os.makedirs(sub_log_root, exist_ok=True)\n    os.makedirs(sub_log_img_root, exist_ok=True)\n    os.makedirs(sub_snapshot_root, exist_ok=True)\n    with tf.gfile.Open(os.path.join(sub_snapshot_root, 'model_config.json'), 'w') as f:\n        json.dump(train_model_params.values(), f, indent=True)\n\n    train(sess, train_model, eval_sample_model, train_set, val_set,\n          sub_log_root, sub_snapshot_root, sub_log_img_root)\n\n\ndef main():\n    model_params = get_default_hparams_clean()\n    trainer(model_params)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils.py",
    "content": "import os\nimport cv2\nimport json\nimport numpy as np\nimport tensorflow as tf\nfrom PIL import Image\nimport matplotlib.pyplot as plt\n\n\n#############################################\n# Tensorflow utils\n#############################################\n\ndef reset_graph():\n    \"\"\"Closes the current default session and resets the graph.\"\"\"\n    sess = tf.get_default_session()\n    if sess:\n        sess.close()\n    tf.reset_default_graph()\n\n\ndef load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_only=False, gen_model_pretrain=False,\n                    train_entire=False):\n    if ras_only:\n        load_var = {var.op.name: var for var in tf.global_variables() if 'raster_unit' in var.op.name}\n    elif perceptual_only:\n        load_var = {var.op.name: var for var in tf.global_variables() if 'VGG16' in var.op.name}\n    elif train_entire:\n        load_var = {var.op.name: var for var in tf.global_variables()\n                    if 'discriminator' not in var.op.name\n                    and 'raster_unit' not in var.op.name\n                    and 'VGG16' not in var.op.name\n                    and 'beta1' not in var.op.name\n                    and 'beta2' not in var.op.name\n                    and 'global_step' not in var.op.name\n                    and 'Entire' not in var.op.name\n                    }\n    else:\n        if gen_model_pretrain:\n            load_var = {var.op.name: var for var in tf.global_variables()\n                        if 'discriminator' not in var.op.name\n                        and 'raster_unit' not in var.op.name\n                        and 'VGG16' not in var.op.name\n                        and 'beta1' not in var.op.name\n                        and 'beta2' not in var.op.name\n                        # and 'global_step' not in var.op.name\n                        }\n        else:\n            load_var = tf.global_variables()\n\n    restorer = tf.train.Saver(load_var)\n    if not ras_only:\n        ckpt = tf.train.get_checkpoint_state(checkpoint_path)\n        model_checkpoint_path = ckpt.model_checkpoint_path\n    else:\n        model_checkpoint_path = checkpoint_path\n    print('Loading model %s' % model_checkpoint_path)\n    restorer.restore(sess, model_checkpoint_path)\n\n    snapshot_step = model_checkpoint_path[model_checkpoint_path.rfind('-') + 1:]\n    return snapshot_step\n\n\ndef create_summary(summary_writer, summ_map, step):\n    for summ_key in summ_map:\n        summ_value = summ_map[summ_key]\n        summ = tf.summary.Summary()\n        summ.value.add(tag=summ_key, simple_value=float(summ_value))\n        summary_writer.add_summary(summ, step)\n    summary_writer.flush()\n\n\ndef save_model(sess, saver, model_save_path, global_step):\n    checkpoint_path = os.path.join(model_save_path, 'p2s')\n    print('saving model %s.' % checkpoint_path)\n    print('global_step %i.' % global_step)\n    saver.save(sess, checkpoint_path, global_step=global_step)\n\n\n#############################################\n# Utils for basic image processing\n#############################################\n\n\ndef normal(x, width):\n    return (int)(x * (width - 1) + 0.5)\n\n\ndef draw(f, width=128):\n    x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f\n    x1 = x0 + (x2 - x0) * x1\n    y1 = y0 + (y2 - y0) * y1\n    x0 = normal(x0, width * 2)\n    x1 = normal(x1, width * 2)\n    x2 = normal(x2, width * 2)\n    y0 = normal(y0, width * 2)\n    y1 = normal(y1, width * 2)\n    y2 = normal(y2, width * 2)\n    z0 = (int)(1 + z0 * width // 2)\n    z2 = (int)(1 + z2 * width // 2)\n    canvas = np.zeros([width * 2, width * 2]).astype('float32')\n    tmp = 1. / 100\n    for i in range(100):\n        t = i * tmp\n        x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)\n        y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)\n        z = (int)((1-t) * z0 + t * z2)\n        w = (1-t) * w0 + t * w2\n        cv2.circle(canvas, (y, x), z, w, -1)\n    return 1 - cv2.resize(canvas, dsize=(width, width))\n\n\ndef rgb_trans(split_num, break_values):\n    slice_per_split = split_num // 8\n    break_values_head, break_values_tail = break_values[:-1], break_values[1:]\n\n    results = []\n\n    for split_i in range(8):\n        break_value_head = break_values_head[split_i]\n        break_value_tail = break_values_tail[split_i]\n\n        slice_gap = float(break_value_tail - break_value_head) / float(slice_per_split)\n        for slice_i in range(slice_per_split):\n            slice_val = break_value_head + slice_gap * slice_i\n            slice_val = int(round(slice_val))\n            results.append(slice_val)\n\n    return results\n\n\ndef get_colors(color_num):\n    split_num = (color_num // 8 + 1) * 8\n\n    r_break_values = [0, 0, 0, 0, 128, 255, 255, 255, 128]\n    g_break_values = [0, 0, 128, 255, 255, 255, 128, 0, 0]\n    b_break_values = [128, 255, 255, 255, 128, 0, 0, 0, 0]\n\n    r_rst_list = rgb_trans(split_num, r_break_values)\n    g_rst_list = rgb_trans(split_num, g_break_values)\n    b_rst_list = rgb_trans(split_num, b_break_values)\n\n    assert len(r_rst_list) == len(g_rst_list)\n    assert len(b_rst_list) == len(g_rst_list)\n\n    rgb_color_list = [(r_rst_list[i], g_rst_list[i], b_rst_list[i]) for i in range(len(r_rst_list))]\n    return rgb_color_list\n\n\n#############################################\n# Utils for testing\n#############################################\n\ndef save_seq_data(save_root, save_filename, strokes_data, init_cursors, image_size, round_length, init_width):\n    seq_save_root = os.path.join(save_root, 'seq_data')\n    os.makedirs(seq_save_root, exist_ok=True)\n    save_npz_path = os.path.join(seq_save_root, save_filename + '.npz')\n    np.savez(save_npz_path, strokes_data=strokes_data, init_cursors=init_cursors,\n             image_size=image_size, round_length=round_length, init_width=init_width)\n\n\ndef image_pasting_v3_testing(patch_image, cursor, image_size, window_size_f, pasting_func, sess):\n    \"\"\"\n    :param patch_image:  (raster_size, raster_size), [0.0-BG, 1.0-stroke]\n    :param cursor: (2), in size [0.0, 1.0)\n    :param window_size_f: (), float32, [0.0, image_size)\n    :return: (image_size, image_size), [0.0-BG, 1.0-stroke]\n    \"\"\"\n    cursor_pos = cursor * float(image_size)\n    pasted_image = sess.run(pasting_func.pasted_image,\n                            feed_dict={pasting_func.patch_canvas: np.expand_dims(patch_image, axis=-1),\n                                       pasting_func.cursor_pos_a: cursor_pos,\n                                       pasting_func.image_size_a: image_size,\n                                       pasting_func.window_size_a: window_size_f})\n    # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]\n    pasted_image = pasted_image[:, :, 0]\n    return pasted_image\n\n\ndef draw_strokes(data, save_root, save_filename, input_img, image_size, init_cursor, infer_lengths, init_width,\n                 cursor_type, raster_size, min_window_size,\n                 sess,\n                 pasting_func=None,\n                 save_seq=False, draw_order=False):\n    \"\"\"\n    :param data: (N_strokes, 9): flag, x1, y1, x2, y2, r2, s2\n    :return:\n    \"\"\"\n    canvas = np.zeros((image_size, image_size), dtype=np.float32)  # [0.0-BG, 1.0-stroke]\n    canvas_color = np.zeros((image_size, image_size, 3), dtype=np.float32)\n    canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)\n    frames = []\n\n    cursor_idx = 0\n\n    stroke_count = len(data)\n    color_rgb_set = get_colors(stroke_count)  # list of (3,) in [0, 255]\n    color_idx = 0\n\n    for round_idx in range(len(infer_lengths)):\n        round_length = infer_lengths[round_idx]\n\n        cursor_pos = init_cursor[cursor_idx]  # (2)\n        cursor_idx += 1\n\n        prev_width = init_width\n        prev_scaling = 1.0\n        prev_window_size = raster_size  # (1)\n\n        for round_inner_i in range(round_length):\n            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i\n\n            curr_window_size = prev_scaling * prev_window_size\n            curr_window_size = np.maximum(curr_window_size, min_window_size)\n            curr_window_size = np.minimum(curr_window_size, image_size)\n\n            pen_state = data[stroke_idx, 0]\n            stroke_params = data[stroke_idx, 1:]  # (8)\n\n            x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]\n            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]\n            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]\n            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]\n            widths = np.stack([prev_width, width2], axis=0)  # (2)\n            stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1)  # (8)\n\n            next_width = stroke_params[4]\n            next_scaling = stroke_params[5]\n            next_window_size = next_scaling * curr_window_size\n            next_window_size = np.maximum(next_window_size, min_window_size)\n            next_window_size = np.minimum(next_window_size, image_size)\n\n            prev_width = next_width * curr_window_size / next_window_size\n            prev_scaling = next_scaling\n            prev_window_size = curr_window_size\n\n            f = stroke_params_proc.tolist()  # (8)\n            f += [1.0, 1.0]\n            gt_stroke_img = draw(f)  # (raster_size, raster_size), [0.0-stroke, 1.0-BG]\n            gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size,\n                                                            curr_window_size,\n                                                            pasting_func, sess)  # [0.0-BG, 1.0-stroke]\n\n            if pen_state == 0:\n                canvas += gt_stroke_img_large  # [0.0-BG, 1.0-stroke]\n\n            if draw_order:\n                color_rgb = color_rgb_set[color_idx]  # (3) in [0, 255]\n                color_idx += 1\n\n                color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)\n                color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)\n                canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),\n                                                                                     axis=-1) + color_stroke  # (H, W, 3)\n\n                if pen_state == 0:\n                    canvas_color = canvas_color * np.expand_dims((1.0 - gt_stroke_img_large),\n                                                                 axis=-1) + color_stroke  # (H, W, 3)\n\n            # update cursor_pos based on hps.cursor_type\n            new_cursor_offsets = stroke_params[2:4] * (curr_window_size / 2.0)  # (1, 6), patch-level\n            new_cursor_offset_next = new_cursor_offsets\n\n            # important!!!\n            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)\n\n            cursor_pos_large = cursor_pos * float(image_size)\n\n            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level\n\n            if cursor_type == 'next':\n                cursor_pos_large = stroke_position_next  # (2), large-level\n            else:\n                raise Exception('Unknown cursor_type')\n\n            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level\n            cursor_pos = cursor_pos_large / float(image_size)\n\n            frames.append(canvas.copy())\n\n    canvas = np.clip(canvas, 0.0, 1.0)\n    canvas = np.round((1.0 - canvas) * 255.0).astype(np.uint8)  # [0-stroke, 255-BG]\n\n    os.makedirs(save_root, exist_ok=True)\n    save_path = os.path.join(save_root, save_filename)\n    canvas_img = Image.fromarray(canvas, 'L')\n    canvas_img.save(save_path, 'PNG')\n\n    if save_seq:\n        seq_save_root = os.path.join(save_root, 'seq', save_filename[:-4])\n        os.makedirs(seq_save_root, exist_ok=True)\n        for len_i in range(len(frames)):\n            frame = frames[len_i]\n            frame = np.round((1.0 - frame) * 255.0).astype(np.uint8)\n            save_path = os.path.join(seq_save_root, str(len_i) + '.png')\n            frame_img = Image.fromarray(frame, 'L')\n            frame_img.save(save_path, 'PNG')\n\n    if draw_order:\n        order_save_root = os.path.join(save_root, 'order')\n        order_comp_save_root = os.path.join(save_root, 'order-compare')\n        os.makedirs(order_save_root, exist_ok=True)\n        os.makedirs(order_comp_save_root, exist_ok=True)\n\n        canvas_color = 255 - np.round(canvas_color * 255.0).astype(np.uint8)\n        canvas_color_img = Image.fromarray(canvas_color, 'RGB')\n        save_path = os.path.join(order_save_root, save_filename)\n        canvas_color_img.save(save_path, 'PNG')\n\n        canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)\n\n        # comparsions\n        rows = 2\n        cols = 3\n        plt.figure(figsize=(5 * cols, 5 * rows))\n\n        plt.subplot(rows, cols, 1)\n        plt.title('Input', fontsize=12)\n        # plt.axis('off')\n        input_rgb = input_img\n        plt.imshow(input_rgb)\n\n        # plt.subplot(rows, cols, 2)\n        # plt.title('GT', fontsize=12)\n        # # plt.axis('off')\n        # gt_rgb = np.stack([gt_img for _ in range(3)], axis=2)\n        # plt.imshow(gt_rgb)\n\n        plt.subplot(rows, cols, 2)\n        plt.title('Sketch', fontsize=12)\n        # plt.axis('off')\n        canvas_rgb = np.stack([canvas for _ in range(3)], axis=2)\n        plt.imshow(canvas_rgb)\n\n        plt.subplot(rows, cols, 4)\n        plt.title('Sketch Order', fontsize=12)\n        # plt.axis('off')\n        plt.imshow(canvas_color)\n\n        plt.subplot(rows, cols, 5)\n        plt.title('Sketch Order with moving', fontsize=12)\n        # plt.axis('off')\n        plt.imshow(canvas_color_with_moving)\n\n        plt.subplot(rows, cols, 6)\n        plt.title('Order', fontsize=12)\n        plt.axis('off')\n\n        img_h = 5\n        img_w = 10\n        color_array = np.zeros([len(color_rgb_set) * img_h, img_w, 3], dtype=np.uint8)\n        for i in range(len(color_rgb_set)):\n            color_array[i * img_h: i * img_h + img_h, :, :] = color_rgb_set[i]\n\n        plt.imshow(color_array)\n\n        comp_save_path = os.path.join(order_comp_save_root, save_filename)\n        plt.savefig(comp_save_path)\n        plt.close()\n        # plt.show()\n\n\ndef update_hyperparams(model_params, model_base_dir, model_name, infer_dataset):\n    with tf.gfile.Open(os.path.join(model_base_dir, model_name, 'model_config.json'), 'r') as f:\n        data = json.load(f)\n\n    ignored_keys = ['image_size_small', 'image_size_large', 'z_size', 'raster_perc_loss_layer', 'raster_loss_wk',\n                    'decreasing_sn', 'raster_loss_weight']\n    for name in model_params._hparam_types.keys():\n        if name not in data and name not in ignored_keys:\n            raise Exception(name, 'not in model_config.json')\n\n    assert data['resize_method'] == 'AREA'\n    data['data_set'] = infer_dataset\n    fix_list = ['use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']\n    for fix in fix_list:\n        data[fix] = (data[fix] == 1)\n\n    pop_keys = ['gpus', 'image_size', 'resolution_type', 'loop_per_gpu', 'stroke_num_loss_weight_end',\n                'perc_loss_fuse_type',\n                'early_pen_length', 'early_pen_loss_type', 'early_pen_loss_weight',\n                'increase_start_steps', 'perc_loss_layers', 'sn_loss_type', 'photo_prob_end_step',\n                'sup_weight', 'gan_weight', 'base_raster_loss_base_type']\n    for pop_key in pop_keys:\n        if pop_key in data.keys():\n            data.pop(pop_key)\n\n    model_params.parse_json(json.dumps(data))\n\n    return model_params\n"
  },
  {
    "path": "vgg_utils/VGG16.py",
    "content": "import tensorflow as tf\n\n\ndef vgg_net(x, n_classes, img_size, reuse, is_train=True, dropout_rate=0.5):\n    # Define a scope for reusing the variables\n    with tf.variable_scope('VGG16', reuse=reuse):\n        x = tf.reshape(x, [-1, img_size, img_size, 1])\n\n        x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#1', x.shape)\n\n        x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#2', x.shape)\n\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#3', x.shape)\n\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#4', x.shape)\n\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#5', x.shape)\n\n        x_shape = x.get_shape().as_list()\n        nodes = x_shape[1] * x_shape[2] * x_shape[3]\n        x = tf.reshape(x, [-1, nodes])\n\n        x = tf.layers.dense(x, 4096, activation=tf.nn.relu)\n        if is_train:\n            x = tf.layers.dropout(x, dropout_rate)\n\n        x = tf.layers.dense(x, 4096, activation=tf.nn.relu)\n        if is_train:\n            x = tf.layers.dropout(x, dropout_rate)\n\n        out = tf.layers.dense(x, n_classes)\n        print(out)\n\n    return out\n\n\ndef vgg_net_slim(x, img_size):\n    return_map = {}\n    # Define a scope for reusing the variables\n    with tf.variable_scope('VGG16', reuse=tf.AUTO_REUSE):\n        x = tf.reshape(x, [-1, img_size, img_size, 1])\n\n        x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU1_1'] = x\n        x = tf.layers.conv2d(inputs=x, filters=64, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU1_2'] = x\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#1', x.shape)  #1 (?, 64, 64, 64)\n\n        x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU2_1'] = x\n        x = tf.layers.conv2d(inputs=x, filters=128, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU2_2'] = x\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#2', x.shape)  #2 (?, 32, 32, 128)\n\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU3_1'] = x\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU3_2'] = x\n        x = tf.layers.conv2d(inputs=x, filters=256, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU3_3'] = x\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#3', x.shape)  #3 (?, 16, 16, 256)\n\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU4_1'] = x\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU4_2'] = x\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU4_3'] = x\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#4', x.shape)  #4 (?, 8, 8, 512)\n\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU5_1'] = x\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU5_2'] = x\n        x = tf.layers.conv2d(inputs=x, filters=512, kernel_size=[3, 3], strides=1,\n                             padding='SAME', activation=tf.nn.relu)\n        return_map['ReLU5_3'] = x\n        x = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2)\n        print('#5', x.shape)  #5 (?, 4, 4, 512)\n\n    return return_map\n"
  },
  {
    "path": "virtual_sketch_gui.py",
    "content": "import tkinter as tk\nfrom tkinter import filedialog, messagebox\nimport subprocess\nimport os\nimport threading\nimport glob\nimport sys\nimport shutil\n\n# ==== Nastavení cesty ke skriptům ====\nMODEL_OPTIONS = {\n    \"Rough → Clean Sketch\": \"test_rough_sketch_simplification.py\",\n    \"Photo → Line Drawing\": \"test_photograph_to_line.py\",\n    \"Clean Sketch → Vector\": \"test_vectorization.py\",\n}\n\nSVG_CONVERTER = os.path.join(\"tools\", \"svg_conversion.py\")\n\nclass VirtualSketchApp:\n    def __init__(self, root):\n        self.root = root\n        self.root.title(\"Virtual Sketching GUI\")\n        self.input_file = None\n        self.model_script = tk.StringVar(value=list(MODEL_OPTIONS.values())[0])\n\n        self.build_ui()\n\n    def build_ui(self):\n        tk.Label(self.root, text=\"1. Vyber vstupní obrázek:\").pack(anchor=\"w\")\n        tk.Button(self.root, text=\"Vybrat obrázek\", command=self.choose_file).pack(fill=\"x\")\n\n        self.file_label = tk.Label(self.root, text=\"Žádný soubor nevybrán\", fg=\"gray\")\n        self.file_label.pack(anchor=\"w\")\n\n        tk.Label(self.root, text=\"2. Zvol model:\").pack(anchor=\"w\")\n        for name, script in MODEL_OPTIONS.items():\n            tk.Radiobutton(self.root, text=name, variable=self.model_script, value=script).pack(anchor=\"w\")\n\n        tk.Button(self.root, text=\"3. Spustit zpracování\", command=self.run_processing).pack(pady=10, fill=\"x\")\n\n        self.status = tk.Label(self.root, text=\"Připraven\", fg=\"green\")\n        self.status.pack(anchor=\"w\")\n\n    def choose_file(self):\n        path = filedialog.askopenfilename(filetypes=[\n            (\"Obrázkové soubory\", \"*.png *.jpg *.jpeg *.bmp *.gif *.tif *.tiff\"),\n            (\"PNG\", \"*.png\"),\n            (\"JPEG\", \"*.jpg;*.jpeg\"),\n            (\"BMP\", \"*.bmp\"),\n            (\"GIF\", \"*.gif\"),\n            (\"TIFF\", \"*.tif;*.tiff\")\n        ])\n        if path:\n            self.input_file = path\n            self.file_label.config(text=os.path.basename(path), fg=\"black\")\n\n    def run_processing(self):\n        if not self.input_file:\n            messagebox.showerror(\"Chyba\", \"Nejprve vyber obrázek.\")\n            return\n\n        script = self.model_script.get()\n        cmd = [sys.executable, script, \"--input\", self.input_file]\n\n        def task():\n            self.status.config(text=\"Zpracovávám...\", fg=\"blue\")\n            try:\n                subprocess.run(cmd, check=True)\n                self.status.config(text=\"✅ Hotovo\", fg=\"green\")\n                self.move_outputs_to_sketches()\n                self.run_svg_conversion()\n            except subprocess.CalledProcessError:\n                self.status.config(text=\"❌ Chyba při běhu skriptu\", fg=\"red\")\n\n        threading.Thread(target=task).start()\n\n    def move_outputs_to_sketches(self):\n        if not self.input_file:\n            return\n\n        input_dir = os.path.dirname(self.input_file)\n        input_base = os.path.splitext(os.path.basename(self.input_file))[0]\n        sketches_dir = os.path.join(input_dir, \"sketches\")\n        os.makedirs(sketches_dir, exist_ok=True)\n\n        for ext in [\"_0.npz\", \"_0_pred.png\", \"_input.png\", \"_0.svg\"]:\n            candidate = os.path.join(input_dir, f\"{input_base}{ext}\")\n            if os.path.isfile(candidate):\n                shutil.move(candidate, os.path.join(sketches_dir, os.path.basename(candidate)))\n\n    def run_svg_conversion(self):\n        if not self.input_file:\n            return\n\n        input_dir = os.path.dirname(self.input_file)\n        input_base = os.path.splitext(os.path.basename(self.input_file))[0]\n        npz_file = os.path.join(input_dir, f\"{input_base}_0.npz\")\n        sketches_dir = os.path.join(input_dir, \"sketches\")\n        npz_file_in_sketches = os.path.join(sketches_dir, f\"{input_base}_0.npz\")\n\n        if not os.path.isfile(npz_file_in_sketches):\n            print(\"⚠️ .npz soubor nebyl nalezen pro SVG konverzi.\")\n            return\n\n        cmd = [sys.executable, SVG_CONVERTER, \"--file\", npz_file_in_sketches, \"--svg_type\", \"single\"]\n        try:\n            subprocess.run(cmd, check=True)\n            svg_path = os.path.join(sketches_dir, f\"{input_base}_0.svg\")\n            if os.path.isfile(svg_path):\n                print(f\"✅ SVG vytvořeno: {svg_path}\")\n        except subprocess.CalledProcessError:\n            print(\"⚠️ Chyba při SVG konverzi\")\n\nif __name__ == \"__main__\":\n    root = tk.Tk()\n    app = VirtualSketchApp(root)\n    root.mainloop()\n"
  }
]