[
  {
    "path": ".gitignore",
    "content": ".vscode\nraw\n__pycache__/*\nsftp-config*.json\nmodels\nprocessed\n*.pth\n*.tar\noutput\n*.gz\n.DS_Store/*\nresult\ncheckpoints\n*.pyc\ndata/results\n.DS_Store\nlocal\ncommand.txt\n*.ipynb_checkpoints\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvidia/cuda:10.1-base-ubuntu16.04\nLABEL version=\"1.0\"\nLABEL description=\"Build using the command \\\n  'docker build -t epflvil/xtconsistency:latest .'\"\n\nARG DEFAULT_GIT_BRANCH=master\nARG DEFAULT_GIT_REPO=git@github.com:EPFL-VIL/XTConsistency.git\nARG GITHUB_DEPLOY_KEY_PATH=docker_key\nARG GITHUB_DEPLOY_KEY\nARG GITHUB_DEPLOY_KEY_PUBLIC\n\nRUN apt-get update && apt-get install -y \\\n    curl \\\n    wget \\\n    ca-certificates \\\n    sudo \\\n    git \\\n    unzip \\\n    bzip2 \\\n    libx11-6 \\\n    nano \\\n    screen \\\n    gcc \\\n    python3-dev \\\n && rm -rf /var/lib/apt/lists/*\n\nRUN mkdir /root/.ssh\nRUN echo \"DEPLOY\" \"${GITHUB_DEPLOY_KEY}\"\nRUN echo \"DEPLOY\" \"${GITHUB_DEPLOY_KEY_PUBLIC}\"\nRUN echo \"${GITHUB_DEPLOY_KEY}\" > /root/.ssh/id_rsa\nRUN echo \"${GITHUB_DEPLOY_KEY_PUBLIC}\" > /root/.ssh/id_rsa.pub\nRUN chmod 600 /root/.ssh/id_rsa\nRUN cat /root/.ssh/id_rsa*\nRUN eval $(ssh-agent) && \\\n    ssh-add /root/.ssh/id_rsa && \\\n    ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts\nRUN git clone --single-branch --branch \"${DEFAULT_GIT_BRANCH}\" \"${DEFAULT_GIT_REPO}\" /app\n\n#############################\n# Pull code\n#############################\n# RUN mkdir /app\nWORKDIR /app\n\nRUN cd /app && git config core.filemode false\nRUN chmod -R 777 /app\n\n\n#############################\n# Create non-root user\n#############################\n# Create a non-root user and switch to it\nRUN adduser --disabled-password --gecos '' --shell /bin/bash user \\\n && chown -R user:user /app\nRUN echo \"user ALL=(ALL) NOPASSWD:ALL\" > /etc/sudoers.d/90-user\nUSER user\n\n# All users can use /home/user as their home directory\nENV HOME=/home/user\nRUN chmod 777 /home/user\n\n\n#############################\n# Create conda environment\n#############################\n# Install Miniconda\nRUN curl -Lso ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh \\\n && chmod +x ~/miniconda.sh \\\n && ~/miniconda.sh -b -p ~/miniconda \\\n && rm ~/miniconda.sh\nENV PATH=/home/user/miniconda/bin:$PATH\nENV CONDA_AUTO_UPDATE_CONDA=false\n\n# Create a Python 3.6 environment\nRUN /home/user/miniconda/bin/conda create -y --name py36 python=3.6.9 \\\n && /home/user/miniconda/bin/conda clean -ya\nENV CONDA_DEFAULT_ENV=py36\nENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV\nENV PATH=$CONDA_PREFIX/bin:$PATH\nRUN /home/user/miniconda/bin/conda install conda-build=3.18.9=py36_3 \\\n && /home/user/miniconda/bin/conda clean -ya\n\n\n#############################\n# Python packages\n#############################\nRUN conda install -y -c pytorch \\\n    cudatoolkit=10.1 \\\n    \"pytorch=1.4.0\" \\\n    \"torchvision=0.5.0\" \\\n  && conda clean -ya\nRUN conda install -y \\\n  ipython==6.5.0 \\\n  matplotlib==3.0.3 \\\n  plac==0.9.6 \\\n  py==1.6.0 \\\n  scipy==1.3.1 \\\n  tqdm==4.36.1 \\\n  pathlib==1.0.1 \\\n  seaborn==0.10.0 \\\n  scikit-learn==0.22.1 \\\n  scikit-image==0.16.2 \\\n && conda clean -ya\nRUN conda install -c conda-forge jupyterlab && conda clean -ya\nRUN pip install runstats==1.8.0 \\\n  fire==0.2.1 \\\n  visdom==0.1.8.9 \\\n  parse==1.12.1\n\n  \n###############################################\n# Default command and environment variables\n###############################################\nRUN sudo touch /root/.bashrc && sudo chmod 770 /root/.bashrc\nRUN echo export PATH=\"\\$PATH:\"$PATH >> /tmp/.bashrc\nRUN sudo su -c 'cat /tmp/.bashrc >> /root/.bashrc' && rm /tmp/.bashrc\n\n# Set the default command to bash\nCMD [\"bash\"]\n"
  },
  {
    "path": "README.md",
    "content": "# Robust Learning Through Cross-Task Consistency <br> \n\n[![](./assets/intro.jpg)](https://consistency.epfl.ch)\n\n<table>\n      <tr><td><em>Above: A comparison of the results from consistency-based learning and learning each task individually. The yellow markers highlight the improvement in fine grained details.</em></td></tr>\n</table>\n\n<br>\nThis repository contains tools for training and evaluating models using consistency:\n\n- [Pretrained models](#pretrained-models)\n- [Demo code](#quickstart-run-demo-locally) and an **[online live demo](https://consistency.epfl.ch/demo/)**\n- [_Uncertainty energy_ estimation code](#Energy-computation)\n- [Training scripts](#training)\n- [Docker and installation instructions](#installation)\n\nfor the following paper:\n<!-- <br><a href=https://consistency.epfl.ch>Robust Learing Through Cross-Task Consistency</a> (CVPR 2020, Oral).<br> -->\n<!-- Amir Zamir, Alexander Sax, Teresa Yeo, Oğuzhan Kar, Nikhil Cheerla, Rohan Suri, Zhangjie Cao, Jitendra Malik, Leonidas Guibas  -->\n\n<div style=\"text-align:center\">\n<h4><a href=https://consistency.epfl.ch>Robust Learing Through Cross-Task Consistency</a> (CVPR 2020, Best Paper Award Nomination, Oral)</h4>\n</div>\n<br>\n\n[![Cross-Task Consistency Results](./assets/vid_thumbnail_600_gif2.gif)](https://youtu.be/6CSmmrBNX9M \"Click to watch results of applying the networks trained with cross-task consisntency frame-by-frame on sample YouTube videos.\")\n\n\nFor further details, a [live demo](https://consistency.epfl.ch/demo/), [video visualizations](https://consistency.epfl.ch/visuals/), and an [overview talk](https://consistency.epfl.ch/#paper), refer to our [project website](https://consistency.epfl.ch/).\n\n#### PROJECT WEBSITE:\n<div style=\"text-align:center\">\n\n| [LIVE DEMO](https://consistency.epfl.ch/demo/) | [VIDEO VISUALIZATION](https://consistency.epfl.ch/visuals/) \n|:----:|:----:|\n| Upload your own images and see the results of different consistency-based models vs. various baselines.<br><br>[<img src=./assets/screenshot-demo.png width=\"400\">](https://consistency.epfl.ch/demo/) | Visualize models with and without consistency, evaluated on a (non-cherry picked) YouTube video.<br><br><br>[<img src=./assets/output_video.gif width=\"400\">](https://consistency.epfl.ch/visuals/) |\n\n</div>\n\n---\n\nTable of Contents\n=================\n\n   * [Introduction](#introduction)\n   * [Installation](#installation)\n   * [Quickstart (demo code)](#quickstart-run-demo-locally)\n   * [Energy computation](#energy-computation)\n   * [Download all pretrained models](#pretrained-models)\n   * [Train a consistency model](#training)\n     * [Instructions for training](#steps)\n     * [To train on other configurations](#to-train-on-other-target-domains)\n   * [Citing](#citation)\n\n<br>\n\n## Introduction \n\nVisual perception entails solving a wide set of tasks (e.g. object detection, depth estimation, etc). The predictions made for each task out of a particular observation are not independent, and therefore, are expected to be **consistent**.\n\n**What is consistency?** Suppose an object detector detects a ball in a particular region of an image, while a depth estimator returns a flat surface for the same region. This presents an issue -- at least one of them has to be wrong, because they are _inconsistent_.\n\n**Why is it important?** \n1. Desired learning tasks are usually predictions of different aspects of a single underlying reality (the scene that underlies an image). Inconsistency among predictions implies contradiction. \n2. Consistency constraints are informative and can be used to better fit the data or lower the sample complexity. They may also reduce the tendency of neural networks to learn \"surface statistics\" (superficial cues) by enforcing constraints rooted in different physical or geometric rules. This is empirically supported by the improved generalization of models when trained with consistency constraints.\n\n**How do we enforce it?** The underlying concept is that of path independence in a network of tasks. Given an endpoint `Y2`, the path from \n`X->Y1->Y2` should give the same results as `X->Y2`. This can be generalized to a larger system, with paths of arbitrary lengths. In this case, the nodes of the graph are our prediction domains (eg. depth, normal) and the edges are neural networks mapping these domains.\n\nThis repository includes [training](#training) code for enforcing cross task consistency, [demo](#run-demo-script) code for visualizing the results of a consistency trained model on a given image and [links](#pretrained-models) to download these models. For further details, refer to our [paper]() or [website](https://consistency.epfl.ch/).\n\n\n#### Consistency Domains\n\nConsistency constraints can be used for virtually any set of domains. This repository considers transferring between image domains, and our networks were trained for transferring between the following domains from the [Taskonomy dataset](https://github.com/StanfordVL/taskonomy/tree/master/data).\n\n    Curvature         Edge-3D            Reshading\n    Depth-ZBuffer     Keypoint-2D        RGB       \n    Edge-2D           Keypoint-3D        Surface-Normal \n\n\nThe repo contains consistency-trained models for `RGB -> Surface-Normals`,  `RGB -> Depth-ZBuffer`, and `RGB -> Reshading`. In each case the remaining 7 domains are used as consistency constraints in during training.\n\nDescriptions for each domain can be found in the [supplementary file](http://taskonomy.stanford.edu/taskonomy_supp_CVPR2018.pdf) of Taskonomy.\n\n#### Network Architecture\n\nAll networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions and trained with the L1 loss. See the table below for more information.\n\n|        Task Name        | Output Dimension | Downsample Blocks |\n|-------------------------|------------------|-------------------|\n| `RGB -> Depth-ZBuffer`  | 256x256x1        | 6                 |\n| `RGB -> Reshading`      | 256x256x1        | 5                 |\n| `RGB -> Surface-Normal` | 256x256x3        | 6                 |\n\nOther networks (e.g. `Curvature -> Surface-Normal`) use a UNet, their architecture hyperparameters are detailed in [transfers.py](./transfers.py).\n\nMore information on the models, including download links, can be found [here](#pretrained-models) and in the [supplementary material](https://consistency.epfl.ch/supplementary_material).\n\n<br>\n<br>\n\n## Installation\n\nThere are two convenient ways to run the code. Either using Docker (recommended) or using a Python-specific tool such as pip, conda, or virtualenv.\n\n#### Installation via Docker [Recommended]\n\nWe provide a docker that contains the code and all the necessary libraries. It's simple to install and run.\n1. Simply run:\n<!-- docker pull epflvilab/xtconsistency:latest -->\n```bash\ndocker run --runtime=nvidia -ti --rm epflvilab/xtconsistency:latest\n```\nThe code is now available in the docker under your home directory (`/app`), and all the necessary libraries should already be installed in the docker.\n\n#### Installation via Pip/Conda/Virtualenv\nThe code can also be run using a Python environment manager such as Conda. See [requirements.txt](./requirements.txt) for complete list of packages. We recommend doing a clean installation of requirements using virtualenv:\n1.  Clone the repo:\n```bash\ngit clone git@github.com:EPFL-VILAB/XTConsistency.git\ncd XTConsistency\n```\n\n2. Create a new environment and install the libraries:\n```bash\nconda create -n testenv -y python=3.6\nsource activate testenv\npip install -r requirements.txt\n```\n\n\n<br>\n<br>\n\n## Quickstart (Run Demo Locally)\n\n#### Download the consistency trained networks\nIf you haven't yet, then download the [pretrained models](#Download-consistency-trained-models). Models used for the demo can be downloaded with the following command:\n```bash\nsh ./tools/download_models.sh\n```\n\nThis downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. Individial models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P).\n\n#### Run a model on your own image\n\nTo run the trained model of a task on a specific image:\n\n```bash\npython demo.py --task $TASK --img_path $PATH_TO_IMAGE_OR_FOLDER --output_path $PATH_TO_SAVE_OUTPUT\n```\n\nThe `--task` flag specifies the target task for the input image, which should be either `normal`, `depth` or `reshading`.\n\nTo run the script for a `normal` target on the [example image](./assets/test.png):\n\n```bash\npython demo.py --task normal --img_path assets/test.png --output_path assets/\n```\n\nIt returns the output prediction from the baseline (`test_normal_baseline.png`) and consistency models (`test_normal_consistency.png`).\n\nTest image                 |  Baseline\t\t\t|  Consistency\n:-------------------------:|:-------------------------: |:-------------------------:\n![](./assets/test_scaled.png)|  ![](./assets/test_normal_baseline.png) |  ![](./assets/test_normal_consistency.png)\n\n\nSimilarly, running for target tasks `reshading` and `depth` gives the following.\n\n  Baseline (reshading)      |  Consistency (reshading)   |  Baseline (depth)\t       |  Consistency (depth)\n:-------------------------: |:-------------------------: | :-------------------------: |:-------------------------:\n![](./assets/test_reshading_baseline.png) |  ![](./assets/test_reshading_consistency.png) | ![](./assets/test_depth_baseline.png) |  ![](./assets/test_depth_consistency.png)\n\n\n\n<br>\n<br>\n\n## Energy Computation\n\nTraining with consistency involves several paths that each predict the target domain, but using different cues to do so. The disagreement between these predictions yields an unsupervised quantity, _consistency energy_, that our CVPR 2020 paper found correlates with prediciton error. You can view the pixel-wise _consistency energy_ (example below) using our [live demo](https://consistency.epfl.ch/demo/).\n\n\n|             Sample Image             |             Normal Prediction             |             Consistency Energy             |\n|:------------------------------------:|:------------------------------------:|:------------------------------------:|\n| <img src=./assets/energy_query.png width=\"600\">  | <img src=./assets/energy_normal_prediction.png width=\"600\">  | <img src=./assets/energy_prediction.png width=\"600\">  |\n| _Sample image from the Stanford 2D3DS dataset._  | _Some chair legs are missing in the `RGB -> Normal` prediction._  |  _The white pixels indicate higher uncertainty about areas with missing chair legs._  | \n\n\nTo compute energy locally, over many images, and/or to plot energy vs error, you can use the following `energy_calc.py` script. For example, to reproduce the following scatterplot using `energy_calc.py`:\n\n|             Energy vs. Error             |\n|:----------------------------------------:|\n| ![](./assets/energy_vs_error.jpg)        |\n| _Result from running the command below._ | \n\n\n\nFirst download a subset of images from the Taskonomy buildings `almena` and `albertville` (512 images per domain, 388MB):\n```bash\nsh ./tools/download_data.sh\n```\n\n\nSecond, download all the networks necessary to compute the consistency energy. The following script will download them for you (skipping previously downloaded models) (0.8GB - 4.0GB):\n```bash\nsh ./tools/download_energy_graph_edges.sh\n```\n\n\nNow we are ready to compute energy. The following command generates a scatter plot of _consistency energy_ vs. prediction error:\n\n```bash\npython -m scripts.energy_calc energy_calc --batch_size 2 --subset_size=128 --save_dir=results\n```\n\n\nBy default, it computes the energy and error of the `subset_size` number of points on the Taskonomy buildings `almena` and `albertville`. The error is computed for the `normal` target. The resulting plot is saved to `energy.pdf` in `RESULTS_DIR` and the corresponding data to `data.csv`. \n\n#### Compute energy on arbitrary images\n_Consistency energy_ is an unsupervised quantity and as such, no ground-truth labels are necessary. To compute the energy for all query images in a directory, run:\n\n```bash\npython -m scripts.energy_calc energy_calc_nogt \n    --data-dir=PATH_TO_QUERY_IMAGE --batch_size 1 --save_dir=RESULTS_DIR \\\n    --subset_size=NUMBER_OF_IMAGES --cont=PATH_TO_TRAINED_MODEL\n```\n\nIt will append a dashed horizontal line to the plot above where the energy of the query image(s) are. This plot is saved to `energy.pdf` in `RESULTS_DIR`.\n\n\n<br>\n<br>\n\n## Pretrained Models\n\nWe are providing all of our pretrained models for download. These models are the same ones used in the [live demo](https://consistency.epfl.ch/demo/) and [video evaluations](https://consistency.epfl.ch/visuals/).\n\n\n#### Network Architecture\nAll networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions. All models were trained with the L1 loss.\n\n\n#### Download consistency-trained models\nInstructions for downloading the trained consistency models can be found [here](#download-consistency-trained-networks)\n```bash\nsh ./tools/download_models.sh\n```\n\nThis downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. See the table below for specifics:\n\n|        Task Name        | Output Dimension | Downsample Blocks |\n|-------------------------|------------------|-------------------|\n| `RGB -> Depth-ZBuffer`  | 256x256x1        | 6                 |\n| `RGB -> Reshading`      | 256x256x1        | 5                 |\n| `RGB -> Surface-Normal` | 256x256x3        | 6                 |\n\nIndividual consistency models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P).\n\n\n\n#### Download perceptual networks\nThe pretrained perceptual models can be downloaded with the following command.\n\n```bash\nsh ./tools/download_percep_models.sh\n```\n\nThis downloads the perceptual models for the `depth`, `normal` and `reshading` target (1.6GB). Each target has 7 pretrained models (from the other sources below).\n\n```\nCurvature         Edge-3D            Reshading\nDepth-ZBuffer     Keypoint-2D        RGB       \nEdge-2D           Keypoint-3D        Surface-Normal \n```\n\nPerceptual model architectural hyperparameters are detailed in [transfers.py](./transfers.py), and some of the pretrained models were trained using L2 loss. For using these models with the provided training code, the pretrained models should be placed in the file path defined by `MODELS_DIR` in [utils.py](./utils.py#L25).\n\nIndividual perceptual models can be downloaded [here](https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE).\n\n\n\n#### Download baselines\nWe also provide the models for other baselines used in the paper. Many of these baselines appear in the [live demo](https://consistency.epfl.ch/demo/). The pretrained baselines can be downloaded [here](https://drive.switch.ch/index.php/s/gdom4FpiiYo1Qay). Note that we will not be providing support for them. \n- A full list of baselines is in the table below:\n   |                     Baseline Method                     |                                                       Description                                                              |      Tasks (RGB -> X)    | \n   |---------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------|--------------------------|\n   | Baseline UNet [[PDF](https://arxiv.org/pdf/1505.04597.pdf)]   | UNets trained on the Taskonomy dataset.                                                                                  | Normal, Reshade, Depth      | \n   | Baseline Perceptual Loss                                | Trained using a randomly initialized percepual network, similar to [RND](http://arxiv.org/pdf/1810.12894.pdf).                 | Normal                      | \n   | Cycle Consistency [[PDF](https://arxiv.org/pdf/1703.10593.pdf)] | A CycleGAN trained on the Taskonomy dataset.                                                                           | Normal                      | \n   | GeoNet [[PDF](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8578135_)] | Trained on the Taskonomy dataset and using L1 instead of L2 loss.                                    | Normal, Depth            |\n   | Multi-Task [[PDF](http://arxiv.org/pdf/1609.02132.pdf)] | A multi-task model we trained using UNets, using a shared encoder (similar to [here](http://arxiv.org/pdf/1609.02132.pdf))     | [All](#consistency-domains) |\n   | Pix2Pix [[PDF](https://arxiv.org/pdf/1611.07004.pdf)]   | A Pix2Pix model trained on the Taskonomy dataset.                                                                              | Normal                      | \n   | Taskonomy [[PDF](https://arxiv.org/pdf/1804.08328.pdf)] | Pretrained models (converted to pytorch [here](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-)), originally trained [here](https://github.com/StanfordVL/taskonomy/tree/master/taskbank).  | Normal, Reshading, Depth* |\n*Models for other tasks are available using the [`visualpriors` package](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-) or in Tensorflow via the [Taskonomy GitHub page](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-).\n\n\n\n\n\n\n<br>\n<br>\n\n## Training\n\nWe used the provided training code to train our consistency models on the Taskonomy dataset. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week. \n\n> **Runnable Example:** \n   You'll find that the code in the rest of this section expects about 12TB of data (9 single-image tasks from Taskonomy). For a quick runnable example that gives the gist, try the following:  \n>     \n>  First download the data and then start a visdom (logging) server:\n>  ```bash\n>  sh ./tools/download_data.sh # Starter data (388MB)\n>  visdom &                    # To view the telemetry\n>  ```\n>  \n>  Then, start the training using the following command, which cascades two models (trains a `normal` model using `curvature` consistenct on a training set of 512 images).  \n>   ```bash\n>   python -m train example_cascade_two_networks --k 1 --fast\n>   ```\n>   You can add more pereceptual losses by changing the config in `energy.py`. For example, train the above model using _both_ `curvature` and `2D edge` consistency:\n>   ```bash\n>   python -m train example_normal --k 2 --fast\n>   ```\n\nAssuming that you want to train on the full dataset or [on your own dataset], read on.\n#### The code is structured as follows\n```python\nconfig/             # Configuration parameters: where to save results, etc.\n    split.txt           # Train, val split\n    jobinfo.txt         # Defines job name, base_dir\nmodules/            # Network definitions\ntrain.py            # Training script\ndataset.py          # Creates dataloader\nenergy.py           # Defines path config, computes total loss, logging\nmodels.py           # Implements forward backward pass\ngraph.py            # Computes path defined in energy.py\ntask_configs.py     # Defines task specific preprocessing, masks, loss fn\ntransfers.py        # Loads models\nutils.py            # Defines file paths (described below) \ndemo.py             # Demo script\n```\n\n#### Expected folder structure\nThe code expects folders structured as follows. These can be modified by changing values in `utils.py`\n```python\nbase_dir/                   # The following paths are defined in utils.py (BASE_DIR)\n    shared/                 # with the corresponding variable names in brackets\n        models/             # Pretrained models (MODELS_DIR)\n        results_[jobname]/  # Checkpoint of model being trained (RESULTS_DIR)\n        ood_standard_set/   # OOD data for visualization (OOD_DIR)\n    data_dir/               # taskonomy data (DATA_DIRS)\n```\n\n#### Training with consistency\n\n   \n1) **Define locations for data, models, etc.:** Create a `jobinfo.txt` file and define the name of the job and the absolute path to `BASE_DIR` where data, models results would be stored, as shown in the folder structure above. An example config is provided in the starter code (`configs/jobinfo.txt`). To modify individual file paths eg. the models folder, change `MODELS_DIR` variable name in [utils.py](./utils.py#L25).  \n  \n   We won't cover downloading the Taskonomy dataset, which can be downloaded following the instructions [here](https://github.com/StanfordVL/taskonomy/tree/master/data)\n\n2) **Download perceptual networks:** If you want to initialize from our pretrained models, then then [download them](#Download-perceptual-networks) with the following command (1.6GB): \n    ```bash\n    sh ./tools/download_percep_models.sh\n    ```\n    More info about the networks is available [here](#Download-perceptual-networks).\n\n3) **Train with consistency** using the command:\n\n   ```bash\n   python -m train multiperceptual_{depth,normal,reshading}\n   ```\n\n   For example, to run the training code for the `normal` target, run \n\n   ```bash\n   python -m train multiperceptual_normal\n   ```\n\n   This trains the model for the `normal` target with 8 perceptual losses ie. `curvature`, `edge2d`, `edge3d`, `keypoint2d`, `keypoint3d`, `reshading`, `depth` and `imagenet`. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week.\n\n   Additional arugments can be specified during training, the most commonly used ones are listed below. For the full list, refer to the [training script](./train.py).\n   - The flag `--k` defines the number of perceptual losses used, thus reducing GPU memory requirements.\n   - There are several options for choosing how this subset is chosen 1. randomly (`--random-select`) 2. winrate (`--winrate`)\n   - Data augmentation is not done by default, it can be added to the training data with the flag `--dataaug`. The transformations applied are 1. random crop with probability 0.5 2. [color jitter](https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=color%20jitter#torchvision.transforms.ColorJitter) with probability 0.5.\n\n   To train a `normal` target domain with 2 perceptual losses selected randomly each epoch, run the following command.\n\n   ```bash\n   python -m train multiperceptual_normal --k 2 --random-select\n   ```\n\n4) **Logging:** The losses and visualizations are logged in Visdom. This can be accessed via `[server name]/env/[job name]` eg. `localhost:8888/env/normaltarget_allperceps`. \n\n   An example visualization is shown below. We plot the the outputs from the paths defined in the energy configuration used. Two windows are shown, one shows the predictions before training starts, the other updates them after each epoch. The labels for each column can be found at the top of the window. The second column has the target's ground truth `y^`, the third its prediction `n(x)` from the RGB image `x`. Thereafter, the predictions of each pair of images with the same domain are given by the paths `f(y^),f(n(x))`, where `f` is from the target domain to another domain eg. `curvature`.\n\n   ![](./assets/visdom_eg.png)\n\n   **Logging conventions:** For uninteresting historical reasons, the columns in the logging during training might have strange names. You can define your own names instead of using these by changing the config file in `energy.py`.\n   \n   Here's a quick guide to the current convention. For example, when training with a `normal` model using consistency:\n    - The RGB input is denoted as `x` and the `target` domain is denoted as `y`. The ground truth label for a domain is marked with a `^`(e.g. `y^` for the for `target` domain). \n    - The direct (`RGB -> Z`) and perceptual (`target [Y] -> Z`) transfer functions are named as follows:<br>(i.e. the function for `rgb` to `curvature` is `RC`; for `normal` to `curvature` it's `f`)\n\n   |  Domain (Z) | `rgb -> Z`<br>(Direct) | `Y -> Z`<br>(Perceptual) ||    Domain (Z)   | `rgb -> Z`<br>(Direct) | `Y -> Z`<br>(Perceptual) |\n   |-------------|------------------------|--------------------------|-|-----------------|------------------------|---------------------------|\n   | target      | n                      | -                        || keypoints2d     | k2                     | Nk2                       |\n   | curvature   | RC                     | f                        || keypoints3d     | k3                     | Nk3                       |\n   | sobel edges | a                      | s                        || edge occlusion  | E0                     | nE0                       |\n\n#### To train on other target domains\n1. A new configuration should be defined in the `energy_configs` dictionary in [energy.py](./energy.py#L39-L521). \n\n   Decription of the infomation needed:\n   - `paths`: `X1->X2->X3`. The keys in this dictionary uses a function notation eg. `f(n(x))`, with its corresponding value being a list of task objects that defines the domains being transfered eg. `[rgb, normal, curvature]`. The `rgb` input is defined as `x`, `n(x)` returns `normal` predictions from `rgb`, and `f(n(x))` returns `curvature` from `normal`. These notations do not need to be same for all configurations. The [table](#function-definitions) below lists those that have been kept constant for all targets.\n   - `freeze_list`: the models that will not be optimized,\n   - `losses`: loss terms to be constructed from the paths defined above,\n   - `plots`: the paths to plots in the visdom environment.\n\n2. New models may need to be defined in the `pretrained_transfers` dictionary in [transfers.py](./transfers.py#L26-L97). For example, for a `curvature` target, and perceptual model `curvature` to `normal`, the code will look for the `principal_curvature2normal.pth` file in `MODELS_DIR` if it is not defined in [transfers.py](./transfers.py#L26-L97).\n\n\n\n#### To train on other datasets\nThe expected folder structure for the data is,\n```\nDATA_DIRS/\n  [building]_[domain]/\n      [domain]/\n          [view]_domain_[domain].png\n          ...\n```\nPytorch's dataloader _\\_\\_getitem\\_\\__ method has been overwritten to return a tuple of all tasks for a given building and view point. This is done in [datasets.py](./datasets.py#L181-L198). Thus, for other folder structures, a function to get the corresponding file paths for different domains should be defined. \n\nFor task specific configs, like transformations and masks, are defined in [task_configs.py](./task_configs.py#L341-L373).\n\n<br>\n<br>\n\n## Citation\nIf you find the code, models, or data useful, please cite this paper:\n\n```\n@article{zamir2020consistency,\n  title={Robust Learning Through Cross-Task Consistency},\n  author={Zamir, Amir and Sax, Alexander and Yeo, Teresa and Kar, Oğuzhan and Cheerla, Nikhil and Suri, Rohan and Cao, Zhangjie and Malik, Jitendra and Guibas, Leonidas},\n  journal={arXiv},\n  year={2020}\n}\n```\n\n"
  },
  {
    "path": "config/jobinfo.txt",
    "content": "normaltarget_allperceps, .\n"
  },
  {
    "path": "config/split.txt",
    "content": "train_buildings: [woodbine, hometown, haymarket, emmaus, swormville, haxtun, martinville,\n  winfield, marksville, hammon, mammoth, kronborg, cobalt, lenoir, bonnie, bautista,\n  retsof, azusa, munsons, darrtown, michiana, uncertain, lajas, plessis, bellemeade,\n  jacobus, swisshome, idanha, lathrup, hambleton, country, ancor, haaswood, orason,\n  cisne, byers, muleshoe, fredericksburg, cayuse, elmira, adrian, merchantville, german,\n  waipahu, clarkridge, rogue, ludlowville, cochranton, colebrook, lovilia, mullica,\n  mobridge, kevin, yscloskey, laytonsville, convoy, sisters, cosmos, calavo, clairton,\n  crookston, nicut, ladue, anaheim, howie, codell, seatonville, brown, quantico, goodview,\n  parole, kingfisher, churchton, edgemere, micanopy, bettendorf, goffs, windhorst,\n  ooltewah, tokeland, darden, benicia, american, neibert, milaca, willow, samuels,\n  sanctuary, pablo, cantwell, biltmore, germfask, eastville, rough, westfield, wainscott,\n  mahtomedi, kopperl, gluck, mogote, cason, hercules, avonia, pettigrew, ewell, imbery,\n  collierville, maguayo, brentsville, mckeesport, hitchland, browntown, aldine, spotswood,\n  chrisney, kildare, stockman, ribera, hortense, mentasta, soldier, kettle, northgate,\n  crandon, ovalo, cullison, dedham, vails, tallmadge, gratz, fonda, helton, mogadore,\n  highspire, sharon, foyil, shelbiana, seiling, brevort, noonday, springhill,\n  coronado, sundown, carneiro, silas, assinippi, monticello, wappingers,\n  lakeville, stockwell, ogilvie, victorville, braxton, sodaville, silerton, hobson,\n  tradewinds, sands, coffeen, umpqua, blackstone, sarcoxie, model, bremerton, capistrano,\n  deatsville, graceville, dansville, belpre, edson, mcnary, kirwin, rosenberg, lynchburg,\n  ranchester, shingler, auburn, connellsville, alstown, kerrtown, marstons, hurley,\n  mifflintown, pamelia, sumas, chilhowie, dryville, deemston, cashel, galatia, harrellsville,\n  mcdade, eudora, sasakwa, baneberry, rosser, halfway, hainesburg, gravelly, frierson,\n  tyler, irvine, natural, murchison, lindsborg, duarte, wando, globe, neshkoro, cornville,\n  bowmore, roxboro, espanola, maugansville, yankeetown, sawpit, schoolcraft, klickitat,\n  scandinavia, donaldson, aloha, gaylord, hartline, laupahoehoe, wiconisco, mesic,\n  eagerville, keiser, potosi, wyldwood, macarthur, newfields, moberly, everton, lindenwood,\n  nuevo, bethlehem, silva, noxapater, lindberg, hornsby, weleetka, tysons, kremlin,\n  jenners, trail, freedom, mcclure, ruckersville, sugarville, nemacolin, athens, vacherie,\n  checotah, blenheim, allensville, grantsville, holcut, hallettsville, angiola, tomales,\n  grangeville, seward, fishersville, kendall, kangley, wilbraham, caruthers, hacienda,\n  readsboro, pocopson, bonfield, cohoes, inkom, monson, peacock, touhy, divide, norvelt,\n  badger, leilani, corozal, warrenville, lluveras, grigston, cooperstown, nimmons,\n  ewansville, paige, matoaca, lessley, purple, kihei, millbury, culbertson, maunawili,\n  brewton, maryhill, channel, branford, creede, goodfield, spencerville, kirksville,\n  cokeville, barahona, leonardo, mosinee, tolstoy, broseley, broadwell, landing, roeville,\n  hatfield, rancocas, mcewen, annona, okabena, terrell, barboursville, booth, hanson,\n  mashulaville, cutlerville, euharlee, rabbit, fitchburg, shellsburg, milford, grassy,\n  timberon, coeburn, wilkinsburg, lynxville, islandton, arbutus, reyno, wakeman, frankfort,\n  sontag, voorhees, beechwood, ossipee, rockport, starks, woonsocket, hildebran, circleville,\n  aldrich, sunshine, destin, chesterbrook, musicks, merom, kinde, andover, pittsburg,\n  scioto, tilghmanton, castor, potterville, onaga, stanleyville, leavittsburg, carpendale,\n  bountiful, kingdom, cebolla, sweatman, arona, sagerton, herricks, morris, montreal,\n  stokes, newcomb, adairsville, bertram, kinney, spread, akiak, westerville, texasville,\n  springerville, almota, aulander, superior, goodyear, cabin, random, ballou, southfield,\n  ballantine, glenmoor, oriole, ashport, denmark, winthrop, bohemia, yadkinville,\n  smoketown, waucousta, winooski, peden, mayesville, liddieville, clive, gluek, goodwine,\n  uvalda, pleasant, losantville, lineville, hillsdale, ackermanville, waukeenah, mentmore,\n  glassboro, bellwood, peconic, pinesdale, hordville, wells, hendrix, dunmor, fleming,\n  mccloud, reserve, gilbert, bonesteel, roane, pocasset, greigsville, delton, whitethorn,\n  frontenac, siren, artois, helix, melstone, sultan, shumway, seeley, cousins, cauthron,\n  gough, anthoston, gladstone, macland, hiteman, shauck, jennie, airport, funkstown,\n  markleeville, marland, gloria, poyen, annawan, bolton, wattsville, waldenburg, pearce,\n  maiden, gasburg, calmar, applewold, kemblesville, redbank, wilseyville, lucan, archer,\n  castroville, pasatiempo, arkansaw, wyatt, shelbyville, merlin, whiteriver, torrington,\n  oyens]\nval_buildings: [cottonport, elton, corder, mazomanie, barranquitas, ihlen, wilkesboro,\n  macedon, portal, gastonia, thrall, orangeburg, poipu, kankakee, chiloquin, sussex,\n  maricopa, wesley, bowlus, copemish, tariffville, pomaria, kathryn, rutherford, plumerville,\n  waimea, experiment, dalcour, ohoopee, mifflinburg, callicoon, manassas, macksville,\n  apache, alfred, maida, dauberville, chireno, stilwell, albertville, ellaville,\n  kobuk, burien, carpio, placida, forkland, tippecanoe, beach, eagan, grace, portola,\n  hominy, maben]\n"
  },
  {
    "path": "config/split_fullplus.txt",
    "content": "train_buildings: [hanson, merom, arbutus, goodfield, eagan, arona, adairsville, reserve, aloha, castor, munsons, ballou, woonsocket, foyil, creede, superior, klickitat, cottonport, rancocas, ossipee, haxtun, tyler, sugarville, martinville, haaswood, euharlee, hacienda, peacock, convoy, roane, ellaville, chilhowie, tomales, springhill, shellsburg, chrisney, rutherford, mullica, winthrop, grace, codell, silas, braxton, chireno, bremerton, lynchburg, halfway, ballantine, hendrix, tokeland, merlin, baneberry, nimmons, mccloud, onaga, aldrich, maguayo, bountiful, carpio, ovalo, waimea, grassy, monticello, portal, melstone, cutlerville, frankfort, gilbert, kinney, gough, quantico, goodyear, waukeenah, pocopson, terrell, andover, ackermanville, copemish, leavittsburg, monson, applewold, sasakwa, anthoston, gaylord, whitethorn, clarkridge, gladstone, cornville, hatfield, circleville, marksville, trail, chesterbrook, gloria, funkstown, tippecanoe, airport, mesic, stokes, rogue, bethlehem, mcewen, hobson, kirwin, glenmoor, ancor, vails, biltmore, kangley, voorhees, leonardo, kevin, maida, castroville, seeley, millbury, waucousta, dansville, ludlowville, cooperstown, elton, marstons, athens, mentmore, inkom, model, ruckersville, hambleton, newfields, broseley, bettendorf, irvine, pinesdale, tilghmanton, american, westerville, maunawili, reyno, sanctuary, portola, goodwine, cullison,ran, cobalt, winooski, orangeburg, jacobus, oriole, lakeville, pasatiempo, beach, mahtomedi, redbank, cosmos, goffs, wilkinsburg, barboursville, frierson, sunshine, globe, kettle, booth, barranquitas, burien, carneiro, placida, hallettsville, fredericksburg, aulander, culbertson, bellwood, stanleyville, smoketown, winfield, seward, emmaus, macarthur, pomaria, kankakee, jenners, nemacolin, archer, neibert, random, parole, carpendale, spotswood, tolstoy, shelbyville, potterville, frontenac, avonia, laupahoehoe, milford, seatonville, roxboro, torrington, grantsville, rosser, spencerville, ewell, wyatt, plumerville, lynxville, allensville, springerville, nuevo, grangeville, stilwell, colebrook, browntown, germfask, readsboro, bowmore, pettigrew, brevort, fitchburg, shelbiana, hometown, montreal, mckeesport, wainscott, adrian, branford, dunmor, windhorst, alfred, islandton, arkansaw, rough, brewton, bonnie, cabin, hitchland, cebolla, mammoth, alstown, beechwood, hominy, crandon, sodaville, everton, maiden, churchton, divide, cayuse, helix, umpqua, chiloquin, schoolcraft, kathryn, coffeen, willow, dedham, corder, timberon, pleasant, darnestown, harrellsville, bohemia, micanopy, hillsdale, mashulaville, wilseyville, kemblesville, thrall, kronborg, fleming, bonesteel, hartline, channel, checotah, orason, fonda, kerrtown, noonday, capistrano, poyen, spread, annona, weleetka, silva, gasburg, milaca, bolton, kendall, stockman, whiteriver, ooltewah, wattsville, soldier, cohoes, lucan, neshkoro, newcomb, cason, roeville, kremlin, yankeetown, maryhill, deatsville, lenoir, byers, kildare, kirksville, blackstone, oyens, peden, azusa, mogote, keiser, potosi, delton, victorville, pamelia, cisne, marland, hiteman, lindberg, mayesville, sussex, nicut, lindsborg, seiling, mobridge, bautista, tradewinds, annawan, mogadore, retsof, murchison, lluveras, highspire, woodbine, lajas, sweatman, okabena, clairton, angiola, callicoon, hurley, touhy, michiana, sawpit, stockwell, country, lindenwood, anaheim, dryville, duarte, sumas, siren, musicks, wilbraham, holcut, kingdom, kopperl, calmar, forkland, mifflinburg, lineville, mosinee, hainesburg, gratz, maugansville, haymarket, uncertain, mifflintown, scandinavia, mazomanie, tariffville, mentasta, freedom, bowlus, apache, ranchester, fishersville, northgate, eagerville]\nval_buildings: [ogilvie, laytonsville, clive, hortense, wesley, idanha, maricopa, tomkins, texasville, sundown, southfield, moberly, auburn, eudora, wiconisco, tallmadge, ashport, lessley, assinippi, gravelly, silerton, hordville, corozal, swormville, warrenville, almota, cantwell, collierville, cashel, pearce, sisters, merchantville, glassboro, shauck, yscloskey, pablo, ribera, pittsburg, graceville, markleeville, hercules, macedon, howie, sands, cokeville, herricks, kinde, edgemere, kobuk, sagerton, starks, jennie, hornsby, greigsville, westfield, wyldwood, sarcoxie, artois, elmira, deemston, galatia, denmark, swisshome, wells, scioto, pocasset, waldenburg, shumway, waipahu, macksville, crookston, eastville, yadkinville, darden]\n"
  },
  {
    "path": "config/split_medium.txt",
    "content": "train_buildings: [hanson, merom, goodfield, eagan, adairsville, castor, klickitat, cottonport, tyler, sugarville, martinville, chilhowie, silas, lynchburg, tokeland, onaga, frankfort, goodyear, albertville, andover, airport, rogue, ancor, leonardo, maida, marstons, athens, newfields, broseley, irvine, pinesdale, tilghmanton, goodwine, hildebran, winooski, lakeville, cosmos, goffs, sunshine, globe, benevolence, emmaus, pomaria, neibert, parole, tolstoy, shelbyville, potterville, rosser, allensville, springerville, nuevo, stilwell, browntown, readsboro, shelbiana, wainscott, arkansaw, bonnie, beechwood, hominy, churchton, coffeen, willow, timberon, bohemia, micanopy, hillsdale, wilseyville, kemblesville, thrall, bonesteel, annona, stockman, soldier, neshkoro, newcomb, byers, oyens, victorville, pamelia, marland, hiteman, sussex, bautista, highspire, woodbine, sweatman, clairton, touhy, lindenwood, anaheim, duarte, musicks, forkland, mifflinburg, hainesburg, maugansville, ranchester]\nval_buildings: [hortense, southfield, wiconisco, gravelly, hordville, corozal, swormville, collierville, pearce, pablo, pittsburg, markleeville, sands, kobuk, westfield, wyldwood, swisshome, scioto, waipahu, darden]\n"
  },
  {
    "path": "datasets.py",
    "content": "\nimport numpy as np\nimport matplotlib as mpl\n\nimport os, sys, math, random, tarfile, glob, time, yaml, itertools\nimport parse\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset, DataLoader\nfrom torchvision import transforms, utils\n\nfrom utils import *\nfrom logger import Logger, VisdomLogger\nfrom task_configs import get_task, tasks\n\nfrom PIL import Image\nfrom io import BytesIO\nfrom sklearn.model_selection import train_test_split\nimport IPython\n\nimport pdb\n\n\"\"\" Default data loading configurations for training, validation, and testing. \"\"\"\ndef load_train_val(train_tasks, val_tasks=None, fast=False,\n        train_buildings=None, val_buildings=None, split_file=\"config/split.txt\",\n        dataset_cls=None, batch_size=32, batch_transforms=cycle,\n        subset=None, subset_size=None, dataaug=False,\n    ):\n\n    dataset_cls = dataset_cls or TaskDataset\n    train_cls = TrainTaskDataset if dataaug else dataset_cls\n    train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]\n    if val_tasks is None: val_tasks = train_tasks\n    val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks]  \n    data = yaml.load(open(split_file))\n    train_buildings = train_buildings or ([\"almena\"] if fast else data[\"train_buildings\"])\n    val_buildings = val_buildings or ([\"almena\"] if fast else data[\"val_buildings\"])\n    print(\"number of train images:\")\n    train_loader = train_cls(buildings=train_buildings, tasks=train_tasks)\n    print(\"number of val images:\")\n    val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)\n\n    if subset_size is not None or subset is not None:\n        train_loader = torch.utils.data.Subset(train_loader,\n            random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),\n        )\n\n    train_step = int(len(train_loader) // (400 * batch_size))\n    val_step = int(len(val_loader) // (400 * batch_size))\n    print(\"Train step: \", train_step)\n    print(\"Val step: \", val_step)\n    if fast: train_step, val_step = 8, 8\n\n    return train_loader, val_loader, train_step, val_step\n\n\n\"\"\" Load all buildings \"\"\"\ndef load_all(tasks, buildings=None, batch_size=64, split_file=\"data/split.txt\", batch_transforms=cycle):\n\n    data = yaml.load(open(split_file))\n    buildings = buildings or (data[\"train_buildings\"] + data[\"val_buildings\"])\n\n    data_loader = torch.utils.data.DataLoader(\n        TaskDataset(buildings=buildings, tasks=tasks),\n        batch_size=batch_size,\n        num_workers=0, shuffle=True, pin_memory=True\n    )\n\n    return data_loader\n\n\n\ndef load_test(all_tasks, buildings=[\"almena\", \"albertville\"], sample=4):\n\n    all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks]\n    print(f\"number of images in {buildings[0]}:\")\n    test_loader1 = torch.utils.data.DataLoader(\n        TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False),\n        batch_size=sample,\n        num_workers=0, shuffle=False, pin_memory=True,\n    )\n    print(f\"number of images in {buildings[1]}:\")\n    test_loader2 = torch.utils.data.DataLoader(\n        TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False),\n        batch_size=sample,\n        num_workers=0, shuffle=False, pin_memory=True,\n    )\n    set1 = list(itertools.islice(test_loader1, 1))[0]\n    set2 = list(itertools.islice(test_loader2, 1))[0]\n    test_set = tuple(torch.cat([x, y], dim=0) for x, y in zip(set1, set2))\n    return test_set\n\n\ndef load_ood(tasks=[tasks.rgb], ood_path=OOD_DIR, sample=21):\n    ood_loader = torch.utils.data.DataLoader(\n        ImageDataset(tasks=tasks, data_dir=ood_path),\n        batch_size=sample,\n        num_workers=sample, shuffle=False, pin_memory=True\n    )\n    ood_images = list(itertools.islice(ood_loader, 1))[0]\n    return ood_images\n\n\n\nclass TaskDataset(Dataset):\n\n    def __init__(self, buildings, tasks=[get_task(\"rgb\"), get_task(\"normal\")], data_dirs=DATA_DIRS,\n            building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True):\n\n        super().__init__()\n        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs\n        self.building_files = building_files or self.building_files\n        self.convert_path = convert_path or self.convert_path\n        self.resize = resize\n        if use_raid:\n            self.convert_path = self.convert_path_raid\n            self.building_files = self.building_files_raid\n\n        self.file_map = {}\n        for data_dir in self.data_dirs:\n            for file in glob.glob(f'{data_dir}/*'):\n                res = parse.parse(\"{building}_{task}\", file[len(data_dir)+1:])\n                if res is None: continue\n                self.file_map[file[len(data_dir)+1:]] = data_dir\n\n        filtered_files = None\n\n        assert (len(tasks) > 0), \"Building dataset for tasks, but no tasks specified!\"\n        task = tasks[0]\n        task_files = []\n        for building in buildings:\n            task_files += self.building_files(task, building)\n        print(f\"    {task.name} file len: {len(task_files)}\")\n        self.idx_files = task_files\n        if not shuffle: self.idx_files = sorted(task_files)\n\n        print (\"    Intersection files len: \", len(self.idx_files))\n\n    def reset_unpaired(self):\n        if self.unpaired:\n            self.task_indices = {task:random.sample(range(len(self.idx_files)), len(self.idx_files)) for task in self.task_indices}\n\n    def building_files(self, task, building):\n        \"\"\" Gets all the tasks in a given building (grouping of data) \"\"\"\n        return get_files(f\"{building}_{task.file_name}/{task.file_name}/*.{task.file_ext}\", self.data_dirs)\n\n    def building_files_raid(self, task, building):\n        return get_files(f\"{task.file_name}/{building}/*.{task.file_ext}\", self.data_dirs)\n\n    def convert_path(self, source_file, task):\n        \"\"\" Converts a file from task A to task B. Can be overriden by subclasses\"\"\"\n        source_file = \"/\".join(source_file.split('/')[-3:])\n        result = parse.parse(\"{building}_{task}/{task}/{view}_domain_{task2}.{ext}\", source_file)\n        building, _, view = (result[\"building\"], result[\"task\"], result[\"view\"])\n        dest_file = f\"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}\"\n        if f\"{building}_{task.file_name}\" not in self.file_map:\n            print (f\"{building}_{task.file_name} not in file map\")\n            # IPython.embed()\n            return \"\"\n        data_dir = self.file_map[f\"{building}_{task.file_name}\"]\n        return f\"{data_dir}/{dest_file}\"\n\n    def convert_path_raid(self, full_file, task):\n        \"\"\" Converts a file from task A to task B. Can be overriden by subclasses\"\"\"\n        source_file = \"/\".join(full_file.split('/')[-3:])\n        result = parse.parse(\"{task}/{building}/{view}.{ext}\", source_file)\n        building, _, view = (result[\"building\"], result[\"task\"], result[\"view\"])\n        dest_file = f\"{task.file_name}/{building}/{view}.{task.file_ext}\"\n        return f\"{full_file[:-len(source_file)-1]}/{dest_file}\"\n\n    def __len__(self):\n        return len(self.idx_files)\n\n    def __getitem__(self, idx):\n\n        for i in range(200):\n            try:\n                res = []\n\n                seed = random.randint(0, 1e10)\n\n                for task in self.tasks:\n                    file_name = self.convert_path(self.idx_files[idx], task)\n                    if len(file_name) == 0: raise Exception(\"unable to convert file\")\n                    image = task.file_loader(file_name, resize=self.resize, seed=seed)\n\n                    res.append(image)\n                return tuple(res)\n            except Exception as e:\n                idx = random.randrange(0, len(self.idx_files))\n                if i == 199: raise (e)\n\n\nclass TrainTaskDataset(TaskDataset):\n\n    def __getitem__(self, idx):\n\n        for i in range(200):\n            try:\n                res = []\n\n                seed = random.randint(0, 1e10)\n                crop = random.randint(int(0.7*512), 512) if bool(random.getrandbits(1)) else 512\n\n                for task in self.tasks:\n                    jitter = bool(random.getrandbits(1)) if task.name == 'rgb' else False\n                    file_name = self.convert_path(self.idx_files[idx], task)\n                    if len(file_name) == 0: raise Exception(\"unable to convert file\")\n                    image = task.file_loader(file_name, resize=self.resize, seed=seed, crop=crop, jitter=jitter)\n                    res.append(image)\n\n                return tuple(res)\n            except Exception as e:\n                idx = random.randrange(0, len(self.idx_files))\n                if i == 199: raise (e)\n\n\nclass ImageDataset(Dataset):\n\n    def __init__(\n        self,\n        tasks=[tasks.rgb],\n        data_dir=f\"data/ood_images\",\n        files=None,\n    ):\n\n        self.tasks = tasks\n        #if not USE_RAID and files is None:\n        #    os.system(f\"ls {data_dir}/*.png\")\n        #    os.system(f\"ls {data_dir}/*.png\")\n\n        self.files = files \\\n            or sorted(\n                glob.glob(f\"{data_dir}/*.png\")\n                + glob.glob(f\"{data_dir}/*.jpg\")\n                + glob.glob(f\"{data_dir}/*.jpeg\")\n            )\n\n        print(\"number of ood images: \", len(self.files))\n\n    def __len__(self):\n        return len(self.files)\n\n    def __getitem__(self, idx):\n\n        file = self.files[idx]\n        res = []\n        seed = random.randint(0, 1e10)\n        for task in self.tasks:\n            image = task.file_loader(file, seed=seed)\n            if image.shape[0] == 1: image = image.expand(3, -1, -1)\n            res.append(image)\n        return tuple(res)\n\n\n\n\nif __name__ == \"__main__\":\n\n    logger = VisdomLogger(\"data\", env=JOB)\n    train_dataset, val_dataset, train_step, val_step = load_train_val(\n        [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)],\n        batch_size=32,\n    )\n    print (\"created dataset\")\n    logger.add_hook(lambda logger, data: logger.step(), freq=32)\n\n    for i, _ in enumerate(train_dataset):\n        logger.update(\"epoch\", i)\n"
  },
  {
    "path": "demo.py",
    "content": "import torch\nfrom torchvision import transforms\n\nfrom modules.unet import UNet, UNetReshade\n\nimport PIL\nfrom PIL import Image\n\nimport argparse\nimport os.path\nfrom pathlib import Path\nimport glob\nimport sys\n\nimport pdb\n\n\n\nparser = argparse.ArgumentParser(description='Visualize output for a single Task')\n\nparser.add_argument('--task', dest='task', help=\"normal, depth or reshading\")\nparser.set_defaults(task='NONE')\n\nparser.add_argument('--img_path', dest='img_path', help=\"path to rgb image\")\nparser.set_defaults(im_name='NONE')\n\nparser.add_argument('--output_path', dest='output_path', help=\"path to where output image should be stored\")\nparser.set_defaults(store_name='NONE')\n\nargs = parser.parse_args()\n\nroot_dir = './models/'\ntrans_totensor = transforms.Compose([transforms.Resize(256, interpolation=PIL.Image.BILINEAR),\n                                    transforms.CenterCrop(256),\n                                    transforms.ToTensor()])\ntrans_topil = transforms.ToPILImage()\n\nos.system(f\"mkdir -p {args.output_path}\")\n\n# get target task and model\ntarget_tasks = ['normal','depth','reshading']\ntry:\n    task_index = target_tasks.index(args.task)\nexcept:\n    print(\"task should be one of the following: normal, depth, reshading\")\n    sys.exit()\nmodels = [UNet(), UNet(downsample=6, out_channels=1), UNetReshade(downsample=5)]\nmodel = models[task_index]\n\nmap_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')\n\ndef save_outputs(img_path, output_file_name):\n\n    img = Image.open(img_path)\n    img_tensor = trans_totensor(img)[:3].unsqueeze(0)\n\n    # compute baseline and consistency output\n    for type in ['baseline','consistency']:\n        path = root_dir + 'rgb2'+args.task+'_'+type+'.pth'\n        model_state_dict = torch.load(path, map_location=map_location)\n        model.load_state_dict(model_state_dict)\n        baseline_output = model(img_tensor).clamp(min=0, max=1)\n        trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png')\n\n\nimg_path = Path(args.img_path)\nif img_path.is_file():\n    save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0])\nelif img_path.is_dir():\n    for f in glob.glob(args.img_path+'/*'):\n        save_outputs(f, os.path.splitext(os.path.basename(f))[0])\nelse:\n    print(\"invalid file path!\")\n    sys.exit()\n"
  },
  {
    "path": "energy.py",
    "content": "import os, sys, math, random, itertools\nfrom functools import partial\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.utils.checkpoint import checkpoint\n\nfrom utils import *\nfrom task_configs import tasks, get_task, ImageTask\nfrom transfers import functional_transfers, finetuned_transfers, get_transfer_name, Transfer\nfrom datasets import TaskDataset, load_train_val\n\nfrom matplotlib.cm import get_cmap\n\n\nimport IPython\n\nimport pdb\n\ndef get_energy_loss(\n    config=\"\", mode=\"winrate\",\n    pretrained=True, finetuned=True, **kwargs,\n):\n    \"\"\" Loads energy loss from config dict. \"\"\"\n    if isinstance(mode, str):\n        mode = {\n            \"standard\": EnergyLoss,\n            \"winrate\": WinRateEnergyLoss,\n        }[mode]\n    return mode(**energy_configs[config],\n        pretrained=pretrained, finetuned=finetuned, **kwargs\n    )\n\nALL_PERCEPTUAL_TASKS = [tasks.principal_curvature,\n             tasks.sobel_edges,\n             tasks.depth_zbuffer,\n             tasks.edge_occlusion,\n             tasks.reshading,\n             tasks.keypoints3d,\n             tasks.keypoints2d]\n\ndef generate_config(perceptual_tasks, target_task=tasks.normal, tree_structure=False, has_gt=True):\n\n    # If we have GT, measure error\n    base_keys = {\n                    \"x\": [tasks.rgb],\n                    \"n((x))\": [tasks.rgb, target_task]\n    }\n    direct_losses = {}\n    if has_gt:\n        base_keys[\"y^\"] = [target_task]\n        direct_losses[\"direct_normal\"] = { (\"train\", \"val\", \"train_subset\"): [ (\"n((x))\", \"y^\"), ] }\n\n    # Add in losses for consistency energy\n    perceptual_losses = []\n    if tree_structure:\n        perceptual_losses = [('', t2) for t2 in perceptual_tasks]\n    else:\n        perceptual_losses = list(itertools.combinations(perceptual_tasks + [''], r=2))\n        perceptual_losses = [(t2, t1) if str(t2) == '' else (t1, t2) for t1, t2 in perceptual_losses]\n\n    return {\n        \"paths\": both( base_keys,\n                {\n                    f\"n({intermediate_task}(x))\": [tasks.rgb, intermediate_task, target_task]\n                    for intermediate_task in perceptual_tasks\n                }\n        ),\n        \"freeze_list\" : [(t, target_task) for t in perceptual_tasks] +\n                        [(tasks.rgb, t) for t in perceptual_tasks],\n        \"losses\": both( direct_losses,\n                {\n                    f'percep_{t1}_{t2}': { (\"train\", \"val\", \"train_subset\"): [ (f\"n({t1}(x))\", f\"n({t2}(x))\"), ], }\n                    for t1, t2 in perceptual_losses\n                }\n        ),\n        \"plots\": {\n            \"ID\": dict(\n                size=256,\n                realities=(\"test\",),\n                paths=[\n                    \"x\",\n                    \"n((x))\",\n                ]\n            ),\n        },\n    }\n\nenergy_configs = {\n\n    \"example_cascade_two_networks\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.normal],\n            \"n(x)\": [tasks.rgb, tasks.normal],\n            \"RC(x)\": [tasks.rgb, tasks.principal_curvature],\n            \"curv\": [tasks.principal_curvature],\n            \"f(y^)\": [tasks.normal, tasks.principal_curvature],\n            \"f(n(x))\": [tasks.rgb, tasks.normal, tasks.principal_curvature],\n        },\n        \"freeze_list\": [\n            [tasks.normal, tasks.principal_curvature],\n            [tasks.normal, tasks.sobel_edges],\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n            \"percep_curv\": {\n                (\"train\", \"val\"): [\n                    (\"f(n(x))\", \"f(y^)\"),\n                ],\n            },\n            \"direct_curv\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=('train', 'val'),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                    \"f(y^)\",\n                    \"f(n(x))\",\n                ]\n            ),\n        },\n    },\n\n\n    \"example_normal\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.normal],\n            \"n(x)\": [tasks.rgb, tasks.normal],\n            \"a(x)\": [tasks.rgb, tasks.sobel_edges],\n            \"RC(x)\": [tasks.rgb, tasks.principal_curvature],\n            \"edge\": [tasks.sobel_edges],\n            \"curv\": [tasks.principal_curvature],\n            \"s(y^)\": [tasks.normal, tasks.sobel_edges],\n            \"s(n(x))\": [tasks.rgb, tasks.normal, tasks.sobel_edges],\n            \"f(y^)\": [tasks.normal, tasks.principal_curvature],\n            \"f(n(x))\": [tasks.rgb, tasks.normal, tasks.principal_curvature],\n        },\n        \"freeze_list\": [\n            [tasks.normal, tasks.principal_curvature],\n            [tasks.normal, tasks.sobel_edges],\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n            \"percep_curv\": {\n                (\"train\", \"val\"): [\n                    (\"f(n(x))\", \"f(y^)\"),\n                ],\n            },\n            \"direct_curv\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n            \"percep_edge\": {\n                (\"train\", \"val\"): [\n                    (\"s(n(x))\", \"s(y^)\"),\n                ],\n            },\n            \"direct_edge\": {\n                (\"train\", \"val\"): [\n                    (\"a(x)\", \"s(y^)\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=('train', 'val'),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                    \"f(y^)\",\n                    \"f(n(x))\",\n                    \"s(y^)\",\n                    \"s(n(x))\",\n                ]\n            ),\n        },\n    },\n\n    \"example_normal_direct\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.normal],\n            \"n(x)\": [tasks.rgb, tasks.normal],\n        },\n        \"freeze_list\": [\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=('train', 'val'),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                ]\n            ),\n        },\n    },\n    \n    \"multiperceptual_normal\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.normal],\n            \"n(x)\": [tasks.rgb, tasks.normal],\n            \"RC(x)\": [tasks.rgb, tasks.principal_curvature],\n            \"a(x)\": [tasks.rgb, tasks.sobel_edges],\n            \"d(x)\": [tasks.rgb, tasks.reshading],\n            \"r(x)\": [tasks.rgb, tasks.depth_zbuffer],\n            \"EO(x)\": [tasks.rgb, tasks.edge_occlusion],\n            \"k2(x)\": [tasks.rgb, tasks.keypoints2d],\n            \"k3(x)\": [tasks.rgb, tasks.keypoints3d],\n            \"curv\": [tasks.principal_curvature],\n            \"edge\": [tasks.sobel_edges],\n            \"depth\": [tasks.depth_zbuffer],\n            \"reshading\": [tasks.reshading],\n            \"keypoints2d\": [tasks.keypoints2d],\n            \"keypoints3d\": [tasks.keypoints3d],\n            \"edge_occlusion\": [tasks.edge_occlusion],\n            \"f(y^)\": [tasks.normal, tasks.principal_curvature],\n            \"f(n(x))\": [tasks.rgb, tasks.normal, tasks.principal_curvature],\n            \"s(y^)\": [tasks.normal, tasks.sobel_edges],\n            \"s(n(x))\": [tasks.rgb, tasks.normal, tasks.sobel_edges],\n            \"g(y^)\": [tasks.normal, tasks.reshading],\n            \"g(n(x))\": [tasks.rgb, tasks.normal, tasks.reshading],\n            \"nr(y^)\": [tasks.normal, tasks.depth_zbuffer],\n            \"nr(n(x))\": [tasks.rgb, tasks.normal, tasks.depth_zbuffer],\n            \"Nk2(y^)\": [tasks.normal, tasks.keypoints2d],\n            \"Nk2(n(x))\": [tasks.rgb, tasks.normal, tasks.keypoints2d],\n            \"Nk3(y^)\": [tasks.normal, tasks.keypoints3d],\n            \"Nk3(n(x))\": [tasks.rgb, tasks.normal, tasks.keypoints3d],\n            \"nEO(y^)\": [tasks.normal, tasks.edge_occlusion],\n            \"nEO(n(x))\": [tasks.rgb, tasks.normal, tasks.edge_occlusion],\n            \"imagenet(y^)\": [tasks.normal, tasks.imagenet],\n            \"imagenet(n(x))\": [tasks.rgb, tasks.normal, tasks.imagenet],\n        },\n        \"freeze_list\": [\n            [tasks.normal, tasks.principal_curvature],\n            [tasks.normal, tasks.sobel_edges],\n            [tasks.normal, tasks.reshading],\n            [tasks.normal, tasks.depth_zbuffer],\n            [tasks.normal, tasks.keypoints3d],\n            [tasks.normal, tasks.keypoints2d],\n            [tasks.normal, tasks.edge_occlusion],\n            [tasks.normal, tasks.imagenet],\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n            \"percep_curv\": {\n                (\"train\", \"val\"): [\n                    (\"f(n(x))\", \"f(y^)\"),\n                ],\n            },\n            \"direct_curv\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n            \"percep_edge\": {\n                (\"train\", \"val\"): [\n                    (\"s(n(x))\", \"s(y^)\"),\n                ],\n            },\n            \"direct_edge\": {\n                (\"train\", \"val\"): [\n                    (\"a(x)\", \"s(y^)\"),\n                ],\n            },\n            \"percep_reshading\": {\n                (\"train\", \"val\"): [\n                    (\"g(n(x))\", \"g(y^)\"),\n                ],\n            },\n            \"direct_reshading\": {\n                (\"train\", \"val\"): [\n                    (\"d(x)\", \"reshading\"),\n                ],\n            },\n            \"percep_depth_zbuffer\": {\n                (\"train\", \"val\"): [\n                    (\"nr(n(x))\", \"nr(y^)\"),\n                ],\n            },\n            \"direct_depth_zbuffer\": {\n                (\"train\", \"val\"): [\n                    (\"r(x)\", \"depth\"),\n                ],\n            },\n            \"percep_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk2(n(x))\", \"Nk2(y^)\"),\n                ],\n            },\n            \"direct_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"k2(x)\", \"keypoints2d\"),\n                ],\n            },\n            \"percep_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk3(n(x))\", \"Nk3(y^)\"),\n                ],\n            },\n            \"direct_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"k3(x)\", \"keypoints3d\"),\n                ],\n            },\n            \"percep_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"nEO(n(x))\", \"nEO(y^)\"),\n                ],\n            },\n            \"direct_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"EO(x)\", \"edge_occlusion\"),\n                ],\n            },\n            \"percep_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"imagenet(n(x))\", \"imagenet(y^)\"),\n                ],\n            },\n            \"direct_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=(\"test\", \"ood\"),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                    \"f(y^)\",\n                    \"f(n(x))\",\n                    \"s(y^)\",\n                    \"s(n(x))\",\n                    \"g(y^)\",\n                    \"g(n(x))\",\n                    \"nr(n(x))\",\n                    \"nr(y^)\",\n                    \"Nk3(y^)\",\n                    \"Nk3(n(x))\",\n                    \"Nk2(y^)\",\n                    \"Nk2(n(x))\",\n                    \"nEO(y^)\",\n                    \"nEO(n(x))\",\n                ]\n            ),\n        },\n    },\n\n    \"multiperceptual_reshading\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.reshading],\n            \"n(x)\": [tasks.rgb, tasks.reshading],\n            \"RC(x)\": [tasks.rgb, tasks.principal_curvature],\n            \"a(x)\": [tasks.rgb, tasks.sobel_edges],\n            \"d(x)\": [tasks.rgb, tasks.normal],\n            \"r(x)\": [tasks.rgb, tasks.depth_zbuffer],\n            \"EO(x)\": [tasks.rgb, tasks.edge_occlusion],\n            \"k2(x)\": [tasks.rgb, tasks.keypoints2d],\n            \"k3(x)\": [tasks.rgb, tasks.keypoints3d],\n            \"curv\": [tasks.principal_curvature],\n            \"edge\": [tasks.sobel_edges],\n            \"depth\": [tasks.depth_zbuffer],\n            \"normal\": [tasks.normal],\n            \"keypoints2d\": [tasks.keypoints2d],\n            \"keypoints3d\": [tasks.keypoints3d],\n            \"edge_occlusion\": [tasks.edge_occlusion],\n            \"f(y^)\": [tasks.reshading, tasks.principal_curvature],\n            \"f(n(x))\": [tasks.rgb, tasks.reshading, tasks.principal_curvature],\n            \"s(y^)\": [tasks.reshading, tasks.sobel_edges],\n            \"s(n(x))\": [tasks.rgb, tasks.reshading, tasks.sobel_edges],\n            \"g(y^)\": [tasks.reshading, tasks.normal],\n            \"g(n(x))\": [tasks.rgb, tasks.reshading, tasks.normal],\n            \"nr(y^)\": [tasks.reshading, tasks.depth_zbuffer],\n            \"nr(n(x))\": [tasks.rgb, tasks.reshading, tasks.depth_zbuffer],\n            \"Nk2(y^)\": [tasks.reshading, tasks.keypoints2d],\n            \"Nk2(n(x))\": [tasks.rgb, tasks.reshading, tasks.keypoints2d],\n            \"Nk3(y^)\": [tasks.reshading, tasks.keypoints3d],\n            \"Nk3(n(x))\": [tasks.rgb, tasks.reshading, tasks.keypoints3d],\n            \"nEO(y^)\": [tasks.reshading, tasks.edge_occlusion],\n            \"nEO(n(x))\": [tasks.rgb, tasks.reshading, tasks.edge_occlusion],\n            \"imagenet(y^)\": [tasks.reshading, tasks.imagenet],\n            \"imagenet(n(x))\": [tasks.rgb, tasks.reshading, tasks.imagenet],\n        },\n        \"freeze_list\": [\n            [tasks.reshading, tasks.principal_curvature],\n            [tasks.reshading, tasks.sobel_edges],\n            [tasks.reshading, tasks.normal],\n            [tasks.reshading, tasks.depth_zbuffer],\n            [tasks.reshading, tasks.keypoints3d],\n            [tasks.reshading, tasks.keypoints2d],\n            [tasks.reshading, tasks.edge_occlusion],\n            [tasks.reshading, tasks.imagenet],\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n            \"percep_curv\": {\n                (\"train\", \"val\"): [\n                    (\"f(n(x))\", \"f(y^)\"),\n                ],\n            },\n            \"direct_curv\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n            \"percep_edge\": {\n                (\"train\", \"val\"): [\n                    (\"s(n(x))\", \"s(y^)\"),\n                ],\n            },\n            \"direct_edge\": {\n                (\"train\", \"val\"): [\n                    (\"a(x)\", \"s(y^)\"),\n                ],\n            },\n            \"percep_normal\": {\n                (\"train\", \"val\"): [\n                    (\"g(n(x))\", \"g(y^)\"),\n                ],\n            },\n            \"direct_normal\": {\n                (\"train\", \"val\"): [\n                    (\"d(x)\", \"normal\"),\n                ],\n            },\n            \"percep_depth_zbuffer\": {\n                (\"train\", \"val\"): [\n                    (\"nr(n(x))\", \"nr(y^)\"),\n                ],\n            },\n            \"direct_depth_zbuffer\": {\n                (\"train\", \"val\"): [\n                    (\"r(x)\", \"depth\"),\n                ],\n            },\n            \"percep_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk2(n(x))\", \"Nk2(y^)\"),\n                ],\n            },\n            \"direct_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"k2(x)\", \"keypoints2d\"),\n                ],\n            },\n            \"percep_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk3(n(x))\", \"Nk3(y^)\"),\n                ],\n            },\n            \"direct_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"k3(x)\", \"keypoints3d\"),\n                ],\n            },\n            \"percep_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"nEO(n(x))\", \"nEO(y^)\"),\n                ],\n            },\n            \"direct_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"EO(x)\", \"edge_occlusion\"),\n                ],\n            },\n            \"percep_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"imagenet(n(x))\", \"imagenet(y^)\"),\n                ],\n            },\n            \"direct_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=(\"test\", \"ood\"),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                    \"f(y^)\",\n                    \"f(n(x))\",\n                    \"s(y^)\",\n                    \"s(n(x))\",\n                    \"g(y^)\",\n                    \"g(n(x))\",\n                    \"nr(n(x))\",\n                    \"nr(y^)\",\n                    \"Nk3(y^)\",\n                    \"Nk3(n(x))\",\n                    \"Nk2(y^)\",\n                    \"Nk2(n(x))\",\n                    \"nEO(y^)\",\n                    \"nEO(n(x))\",\n                    \"depth\",\n                ]\n            ),\n        },\n    },\n\n    \"multiperceptual_depth\": {\n        \"paths\": {\n            \"x\": [tasks.rgb],\n            \"y^\": [tasks.depth_zbuffer],\n            \"n(x)\": [tasks.rgb, tasks.depth_zbuffer],\n            \"RC(x)\": [tasks.rgb, tasks.principal_curvature],\n            \"a(x)\": [tasks.rgb, tasks.sobel_edges],\n            \"d(x)\": [tasks.rgb, tasks.normal],\n            \"r(x)\": [tasks.rgb, tasks.reshading],\n            \"EO(x)\": [tasks.rgb, tasks.edge_occlusion],\n            \"k2(x)\": [tasks.rgb, tasks.keypoints2d],\n            \"k3(x)\": [tasks.rgb, tasks.keypoints3d],\n            \"curv\": [tasks.principal_curvature],\n            \"edge\": [tasks.sobel_edges],\n            \"normal\": [tasks.normal],\n            \"reshading\": [tasks.reshading],\n            \"keypoints2d\": [tasks.keypoints2d],\n            \"keypoints3d\": [tasks.keypoints3d],\n            \"edge_occlusion\": [tasks.edge_occlusion],\n            \"f(y^)\": [tasks.depth_zbuffer, tasks.principal_curvature],\n            \"f(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.principal_curvature],\n            \"s(y^)\": [tasks.depth_zbuffer, tasks.sobel_edges],\n            \"s(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.sobel_edges],\n            \"g(y^)\": [tasks.depth_zbuffer, tasks.normal],\n            \"g(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.normal],\n            \"nr(y^)\": [tasks.depth_zbuffer, tasks.reshading],\n            \"nr(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.reshading],\n            \"Nk2(y^)\": [tasks.depth_zbuffer, tasks.keypoints2d],\n            \"Nk2(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints2d],\n            \"Nk3(y^)\": [tasks.depth_zbuffer, tasks.keypoints3d],\n            \"Nk3(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints3d],\n            \"nEO(y^)\": [tasks.depth_zbuffer, tasks.edge_occlusion],\n            \"nEO(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.edge_occlusion],\n            \"imagenet(y^)\": [tasks.depth_zbuffer, tasks.imagenet],\n            \"imagenet(n(x))\": [tasks.rgb, tasks.depth_zbuffer, tasks.imagenet],\n        },\n        \"freeze_list\": [\n            [tasks.depth_zbuffer, tasks.principal_curvature],\n            [tasks.depth_zbuffer, tasks.sobel_edges],\n            [tasks.depth_zbuffer, tasks.normal],\n            [tasks.depth_zbuffer, tasks.reshading],\n            [tasks.depth_zbuffer, tasks.keypoints3d],\n            [tasks.depth_zbuffer, tasks.keypoints2d],\n            [tasks.depth_zbuffer, tasks.edge_occlusion],\n            [tasks.depth_zbuffer, tasks.imagenet],\n        ],\n        \"losses\": {\n            \"mae\": {\n                (\"train\", \"val\"): [\n                    (\"n(x)\", \"y^\"),\n                ],\n            },\n            \"percep_curv\": {\n                (\"train\", \"val\"): [\n                    (\"f(n(x))\", \"f(y^)\"),\n                ],\n            },\n            \"direct_curv\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n            \"percep_edge\": {\n                (\"train\", \"val\"): [\n                    (\"s(n(x))\", \"s(y^)\"),\n                ],\n            },\n            \"direct_edge\": {\n                (\"train\", \"val\"): [\n                    (\"a(x)\", \"s(y^)\"),\n                ],\n            },\n            \"percep_normal\": {\n                (\"train\", \"val\"): [\n                    (\"g(n(x))\", \"g(y^)\"),\n                ],\n            },\n            \"direct_normal\": {\n                (\"train\", \"val\"): [\n                    (\"d(x)\", \"normal\"),\n                ],\n            },\n            \"percep_reshading\": {\n                (\"train\", \"val\"): [\n                    (\"nr(n(x))\", \"nr(y^)\"),\n                ],\n            },\n            \"direct_reshading\": {\n                (\"train\", \"val\"): [\n                    (\"r(x)\", \"reshading\"),\n                ],\n            },\n            \"percep_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk2(n(x))\", \"Nk2(y^)\"),\n                ],\n            },\n            \"direct_keypoints2d\": {\n                (\"train\", \"val\"): [\n                    (\"k2(x)\", \"keypoints2d\"),\n                ],\n            },\n            \"percep_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"Nk3(n(x))\", \"Nk3(y^)\"),\n                ],\n            },\n            \"direct_keypoints3d\": {\n                (\"train\", \"val\"): [\n                    (\"k3(x)\", \"keypoints3d\"),\n                ],\n            },\n            \"percep_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"nEO(n(x))\", \"nEO(y^)\"),\n                ],\n            },\n            \"direct_edge_occlusion\": {\n                (\"train\", \"val\"): [\n                    (\"EO(x)\", \"edge_occlusion\"),\n                ],\n            },\n            \"percep_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"imagenet(n(x))\", \"imagenet(y^)\"),\n                ],\n            },\n            \"direct_imagenet_percep\": {\n                (\"train\", \"val\"): [\n                    (\"RC(x)\", \"curv\"),\n                ],\n            },\n        },\n        \"plots\": {\n            \"\": dict(\n                size=256,\n                realities=(\"test\", \"ood\"),\n                paths=[\n                    \"x\",\n                    \"y^\",\n                    \"n(x)\",\n                    \"f(y^)\",\n                    \"f(n(x))\",\n                    \"s(y^)\",\n                    \"s(n(x))\",\n                    \"g(y^)\",\n                    \"g(n(x))\",\n                    \"nr(n(x))\",\n                    \"nr(y^)\",\n                    \"Nk3(y^)\",\n                    \"Nk3(n(x))\",\n                    \"Nk2(y^)\",\n                    \"Nk2(n(x))\",\n                    \"nEO(y^)\",\n                    \"nEO(n(x))\",\n                ]\n            ),\n        },\n    },\n\n    \"energy_calc\": generate_config(ALL_PERCEPTUAL_TASKS),\n    \"energy_calc_nogt\": generate_config(ALL_PERCEPTUAL_TASKS, has_gt=False),\n}\n\n\n\ndef coeff_hook(coeff):\n    def fun1(grad):\n        return coeff*grad.clone()\n    return fun1\n\n\nclass EnergyLoss(object):\n\n    def __init__(self, paths, losses, plots,\n        pretrained=True, finetuned=False, freeze_list=[]\n    ):\n\n        self.paths, self.losses, self.plots = paths, losses, plots\n        self.freeze_list = [str((path[0].name, path[1].name)) for path in freeze_list]\n        self.metrics = {}\n\n        self.tasks = []\n        for _, loss_item in self.losses.items():\n            for realities, losses in loss_item.items():\n                for path1, path2 in losses:\n                    self.tasks += self.paths[path1] + self.paths[path2]\n\n        for name, config in self.plots.items():\n            for path in config[\"paths\"]:\n                self.tasks += self.paths[path]\n        self.tasks = list(set(self.tasks))\n\n    def compute_paths(self, graph, reality=None, paths=None):\n        path_cache = {}\n        paths = paths or self.paths\n        path_values = {\n            name: graph.sample_path(path,\n                reality=reality, use_cache=True, cache=path_cache,\n            ) for name, path in paths.items()\n        }\n        del path_cache\n        return {k: v for k, v in path_values.items() if v is not None}\n\n    def get_tasks(self, reality):\n        tasks = []\n        for _, loss_item in self.losses.items():\n            for realities, losses in loss_item.items():\n                if reality in realities:\n                    for path1, path2 in losses:\n                        tasks += [self.paths[path1][0], self.paths[path2][0]]\n\n        for name, config in self.plots.items():\n            if reality in config[\"realities\"]:\n                for path in config[\"paths\"]:\n                    tasks += [self.paths[path][0]]\n\n        return list(set(tasks))\n\n    def __call__(self, graph, discriminator=None, realities=[], loss_types=None, reduce=True, use_l1=False):\n        #pdb.set_trace()\n        loss = {}\n        for reality in realities:\n            loss_dict = {}\n            losses = []\n            all_loss_types = set()\n            for loss_type, loss_item in self.losses.items():\n                all_loss_types.add(loss_type)\n                loss_dict[loss_type] = []\n                for realities_l, data in loss_item.items():\n                    if reality.name in realities_l:\n                        loss_dict[loss_type] += data\n                        if loss_types is not None and loss_type in loss_types:\n                            losses += data\n\n            path_values = self.compute_paths(graph,\n                paths={\n                    path: self.paths[path] for path in \\\n                    set(path for paths in losses for path in paths)\n                    },\n                reality=reality)\n\n            if reality.name not in self.metrics:\n                self.metrics[reality.name] = defaultdict(list)\n\n            for loss_type, losses in sorted(loss_dict.items()):\n                if loss_type not in (loss_types or all_loss_types):\n                    continue\n                if loss_type not in loss:\n                    loss[loss_type] = 0\n                for path1, path2 in losses:\n                    output_task = self.paths[path1][-1]\n                    compute_mask = 'imagenet(n(x))' != path1\n                    if loss_type not in loss:\n                        loss[loss_type] = 0\n                    for path1, path2 in losses:\n                        output_task = self.paths[path1][-1]\n                        if \"direct\" in loss_type:\n                            with torch.no_grad():\n                                path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False)\n                                loss[loss_type] += path_loss\n                        else:\n                            path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False)\n                            loss[loss_type] += path_loss\n                            loss_name = \"mae\" if \"mae\" in loss_type else loss_type+\"_mae\"\n                            self.metrics[reality.name][loss_name +\" : \"+path1 + \" -> \" + path2] += [path_loss.mean().detach().cpu()]\n                            path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=True)\n                            loss_name = \"mse\" if \"mae\" in loss_type else loss_type + \"_mse\"\n                            self.metrics[reality.name][loss_name +\" : \"+path1 + \" -> \" + path2] += [path_loss.mean().detach().cpu()]\n\n        return loss\n\n    def logger_hooks(self, logger):\n\n        name_to_realities = defaultdict(list)\n        for loss_type, loss_item in self.losses.items():\n            for realities, losses in loss_item.items():\n                for path1, path2 in losses:\n                    loss_name = \"mae\" if \"mae\" in loss_type else loss_type+\"_mae\"\n                    name = loss_name+\" : \"+path1 + \" -> \" + path2\n                    name_to_realities[name] += list(realities)\n                    loss_name = \"mse\" if \"mae\" in loss_type else loss_type + \"_mse\"\n                    name = loss_name+\" : \"+path1 + \" -> \" + path2\n                    name_to_realities[name] += list(realities)\n\n        for name, realities in name_to_realities.items():\n            def jointplot(logger, data, name=name, realities=realities):\n                names = [f\"{reality}_{name}\" for reality in realities]\n                if not all(x in data for x in names):\n                    return\n                data = np.stack([data[x] for x in names], axis=1)\n                logger.plot(data, name, opts={\"legend\": names})\n\n            logger.add_hook(partial(jointplot, name=name, realities=realities), feature=f\"{realities[-1]}_{name}\", freq=1)\n\n\n    def logger_update(self, logger):\n\n        name_to_realities = defaultdict(list)\n        for loss_type, loss_item in self.losses.items():\n            for realities, losses in loss_item.items():\n                for path1, path2 in losses:\n                    loss_name = \"mae\" if \"mae\" in loss_type else loss_type+\"_mae\"\n                    name = loss_name+\" : \"+path1 + \" -> \" + path2\n                    name_to_realities[name] += list(realities)\n                    loss_name = \"mse\" if \"mae\" in loss_type else loss_type + \"_mse\"\n                    name = loss_name+\" : \"+path1 + \" -> \" + path2\n                    name_to_realities[name] += list(realities)\n\n        for name, realities in name_to_realities.items():\n            for reality in realities:\n                # IPython.embed()\n                if reality not in self.metrics: continue\n                if name not in self.metrics[reality]: continue\n                if len(self.metrics[reality][name]) == 0: continue\n\n                logger.update(\n                    f\"{reality}_{name}\",\n                    torch.mean(torch.stack(self.metrics[reality][name])),\n                )\n        self.metrics = {}\n\n    def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0,prefix=\"\"):\n        error_pairs = {\"n(x)\": \"y^\"}\n        realities_map = {reality.name: reality for reality in realities}\n        for name, config in (plot_names or self.plots.items()):\n            paths = config[\"paths\"]\n\n            realities = config[\"realities\"]\n            images = []\n            error = False\n            cmap = get_cmap(\"jet\")\n\n            first = True\n            error_passed_ood = 0\n            for reality in realities:\n                with torch.no_grad():\n                    path_values = self.compute_paths(graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality])\n\n                shape = list(path_values[list(path_values.keys())[0]].shape)\n                shape[1] = 3\n\n                for i, path in enumerate(paths):\n                    if path == 'depth': continue\n                    X = path_values.get(path, torch.zeros(shape, device=DEVICE))\n                    if first: images +=[[]]\n\n                    if reality is 'ood' and error_passed_ood==0:\n                        images[i].append(X.clamp(min=0, max=1).expand(*shape))\n                    elif reality is 'ood' and error_passed_ood==1:\n                        images[i+1].append(X.clamp(min=0, max=1).expand(*shape))\n                    else:\n                        images[-1].append(X.clamp(min=0, max=1).expand(*shape))\n\n                    if path in error_pairs:\n\n                        error = True\n                        if first:\n                            images += [[]]\n\n\n                    if error:\n\n                        Y = path_values.get(path, torch.zeros(shape, device=DEVICE))\n                        Y_hat = path_values.get(error_pairs[path], torch.zeros(shape, device=DEVICE))\n\n                        out_task = self.paths[path][-1]\n\n                        if self.target_task == \"reshading\": #Use depth mask\n                            Y_mask = path_values.get(\"depth\", torch.zeros(shape, device = DEVICE))\n                            mask_task = self.paths[\"r(x)\"][-1]\n                            mask = ImageTask.build_mask(Y_mask, val=mask_task.mask_val)\n                        else:\n                            mask = ImageTask.build_mask(Y_hat, val=out_task.mask_val)\n\n                        errors = ((Y - Y_hat)**2).mean(dim=1, keepdim=True)\n                        log_errors = torch.log(errors.clamp(min=0, max=out_task.variance))\n\n\n                        errors = (3*errors/(out_task.variance)).clamp(min=0, max=1)\n\n                        log_errors = torch.log(errors + 1)\n                        log_errors = log_errors / log_errors.max()\n                        log_errors = torch.tensor(cmap(log_errors.cpu()))[:, 0].permute((0, 3, 1, 2)).float()[:, 0:3]\n                        log_errors = log_errors.clamp(min=0, max=1).expand(*shape).to(DEVICE)\n                        log_errors[~mask.expand_as(log_errors)] = 0.505\n                        if reality is 'ood':\n                            images[i+1].append(log_errors)\n                            error_passed_ood = 1\n                        else:\n                            images[-1].append(log_errors)\n\n                        error = False\n                first = False\n\n            for i in range(0, len(images)):\n                images[i] = torch.cat(images[i], dim=0)\n\n            logger.images_grouped(images,\n                f\"{prefix}_{name}_[{', '.join(realities)}]_[{', '.join(paths)}]\",\n                resize=config[\"size\"]\n            )\n\n    def __repr__(self):\n        return str(self.losses)\n\n\nclass WinRateEnergyLoss(EnergyLoss):\n\n    def __init__(self, *args, **kwargs):\n        self.k = kwargs.pop('k', 3)\n        self.random_select = kwargs.pop('random_select', False)\n        self.running_stats = {}\n        self.target_task = kwargs['paths']['y^'][0].name\n\n        super().__init__(*args, **kwargs)\n\n        self.percep_losses = [key[7:] for key in self.losses.keys() if key[0:7] == \"percep_\"]\n        print (\"percep losses:\",self.percep_losses)\n        self.chosen_losses = random.sample(self.percep_losses, self.k)\n\n    def __call__(self, graph, discriminator=None, realities=[], loss_types=None, compute_grad_ratio=False):\n\n        loss_types = [\"mae\"] + [(\"percep_\" + loss) for loss in self.percep_losses] + [(\"direct_\" + loss) for loss in self.percep_losses]\n        # print (self.chosen_losses)\n        loss_dict = super().__call__(graph, discriminator=discriminator, realities=realities, loss_types=loss_types, reduce=False)\n\n        chosen_percep_mse_losses = [k for k in loss_dict.keys() if 'direct' not in k]\n        percep_mse_coeffs = dict.fromkeys(chosen_percep_mse_losses, 1.0)\n        ########### to compute loss coefficients #############\n        if compute_grad_ratio:\n            percep_mse_gradnorms = dict.fromkeys(chosen_percep_mse_losses, 1.0)\n            for loss_name in chosen_percep_mse_losses:\n                loss_dict[loss_name].mean().backward(retain_graph=True)\n                target_weights=list(graph.edge_map[f\"('rgb', '{self.target_task}')\"].model.parameters())\n                percep_mse_gradnorms[loss_name] = sum([l.grad.abs().sum().item() for l in target_weights])/sum([l.numel() for l in target_weights])\n                graph.optimizer.zero_grad()\n                graph.zero_grad()\n                del target_weights\n            total_gradnorms = sum(percep_mse_gradnorms.values())\n            n_losses = len(chosen_percep_mse_losses)\n            for loss_name, val in percep_mse_coeffs.items():\n                percep_mse_coeffs[loss_name] = (total_gradnorms-percep_mse_gradnorms[loss_name])/((n_losses-1)*total_gradnorms)\n            percep_mse_coeffs[\"mae\"] *= (n_losses-1)\n        ###########################################\n\n        for key in self.chosen_losses:\n            winrate = torch.mean((loss_dict[f\"percep_{key}\"] > loss_dict[f\"direct_{key}\"]).float())\n            winrate = winrate.detach().cpu().item()\n            if winrate < 1.0:\n                self.running_stats[key] = winrate\n            loss_dict[f\"percep_{key}\"] = loss_dict[f\"percep_{key}\"].mean() * percep_mse_coeffs[f\"percep_{key}\"]\n            loss_dict.pop(f\"direct_{key}\")\n\n        # print (self.running_stats)\n        loss_dict[\"mae\"] = loss_dict[\"mae\"].mean() * percep_mse_coeffs[\"mae\"]\n\n        return loss_dict, percep_mse_coeffs[\"mae\"]\n\n    def logger_update(self, logger):\n        super().logger_update(logger)\n        if self.random_select or len(self.running_stats) < len(self.percep_losses):\n            self.chosen_losses = random.sample(self.percep_losses, self.k)\n        else:\n            self.chosen_losses = sorted(self.running_stats, key=self.running_stats.get, reverse=True)[:self.k]\n\n        logger.text (f\"Chosen losses: {self.chosen_losses}\")\n\n\n"
  },
  {
    "path": "graph.py",
    "content": "import os, sys, math, random, itertools, heapq\nfrom collections import namedtuple, defaultdict\nfrom functools import partial, reduce\nimport numpy as np\nimport IPython\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils import *\nfrom models import TrainableModel, WrapperModel\nfrom datasets import TaskDataset\nfrom task_configs import get_task, task_map, tasks, get_model, RealityTask\nfrom transfers import Transfer, RealityTransfer, get_transfer_name\n\n#from modules.gan_dis import GanDisNet\n\nimport pdb\n\nclass TaskGraph(TrainableModel):\n    \"\"\"Basic graph that encapsulates set of edge constraints. Can be saved and loaded\n    from directories.\"\"\"\n\n    def __init__(\n        self, tasks=tasks, edges=None, edges_exclude=None,\n        pretrained=True, finetuned=False,\n        reality=[], task_filter=[tasks.segment_semantic],\n        freeze_list=[], lazy=False, initialize_from_transfer=True,\n    ):\n\n        super().__init__()\n        self.tasks = list(set(tasks) - set(task_filter))\n        self.tasks += [task.base for task in self.tasks if hasattr(task, \"base\")]\n        self.edge_list, self.edge_list_exclude = edges, edges_exclude\n        self.pretrained, self.finetuned = pretrained, finetuned\n        self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(list)\n        self.edge_map, self.reality = {}, reality\n        self.initialize_from_transfer = initialize_from_transfer\n        print('Creating graph with tasks:', self.tasks)\n        self.params = {}\n\n        # construct transfer graph\n        for src_task, dest_task in itertools.product(self.tasks, self.tasks):\n            key = (src_task, dest_task)\n            if edges is not None and key not in edges: continue\n            if edges_exclude is not None and key in edges_exclude: continue\n            if src_task == dest_task: continue\n            if isinstance(dest_task, RealityTask): continue\n            # print (src_task, dest_task)\n            transfer = None\n            if isinstance(src_task, RealityTask):\n                if dest_task not in src_task.tasks: continue\n                transfer = RealityTransfer(src_task, dest_task)\n            else:\n                transfer = Transfer(src_task, dest_task,\n                    pretrained=pretrained, finetuned=finetuned\n                )\n                transfer.name = get_transfer_name(transfer)\n                if not self.initialize_from_transfer:\n                    transfer.path = None\n            if transfer.model_type is None:\n                continue\n            # print (\"Added transfer\", transfer)\n            self.edges += [transfer]\n            self.adj[src_task.name] += [transfer]\n            self.in_adj[dest_task.name] += [transfer]\n            self.edge_map[str((src_task.name, dest_task.name))] = transfer\n            if isinstance(transfer, nn.Module):\n                if str((src_task.name, dest_task.name)) not in freeze_list:\n                    self.params[str((src_task.name, dest_task.name))] = transfer\n                else:\n                    print(\"Setting link: \" + str((src_task.name, dest_task.name)) + \" not trainable.\")\n                try:\n                    if not lazy: transfer.load_model()\n                except Exception as e:\n                    print(e)\n                    IPython.embed()\n\n        self.params = nn.ModuleDict(self.params)\n\n    def edge(self, src_task, dest_task):\n        key1 = str((src_task.name, dest_task.name))\n        key2 = str((src_task.kind, dest_task.kind))\n        if key1 in self.edge_map: return self.edge_map[key1]\n        return self.edge_map[key2]\n\n    def sample_path(self, path, reality=None, use_cache=False, cache={}):\n        path = [reality or self.reality[0]] + path\n        x = None\n        for i in range(1, len(path)):\n            try:\n                # if x is not None: print (x.shape)\n                # print (self.edge(path[i-1], path[i]))\n                x = cache.get(tuple(path[0:(i+1)]),\n                    self.edge(path[i-1], path[i])(x)\n                )\n            except KeyError:\n                return None\n            except Exception as e:\n                print(e)\n                IPython.embed()\n\n            if use_cache: cache[tuple(path[0:(i+1)])] = x\n        return x\n\n    def save(self, weights_file=None, weights_dir=None):\n\n        ### TODO: save optimizers here too\n        if weights_file:\n            torch.save({\n                key: model.state_dict() for key, model in self.edge_map.items() \\\n                if not isinstance(model, RealityTransfer)\n            }, weights_file)\n\n        if weights_dir:\n            os.makedirs(weights_dir, exist_ok=True)\n            for key, model in self.edge_map.items():\n                if isinstance(model, RealityTransfer): continue\n                if not isinstance(model.model, TrainableModel): continue\n                model.model.save(f\"{weights_dir}/{model.name}.pth\")\n            torch.save(self.optimizer, f\"{weights_dir}/optimizer.pth\")\n\n\n#    def load_weights(self, weights_file=None):\n#        for key, state_dict in torch.load(weights_file).items():\n#            if key in self.edge_map:\n#                self.edge_map[key].load_state_dict(state_dict)\n\n    def load_weights(self, weights_file=None):\n        loaded_something = False\n        for key, state_dict in torch.load(weights_file).items():\n            if key in self.edge_map:\n                loaded_something = True\n                self.edge_map[key].load_model()\n                self.edge_map[key].load_state_dict(state_dict)\n        if not loaded_something:\n            raise RuntimeError(f\"No edges loaded from file: {weights_file}\")\n"
  },
  {
    "path": "hooks/build",
    "content": "#!/bin/bash\n\ndocker build . -t $IMAGE_NAME --build-arg GITHUB_DEPLOY_KEY=\"$GITHUB_DEPLOY_KEY\" --build-arg GITHUB_DEPLOY_KEY_PUBLIC=\"$GITHUB_DEPLOY_KEY_PUBLIC\"\n\n"
  },
  {
    "path": "logger.py",
    "content": "\nimport numpy as np\nimport matplotlib as mpl\nmpl.use('Agg')\nimport matplotlib.pyplot as plt\nimport random, sys, os, json, math\n\nimport torch\nfrom torchvision import datasets, transforms, utils\nimport visdom\n\nfrom utils import *\nfrom utils import elapsed\nimport IPython\nimport pdb\n\nclass BaseLogger(object):\n    \"\"\" Logger class, with hooks for data features and plotting functions. \"\"\"\n    def __init__(self, name, verbose=True):\n\n        self.name = name\n        self.data = {}\n        self.running_data = {}\n        self.reset_running = {}\n        self.verbose = verbose\n        self.hooks = []\n\n    def add_hook(self, hook, feature='epoch', freq=40):\n        self.hooks.append((hook, feature, freq))\n\n    def update(self, feature, x):\n        if isinstance(x, torch.Tensor):\n            x = x.clone().detach().cpu().numpy().mean()\n        else:\n            x = torch.tensor(x).data.cpu().numpy().mean()\n\n        self.data[feature] = self.data.get(feature, [])\n        self.data[feature].append(x)\n        if feature not in self.running_data or self.reset_running.pop(feature, False):\n            self.running_data[feature] = []\n        self.running_data[feature].append(x)\n\n        for hook, hook_feature, freq in self.hooks:\n            if feature == hook_feature and len(self.data[feature]) % freq == 0:\n                hook(self, self.data)\n\n    def step(self):\n        buf = \"\"\n        buf += f\"({self.name}) \"\n        for feature in self.running_data.keys():\n            if len(self.running_data[feature]) == 0: continue\n            val = np.mean(self.running_data[feature])\n            if float(val).is_integer():\n                buf += f\"{feature}: {int(val)}, \"\n            else:\n                buf += f\"{feature}: {val:0.4f}\" + \", \"\n            self.reset_running[feature] = True\n        buf += f\" ... {elapsed():0.2f} sec\"\n        self.text (buf)\n\n    def text(self, text, end=\"\\n\"):\n        raise NotImplementedError()\n\n    def plot(self, data, plot_name, opts={}):\n        raise NotImplementedError()\n\n    def images(self, data, image_name):\n        raise NotImplementedError()\n\n    def plot_feature(self, feature, opts={}):\n        self.plot(self.data[feature], feature, opts)\n\n    def plot_features(self, features, name, opts={}):\n        stacked = np.stack([self.data[feature] for feature in features], axis=1)\n        self.plot(stacked, name, opts={\"legend\": features})\n\n\nclass Logger(BaseLogger):\n\n    def __init__(self, *args, **kwargs):\n        self.results = kwargs.pop('results', 'output')\n        super().__init__(*args, **kwargs)\n\n    def text(self, text, end='\\n'):\n        print (text, end=end, flush=True)\n\n    def plot(self, data, plot_name, opts={}):\n        np.savez_compressed(f\"{self.results}/{plot_name}.npz\", data)\n        plt.plot(data)\n        plt.savefig(f\"{self.results}/{plot_name}.jpg\");\n        plt.clf()\n\n\nclass VisdomLogger(BaseLogger):\n\n    def __init__(self, *args, **kwargs):\n        self.env = kwargs.pop('env', 'CH')\n        self.port = kwargs.pop('port', 8097)\n        self.server = kwargs.pop('server', '127.0.0.1')\n        self.delete = kwargs.pop('delete', True)\n        print (\"No deletion\")\n        print (\"In (git) scaling-reset\")\n        print (f\"Logging to environment {self.env}\")\n        self.visdom = visdom.Visdom(server=\"http://\" + self.server, port=self.port, env=self.env)\n        self.visdom.delete_env(self.env)\n        self.windows = {}\n        super().__init__(*args, **kwargs)\n        self.save()\n        self.add_hook(lambda logger, data: self.save(), feature=\"epoch\", freq=1)\n\n    def text(self, text, end='\\n'):\n        print (text, end=end)\n        window, old_text = self.windows.get('text', (None, \"\"))\n        if end == '\\n': end = '<br>'\n        display = old_text + text + end\n\n        if window is not None:\n            window = self.visdom.text (display, win=window, append=False)\n        else:\n            window = self.visdom.text (display)\n\n        self.windows[\"text\"] = window, display\n\n    def window(self, plot_name, plot_func, *args, **kwargs):\n\n        options = {'title': plot_name}\n        options.update(kwargs.pop(\"opts\", {}))\n        window = self.windows.get(plot_name, None)\n        if window is not None and self.visdom.win_exists(window):\n            window = plot_func(*args, **kwargs, opts=options, win=window)\n        else:\n            window = plot_func(*args, **kwargs, opts=options)\n\n        self.windows[plot_name] = window\n\n    def plot(self, data, plot_name, opts={}):\n        self.window(plot_name, self.visdom.line,\n            np.array(data), X=np.array(range(len(data))), opts=opts\n        )\n\n    def histogram(self, data, plot_name, opts={}):\n        self.window(plot_name, self.visdom.histogram, np.array(data), opts=opts)\n\n    def scatter(self, X, Y, plot_name, opts={}):\n        self.window(plot_name, self.visdom.scatter, np.stack([X, Y], axis=1), opts=opts)\n\n    def bar(self, data, plot_name, opts={}):\n        self.window(plot_name, self.visdom.bar, np.array(data), opts=opts)\n\n    def save(self):\n        self.visdom.save([self.env])\n\n    def images(self, data, plot_name, opts={}, nrow=2, normalize=False, resize=64):\n\n        transform = transforms.Compose([\n                                    transforms.ToPILImage(),\n                                    transforms.Resize(resize),\n                                    transforms.ToTensor()])\n        data = torch.stack([transform(x.cpu()) for x in data])\n        data = utils.make_grid(data, nrow=nrow, normalize=normalize, pad_value=0)\n        self.window(plot_name, self.visdom.image, np.array(data), opts=opts)\n\n    def images_grouped(self, image_groups, plot_name, **kwargs):\n        interleave = [y for x in zip(*image_groups) for y in x]\n        self.images(interleave, plot_name, nrow=len(image_groups), **kwargs)\n\n\n"
  },
  {
    "path": "models.py",
    "content": "import os, sys, random\nfrom inspect import signature\nimport numpy as np\nimport matplotlib as mpl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom   torch import optim\n\nfrom utils import *\n\n\n\"\"\" Model that implements batchwise training with \"compilation\" and custom loss.\nExposed methods: predict_on_batch(), fit_on_batch(),\nOverridable methods: loss(), forward().\n\"\"\"\n\n\nclass AbstractModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.compiled = False\n\n    # Compile module and assign optimizer + params\n    def compile(self, optimizer=None, **kwargs):\n\n        if optimizer is not None:\n            self.optimizer_class = optimizer\n            self.optimizer_kwargs = kwargs\n            self.optimizer = self.optimizer_class(self.parameters(), **self.optimizer_kwargs)\n        else:\n            self.optimizer = None\n\n        self.compiled = True\n        self.to(DEVICE)\n\n    # Predict scores from a batch of data\n    def predict_on_batch(self, data):\n\n        self.eval()\n        with torch.no_grad():\n            return self.forward(data)\n\n    # Fit (make one optimizer step) on a batch of data\n    def fit_on_batch(self, data, target, loss_fn=None, train=True):\n        loss_fn = loss_fn or self.loss\n\n        self.zero_grad()\n        self.optimizer.zero_grad()\n\n        self.train(train)\n\n        self.zero_grad()\n        self.optimizer.zero_grad()\n        pred = self.forward(data)\n        if isinstance(target, list):\n            target = tuple(t.to(pred.device) for t in target)\n        else: target = target.to(pred.device)\n\n        if len(signature(loss_fn).parameters) > 2:\n            loss, metrics = loss_fn(pred, target, data.to(pred.device))\n        else:\n            loss, metrics = loss_fn(pred, target)\n\n        if train:\n            loss.backward()\n            self.optimizer.step()\n            self.zero_grad()\n            self.optimizer.zero_grad()\n\n        return pred, loss, metrics\n\n    # Make one optimizer step w.r.t a loss\n    def step(self, loss, train=True):\n\n        self.zero_grad()\n        self.optimizer.zero_grad()\n        self.train(train)\n        self.zero_grad()\n        self.optimizer.zero_grad()\n\n        loss.backward()\n        self.optimizer.step()\n        self.zero_grad()\n        self.optimizer.zero_grad()\n\n    @classmethod\n    def load(cls, weights_file=None):\n        model = cls()\n        if weights_file is not None:\n            data = torch.load(weights_file)\n            # hack for models saved with optimizers\n            if \"optimizer\" in data: data = data[\"state_dict\"]\n            model.load_state_dict(data)\n        return model\n\n    def load_weights(self, weights_file, backward_compatible=False):\n        data = torch.load(weights_file)\n        if backward_compatible:\n            data = {'parallel_apply.module.'+k:v for k,v in data.items()}\n        self.load_state_dict(data)\n\n    def save(self, weights_file):\n        torch.save(self.state_dict(), weights_file)\n\n    # Subclasses: override for custom loss + forward functions\n    def loss(self, pred, target):\n        raise NotImplementedError()\n\n    def forward(self, x):\n        raise NotImplementedError()\n\n\n\"\"\" Model that implements training and prediction on generator objects, with\nthe ability to print train and validation metrics.\n\"\"\"\n\n\nclass TrainableModel(AbstractModel):\n    def __init__(self):\n        super().__init__()\n\n    # Fit on generator for one epoch\n    def _process_data(self, datagen, loss_fn=None, train=True, logger=None):\n\n        self.train(train)\n        out = []\n        for data in datagen:\n            batch, y = data[0], data[1:]\n            if len(y) == 1: y = y[0]\n            y_pred, loss, metric_data = self.fit_on_batch(batch, y, loss_fn=loss_fn, train=train)\n            if logger is not None:\n                logger.update(\"loss\", float(loss))\n            yield ((batch.detach(), y_pred.detach(), y, float(loss), metric_data))\n\n    def fit(self, datagen, loss_fn=None, logger=None):\n        for x in self._process_data(datagen, loss_fn=loss_fn, train=train, logger=logger):\n            pass\n\n    def fit_with_data(self, datagen, loss_fn=None, logger=None):\n        images, preds, targets, losses, metrics = zip(\n            *self._process_data(datagen, loss_fn=loss_fn, train=True, logger=logger)\n        )\n        images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0)\n        metrics = zip(*metrics)\n        return images, preds, targets, losses, metrics\n\n    def fit_with_metrics(self, datagen, loss_fn=None, logger=None):\n        metrics = [\n            metrics\n            for _, _, _, _, metrics in self._process_data(\n                datagen, loss_fn=loss_fn, train=True, logger=logger\n            )\n        ]\n        return list(zip(*metrics))\n\n    def predict_with_data(self, datagen, loss_fn=None, logger=None):\n        images, preds, targets, losses, metrics = zip(\n            *self._process_data(datagen, loss_fn=loss_fn, train=False, logger=logger)\n        )\n        images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0)\n        images, preds, targets = images.cpu(), preds.cpu(), targets.cpu()\n        # preds = torch.cat(preds, dim=0)\n        metrics = zip(*metrics)\n        return images, preds, targets, losses, metrics\n\n    def predict_with_metrics(self, datagen, loss_fn=None, logger=None):\n        metrics = [\n            metrics\n            for _, _, _, _, metrics in self._process_data(\n                datagen, loss_fn=loss_fn, train=False, logger=logger\n            )\n        ]\n        return list(zip(*metrics))\n\n    def predict(self, datagen):\n        preds = [self.predict_on_batch(x) for x in datagen]\n        preds = torch.cat(preds, dim=0)\n        return preds\n\n\nclass DataParallelModel(TrainableModel):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n        self.parallel_apply = nn.DataParallel(*args, **kwargs)\n\n    def forward(self, x):\n        return self.parallel_apply(x)\n\n    def loss(self, x, preds):\n        return self.parallel_apply.module.loss(x, preds)\n\n    @property\n    def module(self):\n        return self.parallel_apply.module\n\n    @classmethod\n    def load(cls, model=TrainableModel(), weights_file=None):\n        model = cls(model)\n        if weights_file is not None:\n            data = torch.load(weights_file, map_location=lambda storage, loc: storage)\n            # hack for models saved with optimizers\n            if \"optimizer\" in data: data = data[\"state_dict\"]\n            try:\n                model.load_state_dict(data)\n            except RuntimeError:\n                parallel_module_data = {\n                    'parallel_apply.module.' + k: v for k, v in data.items()\n                }\n                model.load_state_dict(parallel_module_data)\n        return model\n\nclass WrapperModel(TrainableModel):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def forward(self, x):\n        return self.model(x)\n\n    def loss(self, x, preds):\n        raise NotImplementedError()\n\n    def __getitem__(self, i):\n        return self.model[i]\n\n    @property\n    def module(self):\n        return self.model\n\n\nif __name__ == \"__main__\":\n    import IPython\n    IPython.embed()\n"
  },
  {
    "path": "modules/__init__.py",
    "content": ""
  },
  {
    "path": "modules/depth_nets.py",
    "content": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom models import TrainableModel\nfrom utils import *\n\n\nclass ResidualsNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.encoder = nn.Sequential(\n            ConvBlock(3, 32, groups=3, use_groupnorm=False), \n            ConvBlock(32, 32, use_groupnorm=False),\n        )\n        self.mid = nn.Sequential(\n            ConvBlock(32, 64, dilation=1, use_groupnorm=False), \n            ConvBlock(64, 64, dilation=2, use_groupnorm=False),\n            ConvBlock(64, 64, dilation=2, use_groupnorm=False),\n            ConvBlock(64, 32, dilation=4, use_groupnorm=False),\n        )\n        self.decoder = nn.Sequential(\n            ConvBlock(64, 32, use_groupnorm=False),\n            ConvBlock(32, 32, use_groupnorm=False),\n            ConvBlock(32, 1, use_groupnorm=False),\n        )\n\n    def forward(self, x):\n        tmp = self.encoder(x)\n        x = F.max_pool2d(tmp, 4)\n        x = self.mid(x)\n        x = F.upsample(x, scale_factor=4, mode='bilinear')\n        x = torch.cat([x, tmp], dim=1)\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass UNet_up_block(nn.Module):\n    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):\n        super().__init__()\n        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, output_channel)        \n        self.relu = torch.nn.ReLU()\n        self.up_sample = up_sample\n\n    def forward(self, prev_feature_map, x):\n        if self.up_sample:\n            x = self.up_sampling(x)\n        x = torch.cat((x, prev_feature_map), dim=1)\n        x = self.relu(self.bn1(self.conv1(x)))\n        x = self.relu(self.bn2(self.conv2(x)))\n        x = self.relu(self.bn3(self.conv3(x)))\n        return x\n\n\nclass UNet_down_block(nn.Module):\n    def __init__(self, input_channel, output_channel, down_size=True):\n        super().__init__()\n        self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, output_channel)\n        self.max_pool = nn.MaxPool2d(2, 2)\n        self.relu = nn.ReLU()\n        self.down_size = down_size\n\n    def forward(self, x):\n\n        x = self.relu(self.bn1(self.conv1(x)))\n        x = self.relu(self.bn2(self.conv2(x)))\n        x = self.relu(self.bn3(self.conv3(x)))\n        if self.down_size:\n            x = self.max_pool(x)\n        return x\n\n\nclass UNetDepth(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.down_block1 = UNet_down_block(3, 16, False)\n        self.down_block2 = UNet_down_block(16, 32, True)\n        self.down_block3 = UNet_down_block(32, 64, True)\n        self.down_block4 = UNet_down_block(64, 128, True)\n        self.down_block5 = UNet_down_block(128, 256, True)\n        self.down_block6 = UNet_down_block(256, 512, True)\n        self.down_block7 = UNet_down_block(512, 1024, False)\n\n        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, 1024)\n        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, 1024)\n        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn3 = torch.nn.GroupNorm(8, 1024)\n\n        self.up_block1 = UNet_up_block(512, 1024, 512, False)\n        self.up_block2 = UNet_up_block(256, 512, 256, True)\n        self.up_block3 = UNet_up_block(128, 256, 128, True)\n        self.up_block4 = UNet_up_block(64, 128, 64, True)\n        self.up_block5 = UNet_up_block(32, 64, 32, True)\n        self.up_block6 = UNet_up_block(16, 32, 16, True)\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, 1, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.x1 = self.down_block1(x)\n        x = self.x2 = self.down_block2(self.x1)\n        x = self.x3 = self.down_block3(self.x2)\n        x = self.x4 = self.down_block4(self.x3)\n        x = self.x5 = self.down_block5(self.x4)\n        x = self.x6 = self.down_block6(self.x5)\n        x = self.x7 = self.down_block7(self.x6)\n\n        x = self.relu(self.bn1(self.mid_conv1(x)))\n        x = self.relu(self.bn2(self.mid_conv2(x)))\n        x = self.relu(self.bn3(self.mid_conv3(x)))\n\n        x = self.up_block1(self.x6, x)\n        x = self.up_block2(self.x5, x)\n        x = self.up_block3(self.x4, x)\n        x = self.up_block4(self.x3, x)\n        x = self.up_block5(self.x2, x)\n        x = self.up_block6(self.x1, x)\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.last_conv2(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, f1, f2, use_groupnorm=True, groups=8, dilation=1, transpose=False):\n        super().__init__()\n        self.transpose = transpose\n        self.conv = nn.Conv2d(f1, f2, (3, 3), dilation=dilation, padding=dilation)\n        if self.transpose:\n            self.convt = nn.ConvTranspose2d(\n                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1\n            )\n        if use_groupnorm:\n            self.bn = nn.GroupNorm(groups, f1)\n        else:\n            self.bn = nn.GroupNorm(8, f1)\n\n    def forward(self, x):\n        # x = F.dropout(x, 0.04, self.training)\n        x = self.bn(x)\n        if self.transpose:\n            # x = F.upsample(x, scale_factor=2, mode='bilinear')\n            x = F.relu(self.convt(x))\n            # x = x[:, :, :-1, :-1]\n        x = F.relu(self.conv(x))\n        return x\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.GroupNorm(8, planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.GroupNorm(8, planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.GroupNorm(8, planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\nclass ResNetOriginal(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000):\n        self.inplanes = 64\n        super(ResNetOriginal, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.GroupNorm(8, 64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 196, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AvgPool2d(7, stride=1)\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.GroupNorm(8, planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\nclass ResNetDepth(TrainableModel):\n    def __init__(self):\n        super().__init__()\n        # self.resnet = models.resnet50()\n        self.resnet = ResNetOriginal(Bottleneck, [3, 4, 6, 3])\n        self.final_conv = nn.Conv2d(2048, 8, (3, 3), padding=1)\n\n        self.decoder = nn.Sequential(\n            ConvBlock(8, 128),\n            ConvBlock(128, 128),\n            ConvBlock(128, 128),\n            ConvBlock(128, 128),\n            ConvBlock(128, 128),\n            ConvBlock(128, 128, transpose=True),\n            ConvBlock(128, 128, transpose=True),\n            ConvBlock(128, 128, transpose=True),\n            ConvBlock(128, 128, transpose=True),\n            ConvBlock(128, 1, transpose=True),\n        )\n\n    def forward(self, x):\n\n        for layer in list(self.resnet._modules.values())[:-2]:\n            x = layer(x)\n\n        x = self.final_conv(x)\n        x = self.decoder(x)\n\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n"
  },
  {
    "path": "modules/percep_nets.py",
    "content": "\n\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom models import TrainableModel\nfrom utils import *\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):\n        super().__init__()\n        self.transpose = transpose\n        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)\n        if self.transpose:\n            self.convt = nn.ConvTranspose2d(\n                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1\n            )\n        if use_groupnorm:\n            self.bn = nn.GroupNorm(groups, f1)\n        else:\n            self.bn = nn.BatchNorm2d(f1)\n\n    def forward(self, x):\n        # x = F.dropout(x, 0.04, self.training)\n        x = self.bn(x)\n        if self.transpose:\n            # x = F.upsample(x, scale_factor=2, mode='bilinear')\n            x = F.relu(self.convt(x))\n            # x = x[:, :, :-1, :-1]\n        x = F.relu(self.conv(x))\n        return x\n\n\nclass DenseNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 96, groups=3), \n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass Dense1by1Net(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), \n            ConvBlock(64, 96, kernel_size=1, padding=0), \n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\nclass Dense1by1end(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), \n            ConvBlock(64, 96, kernel_size=1, padding=0), \n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 1),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\nclass DenseKernelsNet(TrainableModel):\n    def __init__(self, kernel_size=7):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), \n            ConvBlock(64, 96, kernel_size=1, padding=0), \n            ConvBlock(96, 96, kernel_size=1, padding=0),\n            ConvBlock(96, 96, kernel_size=kernel_size, padding=kernel_size//2),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 96),\n            ConvBlock(96, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass DeepNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 32, groups=3), \n            ConvBlock(32, 32),\n            ConvBlock(32, 32, dilation=2),\n            ConvBlock(32, 32, dilation=2),\n            ConvBlock(32, 32, dilation=4),\n            ConvBlock(32, 32, dilation=4),\n            ConvBlock(32, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass WideNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 32, groups=3), \n            ConvBlock(32, 32, kernel_size=5, padding=2),\n            ConvBlock(32, 32, kernel_size=5, padding=2),\n            ConvBlock(32, 32, kernel_size=5, padding=2),\n            ConvBlock(32, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass PyramidNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 16, groups=3), \n            ConvBlock(16, 32, kernel_size=5, padding=2),\n            ConvBlock(32, 64, kernel_size=5, padding=2),\n            ConvBlock(64, 96, kernel_size=3, padding=1),\n            ConvBlock(96, 32, kernel_size=3, padding=1),\n            ConvBlock(32, 3),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\n\nclass BaseNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.decoder = nn.Sequential(\n            ConvBlock(3, 32, use_groupnorm=False), \n            ConvBlock(32, 32, use_groupnorm=False),\n            ConvBlock(32, 32, use_groupnorm=False),\n            ConvBlock(32, 1, use_groupnorm=False),\n        )\n\n    def forward(self, x):\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass ResidualsNet(TrainableModel):\n    def __init__(self):\n        super().__init__()\n\n        self.encoder = nn.Sequential(\n            ConvBlock(3, 32, use_groupnorm=False), \n            ConvBlock(32, 32, use_groupnorm=False),\n        )\n        self.mid = nn.Sequential(\n            ConvBlock(32, 64, use_groupnorm=False), \n            ConvBlock(64, 64, use_groupnorm=False),\n            ConvBlock(64, 32, use_groupnorm=False),\n        )\n        self.decoder = nn.Sequential(\n            ConvBlock(64, 32, use_groupnorm=False), \n            ConvBlock(32, 3, use_groupnorm=False),\n        )\n\n    def forward(self, x):\n        tmp = self.encoder(x)\n        x = F.max_pool2d(tmp, 2)\n        x = self.mid(x)\n        x = F.upsample(x, scale_factor=2, mode='bilinear')\n        x = torch.cat([x, tmp], dim=1)\n        x = self.decoder(x)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\nclass ResNet50(TrainableModel):\n    def __init__(self, num_classes=365, in_channels=3):\n        super().__init__()\n        self.resnet = models.resnet18(num_classes=num_classes)\n        self.resnet.fc = nn.Linear(in_features=8192, out_features=num_classes, bias=True)\n        self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n\n    def forward(self, x):\n        x = self.resnet(x)\n        return F.log_softmax(x, dim=1)\n\n    def loss(self, pred, target):\n        loss = F.nll_loss(pred, target)\n        return loss, (loss.detach(),)\n\n"
  },
  {
    "path": "modules/resnet.py",
    "content": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom models import TrainableModel\nfrom utils import *\n\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, groups=groups, bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.GroupNorm(8, planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.GroupNorm(8, planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.GroupNorm(8, planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.GroupNorm(8, planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.GroupNorm(8, planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\nclass ResNetOriginal(nn.Module):\n\n    def __init__(self, block, layers, in_channels=3, num_classes=1000):\n        self.inplanes = 64\n        super(ResNetOriginal, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.GroupNorm(8, 64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=1)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=1)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.GroupNorm(8, planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\nclass ResNet(TrainableModel):\n    def __init__(self, in_channels=3, out_channels=1000):\n        super().__init__()\n        self.resnet = ResNetOriginal(BasicBlock, [2, 2, 2, 2], in_channels=in_channels)\n        self.final = nn.Linear(512, out_channels)\n\n    def forward(self, x):\n\n        for layer in list(self.resnet._modules.values())[:-2]:\n            x = layer(x)\n        x = F.relu(x.mean(dim=2).mean(dim=2))\n        x = F.log_softmax(self.final(x), dim=1)\n\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass ResNetClass(TrainableModel):\n    def __init__(self):\n        super().__init__()\n        self.resnet = models.resnet50(pretrained=True)\n\n    def forward(self, x):\n        if x.shape[1] == 1: x = x.repeat(1,3,1,1)\n        for layer in list(self.resnet._modules.values())[:-2]:\n            x = layer(x)\n        return x\n\n    ### Not in Use Right Now ###\n    def loss(self, pred, target):\n        mask = build_mask(pred, val=0.502)\n        mse = F.mse_loss(pred[mask], target[mask])\n        return mse, (mse.detach(),)\n\n\nif __name__ == \"__main__\":\n    model = ResNet(out_channels=365)\n    print (model(torch.randn(2, 3, 224, 224 )).shape)\n\n"
  },
  {
    "path": "modules/unet.py",
    "content": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.utils.checkpoint import checkpoint\n\nfrom models import TrainableModel\nfrom utils import *\n\n\nclass UNet_up_block(nn.Module):\n    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):\n        super().__init__()\n        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, output_channel)        \n        self.relu = torch.nn.ReLU()\n        self.up_sample = up_sample\n\n    def forward(self, prev_feature_map, x):\n        if self.up_sample:\n            x = self.up_sampling(x)\n        x = torch.cat((x, prev_feature_map), dim=1)\n        x = self.relu(self.bn1(self.conv1(x)))\n        x = self.relu(self.bn2(self.conv2(x)))\n        x = self.relu(self.bn3(self.conv3(x)))\n        return x\n\n\nclass UNet_down_block(nn.Module):\n    def __init__(self, input_channel, output_channel, down_size=True):\n        super().__init__()\n        self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, output_channel)\n        self.max_pool = nn.MaxPool2d(2, 2)\n        self.relu = nn.ReLU()\n        self.down_size = down_size\n\n    def forward(self, x):\n        x = self.relu(self.bn1(self.conv1(x)))\n        x = self.relu(self.bn2(self.conv2(x)))\n        x = self.relu(self.bn3(self.conv3(x)))\n        if self.down_size:\n            x = self.max_pool(x)\n        return x\n\n\nclass UNet(TrainableModel):\n    def __init__(self,  downsample=6, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample\n        self.down1 = UNet_down_block(in_channels, 16, False)\n        self.down_blocks = nn.ModuleList(\n            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]\n        )\n\n        bottleneck = 2**(4 + downsample)\n        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, bottleneck)\n\n        self.up_blocks = nn.ModuleList(\n            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]\n        )\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.down1(x)\n        xvals = [x]\n        for i in range(0, self.downsample):\n            x = self.down_blocks[i](x)\n            xvals.append(x)\n\n        x = self.relu(self.bn1(self.mid_conv1(x)))\n        x = self.relu(self.bn2(self.mid_conv2(x)))\n        x = self.relu(self.bn3(self.mid_conv3(x)))\n\n        for i in range(0, self.downsample)[::-1]:\n            x = self.up_blocks[i](xvals[i], x)\n\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\nclass UNetReshade(TrainableModel):\n    def __init__(self,  downsample=6, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample\n        self.down1 = UNet_down_block(in_channels, 16, False)\n        self.down_blocks = nn.ModuleList(\n            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]\n        )\n\n        bottleneck = 2**(4 + downsample)\n        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, bottleneck)\n\n        self.up_blocks = nn.ModuleList(\n            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]\n        )\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.down1(x)\n        xvals = [x]\n        for i in range(0, self.downsample):\n            x = self.down_blocks[i](x)\n            xvals.append(x)\n\n        x = self.relu(self.bn1(self.mid_conv1(x)))\n        x = self.relu(self.bn2(self.mid_conv2(x)))\n        x = self.relu(self.bn3(self.mid_conv3(x)))\n\n        for i in range(0, self.downsample)[::-1]:\n            x = self.up_blocks[i](xvals[i], x)\n\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True)\n        x = x.expand(-1, 3, -1, -1)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass UNetOld(TrainableModel):\n    def __init__(self, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels = in_channels, out_channels\n        self.down_block1 = UNet_down_block(in_channels, 16, False) #   256\n        self.down_block2 = UNet_down_block(16, 32, True) #   128\n        self.down_block3 = UNet_down_block(32, 64, True) #   64\n        self.down_block4 = UNet_down_block(64, 128, True) #  32\n        self.down_block5 = UNet_down_block(128, 256, True) # 16\n        self.down_block6 = UNet_down_block(256, 512, True) # 8\n        self.down_block7 = UNet_down_block(512, 1024, True)# 4 \n        \n        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, 1024)\n        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, 1024)\n        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, 1024)\n\n        self.up_block1 = UNet_up_block(512, 1024, 512)\n        self.up_block2 = UNet_up_block(256, 512, 256)\n        self.up_block3 = UNet_up_block(128, 256, 128)\n        self.up_block4 = UNet_up_block(64, 128, 64)\n        self.up_block5 = UNet_up_block(32, 64, 32)\n        self.up_block6 = UNet_up_block(16, 32, 16)\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        self.x1 = self.down_block1(x)\n        self.x2 = self.down_block2(self.x1)\n        self.x3 = self.down_block3(self.x2)\n        self.x4 = self.down_block4(self.x3)\n        self.x5 = self.down_block5(self.x4)\n        self.x6 = self.down_block6(self.x5)\n        self.x7 = self.down_block7(self.x6)\n\n        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))\n        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))\n        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))\n\n        x = self.up_block1(self.x6, self.x7)\n        x = self.up_block2(self.x5, x)\n        x = self.up_block3(self.x4, x)\n        x = self.up_block4(self.x3, x)\n        x = self.up_block5(self.x2, x)\n        x = self.up_block6(self.x1, x)\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):\n        super().__init__()\n        self.transpose = transpose\n        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)\n        if self.transpose:\n            self.convt = nn.ConvTranspose2d(\n                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1\n            )\n        if use_groupnorm:\n            self.bn = nn.GroupNorm(groups, f1)\n        else:\n            self.bn = nn.BatchNorm2d(f1)\n\n    def forward(self, x):\n        # x = F.dropout(x, 0.04, self.training)\n        x = self.bn(x)\n        if self.transpose:\n            # x = F.upsample(x, scale_factor=2, mode='bilinear')\n            x = F.relu(self.convt(x))\n            # x = x[:, :, :-1, :-1]\n        x = F.relu(self.conv(x))\n        return x\n\nclass UNetOld2(TrainableModel):\n    def __init__(self, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels = in_channels, out_channels\n        self.initial = nn.Sequential(\n            ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0),\n            ConvBlock(16, 16, groups=4, kernel_size=1, padding=0)\n        )\n        self.down_block1 = UNet_down_block(16, 16, False)\n        self.down_block2 = UNet_down_block(16, 32, True) #   128\n        self.down_block3 = UNet_down_block(32, 64, True) #   64\n        self.down_block4 = UNet_down_block(64, 128, True) #  32\n        self.down_block5 = UNet_down_block(128, 256, True) # 16\n        self.down_block6 = UNet_down_block(256, 512, True) # 8\n        self.down_block7 = UNet_down_block(512, 1024, True)# 4 \n        \n        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, 1024)\n        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, 1024)\n        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, 1024)\n\n        self.up_block1 = UNet_up_block(512, 1024, 512)\n        self.up_block2 = UNet_up_block(256, 512, 256)\n        self.up_block3 = UNet_up_block(128, 256, 128)\n        self.up_block4 = UNet_up_block(64, 128, 64)\n        self.up_block5 = UNet_up_block(32, 64, 32)\n        self.up_block6 = UNet_up_block(16, 32, 16)\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.initial(x)\n        self.x1 = self.down_block1(x)\n        self.x2 = self.down_block2(self.x1)\n        self.x3 = self.down_block3(self.x2)\n        self.x4 = self.down_block4(self.x3)\n        self.x5 = self.down_block5(self.x4)\n        self.x6 = self.down_block6(self.x5)\n        self.x7 = self.down_block7(self.x6)\n\n        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))\n        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))\n        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))\n\n        x = self.up_block1(self.x6, self.x7)\n        x = self.up_block2(self.x5, x)\n        x = self.up_block3(self.x4, x)\n        x = self.up_block4(self.x3, x)\n        x = self.up_block5(self.x2, x)\n        x = self.up_block6(self.x1, x)\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n        \n"
  },
  {
    "path": "modules/unet_mirrored.py",
    "content": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.utils.checkpoint import checkpoint\n\nfrom models import TrainableModel\nfrom utils import *\n\n\nclass UNet_up_block(nn.Module):\n    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):\n        super().__init__()\n        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3)\n        self.bn3 = nn.GroupNorm(8, output_channel)        \n        self.relu = torch.nn.ReLU()\n        self.up_sample = up_sample\n\n    def forward(self, prev_feature_map, x):\n        if self.up_sample:\n            x = self.up_sampling(x)\n        x = torch.cat((x, prev_feature_map), dim=1)\n        x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect')))\n        x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect')))\n        x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect')))\n        return x\n\n\nclass UNet_down_block(nn.Module):\n    def __init__(self, input_channel, output_channel, down_size=True):\n        super().__init__()\n        self.conv1 = nn.Conv2d(input_channel, output_channel, 3)\n        self.bn1 = nn.GroupNorm(8, output_channel)\n        self.conv2 = nn.Conv2d(output_channel, output_channel, 3)\n        self.bn2 = nn.GroupNorm(8, output_channel)\n        self.conv3 = nn.Conv2d(output_channel, output_channel, 3)\n        self.bn3 = nn.GroupNorm(8, output_channel)\n        self.max_pool = nn.MaxPool2d(2, 2)\n        self.relu = nn.ReLU()\n        self.down_size = down_size\n\n    def forward(self, x):\n        x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect')))\n        x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect')))\n        x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect')))\n        if self.down_size:\n            x = self.max_pool(x)\n        return x\n\n\nclass UNet(TrainableModel):\n    def __init__(self, downsample=6, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample\n        self.down1 = UNet_down_block(in_channels, 16, False)\n        self.down_blocks = nn.ModuleList(\n            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]\n        )\n\n        bottleneck = 2**(4 + downsample)\n        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, bottleneck)\n\n        self.up_blocks = nn.ModuleList(\n            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]\n        )\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.down1(x)\n        xvals = [x]\n        pad = [False for i in range(0, self.downsample)]\n        for i in range(0, self.downsample):\n            if x.shape[2] % 2 != 0: \n                x = F.pad(x, (1, 0, 1, 0))\n                pad[i] = True\n            x = self.down_blocks[i](x)\n            xvals.append(x)\n\n        x = self.relu(self.bn1(self.mid_conv1(x)))\n        x = self.relu(self.bn2(self.mid_conv2(x)))\n        x = self.relu(self.bn3(self.mid_conv3(x)))\n\n        for i in range(0, self.downsample)[::-1]:\n            # print (x.shape, xvals[i].shape)\n            x = self.up_blocks[i](xvals[i], x)\n            if pad[i] != 0: \n                x = x[:, :, 1:, 1:]\n\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\nclass UNetReshade(TrainableModel):\n    def __init__(self,  downsample=6, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample\n        self.down1 = UNet_down_block(in_channels, 16, False)\n        self.down_blocks = nn.ModuleList(\n            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]\n        )\n\n        bottleneck = 2**(4 + downsample)\n        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, bottleneck)\n        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, bottleneck)\n\n        self.up_blocks = nn.ModuleList(\n            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]\n        )\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.down1(x)\n        xvals = [x]\n        for i in range(0, self.downsample):\n            x = self.down_blocks[i](x)\n            xvals.append(x)\n\n        x = self.relu(self.bn1(self.mid_conv1(x)))\n        x = self.relu(self.bn2(self.mid_conv2(x)))\n        x = self.relu(self.bn3(self.mid_conv3(x)))\n\n        for i in range(0, self.downsample)[::-1]:\n            x = self.up_blocks[i](xvals[i], x)\n\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True)\n        x = x.expand(-1, 3, -1, -1)\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass UNetOld(TrainableModel):\n    def __init__(self, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels = in_channels, out_channels\n        self.down_block1 = UNet_down_block(in_channels, 16, False) #   256\n        self.down_block2 = UNet_down_block(16, 32, True) #   128\n        self.down_block3 = UNet_down_block(32, 64, True) #   64\n        self.down_block4 = UNet_down_block(64, 128, True) #  32\n        self.down_block5 = UNet_down_block(128, 256, True) # 16\n        self.down_block6 = UNet_down_block(256, 512, True) # 8\n        self.down_block7 = UNet_down_block(512, 1024, True)# 4 \n        \n        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, 1024)\n        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, 1024)\n        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, 1024)\n\n        self.up_block1 = UNet_up_block(512, 1024, 512)\n        self.up_block2 = UNet_up_block(256, 512, 256)\n        self.up_block3 = UNet_up_block(128, 256, 128)\n        self.up_block4 = UNet_up_block(64, 128, 64)\n        self.up_block5 = UNet_up_block(32, 64, 32)\n        self.up_block6 = UNet_up_block(16, 32, 16)\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        self.x1 = self.down_block1(x)\n        self.x2 = self.down_block2(self.x1)\n        self.x3 = self.down_block3(self.x2)\n        self.x4 = self.down_block4(self.x3)\n        self.x5 = self.down_block5(self.x4)\n        self.x6 = self.down_block6(self.x5)\n        self.x7 = self.down_block7(self.x6)\n\n        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))\n        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))\n        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))\n\n        x = self.up_block1(self.x6, self.x7)\n        x = self.up_block2(self.x5, x)\n        x = self.up_block3(self.x4, x)\n        x = self.up_block4(self.x3, x)\n        x = self.up_block5(self.x2, x)\n        x = self.up_block6(self.x1, x)\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):\n        super().__init__()\n        self.transpose = transpose\n        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)\n        if self.transpose:\n            self.convt = nn.ConvTranspose2d(\n                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1\n            )\n        if use_groupnorm:\n            self.bn = nn.GroupNorm(groups, f1)\n        else:\n            self.bn = nn.BatchNorm2d(f1)\n\n    def forward(self, x):\n        # x = F.dropout(x, 0.04, self.training)\n        x = self.bn(x)\n        if self.transpose:\n            # x = F.upsample(x, scale_factor=2, mode='bilinear')\n            x = F.relu(self.convt(x))\n            # x = x[:, :, :-1, :-1]\n        x = F.relu(self.conv(x))\n        return x\n\nclass UNetOld2(TrainableModel):\n    def __init__(self, in_channels=3, out_channels=3):\n        super().__init__()\n\n        self.in_channels, self.out_channels = in_channels, out_channels\n        self.initial = nn.Sequential(\n            ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0),\n            ConvBlock(16, 16, groups=4, kernel_size=1, padding=0)\n        )\n        self.down_block1 = UNet_down_block(16, 16, False)\n        self.down_block2 = UNet_down_block(16, 32, True) #   128\n        self.down_block3 = UNet_down_block(32, 64, True) #   64\n        self.down_block4 = UNet_down_block(64, 128, True) #  32\n        self.down_block5 = UNet_down_block(128, 256, True) # 16\n        self.down_block6 = UNet_down_block(256, 512, True) # 8\n        self.down_block7 = UNet_down_block(512, 1024, True)# 4 \n        \n        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn1 = nn.GroupNorm(8, 1024)\n        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn2 = nn.GroupNorm(8, 1024)\n        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)\n        self.bn3 = nn.GroupNorm(8, 1024)\n\n        self.up_block1 = UNet_up_block(512, 1024, 512)\n        self.up_block2 = UNet_up_block(256, 512, 256)\n        self.up_block3 = UNet_up_block(128, 256, 128)\n        self.up_block4 = UNet_up_block(64, 128, 64)\n        self.up_block5 = UNet_up_block(32, 64, 32)\n        self.up_block6 = UNet_up_block(16, 32, 16)\n\n        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)\n        self.last_bn = nn.GroupNorm(8, 16)\n        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.initial(x)\n        self.x1 = self.down_block1(x)\n        self.x2 = self.down_block2(self.x1)\n        self.x3 = self.down_block3(self.x2)\n        self.x4 = self.down_block4(self.x3)\n        self.x5 = self.down_block5(self.x4)\n        self.x6 = self.down_block6(self.x5)\n        self.x7 = self.down_block7(self.x6)\n\n        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))\n        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))\n        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))\n\n        x = self.up_block1(self.x6, self.x7)\n        x = self.up_block2(self.x5, x)\n        x = self.up_block3(self.x4, x)\n        x = self.up_block4(self.x3, x)\n        x = self.up_block5(self.x2, x)\n        x = self.up_block6(self.x1, x)\n        x = self.relu(self.last_bn(self.last_conv1(x)))\n        x = self.relu(self.last_conv2(x))\n        return x\n\n    def loss(self, pred, target):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n\nif __name__ == \"__main__\":\n\n    model = UNet()\n    x = torch.randn(1, 3, 256, 256)\n    print (model(x).shape)\n\n        \n"
  },
  {
    "path": "plotting.py",
    "content": "import numpy as np\n\ndef jointplot(logger, data, loss_type=\"mse_loss\"):\n    data = np.stack((data[f\"train_{loss_type}\"], data[f\"val_{loss_type}\"]), axis=1)\n    logger.plot(data, loss_type, opts={\"legend\": [f\"train_{loss_type}\", f\"val_{loss_type}\"]})\n\ndef get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(list_of_list_values):\n    running_mean_and_std_bounds = []\n    legend = [\"Mean-STD\", \"Mean Difference\", \"Mean+STD\"]\n    for ii, losses_in_batch_ii in enumerate(list_of_list_values):\n        if ii == 0:  # there's no previous time step to compare to\n            running_mean_and_std_bounds.append([0, 0, 0])\n        else:\n            loss_diffs = [loss_val - list_of_list_values[ii - 1][jj]\n                          for jj, loss_val in enumerate(losses_in_batch_ii)]\n            mean = np.mean(loss_diffs)\n            std = np.std(loss_diffs)\n\n            running_mean_and_std_bounds.append([mean - std, mean, mean + std])\n\n    return running_mean_and_std_bounds, legend\n\ndef get_running_means_w_std_bounds_and_legend(list_of_list_values):\n    running_mean_and_std_bounds = []\n    legend = [\"Mean-STD\", \"Mean\", \"Mean+STD\"]\n    for ii in range(len(list_of_list_values)):\n        mean = np.mean(list_of_list_values[ii])\n        std = np.std(list_of_list_values[ii])\n\n        running_mean_and_std_bounds.append([mean - std, mean, mean + std])\n\n    return running_mean_and_std_bounds, legend\n\n\ndef get_running_std(list_of_list_values):\n    return [np.std(list_of_list_values[ii]) for ii in range(len(list_of_list_values))]\n\n\ndef get_running_p_coeffs(list_of_list_values_1, list_of_list_values_2):\n    assert len(list_of_list_values_1) == len(list_of_list_values_2)\n\n    pearson_coefficients = []\n    for ii in range(len(list_of_list_values_1)):\n        cov = np.cov(np.stack((list_of_list_values_1[ii],\n                               list_of_list_values_2[ii]), axis=0))[0, 1]\n        std1 = np.std(list_of_list_values_1[ii])\n        std2 = np.std(list_of_list_values_2[ii])\n        correlation_coefficient = cov / (std1 * std2)\n\n        pearson_coefficients.append(correlation_coefficient)\n\n    return pearson_coefficients\n\ndef mseplots(data, logger):\n    data = np.stack((logger.data[\"train_mse_loss\"], logger.data[\"val_mse_loss\"]), axis=1)\n    logger.plot(data, \"mse_loss\", opts={\"legend\": [\"train_mse\", \"val_mse\"]})\n\n    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data[\"val_mse_losses\"])\n    logger.plot(running_mean_and_std_bounds, \"val_mse_loss_running_mean\", opts={\"legend\": legend})\n    logger.plot(get_running_std(logger.data[\"val_mse_losses\"]), \"val_mse_losses_running_stds\",\n                opts={\"legend\": ['STD']})\n\n    running_mean_and_std_bounds_diff_prev_time_step, legend = \\\n        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data[\"val_mse_losses\"])\n    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, \"val_mse_loss_diff_prev_step_running_mean\",\n                opts={\"legend\": legend})\n\n\ndef curvatureplots(data, logger):\n    data = np.stack((logger.data[\"train_curvature_loss\"], logger.data[\"val_curvature_loss\"]), axis=1)\n    logger.plot(data, \"curvature_loss\", opts={\"legend\": [\"train_curvature\", \"val_curvature\"]})\n\n    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(\n        logger.data[\"val_curvature_losses\"])\n    logger.plot(running_mean_and_std_bounds, \"val_curvature_loss_running_mean\", opts={\"legend\": legend})\n    logger.plot(get_running_std(logger.data[\"val_curvature_losses\"]), \"val_curvature_losses_running_stds\",\n                opts={\"legend\": ['STD']})\n\n    running_mean_and_std_bounds_diff_prev_time_step, legend = \\\n        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data[\"val_curvature_losses\"])\n    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, \"val_curvature_loss_diff_prev_step_running_mean\",\n                opts={\"legend\": legend})\n\n\ndef depthplots(data, logger):\n    data = np.stack((logger.data[\"train_depth_loss\"], logger.data[\"val_depth_loss\"]), axis=1)\n    logger.plot(data, \"depth_loss\", opts={\"legend\": [\"train_depth\", \"val_depth\"]})\n\n    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data[\"val_depth_losses\"])\n    logger.plot(running_mean_and_std_bounds, \"val_depth_loss_running_mean\", opts={\"legend\": legend})\n    logger.plot(get_running_std(logger.data[\"val_depth_losses\"]), \"val_depth_losses_running_stds\",\n                opts={\"legend\": ['STD']})\n\n    running_mean_and_std_bounds_diff_prev_time_step, legend = \\\n        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data[\"val_depth_losses\"])\n    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, \"val_depth_loss_diff_prev_step_running_mean\",\n                opts={\"legend\": legend})\n\n\ndef covarianceplot(data, logger):\n    covs = get_running_p_coeffs(logger.data[\"val_mse_losses\"], logger.data[\"val_curvature_losses\"])\n    logger.plot(covs, \"val_mse_curvature_running_pearson_coeffs\", opts={\"legend\": ['Pearson Coefficient']})\n\n    covs = get_running_p_coeffs(logger.data[\"val_mse_losses\"], logger.data[\"val_depth_losses\"])\n    logger.plot(covs, \"val_mse_depth_running_pearson_coeffs\", opts={\"legend\": ['Pearson Coefficient']})\n\n    covs = get_running_p_coeffs(logger.data[\"val_curvature_losses\"], logger.data[\"val_depth_losses\"])\n    logger.plot(covs, \"train_curvature_depth_running_pearson_coeffs\", opts={\"legend\": ['Pearson Coefficient']})\n\n    ratio_mse_curv_stds = [mse_std / curv_std for mse_std, curv_std in\n                           zip(get_running_std(logger.data[\"val_mse_losses\"]),\n                               get_running_std(logger.data[\"val_curvature_losses\"]))]\n    logger.plot(ratio_mse_curv_stds, \"val_mse_over_curvature_stds\", opts={\"legend\": ['MSE / Curvature STD']})\n\n    ratio_mse_depth_stds = [mse_std / depth_std for mse_std, depth_std in\n                            zip(get_running_std(logger.data[\"val_mse_losses\"]),\n                                get_running_std(logger.data[\"val_depth_losses\"]))]\n    logger.plot(ratio_mse_depth_stds, \"val_mse_over_depth_stds\", opts={\"legend\": ['MSE / Depth STD']})\n"
  },
  {
    "path": "requirements.txt",
    "content": "fire==0.2.1\nipython==6.5.0\nmatplotlib==3.0.3\nnumpy==1.17.2\nparse==1.12.1\npip==19.3.1\nplac==0.9.6\npy==1.6.0\nscipy==1.3.1\nscikit-image==0.16.2\nscikit-learn==0.22.1\ntorch==1.2.0\ntorchvision==0.4.0\ntqdm==4.36.1\nvisdom==0.1.8.9\npathlib==1.0.1\npyyaml==5.3.1\npandas\nseaborn\nstatsmodels\n"
  },
  {
    "path": "scripts/energy_calc.py",
    "content": "import os, sys, math, random, itertools\nimport numpy as np\nimport scipy\nfrom collections import defaultdict\nfrom tqdm import tqdm\nimport pandas as pd\nimport matplotlib\nmatplotlib.use('agg')\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils import *\nfrom plotting import *\nfrom energy import get_energy_loss\nfrom graph import TaskGraph\nfrom datasets import TaskDataset, load_train_val, load_test, load_ood, ImageDataset\nfrom task_configs import tasks, RealityTask\n\nfrom functools import partial\nfrom fire import Fire\n\nimport IPython\nimport pdb\nfrom modules.unet import UNet\n\ndef main(\n    loss_config=\"conservative_full\",\n    mode=\"standard\",\n    pretrained=True, finetuned=False, batch_size=16,\n    ood_batch_size=None, subset_size=None,\n    cont=None,\n    use_l1=True, num_workers=32, data_dir=None, save_dir='mount/shared/', **kwargs,\n):\n\n    # CONFIG\n    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)\n\n    if data_dir is None:\n        buildings = [\"almena\", \"albertville\"]\n        train_subset_dataset = TaskDataset(buildings, tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature])\n    else:\n        train_subset_dataset = ImageDataset(data_dir=data_dir)\n        data_dir = 'CUSTOM'\n\n    train_subset = RealityTask(\"train_subset\", train_subset_dataset, batch_size=batch_size, shuffle=False)\n\n    if subset_size is None:\n        subset_size = len(train_subset_dataset)\n    subset_size = min(subset_size, len(train_subset_dataset))\n\n    # GRAPH\n    realities = [train_subset]\n    edges = []\n    for t in energy_loss.tasks:\n        if t != tasks.rgb:\n            edges.append((tasks.rgb, t))\n            edges.append((tasks.rgb, tasks.normal))\n\n\n    graph = TaskGraph(tasks=energy_loss.tasks + [train_subset],\n                      finetuned=finetuned,\n                      freeze_list=energy_loss.freeze_list, lazy=True,\n                      initialize_from_transfer=True,\n                      )\n\n    # print('file', cont)\n    #graph.load_weights(cont)\n    graph.compile(optimizer=None)\n\n    # Add consistency links\n    for target in ['reshading', 'depth_zbuffer', 'normal']:\n        graph.edge_map[str(('rgb', target))].path = None\n        graph.edge_map[str(('rgb', target))].load_model()\n    graph.edge_map[str(('rgb', 'reshading'))].model.load_weights('./models/rgb2reshading_consistency.pth',backward_compatible=True)\n    graph.edge_map[str(('rgb', 'depth_zbuffer'))].model.load_weights('./models/rgb2depth_consistency.pth',backward_compatible=True)\n    graph.edge_map[str(('rgb', 'normal'))].model.load_weights('./models/rgb2normal_consistency.pth',backward_compatible=True)\n\n    energy_losses, mse_losses = [], []\n    percep_losses = defaultdict(list)\n\n    energy_mean_by_blur, energy_std_by_blur = [], []\n    error_mean_by_blur, error_std_by_blur = [], []\n\n    energy_losses, error_losses = [], []\n\n    energy_losses_all, energy_losses_headings = [], []\n\n    fnames = []\n    train_subset.reload()\n    # Compute energies\n    for epochs in tqdm(range(subset_size // batch_size)):\n        with torch.no_grad():\n            losses = energy_loss(graph, realities=[train_subset], reduce=False, use_l1=use_l1)\n\n            if len(energy_losses_headings) == 0:\n                energy_losses_headings = sorted([loss_name for loss_name in losses if 'percep' in loss_name])\n\n            all_perceps = [losses[loss_name].cpu().numpy() for loss_name in energy_losses_headings]\n            direct_losses = [losses[loss_name].cpu().numpy() for loss_name in losses if 'direct' in loss_name]\n\n            if len(all_perceps) > 0:\n                energy_losses_all += [all_perceps]\n                all_perceps = np.stack(all_perceps)\n                energy_losses += list(all_perceps.mean(0))\n\n            if len(direct_losses) > 0:\n                direct_losses = np.stack(direct_losses)\n                error_losses += list(direct_losses.mean(0))\n\n            if False:\n                fnames += train_subset.task_data[tasks.filename]\n        train_subset.step()\n\n\n    # log losses\n    if len(energy_losses) > 0:\n        energy_losses = np.array(energy_losses)\n        print(f'energy = {energy_losses.mean()}')\n\n        energy_mean_by_blur += [energy_losses.mean()]\n        energy_std_by_blur += [np.std(energy_losses)]\n\n    if len(error_losses) > 0:\n        error_losses = np.array(error_losses)\n        print(f'error = {error_losses.mean()}')\n\n        error_mean_by_blur += [error_losses.mean()]\n        error_std_by_blur += [np.std(error_losses)]\n\n    # save to csv\n    save_error_losses = error_losses if len(error_losses) > 0 else [0] * subset_size\n    save_energy_losses = energy_losses if len(energy_losses) > 0 else [0] * subset_size\n\n    z_score = lambda x: (x - x.mean()) / x.std()\n    def get_standardized_energy(df, use_std=False, compare_to_in_domain=False):\n        percepts = [c for c in df.columns if 'percep' in c]\n        stdize = lambda x: (x - x.mean()).abs().mean()\n        means = {k: df[k].mean() for k in percepts}\n        stds = {k: stdize(df[k]) for k in percepts}\n        stdized = {k: (df[k] - means[k])/stds[k] for k in percepts}\n        energies = np.stack([v for k, v in stdized.items() if k[-1] == '_' or '__' in k]).mean(0)\n        return energies\n\n\n    os.makedirs(save_dir, exist_ok=True)\n    if data_dir is 'CUSTOM':\n        eng_curr = np.array(energy_losses).mean()\n        df = pd.read_csv(os.path.join(save_dir, 'data.csv'))\n    else:\n        percep_losses = { k: v for k, v in zip(energy_losses_headings, np.concatenate(energy_losses_all, axis=-1))}\n        df = pd.DataFrame(both(\n                        {'energy': save_energy_losses, 'error': save_error_losses },\n                        percep_losses\n        ))\n\n    # compuate correlation\n    df['normalized_energy'] = get_standardized_energy(df, use_std=False)\n    df['normalized_error'] = z_score(df['error'])\n    print(scipy.stats.spearmanr(z_score(df['error']), df['normalized_energy']))\n    print(\"Pearson r:\", scipy.stats.pearsonr(df['error'], df['normalized_energy']))\n\n    if data_dir is not 'CUSTOM':\n        df.to_csv(f\"{save_dir}/data.csv\", mode='w', header=True)\n\n    # plot correlation\n    plt.figure(figsize=(4,4))\n    g = sns.regplot(df['normalized_error'], df['normalized_energy'],robust=False)\n    if data_dir is 'CUSTOM':\n        ax1 = g.axes\n        ax1.axhline(eng_curr, ls='--', color='red')\n        ax1.text(0.5, 25, \"Query Image Energy Line\")\n    plt.xlabel('Error (z-score)')\n    plt.ylabel('Energy (z-score)')\n    plt.title('')\n    plt.savefig(f'{save_dir}/energy.pdf')\n\n\n\nif __name__ == \"__main__\":\n    Fire(main)\n"
  },
  {
    "path": "scripts/jobinfo.txt",
    "content": "CH_lbp_all_winrate_depthtarget_1, 0, /scratch-data\n"
  },
  {
    "path": "task_configs.py",
    "content": "\nimport numpy as np\nimport random, sys, os, time, glob, math, itertools, json, copy\nfrom collections import defaultdict, namedtuple\nfrom functools import partial\n\nimport PIL\nfrom PIL import Image\nfrom scipy import ndimage\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as TF\nimport torch.optim as optim\nfrom torchvision import transforms\n\nfrom utils import *\nfrom models import DataParallelModel\nfrom modules.unet import UNet, UNetOld2, UNetOld\nfrom modules.percep_nets import Dense1by1Net\nfrom modules.depth_nets import UNetDepth\nfrom modules.resnet import ResNetClass\nimport IPython\n\nfrom PIL import ImageFilter\nfrom skimage.filters import gaussian\n\n\nclass GaussianBulr(object):\n    def __init__(self, radius):\n        self.radius = radius\n        self.filter = ImageFilter.GaussianBlur(radius)\n\n    def __call__(self, im):\n        return im.filter(self.filter)\n\n    def __repr__(self):\n        return 'GaussianBulr Filter with Radius {:d}'.format(self.radius)\n\n\n\"\"\" Model definitions for launching new transfer jobs between tasks. \"\"\"\nmodel_types = {\n    ('normal', 'principal_curvature'): lambda : Dense1by1Net(),\n    ('normal', 'depth_zbuffer'): lambda : UNetDepth(),\n    ('normal', 'reshading'): lambda : UNet(downsample=5),\n    ('depth_zbuffer', 'normal'): lambda : UNet(downsample=6, in_channels=1, out_channels=3),\n    ('reshading', 'normal'): lambda : UNet(downsample=4, in_channels=3, out_channels=3),\n    ('sobel_edges', 'principal_curvature'): lambda : UNet(downsample=5, in_channels=1, out_channels=3),\n    ('depth_zbuffer', 'principal_curvature'): lambda : UNet(downsample=4, in_channels=1, out_channels=3),\n    ('principal_curvature', 'depth_zbuffer'): lambda : UNet(downsample=6, in_channels=3, out_channels=1),\n    ('rgb', 'normal'): lambda : UNet(downsample=6),\n    ('rgb', 'keypoints2d'): lambda : UNet(downsample=3, out_channels=1),\n}\n\ndef get_model(src_task, dest_task):\n\n    if isinstance(src_task, str) and isinstance(dest_task, str):\n        src_task, dest_task = get_task(src_task), get_task(dest_task)\n\n    if (src_task.name, dest_task.name) in model_types:\n        return model_types[(src_task.name, dest_task.name)]()\n\n    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ImageTask):\n        return UNet(downsample=5, in_channels=src_task.shape[0], out_channels=dest_task.shape[0])\n\n    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ClassTask):\n        return ResNet(in_channels=src_task.shape[0], out_channels=dest_task.classes)\n\n    elif isinstance(src_task, ImageTask) and isinstance(dest_task, PointInfoTask):\n        return ResNet(out_channels=dest_task.out_channels)\n\n    return None\n\n\n\n\"\"\"\nAbstract task type definitions.\nIncludes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask.\n\"\"\"\n\nclass Task(object):\n    variances = {\n        \"normal\": 1.0,\n        \"principal_curvature\": 1.0,\n        \"sobel_edges\": 5,\n        \"depth_zbuffer\": 0.1,\n        \"reshading\": 1.0,\n        \"keypoints2d\": 0.3,\n        \"keypoints3d\": 0.6,\n        \"edge_occlusion\": 0.1,\n    }\n    \"\"\" General task output space\"\"\"\n    def __init__(self, name,\n            file_name=None, file_name_alt=None, file_ext=\"png\", file_loader=None,\n            plot_func=None\n        ):\n\n        super().__init__()\n        self.name = name\n        self.file_name, self.file_ext = file_name or name, file_ext\n        self.file_name_alt = file_name_alt or self.file_name\n        self.file_loader = file_loader or self.file_loader\n        self.plot_func = plot_func or self.plot_func\n        self.variance = Task.variances.get(name, 1.0)\n        self.kind = name\n\n    def norm(self, pred, target, batch_mean=True, compute_mse=True):\n        if batch_mean:\n            loss = ((pred - target)**2).mean() if compute_mse else ((pred - target).abs()).mean()\n        else:\n            loss = ((pred - target)**2).mean(dim=1).mean(dim=1).mean(dim=1) if compute_mse \\\n                    else ((pred - target).abs()).mean(dim=1).mean(dim=1).mean(dim=1)\n\n        return loss, (loss.mean().detach(),)\n\n    def __call__(self, size=256):\n        task = copy.deepcopy(self)\n        return task\n\n    def plot_func(self, data, name, logger, **kwargs):\n        ### Non-image tasks cannot be easily plotted, default to nothing\n        pass\n\n    def file_loader(self, path, resize=None, seed=0, T=0):\n        raise NotImplementedError()\n\n    def __eq__(self, other):\n        return self.name == other.name\n\n    def __repr__(self):\n        return self.name\n\n    def __hash__(self):\n        return hash(self.name)\n\n\n\"\"\"\nAbstract task type definitions.\nIncludes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask.\n\"\"\"\n\nclass RealityTask(Task):\n    \"\"\" General task output space\"\"\"\n\n    def __init__(self, name, dataset, tasks=None, use_dataset=True, shuffle=False, batch_size=64):\n\n        super().__init__(name=name)\n        self.tasks = tasks if tasks is not None else \\\n            (dataset.dataset.tasks if hasattr(dataset, \"dataset\") else dataset.tasks)\n        self.shape = (1,)\n        if not use_dataset: return\n        self.dataset, self.shuffle, self.batch_size = dataset, shuffle, batch_size\n        loader = torch.utils.data.DataLoader(\n            self.dataset, batch_size=self.batch_size,\n            num_workers=0, shuffle=self.shuffle, pin_memory=True\n        )\n        self.generator = cycle(loader)\n        self.step()\n        self.static = False\n\n    @classmethod\n    def from_dataloader(cls, name, loader, tasks):\n        reality = cls(name, None, tasks, use_dataset=False)\n        reality.loader = loader\n        reality.generator = cycle(loader)\n        reality.static = False\n        reality.step()\n        return reality\n\n    @classmethod\n    def from_static(cls, name, data, tasks):\n        reality = cls(name, None, tasks, use_dataset=False)\n        reality.task_data = {task: x.requires_grad_() for task, x in zip(tasks, data)}\n        reality.static = True\n        return reality\n\n    def norm(self, pred, target, batch_mean=True):\n        loss = torch.tensor(0.0, device=pred.device)\n        return loss, (loss.detach(),)\n\n    def step(self):\n        self.task_data = {task: x.requires_grad_() for task, x in zip(self.tasks, next(self.generator))}\n\n    def reload(self):\n        loader = torch.utils.data.DataLoader(\n            self.dataset, batch_size=self.batch_size,\n            num_workers=0, shuffle=self.shuffle, pin_memory=True\n        )\n        self.generator = cycle(loader)\n\nclass ImageTask(Task):\n    \"\"\" Output space for image-style tasks \"\"\"\n\n    def __init__(self, *args, **kwargs):\n\n        self.shape = kwargs.pop(\"shape\", (3, 256, 256))\n        self.mask_val = kwargs.pop(\"mask_val\", -1.0)\n        self.transform = kwargs.pop(\"transform\", lambda x: x)\n        self.resize = kwargs.pop(\"resize\", self.shape[1])\n        self.blur_radius = None\n        self.image_transform = self.load_image_transform()\n        super().__init__(*args, **kwargs)\n\n    @staticmethod\n    def build_mask(target, val=0.0, tol=1e-3):\n        if target.shape[1] == 1:\n            mask = ((target >= val - tol) & (target <= val + tol))\n            mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0\n            return (~mask).expand_as(target)\n\n        mask1 = (target[:, 0, :, :] >= val - tol) & (target[:, 0, :, :] <= val + tol)\n        mask2 = (target[:, 1, :, :] >= val - tol) & (target[:, 1, :, :] <= val + tol)\n        mask3 = (target[:, 2, :, :] >= val - tol) & (target[:, 2, :, :] <= val + tol)\n        mask = (mask1 & mask2 & mask3).unsqueeze(1)\n        mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0\n        return (~mask).expand_as(target)\n\n    def norm(self, pred, target, batch_mean=True, compute_mask=0, compute_mse=True):\n        if compute_mask:\n            mask = ImageTask.build_mask(target, val=self.mask_val)\n            return super().norm(pred*mask.float(), target*mask.float(), batch_mean=batch_mean, compute_mse=compute_mse)\n        else:\n            return super().norm(pred, target, batch_mean=batch_mean, compute_mse=compute_mse)\n\n    def __call__(self, size=256, blur_radius=None):\n        task = copy.deepcopy(self)\n        task.shape = (3, size, size)\n        task.resize = size\n        task.blur_radius = blur_radius\n        task.name +=  \"blur\" if blur_radius else str(size)\n        task.base = self\n        return task\n\n    def plot_func(self, data, name, logger, resize=None, nrow=2):\n        logger.images(data.clamp(min=0, max=1), name, nrow=nrow, resize=resize or self.resize)\n\n    def file_loader(self, path, resize=None, crop=None, seed=0, jitter=False):\n        image_transform = self.load_image_transform(resize=resize, crop=crop, seed=seed, jitter=jitter)\n        return image_transform(Image.open(open(path, 'rb')))[0:3]\n\n    def load_image_transform(self, resize=None, crop=None, seed=0, jitter=False):\n        size = resize or self.resize\n        random.seed(seed)\n        jitter_transform = lambda x: x\n        if jitter: jitter_transform = transforms.ColorJitter(0.4,0.4,0.4,0.1)\n        crop_transform = lambda x: x\n        if crop is not None: crop_transform = transforms.CenterCrop(crop)\n        blur = [GaussianBulr(self.blur_radius)] if self.blur_radius else []\n        return transforms.Compose(blur+[\n            crop_transform,\n            transforms.Resize(size, interpolation=PIL.Image.BILINEAR),\n            jitter_transform,\n            transforms.CenterCrop(size),\n            transforms.ToTensor(),\n            self.transform]\n        )\n\nclass ImageClassTask(ImageTask):\n    \"\"\" Output space for image-class segmentation tasks \"\"\"\n\n    def __init__(self, *args, **kwargs):\n\n        self.classes = kwargs.pop(\"classes\", (3, 256, 256))\n        super().__init__(*args, **kwargs)\n\n    def norm(self, pred, target):\n        loss = F.kl_div(F.log_softmax(pred, dim=1), F.softmax(target, dim=1))\n        return loss, (loss.detach(),)\n\n    def plot_func(self, data, name, logger, resize=None):\n        _, idx = torch.max(data, dim=1)\n        idx = idx.float()/16.0\n        idx = idx.unsqueeze(1).expand(-1, 3, -1, -1)\n        logger.images(idx.clamp(min=0, max=1), name, nrow=2, resize=resize or self.resize)\n\n    def file_loader(self, path, resize=None):\n\n        data = (self.image_transform(Image.open(open(path, 'rb')))*255.0).long()\n        one_hot = torch.zeros((self.classes, data.shape[1], data.shape[2]))\n        one_hot = one_hot.scatter_(0, data, 1)\n        return one_hot\n\n\nclass PointInfoTask(Task):\n    \"\"\" Output space for point-info prediction tasks (what models do we evem use?) \"\"\"\n\n    def __init__(self, *args, **kwargs):\n\n        self.point_type = kwargs.pop(\"point_type\", \"vanishing_points_gaussian_sphere\")\n        self.out_channels = 9\n        super().__init__(*args, **kwargs)\n\n    def plot_func(self, data, name, logger):\n        logger.window(name, logger.visdom.text, str(data.data.cpu().numpy()))\n\n    def file_loader(self, path, resize=None):\n        points = json.load(open(path))[self.point_type]\n        return np.array(points['x'] + points['y'] + points['z'])\n\n\n\n\n\"\"\"\nCurrent list of task definitions.\nAccessible via tasks.{TASK_NAME} or get_task(\"{TASK_NAME}\")\n\"\"\"\n\ndef clamp_maximum_transform(x, max_val=8000.0):\n    x = x.unsqueeze(0).float() / max_val\n    return x[0].clamp(min=0, max=1)\n\ndef crop_transform(x, max_val=8000.0):\n    x = x.unsqueeze(0).float() / max_val\n    return x[0].clamp(min=0, max=1)\n\ndef sobel_transform(x):\n    image = x.data.cpu().numpy().mean(axis=0)\n    blur = ndimage.filters.gaussian_filter(image, sigma=2, )\n    sx = ndimage.sobel(blur, axis=0, mode='constant')\n    sy = ndimage.sobel(blur, axis=1, mode='constant')\n    sob = np.hypot(sx, sy)\n    edge = torch.FloatTensor(sob).unsqueeze(0)\n    return edge\n\ndef blur_transform(x, max_val=4000.0):\n    if x.shape[0] == 1:\n        x = x.squeeze(0)\n    image = x.data.cpu().numpy()\n    blur = ndimage.filters.gaussian_filter(image, sigma=2, )\n    norm = torch.FloatTensor(blur).unsqueeze(0)**0.8 / (max_val**0.8)\n    norm = norm.clamp(min=0, max=1)\n    if norm.shape[0] != 1:\n        norm = norm.unsqueeze(0)\n    return norm\n\ndef get_task(task_name):\n    return task_map[task_name]\n\n\ntasks = [\n    ImageTask('rgb'),\n    ImageTask('imagenet', mask_val=0.0),\n    ImageTask('normal', mask_val=0.502),\n    ImageTask('principal_curvature', mask_val=0.0),\n    ImageTask('depth_zbuffer',\n        shape=(1, 256, 256),\n        mask_val=1.0,\n        transform=partial(clamp_maximum_transform, max_val=8000.0),\n    ),\n    ImageClassTask('segment_semantic',\n        file_name_alt=\"segmentsemantic\",\n        shape=(16, 256, 256), classes=16,\n    ),\n    ImageTask('reshading', mask_val=0.0507),\n    ImageTask('edge_occlusion',\n        shape=(1, 256, 256),\n        transform=partial(blur_transform, max_val=4000.0),\n    ),\n    ImageTask('sobel_edges',\n        shape=(1, 256, 256),\n        file_name='rgb',\n        transform=sobel_transform,\n    ),\n    ImageTask('keypoints3d',\n        shape=(1, 256, 256),\n        transform=partial(clamp_maximum_transform, max_val=64131.0),\n    ),\n    ImageTask('keypoints2d',\n        shape=(1, 256, 256),\n        transform=partial(blur_transform, max_val=2000.0),\n    ),\n]\n\n\ntask_map = {task.name: task for task in tasks}\ntasks = namedtuple('TaskMap', task_map.keys())(**task_map)\n\n\nif __name__ == \"__main__\":\n    IPython.embed()\n"
  },
  {
    "path": "tools/download_data.sh",
    "content": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/0Fqr6t6cZsI0cp9/download\nunzip download\nrm download\ncd data \nunzip -qqo albertville_rgb.zip\nunzip -qqo albertville_normal.zip\nunzip -qqo albertville_principal_curvature.zip\nunzip -qqo almena_rgb.zip\nunzip -qqo almena_normal.zip\nunzip -qqo almena_principal_curvature.zip\nrm albertville_rgb.zip albertville_normal.zip albertville_principal_curvature.zip almena_rgb.zip almena_normal.zip almena_principal_curvature.zip\ncd -\n"
  },
  {
    "path": "tools/download_energy_graph_edges.sh",
    "content": "##!/usr/bin/env bash\n\nSCRIPT_DIR=$( dirname \"$0\" )\n\nFILE=./models/rgb2normal_consistency.pth\nif [ -f \"$FILE\" ]; then\n    echo \"Found consistency network $FILE: skipping download of these networks\"\nelse\n    echo \"Downloading consistency networks...\"\n   sh $SCRIPT_DIR/download_models.sh\nfi\n\nFILE=./models/normal2curvature.pth\nif [ -f \"$FILE\" ]; then\n    echo \"Found perceptual network $FILE: skipping download of these networks\"\nelse\n   echo \"Downloading perceptual networks...\"\n   sh $SCRIPT_DIR/download_percep_models.sh\nfi\n\nFILE=./models/rgb2principal_curvature.pth\nif [ -f \"$FILE\" ]; then\n    echo \"Found energy-graph specific network $FILE: skipping download of these networks\"\nelse\n    echo \"RGB2X energy networks...\"\n    # Get energy-graph-specific links\n    wget https://drive.switch.ch/index.php/s/aZDOEBixS4W7mBL/download\n    unzip download\n    rm download\n    mv energy_graph_edges/* models/\n    rmdir energy_graph_edges\nfi\n\n"
  },
  {
    "path": "tools/download_models.sh",
    "content": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P/download\nunzip download\nrm download\n"
  },
  {
    "path": "tools/download_percep_models.sh",
    "content": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE/download\nunzip download\nrm download\nmv percep_models/* models/\nrmdir percep_models"
  },
  {
    "path": "train.py",
    "content": "'''\n  Name: train.py\n  Desc: Executes training of a network with the consistency framework.\n\n    Here are some options that may be specified for any model. If they have a\n    default value, it is given at the end of the description in parens.\n\n        Data pipeline:\n            Data locations:\n                'train_buildings': A list of the folders containing the training data. This\n                    is defined in configs/split.txt.\n                'val_buildings': As above, but for validation data.\n                'data_dirs': The folder that all the data is stored in. This may just be\n                    something like '/', and then all filenames in 'train_filenames' will\n                    give paths relative to 'dataset_dir'. For example, if 'dataset_dir'='/',\n                    then train_filenames might have entries like 'path/to/data/img_01.png'.\n                    This is defiled in utils.py.\n\n        Logging:\n            'results_dir': An absolute path to where checkpoints are saved. This is\n                defined in utils.py.\n\n        Training:\n            'batch_size': The size of each batch. (64)\n            'num_epochs': The maximum number of epochs to train for. (800)\n            'energy_config': {multiperceptual_targettask} The paths taken to compute the losses.\n            'k': Number of perceptual loss chosen.\n            'data_aug': {True, False} If data augmentation shuold be used during training.\n                See TrainTaskDataset class in datasets.py for the types of data augmentation\n                used. (False)\n\n        Optimization:\n            'initial_learning_rate': The initial learning rate to use for the model. (3e-5)\n\n\n  Usage:\n    python -m train multiperceptual_depth --batch-size 32 --k 8 --max-epochs 100\n'''\n\nimport torch\nimport torch.nn as nn\n\nfrom utils import *\nfrom energy import get_energy_loss\nfrom graph import TaskGraph\nfrom logger import Logger, VisdomLogger\nfrom datasets import load_train_val, load_test, load_ood\nfrom task_configs import tasks, RealityTask\nfrom transfers import functional_transfers\n\nfrom fire import Fire\n\n#import pdb\n\ndef main(\n    loss_config=\"multiperceptual\", mode=\"winrate\", visualize=False,\n    fast=False, batch_size=None,\n    subset_size=None, max_epochs=800, dataaug=False, **kwargs,\n):\n\n\n    # CONFIG\n    batch_size = batch_size or (4 if fast else 64)\n    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)\n\n    # DATA LOADING\n    train_dataset, val_dataset, train_step, val_step = load_train_val(\n        energy_loss.get_tasks(\"train\"),\n        batch_size=batch_size, fast=fast,\n        subset_size=subset_size,\n        dataaug=dataaug,\n    )\n\n    if fast:\n        train_dataset = val_dataset\n        train_step, val_step = 2,2\n\n    train = RealityTask(\"train\", train_dataset, batch_size=batch_size, shuffle=True)\n    val = RealityTask(\"val\", val_dataset, batch_size=batch_size, shuffle=True)\n\n    if fast:\n        train_dataset = val_dataset\n        train_step, val_step = 2,2\n        realities = [train, val]\n    else:\n        test_set = load_test(energy_loss.get_tasks(\"test\"), buildings=['almena', 'albertville'])\n        test = RealityTask.from_static(\"test\", test_set, energy_loss.get_tasks(\"test\"))\n        realities = [train, val, test]\n        # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do:\n        # ood_set = load_ood(energy_loss.get_tasks(\"ood\"))\n        # ood = RealityTask.from_static(\"ood\", ood_set, [tasks.rgb,])\n        # realities.append(ood)\n\n    # GRAPH\n    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,\n        freeze_list=energy_loss.freeze_list,\n        initialize_from_transfer=False,\n    )\n    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)\n\n    # LOGGING\n    os.makedirs(RESULTS_DIR, exist_ok=True)\n    logger = VisdomLogger(\"train\", env=JOB)\n    logger.add_hook(lambda logger, data: logger.step(), feature=\"loss\", freq=20)\n    logger.add_hook(lambda _, __: graph.save(f\"{RESULTS_DIR}/graph.pth\"), feature=\"epoch\", freq=1)\n    energy_loss.logger_hooks(logger)\n    energy_loss.plot_paths(graph, logger, realities, prefix=\"start\")\n\n    # BASELINE\n    graph.eval()\n    with torch.no_grad():\n        for _ in range(0, val_step*4):\n            val_loss, _ = energy_loss(graph, realities=[val])\n            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])\n            val.step()\n            logger.update(\"loss\", val_loss)\n\n        for _ in range(0, train_step*4):\n            train_loss, _ = energy_loss(graph, realities=[train])\n            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])\n            train.step()\n            logger.update(\"loss\", train_loss)\n    energy_loss.logger_update(logger)\n\n    # TRAINING\n    for epochs in range(0, max_epochs):\n\n        logger.update(\"epoch\", epochs)\n        energy_loss.plot_paths(graph, logger, realities, prefix=\"\")\n        if visualize: return\n\n        graph.train()\n        for _ in range(0, train_step):\n            train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True)\n            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])\n            graph.step(train_loss)\n            train.step()\n            logger.update(\"loss\", train_loss)\n\n        graph.eval()\n        for _ in range(0, val_step):\n            with torch.no_grad():\n                val_loss, _ = energy_loss(graph, realities=[val])\n                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])\n            val.step()\n            logger.update(\"loss\", val_loss)\n\n        energy_loss.logger_update(logger)\n\n        logger.step()\n\nif __name__ == \"__main__\":\n    Fire(main)\n"
  },
  {
    "path": "transfers.py",
    "content": "\nimport os, sys, math, random, itertools, functools\nfrom collections import namedtuple\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.utils.checkpoint import checkpoint as util_checkpoint\nfrom torchvision import models\n\nfrom utils import *\nfrom models import TrainableModel, DataParallelModel\nfrom task_configs import get_task, task_map, get_model, Task, RealityTask\n\nfrom modules.percep_nets import DenseNet, Dense1by1Net, DenseKernelsNet, DeepNet, BaseNet, WideNet, PyramidNet\nfrom modules.depth_nets import UNetDepth\nfrom modules.unet import UNet, UNetOld, UNetOld2, UNetReshade\nfrom modules.resnet import ResNetClass\n\nfrom fire import Fire\nimport IPython\n\n\npretrained_transfers = {\n\n    ('normal', 'principal_curvature'):\n        (lambda: Dense1by1Net(), f\"{MODELS_DIR}/normal2curvature.pth\"),\n    ('normal', 'depth_zbuffer'):\n        (lambda: UNetDepth(), f\"{MODELS_DIR}/normal2zdepth_zbuffer.pth\"),\n    ('normal', 'sobel_edges'):\n        (lambda: UNet(out_channels=1, downsample=4).cuda(), f\"{MODELS_DIR}/normal2edges2d.pth\"),\n    ('normal', 'reshading'):\n        (lambda: UNetReshade(downsample=5), f\"{MODELS_DIR}/normal2reshade.pth\"),\n    ('normal', 'keypoints3d'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/normal2keypoints3d.pth\"),\n    ('normal', 'keypoints2d'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/normal2keypoints2d_new.pth\"),\n    ('normal', 'edge_occlusion'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/normal2edge_occlusion.pth\"),\n\n    ('depth_zbuffer', 'normal'):\n        (lambda: UNet(in_channels=1, downsample=6), f\"{MODELS_DIR}/depth2normal.pth\"),\n    ('depth_zbuffer', 'sobel_edges'):\n        (lambda: UNet(downsample=4, in_channels=1, out_channels=1).cuda(), f\"{MODELS_DIR}/depth_zbuffer2sobel_edges.pth\"),\n    ('depth_zbuffer', 'principal_curvature'):\n        (lambda: UNet(downsample=4, in_channels=1), f\"{MODELS_DIR}/depth_zbuffer2principal_curvature.pth\"),\n    ('depth_zbuffer', 'reshading'):\n        (lambda: UNetReshade(downsample=5, in_channels=1), f\"{MODELS_DIR}/depth_zbuffer2reshading.pth\"),\n    ('depth_zbuffer', 'keypoints3d'):\n        (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f\"{MODELS_DIR}/depth_zbuffer2keypoints3d.pth\"),\n    ('depth_zbuffer', 'keypoints2d'):\n        (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f\"{MODELS_DIR}/depth_zbuffer2keypoints2d.pth\"),\n    ('depth_zbuffer', 'edge_occlusion'):\n        (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f\"{MODELS_DIR}/depth_zbuffer2edge_occlusion.pth\"),\n\n    ('reshading', 'depth_zbuffer'):\n        (lambda: UNetReshade(downsample=5, out_channels=1), f\"{MODELS_DIR}/reshading2depth_zbuffer.pth\"),\n    ('reshading', 'keypoints2d'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/reshading2keypoints2d_new.pth\"),\n    ('reshading', 'edge_occlusion'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/reshading2edge_occlusion.pth\"),\n    ('reshading', 'normal'):\n        (lambda: UNet(downsample=4), f\"{MODELS_DIR}/reshading2normal.pth\"),\n    ('reshading', 'keypoints3d'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/reshading2keypoints3d.pth\"),\n    ('reshading', 'sobel_edges'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/reshading2sobel_edges.pth\"),\n    ('reshading', 'principal_curvature'):\n        (lambda: UNet(downsample=5), f\"{MODELS_DIR}/reshading2principal_curvature.pth\"),\n\n    ('rgb', 'sobel_edges'):\n        (lambda: SobelKernel(), None),\n    ('rgb', 'principal_curvature'):\n        (lambda: UNet(downsample=5), f\"{MODELS_DIR}/rgb2principal_curvature.pth\"),\n    ('rgb', 'keypoints2d'):\n        (lambda: UNet(downsample=3, out_channels=1), f\"{MODELS_DIR}/rgb2keypoints2d_new.pth\"),\n    ('rgb', 'keypoints3d'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/rgb2keypoints3d.pth\"),\n    ('rgb', 'edge_occlusion'):\n        (lambda: UNet(downsample=5, out_channels=1), f\"{MODELS_DIR}/rgb2edge_occlusion.pth\"),\n    ('rgb', 'normal'):\n        (lambda: UNet(), f\"{MODELS_DIR}/rgb2normal_baseline.pth\"),\n    ('rgb', 'reshading'):\n        (lambda: UNetReshade(downsample=5), f\"{MODELS_DIR}/rgb2reshading_baseline.pth\"),\n    ('rgb', 'depth_zbuffer'):\n        (lambda: UNet(downsample=6, out_channels=1), f\"{MODELS_DIR}/rgb2zdepth_baseline.pth\"),\n\n    ('normal', 'imagenet'):\n        (lambda: ResNetClass().cuda(), None),\n    ('depth_zbuffer', 'imagenet'):\n        (lambda: ResNetClass().cuda(), None),\n    ('reshading', 'imagenet'):\n        (lambda: ResNetClass().cuda(), None),\n\n    ('principal_curvature', 'sobel_edges'): \n        (lambda: UNet(downsample=4, out_channels=1), f\"{MODELS_DIR}/principal_curvature2sobel_edges.pth\"),\n    ('sobel_edges', 'depth_zbuffer'):\n        (lambda: UNet(downsample=6, in_channels=1, out_channels=1), f\"{MODELS_DIR}/sobel_edges2depth_zbuffer.pth\"),\n\n    ('depth_zbuffer', 'normal'): \n        (lambda: UNet(in_channels=1, downsample=6), f\"{MODELS_DIR}/depth2normal.pth\"),\n    ('keypoints2d', 'normal'):\n        (lambda: UNet(downsample=5, in_channels=1), f\"{MODELS_DIR}/keypoints2d2normal_new.pth\"),\n    ('keypoints3d', 'normal'):\n        (lambda: UNet(downsample=5, in_channels=1), f\"{MODELS_DIR}/keypoints3d2normal.pth\"),\n    ('principal_curvature', 'normal'): \n        (lambda: UNetOld2(), f\"{MODELS_DIR}/principal_curvature2normal.pth\"),\n    ('sobel_edges', 'normal'): \n        (lambda: UNet(in_channels=1, downsample=5).cuda(), f\"{MODELS_DIR}/sobel_edges2normal.pth\"),\n    ('edge_occlusion', 'normal'):\n        (lambda: UNet(in_channels=1, downsample=5), f\"{MODELS_DIR}/edge_occlusion2normal.pth\"),\n\n}\n\nclass Transfer(nn.Module):\n\n    def __init__(self, src_task, dest_task,\n        checkpoint=True, name=None, model_type=None, path=None,\n        pretrained=True, finetuned=False\n    ):\n        super().__init__()\n        if isinstance(src_task, str) and isinstance(dest_task, str):\n            src_task, dest_task = get_task(src_task), get_task(dest_task)\n\n        self.src_task, self.dest_task, self.checkpoint = src_task, dest_task, checkpoint\n        self.name = name or f\"{src_task.name}2{dest_task.name}\"\n        saved_type, saved_path = None, None\n        if model_type is None and path is None:\n            saved_type, saved_path = pretrained_transfers.get((src_task.name, dest_task.name), (None, None))\n\n        self.model_type, self.path = model_type or saved_type, path or saved_path\n        self.model = None\n\n        if finetuned:\n            path = f\"{MODELS_DIR}/ft_perceptual/{src_task.name}2{dest_task.name}.pth\"\n            if os.path.exists(path):\n                self.model_type, self.path = saved_type or (lambda: get_model(src_task, dest_task)), path\n                print (\"Using finetuned: \", path)\n                return\n\n        if self.model_type is None:\n\n            if src_task.kind == dest_task.kind and src_task.resize != dest_task.resize:\n\n                class Module(TrainableModel):\n\n                    def __init__(self):\n                        super().__init__()\n\n                    def forward(self, x):\n                        return resize(x, val=dest_task.resize)\n\n                self.model_type = lambda: Module()\n                self.path = None\n\n            path = f\"{MODELS_DIR}/{src_task.name}2{dest_task.name}.pth\"\n            if src_task.name == \"keypoints2d\" or dest_task.name == \"keypoints2d\":\n                path = f\"{MODELS_DIR}/{src_task.name}2{dest_task.name}_new.pth\"\n            if os.path.exists(path):\n                self.model_type, self.path = lambda: get_model(src_task, dest_task), path\n\n        if not pretrained:\n            print (\"Not using pretrained [heavily discouraged]\")\n            self.path = None\n\n    def load_model(self):\n        if self.model is None:\n            if self.path is not None:\n                self.model = DataParallelModel.load(self.model_type().to(DEVICE), self.path)\n                # if optimizer:\n                #     self.model.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)\n            else:\n                self.model = self.model_type().to(DEVICE)\n                if isinstance(self.model, nn.Module):\n                    self.model = DataParallelModel(self.model)\n        return self.model\n\n    def __call__(self, x):\n        self.load_model()\n        preds = util_checkpoint(self.model, x) if self.checkpoint else self.model(x)\n        preds.task = self.dest_task\n        return preds\n\n    def __repr__(self):\n        return self.name or str(self.src_task) + \" -> \" + str(self.dest_task)\n\n\nclass RealityTransfer(Transfer):\n\n    def __init__(self, src_task, dest_task):\n        super().__init__(src_task, dest_task, model_type=lambda: None)\n\n    def load_model(self, optimizer=True):\n        pass\n\n    def __call__(self, x):\n        assert (isinstance(self.src_task, RealityTask))\n        return self.src_task.task_data[self.dest_task].to(DEVICE)\n\n\nclass FineTunedTransfer(Transfer):\n\n    def __init__(self, transfer):\n        super().__init__(transfer.src_task, transfer.dest_task, checkpoint=transfer.checkpoint, name=transfer.name)\n        self.cached_models = {}\n\n    def load_model(self, parents=[]):\n\n        model_path = get_finetuned_model_path(parents + [self])\n\n        if model_path not in self.cached_models:\n            if not os.path.exists(model_path):\n                print(f\"{model_path} not found, loading pretrained\")\n                self.cached_models[model_path] = super().load_model()\n            else:\n                print(f\"{model_path} found, loading finetuned\")\n                self.cached_models[model_path] = DataParallelModel.load(self.model_type().cuda(), model_path)\n                print(f\"\")\n        self.model = self.cached_models[model_path]\n        return self.model\n\n    def __call__(self, x):\n\n        if not hasattr(x, \"parents\"):\n            x.parents = []\n\n        self.load_model(parents=x.parents)\n        preds = util_checkpoint(self.model, x) if self.checkpoint else self.model(x)\n        preds.parents = x.parents + ([self])\n        return preds\n\n\n\nfunctional_transfers = (\n    Transfer('normal', 'principal_curvature', name='f'),\n    Transfer('principal_curvature', 'normal', name='F'),\n\n    Transfer('normal', 'depth_zbuffer', name='g'),\n    Transfer('depth_zbuffer', 'normal', name='G'),\n\n    Transfer('normal', 'sobel_edges', name='s'),\n    Transfer('sobel_edges', 'normal', name='S'),\n\n    Transfer('principal_curvature', 'sobel_edges', name='CE'),\n    Transfer('sobel_edges', 'principal_curvature', name='EC'),\n\n    Transfer('depth_zbuffer', 'sobel_edges', name='DE'),\n    Transfer('sobel_edges', 'depth_zbuffer', name='ED'),\n\n    Transfer('principal_curvature', 'depth_zbuffer', name='h'),\n    Transfer('depth_zbuffer', 'principal_curvature', name='H'),\n\n    Transfer('rgb', 'normal', name='n'),\n    Transfer('rgb', 'normal', name='npstep',\n        model_type=lambda: UNetOld(),\n        path=f\"{MODELS_DIR}/unet_percepstep_0.1.pth\",\n    ),\n    Transfer('rgb', 'principal_curvature', name='RC'),\n    Transfer('rgb', 'keypoints2d', name='k'),\n    Transfer('rgb', 'sobel_edges', name='a'),\n    Transfer('rgb', 'reshading', name='r'),\n    Transfer('rgb', 'depth_zbuffer', name='d'),\n\n    Transfer('keypoints2d', 'principal_curvature', name='KC'),\n\n    Transfer('keypoints3d', 'principal_curvature', name='k3C'),\n    Transfer('principal_curvature', 'keypoints3d', name='Ck3'),\n\n    Transfer('normal', 'reshading', name='nr'),\n    Transfer('reshading', 'normal', name='rn'),\n\n    Transfer('keypoints3d', 'normal', name='k3N'),\n    Transfer('normal', 'keypoints3d', name='Nk3'),\n\n    Transfer('keypoints2d', 'normal', name='k2N'),\n    Transfer('normal', 'keypoints2d', name='Nk2'),\n\n    Transfer('sobel_edges', 'reshading', name='Er'),\n)\n\nfinetuned_transfers = [FineTunedTransfer(transfer) for transfer in functional_transfers]\nTRANSFER_MAP = {t.name:t for t in functional_transfers}\nfunctional_transfers = namedtuple('functional_transfers', TRANSFER_MAP.keys())(**TRANSFER_MAP)\n\ndef get_transfer_name(transfer):\n    for t in functional_transfers:\n        if transfer.src_task == t.src_task and transfer.dest_task == t.dest_task:\n            return t.name\n    return transfer.name\n\n(f, F, g, G, s, S, CE, EC, DE, ED, h, H, n, npstep, RC, k, a, r, d, KC, k3C, Ck3, nr, rn, k3N, Nk3, Er, k2N, N2k) = functional_transfers\n\nif __name__ == \"__main__\":\n    y = g(F(f(x)))\n    print (y.shape)\n\n\n\n\n\n\n"
  },
  {
    "path": "utils.py",
    "content": "\nimport numpy as np\nimport random, sys, os, time, glob, math, itertools, pickle\nimport parse\nfrom collections import defaultdict\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.autograd import Variable\n\nfrom functools import partial\nfrom scipy import ndimage\n\nimport IPython\n\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nUSE_CUDA = torch.cuda.is_available()\ndtype = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n\nEXPERIMENT, BASE_DIR = open(\"config/jobinfo.txt\").read().strip().split(', ')\nJOB = \"_\".join(EXPERIMENT.split(\"_\")[0:-1])\n\nMODELS_DIR = f\"{BASE_DIR}/models\"\nDATA_DIRS = [f\"/taskonomy-data/taskonomydata\", 'data']\nRESULTS_DIR = f\"{BASE_DIR}/results/results_{EXPERIMENT}\"\nSHARED_DIR = f\"{BASE_DIR}/shared\"\nOOD_DIR = f\"{SHARED_DIR}/ood_standard_set\"\nUSE_RAID = False\n\n# os.system(f\"mkdir -p {RESULTS_DIR}\")\n\ndef both(x, y):\n    x = dict(x.items())\n    x.update(y)\n    return x\n\ndef elapsed(last_time=[time.time()]):\n    \"\"\" Returns the time passed since elapsed() was last called. \"\"\"\n    current_time = time.time()\n    diff = current_time - last_time[0]\n    last_time[0] = current_time\n    return diff\n\ndef cycle(iterable):\n    \"\"\" Cycles through iterable without making extra copies. \"\"\"\n    while True:\n        for i in iterable:\n            yield i\n\ndef average(arr):\n    return sum(arr) / len(arr)\n\n# def random_resize(iterable, vals=[128, 192, 256, 320]):\n#    \"\"\" Cycles through iterable while randomly resizing batch values. \"\"\"\n#     from transforms import resize\n#     while True:\n#         for X, Y in iterable:\n#             val = random.choice(vals)\n#             yield resize(X.to(DEVICE), val=val).detach(), resize(Y.to(DEVICE), val=val).detach()\n\n\ndef get_files(exp, data_dirs=DATA_DIRS, recursive=False):\n    \"\"\" Gets data files across mounted directories matching glob expression pattern. \"\"\"\n    # cache = SHARED_DIR + \"/filecache_\" + \"_\".join(exp.split()).replace(\".\", \"_\").replace(\"/\", \"_\").replace(\"*\", \"_\") + (\"r\" if recursive else \"f\") + \".pkl\"\n    # print (\"Cache file: \", cache)\n    # if os.path.exists(cache):\n    #     return pickle.load(open(cache, 'rb'))\n\n    files, seen = [], set()\n    for data_dir in data_dirs:\n        for file in glob.glob(f'{data_dir}/{exp}', recursive=recursive):\n            if file[len(data_dir):] not in seen:\n                files.append(file)\n                seen.add(file[len(data_dir):])\n\n    # pickle.dump(files, open(cache, 'wb'))\n    return files\n\n\ndef get_finetuned_model_path(parents):\n    if BASE_DIR == \"/\":\n        return f\"{RESULTS_DIR}/\" + \"_\".join([parent.name for parent in parents[::-1]]) + \".pth\"\n    else:\n        return f\"{MODELS_DIR}/finetuned/\" + \"_\".join([parent.name for parent in parents[::-1]]) + \".pth\"\n\n\ndef plot_images(model, logger, test_set, dest_task=\"normal\",\n        ood_images=None, show_masks=False, loss_models={},\n        preds_name=None, target_name=None, ood_name=None,\n    ):\n\n    from task_configs import get_task, ImageTask\n\n    test_images, preds, targets, losses, _ = model.predict_with_data(test_set)\n\n    if isinstance(dest_task, str):\n        dest_task = get_task(dest_task)\n\n    if show_masks and isinstance(dest_task, ImageTask):\n        test_masks = ImageTask.build_mask(targets, dest_task.mask_val, tol=1e-3)\n        logger.images(test_masks.float(), f\"{dest_task}_masks\", resize=64)\n\n    dest_task.plot_func(preds, preds_name or f\"{dest_task.name}_preds\", logger)\n    dest_task.plot_func(targets, target_name or f\"{dest_task.name}_target\", logger)\n\n    if ood_images is not None:\n        ood_preds = model.predict(ood_images)\n        dest_task.plot_func(ood_preds, ood_name or f\"{dest_task.name}_ood_preds\", logger)\n\n    for name, loss_model in loss_models.items():\n        with torch.no_grad():\n            output = loss_model(preds, targets, test_images)\n            if hasattr(output, \"task\"):\n                output.task.plot_func(output, name, logger, resize=128)\n            else:\n                logger.images(output.clamp(min=0, max=1), name, resize=128)\n\n\ndef gaussian_filter(channels=3, kernel_size=5, sigma=1.0, device=0):\n\n    x_cord = torch.arange(kernel_size).float()\n    x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)\n    y_grid = x_grid.t()\n    xy_grid = torch.stack([x_grid, y_grid], dim=-1)\n\n    mean = (kernel_size - 1) / 2.\n    variance = sigma ** 2.\n    gaussian_kernel = (1. / (2. * math.pi * variance)) * torch.exp(\n        -torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance)\n    )\n    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)\n    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)\n    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)\n\n    return gaussian_kernel\n\n\ndef motion_blur_filter(kernel_size=15):\n    channels = 3\n    kernel_motion_blur = torch.zeros((kernel_size, kernel_size))\n    kernel_motion_blur[int((kernel_size - 1) / 2), :] = torch.ones(kernel_size)\n    kernel_motion_blur = kernel_motion_blur / kernel_size\n    kernel_motion_blur = kernel_motion_blur.view(1, 1, kernel_size, kernel_size)\n    kernel_motion_blur = kernel_motion_blur.repeat(channels, 1, 1, 1)\n    return kernel_motion_blur\n\n\ndef sobel_kernel(x):\n    def sobel_transform(x):\n        image = x.data.cpu().numpy().mean(axis=0)\n        blur = ndimage.filters.gaussian_filter(image, sigma=2, )\n        sx = ndimage.sobel(blur, axis=0, mode='constant')\n        sy = ndimage.sobel(blur, axis=1, mode='constant')\n        sob = np.hypot(sx, sy)\n        edge = torch.FloatTensor(sob).unsqueeze(0)\n        return edge\n\n    x = torch.stack([sobel_transform(y) for y in x], dim=0)\n    return x.to(DEVICE).requires_grad_()\n\n\nclass SobelKernel(nn.Module):\n    def __init__(self):\n        super().__init__()\n    \n    def forward(self, x):\n        return sobel_kernel(x)\n    \ndef set_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed) # cpu  vars\n    torch.cuda.manual_seed_all(seed) # gpu vars\n"
  }
]