Full Code of EPFL-VILAB/XTConsistency for AI

master 6d051610358d cached
32 files
208.2 KB
55.5k tokens
322 symbols
1 requests
Download .txt
Showing preview only (218K chars total). Download the full file or copy to clipboard to get everything.
Repository: EPFL-VILAB/XTConsistency
Branch: master
Commit: 6d051610358d
Files: 32
Total size: 208.2 KB

Directory structure:
gitextract_hs8roq0r/

├── .gitignore
├── Dockerfile
├── README.md
├── config/
│   ├── jobinfo.txt
│   ├── split.txt
│   ├── split_fullplus.txt
│   └── split_medium.txt
├── datasets.py
├── demo.py
├── energy.py
├── graph.py
├── hooks/
│   └── build
├── logger.py
├── models.py
├── modules/
│   ├── __init__.py
│   ├── depth_nets.py
│   ├── percep_nets.py
│   ├── resnet.py
│   ├── unet.py
│   └── unet_mirrored.py
├── plotting.py
├── requirements.txt
├── scripts/
│   ├── energy_calc.py
│   └── jobinfo.txt
├── task_configs.py
├── tools/
│   ├── download_data.sh
│   ├── download_energy_graph_edges.sh
│   ├── download_models.sh
│   └── download_percep_models.sh
├── train.py
├── transfers.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.vscode
raw
__pycache__/*
sftp-config*.json
models
processed
*.pth
*.tar
output
*.gz
.DS_Store/*
result
checkpoints
*.pyc
data/results
.DS_Store
local
command.txt
*.ipynb_checkpoints


================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:10.1-base-ubuntu16.04
LABEL version="1.0"
LABEL description="Build using the command \
  'docker build -t epflvil/xtconsistency:latest .'"

ARG DEFAULT_GIT_BRANCH=master
ARG DEFAULT_GIT_REPO=git@github.com:EPFL-VIL/XTConsistency.git
ARG GITHUB_DEPLOY_KEY_PATH=docker_key
ARG GITHUB_DEPLOY_KEY
ARG GITHUB_DEPLOY_KEY_PUBLIC

RUN apt-get update && apt-get install -y \
    curl \
    wget \
    ca-certificates \
    sudo \
    git \
    unzip \
    bzip2 \
    libx11-6 \
    nano \
    screen \
    gcc \
    python3-dev \
 && rm -rf /var/lib/apt/lists/*

RUN mkdir /root/.ssh
RUN echo "DEPLOY" "${GITHUB_DEPLOY_KEY}"
RUN echo "DEPLOY" "${GITHUB_DEPLOY_KEY_PUBLIC}"
RUN echo "${GITHUB_DEPLOY_KEY}" > /root/.ssh/id_rsa
RUN echo "${GITHUB_DEPLOY_KEY_PUBLIC}" > /root/.ssh/id_rsa.pub
RUN chmod 600 /root/.ssh/id_rsa
RUN cat /root/.ssh/id_rsa*
RUN eval $(ssh-agent) && \
    ssh-add /root/.ssh/id_rsa && \
    ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts
RUN git clone --single-branch --branch "${DEFAULT_GIT_BRANCH}" "${DEFAULT_GIT_REPO}" /app

#############################
# Pull code
#############################
# RUN mkdir /app
WORKDIR /app

RUN cd /app && git config core.filemode false
RUN chmod -R 777 /app


#############################
# Create non-root user
#############################
# Create a non-root user and switch to it
RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
 && chown -R user:user /app
RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
USER user

# All users can use /home/user as their home directory
ENV HOME=/home/user
RUN chmod 777 /home/user


#############################
# Create conda environment
#############################
# Install Miniconda
RUN curl -Lso ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh \
 && chmod +x ~/miniconda.sh \
 && ~/miniconda.sh -b -p ~/miniconda \
 && rm ~/miniconda.sh
ENV PATH=/home/user/miniconda/bin:$PATH
ENV CONDA_AUTO_UPDATE_CONDA=false

# Create a Python 3.6 environment
RUN /home/user/miniconda/bin/conda create -y --name py36 python=3.6.9 \
 && /home/user/miniconda/bin/conda clean -ya
ENV CONDA_DEFAULT_ENV=py36
ENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV
ENV PATH=$CONDA_PREFIX/bin:$PATH
RUN /home/user/miniconda/bin/conda install conda-build=3.18.9=py36_3 \
 && /home/user/miniconda/bin/conda clean -ya


#############################
# Python packages
#############################
RUN conda install -y -c pytorch \
    cudatoolkit=10.1 \
    "pytorch=1.4.0" \
    "torchvision=0.5.0" \
  && conda clean -ya
RUN conda install -y \
  ipython==6.5.0 \
  matplotlib==3.0.3 \
  plac==0.9.6 \
  py==1.6.0 \
  scipy==1.3.1 \
  tqdm==4.36.1 \
  pathlib==1.0.1 \
  seaborn==0.10.0 \
  scikit-learn==0.22.1 \
  scikit-image==0.16.2 \
 && conda clean -ya
RUN conda install -c conda-forge jupyterlab && conda clean -ya
RUN pip install runstats==1.8.0 \
  fire==0.2.1 \
  visdom==0.1.8.9 \
  parse==1.12.1

  
###############################################
# Default command and environment variables
###############################################
RUN sudo touch /root/.bashrc && sudo chmod 770 /root/.bashrc
RUN echo export PATH="\$PATH:"$PATH >> /tmp/.bashrc
RUN sudo su -c 'cat /tmp/.bashrc >> /root/.bashrc' && rm /tmp/.bashrc

# Set the default command to bash
CMD ["bash"]


================================================
FILE: README.md
================================================
# Robust Learning Through Cross-Task Consistency <br> 

[![](./assets/intro.jpg)](https://consistency.epfl.ch)

<table>
      <tr><td><em>Above: A comparison of the results from consistency-based learning and learning each task individually. The yellow markers highlight the improvement in fine grained details.</em></td></tr>
</table>

<br>
This repository contains tools for training and evaluating models using consistency:

- [Pretrained models](#pretrained-models)
- [Demo code](#quickstart-run-demo-locally) and an **[online live demo](https://consistency.epfl.ch/demo/)**
- [_Uncertainty energy_ estimation code](#Energy-computation)
- [Training scripts](#training)
- [Docker and installation instructions](#installation)

for the following paper:
<!-- <br><a href=https://consistency.epfl.ch>Robust Learing Through Cross-Task Consistency</a> (CVPR 2020, Oral).<br> -->
<!-- Amir Zamir, Alexander Sax, Teresa Yeo, Oğuzhan Kar, Nikhil Cheerla, Rohan Suri, Zhangjie Cao, Jitendra Malik, Leonidas Guibas  -->

<div style="text-align:center">
<h4><a href=https://consistency.epfl.ch>Robust Learing Through Cross-Task Consistency</a> (CVPR 2020, Best Paper Award Nomination, Oral)</h4>
</div>
<br>

[![Cross-Task Consistency Results](./assets/vid_thumbnail_600_gif2.gif)](https://youtu.be/6CSmmrBNX9M "Click to watch results of applying the networks trained with cross-task consisntency frame-by-frame on sample YouTube videos.")


For further details, a [live demo](https://consistency.epfl.ch/demo/), [video visualizations](https://consistency.epfl.ch/visuals/), and an [overview talk](https://consistency.epfl.ch/#paper), refer to our [project website](https://consistency.epfl.ch/).

#### PROJECT WEBSITE:
<div style="text-align:center">

| [LIVE DEMO](https://consistency.epfl.ch/demo/) | [VIDEO VISUALIZATION](https://consistency.epfl.ch/visuals/) 
|:----:|:----:|
| Upload your own images and see the results of different consistency-based models vs. various baselines.<br><br>[<img src=./assets/screenshot-demo.png width="400">](https://consistency.epfl.ch/demo/) | Visualize models with and without consistency, evaluated on a (non-cherry picked) YouTube video.<br><br><br>[<img src=./assets/output_video.gif width="400">](https://consistency.epfl.ch/visuals/) |

</div>

---

Table of Contents
=================

   * [Introduction](#introduction)
   * [Installation](#installation)
   * [Quickstart (demo code)](#quickstart-run-demo-locally)
   * [Energy computation](#energy-computation)
   * [Download all pretrained models](#pretrained-models)
   * [Train a consistency model](#training)
     * [Instructions for training](#steps)
     * [To train on other configurations](#to-train-on-other-target-domains)
   * [Citing](#citation)

<br>

## Introduction 

Visual perception entails solving a wide set of tasks (e.g. object detection, depth estimation, etc). The predictions made for each task out of a particular observation are not independent, and therefore, are expected to be **consistent**.

**What is consistency?** Suppose an object detector detects a ball in a particular region of an image, while a depth estimator returns a flat surface for the same region. This presents an issue -- at least one of them has to be wrong, because they are _inconsistent_.

**Why is it important?** 
1. Desired learning tasks are usually predictions of different aspects of a single underlying reality (the scene that underlies an image). Inconsistency among predictions implies contradiction. 
2. Consistency constraints are informative and can be used to better fit the data or lower the sample complexity. They may also reduce the tendency of neural networks to learn "surface statistics" (superficial cues) by enforcing constraints rooted in different physical or geometric rules. This is empirically supported by the improved generalization of models when trained with consistency constraints.

**How do we enforce it?** The underlying concept is that of path independence in a network of tasks. Given an endpoint `Y2`, the path from 
`X->Y1->Y2` should give the same results as `X->Y2`. This can be generalized to a larger system, with paths of arbitrary lengths. In this case, the nodes of the graph are our prediction domains (eg. depth, normal) and the edges are neural networks mapping these domains.

This repository includes [training](#training) code for enforcing cross task consistency, [demo](#run-demo-script) code for visualizing the results of a consistency trained model on a given image and [links](#pretrained-models) to download these models. For further details, refer to our [paper]() or [website](https://consistency.epfl.ch/).


#### Consistency Domains

Consistency constraints can be used for virtually any set of domains. This repository considers transferring between image domains, and our networks were trained for transferring between the following domains from the [Taskonomy dataset](https://github.com/StanfordVL/taskonomy/tree/master/data).

    Curvature         Edge-3D            Reshading
    Depth-ZBuffer     Keypoint-2D        RGB       
    Edge-2D           Keypoint-3D        Surface-Normal 


The repo contains consistency-trained models for `RGB -> Surface-Normals`,  `RGB -> Depth-ZBuffer`, and `RGB -> Reshading`. In each case the remaining 7 domains are used as consistency constraints in during training.

Descriptions for each domain can be found in the [supplementary file](http://taskonomy.stanford.edu/taskonomy_supp_CVPR2018.pdf) of Taskonomy.

#### Network Architecture

All networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions and trained with the L1 loss. See the table below for more information.

|        Task Name        | Output Dimension | Downsample Blocks |
|-------------------------|------------------|-------------------|
| `RGB -> Depth-ZBuffer`  | 256x256x1        | 6                 |
| `RGB -> Reshading`      | 256x256x1        | 5                 |
| `RGB -> Surface-Normal` | 256x256x3        | 6                 |

Other networks (e.g. `Curvature -> Surface-Normal`) use a UNet, their architecture hyperparameters are detailed in [transfers.py](./transfers.py).

More information on the models, including download links, can be found [here](#pretrained-models) and in the [supplementary material](https://consistency.epfl.ch/supplementary_material).

<br>
<br>

## Installation

There are two convenient ways to run the code. Either using Docker (recommended) or using a Python-specific tool such as pip, conda, or virtualenv.

#### Installation via Docker [Recommended]

We provide a docker that contains the code and all the necessary libraries. It's simple to install and run.
1. Simply run:
<!-- docker pull epflvilab/xtconsistency:latest -->
```bash
docker run --runtime=nvidia -ti --rm epflvilab/xtconsistency:latest
```
The code is now available in the docker under your home directory (`/app`), and all the necessary libraries should already be installed in the docker.

#### Installation via Pip/Conda/Virtualenv
The code can also be run using a Python environment manager such as Conda. See [requirements.txt](./requirements.txt) for complete list of packages. We recommend doing a clean installation of requirements using virtualenv:
1.  Clone the repo:
```bash
git clone git@github.com:EPFL-VILAB/XTConsistency.git
cd XTConsistency
```

2. Create a new environment and install the libraries:
```bash
conda create -n testenv -y python=3.6
source activate testenv
pip install -r requirements.txt
```


<br>
<br>

## Quickstart (Run Demo Locally)

#### Download the consistency trained networks
If you haven't yet, then download the [pretrained models](#Download-consistency-trained-models). Models used for the demo can be downloaded with the following command:
```bash
sh ./tools/download_models.sh
```

This downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. Individial models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P).

#### Run a model on your own image

To run the trained model of a task on a specific image:

```bash
python demo.py --task $TASK --img_path $PATH_TO_IMAGE_OR_FOLDER --output_path $PATH_TO_SAVE_OUTPUT
```

The `--task` flag specifies the target task for the input image, which should be either `normal`, `depth` or `reshading`.

To run the script for a `normal` target on the [example image](./assets/test.png):

```bash
python demo.py --task normal --img_path assets/test.png --output_path assets/
```

It returns the output prediction from the baseline (`test_normal_baseline.png`) and consistency models (`test_normal_consistency.png`).

Test image                 |  Baseline			|  Consistency
:-------------------------:|:-------------------------: |:-------------------------:
![](./assets/test_scaled.png)|  ![](./assets/test_normal_baseline.png) |  ![](./assets/test_normal_consistency.png)


Similarly, running for target tasks `reshading` and `depth` gives the following.

  Baseline (reshading)      |  Consistency (reshading)   |  Baseline (depth)	       |  Consistency (depth)
:-------------------------: |:-------------------------: | :-------------------------: |:-------------------------:
![](./assets/test_reshading_baseline.png) |  ![](./assets/test_reshading_consistency.png) | ![](./assets/test_depth_baseline.png) |  ![](./assets/test_depth_consistency.png)



<br>
<br>

## Energy Computation

Training with consistency involves several paths that each predict the target domain, but using different cues to do so. The disagreement between these predictions yields an unsupervised quantity, _consistency energy_, that our CVPR 2020 paper found correlates with prediciton error. You can view the pixel-wise _consistency energy_ (example below) using our [live demo](https://consistency.epfl.ch/demo/).


|             Sample Image             |             Normal Prediction             |             Consistency Energy             |
|:------------------------------------:|:------------------------------------:|:------------------------------------:|
| <img src=./assets/energy_query.png width="600">  | <img src=./assets/energy_normal_prediction.png width="600">  | <img src=./assets/energy_prediction.png width="600">  |
| _Sample image from the Stanford 2D3DS dataset._  | _Some chair legs are missing in the `RGB -> Normal` prediction._  |  _The white pixels indicate higher uncertainty about areas with missing chair legs._  | 


To compute energy locally, over many images, and/or to plot energy vs error, you can use the following `energy_calc.py` script. For example, to reproduce the following scatterplot using `energy_calc.py`:

|             Energy vs. Error             |
|:----------------------------------------:|
| ![](./assets/energy_vs_error.jpg)        |
| _Result from running the command below._ | 



First download a subset of images from the Taskonomy buildings `almena` and `albertville` (512 images per domain, 388MB):
```bash
sh ./tools/download_data.sh
```


Second, download all the networks necessary to compute the consistency energy. The following script will download them for you (skipping previously downloaded models) (0.8GB - 4.0GB):
```bash
sh ./tools/download_energy_graph_edges.sh
```


Now we are ready to compute energy. The following command generates a scatter plot of _consistency energy_ vs. prediction error:

```bash
python -m scripts.energy_calc energy_calc --batch_size 2 --subset_size=128 --save_dir=results
```


By default, it computes the energy and error of the `subset_size` number of points on the Taskonomy buildings `almena` and `albertville`. The error is computed for the `normal` target. The resulting plot is saved to `energy.pdf` in `RESULTS_DIR` and the corresponding data to `data.csv`. 

#### Compute energy on arbitrary images
_Consistency energy_ is an unsupervised quantity and as such, no ground-truth labels are necessary. To compute the energy for all query images in a directory, run:

```bash
python -m scripts.energy_calc energy_calc_nogt 
    --data-dir=PATH_TO_QUERY_IMAGE --batch_size 1 --save_dir=RESULTS_DIR \
    --subset_size=NUMBER_OF_IMAGES --cont=PATH_TO_TRAINED_MODEL
```

It will append a dashed horizontal line to the plot above where the energy of the query image(s) are. This plot is saved to `energy.pdf` in `RESULTS_DIR`.


<br>
<br>

## Pretrained Models

We are providing all of our pretrained models for download. These models are the same ones used in the [live demo](https://consistency.epfl.ch/demo/) and [video evaluations](https://consistency.epfl.ch/visuals/).


#### Network Architecture
All networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions. All models were trained with the L1 loss.


#### Download consistency-trained models
Instructions for downloading the trained consistency models can be found [here](#download-consistency-trained-networks)
```bash
sh ./tools/download_models.sh
```

This downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. See the table below for specifics:

|        Task Name        | Output Dimension | Downsample Blocks |
|-------------------------|------------------|-------------------|
| `RGB -> Depth-ZBuffer`  | 256x256x1        | 6                 |
| `RGB -> Reshading`      | 256x256x1        | 5                 |
| `RGB -> Surface-Normal` | 256x256x3        | 6                 |

Individual consistency models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P).



#### Download perceptual networks
The pretrained perceptual models can be downloaded with the following command.

```bash
sh ./tools/download_percep_models.sh
```

This downloads the perceptual models for the `depth`, `normal` and `reshading` target (1.6GB). Each target has 7 pretrained models (from the other sources below).

```
Curvature         Edge-3D            Reshading
Depth-ZBuffer     Keypoint-2D        RGB       
Edge-2D           Keypoint-3D        Surface-Normal 
```

Perceptual model architectural hyperparameters are detailed in [transfers.py](./transfers.py), and some of the pretrained models were trained using L2 loss. For using these models with the provided training code, the pretrained models should be placed in the file path defined by `MODELS_DIR` in [utils.py](./utils.py#L25).

Individual perceptual models can be downloaded [here](https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE).



#### Download baselines
We also provide the models for other baselines used in the paper. Many of these baselines appear in the [live demo](https://consistency.epfl.ch/demo/). The pretrained baselines can be downloaded [here](https://drive.switch.ch/index.php/s/gdom4FpiiYo1Qay). Note that we will not be providing support for them. 
- A full list of baselines is in the table below:
   |                     Baseline Method                     |                                                       Description                                                              |      Tasks (RGB -> X)    | 
   |---------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------|--------------------------|
   | Baseline UNet [[PDF](https://arxiv.org/pdf/1505.04597.pdf)]   | UNets trained on the Taskonomy dataset.                                                                                  | Normal, Reshade, Depth      | 
   | Baseline Perceptual Loss                                | Trained using a randomly initialized percepual network, similar to [RND](http://arxiv.org/pdf/1810.12894.pdf).                 | Normal                      | 
   | Cycle Consistency [[PDF](https://arxiv.org/pdf/1703.10593.pdf)] | A CycleGAN trained on the Taskonomy dataset.                                                                           | Normal                      | 
   | GeoNet [[PDF](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8578135_)] | Trained on the Taskonomy dataset and using L1 instead of L2 loss.                                    | Normal, Depth            |
   | Multi-Task [[PDF](http://arxiv.org/pdf/1609.02132.pdf)] | A multi-task model we trained using UNets, using a shared encoder (similar to [here](http://arxiv.org/pdf/1609.02132.pdf))     | [All](#consistency-domains) |
   | Pix2Pix [[PDF](https://arxiv.org/pdf/1611.07004.pdf)]   | A Pix2Pix model trained on the Taskonomy dataset.                                                                              | Normal                      | 
   | Taskonomy [[PDF](https://arxiv.org/pdf/1804.08328.pdf)] | Pretrained models (converted to pytorch [here](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-)), originally trained [here](https://github.com/StanfordVL/taskonomy/tree/master/taskbank).  | Normal, Reshading, Depth* |
*Models for other tasks are available using the [`visualpriors` package](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-) or in Tensorflow via the [Taskonomy GitHub page](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-).






<br>
<br>

## Training

We used the provided training code to train our consistency models on the Taskonomy dataset. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week. 

> **Runnable Example:** 
   You'll find that the code in the rest of this section expects about 12TB of data (9 single-image tasks from Taskonomy). For a quick runnable example that gives the gist, try the following:  
>     
>  First download the data and then start a visdom (logging) server:
>  ```bash
>  sh ./tools/download_data.sh # Starter data (388MB)
>  visdom &                    # To view the telemetry
>  ```
>  
>  Then, start the training using the following command, which cascades two models (trains a `normal` model using `curvature` consistenct on a training set of 512 images).  
>   ```bash
>   python -m train example_cascade_two_networks --k 1 --fast
>   ```
>   You can add more pereceptual losses by changing the config in `energy.py`. For example, train the above model using _both_ `curvature` and `2D edge` consistency:
>   ```bash
>   python -m train example_normal --k 2 --fast
>   ```

Assuming that you want to train on the full dataset or [on your own dataset], read on.
#### The code is structured as follows
```python
config/             # Configuration parameters: where to save results, etc.
    split.txt           # Train, val split
    jobinfo.txt         # Defines job name, base_dir
modules/            # Network definitions
train.py            # Training script
dataset.py          # Creates dataloader
energy.py           # Defines path config, computes total loss, logging
models.py           # Implements forward backward pass
graph.py            # Computes path defined in energy.py
task_configs.py     # Defines task specific preprocessing, masks, loss fn
transfers.py        # Loads models
utils.py            # Defines file paths (described below) 
demo.py             # Demo script
```

#### Expected folder structure
The code expects folders structured as follows. These can be modified by changing values in `utils.py`
```python
base_dir/                   # The following paths are defined in utils.py (BASE_DIR)
    shared/                 # with the corresponding variable names in brackets
        models/             # Pretrained models (MODELS_DIR)
        results_[jobname]/  # Checkpoint of model being trained (RESULTS_DIR)
        ood_standard_set/   # OOD data for visualization (OOD_DIR)
    data_dir/               # taskonomy data (DATA_DIRS)
```

#### Training with consistency

   
1) **Define locations for data, models, etc.:** Create a `jobinfo.txt` file and define the name of the job and the absolute path to `BASE_DIR` where data, models results would be stored, as shown in the folder structure above. An example config is provided in the starter code (`configs/jobinfo.txt`). To modify individual file paths eg. the models folder, change `MODELS_DIR` variable name in [utils.py](./utils.py#L25).  
  
   We won't cover downloading the Taskonomy dataset, which can be downloaded following the instructions [here](https://github.com/StanfordVL/taskonomy/tree/master/data)

2) **Download perceptual networks:** If you want to initialize from our pretrained models, then then [download them](#Download-perceptual-networks) with the following command (1.6GB): 
    ```bash
    sh ./tools/download_percep_models.sh
    ```
    More info about the networks is available [here](#Download-perceptual-networks).

3) **Train with consistency** using the command:

   ```bash
   python -m train multiperceptual_{depth,normal,reshading}
   ```

   For example, to run the training code for the `normal` target, run 

   ```bash
   python -m train multiperceptual_normal
   ```

   This trains the model for the `normal` target with 8 perceptual losses ie. `curvature`, `edge2d`, `edge3d`, `keypoint2d`, `keypoint3d`, `reshading`, `depth` and `imagenet`. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week.

   Additional arugments can be specified during training, the most commonly used ones are listed below. For the full list, refer to the [training script](./train.py).
   - The flag `--k` defines the number of perceptual losses used, thus reducing GPU memory requirements.
   - There are several options for choosing how this subset is chosen 1. randomly (`--random-select`) 2. winrate (`--winrate`)
   - Data augmentation is not done by default, it can be added to the training data with the flag `--dataaug`. The transformations applied are 1. random crop with probability 0.5 2. [color jitter](https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=color%20jitter#torchvision.transforms.ColorJitter) with probability 0.5.

   To train a `normal` target domain with 2 perceptual losses selected randomly each epoch, run the following command.

   ```bash
   python -m train multiperceptual_normal --k 2 --random-select
   ```

4) **Logging:** The losses and visualizations are logged in Visdom. This can be accessed via `[server name]/env/[job name]` eg. `localhost:8888/env/normaltarget_allperceps`. 

   An example visualization is shown below. We plot the the outputs from the paths defined in the energy configuration used. Two windows are shown, one shows the predictions before training starts, the other updates them after each epoch. The labels for each column can be found at the top of the window. The second column has the target's ground truth `y^`, the third its prediction `n(x)` from the RGB image `x`. Thereafter, the predictions of each pair of images with the same domain are given by the paths `f(y^),f(n(x))`, where `f` is from the target domain to another domain eg. `curvature`.

   ![](./assets/visdom_eg.png)

   **Logging conventions:** For uninteresting historical reasons, the columns in the logging during training might have strange names. You can define your own names instead of using these by changing the config file in `energy.py`.
   
   Here's a quick guide to the current convention. For example, when training with a `normal` model using consistency:
    - The RGB input is denoted as `x` and the `target` domain is denoted as `y`. The ground truth label for a domain is marked with a `^`(e.g. `y^` for the for `target` domain). 
    - The direct (`RGB -> Z`) and perceptual (`target [Y] -> Z`) transfer functions are named as follows:<br>(i.e. the function for `rgb` to `curvature` is `RC`; for `normal` to `curvature` it's `f`)

   |  Domain (Z) | `rgb -> Z`<br>(Direct) | `Y -> Z`<br>(Perceptual) ||    Domain (Z)   | `rgb -> Z`<br>(Direct) | `Y -> Z`<br>(Perceptual) |
   |-------------|------------------------|--------------------------|-|-----------------|------------------------|---------------------------|
   | target      | n                      | -                        || keypoints2d     | k2                     | Nk2                       |
   | curvature   | RC                     | f                        || keypoints3d     | k3                     | Nk3                       |
   | sobel edges | a                      | s                        || edge occlusion  | E0                     | nE0                       |

#### To train on other target domains
1. A new configuration should be defined in the `energy_configs` dictionary in [energy.py](./energy.py#L39-L521). 

   Decription of the infomation needed:
   - `paths`: `X1->X2->X3`. The keys in this dictionary uses a function notation eg. `f(n(x))`, with its corresponding value being a list of task objects that defines the domains being transfered eg. `[rgb, normal, curvature]`. The `rgb` input is defined as `x`, `n(x)` returns `normal` predictions from `rgb`, and `f(n(x))` returns `curvature` from `normal`. These notations do not need to be same for all configurations. The [table](#function-definitions) below lists those that have been kept constant for all targets.
   - `freeze_list`: the models that will not be optimized,
   - `losses`: loss terms to be constructed from the paths defined above,
   - `plots`: the paths to plots in the visdom environment.

2. New models may need to be defined in the `pretrained_transfers` dictionary in [transfers.py](./transfers.py#L26-L97). For example, for a `curvature` target, and perceptual model `curvature` to `normal`, the code will look for the `principal_curvature2normal.pth` file in `MODELS_DIR` if it is not defined in [transfers.py](./transfers.py#L26-L97).



#### To train on other datasets
The expected folder structure for the data is,
```
DATA_DIRS/
  [building]_[domain]/
      [domain]/
          [view]_domain_[domain].png
          ...
```
Pytorch's dataloader _\_\_getitem\_\__ method has been overwritten to return a tuple of all tasks for a given building and view point. This is done in [datasets.py](./datasets.py#L181-L198). Thus, for other folder structures, a function to get the corresponding file paths for different domains should be defined. 

For task specific configs, like transformations and masks, are defined in [task_configs.py](./task_configs.py#L341-L373).

<br>
<br>

## Citation
If you find the code, models, or data useful, please cite this paper:

```
@article{zamir2020consistency,
  title={Robust Learning Through Cross-Task Consistency},
  author={Zamir, Amir and Sax, Alexander and Yeo, Teresa and Kar, Oğuzhan and Cheerla, Nikhil and Suri, Rohan and Cao, Zhangjie and Malik, Jitendra and Guibas, Leonidas},
  journal={arXiv},
  year={2020}
}
```



================================================
FILE: config/jobinfo.txt
================================================
normaltarget_allperceps, .


================================================
FILE: config/split.txt
================================================
train_buildings: [woodbine, hometown, haymarket, emmaus, swormville, haxtun, martinville,
  winfield, marksville, hammon, mammoth, kronborg, cobalt, lenoir, bonnie, bautista,
  retsof, azusa, munsons, darrtown, michiana, uncertain, lajas, plessis, bellemeade,
  jacobus, swisshome, idanha, lathrup, hambleton, country, ancor, haaswood, orason,
  cisne, byers, muleshoe, fredericksburg, cayuse, elmira, adrian, merchantville, german,
  waipahu, clarkridge, rogue, ludlowville, cochranton, colebrook, lovilia, mullica,
  mobridge, kevin, yscloskey, laytonsville, convoy, sisters, cosmos, calavo, clairton,
  crookston, nicut, ladue, anaheim, howie, codell, seatonville, brown, quantico, goodview,
  parole, kingfisher, churchton, edgemere, micanopy, bettendorf, goffs, windhorst,
  ooltewah, tokeland, darden, benicia, american, neibert, milaca, willow, samuels,
  sanctuary, pablo, cantwell, biltmore, germfask, eastville, rough, westfield, wainscott,
  mahtomedi, kopperl, gluck, mogote, cason, hercules, avonia, pettigrew, ewell, imbery,
  collierville, maguayo, brentsville, mckeesport, hitchland, browntown, aldine, spotswood,
  chrisney, kildare, stockman, ribera, hortense, mentasta, soldier, kettle, northgate,
  crandon, ovalo, cullison, dedham, vails, tallmadge, gratz, fonda, helton, mogadore,
  highspire, sharon, foyil, shelbiana, seiling, brevort, noonday, springhill,
  coronado, sundown, carneiro, silas, assinippi, monticello, wappingers,
  lakeville, stockwell, ogilvie, victorville, braxton, sodaville, silerton, hobson,
  tradewinds, sands, coffeen, umpqua, blackstone, sarcoxie, model, bremerton, capistrano,
  deatsville, graceville, dansville, belpre, edson, mcnary, kirwin, rosenberg, lynchburg,
  ranchester, shingler, auburn, connellsville, alstown, kerrtown, marstons, hurley,
  mifflintown, pamelia, sumas, chilhowie, dryville, deemston, cashel, galatia, harrellsville,
  mcdade, eudora, sasakwa, baneberry, rosser, halfway, hainesburg, gravelly, frierson,
  tyler, irvine, natural, murchison, lindsborg, duarte, wando, globe, neshkoro, cornville,
  bowmore, roxboro, espanola, maugansville, yankeetown, sawpit, schoolcraft, klickitat,
  scandinavia, donaldson, aloha, gaylord, hartline, laupahoehoe, wiconisco, mesic,
  eagerville, keiser, potosi, wyldwood, macarthur, newfields, moberly, everton, lindenwood,
  nuevo, bethlehem, silva, noxapater, lindberg, hornsby, weleetka, tysons, kremlin,
  jenners, trail, freedom, mcclure, ruckersville, sugarville, nemacolin, athens, vacherie,
  checotah, blenheim, allensville, grantsville, holcut, hallettsville, angiola, tomales,
  grangeville, seward, fishersville, kendall, kangley, wilbraham, caruthers, hacienda,
  readsboro, pocopson, bonfield, cohoes, inkom, monson, peacock, touhy, divide, norvelt,
  badger, leilani, corozal, warrenville, lluveras, grigston, cooperstown, nimmons,
  ewansville, paige, matoaca, lessley, purple, kihei, millbury, culbertson, maunawili,
  brewton, maryhill, channel, branford, creede, goodfield, spencerville, kirksville,
  cokeville, barahona, leonardo, mosinee, tolstoy, broseley, broadwell, landing, roeville,
  hatfield, rancocas, mcewen, annona, okabena, terrell, barboursville, booth, hanson,
  mashulaville, cutlerville, euharlee, rabbit, fitchburg, shellsburg, milford, grassy,
  timberon, coeburn, wilkinsburg, lynxville, islandton, arbutus, reyno, wakeman, frankfort,
  sontag, voorhees, beechwood, ossipee, rockport, starks, woonsocket, hildebran, circleville,
  aldrich, sunshine, destin, chesterbrook, musicks, merom, kinde, andover, pittsburg,
  scioto, tilghmanton, castor, potterville, onaga, stanleyville, leavittsburg, carpendale,
  bountiful, kingdom, cebolla, sweatman, arona, sagerton, herricks, morris, montreal,
  stokes, newcomb, adairsville, bertram, kinney, spread, akiak, westerville, texasville,
  springerville, almota, aulander, superior, goodyear, cabin, random, ballou, southfield,
  ballantine, glenmoor, oriole, ashport, denmark, winthrop, bohemia, yadkinville,
  smoketown, waucousta, winooski, peden, mayesville, liddieville, clive, gluek, goodwine,
  uvalda, pleasant, losantville, lineville, hillsdale, ackermanville, waukeenah, mentmore,
  glassboro, bellwood, peconic, pinesdale, hordville, wells, hendrix, dunmor, fleming,
  mccloud, reserve, gilbert, bonesteel, roane, pocasset, greigsville, delton, whitethorn,
  frontenac, siren, artois, helix, melstone, sultan, shumway, seeley, cousins, cauthron,
  gough, anthoston, gladstone, macland, hiteman, shauck, jennie, airport, funkstown,
  markleeville, marland, gloria, poyen, annawan, bolton, wattsville, waldenburg, pearce,
  maiden, gasburg, calmar, applewold, kemblesville, redbank, wilseyville, lucan, archer,
  castroville, pasatiempo, arkansaw, wyatt, shelbyville, merlin, whiteriver, torrington,
  oyens]
val_buildings: [cottonport, elton, corder, mazomanie, barranquitas, ihlen, wilkesboro,
  macedon, portal, gastonia, thrall, orangeburg, poipu, kankakee, chiloquin, sussex,
  maricopa, wesley, bowlus, copemish, tariffville, pomaria, kathryn, rutherford, plumerville,
  waimea, experiment, dalcour, ohoopee, mifflinburg, callicoon, manassas, macksville,
  apache, alfred, maida, dauberville, chireno, stilwell, albertville, ellaville,
  kobuk, burien, carpio, placida, forkland, tippecanoe, beach, eagan, grace, portola,
  hominy, maben]


================================================
FILE: config/split_fullplus.txt
================================================
train_buildings: [hanson, merom, arbutus, goodfield, eagan, arona, adairsville, reserve, aloha, castor, munsons, ballou, woonsocket, foyil, creede, superior, klickitat, cottonport, rancocas, ossipee, haxtun, tyler, sugarville, martinville, haaswood, euharlee, hacienda, peacock, convoy, roane, ellaville, chilhowie, tomales, springhill, shellsburg, chrisney, rutherford, mullica, winthrop, grace, codell, silas, braxton, chireno, bremerton, lynchburg, halfway, ballantine, hendrix, tokeland, merlin, baneberry, nimmons, mccloud, onaga, aldrich, maguayo, bountiful, carpio, ovalo, waimea, grassy, monticello, portal, melstone, cutlerville, frankfort, gilbert, kinney, gough, quantico, goodyear, waukeenah, pocopson, terrell, andover, ackermanville, copemish, leavittsburg, monson, applewold, sasakwa, anthoston, gaylord, whitethorn, clarkridge, gladstone, cornville, hatfield, circleville, marksville, trail, chesterbrook, gloria, funkstown, tippecanoe, airport, mesic, stokes, rogue, bethlehem, mcewen, hobson, kirwin, glenmoor, ancor, vails, biltmore, kangley, voorhees, leonardo, kevin, maida, castroville, seeley, millbury, waucousta, dansville, ludlowville, cooperstown, elton, marstons, athens, mentmore, inkom, model, ruckersville, hambleton, newfields, broseley, bettendorf, irvine, pinesdale, tilghmanton, american, westerville, maunawili, reyno, sanctuary, portola, goodwine, cullison,ran, cobalt, winooski, orangeburg, jacobus, oriole, lakeville, pasatiempo, beach, mahtomedi, redbank, cosmos, goffs, wilkinsburg, barboursville, frierson, sunshine, globe, kettle, booth, barranquitas, burien, carneiro, placida, hallettsville, fredericksburg, aulander, culbertson, bellwood, stanleyville, smoketown, winfield, seward, emmaus, macarthur, pomaria, kankakee, jenners, nemacolin, archer, neibert, random, parole, carpendale, spotswood, tolstoy, shelbyville, potterville, frontenac, avonia, laupahoehoe, milford, seatonville, roxboro, torrington, grantsville, rosser, spencerville, ewell, wyatt, plumerville, lynxville, allensville, springerville, nuevo, grangeville, stilwell, colebrook, browntown, germfask, readsboro, bowmore, pettigrew, brevort, fitchburg, shelbiana, hometown, montreal, mckeesport, wainscott, adrian, branford, dunmor, windhorst, alfred, islandton, arkansaw, rough, brewton, bonnie, cabin, hitchland, cebolla, mammoth, alstown, beechwood, hominy, crandon, sodaville, everton, maiden, churchton, divide, cayuse, helix, umpqua, chiloquin, schoolcraft, kathryn, coffeen, willow, dedham, corder, timberon, pleasant, darnestown, harrellsville, bohemia, micanopy, hillsdale, mashulaville, wilseyville, kemblesville, thrall, kronborg, fleming, bonesteel, hartline, channel, checotah, orason, fonda, kerrtown, noonday, capistrano, poyen, spread, annona, weleetka, silva, gasburg, milaca, bolton, kendall, stockman, whiteriver, ooltewah, wattsville, soldier, cohoes, lucan, neshkoro, newcomb, cason, roeville, kremlin, yankeetown, maryhill, deatsville, lenoir, byers, kildare, kirksville, blackstone, oyens, peden, azusa, mogote, keiser, potosi, delton, victorville, pamelia, cisne, marland, hiteman, lindberg, mayesville, sussex, nicut, lindsborg, seiling, mobridge, bautista, tradewinds, annawan, mogadore, retsof, murchison, lluveras, highspire, woodbine, lajas, sweatman, okabena, clairton, angiola, callicoon, hurley, touhy, michiana, sawpit, stockwell, country, lindenwood, anaheim, dryville, duarte, sumas, siren, musicks, wilbraham, holcut, kingdom, kopperl, calmar, forkland, mifflinburg, lineville, mosinee, hainesburg, gratz, maugansville, haymarket, uncertain, mifflintown, scandinavia, mazomanie, tariffville, mentasta, freedom, bowlus, apache, ranchester, fishersville, northgate, eagerville]
val_buildings: [ogilvie, laytonsville, clive, hortense, wesley, idanha, maricopa, tomkins, texasville, sundown, southfield, moberly, auburn, eudora, wiconisco, tallmadge, ashport, lessley, assinippi, gravelly, silerton, hordville, corozal, swormville, warrenville, almota, cantwell, collierville, cashel, pearce, sisters, merchantville, glassboro, shauck, yscloskey, pablo, ribera, pittsburg, graceville, markleeville, hercules, macedon, howie, sands, cokeville, herricks, kinde, edgemere, kobuk, sagerton, starks, jennie, hornsby, greigsville, westfield, wyldwood, sarcoxie, artois, elmira, deemston, galatia, denmark, swisshome, wells, scioto, pocasset, waldenburg, shumway, waipahu, macksville, crookston, eastville, yadkinville, darden]


================================================
FILE: config/split_medium.txt
================================================
train_buildings: [hanson, merom, goodfield, eagan, adairsville, castor, klickitat, cottonport, tyler, sugarville, martinville, chilhowie, silas, lynchburg, tokeland, onaga, frankfort, goodyear, albertville, andover, airport, rogue, ancor, leonardo, maida, marstons, athens, newfields, broseley, irvine, pinesdale, tilghmanton, goodwine, hildebran, winooski, lakeville, cosmos, goffs, sunshine, globe, benevolence, emmaus, pomaria, neibert, parole, tolstoy, shelbyville, potterville, rosser, allensville, springerville, nuevo, stilwell, browntown, readsboro, shelbiana, wainscott, arkansaw, bonnie, beechwood, hominy, churchton, coffeen, willow, timberon, bohemia, micanopy, hillsdale, wilseyville, kemblesville, thrall, bonesteel, annona, stockman, soldier, neshkoro, newcomb, byers, oyens, victorville, pamelia, marland, hiteman, sussex, bautista, highspire, woodbine, sweatman, clairton, touhy, lindenwood, anaheim, duarte, musicks, forkland, mifflinburg, hainesburg, maugansville, ranchester]
val_buildings: [hortense, southfield, wiconisco, gravelly, hordville, corozal, swormville, collierville, pearce, pablo, pittsburg, markleeville, sands, kobuk, westfield, wyldwood, swisshome, scioto, waipahu, darden]


================================================
FILE: datasets.py
================================================

import numpy as np
import matplotlib as mpl

import os, sys, math, random, tarfile, glob, time, yaml, itertools
import parse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from utils import *
from logger import Logger, VisdomLogger
from task_configs import get_task, tasks

from PIL import Image
from io import BytesIO
from sklearn.model_selection import train_test_split
import IPython

import pdb

""" Default data loading configurations for training, validation, and testing. """
def load_train_val(train_tasks, val_tasks=None, fast=False,
        train_buildings=None, val_buildings=None, split_file="config/split.txt",
        dataset_cls=None, batch_size=32, batch_transforms=cycle,
        subset=None, subset_size=None, dataaug=False,
    ):

    dataset_cls = dataset_cls or TaskDataset
    train_cls = TrainTaskDataset if dataaug else dataset_cls
    train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    if val_tasks is None: val_tasks = train_tasks
    val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks]  
    data = yaml.load(open(split_file))
    train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"])
    val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    print("number of train images:")
    train_loader = train_cls(buildings=train_buildings, tasks=train_tasks)
    print("number of val images:")
    val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)

    if subset_size is not None or subset is not None:
        train_loader = torch.utils.data.Subset(train_loader,
            random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),
        )

    train_step = int(len(train_loader) // (400 * batch_size))
    val_step = int(len(val_loader) // (400 * batch_size))
    print("Train step: ", train_step)
    print("Val step: ", val_step)
    if fast: train_step, val_step = 8, 8

    return train_loader, val_loader, train_step, val_step


""" Load all buildings """
def load_all(tasks, buildings=None, batch_size=64, split_file="data/split.txt", batch_transforms=cycle):

    data = yaml.load(open(split_file))
    buildings = buildings or (data["train_buildings"] + data["val_buildings"])

    data_loader = torch.utils.data.DataLoader(
        TaskDataset(buildings=buildings, tasks=tasks),
        batch_size=batch_size,
        num_workers=0, shuffle=True, pin_memory=True
    )

    return data_loader



def load_test(all_tasks, buildings=["almena", "albertville"], sample=4):

    all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks]
    print(f"number of images in {buildings[0]}:")
    test_loader1 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=0, shuffle=False, pin_memory=True,
    )
    print(f"number of images in {buildings[1]}:")
    test_loader2 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=0, shuffle=False, pin_memory=True,
    )
    set1 = list(itertools.islice(test_loader1, 1))[0]
    set2 = list(itertools.islice(test_loader2, 1))[0]
    test_set = tuple(torch.cat([x, y], dim=0) for x, y in zip(set1, set2))
    return test_set


def load_ood(tasks=[tasks.rgb], ood_path=OOD_DIR, sample=21):
    ood_loader = torch.utils.data.DataLoader(
        ImageDataset(tasks=tasks, data_dir=ood_path),
        batch_size=sample,
        num_workers=sample, shuffle=False, pin_memory=True
    )
    ood_images = list(itertools.islice(ood_loader, 1))[0]
    return ood_images



class TaskDataset(Dataset):

    def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS,
            building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True):

        super().__init__()
        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs
        self.building_files = building_files or self.building_files
        self.convert_path = convert_path or self.convert_path
        self.resize = resize
        if use_raid:
            self.convert_path = self.convert_path_raid
            self.building_files = self.building_files_raid

        self.file_map = {}
        for data_dir in self.data_dirs:
            for file in glob.glob(f'{data_dir}/*'):
                res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
                if res is None: continue
                self.file_map[file[len(data_dir)+1:]] = data_dir

        filtered_files = None

        assert (len(tasks) > 0), "Building dataset for tasks, but no tasks specified!"
        task = tasks[0]
        task_files = []
        for building in buildings:
            task_files += self.building_files(task, building)
        print(f"    {task.name} file len: {len(task_files)}")
        self.idx_files = task_files
        if not shuffle: self.idx_files = sorted(task_files)

        print ("    Intersection files len: ", len(self.idx_files))

    def reset_unpaired(self):
        if self.unpaired:
            self.task_indices = {task:random.sample(range(len(self.idx_files)), len(self.idx_files)) for task in self.task_indices}

    def building_files(self, task, building):
        """ Gets all the tasks in a given building (grouping of data) """
        return get_files(f"{building}_{task.file_name}/{task.file_name}/*.{task.file_ext}", self.data_dirs)

    def building_files_raid(self, task, building):
        return get_files(f"{task.file_name}/{building}/*.{task.file_ext}", self.data_dirs)

    def convert_path(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        if f"{building}_{task.file_name}" not in self.file_map:
            print (f"{building}_{task.file_name} not in file map")
            # IPython.embed()
            return ""
        data_dir = self.file_map[f"{building}_{task.file_name}"]
        return f"{data_dir}/{dest_file}"

    def convert_path_raid(self, full_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(full_file.split('/')[-3:])
        result = parse.parse("{task}/{building}/{view}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{task.file_name}/{building}/{view}.{task.file_ext}"
        return f"{full_file[:-len(source_file)-1]}/{dest_file}"

    def __len__(self):
        return len(self.idx_files)

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)

                for task in self.tasks:
                    file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed)

                    res.append(image)
                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


class TrainTaskDataset(TaskDataset):

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)
                crop = random.randint(int(0.7*512), 512) if bool(random.getrandbits(1)) else 512

                for task in self.tasks:
                    jitter = bool(random.getrandbits(1)) if task.name == 'rgb' else False
                    file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed, crop=crop, jitter=jitter)
                    res.append(image)

                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


class ImageDataset(Dataset):

    def __init__(
        self,
        tasks=[tasks.rgb],
        data_dir=f"data/ood_images",
        files=None,
    ):

        self.tasks = tasks
        #if not USE_RAID and files is None:
        #    os.system(f"ls {data_dir}/*.png")
        #    os.system(f"ls {data_dir}/*.png")

        self.files = files \
            or sorted(
                glob.glob(f"{data_dir}/*.png")
                + glob.glob(f"{data_dir}/*.jpg")
                + glob.glob(f"{data_dir}/*.jpeg")
            )

        print("number of ood images: ", len(self.files))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        file = self.files[idx]
        res = []
        seed = random.randint(0, 1e10)
        for task in self.tasks:
            image = task.file_loader(file, seed=seed)
            if image.shape[0] == 1: image = image.expand(3, -1, -1)
            res.append(image)
        return tuple(res)




if __name__ == "__main__":

    logger = VisdomLogger("data", env=JOB)
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)],
        batch_size=32,
    )
    print ("created dataset")
    logger.add_hook(lambda logger, data: logger.step(), freq=32)

    for i, _ in enumerate(train_dataset):
        logger.update("epoch", i)


================================================
FILE: demo.py
================================================
import torch
from torchvision import transforms

from modules.unet import UNet, UNetReshade

import PIL
from PIL import Image

import argparse
import os.path
from pathlib import Path
import glob
import sys

import pdb



parser = argparse.ArgumentParser(description='Visualize output for a single Task')

parser.add_argument('--task', dest='task', help="normal, depth or reshading")
parser.set_defaults(task='NONE')

parser.add_argument('--img_path', dest='img_path', help="path to rgb image")
parser.set_defaults(im_name='NONE')

parser.add_argument('--output_path', dest='output_path', help="path to where output image should be stored")
parser.set_defaults(store_name='NONE')

args = parser.parse_args()

root_dir = './models/'
trans_totensor = transforms.Compose([transforms.Resize(256, interpolation=PIL.Image.BILINEAR),
                                    transforms.CenterCrop(256),
                                    transforms.ToTensor()])
trans_topil = transforms.ToPILImage()

os.system(f"mkdir -p {args.output_path}")

# get target task and model
target_tasks = ['normal','depth','reshading']
try:
    task_index = target_tasks.index(args.task)
except:
    print("task should be one of the following: normal, depth, reshading")
    sys.exit()
models = [UNet(), UNet(downsample=6, out_channels=1), UNetReshade(downsample=5)]
model = models[task_index]

map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')

def save_outputs(img_path, output_file_name):

    img = Image.open(img_path)
    img_tensor = trans_totensor(img)[:3].unsqueeze(0)

    # compute baseline and consistency output
    for type in ['baseline','consistency']:
        path = root_dir + 'rgb2'+args.task+'_'+type+'.pth'
        model_state_dict = torch.load(path, map_location=map_location)
        model.load_state_dict(model_state_dict)
        baseline_output = model(img_tensor).clamp(min=0, max=1)
        trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png')


img_path = Path(args.img_path)
if img_path.is_file():
    save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0])
elif img_path.is_dir():
    for f in glob.glob(args.img_path+'/*'):
        save_outputs(f, os.path.splitext(os.path.basename(f))[0])
else:
    print("invalid file path!")
    sys.exit()


================================================
FILE: energy.py
================================================
import os, sys, math, random, itertools
from functools import partial
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.checkpoint import checkpoint

from utils import *
from task_configs import tasks, get_task, ImageTask
from transfers import functional_transfers, finetuned_transfers, get_transfer_name, Transfer
from datasets import TaskDataset, load_train_val

from matplotlib.cm import get_cmap


import IPython

import pdb

def get_energy_loss(
    config="", mode="winrate",
    pretrained=True, finetuned=True, **kwargs,
):
    """ Loads energy loss from config dict. """
    if isinstance(mode, str):
        mode = {
            "standard": EnergyLoss,
            "winrate": WinRateEnergyLoss,
        }[mode]
    return mode(**energy_configs[config],
        pretrained=pretrained, finetuned=finetuned, **kwargs
    )

ALL_PERCEPTUAL_TASKS = [tasks.principal_curvature,
             tasks.sobel_edges,
             tasks.depth_zbuffer,
             tasks.edge_occlusion,
             tasks.reshading,
             tasks.keypoints3d,
             tasks.keypoints2d]

def generate_config(perceptual_tasks, target_task=tasks.normal, tree_structure=False, has_gt=True):

    # If we have GT, measure error
    base_keys = {
                    "x": [tasks.rgb],
                    "n((x))": [tasks.rgb, target_task]
    }
    direct_losses = {}
    if has_gt:
        base_keys["y^"] = [target_task]
        direct_losses["direct_normal"] = { ("train", "val", "train_subset"): [ ("n((x))", "y^"), ] }

    # Add in losses for consistency energy
    perceptual_losses = []
    if tree_structure:
        perceptual_losses = [('', t2) for t2 in perceptual_tasks]
    else:
        perceptual_losses = list(itertools.combinations(perceptual_tasks + [''], r=2))
        perceptual_losses = [(t2, t1) if str(t2) == '' else (t1, t2) for t1, t2 in perceptual_losses]

    return {
        "paths": both( base_keys,
                {
                    f"n({intermediate_task}(x))": [tasks.rgb, intermediate_task, target_task]
                    for intermediate_task in perceptual_tasks
                }
        ),
        "freeze_list" : [(t, target_task) for t in perceptual_tasks] +
                        [(tasks.rgb, t) for t in perceptual_tasks],
        "losses": both( direct_losses,
                {
                    f'percep_{t1}_{t2}': { ("train", "val", "train_subset"): [ (f"n({t1}(x))", f"n({t2}(x))"), ], }
                    for t1, t2 in perceptual_losses
                }
        ),
        "plots": {
            "ID": dict(
                size=256,
                realities=("test",),
                paths=[
                    "x",
                    "n((x))",
                ]
            ),
        },
    }

energy_configs = {

    "example_cascade_two_networks": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "n(x)": [tasks.rgb, tasks.normal],
            "RC(x)": [tasks.rgb, tasks.principal_curvature],
            "curv": [tasks.principal_curvature],
            "f(y^)": [tasks.normal, tasks.principal_curvature],
            "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature],
        },
        "freeze_list": [
            [tasks.normal, tasks.principal_curvature],
            [tasks.normal, tasks.sobel_edges],
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "percep_curv": {
                ("train", "val"): [
                    ("f(n(x))", "f(y^)"),
                ],
            },
            "direct_curv": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=('train', 'val'),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "f(y^)",
                    "f(n(x))",
                ]
            ),
        },
    },


    "example_normal": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "n(x)": [tasks.rgb, tasks.normal],
            "a(x)": [tasks.rgb, tasks.sobel_edges],
            "RC(x)": [tasks.rgb, tasks.principal_curvature],
            "edge": [tasks.sobel_edges],
            "curv": [tasks.principal_curvature],
            "s(y^)": [tasks.normal, tasks.sobel_edges],
            "s(n(x))": [tasks.rgb, tasks.normal, tasks.sobel_edges],
            "f(y^)": [tasks.normal, tasks.principal_curvature],
            "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature],
        },
        "freeze_list": [
            [tasks.normal, tasks.principal_curvature],
            [tasks.normal, tasks.sobel_edges],
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "percep_curv": {
                ("train", "val"): [
                    ("f(n(x))", "f(y^)"),
                ],
            },
            "direct_curv": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
            "percep_edge": {
                ("train", "val"): [
                    ("s(n(x))", "s(y^)"),
                ],
            },
            "direct_edge": {
                ("train", "val"): [
                    ("a(x)", "s(y^)"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=('train', 'val'),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "f(y^)",
                    "f(n(x))",
                    "s(y^)",
                    "s(n(x))",
                ]
            ),
        },
    },

    "example_normal_direct": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "n(x)": [tasks.rgb, tasks.normal],
        },
        "freeze_list": [
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=('train', 'val'),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                ]
            ),
        },
    },
    
    "multiperceptual_normal": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "n(x)": [tasks.rgb, tasks.normal],
            "RC(x)": [tasks.rgb, tasks.principal_curvature],
            "a(x)": [tasks.rgb, tasks.sobel_edges],
            "d(x)": [tasks.rgb, tasks.reshading],
            "r(x)": [tasks.rgb, tasks.depth_zbuffer],
            "EO(x)": [tasks.rgb, tasks.edge_occlusion],
            "k2(x)": [tasks.rgb, tasks.keypoints2d],
            "k3(x)": [tasks.rgb, tasks.keypoints3d],
            "curv": [tasks.principal_curvature],
            "edge": [tasks.sobel_edges],
            "depth": [tasks.depth_zbuffer],
            "reshading": [tasks.reshading],
            "keypoints2d": [tasks.keypoints2d],
            "keypoints3d": [tasks.keypoints3d],
            "edge_occlusion": [tasks.edge_occlusion],
            "f(y^)": [tasks.normal, tasks.principal_curvature],
            "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature],
            "s(y^)": [tasks.normal, tasks.sobel_edges],
            "s(n(x))": [tasks.rgb, tasks.normal, tasks.sobel_edges],
            "g(y^)": [tasks.normal, tasks.reshading],
            "g(n(x))": [tasks.rgb, tasks.normal, tasks.reshading],
            "nr(y^)": [tasks.normal, tasks.depth_zbuffer],
            "nr(n(x))": [tasks.rgb, tasks.normal, tasks.depth_zbuffer],
            "Nk2(y^)": [tasks.normal, tasks.keypoints2d],
            "Nk2(n(x))": [tasks.rgb, tasks.normal, tasks.keypoints2d],
            "Nk3(y^)": [tasks.normal, tasks.keypoints3d],
            "Nk3(n(x))": [tasks.rgb, tasks.normal, tasks.keypoints3d],
            "nEO(y^)": [tasks.normal, tasks.edge_occlusion],
            "nEO(n(x))": [tasks.rgb, tasks.normal, tasks.edge_occlusion],
            "imagenet(y^)": [tasks.normal, tasks.imagenet],
            "imagenet(n(x))": [tasks.rgb, tasks.normal, tasks.imagenet],
        },
        "freeze_list": [
            [tasks.normal, tasks.principal_curvature],
            [tasks.normal, tasks.sobel_edges],
            [tasks.normal, tasks.reshading],
            [tasks.normal, tasks.depth_zbuffer],
            [tasks.normal, tasks.keypoints3d],
            [tasks.normal, tasks.keypoints2d],
            [tasks.normal, tasks.edge_occlusion],
            [tasks.normal, tasks.imagenet],
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "percep_curv": {
                ("train", "val"): [
                    ("f(n(x))", "f(y^)"),
                ],
            },
            "direct_curv": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
            "percep_edge": {
                ("train", "val"): [
                    ("s(n(x))", "s(y^)"),
                ],
            },
            "direct_edge": {
                ("train", "val"): [
                    ("a(x)", "s(y^)"),
                ],
            },
            "percep_reshading": {
                ("train", "val"): [
                    ("g(n(x))", "g(y^)"),
                ],
            },
            "direct_reshading": {
                ("train", "val"): [
                    ("d(x)", "reshading"),
                ],
            },
            "percep_depth_zbuffer": {
                ("train", "val"): [
                    ("nr(n(x))", "nr(y^)"),
                ],
            },
            "direct_depth_zbuffer": {
                ("train", "val"): [
                    ("r(x)", "depth"),
                ],
            },
            "percep_keypoints2d": {
                ("train", "val"): [
                    ("Nk2(n(x))", "Nk2(y^)"),
                ],
            },
            "direct_keypoints2d": {
                ("train", "val"): [
                    ("k2(x)", "keypoints2d"),
                ],
            },
            "percep_keypoints3d": {
                ("train", "val"): [
                    ("Nk3(n(x))", "Nk3(y^)"),
                ],
            },
            "direct_keypoints3d": {
                ("train", "val"): [
                    ("k3(x)", "keypoints3d"),
                ],
            },
            "percep_edge_occlusion": {
                ("train", "val"): [
                    ("nEO(n(x))", "nEO(y^)"),
                ],
            },
            "direct_edge_occlusion": {
                ("train", "val"): [
                    ("EO(x)", "edge_occlusion"),
                ],
            },
            "percep_imagenet_percep": {
                ("train", "val"): [
                    ("imagenet(n(x))", "imagenet(y^)"),
                ],
            },
            "direct_imagenet_percep": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "f(y^)",
                    "f(n(x))",
                    "s(y^)",
                    "s(n(x))",
                    "g(y^)",
                    "g(n(x))",
                    "nr(n(x))",
                    "nr(y^)",
                    "Nk3(y^)",
                    "Nk3(n(x))",
                    "Nk2(y^)",
                    "Nk2(n(x))",
                    "nEO(y^)",
                    "nEO(n(x))",
                ]
            ),
        },
    },

    "multiperceptual_reshading": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "n(x)": [tasks.rgb, tasks.reshading],
            "RC(x)": [tasks.rgb, tasks.principal_curvature],
            "a(x)": [tasks.rgb, tasks.sobel_edges],
            "d(x)": [tasks.rgb, tasks.normal],
            "r(x)": [tasks.rgb, tasks.depth_zbuffer],
            "EO(x)": [tasks.rgb, tasks.edge_occlusion],
            "k2(x)": [tasks.rgb, tasks.keypoints2d],
            "k3(x)": [tasks.rgb, tasks.keypoints3d],
            "curv": [tasks.principal_curvature],
            "edge": [tasks.sobel_edges],
            "depth": [tasks.depth_zbuffer],
            "normal": [tasks.normal],
            "keypoints2d": [tasks.keypoints2d],
            "keypoints3d": [tasks.keypoints3d],
            "edge_occlusion": [tasks.edge_occlusion],
            "f(y^)": [tasks.reshading, tasks.principal_curvature],
            "f(n(x))": [tasks.rgb, tasks.reshading, tasks.principal_curvature],
            "s(y^)": [tasks.reshading, tasks.sobel_edges],
            "s(n(x))": [tasks.rgb, tasks.reshading, tasks.sobel_edges],
            "g(y^)": [tasks.reshading, tasks.normal],
            "g(n(x))": [tasks.rgb, tasks.reshading, tasks.normal],
            "nr(y^)": [tasks.reshading, tasks.depth_zbuffer],
            "nr(n(x))": [tasks.rgb, tasks.reshading, tasks.depth_zbuffer],
            "Nk2(y^)": [tasks.reshading, tasks.keypoints2d],
            "Nk2(n(x))": [tasks.rgb, tasks.reshading, tasks.keypoints2d],
            "Nk3(y^)": [tasks.reshading, tasks.keypoints3d],
            "Nk3(n(x))": [tasks.rgb, tasks.reshading, tasks.keypoints3d],
            "nEO(y^)": [tasks.reshading, tasks.edge_occlusion],
            "nEO(n(x))": [tasks.rgb, tasks.reshading, tasks.edge_occlusion],
            "imagenet(y^)": [tasks.reshading, tasks.imagenet],
            "imagenet(n(x))": [tasks.rgb, tasks.reshading, tasks.imagenet],
        },
        "freeze_list": [
            [tasks.reshading, tasks.principal_curvature],
            [tasks.reshading, tasks.sobel_edges],
            [tasks.reshading, tasks.normal],
            [tasks.reshading, tasks.depth_zbuffer],
            [tasks.reshading, tasks.keypoints3d],
            [tasks.reshading, tasks.keypoints2d],
            [tasks.reshading, tasks.edge_occlusion],
            [tasks.reshading, tasks.imagenet],
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "percep_curv": {
                ("train", "val"): [
                    ("f(n(x))", "f(y^)"),
                ],
            },
            "direct_curv": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
            "percep_edge": {
                ("train", "val"): [
                    ("s(n(x))", "s(y^)"),
                ],
            },
            "direct_edge": {
                ("train", "val"): [
                    ("a(x)", "s(y^)"),
                ],
            },
            "percep_normal": {
                ("train", "val"): [
                    ("g(n(x))", "g(y^)"),
                ],
            },
            "direct_normal": {
                ("train", "val"): [
                    ("d(x)", "normal"),
                ],
            },
            "percep_depth_zbuffer": {
                ("train", "val"): [
                    ("nr(n(x))", "nr(y^)"),
                ],
            },
            "direct_depth_zbuffer": {
                ("train", "val"): [
                    ("r(x)", "depth"),
                ],
            },
            "percep_keypoints2d": {
                ("train", "val"): [
                    ("Nk2(n(x))", "Nk2(y^)"),
                ],
            },
            "direct_keypoints2d": {
                ("train", "val"): [
                    ("k2(x)", "keypoints2d"),
                ],
            },
            "percep_keypoints3d": {
                ("train", "val"): [
                    ("Nk3(n(x))", "Nk3(y^)"),
                ],
            },
            "direct_keypoints3d": {
                ("train", "val"): [
                    ("k3(x)", "keypoints3d"),
                ],
            },
            "percep_edge_occlusion": {
                ("train", "val"): [
                    ("nEO(n(x))", "nEO(y^)"),
                ],
            },
            "direct_edge_occlusion": {
                ("train", "val"): [
                    ("EO(x)", "edge_occlusion"),
                ],
            },
            "percep_imagenet_percep": {
                ("train", "val"): [
                    ("imagenet(n(x))", "imagenet(y^)"),
                ],
            },
            "direct_imagenet_percep": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "f(y^)",
                    "f(n(x))",
                    "s(y^)",
                    "s(n(x))",
                    "g(y^)",
                    "g(n(x))",
                    "nr(n(x))",
                    "nr(y^)",
                    "Nk3(y^)",
                    "Nk3(n(x))",
                    "Nk2(y^)",
                    "Nk2(n(x))",
                    "nEO(y^)",
                    "nEO(n(x))",
                    "depth",
                ]
            ),
        },
    },

    "multiperceptual_depth": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            "n(x)": [tasks.rgb, tasks.depth_zbuffer],
            "RC(x)": [tasks.rgb, tasks.principal_curvature],
            "a(x)": [tasks.rgb, tasks.sobel_edges],
            "d(x)": [tasks.rgb, tasks.normal],
            "r(x)": [tasks.rgb, tasks.reshading],
            "EO(x)": [tasks.rgb, tasks.edge_occlusion],
            "k2(x)": [tasks.rgb, tasks.keypoints2d],
            "k3(x)": [tasks.rgb, tasks.keypoints3d],
            "curv": [tasks.principal_curvature],
            "edge": [tasks.sobel_edges],
            "normal": [tasks.normal],
            "reshading": [tasks.reshading],
            "keypoints2d": [tasks.keypoints2d],
            "keypoints3d": [tasks.keypoints3d],
            "edge_occlusion": [tasks.edge_occlusion],
            "f(y^)": [tasks.depth_zbuffer, tasks.principal_curvature],
            "f(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.principal_curvature],
            "s(y^)": [tasks.depth_zbuffer, tasks.sobel_edges],
            "s(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.sobel_edges],
            "g(y^)": [tasks.depth_zbuffer, tasks.normal],
            "g(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.normal],
            "nr(y^)": [tasks.depth_zbuffer, tasks.reshading],
            "nr(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.reshading],
            "Nk2(y^)": [tasks.depth_zbuffer, tasks.keypoints2d],
            "Nk2(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints2d],
            "Nk3(y^)": [tasks.depth_zbuffer, tasks.keypoints3d],
            "Nk3(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints3d],
            "nEO(y^)": [tasks.depth_zbuffer, tasks.edge_occlusion],
            "nEO(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.edge_occlusion],
            "imagenet(y^)": [tasks.depth_zbuffer, tasks.imagenet],
            "imagenet(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.imagenet],
        },
        "freeze_list": [
            [tasks.depth_zbuffer, tasks.principal_curvature],
            [tasks.depth_zbuffer, tasks.sobel_edges],
            [tasks.depth_zbuffer, tasks.normal],
            [tasks.depth_zbuffer, tasks.reshading],
            [tasks.depth_zbuffer, tasks.keypoints3d],
            [tasks.depth_zbuffer, tasks.keypoints2d],
            [tasks.depth_zbuffer, tasks.edge_occlusion],
            [tasks.depth_zbuffer, tasks.imagenet],
        ],
        "losses": {
            "mae": {
                ("train", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "percep_curv": {
                ("train", "val"): [
                    ("f(n(x))", "f(y^)"),
                ],
            },
            "direct_curv": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
            "percep_edge": {
                ("train", "val"): [
                    ("s(n(x))", "s(y^)"),
                ],
            },
            "direct_edge": {
                ("train", "val"): [
                    ("a(x)", "s(y^)"),
                ],
            },
            "percep_normal": {
                ("train", "val"): [
                    ("g(n(x))", "g(y^)"),
                ],
            },
            "direct_normal": {
                ("train", "val"): [
                    ("d(x)", "normal"),
                ],
            },
            "percep_reshading": {
                ("train", "val"): [
                    ("nr(n(x))", "nr(y^)"),
                ],
            },
            "direct_reshading": {
                ("train", "val"): [
                    ("r(x)", "reshading"),
                ],
            },
            "percep_keypoints2d": {
                ("train", "val"): [
                    ("Nk2(n(x))", "Nk2(y^)"),
                ],
            },
            "direct_keypoints2d": {
                ("train", "val"): [
                    ("k2(x)", "keypoints2d"),
                ],
            },
            "percep_keypoints3d": {
                ("train", "val"): [
                    ("Nk3(n(x))", "Nk3(y^)"),
                ],
            },
            "direct_keypoints3d": {
                ("train", "val"): [
                    ("k3(x)", "keypoints3d"),
                ],
            },
            "percep_edge_occlusion": {
                ("train", "val"): [
                    ("nEO(n(x))", "nEO(y^)"),
                ],
            },
            "direct_edge_occlusion": {
                ("train", "val"): [
                    ("EO(x)", "edge_occlusion"),
                ],
            },
            "percep_imagenet_percep": {
                ("train", "val"): [
                    ("imagenet(n(x))", "imagenet(y^)"),
                ],
            },
            "direct_imagenet_percep": {
                ("train", "val"): [
                    ("RC(x)", "curv"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "f(y^)",
                    "f(n(x))",
                    "s(y^)",
                    "s(n(x))",
                    "g(y^)",
                    "g(n(x))",
                    "nr(n(x))",
                    "nr(y^)",
                    "Nk3(y^)",
                    "Nk3(n(x))",
                    "Nk2(y^)",
                    "Nk2(n(x))",
                    "nEO(y^)",
                    "nEO(n(x))",
                ]
            ),
        },
    },

    "energy_calc": generate_config(ALL_PERCEPTUAL_TASKS),
    "energy_calc_nogt": generate_config(ALL_PERCEPTUAL_TASKS, has_gt=False),
}



def coeff_hook(coeff):
    def fun1(grad):
        return coeff*grad.clone()
    return fun1


class EnergyLoss(object):

    def __init__(self, paths, losses, plots,
        pretrained=True, finetuned=False, freeze_list=[]
    ):

        self.paths, self.losses, self.plots = paths, losses, plots
        self.freeze_list = [str((path[0].name, path[1].name)) for path in freeze_list]
        self.metrics = {}

        self.tasks = []
        for _, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    self.tasks += self.paths[path1] + self.paths[path2]

        for name, config in self.plots.items():
            for path in config["paths"]:
                self.tasks += self.paths[path]
        self.tasks = list(set(self.tasks))

    def compute_paths(self, graph, reality=None, paths=None):
        path_cache = {}
        paths = paths or self.paths
        path_values = {
            name: graph.sample_path(path,
                reality=reality, use_cache=True, cache=path_cache,
            ) for name, path in paths.items()
        }
        del path_cache
        return {k: v for k, v in path_values.items() if v is not None}

    def get_tasks(self, reality):
        tasks = []
        for _, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                if reality in realities:
                    for path1, path2 in losses:
                        tasks += [self.paths[path1][0], self.paths[path2][0]]

        for name, config in self.plots.items():
            if reality in config["realities"]:
                for path in config["paths"]:
                    tasks += [self.paths[path][0]]

        return list(set(tasks))

    def __call__(self, graph, discriminator=None, realities=[], loss_types=None, reduce=True, use_l1=False):
        #pdb.set_trace()
        loss = {}
        for reality in realities:
            loss_dict = {}
            losses = []
            all_loss_types = set()
            for loss_type, loss_item in self.losses.items():
                all_loss_types.add(loss_type)
                loss_dict[loss_type] = []
                for realities_l, data in loss_item.items():
                    if reality.name in realities_l:
                        loss_dict[loss_type] += data
                        if loss_types is not None and loss_type in loss_types:
                            losses += data

            path_values = self.compute_paths(graph,
                paths={
                    path: self.paths[path] for path in \
                    set(path for paths in losses for path in paths)
                    },
                reality=reality)

            if reality.name not in self.metrics:
                self.metrics[reality.name] = defaultdict(list)

            for loss_type, losses in sorted(loss_dict.items()):
                if loss_type not in (loss_types or all_loss_types):
                    continue
                if loss_type not in loss:
                    loss[loss_type] = 0
                for path1, path2 in losses:
                    output_task = self.paths[path1][-1]
                    compute_mask = 'imagenet(n(x))' != path1
                    if loss_type not in loss:
                        loss[loss_type] = 0
                    for path1, path2 in losses:
                        output_task = self.paths[path1][-1]
                        if "direct" in loss_type:
                            with torch.no_grad():
                                path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False)
                                loss[loss_type] += path_loss
                        else:
                            path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False)
                            loss[loss_type] += path_loss
                            loss_name = "mae" if "mae" in loss_type else loss_type+"_mae"
                            self.metrics[reality.name][loss_name +" : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()]
                            path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=True)
                            loss_name = "mse" if "mae" in loss_type else loss_type + "_mse"
                            self.metrics[reality.name][loss_name +" : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()]

        return loss

    def logger_hooks(self, logger):

        name_to_realities = defaultdict(list)
        for loss_type, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    loss_name = "mae" if "mae" in loss_type else loss_type+"_mae"
                    name = loss_name+" : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    loss_name = "mse" if "mae" in loss_type else loss_type + "_mse"
                    name = loss_name+" : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)

        for name, realities in name_to_realities.items():
            def jointplot(logger, data, name=name, realities=realities):
                names = [f"{reality}_{name}" for reality in realities]
                if not all(x in data for x in names):
                    return
                data = np.stack([data[x] for x in names], axis=1)
                logger.plot(data, name, opts={"legend": names})

            logger.add_hook(partial(jointplot, name=name, realities=realities), feature=f"{realities[-1]}_{name}", freq=1)


    def logger_update(self, logger):

        name_to_realities = defaultdict(list)
        for loss_type, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    loss_name = "mae" if "mae" in loss_type else loss_type+"_mae"
                    name = loss_name+" : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    loss_name = "mse" if "mae" in loss_type else loss_type + "_mse"
                    name = loss_name+" : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)

        for name, realities in name_to_realities.items():
            for reality in realities:
                # IPython.embed()
                if reality not in self.metrics: continue
                if name not in self.metrics[reality]: continue
                if len(self.metrics[reality][name]) == 0: continue

                logger.update(
                    f"{reality}_{name}",
                    torch.mean(torch.stack(self.metrics[reality][name])),
                )
        self.metrics = {}

    def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0,prefix=""):
        error_pairs = {"n(x)": "y^"}
        realities_map = {reality.name: reality for reality in realities}
        for name, config in (plot_names or self.plots.items()):
            paths = config["paths"]

            realities = config["realities"]
            images = []
            error = False
            cmap = get_cmap("jet")

            first = True
            error_passed_ood = 0
            for reality in realities:
                with torch.no_grad():
                    path_values = self.compute_paths(graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality])

                shape = list(path_values[list(path_values.keys())[0]].shape)
                shape[1] = 3

                for i, path in enumerate(paths):
                    if path == 'depth': continue
                    X = path_values.get(path, torch.zeros(shape, device=DEVICE))
                    if first: images +=[[]]

                    if reality is 'ood' and error_passed_ood==0:
                        images[i].append(X.clamp(min=0, max=1).expand(*shape))
                    elif reality is 'ood' and error_passed_ood==1:
                        images[i+1].append(X.clamp(min=0, max=1).expand(*shape))
                    else:
                        images[-1].append(X.clamp(min=0, max=1).expand(*shape))

                    if path in error_pairs:

                        error = True
                        if first:
                            images += [[]]


                    if error:

                        Y = path_values.get(path, torch.zeros(shape, device=DEVICE))
                        Y_hat = path_values.get(error_pairs[path], torch.zeros(shape, device=DEVICE))

                        out_task = self.paths[path][-1]

                        if self.target_task == "reshading": #Use depth mask
                            Y_mask = path_values.get("depth", torch.zeros(shape, device = DEVICE))
                            mask_task = self.paths["r(x)"][-1]
                            mask = ImageTask.build_mask(Y_mask, val=mask_task.mask_val)
                        else:
                            mask = ImageTask.build_mask(Y_hat, val=out_task.mask_val)

                        errors = ((Y - Y_hat)**2).mean(dim=1, keepdim=True)
                        log_errors = torch.log(errors.clamp(min=0, max=out_task.variance))


                        errors = (3*errors/(out_task.variance)).clamp(min=0, max=1)

                        log_errors = torch.log(errors + 1)
                        log_errors = log_errors / log_errors.max()
                        log_errors = torch.tensor(cmap(log_errors.cpu()))[:, 0].permute((0, 3, 1, 2)).float()[:, 0:3]
                        log_errors = log_errors.clamp(min=0, max=1).expand(*shape).to(DEVICE)
                        log_errors[~mask.expand_as(log_errors)] = 0.505
                        if reality is 'ood':
                            images[i+1].append(log_errors)
                            error_passed_ood = 1
                        else:
                            images[-1].append(log_errors)

                        error = False
                first = False

            for i in range(0, len(images)):
                images[i] = torch.cat(images[i], dim=0)

            logger.images_grouped(images,
                f"{prefix}_{name}_[{', '.join(realities)}]_[{', '.join(paths)}]",
                resize=config["size"]
            )

    def __repr__(self):
        return str(self.losses)


class WinRateEnergyLoss(EnergyLoss):

    def __init__(self, *args, **kwargs):
        self.k = kwargs.pop('k', 3)
        self.random_select = kwargs.pop('random_select', False)
        self.running_stats = {}
        self.target_task = kwargs['paths']['y^'][0].name

        super().__init__(*args, **kwargs)

        self.percep_losses = [key[7:] for key in self.losses.keys() if key[0:7] == "percep_"]
        print ("percep losses:",self.percep_losses)
        self.chosen_losses = random.sample(self.percep_losses, self.k)

    def __call__(self, graph, discriminator=None, realities=[], loss_types=None, compute_grad_ratio=False):

        loss_types = ["mae"] + [("percep_" + loss) for loss in self.percep_losses] + [("direct_" + loss) for loss in self.percep_losses]
        # print (self.chosen_losses)
        loss_dict = super().__call__(graph, discriminator=discriminator, realities=realities, loss_types=loss_types, reduce=False)

        chosen_percep_mse_losses = [k for k in loss_dict.keys() if 'direct' not in k]
        percep_mse_coeffs = dict.fromkeys(chosen_percep_mse_losses, 1.0)
        ########### to compute loss coefficients #############
        if compute_grad_ratio:
            percep_mse_gradnorms = dict.fromkeys(chosen_percep_mse_losses, 1.0)
            for loss_name in chosen_percep_mse_losses:
                loss_dict[loss_name].mean().backward(retain_graph=True)
                target_weights=list(graph.edge_map[f"('rgb', '{self.target_task}')"].model.parameters())
                percep_mse_gradnorms[loss_name] = sum([l.grad.abs().sum().item() for l in target_weights])/sum([l.numel() for l in target_weights])
                graph.optimizer.zero_grad()
                graph.zero_grad()
                del target_weights
            total_gradnorms = sum(percep_mse_gradnorms.values())
            n_losses = len(chosen_percep_mse_losses)
            for loss_name, val in percep_mse_coeffs.items():
                percep_mse_coeffs[loss_name] = (total_gradnorms-percep_mse_gradnorms[loss_name])/((n_losses-1)*total_gradnorms)
            percep_mse_coeffs["mae"] *= (n_losses-1)
        ###########################################

        for key in self.chosen_losses:
            winrate = torch.mean((loss_dict[f"percep_{key}"] > loss_dict[f"direct_{key}"]).float())
            winrate = winrate.detach().cpu().item()
            if winrate < 1.0:
                self.running_stats[key] = winrate
            loss_dict[f"percep_{key}"] = loss_dict[f"percep_{key}"].mean() * percep_mse_coeffs[f"percep_{key}"]
            loss_dict.pop(f"direct_{key}")

        # print (self.running_stats)
        loss_dict["mae"] = loss_dict["mae"].mean() * percep_mse_coeffs["mae"]

        return loss_dict, percep_mse_coeffs["mae"]

    def logger_update(self, logger):
        super().logger_update(logger)
        if self.random_select or len(self.running_stats) < len(self.percep_losses):
            self.chosen_losses = random.sample(self.percep_losses, self.k)
        else:
            self.chosen_losses = sorted(self.running_stats, key=self.running_stats.get, reverse=True)[:self.k]

        logger.text (f"Chosen losses: {self.chosen_losses}")




================================================
FILE: graph.py
================================================
import os, sys, math, random, itertools, heapq
from collections import namedtuple, defaultdict
from functools import partial, reduce
import numpy as np
import IPython

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *
from models import TrainableModel, WrapperModel
from datasets import TaskDataset
from task_configs import get_task, task_map, tasks, get_model, RealityTask
from transfers import Transfer, RealityTransfer, get_transfer_name

#from modules.gan_dis import GanDisNet

import pdb

class TaskGraph(TrainableModel):
    """Basic graph that encapsulates set of edge constraints. Can be saved and loaded
    from directories."""

    def __init__(
        self, tasks=tasks, edges=None, edges_exclude=None,
        pretrained=True, finetuned=False,
        reality=[], task_filter=[tasks.segment_semantic],
        freeze_list=[], lazy=False, initialize_from_transfer=True,
    ):

        super().__init__()
        self.tasks = list(set(tasks) - set(task_filter))
        self.tasks += [task.base for task in self.tasks if hasattr(task, "base")]
        self.edge_list, self.edge_list_exclude = edges, edges_exclude
        self.pretrained, self.finetuned = pretrained, finetuned
        self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(list)
        self.edge_map, self.reality = {}, reality
        self.initialize_from_transfer = initialize_from_transfer
        print('Creating graph with tasks:', self.tasks)
        self.params = {}

        # construct transfer graph
        for src_task, dest_task in itertools.product(self.tasks, self.tasks):
            key = (src_task, dest_task)
            if edges is not None and key not in edges: continue
            if edges_exclude is not None and key in edges_exclude: continue
            if src_task == dest_task: continue
            if isinstance(dest_task, RealityTask): continue
            # print (src_task, dest_task)
            transfer = None
            if isinstance(src_task, RealityTask):
                if dest_task not in src_task.tasks: continue
                transfer = RealityTransfer(src_task, dest_task)
            else:
                transfer = Transfer(src_task, dest_task,
                    pretrained=pretrained, finetuned=finetuned
                )
                transfer.name = get_transfer_name(transfer)
                if not self.initialize_from_transfer:
                    transfer.path = None
            if transfer.model_type is None:
                continue
            # print ("Added transfer", transfer)
            self.edges += [transfer]
            self.adj[src_task.name] += [transfer]
            self.in_adj[dest_task.name] += [transfer]
            self.edge_map[str((src_task.name, dest_task.name))] = transfer
            if isinstance(transfer, nn.Module):
                if str((src_task.name, dest_task.name)) not in freeze_list:
                    self.params[str((src_task.name, dest_task.name))] = transfer
                else:
                    print("Setting link: " + str((src_task.name, dest_task.name)) + " not trainable.")
                try:
                    if not lazy: transfer.load_model()
                except Exception as e:
                    print(e)
                    IPython.embed()

        self.params = nn.ModuleDict(self.params)

    def edge(self, src_task, dest_task):
        key1 = str((src_task.name, dest_task.name))
        key2 = str((src_task.kind, dest_task.kind))
        if key1 in self.edge_map: return self.edge_map[key1]
        return self.edge_map[key2]

    def sample_path(self, path, reality=None, use_cache=False, cache={}):
        path = [reality or self.reality[0]] + path
        x = None
        for i in range(1, len(path)):
            try:
                # if x is not None: print (x.shape)
                # print (self.edge(path[i-1], path[i]))
                x = cache.get(tuple(path[0:(i+1)]),
                    self.edge(path[i-1], path[i])(x)
                )
            except KeyError:
                return None
            except Exception as e:
                print(e)
                IPython.embed()

            if use_cache: cache[tuple(path[0:(i+1)])] = x
        return x

    def save(self, weights_file=None, weights_dir=None):

        ### TODO: save optimizers here too
        if weights_file:
            torch.save({
                key: model.state_dict() for key, model in self.edge_map.items() \
                if not isinstance(model, RealityTransfer)
            }, weights_file)

        if weights_dir:
            os.makedirs(weights_dir, exist_ok=True)
            for key, model in self.edge_map.items():
                if isinstance(model, RealityTransfer): continue
                if not isinstance(model.model, TrainableModel): continue
                model.model.save(f"{weights_dir}/{model.name}.pth")
            torch.save(self.optimizer, f"{weights_dir}/optimizer.pth")


#    def load_weights(self, weights_file=None):
#        for key, state_dict in torch.load(weights_file).items():
#            if key in self.edge_map:
#                self.edge_map[key].load_state_dict(state_dict)

    def load_weights(self, weights_file=None):
        loaded_something = False
        for key, state_dict in torch.load(weights_file).items():
            if key in self.edge_map:
                loaded_something = True
                self.edge_map[key].load_model()
                self.edge_map[key].load_state_dict(state_dict)
        if not loaded_something:
            raise RuntimeError(f"No edges loaded from file: {weights_file}")


================================================
FILE: hooks/build
================================================
#!/bin/bash

docker build . -t $IMAGE_NAME --build-arg GITHUB_DEPLOY_KEY="$GITHUB_DEPLOY_KEY" --build-arg GITHUB_DEPLOY_KEY_PUBLIC="$GITHUB_DEPLOY_KEY_PUBLIC"



================================================
FILE: logger.py
================================================

import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import random, sys, os, json, math

import torch
from torchvision import datasets, transforms, utils
import visdom

from utils import *
from utils import elapsed
import IPython
import pdb

class BaseLogger(object):
    """ Logger class, with hooks for data features and plotting functions. """
    def __init__(self, name, verbose=True):

        self.name = name
        self.data = {}
        self.running_data = {}
        self.reset_running = {}
        self.verbose = verbose
        self.hooks = []

    def add_hook(self, hook, feature='epoch', freq=40):
        self.hooks.append((hook, feature, freq))

    def update(self, feature, x):
        if isinstance(x, torch.Tensor):
            x = x.clone().detach().cpu().numpy().mean()
        else:
            x = torch.tensor(x).data.cpu().numpy().mean()

        self.data[feature] = self.data.get(feature, [])
        self.data[feature].append(x)
        if feature not in self.running_data or self.reset_running.pop(feature, False):
            self.running_data[feature] = []
        self.running_data[feature].append(x)

        for hook, hook_feature, freq in self.hooks:
            if feature == hook_feature and len(self.data[feature]) % freq == 0:
                hook(self, self.data)

    def step(self):
        buf = ""
        buf += f"({self.name}) "
        for feature in self.running_data.keys():
            if len(self.running_data[feature]) == 0: continue
            val = np.mean(self.running_data[feature])
            if float(val).is_integer():
                buf += f"{feature}: {int(val)}, "
            else:
                buf += f"{feature}: {val:0.4f}" + ", "
            self.reset_running[feature] = True
        buf += f" ... {elapsed():0.2f} sec"
        self.text (buf)

    def text(self, text, end="\n"):
        raise NotImplementedError()

    def plot(self, data, plot_name, opts={}):
        raise NotImplementedError()

    def images(self, data, image_name):
        raise NotImplementedError()

    def plot_feature(self, feature, opts={}):
        self.plot(self.data[feature], feature, opts)

    def plot_features(self, features, name, opts={}):
        stacked = np.stack([self.data[feature] for feature in features], axis=1)
        self.plot(stacked, name, opts={"legend": features})


class Logger(BaseLogger):

    def __init__(self, *args, **kwargs):
        self.results = kwargs.pop('results', 'output')
        super().__init__(*args, **kwargs)

    def text(self, text, end='\n'):
        print (text, end=end, flush=True)

    def plot(self, data, plot_name, opts={}):
        np.savez_compressed(f"{self.results}/{plot_name}.npz", data)
        plt.plot(data)
        plt.savefig(f"{self.results}/{plot_name}.jpg");
        plt.clf()


class VisdomLogger(BaseLogger):

    def __init__(self, *args, **kwargs):
        self.env = kwargs.pop('env', 'CH')
        self.port = kwargs.pop('port', 8097)
        self.server = kwargs.pop('server', '127.0.0.1')
        self.delete = kwargs.pop('delete', True)
        print ("No deletion")
        print ("In (git) scaling-reset")
        print (f"Logging to environment {self.env}")
        self.visdom = visdom.Visdom(server="http://" + self.server, port=self.port, env=self.env)
        self.visdom.delete_env(self.env)
        self.windows = {}
        super().__init__(*args, **kwargs)
        self.save()
        self.add_hook(lambda logger, data: self.save(), feature="epoch", freq=1)

    def text(self, text, end='\n'):
        print (text, end=end)
        window, old_text = self.windows.get('text', (None, ""))
        if end == '\n': end = '<br>'
        display = old_text + text + end

        if window is not None:
            window = self.visdom.text (display, win=window, append=False)
        else:
            window = self.visdom.text (display)

        self.windows["text"] = window, display

    def window(self, plot_name, plot_func, *args, **kwargs):

        options = {'title': plot_name}
        options.update(kwargs.pop("opts", {}))
        window = self.windows.get(plot_name, None)
        if window is not None and self.visdom.win_exists(window):
            window = plot_func(*args, **kwargs, opts=options, win=window)
        else:
            window = plot_func(*args, **kwargs, opts=options)

        self.windows[plot_name] = window

    def plot(self, data, plot_name, opts={}):
        self.window(plot_name, self.visdom.line,
            np.array(data), X=np.array(range(len(data))), opts=opts
        )

    def histogram(self, data, plot_name, opts={}):
        self.window(plot_name, self.visdom.histogram, np.array(data), opts=opts)

    def scatter(self, X, Y, plot_name, opts={}):
        self.window(plot_name, self.visdom.scatter, np.stack([X, Y], axis=1), opts=opts)

    def bar(self, data, plot_name, opts={}):
        self.window(plot_name, self.visdom.bar, np.array(data), opts=opts)

    def save(self):
        self.visdom.save([self.env])

    def images(self, data, plot_name, opts={}, nrow=2, normalize=False, resize=64):

        transform = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(resize),
                                    transforms.ToTensor()])
        data = torch.stack([transform(x.cpu()) for x in data])
        data = utils.make_grid(data, nrow=nrow, normalize=normalize, pad_value=0)
        self.window(plot_name, self.visdom.image, np.array(data), opts=opts)

    def images_grouped(self, image_groups, plot_name, **kwargs):
        interleave = [y for x in zip(*image_groups) for y in x]
        self.images(interleave, plot_name, nrow=len(image_groups), **kwargs)




================================================
FILE: models.py
================================================
import os, sys, random
from inspect import signature
import numpy as np
import matplotlib as mpl
import torch
import torch.nn as nn
import torch.nn.functional as F
from   torch import optim

from utils import *


""" Model that implements batchwise training with "compilation" and custom loss.
Exposed methods: predict_on_batch(), fit_on_batch(),
Overridable methods: loss(), forward().
"""


class AbstractModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.compiled = False

    # Compile module and assign optimizer + params
    def compile(self, optimizer=None, **kwargs):

        if optimizer is not None:
            self.optimizer_class = optimizer
            self.optimizer_kwargs = kwargs
            self.optimizer = self.optimizer_class(self.parameters(), **self.optimizer_kwargs)
        else:
            self.optimizer = None

        self.compiled = True
        self.to(DEVICE)

    # Predict scores from a batch of data
    def predict_on_batch(self, data):

        self.eval()
        with torch.no_grad():
            return self.forward(data)

    # Fit (make one optimizer step) on a batch of data
    def fit_on_batch(self, data, target, loss_fn=None, train=True):
        loss_fn = loss_fn or self.loss

        self.zero_grad()
        self.optimizer.zero_grad()

        self.train(train)

        self.zero_grad()
        self.optimizer.zero_grad()
        pred = self.forward(data)
        if isinstance(target, list):
            target = tuple(t.to(pred.device) for t in target)
        else: target = target.to(pred.device)

        if len(signature(loss_fn).parameters) > 2:
            loss, metrics = loss_fn(pred, target, data.to(pred.device))
        else:
            loss, metrics = loss_fn(pred, target)

        if train:
            loss.backward()
            self.optimizer.step()
            self.zero_grad()
            self.optimizer.zero_grad()

        return pred, loss, metrics

    # Make one optimizer step w.r.t a loss
    def step(self, loss, train=True):

        self.zero_grad()
        self.optimizer.zero_grad()
        self.train(train)
        self.zero_grad()
        self.optimizer.zero_grad()

        loss.backward()
        self.optimizer.step()
        self.zero_grad()
        self.optimizer.zero_grad()

    @classmethod
    def load(cls, weights_file=None):
        model = cls()
        if weights_file is not None:
            data = torch.load(weights_file)
            # hack for models saved with optimizers
            if "optimizer" in data: data = data["state_dict"]
            model.load_state_dict(data)
        return model

    def load_weights(self, weights_file, backward_compatible=False):
        data = torch.load(weights_file)
        if backward_compatible:
            data = {'parallel_apply.module.'+k:v for k,v in data.items()}
        self.load_state_dict(data)

    def save(self, weights_file):
        torch.save(self.state_dict(), weights_file)

    # Subclasses: override for custom loss + forward functions
    def loss(self, pred, target):
        raise NotImplementedError()

    def forward(self, x):
        raise NotImplementedError()


""" Model that implements training and prediction on generator objects, with
the ability to print train and validation metrics.
"""


class TrainableModel(AbstractModel):
    def __init__(self):
        super().__init__()

    # Fit on generator for one epoch
    def _process_data(self, datagen, loss_fn=None, train=True, logger=None):

        self.train(train)
        out = []
        for data in datagen:
            batch, y = data[0], data[1:]
            if len(y) == 1: y = y[0]
            y_pred, loss, metric_data = self.fit_on_batch(batch, y, loss_fn=loss_fn, train=train)
            if logger is not None:
                logger.update("loss", float(loss))
            yield ((batch.detach(), y_pred.detach(), y, float(loss), metric_data))

    def fit(self, datagen, loss_fn=None, logger=None):
        for x in self._process_data(datagen, loss_fn=loss_fn, train=train, logger=logger):
            pass

    def fit_with_data(self, datagen, loss_fn=None, logger=None):
        images, preds, targets, losses, metrics = zip(
            *self._process_data(datagen, loss_fn=loss_fn, train=True, logger=logger)
        )
        images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0)
        metrics = zip(*metrics)
        return images, preds, targets, losses, metrics

    def fit_with_metrics(self, datagen, loss_fn=None, logger=None):
        metrics = [
            metrics
            for _, _, _, _, metrics in self._process_data(
                datagen, loss_fn=loss_fn, train=True, logger=logger
            )
        ]
        return list(zip(*metrics))

    def predict_with_data(self, datagen, loss_fn=None, logger=None):
        images, preds, targets, losses, metrics = zip(
            *self._process_data(datagen, loss_fn=loss_fn, train=False, logger=logger)
        )
        images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0)
        images, preds, targets = images.cpu(), preds.cpu(), targets.cpu()
        # preds = torch.cat(preds, dim=0)
        metrics = zip(*metrics)
        return images, preds, targets, losses, metrics

    def predict_with_metrics(self, datagen, loss_fn=None, logger=None):
        metrics = [
            metrics
            for _, _, _, _, metrics in self._process_data(
                datagen, loss_fn=loss_fn, train=False, logger=logger
            )
        ]
        return list(zip(*metrics))

    def predict(self, datagen):
        preds = [self.predict_on_batch(x) for x in datagen]
        preds = torch.cat(preds, dim=0)
        return preds


class DataParallelModel(TrainableModel):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.parallel_apply = nn.DataParallel(*args, **kwargs)

    def forward(self, x):
        return self.parallel_apply(x)

    def loss(self, x, preds):
        return self.parallel_apply.module.loss(x, preds)

    @property
    def module(self):
        return self.parallel_apply.module

    @classmethod
    def load(cls, model=TrainableModel(), weights_file=None):
        model = cls(model)
        if weights_file is not None:
            data = torch.load(weights_file, map_location=lambda storage, loc: storage)
            # hack for models saved with optimizers
            if "optimizer" in data: data = data["state_dict"]
            try:
                model.load_state_dict(data)
            except RuntimeError:
                parallel_module_data = {
                    'parallel_apply.module.' + k: v for k, v in data.items()
                }
                model.load_state_dict(parallel_module_data)
        return model

class WrapperModel(TrainableModel):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def loss(self, x, preds):
        raise NotImplementedError()

    def __getitem__(self, i):
        return self.model[i]

    @property
    def module(self):
        return self.model


if __name__ == "__main__":
    import IPython
    IPython.embed()


================================================
FILE: modules/__init__.py
================================================


================================================
FILE: modules/depth_nets.py
================================================

import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR

from models import TrainableModel
from utils import *


class ResidualsNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            ConvBlock(3, 32, groups=3, use_groupnorm=False), 
            ConvBlock(32, 32, use_groupnorm=False),
        )
        self.mid = nn.Sequential(
            ConvBlock(32, 64, dilation=1, use_groupnorm=False), 
            ConvBlock(64, 64, dilation=2, use_groupnorm=False),
            ConvBlock(64, 64, dilation=2, use_groupnorm=False),
            ConvBlock(64, 32, dilation=4, use_groupnorm=False),
        )
        self.decoder = nn.Sequential(
            ConvBlock(64, 32, use_groupnorm=False),
            ConvBlock(32, 32, use_groupnorm=False),
            ConvBlock(32, 1, use_groupnorm=False),
        )

    def forward(self, x):
        tmp = self.encoder(x)
        x = F.max_pool2d(tmp, 4)
        x = self.mid(x)
        x = F.upsample(x, scale_factor=4, mode='bilinear')
        x = torch.cat([x, tmp], dim=1)
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class UNet_up_block(nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
        super().__init__()
        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)        
        self.relu = torch.nn.ReLU()
        self.up_sample = up_sample

    def forward(self, prev_feature_map, x):
        if self.up_sample:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class UNet_down_block(nn.Module):
    def __init__(self, input_channel, output_channel, down_size=True):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.down_size = down_size

    def forward(self, x):

        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        if self.down_size:
            x = self.max_pool(x)
        return x


class UNetDepth(TrainableModel):
    def __init__(self):
        super().__init__()

        self.down_block1 = UNet_down_block(3, 16, False)
        self.down_block2 = UNet_down_block(16, 32, True)
        self.down_block3 = UNet_down_block(32, 64, True)
        self.down_block4 = UNet_down_block(64, 128, True)
        self.down_block5 = UNet_down_block(128, 256, True)
        self.down_block6 = UNet_down_block(256, 512, True)
        self.down_block7 = UNet_down_block(512, 1024, False)

        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, 1024)
        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, 1024)
        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn3 = torch.nn.GroupNorm(8, 1024)

        self.up_block1 = UNet_up_block(512, 1024, 512, False)
        self.up_block2 = UNet_up_block(256, 512, 256, True)
        self.up_block3 = UNet_up_block(128, 256, 128, True)
        self.up_block4 = UNet_up_block(64, 128, 64, True)
        self.up_block5 = UNet_up_block(32, 64, 32, True)
        self.up_block6 = UNet_up_block(16, 32, 16, True)

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, 1, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.x1 = self.down_block1(x)
        x = self.x2 = self.down_block2(self.x1)
        x = self.x3 = self.down_block3(self.x2)
        x = self.x4 = self.down_block4(self.x3)
        x = self.x5 = self.down_block5(self.x4)
        x = self.x6 = self.down_block6(self.x5)
        x = self.x7 = self.down_block7(self.x6)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        x = self.up_block1(self.x6, x)
        x = self.up_block2(self.x5, x)
        x = self.up_block3(self.x4, x)
        x = self.up_block4(self.x3, x)
        x = self.up_block5(self.x2, x)
        x = self.up_block6(self.x1, x)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)



class ConvBlock(nn.Module):
    def __init__(self, f1, f2, use_groupnorm=True, groups=8, dilation=1, transpose=False):
        super().__init__()
        self.transpose = transpose
        self.conv = nn.Conv2d(f1, f2, (3, 3), dilation=dilation, padding=dilation)
        if self.transpose:
            self.convt = nn.ConvTranspose2d(
                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1
            )
        if use_groupnorm:
            self.bn = nn.GroupNorm(groups, f1)
        else:
            self.bn = nn.GroupNorm(8, f1)

    def forward(self, x):
        # x = F.dropout(x, 0.04, self.training)
        x = self.bn(x)
        if self.transpose:
            # x = F.upsample(x, scale_factor=2, mode='bilinear')
            x = F.relu(self.convt(x))
            # x = x[:, :, :-1, :-1]
        x = F.relu(self.conv(x))
        return x

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.GroupNorm(8, planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.GroupNorm(8, planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(8, planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNetOriginal(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNetOriginal, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.GroupNorm(8, 64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 196, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(8, planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

class ResNetDepth(TrainableModel):
    def __init__(self):
        super().__init__()
        # self.resnet = models.resnet50()
        self.resnet = ResNetOriginal(Bottleneck, [3, 4, 6, 3])
        self.final_conv = nn.Conv2d(2048, 8, (3, 3), padding=1)

        self.decoder = nn.Sequential(
            ConvBlock(8, 128),
            ConvBlock(128, 128),
            ConvBlock(128, 128),
            ConvBlock(128, 128),
            ConvBlock(128, 128),
            ConvBlock(128, 128, transpose=True),
            ConvBlock(128, 128, transpose=True),
            ConvBlock(128, 128, transpose=True),
            ConvBlock(128, 128, transpose=True),
            ConvBlock(128, 1, transpose=True),
        )

    def forward(self, x):

        for layer in list(self.resnet._modules.values())[:-2]:
            x = layer(x)

        x = self.final_conv(x)
        x = self.decoder(x)

        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)



================================================
FILE: modules/percep_nets.py
================================================


import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR

from models import TrainableModel
from utils import *


class ConvBlock(nn.Module):
    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):
        super().__init__()
        self.transpose = transpose
        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)
        if self.transpose:
            self.convt = nn.ConvTranspose2d(
                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1
            )
        if use_groupnorm:
            self.bn = nn.GroupNorm(groups, f1)
        else:
            self.bn = nn.BatchNorm2d(f1)

    def forward(self, x):
        # x = F.dropout(x, 0.04, self.training)
        x = self.bn(x)
        if self.transpose:
            # x = F.upsample(x, scale_factor=2, mode='bilinear')
            x = F.relu(self.convt(x))
            # x = x[:, :, :-1, :-1]
        x = F.relu(self.conv(x))
        return x


class DenseNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 96, groups=3), 
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class Dense1by1Net(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), 
            ConvBlock(64, 96, kernel_size=1, padding=0), 
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class Dense1by1end(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), 
            ConvBlock(64, 96, kernel_size=1, padding=0), 
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 1),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class DenseKernelsNet(TrainableModel):
    def __init__(self, kernel_size=7):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), 
            ConvBlock(64, 96, kernel_size=1, padding=0), 
            ConvBlock(96, 96, kernel_size=1, padding=0),
            ConvBlock(96, 96, kernel_size=kernel_size, padding=kernel_size//2),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 96),
            ConvBlock(96, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class DeepNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 32, groups=3), 
            ConvBlock(32, 32),
            ConvBlock(32, 32, dilation=2),
            ConvBlock(32, 32, dilation=2),
            ConvBlock(32, 32, dilation=4),
            ConvBlock(32, 32, dilation=4),
            ConvBlock(32, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class WideNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 32, groups=3), 
            ConvBlock(32, 32, kernel_size=5, padding=2),
            ConvBlock(32, 32, kernel_size=5, padding=2),
            ConvBlock(32, 32, kernel_size=5, padding=2),
            ConvBlock(32, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class PyramidNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 16, groups=3), 
            ConvBlock(16, 32, kernel_size=5, padding=2),
            ConvBlock(32, 64, kernel_size=5, padding=2),
            ConvBlock(64, 96, kernel_size=3, padding=1),
            ConvBlock(96, 32, kernel_size=3, padding=1),
            ConvBlock(32, 3),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)



class BaseNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.decoder = nn.Sequential(
            ConvBlock(3, 32, use_groupnorm=False), 
            ConvBlock(32, 32, use_groupnorm=False),
            ConvBlock(32, 32, use_groupnorm=False),
            ConvBlock(32, 1, use_groupnorm=False),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class ResidualsNet(TrainableModel):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            ConvBlock(3, 32, use_groupnorm=False), 
            ConvBlock(32, 32, use_groupnorm=False),
        )
        self.mid = nn.Sequential(
            ConvBlock(32, 64, use_groupnorm=False), 
            ConvBlock(64, 64, use_groupnorm=False),
            ConvBlock(64, 32, use_groupnorm=False),
        )
        self.decoder = nn.Sequential(
            ConvBlock(64, 32, use_groupnorm=False), 
            ConvBlock(32, 3, use_groupnorm=False),
        )

    def forward(self, x):
        tmp = self.encoder(x)
        x = F.max_pool2d(tmp, 2)
        x = self.mid(x)
        x = F.upsample(x, scale_factor=2, mode='bilinear')
        x = torch.cat([x, tmp], dim=1)
        x = self.decoder(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class ResNet50(TrainableModel):
    def __init__(self, num_classes=365, in_channels=3):
        super().__init__()
        self.resnet = models.resnet18(num_classes=num_classes)
        self.resnet.fc = nn.Linear(in_features=8192, out_features=num_classes, bias=True)
        self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

    def forward(self, x):
        x = self.resnet(x)
        return F.log_softmax(x, dim=1)

    def loss(self, pred, target):
        loss = F.nll_loss(pred, target)
        return loss, (loss.detach(),)



================================================
FILE: modules/resnet.py
================================================

import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR

from models import TrainableModel
from utils import *



def conv3x3(in_planes, out_planes, stride=1, groups=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, groups=groups, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)



class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, norm_layer=None):
        super(BasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.GroupNorm(8, planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.GroupNorm(8, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.GroupNorm(8, planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.GroupNorm(8, planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(8, planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNetOriginal(nn.Module):

    def __init__(self, block, layers, in_channels=3, num_classes=1000):
        self.inplanes = 64
        super(ResNetOriginal, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.GroupNorm(8, 64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(8, planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

class ResNet(TrainableModel):
    def __init__(self, in_channels=3, out_channels=1000):
        super().__init__()
        self.resnet = ResNetOriginal(BasicBlock, [2, 2, 2, 2], in_channels=in_channels)
        self.final = nn.Linear(512, out_channels)

    def forward(self, x):

        for layer in list(self.resnet._modules.values())[:-2]:
            x = layer(x)
        x = F.relu(x.mean(dim=2).mean(dim=2))
        x = F.log_softmax(self.final(x), dim=1)

        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class ResNetClass(TrainableModel):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)

    def forward(self, x):
        if x.shape[1] == 1: x = x.repeat(1,3,1,1)
        for layer in list(self.resnet._modules.values())[:-2]:
            x = layer(x)
        return x

    ### Not in Use Right Now ###
    def loss(self, pred, target):
        mask = build_mask(pred, val=0.502)
        mse = F.mse_loss(pred[mask], target[mask])
        return mse, (mse.detach(),)


if __name__ == "__main__":
    model = ResNet(out_channels=365)
    print (model(torch.randn(2, 3, 224, 224 )).shape)



================================================
FILE: modules/unet.py
================================================

import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.checkpoint import checkpoint

from models import TrainableModel
from utils import *


class UNet_up_block(nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
        super().__init__()
        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)        
        self.relu = torch.nn.ReLU()
        self.up_sample = up_sample

    def forward(self, prev_feature_map, x):
        if self.up_sample:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class UNet_down_block(nn.Module):
    def __init__(self, input_channel, output_channel, down_size=True):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.down_size = down_size

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        if self.down_size:
            x = self.max_pool(x)
        return x


class UNet(TrainableModel):
    def __init__(self,  downsample=6, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.down1(x)
        xvals = [x]
        for i in range(0, self.downsample):
            x = self.down_blocks[i](x)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            x = self.up_blocks[i](xvals[i], x)

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class UNetReshade(TrainableModel):
    def __init__(self,  downsample=6, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.down1(x)
        xvals = [x]
        for i in range(0, self.downsample):
            x = self.down_blocks[i](x)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            x = self.up_blocks[i](xvals[i], x)

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True)
        x = x.expand(-1, 3, -1, -1)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class UNetOld(TrainableModel):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels = in_channels, out_channels
        self.down_block1 = UNet_down_block(in_channels, 16, False) #   256
        self.down_block2 = UNet_down_block(16, 32, True) #   128
        self.down_block3 = UNet_down_block(32, 64, True) #   64
        self.down_block4 = UNet_down_block(64, 128, True) #  32
        self.down_block5 = UNet_down_block(128, 256, True) # 16
        self.down_block6 = UNet_down_block(256, 512, True) # 8
        self.down_block7 = UNet_down_block(512, 1024, True)# 4 
        
        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, 1024)
        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, 1024)
        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, 1024)

        self.up_block1 = UNet_up_block(512, 1024, 512)
        self.up_block2 = UNet_up_block(256, 512, 256)
        self.up_block3 = UNet_up_block(128, 256, 128)
        self.up_block4 = UNet_up_block(64, 128, 64)
        self.up_block5 = UNet_up_block(32, 64, 32)
        self.up_block6 = UNet_up_block(16, 32, 16)

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        self.x1 = self.down_block1(x)
        self.x2 = self.down_block2(self.x1)
        self.x3 = self.down_block3(self.x2)
        self.x4 = self.down_block4(self.x3)
        self.x5 = self.down_block5(self.x4)
        self.x6 = self.down_block6(self.x5)
        self.x7 = self.down_block7(self.x6)

        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))
        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))

        x = self.up_block1(self.x6, self.x7)
        x = self.up_block2(self.x5, x)
        x = self.up_block3(self.x4, x)
        x = self.up_block4(self.x3, x)
        x = self.up_block5(self.x2, x)
        x = self.up_block6(self.x1, x)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class ConvBlock(nn.Module):
    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):
        super().__init__()
        self.transpose = transpose
        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)
        if self.transpose:
            self.convt = nn.ConvTranspose2d(
                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1
            )
        if use_groupnorm:
            self.bn = nn.GroupNorm(groups, f1)
        else:
            self.bn = nn.BatchNorm2d(f1)

    def forward(self, x):
        # x = F.dropout(x, 0.04, self.training)
        x = self.bn(x)
        if self.transpose:
            # x = F.upsample(x, scale_factor=2, mode='bilinear')
            x = F.relu(self.convt(x))
            # x = x[:, :, :-1, :-1]
        x = F.relu(self.conv(x))
        return x

class UNetOld2(TrainableModel):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels = in_channels, out_channels
        self.initial = nn.Sequential(
            ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0),
            ConvBlock(16, 16, groups=4, kernel_size=1, padding=0)
        )
        self.down_block1 = UNet_down_block(16, 16, False)
        self.down_block2 = UNet_down_block(16, 32, True) #   128
        self.down_block3 = UNet_down_block(32, 64, True) #   64
        self.down_block4 = UNet_down_block(64, 128, True) #  32
        self.down_block5 = UNet_down_block(128, 256, True) # 16
        self.down_block6 = UNet_down_block(256, 512, True) # 8
        self.down_block7 = UNet_down_block(512, 1024, True)# 4 
        
        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, 1024)
        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, 1024)
        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, 1024)

        self.up_block1 = UNet_up_block(512, 1024, 512)
        self.up_block2 = UNet_up_block(256, 512, 256)
        self.up_block3 = UNet_up_block(128, 256, 128)
        self.up_block4 = UNet_up_block(64, 128, 64)
        self.up_block5 = UNet_up_block(32, 64, 32)
        self.up_block6 = UNet_up_block(16, 32, 16)

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.initial(x)
        self.x1 = self.down_block1(x)
        self.x2 = self.down_block2(self.x1)
        self.x3 = self.down_block3(self.x2)
        self.x4 = self.down_block4(self.x3)
        self.x5 = self.down_block5(self.x4)
        self.x6 = self.down_block6(self.x5)
        self.x7 = self.down_block7(self.x6)

        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))
        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))

        x = self.up_block1(self.x6, self.x7)
        x = self.up_block2(self.x5, x)
        x = self.up_block3(self.x4, x)
        x = self.up_block4(self.x3, x)
        x = self.up_block5(self.x2, x)
        x = self.up_block6(self.x1, x)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)
        


================================================
FILE: modules/unet_mirrored.py
================================================

import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.checkpoint import checkpoint

from models import TrainableModel
from utils import *


class UNet_up_block(nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
        super().__init__()
        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3)
        self.bn3 = nn.GroupNorm(8, output_channel)        
        self.relu = torch.nn.ReLU()
        self.up_sample = up_sample

    def forward(self, prev_feature_map, x):
        if self.up_sample:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect')))
        x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect')))
        x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect')))
        return x


class UNet_down_block(nn.Module):
    def __init__(self, input_channel, output_channel, down_size=True):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, 3)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3)
        self.bn3 = nn.GroupNorm(8, output_channel)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.down_size = down_size

    def forward(self, x):
        x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect')))
        x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect')))
        x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect')))
        if self.down_size:
            x = self.max_pool(x)
        return x


class UNet(TrainableModel):
    def __init__(self, downsample=6, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.down1(x)
        xvals = [x]
        pad = [False for i in range(0, self.downsample)]
        for i in range(0, self.downsample):
            if x.shape[2] % 2 != 0: 
                x = F.pad(x, (1, 0, 1, 0))
                pad[i] = True
            x = self.down_blocks[i](x)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            # print (x.shape, xvals[i].shape)
            x = self.up_blocks[i](xvals[i], x)
            if pad[i] != 0: 
                x = x[:, :, 1:, 1:]

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class UNetReshade(TrainableModel):
    def __init__(self,  downsample=6, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.down1(x)
        xvals = [x]
        for i in range(0, self.downsample):
            x = self.down_blocks[i](x)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            x = self.up_blocks[i](xvals[i], x)

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True)
        x = x.expand(-1, 3, -1, -1)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class UNetOld(TrainableModel):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels = in_channels, out_channels
        self.down_block1 = UNet_down_block(in_channels, 16, False) #   256
        self.down_block2 = UNet_down_block(16, 32, True) #   128
        self.down_block3 = UNet_down_block(32, 64, True) #   64
        self.down_block4 = UNet_down_block(64, 128, True) #  32
        self.down_block5 = UNet_down_block(128, 256, True) # 16
        self.down_block6 = UNet_down_block(256, 512, True) # 8
        self.down_block7 = UNet_down_block(512, 1024, True)# 4 
        
        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, 1024)
        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, 1024)
        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, 1024)

        self.up_block1 = UNet_up_block(512, 1024, 512)
        self.up_block2 = UNet_up_block(256, 512, 256)
        self.up_block3 = UNet_up_block(128, 256, 128)
        self.up_block4 = UNet_up_block(64, 128, 64)
        self.up_block5 = UNet_up_block(32, 64, 32)
        self.up_block6 = UNet_up_block(16, 32, 16)

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        self.x1 = self.down_block1(x)
        self.x2 = self.down_block2(self.x1)
        self.x3 = self.down_block3(self.x2)
        self.x4 = self.down_block4(self.x3)
        self.x5 = self.down_block5(self.x4)
        self.x6 = self.down_block6(self.x5)
        self.x7 = self.down_block7(self.x6)

        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))
        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))

        x = self.up_block1(self.x6, self.x7)
        x = self.up_block2(self.x5, x)
        x = self.up_block3(self.x4, x)
        x = self.up_block4(self.x3, x)
        x = self.up_block5(self.x2, x)
        x = self.up_block6(self.x1, x)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


class ConvBlock(nn.Module):
    def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False):
        super().__init__()
        self.transpose = transpose
        self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation)
        if self.transpose:
            self.convt = nn.ConvTranspose2d(
                f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1
            )
        if use_groupnorm:
            self.bn = nn.GroupNorm(groups, f1)
        else:
            self.bn = nn.BatchNorm2d(f1)

    def forward(self, x):
        # x = F.dropout(x, 0.04, self.training)
        x = self.bn(x)
        if self.transpose:
            # x = F.upsample(x, scale_factor=2, mode='bilinear')
            x = F.relu(self.convt(x))
            # x = x[:, :, :-1, :-1]
        x = F.relu(self.conv(x))
        return x

class UNetOld2(TrainableModel):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.in_channels, self.out_channels = in_channels, out_channels
        self.initial = nn.Sequential(
            ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0),
            ConvBlock(16, 16, groups=4, kernel_size=1, padding=0)
        )
        self.down_block1 = UNet_down_block(16, 16, False)
        self.down_block2 = UNet_down_block(16, 32, True) #   128
        self.down_block3 = UNet_down_block(32, 64, True) #   64
        self.down_block4 = UNet_down_block(64, 128, True) #  32
        self.down_block5 = UNet_down_block(128, 256, True) # 16
        self.down_block6 = UNet_down_block(256, 512, True) # 8
        self.down_block7 = UNet_down_block(512, 1024, True)# 4 
        
        self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, 1024)
        self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, 1024)
        self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, 1024)

        self.up_block1 = UNet_up_block(512, 1024, 512)
        self.up_block2 = UNet_up_block(256, 512, 256)
        self.up_block3 = UNet_up_block(128, 256, 128)
        self.up_block4 = UNet_up_block(64, 128, 64)
        self.up_block5 = UNet_up_block(32, 64, 32)
        self.up_block6 = UNet_up_block(16, 32, 16)

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.initial(x)
        self.x1 = self.down_block1(x)
        self.x2 = self.down_block2(self.x1)
        self.x3 = self.down_block3(self.x2)
        self.x4 = self.down_block4(self.x3)
        self.x5 = self.down_block5(self.x4)
        self.x6 = self.down_block6(self.x5)
        self.x7 = self.down_block7(self.x6)

        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7)))
        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))

        x = self.up_block1(self.x6, self.x7)
        x = self.up_block2(self.x5, x)
        x = self.up_block3(self.x4, x)
        x = self.up_block4(self.x3, x)
        x = self.up_block5(self.x2, x)
        x = self.up_block6(self.x1, x)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.relu(self.last_conv2(x))
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)


if __name__ == "__main__":

    model = UNet()
    x = torch.randn(1, 3, 256, 256)
    print (model(x).shape)

        


================================================
FILE: plotting.py
================================================
import numpy as np

def jointplot(logger, data, loss_type="mse_loss"):
    data = np.stack((data[f"train_{loss_type}"], data[f"val_{loss_type}"]), axis=1)
    logger.plot(data, loss_type, opts={"legend": [f"train_{loss_type}", f"val_{loss_type}"]})

def get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(list_of_list_values):
    running_mean_and_std_bounds = []
    legend = ["Mean-STD", "Mean Difference", "Mean+STD"]
    for ii, losses_in_batch_ii in enumerate(list_of_list_values):
        if ii == 0:  # there's no previous time step to compare to
            running_mean_and_std_bounds.append([0, 0, 0])
        else:
            loss_diffs = [loss_val - list_of_list_values[ii - 1][jj]
                          for jj, loss_val in enumerate(losses_in_batch_ii)]
            mean = np.mean(loss_diffs)
            std = np.std(loss_diffs)

            running_mean_and_std_bounds.append([mean - std, mean, mean + std])

    return running_mean_and_std_bounds, legend

def get_running_means_w_std_bounds_and_legend(list_of_list_values):
    running_mean_and_std_bounds = []
    legend = ["Mean-STD", "Mean", "Mean+STD"]
    for ii in range(len(list_of_list_values)):
        mean = np.mean(list_of_list_values[ii])
        std = np.std(list_of_list_values[ii])

        running_mean_and_std_bounds.append([mean - std, mean, mean + std])

    return running_mean_and_std_bounds, legend


def get_running_std(list_of_list_values):
    return [np.std(list_of_list_values[ii]) for ii in range(len(list_of_list_values))]


def get_running_p_coeffs(list_of_list_values_1, list_of_list_values_2):
    assert len(list_of_list_values_1) == len(list_of_list_values_2)

    pearson_coefficients = []
    for ii in range(len(list_of_list_values_1)):
        cov = np.cov(np.stack((list_of_list_values_1[ii],
                               list_of_list_values_2[ii]), axis=0))[0, 1]
        std1 = np.std(list_of_list_values_1[ii])
        std2 = np.std(list_of_list_values_2[ii])
        correlation_coefficient = cov / (std1 * std2)

        pearson_coefficients.append(correlation_coefficient)

    return pearson_coefficients

def mseplots(data, logger):
    data = np.stack((logger.data["train_mse_loss"], logger.data["val_mse_loss"]), axis=1)
    logger.plot(data, "mse_loss", opts={"legend": ["train_mse", "val_mse"]})

    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data["val_mse_losses"])
    logger.plot(running_mean_and_std_bounds, "val_mse_loss_running_mean", opts={"legend": legend})
    logger.plot(get_running_std(logger.data["val_mse_losses"]), "val_mse_losses_running_stds",
                opts={"legend": ['STD']})

    running_mean_and_std_bounds_diff_prev_time_step, legend = \
        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_mse_losses"])
    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_mse_loss_diff_prev_step_running_mean",
                opts={"legend": legend})


def curvatureplots(data, logger):
    data = np.stack((logger.data["train_curvature_loss"], logger.data["val_curvature_loss"]), axis=1)
    logger.plot(data, "curvature_loss", opts={"legend": ["train_curvature", "val_curvature"]})

    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(
        logger.data["val_curvature_losses"])
    logger.plot(running_mean_and_std_bounds, "val_curvature_loss_running_mean", opts={"legend": legend})
    logger.plot(get_running_std(logger.data["val_curvature_losses"]), "val_curvature_losses_running_stds",
                opts={"legend": ['STD']})

    running_mean_and_std_bounds_diff_prev_time_step, legend = \
        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_curvature_losses"])
    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_curvature_loss_diff_prev_step_running_mean",
                opts={"legend": legend})


def depthplots(data, logger):
    data = np.stack((logger.data["train_depth_loss"], logger.data["val_depth_loss"]), axis=1)
    logger.plot(data, "depth_loss", opts={"legend": ["train_depth", "val_depth"]})

    running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data["val_depth_losses"])
    logger.plot(running_mean_and_std_bounds, "val_depth_loss_running_mean", opts={"legend": legend})
    logger.plot(get_running_std(logger.data["val_depth_losses"]), "val_depth_losses_running_stds",
                opts={"legend": ['STD']})

    running_mean_and_std_bounds_diff_prev_time_step, legend = \
        get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_depth_losses"])
    logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_depth_loss_diff_prev_step_running_mean",
                opts={"legend": legend})


def covarianceplot(data, logger):
    covs = get_running_p_coeffs(logger.data["val_mse_losses"], logger.data["val_curvature_losses"])
    logger.plot(covs, "val_mse_curvature_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']})

    covs = get_running_p_coeffs(logger.data["val_mse_losses"], logger.data["val_depth_losses"])
    logger.plot(covs, "val_mse_depth_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']})

    covs = get_running_p_coeffs(logger.data["val_curvature_losses"], logger.data["val_depth_losses"])
    logger.plot(covs, "train_curvature_depth_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']})

    ratio_mse_curv_stds = [mse_std / curv_std for mse_std, curv_std in
                           zip(get_running_std(logger.data["val_mse_losses"]),
                               get_running_std(logger.data["val_curvature_losses"]))]
    logger.plot(ratio_mse_curv_stds, "val_mse_over_curvature_stds", opts={"legend": ['MSE / Curvature STD']})

    ratio_mse_depth_stds = [mse_std / depth_std for mse_std, depth_std in
                            zip(get_running_std(logger.data["val_mse_losses"]),
                                get_running_std(logger.data["val_depth_losses"]))]
    logger.plot(ratio_mse_depth_stds, "val_mse_over_depth_stds", opts={"legend": ['MSE / Depth STD']})


================================================
FILE: requirements.txt
================================================
fire==0.2.1
ipython==6.5.0
matplotlib==3.0.3
numpy==1.17.2
parse==1.12.1
pip==19.3.1
plac==0.9.6
py==1.6.0
scipy==1.3.1
scikit-image==0.16.2
scikit-learn==0.22.1
torch==1.2.0
torchvision==0.4.0
tqdm==4.36.1
visdom==0.1.8.9
pathlib==1.0.1
pyyaml==5.3.1
pandas
seaborn
statsmodels


================================================
FILE: scripts/energy_calc.py
================================================
import os, sys, math, random, itertools
import numpy as np
import scipy
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *
from plotting import *
from energy import get_energy_loss
from graph import TaskGraph
from datasets import TaskDataset, load_train_val, load_test, load_ood, ImageDataset
from task_configs import tasks, RealityTask

from functools import partial
from fire import Fire

import IPython
import pdb
from modules.unet import UNet

def main(
    loss_config="conservative_full",
    mode="standard",
    pretrained=True, finetuned=False, batch_size=16,
    ood_batch_size=None, subset_size=None,
    cont=None,
    use_l1=True, num_workers=32, data_dir=None, save_dir='mount/shared/', **kwargs,
):

    # CONFIG
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    if data_dir is None:
        buildings = ["almena", "albertville"]
        train_subset_dataset = TaskDataset(buildings, tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature])
    else:
        train_subset_dataset = ImageDataset(data_dir=data_dir)
        data_dir = 'CUSTOM'

    train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=False)

    if subset_size is None:
        subset_size = len(train_subset_dataset)
    subset_size = min(subset_size, len(train_subset_dataset))

    # GRAPH
    realities = [train_subset]
    edges = []
    for t in energy_loss.tasks:
        if t != tasks.rgb:
            edges.append((tasks.rgb, t))
            edges.append((tasks.rgb, tasks.normal))


    graph = TaskGraph(tasks=energy_loss.tasks + [train_subset],
                      finetuned=finetuned,
                      freeze_list=energy_loss.freeze_list, lazy=True,
                      initialize_from_transfer=True,
                      )

    # print('file', cont)
    #graph.load_weights(cont)
    graph.compile(optimizer=None)

    # Add consistency links
    for target in ['reshading', 'depth_zbuffer', 'normal']:
        graph.edge_map[str(('rgb', target))].path = None
        graph.edge_map[str(('rgb', target))].load_model()
    graph.edge_map[str(('rgb', 'reshading'))].model.load_weights('./models/rgb2reshading_consistency.pth',backward_compatible=True)
    graph.edge_map[str(('rgb', 'depth_zbuffer'))].model.load_weights('./models/rgb2depth_consistency.pth',backward_compatible=True)
    graph.edge_map[str(('rgb', 'normal'))].model.load_weights('./models/rgb2normal_consistency.pth',backward_compatible=True)

    energy_losses, mse_losses = [], []
    percep_losses = defaultdict(list)

    energy_mean_by_blur, energy_std_by_blur = [], []
    error_mean_by_blur, error_std_by_blur = [], []

    energy_losses, error_losses = [], []

    energy_losses_all, energy_losses_headings = [], []

    fnames = []
    train_subset.reload()
    # Compute energies
    for epochs in tqdm(range(subset_size // batch_size)):
        with torch.no_grad():
            losses = energy_loss(graph, realities=[train_subset], reduce=False, use_l1=use_l1)

            if len(energy_losses_headings) == 0:
                energy_losses_headings = sorted([loss_name for loss_name in losses if 'percep' in loss_name])

            all_perceps = [losses[loss_name].cpu().numpy() for loss_name in energy_losses_headings]
            direct_losses = [losses[loss_name].cpu().numpy() for loss_name in losses if 'direct' in loss_name]

            if len(all_perceps) > 0:
                energy_losses_all += [all_perceps]
                all_perceps = np.stack(all_perceps)
                energy_losses += list(all_perceps.mean(0))

            if len(direct_losses) > 0:
                direct_losses = np.stack(direct_losses)
                error_losses += list(direct_losses.mean(0))

            if False:
                fnames += train_subset.task_data[tasks.filename]
        train_subset.step()


    # log losses
    if len(energy_losses) > 0:
        energy_losses = np.array(energy_losses)
        print(f'energy = {energy_losses.mean()}')

        energy_mean_by_blur += [energy_losses.mean()]
        energy_std_by_blur += [np.std(energy_losses)]

    if len(error_losses) > 0:
        error_losses = np.array(error_losses)
        print(f'error = {error_losses.mean()}')

        error_mean_by_blur += [error_losses.mean()]
        error_std_by_blur += [np.std(error_losses)]

    # save to csv
    save_error_losses = error_losses if len(error_losses) > 0 else [0] * subset_size
    save_energy_losses = energy_losses if len(energy_losses) > 0 else [0] * subset_size

    z_score = lambda x: (x - x.mean()) / x.std()
    def get_standardized_energy(df, use_std=False, compare_to_in_domain=False):
        percepts = [c for c in df.columns if 'percep' in c]
        stdize = lambda x: (x - x.mean()).abs().mean()
        means = {k: df[k].mean() for k in percepts}
        stds = {k: stdize(df[k]) for k in percepts}
        stdized = {k: (df[k] - means[k])/stds[k] for k in percepts}
        energies = np.stack([v for k, v in stdized.items() if k[-1] == '_' or '__' in k]).mean(0)
        return energies


    os.makedirs(save_dir, exist_ok=True)
    if data_dir is 'CUSTOM':
        eng_curr = np.array(energy_losses).mean()
        df = pd.read_csv(os.path.join(save_dir, 'data.csv'))
    else:
        percep_losses = { k: v for k, v in zip(energy_losses_headings, np.concatenate(energy_losses_all, axis=-1))}
        df = pd.DataFrame(both(
                        {'energy': save_energy_losses, 'error': save_error_losses },
                        percep_losses
        ))

    # compuate correlation
    df['normalized_energy'] = get_standardized_energy(df, use_std=False)
    df['normalized_error'] = z_score(df['error'])
    print(scipy.stats.spearmanr(z_score(df['error']), df['normalized_energy']))
    print("Pearson r:", scipy.stats.pearsonr(df['error'], df['normalized_energy']))

    if data_dir is not 'CUSTOM':
        df.to_csv(f"{save_dir}/data.csv", mode='w', header=True)

    # plot correlation
    plt.figure(figsize=(4,4))
    g = sns.regplot(df['normalized_error'], df['normalized_energy'],robust=False)
    if data_dir is 'CUSTOM':
        ax1 = g.axes
        ax1.axhline(eng_curr, ls='--', color='red')
        ax1.text(0.5, 25, "Query Image Energy Line")
    plt.xlabel('Error (z-score)')
    plt.ylabel('Energy (z-score)')
    plt.title('')
    plt.savefig(f'{save_dir}/energy.pdf')



if __name__ == "__main__":
    Fire(main)


================================================
FILE: scripts/jobinfo.txt
================================================
CH_lbp_all_winrate_depthtarget_1, 0, /scratch-data


================================================
FILE: task_configs.py
================================================

import numpy as np
import random, sys, os, time, glob, math, itertools, json, copy
from collections import defaultdict, namedtuple
from functools import partial

import PIL
from PIL import Image
from scipy import ndimage

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.optim as optim
from torchvision import transforms

from utils import *
from models import DataParallelModel
from modules.unet import UNet, UNetOld2, UNetOld
from modules.percep_nets import Dense1by1Net
from modules.depth_nets import UNetDepth
from modules.resnet import ResNetClass
import IPython

from PIL import ImageFilter
from skimage.filters import gaussian


class GaussianBulr(object):
    def __init__(self, radius):
        self.radius = radius
        self.filter = ImageFilter.GaussianBlur(radius)

    def __call__(self, im):
        return im.filter(self.filter)

    def __repr__(self):
        return 'GaussianBulr Filter with Radius {:d}'.format(self.radius)


""" Model definitions for launching new transfer jobs between tasks. """
model_types = {
    ('normal', 'principal_curvature'): lambda : Dense1by1Net(),
    ('normal', 'depth_zbuffer'): lambda : UNetDepth(),
    ('normal', 'reshading'): lambda : UNet(downsample=5),
    ('depth_zbuffer', 'normal'): lambda : UNet(downsample=6, in_channels=1, out_channels=3),
    ('reshading', 'normal'): lambda : UNet(downsample=4, in_channels=3, out_channels=3),
    ('sobel_edges', 'principal_curvature'): lambda : UNet(downsample=5, in_channels=1, out_channels=3),
    ('depth_zbuffer', 'principal_curvature'): lambda : UNet(downsample=4, in_channels=1, out_channels=3),
    ('principal_curvature', 'depth_zbuffer'): lambda : UNet(downsample=6, in_channels=3, out_channels=1),
    ('rgb', 'normal'): lambda : UNet(downsample=6),
    ('rgb', 'keypoints2d'): lambda : UNet(downsample=3, out_channels=1),
}

def get_model(src_task, dest_task):

    if isinstance(src_task, str) and isinstance(dest_task, str):
        src_task, dest_task = get_task(src_task), get_task(dest_task)

    if (src_task.name, dest_task.name) in model_types:
        return model_types[(src_task.name, dest_task.name)]()

    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ImageTask):
        return UNet(downsample=5, in_channels=src_task.shape[0], out_channels=dest_task.shape[0])

    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ClassTask):
        return ResNet(in_channels=src_task.shape[0], out_channels=dest_task.classes)

    elif isinstance(src_task, ImageTask) and isinstance(dest_task, PointInfoTask):
        return ResNet(out_channels=dest_task.out_channels)

    return None



"""
Abstract task type definitions.
Includes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask.
"""

class Task(object):
    variances = {
        "normal": 1.0,
        "principal_curvature": 1.0,
        "sobel_edges": 5,
        "depth_zbuffer": 0.1,
        "reshading": 1.0,
        "keypoints2d": 0.3,
        "keypoints3d": 0.6,
        "edge_occlusion": 0.1,
    }
    """ General task output space"""
    def __init__(self, name,
            file_name=None, file_name_alt=None, file_ext="png", file_loader=None,
            plot_func=None
        ):

        super().__init__()
        self.name = name
        self.file_name, self.file_ext = file_name or name, file_ext
        self.file_name_alt = file_name_alt or self.file_name
        self.file_loader = file_loader or self.file_loader
        self.plot_func = plot_func or self.plot_func
        self.variance = Task.variances.get(name, 1.0)
        self.kind = name

    def norm(self, pred, target, batch_mean=True, compute_mse=True):
        if batch_mean:
            loss = ((pred - target)**2).mean() if compute_mse else ((pred - target).abs()).mean()
        else:
            loss = ((pred - target)**2).mean(dim=1).mean(dim=1).mean(dim=1) if compute_mse \
                    else ((pred - target).abs()).mean(dim=1).mean(dim=1).mean(dim=1)

        return loss, (loss.mean().detach(),)

    def __call__(self, size=256):
        task = copy.deepcopy(self)
        return task

    def plot_func(self, data, name, logger, **kwargs):
        ### Non-image tasks cannot be easily plotted, default to nothing
        pass

    def file_loader(self, path, resize=None, seed=0, T=0):
        raise NotImplementedError()

    def __eq__(self, other):
        return self.name == other.name

    def __repr__(self):
        return self.name

    def __hash__(self):
        return hash(self.name)


"""
Abstract task type definitions.
Includes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask.
"""

class RealityTask(Task):
    """ General task output space"""

    def __init__(self, name, dataset, tasks=None, use_dataset=True, shuffle=False, batch_size=64):

        super().__init__(name=name)
        self.tasks = tasks if tasks is not None else \
            (dataset.dataset.tasks if hasattr(dataset, "dataset") else dataset.tasks)
        self.shape = (1,)
        if not use_dataset: return
        self.dataset, self.shuffle, self.batch_size = dataset, shuffle, batch_size
        loader = torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size,
            num_workers=0, shuffle=self.shuffle, pin_memory=True
        )
        self.generator = cycle(loader)
        self.step()
        self.static = False

    @classmethod
    def from_dataloader(cls, name, loader, tasks):
        reality = cls(name, None, tasks, use_dataset=False)
        reality.loader = loader
        reality.generator = cycle(loader)
        reality.static = False
        reality.step()
        return reality

    @classmethod
    def from_static(cls, name, data, tasks):
        reality = cls(name, None, tasks, use_dataset=False)
        reality.task_data = {task: x.requires_grad_() for task, x in zip(tasks, data)}
        reality.static = True
        return reality

    def norm(self, pred, target, batch_mean=True):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

    def step(self):
        self.task_data = {task: x.requires_grad_() for task, x in zip(self.tasks, next(self.generator))}

    def reload(self):
        loader = torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size,
            num_workers=0, shuffle=self.shuffle, pin_memory=True
        )
        self.generator = cycle(loader)

class ImageTask(Task):
    """ Output space for image-style tasks """

    def __init__(self, *args, **kwargs):

        self.shape = kwargs.pop("shape", (3, 256, 256))
        self.mask_val = kwargs.pop("mask_val", -1.0)
        self.transform = kwargs.pop("transform", lambda x: x)
        self.resize = kwargs.pop("resize", self.shape[1])
        self.blur_radius = None
        self.image_transform = self.load_image_transform()
        super().__init__(*args, **kwargs)

    @staticmethod
    def build_mask(target, val=0.0, tol=1e-3):
        if target.shape[1] == 1:
            mask = ((target >= val - tol) & (target <= val + tol))
            mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0
            return (~mask).expand_as(target)

        mask1 = (target[:, 0, :, :] >= val - tol) & (target[:, 0, :, :] <= val + tol)
        mask2 = (target[:, 1, :, :] >= val - tol) & (target[:, 1, :, :] <= val + tol)
        mask3 = (target[:, 2, :, :] >= val - tol) & (target[:, 2, :, :] <= val + tol)
        mask = (mask1 & mask2 & mask3).unsqueeze(1)
        mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0
        return (~mask).expand_as(target)

    def norm(self, pred, target, batch_mean=True, compute_mask=0, compute_mse=True):
        if compute_mask:
            mask = ImageTask.build_mask(target, val=self.mask_val)
            return super().norm(pred*mask.float(), target*mask.float(), batch_mean=batch_mean, compute_mse=compute_mse)
        else:
            return super().norm(pred, target, batch_mean=batch_mean, compute_mse=compute_mse)

    def __call__(self, size=256, blur_radius=None):
        task = copy.deepcopy(self)
        task.shape = (3, size, size)
        task.resize = size
        task.blur_radius = blur_radius
        task.name +=  "blur" if blur_radius else str(size)
        task.base = self
        return task

    def plot_func(self, data, name, logger, resize=None, nrow=2):
        logger.images(data.clamp(min=0, max=1), name, nrow=nrow, resize=resize or self.resize)

    def file_loader(self, path, resize=None, crop=None, seed=0, jitter=False):
        image_transform = self.load_image_transform(resize=resize, crop=crop, seed=seed, jitter=jitter)
        return image_transform(Image.open(open(path, 'rb')))[0:3]

    def load_image_transform(self, resize=None, crop=None, seed=0, jitter=False):
        size = resize or self.resize
        random.seed(seed)
        jitter_transform = lambda x: x
        if jitter: jitter_transform = transforms.ColorJitter(0.4,0.4,0.4,0.1)
        crop_transform = lambda x: x
        if crop is not None: crop_transform = transforms.CenterCrop(crop)
        blur = [GaussianBulr(self.blur_radius)] if self.blur_radius else []
        return transforms.Compose(blur+[
            crop_transform,
            transforms.Resize(size, interpolation=PIL.Image.BILINEAR),
            jitter_transform,
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            self.transform]
        )

class ImageClassTask(ImageTask):
    """ Output space for image-class segmentation tasks """

    def __init__(self, *args, **kwargs):

        self.classes = kwargs.pop("classes", (3, 256, 256))
        super().__init__(*args, **kwargs)

    def norm(self, pred, target):
        loss = F.kl_div(F.log_softmax(pred, dim=1), F.softmax(target, dim=1))
        return loss, (loss.detach(),)

    def plot_func(self, data, name, logger, resize=None):
        _, idx = torch.max(data, dim=1)
        idx = idx.float()/16.0
        idx = idx.unsqueeze(1).expand(-1, 3, -1, -1)
        logger.images(idx.clamp(min=0, max=1), name, nrow=2, resize=resize or self.resize)

    def file_loader(self, path, resize=None):

        data = (self.image_transform(Image.open(open(path, 'rb')))*255.0).long()
        one_hot = torch.zeros((self.classes, data.shape[1], data.shape[2]))
        one_hot = one_hot.scatter_(0, data, 1)
        return one_hot


class PointInfoTask(Task):
    """ Output space for point-info prediction tasks (what models do we evem use?) """

    def __init__(self, *args, **kwargs):

        self.point_type = kwargs.pop("point_type", "vanishing_points_gaussian_sphere")
        self.out_channels = 9
        super().__init__(*args, **kwargs)

    def plot_func(self, data, name, logger):
        logger.window(name, logger.visdom.text, str(data.data.cpu().numpy()))

    def file_loader(self, path, resize=None):
        points = json.load(open(path))[self.point_type]
        return np.array(points['x'] + points['y'] + points['z'])




"""
Current list of task definitions.
Accessible via tasks.{TASK_NAME} or get_task("{TASK_NAME}")
"""

def clamp_maximum_transform(x, max_val=8000.0):
    x = x.unsqueeze(0).float() / max_val
    return x[0].clamp(min=0, max=1)

def crop_transform(x, max_val=8000.0):
    x = x.unsqueeze(0).float() / max_val
    return x[0].clamp(min=0, max=1)

def sobel_transform(x):
    image = x.data.cpu().numpy().mean(axis=0)
    blur = ndimage.filters.gaussian_filter(image, sigma=2, )
    sx = ndimage.sobel(blur, axis=0, mode='constant')
    sy = ndimage.sobel(blur, axis=1, mode='constant')
    sob = np.hypot(sx, sy)
    edge = torch.FloatTensor(sob).unsqueeze(0)
    return edge

def blur_transform(x, max_val=4000.0):
    if x.shape[0] == 1:
        x = x.squeeze(0)
    image = x.data.cpu().numpy()
    blur = ndimage.filters.gaussian_filter(image, sigma=2, )
    norm = torch.FloatTensor(blur).unsqueeze(0)**0.8 / (max_val**0.8)
    norm = norm.clamp(min=0, max=1)
    if norm.shape[0] != 1:
        norm = norm.unsqueeze(0)
    return norm

def get_task(task_name):
    return task_map[task_name]


tasks = [
    ImageTask('rgb'),
    ImageTask('imagenet', mask_val=0.0),
    ImageTask('normal', mask_val=0.502),
    ImageTask('principal_curvature', mask_val=0.0),
    ImageTask('depth_zbuffer',
        shape=(1, 256, 256),
        mask_val=1.0,
        transform=partial(clamp_maximum_transform, max_val=8000.0),
    ),
    ImageClassTask('segment_semantic',
        file_name_alt="segmentsemantic",
        shape=(16, 256, 256), classes=16,
    ),
    ImageTask('reshading', mask_val=0.0507),
    ImageTask('edge_occlusion',
        shape=(1, 256, 256),
        transform=partial(blur_transform, max_val=4000.0),
    ),
    ImageTask('sobel_edges',
        shape=(1, 256, 256),
        file_name='rgb',
        transform=sobel_transform,
    ),
    ImageTask('keypoints3d',
        shape=(1, 256, 256),
        transform=partial(clamp_maximum_transform, max_val=64131.0),
    ),
    ImageTask('keypoints2d',
        shape=(1, 256, 256),
        transform=partial(blur_transform, max_val=2000.0),
    ),
]


task_map = {task.name: task for task in tasks}
tasks = namedtuple('TaskMap', task_map.keys())(**task_map)


if __name__ == "__main__":
    IPython.embed()


================================================
FILE: tools/download_data.sh
================================================
##!/usr/bin/env bash

wget https://drive.switch.ch/index.php/s/0Fqr6t6cZsI0cp9/download
unzip download
rm download
cd data 
unzip -qqo albertville_rgb.zip
unzip -qqo albertville_normal.zip
unzip -qqo albertville_principal_curvature.zip
unzip -qqo almena_rgb.zip
unzip -qqo almena_normal.zip
unzip -qqo almena_principal_curvature.zip
rm albertville_rgb.zip albertville_normal.zip albertville_principal_curvature.zip almena_rgb.zip almena_normal.zip almena_principal_curvature.zip
cd -


================================================
FILE: tools/download_energy_graph_edges.sh
================================================
##!/usr/bin/env bash

SCRIPT_DIR=$( dirname "$0" )

FILE=./models/rgb2normal_consistency.pth
if [ -f "$FILE" ]; then
    echo "Found consistency network $FILE: skipping download of these networks"
else
    echo "Downloading consistency networks..."
   sh $SCRIPT_DIR/download_models.sh
fi

FILE=./models/normal2curvature.pth
if [ -f "$FILE" ]; then
    echo "Found perceptual network $FILE: skipping download of these networks"
else
   echo "Downloading perceptual networks..."
   sh $SCRIPT_DIR/download_percep_models.sh
fi

FILE=./models/rgb2principal_curvature.pth
if [ -f "$FILE" ]; then
    echo "Found energy-graph specific network $FILE: skipping download of these networks"
else
    echo "RGB2X energy networks..."
    # Get energy-graph-specific links
    wget https://drive.switch.ch/index.php/s/aZDOEBixS4W7mBL/download
    unzip download
    rm download
    mv energy_graph_edges/* models/
    rmdir energy_graph_edges
fi



================================================
FILE: tools/download_models.sh
================================================
##!/usr/bin/env bash

wget https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P/download
unzip download
rm download


================================================
FILE: tools/download_percep_models.sh
================================================
##!/usr/bin/env bash

wget https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE/download
unzip download
rm download
mv percep_models/* models/
rmdir percep_models

================================================
FILE: train.py
================================================
'''
  Name: train.py
  Desc: Executes training of a network with the consistency framework.

    Here are some options that may be specified for any model. If they have a
    default value, it is given at the end of the description in parens.

        Data pipeline:
            Data locations:
                'train_buildings': A list of the folders containing the training data. This
                    is defined in configs/split.txt.
                'val_buildings': As above, but for validation data.
                'data_dirs': The folder that all the data is stored in. This may just be
                    something like '/', and then all filenames in 'train_filenames' will
                    give paths relative to 'dataset_dir'. For example, if 'dataset_dir'='/',
                    then train_filenames might have entries like 'path/to/data/img_01.png'.
                    This is defiled in utils.py.

        Logging:
            'results_dir': An absolute path to where checkpoints are saved. This is
                defined in utils.py.

        Training:
            'batch_size': The size of each batch. (64)
            'num_epochs': The maximum number of epochs to train for. (800)
            'energy_config': {multiperceptual_targettask} The paths taken to compute the losses.
            'k': Number of perceptual loss chosen.
            'data_aug': {True, False} If data augmentation shuold be used during training.
                See TrainTaskDataset class in datasets.py for the types of data augmentation
                used. (False)

        Optimization:
            'initial_learning_rate': The initial learning rate to use for the model. (3e-5)


  Usage:
    python -m train multiperceptual_depth --batch-size 32 --k 8 --max-epochs 100
'''

import torch
import torch.nn as nn

from utils import *
from energy import get_energy_loss
from graph import TaskGraph
from logger import Logger, VisdomLogger
from datasets import load_train_val, load_test, load_ood
from task_configs import tasks, RealityTask
from transfers import functional_transfers

from fire import Fire

#import pdb

def main(
    loss_config="multiperceptual", mode="winrate", visualize=False,
    fast=False, batch_size=None,
    subset_size=None, max_epochs=800, dataaug=False, **kwargs,
):


    # CONFIG
    batch_size = batch_size or (4 if fast else 64)
    energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

    # DATA LOADING
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        energy_loss.get_tasks("train"),
        batch_size=batch_size, fast=fast,
        subset_size=subset_size,
        dataaug=dataaug,
    )

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2

    train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
    val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)

    if fast:
        train_dataset = val_dataset
        train_step, val_step = 2,2
        realities = [train, val]
    else:
        test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville'])
        test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
        realities = [train, val, test]
        # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do:
        # ood_set = load_ood(energy_loss.get_tasks("ood"))
        # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])
        # realities.append(ood)

    # GRAPH
    graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
        freeze_list=energy_loss.freeze_list,
        initialize_from_transfer=False,
    )
    graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    os.makedirs(RESULTS_DIR, exist_ok=True)
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1)
    energy_loss.logger_hooks(logger)
    energy_loss.plot_paths(graph, logger, realities, prefix="start")

    # BASELINE
    graph.eval()
    with torch.no_grad():
        for _ in range(0, val_step*4):
            val_loss, _ = energy_loss(graph, realities=[val])
            val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        for _ in range(0, train_step*4):
            train_loss, _ = energy_loss(graph, realities=[train])
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            train.step()
            logger.update("loss", train_loss)
    energy_loss.logger_update(logger)

    # TRAINING
    for epochs in range(0, max_epochs):

        logger.update("epoch", epochs)
        energy_loss.plot_paths(graph, logger, realities, prefix="")
        if visualize: return

        graph.train()
        for _ in range(0, train_step):
            train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True)
            train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
            graph.step(train_loss)
            train.step()
            logger.update("loss", train_loss)

        graph.eval()
        for _ in range(0, val_step):
            with torch.no_grad():
                val_loss, _ = energy_loss(graph, realities=[val])
                val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
            val.step()
            logger.update("loss", val_loss)

        energy_loss.logger_update(logger)

        logger.step()

if __name__ == "__main__":
    Fire(main)


================================================
FILE: transfers.py
================================================

import os, sys, math, random, itertools, functools
from collections import namedtuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.checkpoint import checkpoint as util_checkpoint
from torchvision import models

from utils import *
from models import TrainableModel, DataParallelModel
from task_configs import get_task, task_map, get_model, Task, RealityTask

from modules.percep_nets import DenseNet, Dense1by1Net, DenseKernelsNet, DeepNet, BaseNet, WideNet, PyramidNet
from modules.depth_nets import UNetDepth
from modules.unet import UNet, UNetOld, UNetOld2, UNetReshade
from modules.resnet import ResNetClass

from fire import Fire
import IPython


pretrained_transfers = {

    ('normal', 'principal_curvature'):
        (lambda: Dense1by1Net(), f"{MODELS_DIR}/normal2curvature.pth"),
    ('normal', 'depth_zbuffer'):
        (lambda: UNetDepth(), f"{MODELS_DIR}/normal2zdepth_zbuffer.pth"),
    ('normal', 'sobel_edges'):
        (lambda: UNet(out_channels=1, downsample=4).cuda(), f"{MODELS_DIR}/normal2edges2d.pth"),
    ('normal', 'reshading'):
        (lambda: UNetReshade(downsample=5), f"{MODELS_DIR}/normal2reshade.pth"),
    ('normal', 'keypoints3d'):
        (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/normal2keypoints3d.pth"),
    ('normal', 'keypoints2d'):
        (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/normal2keypoints2d_new.pth"),
    ('normal
Download .txt
gitextract_hs8roq0r/

├── .gitignore
├── Dockerfile
├── README.md
├── config/
│   ├── jobinfo.txt
│   ├── split.txt
│   ├── split_fullplus.txt
│   └── split_medium.txt
├── datasets.py
├── demo.py
├── energy.py
├── graph.py
├── hooks/
│   └── build
├── logger.py
├── models.py
├── modules/
│   ├── __init__.py
│   ├── depth_nets.py
│   ├── percep_nets.py
│   ├── resnet.py
│   ├── unet.py
│   └── unet_mirrored.py
├── plotting.py
├── requirements.txt
├── scripts/
│   ├── energy_calc.py
│   └── jobinfo.txt
├── task_configs.py
├── tools/
│   ├── download_data.sh
│   ├── download_energy_graph_edges.sh
│   ├── download_models.sh
│   └── download_percep_models.sh
├── train.py
├── transfers.py
└── utils.py
Download .txt
SYMBOL INDEX (322 symbols across 17 files)

FILE: datasets.py
  function load_train_val (line 26) | def load_train_val(train_tasks, val_tasks=None, fast=False,
  function load_all (line 60) | def load_all(tasks, buildings=None, batch_size=64, split_file="data/spli...
  function load_test (line 75) | def load_test(all_tasks, buildings=["almena", "albertville"], sample=4):
  function load_ood (line 96) | def load_ood(tasks=[tasks.rgb], ood_path=OOD_DIR, sample=21):
  class TaskDataset (line 107) | class TaskDataset(Dataset):
    method __init__ (line 109) | def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal...
    method reset_unpaired (line 141) | def reset_unpaired(self):
    method building_files (line 145) | def building_files(self, task, building):
    method building_files_raid (line 149) | def building_files_raid(self, task, building):
    method convert_path (line 152) | def convert_path(self, source_file, task):
    method convert_path_raid (line 165) | def convert_path_raid(self, full_file, task):
    method __len__ (line 173) | def __len__(self):
    method __getitem__ (line 176) | def __getitem__(self, idx):
  class TrainTaskDataset (line 196) | class TrainTaskDataset(TaskDataset):
    method __getitem__ (line 198) | def __getitem__(self, idx):
  class ImageDataset (line 220) | class ImageDataset(Dataset):
    method __init__ (line 222) | def __init__(
    method __len__ (line 243) | def __len__(self):
    method __getitem__ (line 246) | def __getitem__(self, idx):

FILE: demo.py
  function save_outputs (line 52) | def save_outputs(img_path, output_file_name):

FILE: energy.py
  function get_energy_loss (line 25) | def get_energy_loss(
  function generate_config (line 47) | def generate_config(perceptual_tasks, target_task=tasks.normal, tree_str...
  function coeff_hook (line 720) | def coeff_hook(coeff):
  class EnergyLoss (line 726) | class EnergyLoss(object):
    method __init__ (line 728) | def __init__(self, paths, losses, plots,
    method compute_paths (line 747) | def compute_paths(self, graph, reality=None, paths=None):
    method get_tasks (line 758) | def get_tasks(self, reality):
    method __call__ (line 773) | def __call__(self, graph, discriminator=None, realities=[], loss_types...
    method logger_hooks (line 826) | def logger_hooks(self, logger):
    method logger_update (line 850) | def logger_update(self, logger):
    method plot_paths (line 876) | def plot_paths(self, graph, logger, realities=[], plot_names=None, epo...
    method __repr__ (line 957) | def __repr__(self):
  class WinRateEnergyLoss (line 961) | class WinRateEnergyLoss(EnergyLoss):
    method __init__ (line 963) | def __init__(self, *args, **kwargs):
    method __call__ (line 975) | def __call__(self, graph, discriminator=None, realities=[], loss_types...
    method logger_update (line 1013) | def logger_update(self, logger):

FILE: graph.py
  class TaskGraph (line 21) | class TaskGraph(TrainableModel):
    method __init__ (line 25) | def __init__(
    method edge (line 82) | def edge(self, src_task, dest_task):
    method sample_path (line 88) | def sample_path(self, path, reality=None, use_cache=False, cache={}):
    method save (line 107) | def save(self, weights_file=None, weights_dir=None):
    method load_weights (line 130) | def load_weights(self, weights_file=None):

FILE: logger.py
  class BaseLogger (line 17) | class BaseLogger(object):
    method __init__ (line 19) | def __init__(self, name, verbose=True):
    method add_hook (line 28) | def add_hook(self, hook, feature='epoch', freq=40):
    method update (line 31) | def update(self, feature, x):
    method step (line 47) | def step(self):
    method text (line 61) | def text(self, text, end="\n"):
    method plot (line 64) | def plot(self, data, plot_name, opts={}):
    method images (line 67) | def images(self, data, image_name):
    method plot_feature (line 70) | def plot_feature(self, feature, opts={}):
    method plot_features (line 73) | def plot_features(self, features, name, opts={}):
  class Logger (line 78) | class Logger(BaseLogger):
    method __init__ (line 80) | def __init__(self, *args, **kwargs):
    method text (line 84) | def text(self, text, end='\n'):
    method plot (line 87) | def plot(self, data, plot_name, opts={}):
  class VisdomLogger (line 94) | class VisdomLogger(BaseLogger):
    method __init__ (line 96) | def __init__(self, *args, **kwargs):
    method text (line 111) | def text(self, text, end='\n'):
    method window (line 124) | def window(self, plot_name, plot_func, *args, **kwargs):
    method plot (line 136) | def plot(self, data, plot_name, opts={}):
    method histogram (line 141) | def histogram(self, data, plot_name, opts={}):
    method scatter (line 144) | def scatter(self, X, Y, plot_name, opts={}):
    method bar (line 147) | def bar(self, data, plot_name, opts={}):
    method save (line 150) | def save(self):
    method images (line 153) | def images(self, data, plot_name, opts={}, nrow=2, normalize=False, re...
    method images_grouped (line 163) | def images_grouped(self, image_groups, plot_name, **kwargs):

FILE: models.py
  class AbstractModel (line 19) | class AbstractModel(nn.Module):
    method __init__ (line 20) | def __init__(self):
    method compile (line 25) | def compile(self, optimizer=None, **kwargs):
    method predict_on_batch (line 38) | def predict_on_batch(self, data):
    method fit_on_batch (line 45) | def fit_on_batch(self, data, target, loss_fn=None, train=True):
    method step (line 74) | def step(self, loss, train=True):
    method load (line 88) | def load(cls, weights_file=None):
    method load_weights (line 97) | def load_weights(self, weights_file, backward_compatible=False):
    method save (line 103) | def save(self, weights_file):
    method loss (line 107) | def loss(self, pred, target):
    method forward (line 110) | def forward(self, x):
  class TrainableModel (line 119) | class TrainableModel(AbstractModel):
    method __init__ (line 120) | def __init__(self):
    method _process_data (line 124) | def _process_data(self, datagen, loss_fn=None, train=True, logger=None):
    method fit (line 136) | def fit(self, datagen, loss_fn=None, logger=None):
    method fit_with_data (line 140) | def fit_with_data(self, datagen, loss_fn=None, logger=None):
    method fit_with_metrics (line 148) | def fit_with_metrics(self, datagen, loss_fn=None, logger=None):
    method predict_with_data (line 157) | def predict_with_data(self, datagen, loss_fn=None, logger=None):
    method predict_with_metrics (line 167) | def predict_with_metrics(self, datagen, loss_fn=None, logger=None):
    method predict (line 176) | def predict(self, datagen):
  class DataParallelModel (line 182) | class DataParallelModel(TrainableModel):
    method __init__ (line 183) | def __init__(self, *args, **kwargs):
    method forward (line 187) | def forward(self, x):
    method loss (line 190) | def loss(self, x, preds):
    method module (line 194) | def module(self):
    method load (line 198) | def load(cls, model=TrainableModel(), weights_file=None):
  class WrapperModel (line 213) | class WrapperModel(TrainableModel):
    method __init__ (line 214) | def __init__(self, model):
    method forward (line 218) | def forward(self, x):
    method loss (line 221) | def loss(self, x, preds):
    method __getitem__ (line 224) | def __getitem__(self, i):
    method module (line 228) | def module(self):

FILE: modules/depth_nets.py
  class ResidualsNet (line 16) | class ResidualsNet(TrainableModel):
    method __init__ (line 17) | def __init__(self):
    method forward (line 36) | def forward(self, x):
    method loss (line 45) | def loss(self, pred, target):
  class UNet_up_block (line 50) | class UNet_up_block(nn.Module):
    method __init__ (line 51) | def __init__(self, prev_channel, input_channel, output_channel, up_sam...
    method forward (line 63) | def forward(self, prev_feature_map, x):
  class UNet_down_block (line 73) | class UNet_down_block(nn.Module):
    method __init__ (line 74) | def __init__(self, input_channel, output_channel, down_size=True):
    method forward (line 86) | def forward(self, x):
  class UNetDepth (line 96) | class UNetDepth(TrainableModel):
    method __init__ (line 97) | def __init__(self):
    method forward (line 127) | def forward(self, x):
    method loss (line 150) | def loss(self, pred, target):
  class ConvBlock (line 156) | class ConvBlock(nn.Module):
    method __init__ (line 157) | def __init__(self, f1, f2, use_groupnorm=True, groups=8, dilation=1, t...
    method forward (line 170) | def forward(self, x):
  class Bottleneck (line 180) | class Bottleneck(nn.Module):
    method __init__ (line 183) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 196) | def forward(self, x):
  class ResNetOriginal (line 218) | class ResNetOriginal(nn.Module):
    method __init__ (line 220) | def __init__(self, block, layers, num_classes=1000):
    method _make_layer (line 242) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 259) | def forward(self, x):
  class ResNetDepth (line 276) | class ResNetDepth(TrainableModel):
    method __init__ (line 277) | def __init__(self):
    method forward (line 296) | def forward(self, x):
    method loss (line 306) | def loss(self, pred, target):

FILE: modules/percep_nets.py
  class ConvBlock (line 17) | class ConvBlock(nn.Module):
    method __init__ (line 18) | def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=Tru...
    method forward (line 31) | def forward(self, x):
  class DenseNet (line 42) | class DenseNet(TrainableModel):
    method __init__ (line 43) | def __init__(self):
    method forward (line 54) | def forward(self, x):
    method loss (line 58) | def loss(self, pred, target):
  class Dense1by1Net (line 63) | class Dense1by1Net(TrainableModel):
    method __init__ (line 64) | def __init__(self):
    method forward (line 77) | def forward(self, x):
    method loss (line 81) | def loss(self, pred, target):
  class Dense1by1end (line 85) | class Dense1by1end(TrainableModel):
    method __init__ (line 86) | def __init__(self):
    method forward (line 99) | def forward(self, x):
    method loss (line 103) | def loss(self, pred, target):
  class DenseKernelsNet (line 107) | class DenseKernelsNet(TrainableModel):
    method __init__ (line 108) | def __init__(self, kernel_size=7):
    method forward (line 122) | def forward(self, x):
    method loss (line 126) | def loss(self, pred, target):
  class DeepNet (line 131) | class DeepNet(TrainableModel):
    method __init__ (line 132) | def __init__(self):
    method forward (line 145) | def forward(self, x):
    method loss (line 149) | def loss(self, pred, target):
  class WideNet (line 154) | class WideNet(TrainableModel):
    method __init__ (line 155) | def __init__(self):
    method forward (line 166) | def forward(self, x):
    method loss (line 170) | def loss(self, pred, target):
  class PyramidNet (line 175) | class PyramidNet(TrainableModel):
    method __init__ (line 176) | def __init__(self):
    method forward (line 188) | def forward(self, x):
    method loss (line 192) | def loss(self, pred, target):
  class BaseNet (line 198) | class BaseNet(TrainableModel):
    method __init__ (line 199) | def __init__(self):
    method forward (line 209) | def forward(self, x):
    method loss (line 213) | def loss(self, pred, target):
  class ResidualsNet (line 218) | class ResidualsNet(TrainableModel):
    method __init__ (line 219) | def __init__(self):
    method forward (line 236) | def forward(self, x):
    method loss (line 245) | def loss(self, pred, target):
  class ResNet50 (line 249) | class ResNet50(TrainableModel):
    method __init__ (line 250) | def __init__(self, num_classes=365, in_channels=3):
    method forward (line 256) | def forward(self, x):
    method loss (line 260) | def loss(self, pred, target):

FILE: modules/resnet.py
  function conv3x3 (line 17) | def conv3x3(in_planes, out_planes, stride=1, groups=1):
  function conv1x1 (line 23) | def conv1x1(in_planes, out_planes, stride=1):
  class BasicBlock (line 29) | class BasicBlock(nn.Module):
    method __init__ (line 32) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 46) | def forward(self, x):
  class Bottleneck (line 64) | class Bottleneck(nn.Module):
    method __init__ (line 67) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 80) | def forward(self, x):
  class ResNetOriginal (line 102) | class ResNetOriginal(nn.Module):
    method __init__ (line 104) | def __init__(self, block, layers, in_channels=3, num_classes=1000):
    method _make_layer (line 126) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 143) | def forward(self, x):
  class ResNet (line 160) | class ResNet(TrainableModel):
    method __init__ (line 161) | def __init__(self, in_channels=3, out_channels=1000):
    method forward (line 166) | def forward(self, x):
    method loss (line 175) | def loss(self, pred, target):
  class ResNetClass (line 180) | class ResNetClass(TrainableModel):
    method __init__ (line 181) | def __init__(self):
    method forward (line 185) | def forward(self, x):
    method loss (line 192) | def loss(self, pred, target):

FILE: modules/unet.py
  class UNet_up_block (line 17) | class UNet_up_block(nn.Module):
    method __init__ (line 18) | def __init__(self, prev_channel, input_channel, output_channel, up_sam...
    method forward (line 30) | def forward(self, prev_feature_map, x):
  class UNet_down_block (line 40) | class UNet_down_block(nn.Module):
    method __init__ (line 41) | def __init__(self, input_channel, output_channel, down_size=True):
    method forward (line 53) | def forward(self, x):
  class UNet (line 62) | class UNet(TrainableModel):
    method __init__ (line 63) | def __init__(self,  downsample=6, in_channels=3, out_channels=3):
    method forward (line 89) | def forward(self, x):
    method loss (line 107) | def loss(self, pred, target):
  class UNetReshade (line 111) | class UNetReshade(TrainableModel):
    method __init__ (line 112) | def __init__(self,  downsample=6, in_channels=3, out_channels=3):
    method forward (line 138) | def forward(self, x):
    method loss (line 158) | def loss(self, pred, target):
  class UNetOld (line 163) | class UNetOld(TrainableModel):
    method __init__ (line 164) | def __init__(self, in_channels=3, out_channels=3):
    method forward (line 195) | def forward(self, x):
    method loss (line 218) | def loss(self, pred, target):
  class ConvBlock (line 223) | class ConvBlock(nn.Module):
    method __init__ (line 224) | def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=Tru...
    method forward (line 237) | def forward(self, x):
  class UNetOld2 (line 247) | class UNetOld2(TrainableModel):
    method __init__ (line 248) | def __init__(self, in_channels=3, out_channels=3):
    method forward (line 283) | def forward(self, x):
    method loss (line 307) | def loss(self, pred, target):

FILE: modules/unet_mirrored.py
  class UNet_up_block (line 17) | class UNet_up_block(nn.Module):
    method __init__ (line 18) | def __init__(self, prev_channel, input_channel, output_channel, up_sam...
    method forward (line 30) | def forward(self, prev_feature_map, x):
  class UNet_down_block (line 40) | class UNet_down_block(nn.Module):
    method __init__ (line 41) | def __init__(self, input_channel, output_channel, down_size=True):
    method forward (line 53) | def forward(self, x):
  class UNet (line 62) | class UNet(TrainableModel):
    method __init__ (line 63) | def __init__(self, downsample=6, in_channels=3, out_channels=3):
    method forward (line 89) | def forward(self, x):
    method loss (line 114) | def loss(self, pred, target):
  class UNetReshade (line 118) | class UNetReshade(TrainableModel):
    method __init__ (line 119) | def __init__(self,  downsample=6, in_channels=3, out_channels=3):
    method forward (line 145) | def forward(self, x):
    method loss (line 165) | def loss(self, pred, target):
  class UNetOld (line 170) | class UNetOld(TrainableModel):
    method __init__ (line 171) | def __init__(self, in_channels=3, out_channels=3):
    method forward (line 202) | def forward(self, x):
    method loss (line 225) | def loss(self, pred, target):
  class ConvBlock (line 230) | class ConvBlock(nn.Module):
    method __init__ (line 231) | def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=Tru...
    method forward (line 244) | def forward(self, x):
  class UNetOld2 (line 254) | class UNetOld2(TrainableModel):
    method __init__ (line 255) | def __init__(self, in_channels=3, out_channels=3):
    method forward (line 290) | def forward(self, x):
    method loss (line 314) | def loss(self, pred, target):

FILE: plotting.py
  function jointplot (line 3) | def jointplot(logger, data, loss_type="mse_loss"):
  function get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step (line 7) | def get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(lis...
  function get_running_means_w_std_bounds_and_legend (line 23) | def get_running_means_w_std_bounds_and_legend(list_of_list_values):
  function get_running_std (line 35) | def get_running_std(list_of_list_values):
  function get_running_p_coeffs (line 39) | def get_running_p_coeffs(list_of_list_values_1, list_of_list_values_2):
  function mseplots (line 54) | def mseplots(data, logger):
  function curvatureplots (line 69) | def curvatureplots(data, logger):
  function depthplots (line 85) | def depthplots(data, logger):
  function covarianceplot (line 100) | def covarianceplot(data, logger):

FILE: scripts/energy_calc.py
  function main (line 30) | def main(

FILE: task_configs.py
  class GaussianBulr (line 30) | class GaussianBulr(object):
    method __init__ (line 31) | def __init__(self, radius):
    method __call__ (line 35) | def __call__(self, im):
    method __repr__ (line 38) | def __repr__(self):
  function get_model (line 56) | def get_model(src_task, dest_task):
  class Task (line 82) | class Task(object):
    method __init__ (line 94) | def __init__(self, name,
    method norm (line 108) | def norm(self, pred, target, batch_mean=True, compute_mse=True):
    method __call__ (line 117) | def __call__(self, size=256):
    method plot_func (line 121) | def plot_func(self, data, name, logger, **kwargs):
    method file_loader (line 125) | def file_loader(self, path, resize=None, seed=0, T=0):
    method __eq__ (line 128) | def __eq__(self, other):
    method __repr__ (line 131) | def __repr__(self):
    method __hash__ (line 134) | def __hash__(self):
  class RealityTask (line 143) | class RealityTask(Task):
    method __init__ (line 146) | def __init__(self, name, dataset, tasks=None, use_dataset=True, shuffl...
    method from_dataloader (line 163) | def from_dataloader(cls, name, loader, tasks):
    method from_static (line 172) | def from_static(cls, name, data, tasks):
    method norm (line 178) | def norm(self, pred, target, batch_mean=True):
    method step (line 182) | def step(self):
    method reload (line 185) | def reload(self):
  class ImageTask (line 192) | class ImageTask(Task):
    method __init__ (line 195) | def __init__(self, *args, **kwargs):
    method build_mask (line 206) | def build_mask(target, val=0.0, tol=1e-3):
    method norm (line 219) | def norm(self, pred, target, batch_mean=True, compute_mask=0, compute_...
    method __call__ (line 226) | def __call__(self, size=256, blur_radius=None):
    method plot_func (line 235) | def plot_func(self, data, name, logger, resize=None, nrow=2):
    method file_loader (line 238) | def file_loader(self, path, resize=None, crop=None, seed=0, jitter=Fal...
    method load_image_transform (line 242) | def load_image_transform(self, resize=None, crop=None, seed=0, jitter=...
  class ImageClassTask (line 259) | class ImageClassTask(ImageTask):
    method __init__ (line 262) | def __init__(self, *args, **kwargs):
    method norm (line 267) | def norm(self, pred, target):
    method plot_func (line 271) | def plot_func(self, data, name, logger, resize=None):
    method file_loader (line 277) | def file_loader(self, path, resize=None):
  class PointInfoTask (line 285) | class PointInfoTask(Task):
    method __init__ (line 288) | def __init__(self, *args, **kwargs):
    method plot_func (line 294) | def plot_func(self, data, name, logger):
    method file_loader (line 297) | def file_loader(self, path, resize=None):
  function clamp_maximum_transform (line 309) | def clamp_maximum_transform(x, max_val=8000.0):
  function crop_transform (line 313) | def crop_transform(x, max_val=8000.0):
  function sobel_transform (line 317) | def sobel_transform(x):
  function blur_transform (line 326) | def blur_transform(x, max_val=4000.0):
  function get_task (line 337) | def get_task(task_name):

FILE: train.py
  function main (line 55) | def main(

FILE: transfers.py
  class Transfer (line 117) | class Transfer(nn.Module):
    method __init__ (line 119) | def __init__(self, src_task, dest_task,
    method load_model (line 168) | def load_model(self):
    method __call__ (line 180) | def __call__(self, x):
    method __repr__ (line 186) | def __repr__(self):
  class RealityTransfer (line 190) | class RealityTransfer(Transfer):
    method __init__ (line 192) | def __init__(self, src_task, dest_task):
    method load_model (line 195) | def load_model(self, optimizer=True):
    method __call__ (line 198) | def __call__(self, x):
  class FineTunedTransfer (line 203) | class FineTunedTransfer(Transfer):
    method __init__ (line 205) | def __init__(self, transfer):
    method load_model (line 209) | def load_model(self, parents=[]):
    method __call__ (line 224) | def __call__(self, x):
  function get_transfer_name (line 287) | def get_transfer_name(transfer):

FILE: utils.py
  function both (line 34) | def both(x, y):
  function elapsed (line 39) | def elapsed(last_time=[time.time()]):
  function cycle (line 46) | def cycle(iterable):
  function average (line 52) | def average(arr):
  function get_files (line 64) | def get_files(exp, data_dirs=DATA_DIRS, recursive=False):
  function get_finetuned_model_path (line 82) | def get_finetuned_model_path(parents):
  function plot_images (line 89) | def plot_images(model, logger, test_set, dest_task="normal",
  function gaussian_filter (line 121) | def gaussian_filter(channels=3, kernel_size=5, sigma=1.0, device=0):
  function motion_blur_filter (line 140) | def motion_blur_filter(kernel_size=15):
  function sobel_kernel (line 150) | def sobel_kernel(x):
  class SobelKernel (line 164) | class SobelKernel(nn.Module):
    method __init__ (line 165) | def __init__(self):
    method forward (line 168) | def forward(self, x):
  function set_seed (line 171) | def set_seed(seed):
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (222K chars).
[
  {
    "path": ".gitignore",
    "chars": 183,
    "preview": ".vscode\nraw\n__pycache__/*\nsftp-config*.json\nmodels\nprocessed\n*.pth\n*.tar\noutput\n*.gz\n.DS_Store/*\nresult\ncheckpoints\n*.py"
  },
  {
    "path": "Dockerfile",
    "chars": 3368,
    "preview": "FROM nvidia/cuda:10.1-base-ubuntu16.04\nLABEL version=\"1.0\"\nLABEL description=\"Build using the command \\\n  'docker build "
  },
  {
    "path": "README.md",
    "chars": 27130,
    "preview": "# Robust Learning Through Cross-Task Consistency <br> \n\n[![](./assets/intro.jpg)](https://consistency.epfl.ch)\n\n<table>\n"
  },
  {
    "path": "config/jobinfo.txt",
    "chars": 27,
    "preview": "normaltarget_allperceps, .\n"
  },
  {
    "path": "config/split.txt",
    "chars": 5350,
    "preview": "train_buildings: [woodbine, hometown, haymarket, emmaus, swormville, haxtun, martinville,\n  winfield, marksville, hammon"
  },
  {
    "path": "config/split_fullplus.txt",
    "chars": 4466,
    "preview": "train_buildings: [hanson, merom, arbutus, goodfield, eagan, arona, adairsville, reserve, aloha, castor, munsons, ballou,"
  },
  {
    "path": "config/split_medium.txt",
    "chars": 1212,
    "preview": "train_buildings: [hanson, merom, goodfield, eagan, adairsville, castor, klickitat, cottonport, tyler, sugarville, martin"
  },
  {
    "path": "datasets.py",
    "chars": 10131,
    "preview": "\nimport numpy as np\nimport matplotlib as mpl\n\nimport os, sys, math, random, tarfile, glob, time, yaml, itertools\nimport "
  },
  {
    "path": "demo.py",
    "chars": 2381,
    "preview": "import torch\nfrom torchvision import transforms\n\nfrom modules.unet import UNet, UNetReshade\n\nimport PIL\nfrom PIL import "
  },
  {
    "path": "energy.py",
    "chars": 37803,
    "preview": "import os, sys, math, random, itertools\nfrom functools import partial\nimport numpy as np\n\nimport torch\nimport torch.nn a"
  },
  {
    "path": "graph.py",
    "chars": 5646,
    "preview": "import os, sys, math, random, itertools, heapq\nfrom collections import namedtuple, defaultdict\nfrom functools import par"
  },
  {
    "path": "hooks/build",
    "chars": 160,
    "preview": "#!/bin/bash\n\ndocker build . -t $IMAGE_NAME --build-arg GITHUB_DEPLOY_KEY=\"$GITHUB_DEPLOY_KEY\" --build-arg GITHUB_DEPLOY_"
  },
  {
    "path": "logger.py",
    "chars": 5799,
    "preview": "\nimport numpy as np\nimport matplotlib as mpl\nmpl.use('Agg')\nimport matplotlib.pyplot as plt\nimport random, sys, os, json"
  },
  {
    "path": "models.py",
    "chars": 7288,
    "preview": "import os, sys, random\nfrom inspect import signature\nimport numpy as np\nimport matplotlib as mpl\nimport torch\nimport tor"
  },
  {
    "path": "modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "modules/depth_nets.py",
    "chars": 10887,
    "preview": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.function"
  },
  {
    "path": "modules/percep_nets.py",
    "chars": 7616,
    "preview": "\n\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functio"
  },
  {
    "path": "modules/resnet.py",
    "chars": 6341,
    "preview": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.function"
  },
  {
    "path": "modules/unet.py",
    "chars": 12172,
    "preview": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.function"
  },
  {
    "path": "modules/unet_mirrored.py",
    "chars": 12718,
    "preview": "\nimport os, sys, math, random, itertools\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.function"
  },
  {
    "path": "plotting.py",
    "chars": 6195,
    "preview": "import numpy as np\n\ndef jointplot(logger, data, loss_type=\"mse_loss\"):\n    data = np.stack((data[f\"train_{loss_type}\"], "
  },
  {
    "path": "requirements.txt",
    "chars": 279,
    "preview": "fire==0.2.1\nipython==6.5.0\nmatplotlib==3.0.3\nnumpy==1.17.2\nparse==1.12.1\npip==19.3.1\nplac==0.9.6\npy==1.6.0\nscipy==1.3.1\n"
  },
  {
    "path": "scripts/energy_calc.py",
    "chars": 6645,
    "preview": "import os, sys, math, random, itertools\nimport numpy as np\nimport scipy\nfrom collections import defaultdict\nfrom tqdm im"
  },
  {
    "path": "scripts/jobinfo.txt",
    "chars": 51,
    "preview": "CH_lbp_all_winrate_depthtarget_1, 0, /scratch-data\n"
  },
  {
    "path": "task_configs.py",
    "chars": 13409,
    "preview": "\nimport numpy as np\nimport random, sys, os, time, glob, math, itertools, json, copy\nfrom collections import defaultdict,"
  },
  {
    "path": "tools/download_data.sh",
    "chars": 484,
    "preview": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/0Fqr6t6cZsI0cp9/download\nunzip download\nrm download\ncd da"
  },
  {
    "path": "tools/download_energy_graph_edges.sh",
    "chars": 935,
    "preview": "##!/usr/bin/env bash\n\nSCRIPT_DIR=$( dirname \"$0\" )\n\nFILE=./models/rgb2normal_consistency.pth\nif [ -f \"$FILE\" ]; then\n   "
  },
  {
    "path": "tools/download_models.sh",
    "chars": 115,
    "preview": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P/download\nunzip download\nrm download\n"
  },
  {
    "path": "tools/download_percep_models.sh",
    "chars": 161,
    "preview": "##!/usr/bin/env bash\n\nwget https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE/download\nunzip download\nrm download\nmv pe"
  },
  {
    "path": "train.py",
    "chars": 5763,
    "preview": "'''\n  Name: train.py\n  Desc: Executes training of a network with the consistency framework.\n\n    Here are some options t"
  },
  {
    "path": "transfers.py",
    "chars": 12485,
    "preview": "\nimport os, sys, math, random, itertools, functools\nfrom collections import namedtuple\nimport numpy as np\n\nimport torch\n"
  },
  {
    "path": "utils.py",
    "chars": 6015,
    "preview": "\nimport numpy as np\nimport random, sys, os, time, glob, math, itertools, pickle\nimport parse\nfrom collections import def"
  }
]

About this extraction

This page contains the full source code of the EPFL-VILAB/XTConsistency GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (208.2 KB), approximately 55.5k tokens, and a symbol index with 322 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!