Repository: EPFL-VILAB/XTConsistency Branch: master Commit: 6d051610358d Files: 32 Total size: 208.2 KB Directory structure: gitextract_hs8roq0r/ ├── .gitignore ├── Dockerfile ├── README.md ├── config/ │ ├── jobinfo.txt │ ├── split.txt │ ├── split_fullplus.txt │ └── split_medium.txt ├── datasets.py ├── demo.py ├── energy.py ├── graph.py ├── hooks/ │ └── build ├── logger.py ├── models.py ├── modules/ │ ├── __init__.py │ ├── depth_nets.py │ ├── percep_nets.py │ ├── resnet.py │ ├── unet.py │ └── unet_mirrored.py ├── plotting.py ├── requirements.txt ├── scripts/ │ ├── energy_calc.py │ └── jobinfo.txt ├── task_configs.py ├── tools/ │ ├── download_data.sh │ ├── download_energy_graph_edges.sh │ ├── download_models.sh │ └── download_percep_models.sh ├── train.py ├── transfers.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .vscode raw __pycache__/* sftp-config*.json models processed *.pth *.tar output *.gz .DS_Store/* result checkpoints *.pyc data/results .DS_Store local command.txt *.ipynb_checkpoints ================================================ FILE: Dockerfile ================================================ FROM nvidia/cuda:10.1-base-ubuntu16.04 LABEL version="1.0" LABEL description="Build using the command \ 'docker build -t epflvil/xtconsistency:latest .'" ARG DEFAULT_GIT_BRANCH=master ARG DEFAULT_GIT_REPO=git@github.com:EPFL-VIL/XTConsistency.git ARG GITHUB_DEPLOY_KEY_PATH=docker_key ARG GITHUB_DEPLOY_KEY ARG GITHUB_DEPLOY_KEY_PUBLIC RUN apt-get update && apt-get install -y \ curl \ wget \ ca-certificates \ sudo \ git \ unzip \ bzip2 \ libx11-6 \ nano \ screen \ gcc \ python3-dev \ && rm -rf /var/lib/apt/lists/* RUN mkdir /root/.ssh RUN echo "DEPLOY" "${GITHUB_DEPLOY_KEY}" RUN echo "DEPLOY" "${GITHUB_DEPLOY_KEY_PUBLIC}" RUN echo "${GITHUB_DEPLOY_KEY}" > /root/.ssh/id_rsa RUN echo "${GITHUB_DEPLOY_KEY_PUBLIC}" > /root/.ssh/id_rsa.pub RUN chmod 600 /root/.ssh/id_rsa RUN cat /root/.ssh/id_rsa* RUN eval $(ssh-agent) && \ ssh-add /root/.ssh/id_rsa && \ ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts RUN git clone --single-branch --branch "${DEFAULT_GIT_BRANCH}" "${DEFAULT_GIT_REPO}" /app ############################# # Pull code ############################# # RUN mkdir /app WORKDIR /app RUN cd /app && git config core.filemode false RUN chmod -R 777 /app ############################# # Create non-root user ############################# # Create a non-root user and switch to it RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ && chown -R user:user /app RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user USER user # All users can use /home/user as their home directory ENV HOME=/home/user RUN chmod 777 /home/user ############################# # Create conda environment ############################# # Install Miniconda RUN curl -Lso ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh \ && chmod +x ~/miniconda.sh \ && ~/miniconda.sh -b -p ~/miniconda \ && rm ~/miniconda.sh ENV PATH=/home/user/miniconda/bin:$PATH ENV CONDA_AUTO_UPDATE_CONDA=false # Create a Python 3.6 environment RUN /home/user/miniconda/bin/conda create -y --name py36 python=3.6.9 \ && /home/user/miniconda/bin/conda clean -ya ENV CONDA_DEFAULT_ENV=py36 ENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV ENV PATH=$CONDA_PREFIX/bin:$PATH RUN /home/user/miniconda/bin/conda install conda-build=3.18.9=py36_3 \ && /home/user/miniconda/bin/conda clean -ya ############################# # Python packages ############################# RUN conda install -y -c pytorch \ cudatoolkit=10.1 \ "pytorch=1.4.0" \ "torchvision=0.5.0" \ && conda clean -ya RUN conda install -y \ ipython==6.5.0 \ matplotlib==3.0.3 \ plac==0.9.6 \ py==1.6.0 \ scipy==1.3.1 \ tqdm==4.36.1 \ pathlib==1.0.1 \ seaborn==0.10.0 \ scikit-learn==0.22.1 \ scikit-image==0.16.2 \ && conda clean -ya RUN conda install -c conda-forge jupyterlab && conda clean -ya RUN pip install runstats==1.8.0 \ fire==0.2.1 \ visdom==0.1.8.9 \ parse==1.12.1 ############################################### # Default command and environment variables ############################################### RUN sudo touch /root/.bashrc && sudo chmod 770 /root/.bashrc RUN echo export PATH="\$PATH:"$PATH >> /tmp/.bashrc RUN sudo su -c 'cat /tmp/.bashrc >> /root/.bashrc' && rm /tmp/.bashrc # Set the default command to bash CMD ["bash"] ================================================ FILE: README.md ================================================ # Robust Learning Through Cross-Task Consistency
[![](./assets/intro.jpg)](https://consistency.epfl.ch)
Above: A comparison of the results from consistency-based learning and learning each task individually. The yellow markers highlight the improvement in fine grained details.

This repository contains tools for training and evaluating models using consistency: - [Pretrained models](#pretrained-models) - [Demo code](#quickstart-run-demo-locally) and an **[online live demo](https://consistency.epfl.ch/demo/)** - [_Uncertainty energy_ estimation code](#Energy-computation) - [Training scripts](#training) - [Docker and installation instructions](#installation) for the following paper:

Robust Learing Through Cross-Task Consistency (CVPR 2020, Best Paper Award Nomination, Oral)


[![Cross-Task Consistency Results](./assets/vid_thumbnail_600_gif2.gif)](https://youtu.be/6CSmmrBNX9M "Click to watch results of applying the networks trained with cross-task consisntency frame-by-frame on sample YouTube videos.") For further details, a [live demo](https://consistency.epfl.ch/demo/), [video visualizations](https://consistency.epfl.ch/visuals/), and an [overview talk](https://consistency.epfl.ch/#paper), refer to our [project website](https://consistency.epfl.ch/). #### PROJECT WEBSITE:
| [LIVE DEMO](https://consistency.epfl.ch/demo/) | [VIDEO VISUALIZATION](https://consistency.epfl.ch/visuals/) |:----:|:----:| | Upload your own images and see the results of different consistency-based models vs. various baselines.

[](https://consistency.epfl.ch/demo/) | Visualize models with and without consistency, evaluated on a (non-cherry picked) YouTube video.


[](https://consistency.epfl.ch/visuals/) |
--- Table of Contents ================= * [Introduction](#introduction) * [Installation](#installation) * [Quickstart (demo code)](#quickstart-run-demo-locally) * [Energy computation](#energy-computation) * [Download all pretrained models](#pretrained-models) * [Train a consistency model](#training) * [Instructions for training](#steps) * [To train on other configurations](#to-train-on-other-target-domains) * [Citing](#citation)
## Introduction Visual perception entails solving a wide set of tasks (e.g. object detection, depth estimation, etc). The predictions made for each task out of a particular observation are not independent, and therefore, are expected to be **consistent**. **What is consistency?** Suppose an object detector detects a ball in a particular region of an image, while a depth estimator returns a flat surface for the same region. This presents an issue -- at least one of them has to be wrong, because they are _inconsistent_. **Why is it important?** 1. Desired learning tasks are usually predictions of different aspects of a single underlying reality (the scene that underlies an image). Inconsistency among predictions implies contradiction. 2. Consistency constraints are informative and can be used to better fit the data or lower the sample complexity. They may also reduce the tendency of neural networks to learn "surface statistics" (superficial cues) by enforcing constraints rooted in different physical or geometric rules. This is empirically supported by the improved generalization of models when trained with consistency constraints. **How do we enforce it?** The underlying concept is that of path independence in a network of tasks. Given an endpoint `Y2`, the path from `X->Y1->Y2` should give the same results as `X->Y2`. This can be generalized to a larger system, with paths of arbitrary lengths. In this case, the nodes of the graph are our prediction domains (eg. depth, normal) and the edges are neural networks mapping these domains. This repository includes [training](#training) code for enforcing cross task consistency, [demo](#run-demo-script) code for visualizing the results of a consistency trained model on a given image and [links](#pretrained-models) to download these models. For further details, refer to our [paper]() or [website](https://consistency.epfl.ch/). #### Consistency Domains Consistency constraints can be used for virtually any set of domains. This repository considers transferring between image domains, and our networks were trained for transferring between the following domains from the [Taskonomy dataset](https://github.com/StanfordVL/taskonomy/tree/master/data). Curvature Edge-3D Reshading Depth-ZBuffer Keypoint-2D RGB Edge-2D Keypoint-3D Surface-Normal The repo contains consistency-trained models for `RGB -> Surface-Normals`, `RGB -> Depth-ZBuffer`, and `RGB -> Reshading`. In each case the remaining 7 domains are used as consistency constraints in during training. Descriptions for each domain can be found in the [supplementary file](http://taskonomy.stanford.edu/taskonomy_supp_CVPR2018.pdf) of Taskonomy. #### Network Architecture All networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions and trained with the L1 loss. See the table below for more information. | Task Name | Output Dimension | Downsample Blocks | |-------------------------|------------------|-------------------| | `RGB -> Depth-ZBuffer` | 256x256x1 | 6 | | `RGB -> Reshading` | 256x256x1 | 5 | | `RGB -> Surface-Normal` | 256x256x3 | 6 | Other networks (e.g. `Curvature -> Surface-Normal`) use a UNet, their architecture hyperparameters are detailed in [transfers.py](./transfers.py). More information on the models, including download links, can be found [here](#pretrained-models) and in the [supplementary material](https://consistency.epfl.ch/supplementary_material).

## Installation There are two convenient ways to run the code. Either using Docker (recommended) or using a Python-specific tool such as pip, conda, or virtualenv. #### Installation via Docker [Recommended] We provide a docker that contains the code and all the necessary libraries. It's simple to install and run. 1. Simply run: ```bash docker run --runtime=nvidia -ti --rm epflvilab/xtconsistency:latest ``` The code is now available in the docker under your home directory (`/app`), and all the necessary libraries should already be installed in the docker. #### Installation via Pip/Conda/Virtualenv The code can also be run using a Python environment manager such as Conda. See [requirements.txt](./requirements.txt) for complete list of packages. We recommend doing a clean installation of requirements using virtualenv: 1. Clone the repo: ```bash git clone git@github.com:EPFL-VILAB/XTConsistency.git cd XTConsistency ``` 2. Create a new environment and install the libraries: ```bash conda create -n testenv -y python=3.6 source activate testenv pip install -r requirements.txt ```

## Quickstart (Run Demo Locally) #### Download the consistency trained networks If you haven't yet, then download the [pretrained models](#Download-consistency-trained-models). Models used for the demo can be downloaded with the following command: ```bash sh ./tools/download_models.sh ``` This downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. Individial models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P). #### Run a model on your own image To run the trained model of a task on a specific image: ```bash python demo.py --task $TASK --img_path $PATH_TO_IMAGE_OR_FOLDER --output_path $PATH_TO_SAVE_OUTPUT ``` The `--task` flag specifies the target task for the input image, which should be either `normal`, `depth` or `reshading`. To run the script for a `normal` target on the [example image](./assets/test.png): ```bash python demo.py --task normal --img_path assets/test.png --output_path assets/ ``` It returns the output prediction from the baseline (`test_normal_baseline.png`) and consistency models (`test_normal_consistency.png`). Test image | Baseline | Consistency :-------------------------:|:-------------------------: |:-------------------------: ![](./assets/test_scaled.png)| ![](./assets/test_normal_baseline.png) | ![](./assets/test_normal_consistency.png) Similarly, running for target tasks `reshading` and `depth` gives the following. Baseline (reshading) | Consistency (reshading) | Baseline (depth) | Consistency (depth) :-------------------------: |:-------------------------: | :-------------------------: |:-------------------------: ![](./assets/test_reshading_baseline.png) | ![](./assets/test_reshading_consistency.png) | ![](./assets/test_depth_baseline.png) | ![](./assets/test_depth_consistency.png)

## Energy Computation Training with consistency involves several paths that each predict the target domain, but using different cues to do so. The disagreement between these predictions yields an unsupervised quantity, _consistency energy_, that our CVPR 2020 paper found correlates with prediciton error. You can view the pixel-wise _consistency energy_ (example below) using our [live demo](https://consistency.epfl.ch/demo/). | Sample Image | Normal Prediction | Consistency Energy | |:------------------------------------:|:------------------------------------:|:------------------------------------:| | | | | | _Sample image from the Stanford 2D3DS dataset._ | _Some chair legs are missing in the `RGB -> Normal` prediction._ | _The white pixels indicate higher uncertainty about areas with missing chair legs._ | To compute energy locally, over many images, and/or to plot energy vs error, you can use the following `energy_calc.py` script. For example, to reproduce the following scatterplot using `energy_calc.py`: | Energy vs. Error | |:----------------------------------------:| | ![](./assets/energy_vs_error.jpg) | | _Result from running the command below._ | First download a subset of images from the Taskonomy buildings `almena` and `albertville` (512 images per domain, 388MB): ```bash sh ./tools/download_data.sh ``` Second, download all the networks necessary to compute the consistency energy. The following script will download them for you (skipping previously downloaded models) (0.8GB - 4.0GB): ```bash sh ./tools/download_energy_graph_edges.sh ``` Now we are ready to compute energy. The following command generates a scatter plot of _consistency energy_ vs. prediction error: ```bash python -m scripts.energy_calc energy_calc --batch_size 2 --subset_size=128 --save_dir=results ``` By default, it computes the energy and error of the `subset_size` number of points on the Taskonomy buildings `almena` and `albertville`. The error is computed for the `normal` target. The resulting plot is saved to `energy.pdf` in `RESULTS_DIR` and the corresponding data to `data.csv`. #### Compute energy on arbitrary images _Consistency energy_ is an unsupervised quantity and as such, no ground-truth labels are necessary. To compute the energy for all query images in a directory, run: ```bash python -m scripts.energy_calc energy_calc_nogt --data-dir=PATH_TO_QUERY_IMAGE --batch_size 1 --save_dir=RESULTS_DIR \ --subset_size=NUMBER_OF_IMAGES --cont=PATH_TO_TRAINED_MODEL ``` It will append a dashed horizontal line to the plot above where the energy of the query image(s) are. This plot is saved to `energy.pdf` in `RESULTS_DIR`.

## Pretrained Models We are providing all of our pretrained models for download. These models are the same ones used in the [live demo](https://consistency.epfl.ch/demo/) and [video evaluations](https://consistency.epfl.ch/visuals/). #### Network Architecture All networks are based on the [UNet](https://arxiv.org/pdf/1505.04597.pdf) architecture. They take in an input size of 256x256, upsampling is done via bilinear interpolations instead of deconvolutions. All models were trained with the L1 loss. #### Download consistency-trained models Instructions for downloading the trained consistency models can be found [here](#download-consistency-trained-networks) ```bash sh ./tools/download_models.sh ``` This downloads the `baseline`, `consistency` trained models for `depth`, `normal` and `reshading` target (1.3GB) to a folder called `./models/`. See the table below for specifics: | Task Name | Output Dimension | Downsample Blocks | |-------------------------|------------------|-------------------| | `RGB -> Depth-ZBuffer` | 256x256x1 | 6 | | `RGB -> Reshading` | 256x256x1 | 5 | | `RGB -> Surface-Normal` | 256x256x3 | 6 | Individual consistency models can be downloaded [here](https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P). #### Download perceptual networks The pretrained perceptual models can be downloaded with the following command. ```bash sh ./tools/download_percep_models.sh ``` This downloads the perceptual models for the `depth`, `normal` and `reshading` target (1.6GB). Each target has 7 pretrained models (from the other sources below). ``` Curvature Edge-3D Reshading Depth-ZBuffer Keypoint-2D RGB Edge-2D Keypoint-3D Surface-Normal ``` Perceptual model architectural hyperparameters are detailed in [transfers.py](./transfers.py), and some of the pretrained models were trained using L2 loss. For using these models with the provided training code, the pretrained models should be placed in the file path defined by `MODELS_DIR` in [utils.py](./utils.py#L25). Individual perceptual models can be downloaded [here](https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE). #### Download baselines We also provide the models for other baselines used in the paper. Many of these baselines appear in the [live demo](https://consistency.epfl.ch/demo/). The pretrained baselines can be downloaded [here](https://drive.switch.ch/index.php/s/gdom4FpiiYo1Qay). Note that we will not be providing support for them. - A full list of baselines is in the table below: | Baseline Method | Description | Tasks (RGB -> X) | |---------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------|--------------------------| | Baseline UNet [[PDF](https://arxiv.org/pdf/1505.04597.pdf)] | UNets trained on the Taskonomy dataset. | Normal, Reshade, Depth | | Baseline Perceptual Loss | Trained using a randomly initialized percepual network, similar to [RND](http://arxiv.org/pdf/1810.12894.pdf). | Normal | | Cycle Consistency [[PDF](https://arxiv.org/pdf/1703.10593.pdf)] | A CycleGAN trained on the Taskonomy dataset. | Normal | | GeoNet [[PDF](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8578135_)] | Trained on the Taskonomy dataset and using L1 instead of L2 loss. | Normal, Depth | | Multi-Task [[PDF](http://arxiv.org/pdf/1609.02132.pdf)] | A multi-task model we trained using UNets, using a shared encoder (similar to [here](http://arxiv.org/pdf/1609.02132.pdf)) | [All](#consistency-domains) | | Pix2Pix [[PDF](https://arxiv.org/pdf/1611.07004.pdf)] | A Pix2Pix model trained on the Taskonomy dataset. | Normal | | Taskonomy [[PDF](https://arxiv.org/pdf/1804.08328.pdf)] | Pretrained models (converted to pytorch [here](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-)), originally trained [here](https://github.com/StanfordVL/taskonomy/tree/master/taskbank). | Normal, Reshading, Depth* | *Models for other tasks are available using the [`visualpriors` package](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-) or in Tensorflow via the [Taskonomy GitHub page](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-).

## Training We used the provided training code to train our consistency models on the Taskonomy dataset. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week. > **Runnable Example:** You'll find that the code in the rest of this section expects about 12TB of data (9 single-image tasks from Taskonomy). For a quick runnable example that gives the gist, try the following: > > First download the data and then start a visdom (logging) server: > ```bash > sh ./tools/download_data.sh # Starter data (388MB) > visdom & # To view the telemetry > ``` > > Then, start the training using the following command, which cascades two models (trains a `normal` model using `curvature` consistenct on a training set of 512 images). > ```bash > python -m train example_cascade_two_networks --k 1 --fast > ``` > You can add more pereceptual losses by changing the config in `energy.py`. For example, train the above model using _both_ `curvature` and `2D edge` consistency: > ```bash > python -m train example_normal --k 2 --fast > ``` Assuming that you want to train on the full dataset or [on your own dataset], read on. #### The code is structured as follows ```python config/ # Configuration parameters: where to save results, etc. split.txt # Train, val split jobinfo.txt # Defines job name, base_dir modules/ # Network definitions train.py # Training script dataset.py # Creates dataloader energy.py # Defines path config, computes total loss, logging models.py # Implements forward backward pass graph.py # Computes path defined in energy.py task_configs.py # Defines task specific preprocessing, masks, loss fn transfers.py # Loads models utils.py # Defines file paths (described below) demo.py # Demo script ``` #### Expected folder structure The code expects folders structured as follows. These can be modified by changing values in `utils.py` ```python base_dir/ # The following paths are defined in utils.py (BASE_DIR) shared/ # with the corresponding variable names in brackets models/ # Pretrained models (MODELS_DIR) results_[jobname]/ # Checkpoint of model being trained (RESULTS_DIR) ood_standard_set/ # OOD data for visualization (OOD_DIR) data_dir/ # taskonomy data (DATA_DIRS) ``` #### Training with consistency 1) **Define locations for data, models, etc.:** Create a `jobinfo.txt` file and define the name of the job and the absolute path to `BASE_DIR` where data, models results would be stored, as shown in the folder structure above. An example config is provided in the starter code (`configs/jobinfo.txt`). To modify individual file paths eg. the models folder, change `MODELS_DIR` variable name in [utils.py](./utils.py#L25). We won't cover downloading the Taskonomy dataset, which can be downloaded following the instructions [here](https://github.com/StanfordVL/taskonomy/tree/master/data) 2) **Download perceptual networks:** If you want to initialize from our pretrained models, then then [download them](#Download-perceptual-networks) with the following command (1.6GB): ```bash sh ./tools/download_percep_models.sh ``` More info about the networks is available [here](#Download-perceptual-networks). 3) **Train with consistency** using the command: ```bash python -m train multiperceptual_{depth,normal,reshading} ``` For example, to run the training code for the `normal` target, run ```bash python -m train multiperceptual_normal ``` This trains the model for the `normal` target with 8 perceptual losses ie. `curvature`, `edge2d`, `edge3d`, `keypoint2d`, `keypoint3d`, `reshading`, `depth` and `imagenet`. We used 3 V100 (32GB) GPUs to train our models, running them for 500 epochs takes about a week. Additional arugments can be specified during training, the most commonly used ones are listed below. For the full list, refer to the [training script](./train.py). - The flag `--k` defines the number of perceptual losses used, thus reducing GPU memory requirements. - There are several options for choosing how this subset is chosen 1. randomly (`--random-select`) 2. winrate (`--winrate`) - Data augmentation is not done by default, it can be added to the training data with the flag `--dataaug`. The transformations applied are 1. random crop with probability 0.5 2. [color jitter](https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=color%20jitter#torchvision.transforms.ColorJitter) with probability 0.5. To train a `normal` target domain with 2 perceptual losses selected randomly each epoch, run the following command. ```bash python -m train multiperceptual_normal --k 2 --random-select ``` 4) **Logging:** The losses and visualizations are logged in Visdom. This can be accessed via `[server name]/env/[job name]` eg. `localhost:8888/env/normaltarget_allperceps`. An example visualization is shown below. We plot the the outputs from the paths defined in the energy configuration used. Two windows are shown, one shows the predictions before training starts, the other updates them after each epoch. The labels for each column can be found at the top of the window. The second column has the target's ground truth `y^`, the third its prediction `n(x)` from the RGB image `x`. Thereafter, the predictions of each pair of images with the same domain are given by the paths `f(y^),f(n(x))`, where `f` is from the target domain to another domain eg. `curvature`. ![](./assets/visdom_eg.png) **Logging conventions:** For uninteresting historical reasons, the columns in the logging during training might have strange names. You can define your own names instead of using these by changing the config file in `energy.py`. Here's a quick guide to the current convention. For example, when training with a `normal` model using consistency: - The RGB input is denoted as `x` and the `target` domain is denoted as `y`. The ground truth label for a domain is marked with a `^`(e.g. `y^` for the for `target` domain). - The direct (`RGB -> Z`) and perceptual (`target [Y] -> Z`) transfer functions are named as follows:
(i.e. the function for `rgb` to `curvature` is `RC`; for `normal` to `curvature` it's `f`) | Domain (Z) | `rgb -> Z`
(Direct) | `Y -> Z`
(Perceptual) || Domain (Z) | `rgb -> Z`
(Direct) | `Y -> Z`
(Perceptual) | |-------------|------------------------|--------------------------|-|-----------------|------------------------|---------------------------| | target | n | - || keypoints2d | k2 | Nk2 | | curvature | RC | f || keypoints3d | k3 | Nk3 | | sobel edges | a | s || edge occlusion | E0 | nE0 | #### To train on other target domains 1. A new configuration should be defined in the `energy_configs` dictionary in [energy.py](./energy.py#L39-L521). Decription of the infomation needed: - `paths`: `X1->X2->X3`. The keys in this dictionary uses a function notation eg. `f(n(x))`, with its corresponding value being a list of task objects that defines the domains being transfered eg. `[rgb, normal, curvature]`. The `rgb` input is defined as `x`, `n(x)` returns `normal` predictions from `rgb`, and `f(n(x))` returns `curvature` from `normal`. These notations do not need to be same for all configurations. The [table](#function-definitions) below lists those that have been kept constant for all targets. - `freeze_list`: the models that will not be optimized, - `losses`: loss terms to be constructed from the paths defined above, - `plots`: the paths to plots in the visdom environment. 2. New models may need to be defined in the `pretrained_transfers` dictionary in [transfers.py](./transfers.py#L26-L97). For example, for a `curvature` target, and perceptual model `curvature` to `normal`, the code will look for the `principal_curvature2normal.pth` file in `MODELS_DIR` if it is not defined in [transfers.py](./transfers.py#L26-L97). #### To train on other datasets The expected folder structure for the data is, ``` DATA_DIRS/ [building]_[domain]/ [domain]/ [view]_domain_[domain].png ... ``` Pytorch's dataloader _\_\_getitem\_\__ method has been overwritten to return a tuple of all tasks for a given building and view point. This is done in [datasets.py](./datasets.py#L181-L198). Thus, for other folder structures, a function to get the corresponding file paths for different domains should be defined. For task specific configs, like transformations and masks, are defined in [task_configs.py](./task_configs.py#L341-L373).

## Citation If you find the code, models, or data useful, please cite this paper: ``` @article{zamir2020consistency, title={Robust Learning Through Cross-Task Consistency}, author={Zamir, Amir and Sax, Alexander and Yeo, Teresa and Kar, Oğuzhan and Cheerla, Nikhil and Suri, Rohan and Cao, Zhangjie and Malik, Jitendra and Guibas, Leonidas}, journal={arXiv}, year={2020} } ``` ================================================ FILE: config/jobinfo.txt ================================================ normaltarget_allperceps, . ================================================ FILE: config/split.txt ================================================ train_buildings: [woodbine, hometown, haymarket, emmaus, swormville, haxtun, martinville, winfield, marksville, hammon, mammoth, kronborg, cobalt, lenoir, bonnie, bautista, retsof, azusa, munsons, darrtown, michiana, uncertain, lajas, plessis, bellemeade, jacobus, swisshome, idanha, lathrup, hambleton, country, ancor, haaswood, orason, cisne, byers, muleshoe, fredericksburg, cayuse, elmira, adrian, merchantville, german, waipahu, clarkridge, rogue, ludlowville, cochranton, colebrook, lovilia, mullica, mobridge, kevin, yscloskey, laytonsville, convoy, sisters, cosmos, calavo, clairton, crookston, nicut, ladue, anaheim, howie, codell, seatonville, brown, quantico, goodview, parole, kingfisher, churchton, edgemere, micanopy, bettendorf, goffs, windhorst, ooltewah, tokeland, darden, benicia, american, neibert, milaca, willow, samuels, sanctuary, pablo, cantwell, biltmore, germfask, eastville, rough, westfield, wainscott, mahtomedi, kopperl, gluck, mogote, cason, hercules, avonia, pettigrew, ewell, imbery, collierville, maguayo, brentsville, mckeesport, hitchland, browntown, aldine, spotswood, chrisney, kildare, stockman, ribera, hortense, mentasta, soldier, kettle, northgate, crandon, ovalo, cullison, dedham, vails, tallmadge, gratz, fonda, helton, mogadore, highspire, sharon, foyil, shelbiana, seiling, brevort, noonday, springhill, coronado, sundown, carneiro, silas, assinippi, monticello, wappingers, lakeville, stockwell, ogilvie, victorville, braxton, sodaville, silerton, hobson, tradewinds, sands, coffeen, umpqua, blackstone, sarcoxie, model, bremerton, capistrano, deatsville, graceville, dansville, belpre, edson, mcnary, kirwin, rosenberg, lynchburg, ranchester, shingler, auburn, connellsville, alstown, kerrtown, marstons, hurley, mifflintown, pamelia, sumas, chilhowie, dryville, deemston, cashel, galatia, harrellsville, mcdade, eudora, sasakwa, baneberry, rosser, halfway, hainesburg, gravelly, frierson, tyler, irvine, natural, murchison, lindsborg, duarte, wando, globe, neshkoro, cornville, bowmore, roxboro, espanola, maugansville, yankeetown, sawpit, schoolcraft, klickitat, scandinavia, donaldson, aloha, gaylord, hartline, laupahoehoe, wiconisco, mesic, eagerville, keiser, potosi, wyldwood, macarthur, newfields, moberly, everton, lindenwood, nuevo, bethlehem, silva, noxapater, lindberg, hornsby, weleetka, tysons, kremlin, jenners, trail, freedom, mcclure, ruckersville, sugarville, nemacolin, athens, vacherie, checotah, blenheim, allensville, grantsville, holcut, hallettsville, angiola, tomales, grangeville, seward, fishersville, kendall, kangley, wilbraham, caruthers, hacienda, readsboro, pocopson, bonfield, cohoes, inkom, monson, peacock, touhy, divide, norvelt, badger, leilani, corozal, warrenville, lluveras, grigston, cooperstown, nimmons, ewansville, paige, matoaca, lessley, purple, kihei, millbury, culbertson, maunawili, brewton, maryhill, channel, branford, creede, goodfield, spencerville, kirksville, cokeville, barahona, leonardo, mosinee, tolstoy, broseley, broadwell, landing, roeville, hatfield, rancocas, mcewen, annona, okabena, terrell, barboursville, booth, hanson, mashulaville, cutlerville, euharlee, rabbit, fitchburg, shellsburg, milford, grassy, timberon, coeburn, wilkinsburg, lynxville, islandton, arbutus, reyno, wakeman, frankfort, sontag, voorhees, beechwood, ossipee, rockport, starks, woonsocket, hildebran, circleville, aldrich, sunshine, destin, chesterbrook, musicks, merom, kinde, andover, pittsburg, scioto, tilghmanton, castor, potterville, onaga, stanleyville, leavittsburg, carpendale, bountiful, kingdom, cebolla, sweatman, arona, sagerton, herricks, morris, montreal, stokes, newcomb, adairsville, bertram, kinney, spread, akiak, westerville, texasville, springerville, almota, aulander, superior, goodyear, cabin, random, ballou, southfield, ballantine, glenmoor, oriole, ashport, denmark, winthrop, bohemia, yadkinville, smoketown, waucousta, winooski, peden, mayesville, liddieville, clive, gluek, goodwine, uvalda, pleasant, losantville, lineville, hillsdale, ackermanville, waukeenah, mentmore, glassboro, bellwood, peconic, pinesdale, hordville, wells, hendrix, dunmor, fleming, mccloud, reserve, gilbert, bonesteel, roane, pocasset, greigsville, delton, whitethorn, frontenac, siren, artois, helix, melstone, sultan, shumway, seeley, cousins, cauthron, gough, anthoston, gladstone, macland, hiteman, shauck, jennie, airport, funkstown, markleeville, marland, gloria, poyen, annawan, bolton, wattsville, waldenburg, pearce, maiden, gasburg, calmar, applewold, kemblesville, redbank, wilseyville, lucan, archer, castroville, pasatiempo, arkansaw, wyatt, shelbyville, merlin, whiteriver, torrington, oyens] val_buildings: [cottonport, elton, corder, mazomanie, barranquitas, ihlen, wilkesboro, macedon, portal, gastonia, thrall, orangeburg, poipu, kankakee, chiloquin, sussex, maricopa, wesley, bowlus, copemish, tariffville, pomaria, kathryn, rutherford, plumerville, waimea, experiment, dalcour, ohoopee, mifflinburg, callicoon, manassas, macksville, apache, alfred, maida, dauberville, chireno, stilwell, albertville, ellaville, kobuk, burien, carpio, placida, forkland, tippecanoe, beach, eagan, grace, portola, hominy, maben] ================================================ FILE: config/split_fullplus.txt ================================================ train_buildings: [hanson, merom, arbutus, goodfield, eagan, arona, adairsville, reserve, aloha, castor, munsons, ballou, woonsocket, foyil, creede, superior, klickitat, cottonport, rancocas, ossipee, haxtun, tyler, sugarville, martinville, haaswood, euharlee, hacienda, peacock, convoy, roane, ellaville, chilhowie, tomales, springhill, shellsburg, chrisney, rutherford, mullica, winthrop, grace, codell, silas, braxton, chireno, bremerton, lynchburg, halfway, ballantine, hendrix, tokeland, merlin, baneberry, nimmons, mccloud, onaga, aldrich, maguayo, bountiful, carpio, ovalo, waimea, grassy, monticello, portal, melstone, cutlerville, frankfort, gilbert, kinney, gough, quantico, goodyear, waukeenah, pocopson, terrell, andover, ackermanville, copemish, leavittsburg, monson, applewold, sasakwa, anthoston, gaylord, whitethorn, clarkridge, gladstone, cornville, hatfield, circleville, marksville, trail, chesterbrook, gloria, funkstown, tippecanoe, airport, mesic, stokes, rogue, bethlehem, mcewen, hobson, kirwin, glenmoor, ancor, vails, biltmore, kangley, voorhees, leonardo, kevin, maida, castroville, seeley, millbury, waucousta, dansville, ludlowville, cooperstown, elton, marstons, athens, mentmore, inkom, model, ruckersville, hambleton, newfields, broseley, bettendorf, irvine, pinesdale, tilghmanton, american, westerville, maunawili, reyno, sanctuary, portola, goodwine, cullison,ran, cobalt, winooski, orangeburg, jacobus, oriole, lakeville, pasatiempo, beach, mahtomedi, redbank, cosmos, goffs, wilkinsburg, barboursville, frierson, sunshine, globe, kettle, booth, barranquitas, burien, carneiro, placida, hallettsville, fredericksburg, aulander, culbertson, bellwood, stanleyville, smoketown, winfield, seward, emmaus, macarthur, pomaria, kankakee, jenners, nemacolin, archer, neibert, random, parole, carpendale, spotswood, tolstoy, shelbyville, potterville, frontenac, avonia, laupahoehoe, milford, seatonville, roxboro, torrington, grantsville, rosser, spencerville, ewell, wyatt, plumerville, lynxville, allensville, springerville, nuevo, grangeville, stilwell, colebrook, browntown, germfask, readsboro, bowmore, pettigrew, brevort, fitchburg, shelbiana, hometown, montreal, mckeesport, wainscott, adrian, branford, dunmor, windhorst, alfred, islandton, arkansaw, rough, brewton, bonnie, cabin, hitchland, cebolla, mammoth, alstown, beechwood, hominy, crandon, sodaville, everton, maiden, churchton, divide, cayuse, helix, umpqua, chiloquin, schoolcraft, kathryn, coffeen, willow, dedham, corder, timberon, pleasant, darnestown, harrellsville, bohemia, micanopy, hillsdale, mashulaville, wilseyville, kemblesville, thrall, kronborg, fleming, bonesteel, hartline, channel, checotah, orason, fonda, kerrtown, noonday, capistrano, poyen, spread, annona, weleetka, silva, gasburg, milaca, bolton, kendall, stockman, whiteriver, ooltewah, wattsville, soldier, cohoes, lucan, neshkoro, newcomb, cason, roeville, kremlin, yankeetown, maryhill, deatsville, lenoir, byers, kildare, kirksville, blackstone, oyens, peden, azusa, mogote, keiser, potosi, delton, victorville, pamelia, cisne, marland, hiteman, lindberg, mayesville, sussex, nicut, lindsborg, seiling, mobridge, bautista, tradewinds, annawan, mogadore, retsof, murchison, lluveras, highspire, woodbine, lajas, sweatman, okabena, clairton, angiola, callicoon, hurley, touhy, michiana, sawpit, stockwell, country, lindenwood, anaheim, dryville, duarte, sumas, siren, musicks, wilbraham, holcut, kingdom, kopperl, calmar, forkland, mifflinburg, lineville, mosinee, hainesburg, gratz, maugansville, haymarket, uncertain, mifflintown, scandinavia, mazomanie, tariffville, mentasta, freedom, bowlus, apache, ranchester, fishersville, northgate, eagerville] val_buildings: [ogilvie, laytonsville, clive, hortense, wesley, idanha, maricopa, tomkins, texasville, sundown, southfield, moberly, auburn, eudora, wiconisco, tallmadge, ashport, lessley, assinippi, gravelly, silerton, hordville, corozal, swormville, warrenville, almota, cantwell, collierville, cashel, pearce, sisters, merchantville, glassboro, shauck, yscloskey, pablo, ribera, pittsburg, graceville, markleeville, hercules, macedon, howie, sands, cokeville, herricks, kinde, edgemere, kobuk, sagerton, starks, jennie, hornsby, greigsville, westfield, wyldwood, sarcoxie, artois, elmira, deemston, galatia, denmark, swisshome, wells, scioto, pocasset, waldenburg, shumway, waipahu, macksville, crookston, eastville, yadkinville, darden] ================================================ FILE: config/split_medium.txt ================================================ train_buildings: [hanson, merom, goodfield, eagan, adairsville, castor, klickitat, cottonport, tyler, sugarville, martinville, chilhowie, silas, lynchburg, tokeland, onaga, frankfort, goodyear, albertville, andover, airport, rogue, ancor, leonardo, maida, marstons, athens, newfields, broseley, irvine, pinesdale, tilghmanton, goodwine, hildebran, winooski, lakeville, cosmos, goffs, sunshine, globe, benevolence, emmaus, pomaria, neibert, parole, tolstoy, shelbyville, potterville, rosser, allensville, springerville, nuevo, stilwell, browntown, readsboro, shelbiana, wainscott, arkansaw, bonnie, beechwood, hominy, churchton, coffeen, willow, timberon, bohemia, micanopy, hillsdale, wilseyville, kemblesville, thrall, bonesteel, annona, stockman, soldier, neshkoro, newcomb, byers, oyens, victorville, pamelia, marland, hiteman, sussex, bautista, highspire, woodbine, sweatman, clairton, touhy, lindenwood, anaheim, duarte, musicks, forkland, mifflinburg, hainesburg, maugansville, ranchester] val_buildings: [hortense, southfield, wiconisco, gravelly, hordville, corozal, swormville, collierville, pearce, pablo, pittsburg, markleeville, sands, kobuk, westfield, wyldwood, swisshome, scioto, waipahu, darden] ================================================ FILE: datasets.py ================================================ import numpy as np import matplotlib as mpl import os, sys, math, random, tarfile, glob, time, yaml, itertools import parse import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from utils import * from logger import Logger, VisdomLogger from task_configs import get_task, tasks from PIL import Image from io import BytesIO from sklearn.model_selection import train_test_split import IPython import pdb """ Default data loading configurations for training, validation, and testing. """ def load_train_val(train_tasks, val_tasks=None, fast=False, train_buildings=None, val_buildings=None, split_file="config/split.txt", dataset_cls=None, batch_size=32, batch_transforms=cycle, subset=None, subset_size=None, dataaug=False, ): dataset_cls = dataset_cls or TaskDataset train_cls = TrainTaskDataset if dataaug else dataset_cls train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks] if val_tasks is None: val_tasks = train_tasks val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks] data = yaml.load(open(split_file)) train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"]) val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"]) print("number of train images:") train_loader = train_cls(buildings=train_buildings, tasks=train_tasks) print("number of val images:") val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks) if subset_size is not None or subset is not None: train_loader = torch.utils.data.Subset(train_loader, random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)), ) train_step = int(len(train_loader) // (400 * batch_size)) val_step = int(len(val_loader) // (400 * batch_size)) print("Train step: ", train_step) print("Val step: ", val_step) if fast: train_step, val_step = 8, 8 return train_loader, val_loader, train_step, val_step """ Load all buildings """ def load_all(tasks, buildings=None, batch_size=64, split_file="data/split.txt", batch_transforms=cycle): data = yaml.load(open(split_file)) buildings = buildings or (data["train_buildings"] + data["val_buildings"]) data_loader = torch.utils.data.DataLoader( TaskDataset(buildings=buildings, tasks=tasks), batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True ) return data_loader def load_test(all_tasks, buildings=["almena", "albertville"], sample=4): all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks] print(f"number of images in {buildings[0]}:") test_loader1 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=0, shuffle=False, pin_memory=True, ) print(f"number of images in {buildings[1]}:") test_loader2 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=0, shuffle=False, pin_memory=True, ) set1 = list(itertools.islice(test_loader1, 1))[0] set2 = list(itertools.islice(test_loader2, 1))[0] test_set = tuple(torch.cat([x, y], dim=0) for x, y in zip(set1, set2)) return test_set def load_ood(tasks=[tasks.rgb], ood_path=OOD_DIR, sample=21): ood_loader = torch.utils.data.DataLoader( ImageDataset(tasks=tasks, data_dir=ood_path), batch_size=sample, num_workers=sample, shuffle=False, pin_memory=True ) ood_images = list(itertools.islice(ood_loader, 1))[0] return ood_images class TaskDataset(Dataset): def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS, building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True): super().__init__() self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs self.building_files = building_files or self.building_files self.convert_path = convert_path or self.convert_path self.resize = resize if use_raid: self.convert_path = self.convert_path_raid self.building_files = self.building_files_raid self.file_map = {} for data_dir in self.data_dirs: for file in glob.glob(f'{data_dir}/*'): res = parse.parse("{building}_{task}", file[len(data_dir)+1:]) if res is None: continue self.file_map[file[len(data_dir)+1:]] = data_dir filtered_files = None assert (len(tasks) > 0), "Building dataset for tasks, but no tasks specified!" task = tasks[0] task_files = [] for building in buildings: task_files += self.building_files(task, building) print(f" {task.name} file len: {len(task_files)}") self.idx_files = task_files if not shuffle: self.idx_files = sorted(task_files) print (" Intersection files len: ", len(self.idx_files)) def reset_unpaired(self): if self.unpaired: self.task_indices = {task:random.sample(range(len(self.idx_files)), len(self.idx_files)) for task in self.task_indices} def building_files(self, task, building): """ Gets all the tasks in a given building (grouping of data) """ return get_files(f"{building}_{task.file_name}/{task.file_name}/*.{task.file_ext}", self.data_dirs) def building_files_raid(self, task, building): return get_files(f"{task.file_name}/{building}/*.{task.file_ext}", self.data_dirs) def convert_path(self, source_file, task): """ Converts a file from task A to task B. Can be overriden by subclasses""" source_file = "/".join(source_file.split('/')[-3:]) result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file) building, _, view = (result["building"], result["task"], result["view"]) dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}" if f"{building}_{task.file_name}" not in self.file_map: print (f"{building}_{task.file_name} not in file map") # IPython.embed() return "" data_dir = self.file_map[f"{building}_{task.file_name}"] return f"{data_dir}/{dest_file}" def convert_path_raid(self, full_file, task): """ Converts a file from task A to task B. Can be overriden by subclasses""" source_file = "/".join(full_file.split('/')[-3:]) result = parse.parse("{task}/{building}/{view}.{ext}", source_file) building, _, view = (result["building"], result["task"], result["view"]) dest_file = f"{task.file_name}/{building}/{view}.{task.file_ext}" return f"{full_file[:-len(source_file)-1]}/{dest_file}" def __len__(self): return len(self.idx_files) def __getitem__(self, idx): for i in range(200): try: res = [] seed = random.randint(0, 1e10) for task in self.tasks: file_name = self.convert_path(self.idx_files[idx], task) if len(file_name) == 0: raise Exception("unable to convert file") image = task.file_loader(file_name, resize=self.resize, seed=seed) res.append(image) return tuple(res) except Exception as e: idx = random.randrange(0, len(self.idx_files)) if i == 199: raise (e) class TrainTaskDataset(TaskDataset): def __getitem__(self, idx): for i in range(200): try: res = [] seed = random.randint(0, 1e10) crop = random.randint(int(0.7*512), 512) if bool(random.getrandbits(1)) else 512 for task in self.tasks: jitter = bool(random.getrandbits(1)) if task.name == 'rgb' else False file_name = self.convert_path(self.idx_files[idx], task) if len(file_name) == 0: raise Exception("unable to convert file") image = task.file_loader(file_name, resize=self.resize, seed=seed, crop=crop, jitter=jitter) res.append(image) return tuple(res) except Exception as e: idx = random.randrange(0, len(self.idx_files)) if i == 199: raise (e) class ImageDataset(Dataset): def __init__( self, tasks=[tasks.rgb], data_dir=f"data/ood_images", files=None, ): self.tasks = tasks #if not USE_RAID and files is None: # os.system(f"ls {data_dir}/*.png") # os.system(f"ls {data_dir}/*.png") self.files = files \ or sorted( glob.glob(f"{data_dir}/*.png") + glob.glob(f"{data_dir}/*.jpg") + glob.glob(f"{data_dir}/*.jpeg") ) print("number of ood images: ", len(self.files)) def __len__(self): return len(self.files) def __getitem__(self, idx): file = self.files[idx] res = [] seed = random.randint(0, 1e10) for task in self.tasks: image = task.file_loader(file, seed=seed) if image.shape[0] == 1: image = image.expand(3, -1, -1) res.append(image) return tuple(res) if __name__ == "__main__": logger = VisdomLogger("data", env=JOB) train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)], batch_size=32, ) print ("created dataset") logger.add_hook(lambda logger, data: logger.step(), freq=32) for i, _ in enumerate(train_dataset): logger.update("epoch", i) ================================================ FILE: demo.py ================================================ import torch from torchvision import transforms from modules.unet import UNet, UNetReshade import PIL from PIL import Image import argparse import os.path from pathlib import Path import glob import sys import pdb parser = argparse.ArgumentParser(description='Visualize output for a single Task') parser.add_argument('--task', dest='task', help="normal, depth or reshading") parser.set_defaults(task='NONE') parser.add_argument('--img_path', dest='img_path', help="path to rgb image") parser.set_defaults(im_name='NONE') parser.add_argument('--output_path', dest='output_path', help="path to where output image should be stored") parser.set_defaults(store_name='NONE') args = parser.parse_args() root_dir = './models/' trans_totensor = transforms.Compose([transforms.Resize(256, interpolation=PIL.Image.BILINEAR), transforms.CenterCrop(256), transforms.ToTensor()]) trans_topil = transforms.ToPILImage() os.system(f"mkdir -p {args.output_path}") # get target task and model target_tasks = ['normal','depth','reshading'] try: task_index = target_tasks.index(args.task) except: print("task should be one of the following: normal, depth, reshading") sys.exit() models = [UNet(), UNet(downsample=6, out_channels=1), UNetReshade(downsample=5)] model = models[task_index] map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') def save_outputs(img_path, output_file_name): img = Image.open(img_path) img_tensor = trans_totensor(img)[:3].unsqueeze(0) # compute baseline and consistency output for type in ['baseline','consistency']: path = root_dir + 'rgb2'+args.task+'_'+type+'.pth' model_state_dict = torch.load(path, map_location=map_location) model.load_state_dict(model_state_dict) baseline_output = model(img_tensor).clamp(min=0, max=1) trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png') img_path = Path(args.img_path) if img_path.is_file(): save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0]) elif img_path.is_dir(): for f in glob.glob(args.img_path+'/*'): save_outputs(f, os.path.splitext(os.path.basename(f))[0]) else: print("invalid file path!") sys.exit() ================================================ FILE: energy.py ================================================ import os, sys, math, random, itertools from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from torch.utils.checkpoint import checkpoint from utils import * from task_configs import tasks, get_task, ImageTask from transfers import functional_transfers, finetuned_transfers, get_transfer_name, Transfer from datasets import TaskDataset, load_train_val from matplotlib.cm import get_cmap import IPython import pdb def get_energy_loss( config="", mode="winrate", pretrained=True, finetuned=True, **kwargs, ): """ Loads energy loss from config dict. """ if isinstance(mode, str): mode = { "standard": EnergyLoss, "winrate": WinRateEnergyLoss, }[mode] return mode(**energy_configs[config], pretrained=pretrained, finetuned=finetuned, **kwargs ) ALL_PERCEPTUAL_TASKS = [tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.edge_occlusion, tasks.reshading, tasks.keypoints3d, tasks.keypoints2d] def generate_config(perceptual_tasks, target_task=tasks.normal, tree_structure=False, has_gt=True): # If we have GT, measure error base_keys = { "x": [tasks.rgb], "n((x))": [tasks.rgb, target_task] } direct_losses = {} if has_gt: base_keys["y^"] = [target_task] direct_losses["direct_normal"] = { ("train", "val", "train_subset"): [ ("n((x))", "y^"), ] } # Add in losses for consistency energy perceptual_losses = [] if tree_structure: perceptual_losses = [('', t2) for t2 in perceptual_tasks] else: perceptual_losses = list(itertools.combinations(perceptual_tasks + [''], r=2)) perceptual_losses = [(t2, t1) if str(t2) == '' else (t1, t2) for t1, t2 in perceptual_losses] return { "paths": both( base_keys, { f"n({intermediate_task}(x))": [tasks.rgb, intermediate_task, target_task] for intermediate_task in perceptual_tasks } ), "freeze_list" : [(t, target_task) for t in perceptual_tasks] + [(tasks.rgb, t) for t in perceptual_tasks], "losses": both( direct_losses, { f'percep_{t1}_{t2}': { ("train", "val", "train_subset"): [ (f"n({t1}(x))", f"n({t2}(x))"), ], } for t1, t2 in perceptual_losses } ), "plots": { "ID": dict( size=256, realities=("test",), paths=[ "x", "n((x))", ] ), }, } energy_configs = { "example_cascade_two_networks": { "paths": { "x": [tasks.rgb], "y^": [tasks.normal], "n(x)": [tasks.rgb, tasks.normal], "RC(x)": [tasks.rgb, tasks.principal_curvature], "curv": [tasks.principal_curvature], "f(y^)": [tasks.normal, tasks.principal_curvature], "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature], }, "freeze_list": [ [tasks.normal, tasks.principal_curvature], [tasks.normal, tasks.sobel_edges], ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, "percep_curv": { ("train", "val"): [ ("f(n(x))", "f(y^)"), ], }, "direct_curv": { ("train", "val"): [ ("RC(x)", "curv"), ], }, }, "plots": { "": dict( size=256, realities=('train', 'val'), paths=[ "x", "y^", "n(x)", "f(y^)", "f(n(x))", ] ), }, }, "example_normal": { "paths": { "x": [tasks.rgb], "y^": [tasks.normal], "n(x)": [tasks.rgb, tasks.normal], "a(x)": [tasks.rgb, tasks.sobel_edges], "RC(x)": [tasks.rgb, tasks.principal_curvature], "edge": [tasks.sobel_edges], "curv": [tasks.principal_curvature], "s(y^)": [tasks.normal, tasks.sobel_edges], "s(n(x))": [tasks.rgb, tasks.normal, tasks.sobel_edges], "f(y^)": [tasks.normal, tasks.principal_curvature], "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature], }, "freeze_list": [ [tasks.normal, tasks.principal_curvature], [tasks.normal, tasks.sobel_edges], ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, "percep_curv": { ("train", "val"): [ ("f(n(x))", "f(y^)"), ], }, "direct_curv": { ("train", "val"): [ ("RC(x)", "curv"), ], }, "percep_edge": { ("train", "val"): [ ("s(n(x))", "s(y^)"), ], }, "direct_edge": { ("train", "val"): [ ("a(x)", "s(y^)"), ], }, }, "plots": { "": dict( size=256, realities=('train', 'val'), paths=[ "x", "y^", "n(x)", "f(y^)", "f(n(x))", "s(y^)", "s(n(x))", ] ), }, }, "example_normal_direct": { "paths": { "x": [tasks.rgb], "y^": [tasks.normal], "n(x)": [tasks.rgb, tasks.normal], }, "freeze_list": [ ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, }, "plots": { "": dict( size=256, realities=('train', 'val'), paths=[ "x", "y^", "n(x)", ] ), }, }, "multiperceptual_normal": { "paths": { "x": [tasks.rgb], "y^": [tasks.normal], "n(x)": [tasks.rgb, tasks.normal], "RC(x)": [tasks.rgb, tasks.principal_curvature], "a(x)": [tasks.rgb, tasks.sobel_edges], "d(x)": [tasks.rgb, tasks.reshading], "r(x)": [tasks.rgb, tasks.depth_zbuffer], "EO(x)": [tasks.rgb, tasks.edge_occlusion], "k2(x)": [tasks.rgb, tasks.keypoints2d], "k3(x)": [tasks.rgb, tasks.keypoints3d], "curv": [tasks.principal_curvature], "edge": [tasks.sobel_edges], "depth": [tasks.depth_zbuffer], "reshading": [tasks.reshading], "keypoints2d": [tasks.keypoints2d], "keypoints3d": [tasks.keypoints3d], "edge_occlusion": [tasks.edge_occlusion], "f(y^)": [tasks.normal, tasks.principal_curvature], "f(n(x))": [tasks.rgb, tasks.normal, tasks.principal_curvature], "s(y^)": [tasks.normal, tasks.sobel_edges], "s(n(x))": [tasks.rgb, tasks.normal, tasks.sobel_edges], "g(y^)": [tasks.normal, tasks.reshading], "g(n(x))": [tasks.rgb, tasks.normal, tasks.reshading], "nr(y^)": [tasks.normal, tasks.depth_zbuffer], "nr(n(x))": [tasks.rgb, tasks.normal, tasks.depth_zbuffer], "Nk2(y^)": [tasks.normal, tasks.keypoints2d], "Nk2(n(x))": [tasks.rgb, tasks.normal, tasks.keypoints2d], "Nk3(y^)": [tasks.normal, tasks.keypoints3d], "Nk3(n(x))": [tasks.rgb, tasks.normal, tasks.keypoints3d], "nEO(y^)": [tasks.normal, tasks.edge_occlusion], "nEO(n(x))": [tasks.rgb, tasks.normal, tasks.edge_occlusion], "imagenet(y^)": [tasks.normal, tasks.imagenet], "imagenet(n(x))": [tasks.rgb, tasks.normal, tasks.imagenet], }, "freeze_list": [ [tasks.normal, tasks.principal_curvature], [tasks.normal, tasks.sobel_edges], [tasks.normal, tasks.reshading], [tasks.normal, tasks.depth_zbuffer], [tasks.normal, tasks.keypoints3d], [tasks.normal, tasks.keypoints2d], [tasks.normal, tasks.edge_occlusion], [tasks.normal, tasks.imagenet], ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, "percep_curv": { ("train", "val"): [ ("f(n(x))", "f(y^)"), ], }, "direct_curv": { ("train", "val"): [ ("RC(x)", "curv"), ], }, "percep_edge": { ("train", "val"): [ ("s(n(x))", "s(y^)"), ], }, "direct_edge": { ("train", "val"): [ ("a(x)", "s(y^)"), ], }, "percep_reshading": { ("train", "val"): [ ("g(n(x))", "g(y^)"), ], }, "direct_reshading": { ("train", "val"): [ ("d(x)", "reshading"), ], }, "percep_depth_zbuffer": { ("train", "val"): [ ("nr(n(x))", "nr(y^)"), ], }, "direct_depth_zbuffer": { ("train", "val"): [ ("r(x)", "depth"), ], }, "percep_keypoints2d": { ("train", "val"): [ ("Nk2(n(x))", "Nk2(y^)"), ], }, "direct_keypoints2d": { ("train", "val"): [ ("k2(x)", "keypoints2d"), ], }, "percep_keypoints3d": { ("train", "val"): [ ("Nk3(n(x))", "Nk3(y^)"), ], }, "direct_keypoints3d": { ("train", "val"): [ ("k3(x)", "keypoints3d"), ], }, "percep_edge_occlusion": { ("train", "val"): [ ("nEO(n(x))", "nEO(y^)"), ], }, "direct_edge_occlusion": { ("train", "val"): [ ("EO(x)", "edge_occlusion"), ], }, "percep_imagenet_percep": { ("train", "val"): [ ("imagenet(n(x))", "imagenet(y^)"), ], }, "direct_imagenet_percep": { ("train", "val"): [ ("RC(x)", "curv"), ], }, }, "plots": { "": dict( size=256, realities=("test", "ood"), paths=[ "x", "y^", "n(x)", "f(y^)", "f(n(x))", "s(y^)", "s(n(x))", "g(y^)", "g(n(x))", "nr(n(x))", "nr(y^)", "Nk3(y^)", "Nk3(n(x))", "Nk2(y^)", "Nk2(n(x))", "nEO(y^)", "nEO(n(x))", ] ), }, }, "multiperceptual_reshading": { "paths": { "x": [tasks.rgb], "y^": [tasks.reshading], "n(x)": [tasks.rgb, tasks.reshading], "RC(x)": [tasks.rgb, tasks.principal_curvature], "a(x)": [tasks.rgb, tasks.sobel_edges], "d(x)": [tasks.rgb, tasks.normal], "r(x)": [tasks.rgb, tasks.depth_zbuffer], "EO(x)": [tasks.rgb, tasks.edge_occlusion], "k2(x)": [tasks.rgb, tasks.keypoints2d], "k3(x)": [tasks.rgb, tasks.keypoints3d], "curv": [tasks.principal_curvature], "edge": [tasks.sobel_edges], "depth": [tasks.depth_zbuffer], "normal": [tasks.normal], "keypoints2d": [tasks.keypoints2d], "keypoints3d": [tasks.keypoints3d], "edge_occlusion": [tasks.edge_occlusion], "f(y^)": [tasks.reshading, tasks.principal_curvature], "f(n(x))": [tasks.rgb, tasks.reshading, tasks.principal_curvature], "s(y^)": [tasks.reshading, tasks.sobel_edges], "s(n(x))": [tasks.rgb, tasks.reshading, tasks.sobel_edges], "g(y^)": [tasks.reshading, tasks.normal], "g(n(x))": [tasks.rgb, tasks.reshading, tasks.normal], "nr(y^)": [tasks.reshading, tasks.depth_zbuffer], "nr(n(x))": [tasks.rgb, tasks.reshading, tasks.depth_zbuffer], "Nk2(y^)": [tasks.reshading, tasks.keypoints2d], "Nk2(n(x))": [tasks.rgb, tasks.reshading, tasks.keypoints2d], "Nk3(y^)": [tasks.reshading, tasks.keypoints3d], "Nk3(n(x))": [tasks.rgb, tasks.reshading, tasks.keypoints3d], "nEO(y^)": [tasks.reshading, tasks.edge_occlusion], "nEO(n(x))": [tasks.rgb, tasks.reshading, tasks.edge_occlusion], "imagenet(y^)": [tasks.reshading, tasks.imagenet], "imagenet(n(x))": [tasks.rgb, tasks.reshading, tasks.imagenet], }, "freeze_list": [ [tasks.reshading, tasks.principal_curvature], [tasks.reshading, tasks.sobel_edges], [tasks.reshading, tasks.normal], [tasks.reshading, tasks.depth_zbuffer], [tasks.reshading, tasks.keypoints3d], [tasks.reshading, tasks.keypoints2d], [tasks.reshading, tasks.edge_occlusion], [tasks.reshading, tasks.imagenet], ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, "percep_curv": { ("train", "val"): [ ("f(n(x))", "f(y^)"), ], }, "direct_curv": { ("train", "val"): [ ("RC(x)", "curv"), ], }, "percep_edge": { ("train", "val"): [ ("s(n(x))", "s(y^)"), ], }, "direct_edge": { ("train", "val"): [ ("a(x)", "s(y^)"), ], }, "percep_normal": { ("train", "val"): [ ("g(n(x))", "g(y^)"), ], }, "direct_normal": { ("train", "val"): [ ("d(x)", "normal"), ], }, "percep_depth_zbuffer": { ("train", "val"): [ ("nr(n(x))", "nr(y^)"), ], }, "direct_depth_zbuffer": { ("train", "val"): [ ("r(x)", "depth"), ], }, "percep_keypoints2d": { ("train", "val"): [ ("Nk2(n(x))", "Nk2(y^)"), ], }, "direct_keypoints2d": { ("train", "val"): [ ("k2(x)", "keypoints2d"), ], }, "percep_keypoints3d": { ("train", "val"): [ ("Nk3(n(x))", "Nk3(y^)"), ], }, "direct_keypoints3d": { ("train", "val"): [ ("k3(x)", "keypoints3d"), ], }, "percep_edge_occlusion": { ("train", "val"): [ ("nEO(n(x))", "nEO(y^)"), ], }, "direct_edge_occlusion": { ("train", "val"): [ ("EO(x)", "edge_occlusion"), ], }, "percep_imagenet_percep": { ("train", "val"): [ ("imagenet(n(x))", "imagenet(y^)"), ], }, "direct_imagenet_percep": { ("train", "val"): [ ("RC(x)", "curv"), ], }, }, "plots": { "": dict( size=256, realities=("test", "ood"), paths=[ "x", "y^", "n(x)", "f(y^)", "f(n(x))", "s(y^)", "s(n(x))", "g(y^)", "g(n(x))", "nr(n(x))", "nr(y^)", "Nk3(y^)", "Nk3(n(x))", "Nk2(y^)", "Nk2(n(x))", "nEO(y^)", "nEO(n(x))", "depth", ] ), }, }, "multiperceptual_depth": { "paths": { "x": [tasks.rgb], "y^": [tasks.depth_zbuffer], "n(x)": [tasks.rgb, tasks.depth_zbuffer], "RC(x)": [tasks.rgb, tasks.principal_curvature], "a(x)": [tasks.rgb, tasks.sobel_edges], "d(x)": [tasks.rgb, tasks.normal], "r(x)": [tasks.rgb, tasks.reshading], "EO(x)": [tasks.rgb, tasks.edge_occlusion], "k2(x)": [tasks.rgb, tasks.keypoints2d], "k3(x)": [tasks.rgb, tasks.keypoints3d], "curv": [tasks.principal_curvature], "edge": [tasks.sobel_edges], "normal": [tasks.normal], "reshading": [tasks.reshading], "keypoints2d": [tasks.keypoints2d], "keypoints3d": [tasks.keypoints3d], "edge_occlusion": [tasks.edge_occlusion], "f(y^)": [tasks.depth_zbuffer, tasks.principal_curvature], "f(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.principal_curvature], "s(y^)": [tasks.depth_zbuffer, tasks.sobel_edges], "s(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.sobel_edges], "g(y^)": [tasks.depth_zbuffer, tasks.normal], "g(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.normal], "nr(y^)": [tasks.depth_zbuffer, tasks.reshading], "nr(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.reshading], "Nk2(y^)": [tasks.depth_zbuffer, tasks.keypoints2d], "Nk2(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints2d], "Nk3(y^)": [tasks.depth_zbuffer, tasks.keypoints3d], "Nk3(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.keypoints3d], "nEO(y^)": [tasks.depth_zbuffer, tasks.edge_occlusion], "nEO(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.edge_occlusion], "imagenet(y^)": [tasks.depth_zbuffer, tasks.imagenet], "imagenet(n(x))": [tasks.rgb, tasks.depth_zbuffer, tasks.imagenet], }, "freeze_list": [ [tasks.depth_zbuffer, tasks.principal_curvature], [tasks.depth_zbuffer, tasks.sobel_edges], [tasks.depth_zbuffer, tasks.normal], [tasks.depth_zbuffer, tasks.reshading], [tasks.depth_zbuffer, tasks.keypoints3d], [tasks.depth_zbuffer, tasks.keypoints2d], [tasks.depth_zbuffer, tasks.edge_occlusion], [tasks.depth_zbuffer, tasks.imagenet], ], "losses": { "mae": { ("train", "val"): [ ("n(x)", "y^"), ], }, "percep_curv": { ("train", "val"): [ ("f(n(x))", "f(y^)"), ], }, "direct_curv": { ("train", "val"): [ ("RC(x)", "curv"), ], }, "percep_edge": { ("train", "val"): [ ("s(n(x))", "s(y^)"), ], }, "direct_edge": { ("train", "val"): [ ("a(x)", "s(y^)"), ], }, "percep_normal": { ("train", "val"): [ ("g(n(x))", "g(y^)"), ], }, "direct_normal": { ("train", "val"): [ ("d(x)", "normal"), ], }, "percep_reshading": { ("train", "val"): [ ("nr(n(x))", "nr(y^)"), ], }, "direct_reshading": { ("train", "val"): [ ("r(x)", "reshading"), ], }, "percep_keypoints2d": { ("train", "val"): [ ("Nk2(n(x))", "Nk2(y^)"), ], }, "direct_keypoints2d": { ("train", "val"): [ ("k2(x)", "keypoints2d"), ], }, "percep_keypoints3d": { ("train", "val"): [ ("Nk3(n(x))", "Nk3(y^)"), ], }, "direct_keypoints3d": { ("train", "val"): [ ("k3(x)", "keypoints3d"), ], }, "percep_edge_occlusion": { ("train", "val"): [ ("nEO(n(x))", "nEO(y^)"), ], }, "direct_edge_occlusion": { ("train", "val"): [ ("EO(x)", "edge_occlusion"), ], }, "percep_imagenet_percep": { ("train", "val"): [ ("imagenet(n(x))", "imagenet(y^)"), ], }, "direct_imagenet_percep": { ("train", "val"): [ ("RC(x)", "curv"), ], }, }, "plots": { "": dict( size=256, realities=("test", "ood"), paths=[ "x", "y^", "n(x)", "f(y^)", "f(n(x))", "s(y^)", "s(n(x))", "g(y^)", "g(n(x))", "nr(n(x))", "nr(y^)", "Nk3(y^)", "Nk3(n(x))", "Nk2(y^)", "Nk2(n(x))", "nEO(y^)", "nEO(n(x))", ] ), }, }, "energy_calc": generate_config(ALL_PERCEPTUAL_TASKS), "energy_calc_nogt": generate_config(ALL_PERCEPTUAL_TASKS, has_gt=False), } def coeff_hook(coeff): def fun1(grad): return coeff*grad.clone() return fun1 class EnergyLoss(object): def __init__(self, paths, losses, plots, pretrained=True, finetuned=False, freeze_list=[] ): self.paths, self.losses, self.plots = paths, losses, plots self.freeze_list = [str((path[0].name, path[1].name)) for path in freeze_list] self.metrics = {} self.tasks = [] for _, loss_item in self.losses.items(): for realities, losses in loss_item.items(): for path1, path2 in losses: self.tasks += self.paths[path1] + self.paths[path2] for name, config in self.plots.items(): for path in config["paths"]: self.tasks += self.paths[path] self.tasks = list(set(self.tasks)) def compute_paths(self, graph, reality=None, paths=None): path_cache = {} paths = paths or self.paths path_values = { name: graph.sample_path(path, reality=reality, use_cache=True, cache=path_cache, ) for name, path in paths.items() } del path_cache return {k: v for k, v in path_values.items() if v is not None} def get_tasks(self, reality): tasks = [] for _, loss_item in self.losses.items(): for realities, losses in loss_item.items(): if reality in realities: for path1, path2 in losses: tasks += [self.paths[path1][0], self.paths[path2][0]] for name, config in self.plots.items(): if reality in config["realities"]: for path in config["paths"]: tasks += [self.paths[path][0]] return list(set(tasks)) def __call__(self, graph, discriminator=None, realities=[], loss_types=None, reduce=True, use_l1=False): #pdb.set_trace() loss = {} for reality in realities: loss_dict = {} losses = [] all_loss_types = set() for loss_type, loss_item in self.losses.items(): all_loss_types.add(loss_type) loss_dict[loss_type] = [] for realities_l, data in loss_item.items(): if reality.name in realities_l: loss_dict[loss_type] += data if loss_types is not None and loss_type in loss_types: losses += data path_values = self.compute_paths(graph, paths={ path: self.paths[path] for path in \ set(path for paths in losses for path in paths) }, reality=reality) if reality.name not in self.metrics: self.metrics[reality.name] = defaultdict(list) for loss_type, losses in sorted(loss_dict.items()): if loss_type not in (loss_types or all_loss_types): continue if loss_type not in loss: loss[loss_type] = 0 for path1, path2 in losses: output_task = self.paths[path1][-1] compute_mask = 'imagenet(n(x))' != path1 if loss_type not in loss: loss[loss_type] = 0 for path1, path2 in losses: output_task = self.paths[path1][-1] if "direct" in loss_type: with torch.no_grad(): path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False) loss[loss_type] += path_loss else: path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=False) loss[loss_type] += path_loss loss_name = "mae" if "mae" in loss_type else loss_type+"_mae" self.metrics[reality.name][loss_name +" : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()] path_loss, _ = output_task.norm(path_values[path1], path_values[path2], batch_mean=reduce, compute_mask=compute_mask, compute_mse=True) loss_name = "mse" if "mae" in loss_type else loss_type + "_mse" self.metrics[reality.name][loss_name +" : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()] return loss def logger_hooks(self, logger): name_to_realities = defaultdict(list) for loss_type, loss_item in self.losses.items(): for realities, losses in loss_item.items(): for path1, path2 in losses: loss_name = "mae" if "mae" in loss_type else loss_type+"_mae" name = loss_name+" : "+path1 + " -> " + path2 name_to_realities[name] += list(realities) loss_name = "mse" if "mae" in loss_type else loss_type + "_mse" name = loss_name+" : "+path1 + " -> " + path2 name_to_realities[name] += list(realities) for name, realities in name_to_realities.items(): def jointplot(logger, data, name=name, realities=realities): names = [f"{reality}_{name}" for reality in realities] if not all(x in data for x in names): return data = np.stack([data[x] for x in names], axis=1) logger.plot(data, name, opts={"legend": names}) logger.add_hook(partial(jointplot, name=name, realities=realities), feature=f"{realities[-1]}_{name}", freq=1) def logger_update(self, logger): name_to_realities = defaultdict(list) for loss_type, loss_item in self.losses.items(): for realities, losses in loss_item.items(): for path1, path2 in losses: loss_name = "mae" if "mae" in loss_type else loss_type+"_mae" name = loss_name+" : "+path1 + " -> " + path2 name_to_realities[name] += list(realities) loss_name = "mse" if "mae" in loss_type else loss_type + "_mse" name = loss_name+" : "+path1 + " -> " + path2 name_to_realities[name] += list(realities) for name, realities in name_to_realities.items(): for reality in realities: # IPython.embed() if reality not in self.metrics: continue if name not in self.metrics[reality]: continue if len(self.metrics[reality][name]) == 0: continue logger.update( f"{reality}_{name}", torch.mean(torch.stack(self.metrics[reality][name])), ) self.metrics = {} def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0,prefix=""): error_pairs = {"n(x)": "y^"} realities_map = {reality.name: reality for reality in realities} for name, config in (plot_names or self.plots.items()): paths = config["paths"] realities = config["realities"] images = [] error = False cmap = get_cmap("jet") first = True error_passed_ood = 0 for reality in realities: with torch.no_grad(): path_values = self.compute_paths(graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality]) shape = list(path_values[list(path_values.keys())[0]].shape) shape[1] = 3 for i, path in enumerate(paths): if path == 'depth': continue X = path_values.get(path, torch.zeros(shape, device=DEVICE)) if first: images +=[[]] if reality is 'ood' and error_passed_ood==0: images[i].append(X.clamp(min=0, max=1).expand(*shape)) elif reality is 'ood' and error_passed_ood==1: images[i+1].append(X.clamp(min=0, max=1).expand(*shape)) else: images[-1].append(X.clamp(min=0, max=1).expand(*shape)) if path in error_pairs: error = True if first: images += [[]] if error: Y = path_values.get(path, torch.zeros(shape, device=DEVICE)) Y_hat = path_values.get(error_pairs[path], torch.zeros(shape, device=DEVICE)) out_task = self.paths[path][-1] if self.target_task == "reshading": #Use depth mask Y_mask = path_values.get("depth", torch.zeros(shape, device = DEVICE)) mask_task = self.paths["r(x)"][-1] mask = ImageTask.build_mask(Y_mask, val=mask_task.mask_val) else: mask = ImageTask.build_mask(Y_hat, val=out_task.mask_val) errors = ((Y - Y_hat)**2).mean(dim=1, keepdim=True) log_errors = torch.log(errors.clamp(min=0, max=out_task.variance)) errors = (3*errors/(out_task.variance)).clamp(min=0, max=1) log_errors = torch.log(errors + 1) log_errors = log_errors / log_errors.max() log_errors = torch.tensor(cmap(log_errors.cpu()))[:, 0].permute((0, 3, 1, 2)).float()[:, 0:3] log_errors = log_errors.clamp(min=0, max=1).expand(*shape).to(DEVICE) log_errors[~mask.expand_as(log_errors)] = 0.505 if reality is 'ood': images[i+1].append(log_errors) error_passed_ood = 1 else: images[-1].append(log_errors) error = False first = False for i in range(0, len(images)): images[i] = torch.cat(images[i], dim=0) logger.images_grouped(images, f"{prefix}_{name}_[{', '.join(realities)}]_[{', '.join(paths)}]", resize=config["size"] ) def __repr__(self): return str(self.losses) class WinRateEnergyLoss(EnergyLoss): def __init__(self, *args, **kwargs): self.k = kwargs.pop('k', 3) self.random_select = kwargs.pop('random_select', False) self.running_stats = {} self.target_task = kwargs['paths']['y^'][0].name super().__init__(*args, **kwargs) self.percep_losses = [key[7:] for key in self.losses.keys() if key[0:7] == "percep_"] print ("percep losses:",self.percep_losses) self.chosen_losses = random.sample(self.percep_losses, self.k) def __call__(self, graph, discriminator=None, realities=[], loss_types=None, compute_grad_ratio=False): loss_types = ["mae"] + [("percep_" + loss) for loss in self.percep_losses] + [("direct_" + loss) for loss in self.percep_losses] # print (self.chosen_losses) loss_dict = super().__call__(graph, discriminator=discriminator, realities=realities, loss_types=loss_types, reduce=False) chosen_percep_mse_losses = [k for k in loss_dict.keys() if 'direct' not in k] percep_mse_coeffs = dict.fromkeys(chosen_percep_mse_losses, 1.0) ########### to compute loss coefficients ############# if compute_grad_ratio: percep_mse_gradnorms = dict.fromkeys(chosen_percep_mse_losses, 1.0) for loss_name in chosen_percep_mse_losses: loss_dict[loss_name].mean().backward(retain_graph=True) target_weights=list(graph.edge_map[f"('rgb', '{self.target_task}')"].model.parameters()) percep_mse_gradnorms[loss_name] = sum([l.grad.abs().sum().item() for l in target_weights])/sum([l.numel() for l in target_weights]) graph.optimizer.zero_grad() graph.zero_grad() del target_weights total_gradnorms = sum(percep_mse_gradnorms.values()) n_losses = len(chosen_percep_mse_losses) for loss_name, val in percep_mse_coeffs.items(): percep_mse_coeffs[loss_name] = (total_gradnorms-percep_mse_gradnorms[loss_name])/((n_losses-1)*total_gradnorms) percep_mse_coeffs["mae"] *= (n_losses-1) ########################################### for key in self.chosen_losses: winrate = torch.mean((loss_dict[f"percep_{key}"] > loss_dict[f"direct_{key}"]).float()) winrate = winrate.detach().cpu().item() if winrate < 1.0: self.running_stats[key] = winrate loss_dict[f"percep_{key}"] = loss_dict[f"percep_{key}"].mean() * percep_mse_coeffs[f"percep_{key}"] loss_dict.pop(f"direct_{key}") # print (self.running_stats) loss_dict["mae"] = loss_dict["mae"].mean() * percep_mse_coeffs["mae"] return loss_dict, percep_mse_coeffs["mae"] def logger_update(self, logger): super().logger_update(logger) if self.random_select or len(self.running_stats) < len(self.percep_losses): self.chosen_losses = random.sample(self.percep_losses, self.k) else: self.chosen_losses = sorted(self.running_stats, key=self.running_stats.get, reverse=True)[:self.k] logger.text (f"Chosen losses: {self.chosen_losses}") ================================================ FILE: graph.py ================================================ import os, sys, math, random, itertools, heapq from collections import namedtuple, defaultdict from functools import partial, reduce import numpy as np import IPython import torch import torch.nn as nn import torch.nn.functional as F from utils import * from models import TrainableModel, WrapperModel from datasets import TaskDataset from task_configs import get_task, task_map, tasks, get_model, RealityTask from transfers import Transfer, RealityTransfer, get_transfer_name #from modules.gan_dis import GanDisNet import pdb class TaskGraph(TrainableModel): """Basic graph that encapsulates set of edge constraints. Can be saved and loaded from directories.""" def __init__( self, tasks=tasks, edges=None, edges_exclude=None, pretrained=True, finetuned=False, reality=[], task_filter=[tasks.segment_semantic], freeze_list=[], lazy=False, initialize_from_transfer=True, ): super().__init__() self.tasks = list(set(tasks) - set(task_filter)) self.tasks += [task.base for task in self.tasks if hasattr(task, "base")] self.edge_list, self.edge_list_exclude = edges, edges_exclude self.pretrained, self.finetuned = pretrained, finetuned self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(list) self.edge_map, self.reality = {}, reality self.initialize_from_transfer = initialize_from_transfer print('Creating graph with tasks:', self.tasks) self.params = {} # construct transfer graph for src_task, dest_task in itertools.product(self.tasks, self.tasks): key = (src_task, dest_task) if edges is not None and key not in edges: continue if edges_exclude is not None and key in edges_exclude: continue if src_task == dest_task: continue if isinstance(dest_task, RealityTask): continue # print (src_task, dest_task) transfer = None if isinstance(src_task, RealityTask): if dest_task not in src_task.tasks: continue transfer = RealityTransfer(src_task, dest_task) else: transfer = Transfer(src_task, dest_task, pretrained=pretrained, finetuned=finetuned ) transfer.name = get_transfer_name(transfer) if not self.initialize_from_transfer: transfer.path = None if transfer.model_type is None: continue # print ("Added transfer", transfer) self.edges += [transfer] self.adj[src_task.name] += [transfer] self.in_adj[dest_task.name] += [transfer] self.edge_map[str((src_task.name, dest_task.name))] = transfer if isinstance(transfer, nn.Module): if str((src_task.name, dest_task.name)) not in freeze_list: self.params[str((src_task.name, dest_task.name))] = transfer else: print("Setting link: " + str((src_task.name, dest_task.name)) + " not trainable.") try: if not lazy: transfer.load_model() except Exception as e: print(e) IPython.embed() self.params = nn.ModuleDict(self.params) def edge(self, src_task, dest_task): key1 = str((src_task.name, dest_task.name)) key2 = str((src_task.kind, dest_task.kind)) if key1 in self.edge_map: return self.edge_map[key1] return self.edge_map[key2] def sample_path(self, path, reality=None, use_cache=False, cache={}): path = [reality or self.reality[0]] + path x = None for i in range(1, len(path)): try: # if x is not None: print (x.shape) # print (self.edge(path[i-1], path[i])) x = cache.get(tuple(path[0:(i+1)]), self.edge(path[i-1], path[i])(x) ) except KeyError: return None except Exception as e: print(e) IPython.embed() if use_cache: cache[tuple(path[0:(i+1)])] = x return x def save(self, weights_file=None, weights_dir=None): ### TODO: save optimizers here too if weights_file: torch.save({ key: model.state_dict() for key, model in self.edge_map.items() \ if not isinstance(model, RealityTransfer) }, weights_file) if weights_dir: os.makedirs(weights_dir, exist_ok=True) for key, model in self.edge_map.items(): if isinstance(model, RealityTransfer): continue if not isinstance(model.model, TrainableModel): continue model.model.save(f"{weights_dir}/{model.name}.pth") torch.save(self.optimizer, f"{weights_dir}/optimizer.pth") # def load_weights(self, weights_file=None): # for key, state_dict in torch.load(weights_file).items(): # if key in self.edge_map: # self.edge_map[key].load_state_dict(state_dict) def load_weights(self, weights_file=None): loaded_something = False for key, state_dict in torch.load(weights_file).items(): if key in self.edge_map: loaded_something = True self.edge_map[key].load_model() self.edge_map[key].load_state_dict(state_dict) if not loaded_something: raise RuntimeError(f"No edges loaded from file: {weights_file}") ================================================ FILE: hooks/build ================================================ #!/bin/bash docker build . -t $IMAGE_NAME --build-arg GITHUB_DEPLOY_KEY="$GITHUB_DEPLOY_KEY" --build-arg GITHUB_DEPLOY_KEY_PUBLIC="$GITHUB_DEPLOY_KEY_PUBLIC" ================================================ FILE: logger.py ================================================ import numpy as np import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import random, sys, os, json, math import torch from torchvision import datasets, transforms, utils import visdom from utils import * from utils import elapsed import IPython import pdb class BaseLogger(object): """ Logger class, with hooks for data features and plotting functions. """ def __init__(self, name, verbose=True): self.name = name self.data = {} self.running_data = {} self.reset_running = {} self.verbose = verbose self.hooks = [] def add_hook(self, hook, feature='epoch', freq=40): self.hooks.append((hook, feature, freq)) def update(self, feature, x): if isinstance(x, torch.Tensor): x = x.clone().detach().cpu().numpy().mean() else: x = torch.tensor(x).data.cpu().numpy().mean() self.data[feature] = self.data.get(feature, []) self.data[feature].append(x) if feature not in self.running_data or self.reset_running.pop(feature, False): self.running_data[feature] = [] self.running_data[feature].append(x) for hook, hook_feature, freq in self.hooks: if feature == hook_feature and len(self.data[feature]) % freq == 0: hook(self, self.data) def step(self): buf = "" buf += f"({self.name}) " for feature in self.running_data.keys(): if len(self.running_data[feature]) == 0: continue val = np.mean(self.running_data[feature]) if float(val).is_integer(): buf += f"{feature}: {int(val)}, " else: buf += f"{feature}: {val:0.4f}" + ", " self.reset_running[feature] = True buf += f" ... {elapsed():0.2f} sec" self.text (buf) def text(self, text, end="\n"): raise NotImplementedError() def plot(self, data, plot_name, opts={}): raise NotImplementedError() def images(self, data, image_name): raise NotImplementedError() def plot_feature(self, feature, opts={}): self.plot(self.data[feature], feature, opts) def plot_features(self, features, name, opts={}): stacked = np.stack([self.data[feature] for feature in features], axis=1) self.plot(stacked, name, opts={"legend": features}) class Logger(BaseLogger): def __init__(self, *args, **kwargs): self.results = kwargs.pop('results', 'output') super().__init__(*args, **kwargs) def text(self, text, end='\n'): print (text, end=end, flush=True) def plot(self, data, plot_name, opts={}): np.savez_compressed(f"{self.results}/{plot_name}.npz", data) plt.plot(data) plt.savefig(f"{self.results}/{plot_name}.jpg"); plt.clf() class VisdomLogger(BaseLogger): def __init__(self, *args, **kwargs): self.env = kwargs.pop('env', 'CH') self.port = kwargs.pop('port', 8097) self.server = kwargs.pop('server', '127.0.0.1') self.delete = kwargs.pop('delete', True) print ("No deletion") print ("In (git) scaling-reset") print (f"Logging to environment {self.env}") self.visdom = visdom.Visdom(server="http://" + self.server, port=self.port, env=self.env) self.visdom.delete_env(self.env) self.windows = {} super().__init__(*args, **kwargs) self.save() self.add_hook(lambda logger, data: self.save(), feature="epoch", freq=1) def text(self, text, end='\n'): print (text, end=end) window, old_text = self.windows.get('text', (None, "")) if end == '\n': end = '
' display = old_text + text + end if window is not None: window = self.visdom.text (display, win=window, append=False) else: window = self.visdom.text (display) self.windows["text"] = window, display def window(self, plot_name, plot_func, *args, **kwargs): options = {'title': plot_name} options.update(kwargs.pop("opts", {})) window = self.windows.get(plot_name, None) if window is not None and self.visdom.win_exists(window): window = plot_func(*args, **kwargs, opts=options, win=window) else: window = plot_func(*args, **kwargs, opts=options) self.windows[plot_name] = window def plot(self, data, plot_name, opts={}): self.window(plot_name, self.visdom.line, np.array(data), X=np.array(range(len(data))), opts=opts ) def histogram(self, data, plot_name, opts={}): self.window(plot_name, self.visdom.histogram, np.array(data), opts=opts) def scatter(self, X, Y, plot_name, opts={}): self.window(plot_name, self.visdom.scatter, np.stack([X, Y], axis=1), opts=opts) def bar(self, data, plot_name, opts={}): self.window(plot_name, self.visdom.bar, np.array(data), opts=opts) def save(self): self.visdom.save([self.env]) def images(self, data, plot_name, opts={}, nrow=2, normalize=False, resize=64): transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(resize), transforms.ToTensor()]) data = torch.stack([transform(x.cpu()) for x in data]) data = utils.make_grid(data, nrow=nrow, normalize=normalize, pad_value=0) self.window(plot_name, self.visdom.image, np.array(data), opts=opts) def images_grouped(self, image_groups, plot_name, **kwargs): interleave = [y for x in zip(*image_groups) for y in x] self.images(interleave, plot_name, nrow=len(image_groups), **kwargs) ================================================ FILE: models.py ================================================ import os, sys, random from inspect import signature import numpy as np import matplotlib as mpl import torch import torch.nn as nn import torch.nn.functional as F from torch import optim from utils import * """ Model that implements batchwise training with "compilation" and custom loss. Exposed methods: predict_on_batch(), fit_on_batch(), Overridable methods: loss(), forward(). """ class AbstractModel(nn.Module): def __init__(self): super().__init__() self.compiled = False # Compile module and assign optimizer + params def compile(self, optimizer=None, **kwargs): if optimizer is not None: self.optimizer_class = optimizer self.optimizer_kwargs = kwargs self.optimizer = self.optimizer_class(self.parameters(), **self.optimizer_kwargs) else: self.optimizer = None self.compiled = True self.to(DEVICE) # Predict scores from a batch of data def predict_on_batch(self, data): self.eval() with torch.no_grad(): return self.forward(data) # Fit (make one optimizer step) on a batch of data def fit_on_batch(self, data, target, loss_fn=None, train=True): loss_fn = loss_fn or self.loss self.zero_grad() self.optimizer.zero_grad() self.train(train) self.zero_grad() self.optimizer.zero_grad() pred = self.forward(data) if isinstance(target, list): target = tuple(t.to(pred.device) for t in target) else: target = target.to(pred.device) if len(signature(loss_fn).parameters) > 2: loss, metrics = loss_fn(pred, target, data.to(pred.device)) else: loss, metrics = loss_fn(pred, target) if train: loss.backward() self.optimizer.step() self.zero_grad() self.optimizer.zero_grad() return pred, loss, metrics # Make one optimizer step w.r.t a loss def step(self, loss, train=True): self.zero_grad() self.optimizer.zero_grad() self.train(train) self.zero_grad() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.zero_grad() self.optimizer.zero_grad() @classmethod def load(cls, weights_file=None): model = cls() if weights_file is not None: data = torch.load(weights_file) # hack for models saved with optimizers if "optimizer" in data: data = data["state_dict"] model.load_state_dict(data) return model def load_weights(self, weights_file, backward_compatible=False): data = torch.load(weights_file) if backward_compatible: data = {'parallel_apply.module.'+k:v for k,v in data.items()} self.load_state_dict(data) def save(self, weights_file): torch.save(self.state_dict(), weights_file) # Subclasses: override for custom loss + forward functions def loss(self, pred, target): raise NotImplementedError() def forward(self, x): raise NotImplementedError() """ Model that implements training and prediction on generator objects, with the ability to print train and validation metrics. """ class TrainableModel(AbstractModel): def __init__(self): super().__init__() # Fit on generator for one epoch def _process_data(self, datagen, loss_fn=None, train=True, logger=None): self.train(train) out = [] for data in datagen: batch, y = data[0], data[1:] if len(y) == 1: y = y[0] y_pred, loss, metric_data = self.fit_on_batch(batch, y, loss_fn=loss_fn, train=train) if logger is not None: logger.update("loss", float(loss)) yield ((batch.detach(), y_pred.detach(), y, float(loss), metric_data)) def fit(self, datagen, loss_fn=None, logger=None): for x in self._process_data(datagen, loss_fn=loss_fn, train=train, logger=logger): pass def fit_with_data(self, datagen, loss_fn=None, logger=None): images, preds, targets, losses, metrics = zip( *self._process_data(datagen, loss_fn=loss_fn, train=True, logger=logger) ) images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0) metrics = zip(*metrics) return images, preds, targets, losses, metrics def fit_with_metrics(self, datagen, loss_fn=None, logger=None): metrics = [ metrics for _, _, _, _, metrics in self._process_data( datagen, loss_fn=loss_fn, train=True, logger=logger ) ] return list(zip(*metrics)) def predict_with_data(self, datagen, loss_fn=None, logger=None): images, preds, targets, losses, metrics = zip( *self._process_data(datagen, loss_fn=loss_fn, train=False, logger=logger) ) images, preds, targets = torch.cat(images, dim=0), torch.cat(preds, dim=0), torch.cat(targets, dim=0) images, preds, targets = images.cpu(), preds.cpu(), targets.cpu() # preds = torch.cat(preds, dim=0) metrics = zip(*metrics) return images, preds, targets, losses, metrics def predict_with_metrics(self, datagen, loss_fn=None, logger=None): metrics = [ metrics for _, _, _, _, metrics in self._process_data( datagen, loss_fn=loss_fn, train=False, logger=logger ) ] return list(zip(*metrics)) def predict(self, datagen): preds = [self.predict_on_batch(x) for x in datagen] preds = torch.cat(preds, dim=0) return preds class DataParallelModel(TrainableModel): def __init__(self, *args, **kwargs): super().__init__() self.parallel_apply = nn.DataParallel(*args, **kwargs) def forward(self, x): return self.parallel_apply(x) def loss(self, x, preds): return self.parallel_apply.module.loss(x, preds) @property def module(self): return self.parallel_apply.module @classmethod def load(cls, model=TrainableModel(), weights_file=None): model = cls(model) if weights_file is not None: data = torch.load(weights_file, map_location=lambda storage, loc: storage) # hack for models saved with optimizers if "optimizer" in data: data = data["state_dict"] try: model.load_state_dict(data) except RuntimeError: parallel_module_data = { 'parallel_apply.module.' + k: v for k, v in data.items() } model.load_state_dict(parallel_module_data) return model class WrapperModel(TrainableModel): def __init__(self, model): super().__init__() self.model = model def forward(self, x): return self.model(x) def loss(self, x, preds): raise NotImplementedError() def __getitem__(self, i): return self.model[i] @property def module(self): return self.model if __name__ == "__main__": import IPython IPython.embed() ================================================ FILE: modules/__init__.py ================================================ ================================================ FILE: modules/depth_nets.py ================================================ import os, sys, math, random, itertools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from models import TrainableModel from utils import * class ResidualsNet(TrainableModel): def __init__(self): super().__init__() self.encoder = nn.Sequential( ConvBlock(3, 32, groups=3, use_groupnorm=False), ConvBlock(32, 32, use_groupnorm=False), ) self.mid = nn.Sequential( ConvBlock(32, 64, dilation=1, use_groupnorm=False), ConvBlock(64, 64, dilation=2, use_groupnorm=False), ConvBlock(64, 64, dilation=2, use_groupnorm=False), ConvBlock(64, 32, dilation=4, use_groupnorm=False), ) self.decoder = nn.Sequential( ConvBlock(64, 32, use_groupnorm=False), ConvBlock(32, 32, use_groupnorm=False), ConvBlock(32, 1, use_groupnorm=False), ) def forward(self, x): tmp = self.encoder(x) x = F.max_pool2d(tmp, 4) x = self.mid(x) x = F.upsample(x, scale_factor=4, mode='bilinear') x = torch.cat([x, tmp], dim=1) x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class UNet_up_block(nn.Module): def __init__(self, prev_channel, input_channel, output_channel, up_sample=True): super().__init__() self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn3 = nn.GroupNorm(8, output_channel) self.relu = torch.nn.ReLU() self.up_sample = up_sample def forward(self, prev_feature_map, x): if self.up_sample: x = self.up_sampling(x) x = torch.cat((x, prev_feature_map), dim=1) x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) return x class UNet_down_block(nn.Module): def __init__(self, input_channel, output_channel, down_size=True): super().__init__() self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn3 = nn.GroupNorm(8, output_channel) self.max_pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU() self.down_size = down_size def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) if self.down_size: x = self.max_pool(x) return x class UNetDepth(TrainableModel): def __init__(self): super().__init__() self.down_block1 = UNet_down_block(3, 16, False) self.down_block2 = UNet_down_block(16, 32, True) self.down_block3 = UNet_down_block(32, 64, True) self.down_block4 = UNet_down_block(64, 128, True) self.down_block5 = UNet_down_block(128, 256, True) self.down_block6 = UNet_down_block(256, 512, True) self.down_block7 = UNet_down_block(512, 1024, False) self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn1 = nn.GroupNorm(8, 1024) self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn2 = nn.GroupNorm(8, 1024) self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1) self.bn3 = torch.nn.GroupNorm(8, 1024) self.up_block1 = UNet_up_block(512, 1024, 512, False) self.up_block2 = UNet_up_block(256, 512, 256, True) self.up_block3 = UNet_up_block(128, 256, 128, True) self.up_block4 = UNet_up_block(64, 128, 64, True) self.up_block5 = UNet_up_block(32, 64, 32, True) self.up_block6 = UNet_up_block(16, 32, 16, True) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, 1, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.x1 = self.down_block1(x) x = self.x2 = self.down_block2(self.x1) x = self.x3 = self.down_block3(self.x2) x = self.x4 = self.down_block4(self.x3) x = self.x5 = self.down_block5(self.x4) x = self.x6 = self.down_block6(self.x5) x = self.x7 = self.down_block7(self.x6) x = self.relu(self.bn1(self.mid_conv1(x))) x = self.relu(self.bn2(self.mid_conv2(x))) x = self.relu(self.bn3(self.mid_conv3(x))) x = self.up_block1(self.x6, x) x = self.up_block2(self.x5, x) x = self.up_block3(self.x4, x) x = self.up_block4(self.x3, x) x = self.up_block5(self.x2, x) x = self.up_block6(self.x1, x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.last_conv2(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ConvBlock(nn.Module): def __init__(self, f1, f2, use_groupnorm=True, groups=8, dilation=1, transpose=False): super().__init__() self.transpose = transpose self.conv = nn.Conv2d(f1, f2, (3, 3), dilation=dilation, padding=dilation) if self.transpose: self.convt = nn.ConvTranspose2d( f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1 ) if use_groupnorm: self.bn = nn.GroupNorm(groups, f1) else: self.bn = nn.GroupNorm(8, f1) def forward(self, x): # x = F.dropout(x, 0.04, self.training) x = self.bn(x) if self.transpose: # x = F.upsample(x, scale_factor=2, mode='bilinear') x = F.relu(self.convt(x)) # x = x[:, :, :-1, :-1] x = F.relu(self.conv(x)) return x class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.GroupNorm(8, planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.GroupNorm(8, planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.GroupNorm(8, planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNetOriginal(nn.Module): def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNetOriginal, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.GroupNorm(8, 64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 196, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.GroupNorm(8, planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x class ResNetDepth(TrainableModel): def __init__(self): super().__init__() # self.resnet = models.resnet50() self.resnet = ResNetOriginal(Bottleneck, [3, 4, 6, 3]) self.final_conv = nn.Conv2d(2048, 8, (3, 3), padding=1) self.decoder = nn.Sequential( ConvBlock(8, 128), ConvBlock(128, 128), ConvBlock(128, 128), ConvBlock(128, 128), ConvBlock(128, 128), ConvBlock(128, 128, transpose=True), ConvBlock(128, 128, transpose=True), ConvBlock(128, 128, transpose=True), ConvBlock(128, 128, transpose=True), ConvBlock(128, 1, transpose=True), ) def forward(self, x): for layer in list(self.resnet._modules.values())[:-2]: x = layer(x) x = self.final_conv(x) x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) ================================================ FILE: modules/percep_nets.py ================================================ import os, sys, math, random, itertools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from models import TrainableModel from utils import * class ConvBlock(nn.Module): def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False): super().__init__() self.transpose = transpose self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation) if self.transpose: self.convt = nn.ConvTranspose2d( f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1 ) if use_groupnorm: self.bn = nn.GroupNorm(groups, f1) else: self.bn = nn.BatchNorm2d(f1) def forward(self, x): # x = F.dropout(x, 0.04, self.training) x = self.bn(x) if self.transpose: # x = F.upsample(x, scale_factor=2, mode='bilinear') x = F.relu(self.convt(x)) # x = x[:, :, :-1, :-1] x = F.relu(self.conv(x)) return x class DenseNet(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 96, groups=3), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class Dense1by1Net(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), ConvBlock(64, 96, kernel_size=1, padding=0), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class Dense1by1end(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), ConvBlock(64, 96, kernel_size=1, padding=0), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 1), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class DenseKernelsNet(TrainableModel): def __init__(self, kernel_size=7): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 64, groups=3, kernel_size=1, padding=0), ConvBlock(64, 96, kernel_size=1, padding=0), ConvBlock(96, 96, kernel_size=1, padding=0), ConvBlock(96, 96, kernel_size=kernel_size, padding=kernel_size//2), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 96), ConvBlock(96, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class DeepNet(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 32, groups=3), ConvBlock(32, 32), ConvBlock(32, 32, dilation=2), ConvBlock(32, 32, dilation=2), ConvBlock(32, 32, dilation=4), ConvBlock(32, 32, dilation=4), ConvBlock(32, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class WideNet(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 32, groups=3), ConvBlock(32, 32, kernel_size=5, padding=2), ConvBlock(32, 32, kernel_size=5, padding=2), ConvBlock(32, 32, kernel_size=5, padding=2), ConvBlock(32, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class PyramidNet(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 16, groups=3), ConvBlock(16, 32, kernel_size=5, padding=2), ConvBlock(32, 64, kernel_size=5, padding=2), ConvBlock(64, 96, kernel_size=3, padding=1), ConvBlock(96, 32, kernel_size=3, padding=1), ConvBlock(32, 3), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class BaseNet(TrainableModel): def __init__(self): super().__init__() self.decoder = nn.Sequential( ConvBlock(3, 32, use_groupnorm=False), ConvBlock(32, 32, use_groupnorm=False), ConvBlock(32, 32, use_groupnorm=False), ConvBlock(32, 1, use_groupnorm=False), ) def forward(self, x): x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ResidualsNet(TrainableModel): def __init__(self): super().__init__() self.encoder = nn.Sequential( ConvBlock(3, 32, use_groupnorm=False), ConvBlock(32, 32, use_groupnorm=False), ) self.mid = nn.Sequential( ConvBlock(32, 64, use_groupnorm=False), ConvBlock(64, 64, use_groupnorm=False), ConvBlock(64, 32, use_groupnorm=False), ) self.decoder = nn.Sequential( ConvBlock(64, 32, use_groupnorm=False), ConvBlock(32, 3, use_groupnorm=False), ) def forward(self, x): tmp = self.encoder(x) x = F.max_pool2d(tmp, 2) x = self.mid(x) x = F.upsample(x, scale_factor=2, mode='bilinear') x = torch.cat([x, tmp], dim=1) x = self.decoder(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ResNet50(TrainableModel): def __init__(self, num_classes=365, in_channels=3): super().__init__() self.resnet = models.resnet18(num_classes=num_classes) self.resnet.fc = nn.Linear(in_features=8192, out_features=num_classes, bias=True) self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) def forward(self, x): x = self.resnet(x) return F.log_softmax(x, dim=1) def loss(self, pred, target): loss = F.nll_loss(pred, target) return loss, (loss.detach(),) ================================================ FILE: modules/resnet.py ================================================ import os, sys, math, random, itertools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from models import TrainableModel from utils import * def conv3x3(in_planes, out_planes, stride=1, groups=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, norm_layer=None): super(BasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.GroupNorm(8, planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.GroupNorm(8, planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.GroupNorm(8, planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.GroupNorm(8, planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.GroupNorm(8, planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNetOriginal(nn.Module): def __init__(self, block, layers, in_channels=3, num_classes=1000): self.inplanes = 64 super(ResNetOriginal, self).__init__() self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.GroupNorm(8, 64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=1) self.layer4 = self._make_layer(block, 512, layers[3], stride=1) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.GroupNorm(8, planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x class ResNet(TrainableModel): def __init__(self, in_channels=3, out_channels=1000): super().__init__() self.resnet = ResNetOriginal(BasicBlock, [2, 2, 2, 2], in_channels=in_channels) self.final = nn.Linear(512, out_channels) def forward(self, x): for layer in list(self.resnet._modules.values())[:-2]: x = layer(x) x = F.relu(x.mean(dim=2).mean(dim=2)) x = F.log_softmax(self.final(x), dim=1) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ResNetClass(TrainableModel): def __init__(self): super().__init__() self.resnet = models.resnet50(pretrained=True) def forward(self, x): if x.shape[1] == 1: x = x.repeat(1,3,1,1) for layer in list(self.resnet._modules.values())[:-2]: x = layer(x) return x ### Not in Use Right Now ### def loss(self, pred, target): mask = build_mask(pred, val=0.502) mse = F.mse_loss(pred[mask], target[mask]) return mse, (mse.detach(),) if __name__ == "__main__": model = ResNet(out_channels=365) print (model(torch.randn(2, 3, 224, 224 )).shape) ================================================ FILE: modules/unet.py ================================================ import os, sys, math, random, itertools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from torch.utils.checkpoint import checkpoint from models import TrainableModel from utils import * class UNet_up_block(nn.Module): def __init__(self, prev_channel, input_channel, output_channel, up_sample=True): super().__init__() self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn3 = nn.GroupNorm(8, output_channel) self.relu = torch.nn.ReLU() self.up_sample = up_sample def forward(self, prev_feature_map, x): if self.up_sample: x = self.up_sampling(x) x = torch.cat((x, prev_feature_map), dim=1) x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) return x class UNet_down_block(nn.Module): def __init__(self, input_channel, output_channel, down_size=True): super().__init__() self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) self.bn3 = nn.GroupNorm(8, output_channel) self.max_pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU() self.down_size = down_size def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) if self.down_size: x = self.max_pool(x) return x class UNet(TrainableModel): def __init__(self, downsample=6, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample self.down1 = UNet_down_block(in_channels, 16, False) self.down_blocks = nn.ModuleList( [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] ) bottleneck = 2**(4 + downsample) self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn1 = nn.GroupNorm(8, bottleneck) self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn2 = nn.GroupNorm(8, bottleneck) self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn3 = nn.GroupNorm(8, bottleneck) self.up_blocks = nn.ModuleList( [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] ) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.down1(x) xvals = [x] for i in range(0, self.downsample): x = self.down_blocks[i](x) xvals.append(x) x = self.relu(self.bn1(self.mid_conv1(x))) x = self.relu(self.bn2(self.mid_conv2(x))) x = self.relu(self.bn3(self.mid_conv3(x))) for i in range(0, self.downsample)[::-1]: x = self.up_blocks[i](xvals[i], x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class UNetReshade(TrainableModel): def __init__(self, downsample=6, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample self.down1 = UNet_down_block(in_channels, 16, False) self.down_blocks = nn.ModuleList( [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] ) bottleneck = 2**(4 + downsample) self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn1 = nn.GroupNorm(8, bottleneck) self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn2 = nn.GroupNorm(8, bottleneck) self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn3 = nn.GroupNorm(8, bottleneck) self.up_blocks = nn.ModuleList( [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] ) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.down1(x) xvals = [x] for i in range(0, self.downsample): x = self.down_blocks[i](x) xvals.append(x) x = self.relu(self.bn1(self.mid_conv1(x))) x = self.relu(self.bn2(self.mid_conv2(x))) x = self.relu(self.bn3(self.mid_conv3(x))) for i in range(0, self.downsample)[::-1]: x = self.up_blocks[i](xvals[i], x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True) x = x.expand(-1, 3, -1, -1) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class UNetOld(TrainableModel): def __init__(self, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels = in_channels, out_channels self.down_block1 = UNet_down_block(in_channels, 16, False) # 256 self.down_block2 = UNet_down_block(16, 32, True) # 128 self.down_block3 = UNet_down_block(32, 64, True) # 64 self.down_block4 = UNet_down_block(64, 128, True) # 32 self.down_block5 = UNet_down_block(128, 256, True) # 16 self.down_block6 = UNet_down_block(256, 512, True) # 8 self.down_block7 = UNet_down_block(512, 1024, True)# 4 self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn1 = nn.GroupNorm(8, 1024) self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn2 = nn.GroupNorm(8, 1024) self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1) self.bn3 = nn.GroupNorm(8, 1024) self.up_block1 = UNet_up_block(512, 1024, 512) self.up_block2 = UNet_up_block(256, 512, 256) self.up_block3 = UNet_up_block(128, 256, 128) self.up_block4 = UNet_up_block(64, 128, 64) self.up_block5 = UNet_up_block(32, 64, 32) self.up_block6 = UNet_up_block(16, 32, 16) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): self.x1 = self.down_block1(x) self.x2 = self.down_block2(self.x1) self.x3 = self.down_block3(self.x2) self.x4 = self.down_block4(self.x3) self.x5 = self.down_block5(self.x4) self.x6 = self.down_block6(self.x5) self.x7 = self.down_block7(self.x6) self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7))) self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7))) self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7))) x = self.up_block1(self.x6, self.x7) x = self.up_block2(self.x5, x) x = self.up_block3(self.x4, x) x = self.up_block4(self.x3, x) x = self.up_block5(self.x2, x) x = self.up_block6(self.x1, x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ConvBlock(nn.Module): def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False): super().__init__() self.transpose = transpose self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation) if self.transpose: self.convt = nn.ConvTranspose2d( f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1 ) if use_groupnorm: self.bn = nn.GroupNorm(groups, f1) else: self.bn = nn.BatchNorm2d(f1) def forward(self, x): # x = F.dropout(x, 0.04, self.training) x = self.bn(x) if self.transpose: # x = F.upsample(x, scale_factor=2, mode='bilinear') x = F.relu(self.convt(x)) # x = x[:, :, :-1, :-1] x = F.relu(self.conv(x)) return x class UNetOld2(TrainableModel): def __init__(self, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels = in_channels, out_channels self.initial = nn.Sequential( ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0), ConvBlock(16, 16, groups=4, kernel_size=1, padding=0) ) self.down_block1 = UNet_down_block(16, 16, False) self.down_block2 = UNet_down_block(16, 32, True) # 128 self.down_block3 = UNet_down_block(32, 64, True) # 64 self.down_block4 = UNet_down_block(64, 128, True) # 32 self.down_block5 = UNet_down_block(128, 256, True) # 16 self.down_block6 = UNet_down_block(256, 512, True) # 8 self.down_block7 = UNet_down_block(512, 1024, True)# 4 self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn1 = nn.GroupNorm(8, 1024) self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn2 = nn.GroupNorm(8, 1024) self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1) self.bn3 = nn.GroupNorm(8, 1024) self.up_block1 = UNet_up_block(512, 1024, 512) self.up_block2 = UNet_up_block(256, 512, 256) self.up_block3 = UNet_up_block(128, 256, 128) self.up_block4 = UNet_up_block(64, 128, 64) self.up_block5 = UNet_up_block(32, 64, 32) self.up_block6 = UNet_up_block(16, 32, 16) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.initial(x) self.x1 = self.down_block1(x) self.x2 = self.down_block2(self.x1) self.x3 = self.down_block3(self.x2) self.x4 = self.down_block4(self.x3) self.x5 = self.down_block5(self.x4) self.x6 = self.down_block6(self.x5) self.x7 = self.down_block7(self.x6) self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7))) self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7))) self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7))) x = self.up_block1(self.x6, self.x7) x = self.up_block2(self.x5, x) x = self.up_block3(self.x4, x) x = self.up_block4(self.x3, x) x = self.up_block5(self.x2, x) x = self.up_block6(self.x1, x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) ================================================ FILE: modules/unet_mirrored.py ================================================ import os, sys, math, random, itertools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms, models from torch.optim.lr_scheduler import MultiStepLR from torch.utils.checkpoint import checkpoint from models import TrainableModel from utils import * class UNet_up_block(nn.Module): def __init__(self, prev_channel, input_channel, output_channel, up_sample=True): super().__init__() self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3) self.bn3 = nn.GroupNorm(8, output_channel) self.relu = torch.nn.ReLU() self.up_sample = up_sample def forward(self, prev_feature_map, x): if self.up_sample: x = self.up_sampling(x) x = torch.cat((x, prev_feature_map), dim=1) x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect'))) x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect'))) x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect'))) return x class UNet_down_block(nn.Module): def __init__(self, input_channel, output_channel, down_size=True): super().__init__() self.conv1 = nn.Conv2d(input_channel, output_channel, 3) self.bn1 = nn.GroupNorm(8, output_channel) self.conv2 = nn.Conv2d(output_channel, output_channel, 3) self.bn2 = nn.GroupNorm(8, output_channel) self.conv3 = nn.Conv2d(output_channel, output_channel, 3) self.bn3 = nn.GroupNorm(8, output_channel) self.max_pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU() self.down_size = down_size def forward(self, x): x = self.relu(self.bn1(F.pad(self.conv1(x), (1, 1, 1, 1), mode='reflect'))) x = self.relu(self.bn2(F.pad(self.conv2(x), (1, 1, 1, 1), mode='reflect'))) x = self.relu(self.bn3(F.pad(self.conv3(x), (1, 1, 1, 1), mode='reflect'))) if self.down_size: x = self.max_pool(x) return x class UNet(TrainableModel): def __init__(self, downsample=6, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample self.down1 = UNet_down_block(in_channels, 16, False) self.down_blocks = nn.ModuleList( [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] ) bottleneck = 2**(4 + downsample) self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn1 = nn.GroupNorm(8, bottleneck) self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn2 = nn.GroupNorm(8, bottleneck) self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn3 = nn.GroupNorm(8, bottleneck) self.up_blocks = nn.ModuleList( [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] ) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.down1(x) xvals = [x] pad = [False for i in range(0, self.downsample)] for i in range(0, self.downsample): if x.shape[2] % 2 != 0: x = F.pad(x, (1, 0, 1, 0)) pad[i] = True x = self.down_blocks[i](x) xvals.append(x) x = self.relu(self.bn1(self.mid_conv1(x))) x = self.relu(self.bn2(self.mid_conv2(x))) x = self.relu(self.bn3(self.mid_conv3(x))) for i in range(0, self.downsample)[::-1]: # print (x.shape, xvals[i].shape) x = self.up_blocks[i](xvals[i], x) if pad[i] != 0: x = x[:, :, 1:, 1:] x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class UNetReshade(TrainableModel): def __init__(self, downsample=6, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample self.down1 = UNet_down_block(in_channels, 16, False) self.down_blocks = nn.ModuleList( [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] ) bottleneck = 2**(4 + downsample) self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn1 = nn.GroupNorm(8, bottleneck) self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn2 = nn.GroupNorm(8, bottleneck) self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) self.bn3 = nn.GroupNorm(8, bottleneck) self.up_blocks = nn.ModuleList( [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] ) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.down1(x) xvals = [x] for i in range(0, self.downsample): x = self.down_blocks[i](x) xvals.append(x) x = self.relu(self.bn1(self.mid_conv1(x))) x = self.relu(self.bn2(self.mid_conv2(x))) x = self.relu(self.bn3(self.mid_conv3(x))) for i in range(0, self.downsample)[::-1]: x = self.up_blocks[i](xvals[i], x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) x = x.clamp(max=1, min=0).mean(dim=1, keepdim=True) x = x.expand(-1, 3, -1, -1) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class UNetOld(TrainableModel): def __init__(self, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels = in_channels, out_channels self.down_block1 = UNet_down_block(in_channels, 16, False) # 256 self.down_block2 = UNet_down_block(16, 32, True) # 128 self.down_block3 = UNet_down_block(32, 64, True) # 64 self.down_block4 = UNet_down_block(64, 128, True) # 32 self.down_block5 = UNet_down_block(128, 256, True) # 16 self.down_block6 = UNet_down_block(256, 512, True) # 8 self.down_block7 = UNet_down_block(512, 1024, True)# 4 self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn1 = nn.GroupNorm(8, 1024) self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn2 = nn.GroupNorm(8, 1024) self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1) self.bn3 = nn.GroupNorm(8, 1024) self.up_block1 = UNet_up_block(512, 1024, 512) self.up_block2 = UNet_up_block(256, 512, 256) self.up_block3 = UNet_up_block(128, 256, 128) self.up_block4 = UNet_up_block(64, 128, 64) self.up_block5 = UNet_up_block(32, 64, 32) self.up_block6 = UNet_up_block(16, 32, 16) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): self.x1 = self.down_block1(x) self.x2 = self.down_block2(self.x1) self.x3 = self.down_block3(self.x2) self.x4 = self.down_block4(self.x3) self.x5 = self.down_block5(self.x4) self.x6 = self.down_block6(self.x5) self.x7 = self.down_block7(self.x6) self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7))) self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7))) self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7))) x = self.up_block1(self.x6, self.x7) x = self.up_block2(self.x5, x) x = self.up_block3(self.x4, x) x = self.up_block4(self.x3, x) x = self.up_block5(self.x2, x) x = self.up_block6(self.x1, x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) class ConvBlock(nn.Module): def __init__(self, f1, f2, kernel_size=3, padding=1, use_groupnorm=True, groups=8, dilation=1, transpose=False): super().__init__() self.transpose = transpose self.conv = nn.Conv2d(f1, f2, (kernel_size, kernel_size), dilation=dilation, padding=padding*dilation) if self.transpose: self.convt = nn.ConvTranspose2d( f1, f1, (3, 3), dilation=dilation, stride=2, padding=dilation, output_padding=1 ) if use_groupnorm: self.bn = nn.GroupNorm(groups, f1) else: self.bn = nn.BatchNorm2d(f1) def forward(self, x): # x = F.dropout(x, 0.04, self.training) x = self.bn(x) if self.transpose: # x = F.upsample(x, scale_factor=2, mode='bilinear') x = F.relu(self.convt(x)) # x = x[:, :, :-1, :-1] x = F.relu(self.conv(x)) return x class UNetOld2(TrainableModel): def __init__(self, in_channels=3, out_channels=3): super().__init__() self.in_channels, self.out_channels = in_channels, out_channels self.initial = nn.Sequential( ConvBlock(in_channels, 16, groups=3, kernel_size=1, padding=0), ConvBlock(16, 16, groups=4, kernel_size=1, padding=0) ) self.down_block1 = UNet_down_block(16, 16, False) self.down_block2 = UNet_down_block(16, 32, True) # 128 self.down_block3 = UNet_down_block(32, 64, True) # 64 self.down_block4 = UNet_down_block(64, 128, True) # 32 self.down_block5 = UNet_down_block(128, 256, True) # 16 self.down_block6 = UNet_down_block(256, 512, True) # 8 self.down_block7 = UNet_down_block(512, 1024, True)# 4 self.mid_conv1 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn1 = nn.GroupNorm(8, 1024) self.mid_conv2 = nn.Conv2d(1024, 1024, 3, padding=1) self.bn2 = nn.GroupNorm(8, 1024) self.mid_conv3 = torch.nn.Conv2d(1024, 1024, 3, padding=1) self.bn3 = nn.GroupNorm(8, 1024) self.up_block1 = UNet_up_block(512, 1024, 512) self.up_block2 = UNet_up_block(256, 512, 256) self.up_block3 = UNet_up_block(128, 256, 128) self.up_block4 = UNet_up_block(64, 128, 64) self.up_block5 = UNet_up_block(32, 64, 32) self.up_block6 = UNet_up_block(16, 32, 16) self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) self.last_bn = nn.GroupNorm(8, 16) self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.initial(x) self.x1 = self.down_block1(x) self.x2 = self.down_block2(self.x1) self.x3 = self.down_block3(self.x2) self.x4 = self.down_block4(self.x3) self.x5 = self.down_block5(self.x4) self.x6 = self.down_block6(self.x5) self.x7 = self.down_block7(self.x6) self.x7 = self.relu(self.bn1(self.mid_conv1(self.x7))) self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7))) self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7))) x = self.up_block1(self.x6, self.x7) x = self.up_block2(self.x5, x) x = self.up_block3(self.x4, x) x = self.up_block4(self.x3, x) x = self.up_block5(self.x2, x) x = self.up_block6(self.x1, x) x = self.relu(self.last_bn(self.last_conv1(x))) x = self.relu(self.last_conv2(x)) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) if __name__ == "__main__": model = UNet() x = torch.randn(1, 3, 256, 256) print (model(x).shape) ================================================ FILE: plotting.py ================================================ import numpy as np def jointplot(logger, data, loss_type="mse_loss"): data = np.stack((data[f"train_{loss_type}"], data[f"val_{loss_type}"]), axis=1) logger.plot(data, loss_type, opts={"legend": [f"train_{loss_type}", f"val_{loss_type}"]}) def get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(list_of_list_values): running_mean_and_std_bounds = [] legend = ["Mean-STD", "Mean Difference", "Mean+STD"] for ii, losses_in_batch_ii in enumerate(list_of_list_values): if ii == 0: # there's no previous time step to compare to running_mean_and_std_bounds.append([0, 0, 0]) else: loss_diffs = [loss_val - list_of_list_values[ii - 1][jj] for jj, loss_val in enumerate(losses_in_batch_ii)] mean = np.mean(loss_diffs) std = np.std(loss_diffs) running_mean_and_std_bounds.append([mean - std, mean, mean + std]) return running_mean_and_std_bounds, legend def get_running_means_w_std_bounds_and_legend(list_of_list_values): running_mean_and_std_bounds = [] legend = ["Mean-STD", "Mean", "Mean+STD"] for ii in range(len(list_of_list_values)): mean = np.mean(list_of_list_values[ii]) std = np.std(list_of_list_values[ii]) running_mean_and_std_bounds.append([mean - std, mean, mean + std]) return running_mean_and_std_bounds, legend def get_running_std(list_of_list_values): return [np.std(list_of_list_values[ii]) for ii in range(len(list_of_list_values))] def get_running_p_coeffs(list_of_list_values_1, list_of_list_values_2): assert len(list_of_list_values_1) == len(list_of_list_values_2) pearson_coefficients = [] for ii in range(len(list_of_list_values_1)): cov = np.cov(np.stack((list_of_list_values_1[ii], list_of_list_values_2[ii]), axis=0))[0, 1] std1 = np.std(list_of_list_values_1[ii]) std2 = np.std(list_of_list_values_2[ii]) correlation_coefficient = cov / (std1 * std2) pearson_coefficients.append(correlation_coefficient) return pearson_coefficients def mseplots(data, logger): data = np.stack((logger.data["train_mse_loss"], logger.data["val_mse_loss"]), axis=1) logger.plot(data, "mse_loss", opts={"legend": ["train_mse", "val_mse"]}) running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data["val_mse_losses"]) logger.plot(running_mean_and_std_bounds, "val_mse_loss_running_mean", opts={"legend": legend}) logger.plot(get_running_std(logger.data["val_mse_losses"]), "val_mse_losses_running_stds", opts={"legend": ['STD']}) running_mean_and_std_bounds_diff_prev_time_step, legend = \ get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_mse_losses"]) logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_mse_loss_diff_prev_step_running_mean", opts={"legend": legend}) def curvatureplots(data, logger): data = np.stack((logger.data["train_curvature_loss"], logger.data["val_curvature_loss"]), axis=1) logger.plot(data, "curvature_loss", opts={"legend": ["train_curvature", "val_curvature"]}) running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend( logger.data["val_curvature_losses"]) logger.plot(running_mean_and_std_bounds, "val_curvature_loss_running_mean", opts={"legend": legend}) logger.plot(get_running_std(logger.data["val_curvature_losses"]), "val_curvature_losses_running_stds", opts={"legend": ['STD']}) running_mean_and_std_bounds_diff_prev_time_step, legend = \ get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_curvature_losses"]) logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_curvature_loss_diff_prev_step_running_mean", opts={"legend": legend}) def depthplots(data, logger): data = np.stack((logger.data["train_depth_loss"], logger.data["val_depth_loss"]), axis=1) logger.plot(data, "depth_loss", opts={"legend": ["train_depth", "val_depth"]}) running_mean_and_std_bounds, legend = get_running_means_w_std_bounds_and_legend(logger.data["val_depth_losses"]) logger.plot(running_mean_and_std_bounds, "val_depth_loss_running_mean", opts={"legend": legend}) logger.plot(get_running_std(logger.data["val_depth_losses"]), "val_depth_losses_running_stds", opts={"legend": ['STD']}) running_mean_and_std_bounds_diff_prev_time_step, legend = \ get_running_means_w_std_bounds_and_legend_on_diff_prev_time_step(logger.data["val_depth_losses"]) logger.plot(running_mean_and_std_bounds_diff_prev_time_step, "val_depth_loss_diff_prev_step_running_mean", opts={"legend": legend}) def covarianceplot(data, logger): covs = get_running_p_coeffs(logger.data["val_mse_losses"], logger.data["val_curvature_losses"]) logger.plot(covs, "val_mse_curvature_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']}) covs = get_running_p_coeffs(logger.data["val_mse_losses"], logger.data["val_depth_losses"]) logger.plot(covs, "val_mse_depth_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']}) covs = get_running_p_coeffs(logger.data["val_curvature_losses"], logger.data["val_depth_losses"]) logger.plot(covs, "train_curvature_depth_running_pearson_coeffs", opts={"legend": ['Pearson Coefficient']}) ratio_mse_curv_stds = [mse_std / curv_std for mse_std, curv_std in zip(get_running_std(logger.data["val_mse_losses"]), get_running_std(logger.data["val_curvature_losses"]))] logger.plot(ratio_mse_curv_stds, "val_mse_over_curvature_stds", opts={"legend": ['MSE / Curvature STD']}) ratio_mse_depth_stds = [mse_std / depth_std for mse_std, depth_std in zip(get_running_std(logger.data["val_mse_losses"]), get_running_std(logger.data["val_depth_losses"]))] logger.plot(ratio_mse_depth_stds, "val_mse_over_depth_stds", opts={"legend": ['MSE / Depth STD']}) ================================================ FILE: requirements.txt ================================================ fire==0.2.1 ipython==6.5.0 matplotlib==3.0.3 numpy==1.17.2 parse==1.12.1 pip==19.3.1 plac==0.9.6 py==1.6.0 scipy==1.3.1 scikit-image==0.16.2 scikit-learn==0.22.1 torch==1.2.0 torchvision==0.4.0 tqdm==4.36.1 visdom==0.1.8.9 pathlib==1.0.1 pyyaml==5.3.1 pandas seaborn statsmodels ================================================ FILE: scripts/energy_calc.py ================================================ import os, sys, math, random, itertools import numpy as np import scipy from collections import defaultdict from tqdm import tqdm import pandas as pd import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn as nn import torch.nn.functional as F from utils import * from plotting import * from energy import get_energy_loss from graph import TaskGraph from datasets import TaskDataset, load_train_val, load_test, load_ood, ImageDataset from task_configs import tasks, RealityTask from functools import partial from fire import Fire import IPython import pdb from modules.unet import UNet def main( loss_config="conservative_full", mode="standard", pretrained=True, finetuned=False, batch_size=16, ood_batch_size=None, subset_size=None, cont=None, use_l1=True, num_workers=32, data_dir=None, save_dir='mount/shared/', **kwargs, ): # CONFIG energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) if data_dir is None: buildings = ["almena", "albertville"] train_subset_dataset = TaskDataset(buildings, tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature]) else: train_subset_dataset = ImageDataset(data_dir=data_dir) data_dir = 'CUSTOM' train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=False) if subset_size is None: subset_size = len(train_subset_dataset) subset_size = min(subset_size, len(train_subset_dataset)) # GRAPH realities = [train_subset] edges = [] for t in energy_loss.tasks: if t != tasks.rgb: edges.append((tasks.rgb, t)) edges.append((tasks.rgb, tasks.normal)) graph = TaskGraph(tasks=energy_loss.tasks + [train_subset], finetuned=finetuned, freeze_list=energy_loss.freeze_list, lazy=True, initialize_from_transfer=True, ) # print('file', cont) #graph.load_weights(cont) graph.compile(optimizer=None) # Add consistency links for target in ['reshading', 'depth_zbuffer', 'normal']: graph.edge_map[str(('rgb', target))].path = None graph.edge_map[str(('rgb', target))].load_model() graph.edge_map[str(('rgb', 'reshading'))].model.load_weights('./models/rgb2reshading_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'depth_zbuffer'))].model.load_weights('./models/rgb2depth_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'normal'))].model.load_weights('./models/rgb2normal_consistency.pth',backward_compatible=True) energy_losses, mse_losses = [], [] percep_losses = defaultdict(list) energy_mean_by_blur, energy_std_by_blur = [], [] error_mean_by_blur, error_std_by_blur = [], [] energy_losses, error_losses = [], [] energy_losses_all, energy_losses_headings = [], [] fnames = [] train_subset.reload() # Compute energies for epochs in tqdm(range(subset_size // batch_size)): with torch.no_grad(): losses = energy_loss(graph, realities=[train_subset], reduce=False, use_l1=use_l1) if len(energy_losses_headings) == 0: energy_losses_headings = sorted([loss_name for loss_name in losses if 'percep' in loss_name]) all_perceps = [losses[loss_name].cpu().numpy() for loss_name in energy_losses_headings] direct_losses = [losses[loss_name].cpu().numpy() for loss_name in losses if 'direct' in loss_name] if len(all_perceps) > 0: energy_losses_all += [all_perceps] all_perceps = np.stack(all_perceps) energy_losses += list(all_perceps.mean(0)) if len(direct_losses) > 0: direct_losses = np.stack(direct_losses) error_losses += list(direct_losses.mean(0)) if False: fnames += train_subset.task_data[tasks.filename] train_subset.step() # log losses if len(energy_losses) > 0: energy_losses = np.array(energy_losses) print(f'energy = {energy_losses.mean()}') energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] if len(error_losses) > 0: error_losses = np.array(error_losses) print(f'error = {error_losses.mean()}') error_mean_by_blur += [error_losses.mean()] error_std_by_blur += [np.std(error_losses)] # save to csv save_error_losses = error_losses if len(error_losses) > 0 else [0] * subset_size save_energy_losses = energy_losses if len(energy_losses) > 0 else [0] * subset_size z_score = lambda x: (x - x.mean()) / x.std() def get_standardized_energy(df, use_std=False, compare_to_in_domain=False): percepts = [c for c in df.columns if 'percep' in c] stdize = lambda x: (x - x.mean()).abs().mean() means = {k: df[k].mean() for k in percepts} stds = {k: stdize(df[k]) for k in percepts} stdized = {k: (df[k] - means[k])/stds[k] for k in percepts} energies = np.stack([v for k, v in stdized.items() if k[-1] == '_' or '__' in k]).mean(0) return energies os.makedirs(save_dir, exist_ok=True) if data_dir is 'CUSTOM': eng_curr = np.array(energy_losses).mean() df = pd.read_csv(os.path.join(save_dir, 'data.csv')) else: percep_losses = { k: v for k, v in zip(energy_losses_headings, np.concatenate(energy_losses_all, axis=-1))} df = pd.DataFrame(both( {'energy': save_energy_losses, 'error': save_error_losses }, percep_losses )) # compuate correlation df['normalized_energy'] = get_standardized_energy(df, use_std=False) df['normalized_error'] = z_score(df['error']) print(scipy.stats.spearmanr(z_score(df['error']), df['normalized_energy'])) print("Pearson r:", scipy.stats.pearsonr(df['error'], df['normalized_energy'])) if data_dir is not 'CUSTOM': df.to_csv(f"{save_dir}/data.csv", mode='w', header=True) # plot correlation plt.figure(figsize=(4,4)) g = sns.regplot(df['normalized_error'], df['normalized_energy'],robust=False) if data_dir is 'CUSTOM': ax1 = g.axes ax1.axhline(eng_curr, ls='--', color='red') ax1.text(0.5, 25, "Query Image Energy Line") plt.xlabel('Error (z-score)') plt.ylabel('Energy (z-score)') plt.title('') plt.savefig(f'{save_dir}/energy.pdf') if __name__ == "__main__": Fire(main) ================================================ FILE: scripts/jobinfo.txt ================================================ CH_lbp_all_winrate_depthtarget_1, 0, /scratch-data ================================================ FILE: task_configs.py ================================================ import numpy as np import random, sys, os, time, glob, math, itertools, json, copy from collections import defaultdict, namedtuple from functools import partial import PIL from PIL import Image from scipy import ndimage import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF import torch.optim as optim from torchvision import transforms from utils import * from models import DataParallelModel from modules.unet import UNet, UNetOld2, UNetOld from modules.percep_nets import Dense1by1Net from modules.depth_nets import UNetDepth from modules.resnet import ResNetClass import IPython from PIL import ImageFilter from skimage.filters import gaussian class GaussianBulr(object): def __init__(self, radius): self.radius = radius self.filter = ImageFilter.GaussianBlur(radius) def __call__(self, im): return im.filter(self.filter) def __repr__(self): return 'GaussianBulr Filter with Radius {:d}'.format(self.radius) """ Model definitions for launching new transfer jobs between tasks. """ model_types = { ('normal', 'principal_curvature'): lambda : Dense1by1Net(), ('normal', 'depth_zbuffer'): lambda : UNetDepth(), ('normal', 'reshading'): lambda : UNet(downsample=5), ('depth_zbuffer', 'normal'): lambda : UNet(downsample=6, in_channels=1, out_channels=3), ('reshading', 'normal'): lambda : UNet(downsample=4, in_channels=3, out_channels=3), ('sobel_edges', 'principal_curvature'): lambda : UNet(downsample=5, in_channels=1, out_channels=3), ('depth_zbuffer', 'principal_curvature'): lambda : UNet(downsample=4, in_channels=1, out_channels=3), ('principal_curvature', 'depth_zbuffer'): lambda : UNet(downsample=6, in_channels=3, out_channels=1), ('rgb', 'normal'): lambda : UNet(downsample=6), ('rgb', 'keypoints2d'): lambda : UNet(downsample=3, out_channels=1), } def get_model(src_task, dest_task): if isinstance(src_task, str) and isinstance(dest_task, str): src_task, dest_task = get_task(src_task), get_task(dest_task) if (src_task.name, dest_task.name) in model_types: return model_types[(src_task.name, dest_task.name)]() elif isinstance(src_task, ImageTask) and isinstance(dest_task, ImageTask): return UNet(downsample=5, in_channels=src_task.shape[0], out_channels=dest_task.shape[0]) elif isinstance(src_task, ImageTask) and isinstance(dest_task, ClassTask): return ResNet(in_channels=src_task.shape[0], out_channels=dest_task.classes) elif isinstance(src_task, ImageTask) and isinstance(dest_task, PointInfoTask): return ResNet(out_channels=dest_task.out_channels) return None """ Abstract task type definitions. Includes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask. """ class Task(object): variances = { "normal": 1.0, "principal_curvature": 1.0, "sobel_edges": 5, "depth_zbuffer": 0.1, "reshading": 1.0, "keypoints2d": 0.3, "keypoints3d": 0.6, "edge_occlusion": 0.1, } """ General task output space""" def __init__(self, name, file_name=None, file_name_alt=None, file_ext="png", file_loader=None, plot_func=None ): super().__init__() self.name = name self.file_name, self.file_ext = file_name or name, file_ext self.file_name_alt = file_name_alt or self.file_name self.file_loader = file_loader or self.file_loader self.plot_func = plot_func or self.plot_func self.variance = Task.variances.get(name, 1.0) self.kind = name def norm(self, pred, target, batch_mean=True, compute_mse=True): if batch_mean: loss = ((pred - target)**2).mean() if compute_mse else ((pred - target).abs()).mean() else: loss = ((pred - target)**2).mean(dim=1).mean(dim=1).mean(dim=1) if compute_mse \ else ((pred - target).abs()).mean(dim=1).mean(dim=1).mean(dim=1) return loss, (loss.mean().detach(),) def __call__(self, size=256): task = copy.deepcopy(self) return task def plot_func(self, data, name, logger, **kwargs): ### Non-image tasks cannot be easily plotted, default to nothing pass def file_loader(self, path, resize=None, seed=0, T=0): raise NotImplementedError() def __eq__(self, other): return self.name == other.name def __repr__(self): return self.name def __hash__(self): return hash(self.name) """ Abstract task type definitions. Includes Task, ImageTask, ClassTask, PointInfoTask, and SegmentationTask. """ class RealityTask(Task): """ General task output space""" def __init__(self, name, dataset, tasks=None, use_dataset=True, shuffle=False, batch_size=64): super().__init__(name=name) self.tasks = tasks if tasks is not None else \ (dataset.dataset.tasks if hasattr(dataset, "dataset") else dataset.tasks) self.shape = (1,) if not use_dataset: return self.dataset, self.shuffle, self.batch_size = dataset, shuffle, batch_size loader = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, num_workers=0, shuffle=self.shuffle, pin_memory=True ) self.generator = cycle(loader) self.step() self.static = False @classmethod def from_dataloader(cls, name, loader, tasks): reality = cls(name, None, tasks, use_dataset=False) reality.loader = loader reality.generator = cycle(loader) reality.static = False reality.step() return reality @classmethod def from_static(cls, name, data, tasks): reality = cls(name, None, tasks, use_dataset=False) reality.task_data = {task: x.requires_grad_() for task, x in zip(tasks, data)} reality.static = True return reality def norm(self, pred, target, batch_mean=True): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(),) def step(self): self.task_data = {task: x.requires_grad_() for task, x in zip(self.tasks, next(self.generator))} def reload(self): loader = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, num_workers=0, shuffle=self.shuffle, pin_memory=True ) self.generator = cycle(loader) class ImageTask(Task): """ Output space for image-style tasks """ def __init__(self, *args, **kwargs): self.shape = kwargs.pop("shape", (3, 256, 256)) self.mask_val = kwargs.pop("mask_val", -1.0) self.transform = kwargs.pop("transform", lambda x: x) self.resize = kwargs.pop("resize", self.shape[1]) self.blur_radius = None self.image_transform = self.load_image_transform() super().__init__(*args, **kwargs) @staticmethod def build_mask(target, val=0.0, tol=1e-3): if target.shape[1] == 1: mask = ((target >= val - tol) & (target <= val + tol)) mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0 return (~mask).expand_as(target) mask1 = (target[:, 0, :, :] >= val - tol) & (target[:, 0, :, :] <= val + tol) mask2 = (target[:, 1, :, :] >= val - tol) & (target[:, 1, :, :] <= val + tol) mask3 = (target[:, 2, :, :] >= val - tol) & (target[:, 2, :, :] <= val + tol) mask = (mask1 & mask2 & mask3).unsqueeze(1) mask = F.conv2d(mask.float(), torch.ones(1, 1, 5, 5, device=mask.device), padding=2) != 0 return (~mask).expand_as(target) def norm(self, pred, target, batch_mean=True, compute_mask=0, compute_mse=True): if compute_mask: mask = ImageTask.build_mask(target, val=self.mask_val) return super().norm(pred*mask.float(), target*mask.float(), batch_mean=batch_mean, compute_mse=compute_mse) else: return super().norm(pred, target, batch_mean=batch_mean, compute_mse=compute_mse) def __call__(self, size=256, blur_radius=None): task = copy.deepcopy(self) task.shape = (3, size, size) task.resize = size task.blur_radius = blur_radius task.name += "blur" if blur_radius else str(size) task.base = self return task def plot_func(self, data, name, logger, resize=None, nrow=2): logger.images(data.clamp(min=0, max=1), name, nrow=nrow, resize=resize or self.resize) def file_loader(self, path, resize=None, crop=None, seed=0, jitter=False): image_transform = self.load_image_transform(resize=resize, crop=crop, seed=seed, jitter=jitter) return image_transform(Image.open(open(path, 'rb')))[0:3] def load_image_transform(self, resize=None, crop=None, seed=0, jitter=False): size = resize or self.resize random.seed(seed) jitter_transform = lambda x: x if jitter: jitter_transform = transforms.ColorJitter(0.4,0.4,0.4,0.1) crop_transform = lambda x: x if crop is not None: crop_transform = transforms.CenterCrop(crop) blur = [GaussianBulr(self.blur_radius)] if self.blur_radius else [] return transforms.Compose(blur+[ crop_transform, transforms.Resize(size, interpolation=PIL.Image.BILINEAR), jitter_transform, transforms.CenterCrop(size), transforms.ToTensor(), self.transform] ) class ImageClassTask(ImageTask): """ Output space for image-class segmentation tasks """ def __init__(self, *args, **kwargs): self.classes = kwargs.pop("classes", (3, 256, 256)) super().__init__(*args, **kwargs) def norm(self, pred, target): loss = F.kl_div(F.log_softmax(pred, dim=1), F.softmax(target, dim=1)) return loss, (loss.detach(),) def plot_func(self, data, name, logger, resize=None): _, idx = torch.max(data, dim=1) idx = idx.float()/16.0 idx = idx.unsqueeze(1).expand(-1, 3, -1, -1) logger.images(idx.clamp(min=0, max=1), name, nrow=2, resize=resize or self.resize) def file_loader(self, path, resize=None): data = (self.image_transform(Image.open(open(path, 'rb')))*255.0).long() one_hot = torch.zeros((self.classes, data.shape[1], data.shape[2])) one_hot = one_hot.scatter_(0, data, 1) return one_hot class PointInfoTask(Task): """ Output space for point-info prediction tasks (what models do we evem use?) """ def __init__(self, *args, **kwargs): self.point_type = kwargs.pop("point_type", "vanishing_points_gaussian_sphere") self.out_channels = 9 super().__init__(*args, **kwargs) def plot_func(self, data, name, logger): logger.window(name, logger.visdom.text, str(data.data.cpu().numpy())) def file_loader(self, path, resize=None): points = json.load(open(path))[self.point_type] return np.array(points['x'] + points['y'] + points['z']) """ Current list of task definitions. Accessible via tasks.{TASK_NAME} or get_task("{TASK_NAME}") """ def clamp_maximum_transform(x, max_val=8000.0): x = x.unsqueeze(0).float() / max_val return x[0].clamp(min=0, max=1) def crop_transform(x, max_val=8000.0): x = x.unsqueeze(0).float() / max_val return x[0].clamp(min=0, max=1) def sobel_transform(x): image = x.data.cpu().numpy().mean(axis=0) blur = ndimage.filters.gaussian_filter(image, sigma=2, ) sx = ndimage.sobel(blur, axis=0, mode='constant') sy = ndimage.sobel(blur, axis=1, mode='constant') sob = np.hypot(sx, sy) edge = torch.FloatTensor(sob).unsqueeze(0) return edge def blur_transform(x, max_val=4000.0): if x.shape[0] == 1: x = x.squeeze(0) image = x.data.cpu().numpy() blur = ndimage.filters.gaussian_filter(image, sigma=2, ) norm = torch.FloatTensor(blur).unsqueeze(0)**0.8 / (max_val**0.8) norm = norm.clamp(min=0, max=1) if norm.shape[0] != 1: norm = norm.unsqueeze(0) return norm def get_task(task_name): return task_map[task_name] tasks = [ ImageTask('rgb'), ImageTask('imagenet', mask_val=0.0), ImageTask('normal', mask_val=0.502), ImageTask('principal_curvature', mask_val=0.0), ImageTask('depth_zbuffer', shape=(1, 256, 256), mask_val=1.0, transform=partial(clamp_maximum_transform, max_val=8000.0), ), ImageClassTask('segment_semantic', file_name_alt="segmentsemantic", shape=(16, 256, 256), classes=16, ), ImageTask('reshading', mask_val=0.0507), ImageTask('edge_occlusion', shape=(1, 256, 256), transform=partial(blur_transform, max_val=4000.0), ), ImageTask('sobel_edges', shape=(1, 256, 256), file_name='rgb', transform=sobel_transform, ), ImageTask('keypoints3d', shape=(1, 256, 256), transform=partial(clamp_maximum_transform, max_val=64131.0), ), ImageTask('keypoints2d', shape=(1, 256, 256), transform=partial(blur_transform, max_val=2000.0), ), ] task_map = {task.name: task for task in tasks} tasks = namedtuple('TaskMap', task_map.keys())(**task_map) if __name__ == "__main__": IPython.embed() ================================================ FILE: tools/download_data.sh ================================================ ##!/usr/bin/env bash wget https://drive.switch.ch/index.php/s/0Fqr6t6cZsI0cp9/download unzip download rm download cd data unzip -qqo albertville_rgb.zip unzip -qqo albertville_normal.zip unzip -qqo albertville_principal_curvature.zip unzip -qqo almena_rgb.zip unzip -qqo almena_normal.zip unzip -qqo almena_principal_curvature.zip rm albertville_rgb.zip albertville_normal.zip albertville_principal_curvature.zip almena_rgb.zip almena_normal.zip almena_principal_curvature.zip cd - ================================================ FILE: tools/download_energy_graph_edges.sh ================================================ ##!/usr/bin/env bash SCRIPT_DIR=$( dirname "$0" ) FILE=./models/rgb2normal_consistency.pth if [ -f "$FILE" ]; then echo "Found consistency network $FILE: skipping download of these networks" else echo "Downloading consistency networks..." sh $SCRIPT_DIR/download_models.sh fi FILE=./models/normal2curvature.pth if [ -f "$FILE" ]; then echo "Found perceptual network $FILE: skipping download of these networks" else echo "Downloading perceptual networks..." sh $SCRIPT_DIR/download_percep_models.sh fi FILE=./models/rgb2principal_curvature.pth if [ -f "$FILE" ]; then echo "Found energy-graph specific network $FILE: skipping download of these networks" else echo "RGB2X energy networks..." # Get energy-graph-specific links wget https://drive.switch.ch/index.php/s/aZDOEBixS4W7mBL/download unzip download rm download mv energy_graph_edges/* models/ rmdir energy_graph_edges fi ================================================ FILE: tools/download_models.sh ================================================ ##!/usr/bin/env bash wget https://drive.switch.ch/index.php/s/QPvImzbbdjBKI5P/download unzip download rm download ================================================ FILE: tools/download_percep_models.sh ================================================ ##!/usr/bin/env bash wget https://drive.switch.ch/index.php/s/aXu4EFaznqtNzsE/download unzip download rm download mv percep_models/* models/ rmdir percep_models ================================================ FILE: train.py ================================================ ''' Name: train.py Desc: Executes training of a network with the consistency framework. Here are some options that may be specified for any model. If they have a default value, it is given at the end of the description in parens. Data pipeline: Data locations: 'train_buildings': A list of the folders containing the training data. This is defined in configs/split.txt. 'val_buildings': As above, but for validation data. 'data_dirs': The folder that all the data is stored in. This may just be something like '/', and then all filenames in 'train_filenames' will give paths relative to 'dataset_dir'. For example, if 'dataset_dir'='/', then train_filenames might have entries like 'path/to/data/img_01.png'. This is defiled in utils.py. Logging: 'results_dir': An absolute path to where checkpoints are saved. This is defined in utils.py. Training: 'batch_size': The size of each batch. (64) 'num_epochs': The maximum number of epochs to train for. (800) 'energy_config': {multiperceptual_targettask} The paths taken to compute the losses. 'k': Number of perceptual loss chosen. 'data_aug': {True, False} If data augmentation shuold be used during training. See TrainTaskDataset class in datasets.py for the types of data augmentation used. (False) Optimization: 'initial_learning_rate': The initial learning rate to use for the model. (3e-5) Usage: python -m train multiperceptual_depth --batch-size 32 --k 8 --max-epochs 100 ''' import torch import torch.nn as nn from utils import * from energy import get_energy_loss from graph import TaskGraph from logger import Logger, VisdomLogger from datasets import load_train_val, load_test, load_ood from task_configs import tasks, RealityTask from transfers import functional_transfers from fire import Fire #import pdb def main( loss_config="multiperceptual", mode="winrate", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2,2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) if fast: train_dataset = val_dataset train_step, val_step = 2,2 realities = [train, val] else: test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville']) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) realities = [train, val, test] # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do: # ood_set = load_ood(energy_loss.get_tasks("ood")) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # realities.append(ood) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, initialize_from_transfer=False, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(RESULTS_DIR, exist_ok=True) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step*4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step*4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") if visualize: return graph.train() for _ in range(0, train_step): train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() if __name__ == "__main__": Fire(main) ================================================ FILE: transfers.py ================================================ import os, sys, math, random, itertools, functools from collections import namedtuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as util_checkpoint from torchvision import models from utils import * from models import TrainableModel, DataParallelModel from task_configs import get_task, task_map, get_model, Task, RealityTask from modules.percep_nets import DenseNet, Dense1by1Net, DenseKernelsNet, DeepNet, BaseNet, WideNet, PyramidNet from modules.depth_nets import UNetDepth from modules.unet import UNet, UNetOld, UNetOld2, UNetReshade from modules.resnet import ResNetClass from fire import Fire import IPython pretrained_transfers = { ('normal', 'principal_curvature'): (lambda: Dense1by1Net(), f"{MODELS_DIR}/normal2curvature.pth"), ('normal', 'depth_zbuffer'): (lambda: UNetDepth(), f"{MODELS_DIR}/normal2zdepth_zbuffer.pth"), ('normal', 'sobel_edges'): (lambda: UNet(out_channels=1, downsample=4).cuda(), f"{MODELS_DIR}/normal2edges2d.pth"), ('normal', 'reshading'): (lambda: UNetReshade(downsample=5), f"{MODELS_DIR}/normal2reshade.pth"), ('normal', 'keypoints3d'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/normal2keypoints3d.pth"), ('normal', 'keypoints2d'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/normal2keypoints2d_new.pth"), ('normal', 'edge_occlusion'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/normal2edge_occlusion.pth"), ('depth_zbuffer', 'normal'): (lambda: UNet(in_channels=1, downsample=6), f"{MODELS_DIR}/depth2normal.pth"), ('depth_zbuffer', 'sobel_edges'): (lambda: UNet(downsample=4, in_channels=1, out_channels=1).cuda(), f"{MODELS_DIR}/depth_zbuffer2sobel_edges.pth"), ('depth_zbuffer', 'principal_curvature'): (lambda: UNet(downsample=4, in_channels=1), f"{MODELS_DIR}/depth_zbuffer2principal_curvature.pth"), ('depth_zbuffer', 'reshading'): (lambda: UNetReshade(downsample=5, in_channels=1), f"{MODELS_DIR}/depth_zbuffer2reshading.pth"), ('depth_zbuffer', 'keypoints3d'): (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f"{MODELS_DIR}/depth_zbuffer2keypoints3d.pth"), ('depth_zbuffer', 'keypoints2d'): (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f"{MODELS_DIR}/depth_zbuffer2keypoints2d.pth"), ('depth_zbuffer', 'edge_occlusion'): (lambda: UNet(downsample=5, in_channels=1, out_channels=1), f"{MODELS_DIR}/depth_zbuffer2edge_occlusion.pth"), ('reshading', 'depth_zbuffer'): (lambda: UNetReshade(downsample=5, out_channels=1), f"{MODELS_DIR}/reshading2depth_zbuffer.pth"), ('reshading', 'keypoints2d'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/reshading2keypoints2d_new.pth"), ('reshading', 'edge_occlusion'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/reshading2edge_occlusion.pth"), ('reshading', 'normal'): (lambda: UNet(downsample=4), f"{MODELS_DIR}/reshading2normal.pth"), ('reshading', 'keypoints3d'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/reshading2keypoints3d.pth"), ('reshading', 'sobel_edges'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/reshading2sobel_edges.pth"), ('reshading', 'principal_curvature'): (lambda: UNet(downsample=5), f"{MODELS_DIR}/reshading2principal_curvature.pth"), ('rgb', 'sobel_edges'): (lambda: SobelKernel(), None), ('rgb', 'principal_curvature'): (lambda: UNet(downsample=5), f"{MODELS_DIR}/rgb2principal_curvature.pth"), ('rgb', 'keypoints2d'): (lambda: UNet(downsample=3, out_channels=1), f"{MODELS_DIR}/rgb2keypoints2d_new.pth"), ('rgb', 'keypoints3d'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/rgb2keypoints3d.pth"), ('rgb', 'edge_occlusion'): (lambda: UNet(downsample=5, out_channels=1), f"{MODELS_DIR}/rgb2edge_occlusion.pth"), ('rgb', 'normal'): (lambda: UNet(), f"{MODELS_DIR}/rgb2normal_baseline.pth"), ('rgb', 'reshading'): (lambda: UNetReshade(downsample=5), f"{MODELS_DIR}/rgb2reshading_baseline.pth"), ('rgb', 'depth_zbuffer'): (lambda: UNet(downsample=6, out_channels=1), f"{MODELS_DIR}/rgb2zdepth_baseline.pth"), ('normal', 'imagenet'): (lambda: ResNetClass().cuda(), None), ('depth_zbuffer', 'imagenet'): (lambda: ResNetClass().cuda(), None), ('reshading', 'imagenet'): (lambda: ResNetClass().cuda(), None), ('principal_curvature', 'sobel_edges'): (lambda: UNet(downsample=4, out_channels=1), f"{MODELS_DIR}/principal_curvature2sobel_edges.pth"), ('sobel_edges', 'depth_zbuffer'): (lambda: UNet(downsample=6, in_channels=1, out_channels=1), f"{MODELS_DIR}/sobel_edges2depth_zbuffer.pth"), ('depth_zbuffer', 'normal'): (lambda: UNet(in_channels=1, downsample=6), f"{MODELS_DIR}/depth2normal.pth"), ('keypoints2d', 'normal'): (lambda: UNet(downsample=5, in_channels=1), f"{MODELS_DIR}/keypoints2d2normal_new.pth"), ('keypoints3d', 'normal'): (lambda: UNet(downsample=5, in_channels=1), f"{MODELS_DIR}/keypoints3d2normal.pth"), ('principal_curvature', 'normal'): (lambda: UNetOld2(), f"{MODELS_DIR}/principal_curvature2normal.pth"), ('sobel_edges', 'normal'): (lambda: UNet(in_channels=1, downsample=5).cuda(), f"{MODELS_DIR}/sobel_edges2normal.pth"), ('edge_occlusion', 'normal'): (lambda: UNet(in_channels=1, downsample=5), f"{MODELS_DIR}/edge_occlusion2normal.pth"), } class Transfer(nn.Module): def __init__(self, src_task, dest_task, checkpoint=True, name=None, model_type=None, path=None, pretrained=True, finetuned=False ): super().__init__() if isinstance(src_task, str) and isinstance(dest_task, str): src_task, dest_task = get_task(src_task), get_task(dest_task) self.src_task, self.dest_task, self.checkpoint = src_task, dest_task, checkpoint self.name = name or f"{src_task.name}2{dest_task.name}" saved_type, saved_path = None, None if model_type is None and path is None: saved_type, saved_path = pretrained_transfers.get((src_task.name, dest_task.name), (None, None)) self.model_type, self.path = model_type or saved_type, path or saved_path self.model = None if finetuned: path = f"{MODELS_DIR}/ft_perceptual/{src_task.name}2{dest_task.name}.pth" if os.path.exists(path): self.model_type, self.path = saved_type or (lambda: get_model(src_task, dest_task)), path print ("Using finetuned: ", path) return if self.model_type is None: if src_task.kind == dest_task.kind and src_task.resize != dest_task.resize: class Module(TrainableModel): def __init__(self): super().__init__() def forward(self, x): return resize(x, val=dest_task.resize) self.model_type = lambda: Module() self.path = None path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}.pth" if src_task.name == "keypoints2d" or dest_task.name == "keypoints2d": path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}_new.pth" if os.path.exists(path): self.model_type, self.path = lambda: get_model(src_task, dest_task), path if not pretrained: print ("Not using pretrained [heavily discouraged]") self.path = None def load_model(self): if self.model is None: if self.path is not None: self.model = DataParallelModel.load(self.model_type().to(DEVICE), self.path) # if optimizer: # self.model.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) else: self.model = self.model_type().to(DEVICE) if isinstance(self.model, nn.Module): self.model = DataParallelModel(self.model) return self.model def __call__(self, x): self.load_model() preds = util_checkpoint(self.model, x) if self.checkpoint else self.model(x) preds.task = self.dest_task return preds def __repr__(self): return self.name or str(self.src_task) + " -> " + str(self.dest_task) class RealityTransfer(Transfer): def __init__(self, src_task, dest_task): super().__init__(src_task, dest_task, model_type=lambda: None) def load_model(self, optimizer=True): pass def __call__(self, x): assert (isinstance(self.src_task, RealityTask)) return self.src_task.task_data[self.dest_task].to(DEVICE) class FineTunedTransfer(Transfer): def __init__(self, transfer): super().__init__(transfer.src_task, transfer.dest_task, checkpoint=transfer.checkpoint, name=transfer.name) self.cached_models = {} def load_model(self, parents=[]): model_path = get_finetuned_model_path(parents + [self]) if model_path not in self.cached_models: if not os.path.exists(model_path): print(f"{model_path} not found, loading pretrained") self.cached_models[model_path] = super().load_model() else: print(f"{model_path} found, loading finetuned") self.cached_models[model_path] = DataParallelModel.load(self.model_type().cuda(), model_path) print(f"") self.model = self.cached_models[model_path] return self.model def __call__(self, x): if not hasattr(x, "parents"): x.parents = [] self.load_model(parents=x.parents) preds = util_checkpoint(self.model, x) if self.checkpoint else self.model(x) preds.parents = x.parents + ([self]) return preds functional_transfers = ( Transfer('normal', 'principal_curvature', name='f'), Transfer('principal_curvature', 'normal', name='F'), Transfer('normal', 'depth_zbuffer', name='g'), Transfer('depth_zbuffer', 'normal', name='G'), Transfer('normal', 'sobel_edges', name='s'), Transfer('sobel_edges', 'normal', name='S'), Transfer('principal_curvature', 'sobel_edges', name='CE'), Transfer('sobel_edges', 'principal_curvature', name='EC'), Transfer('depth_zbuffer', 'sobel_edges', name='DE'), Transfer('sobel_edges', 'depth_zbuffer', name='ED'), Transfer('principal_curvature', 'depth_zbuffer', name='h'), Transfer('depth_zbuffer', 'principal_curvature', name='H'), Transfer('rgb', 'normal', name='n'), Transfer('rgb', 'normal', name='npstep', model_type=lambda: UNetOld(), path=f"{MODELS_DIR}/unet_percepstep_0.1.pth", ), Transfer('rgb', 'principal_curvature', name='RC'), Transfer('rgb', 'keypoints2d', name='k'), Transfer('rgb', 'sobel_edges', name='a'), Transfer('rgb', 'reshading', name='r'), Transfer('rgb', 'depth_zbuffer', name='d'), Transfer('keypoints2d', 'principal_curvature', name='KC'), Transfer('keypoints3d', 'principal_curvature', name='k3C'), Transfer('principal_curvature', 'keypoints3d', name='Ck3'), Transfer('normal', 'reshading', name='nr'), Transfer('reshading', 'normal', name='rn'), Transfer('keypoints3d', 'normal', name='k3N'), Transfer('normal', 'keypoints3d', name='Nk3'), Transfer('keypoints2d', 'normal', name='k2N'), Transfer('normal', 'keypoints2d', name='Nk2'), Transfer('sobel_edges', 'reshading', name='Er'), ) finetuned_transfers = [FineTunedTransfer(transfer) for transfer in functional_transfers] TRANSFER_MAP = {t.name:t for t in functional_transfers} functional_transfers = namedtuple('functional_transfers', TRANSFER_MAP.keys())(**TRANSFER_MAP) def get_transfer_name(transfer): for t in functional_transfers: if transfer.src_task == t.src_task and transfer.dest_task == t.dest_task: return t.name return transfer.name (f, F, g, G, s, S, CE, EC, DE, ED, h, H, n, npstep, RC, k, a, r, d, KC, k3C, Ck3, nr, rn, k3N, Nk3, Er, k2N, N2k) = functional_transfers if __name__ == "__main__": y = g(F(f(x))) print (y.shape) ================================================ FILE: utils.py ================================================ import numpy as np import random, sys, os, time, glob, math, itertools, pickle import parse from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable from functools import partial from scipy import ndimage import IPython DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") USE_CUDA = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor EXPERIMENT, BASE_DIR = open("config/jobinfo.txt").read().strip().split(', ') JOB = "_".join(EXPERIMENT.split("_")[0:-1]) MODELS_DIR = f"{BASE_DIR}/models" DATA_DIRS = [f"/taskonomy-data/taskonomydata", 'data'] RESULTS_DIR = f"{BASE_DIR}/results/results_{EXPERIMENT}" SHARED_DIR = f"{BASE_DIR}/shared" OOD_DIR = f"{SHARED_DIR}/ood_standard_set" USE_RAID = False # os.system(f"mkdir -p {RESULTS_DIR}") def both(x, y): x = dict(x.items()) x.update(y) return x def elapsed(last_time=[time.time()]): """ Returns the time passed since elapsed() was last called. """ current_time = time.time() diff = current_time - last_time[0] last_time[0] = current_time return diff def cycle(iterable): """ Cycles through iterable without making extra copies. """ while True: for i in iterable: yield i def average(arr): return sum(arr) / len(arr) # def random_resize(iterable, vals=[128, 192, 256, 320]): # """ Cycles through iterable while randomly resizing batch values. """ # from transforms import resize # while True: # for X, Y in iterable: # val = random.choice(vals) # yield resize(X.to(DEVICE), val=val).detach(), resize(Y.to(DEVICE), val=val).detach() def get_files(exp, data_dirs=DATA_DIRS, recursive=False): """ Gets data files across mounted directories matching glob expression pattern. """ # cache = SHARED_DIR + "/filecache_" + "_".join(exp.split()).replace(".", "_").replace("/", "_").replace("*", "_") + ("r" if recursive else "f") + ".pkl" # print ("Cache file: ", cache) # if os.path.exists(cache): # return pickle.load(open(cache, 'rb')) files, seen = [], set() for data_dir in data_dirs: for file in glob.glob(f'{data_dir}/{exp}', recursive=recursive): if file[len(data_dir):] not in seen: files.append(file) seen.add(file[len(data_dir):]) # pickle.dump(files, open(cache, 'wb')) return files def get_finetuned_model_path(parents): if BASE_DIR == "/": return f"{RESULTS_DIR}/" + "_".join([parent.name for parent in parents[::-1]]) + ".pth" else: return f"{MODELS_DIR}/finetuned/" + "_".join([parent.name for parent in parents[::-1]]) + ".pth" def plot_images(model, logger, test_set, dest_task="normal", ood_images=None, show_masks=False, loss_models={}, preds_name=None, target_name=None, ood_name=None, ): from task_configs import get_task, ImageTask test_images, preds, targets, losses, _ = model.predict_with_data(test_set) if isinstance(dest_task, str): dest_task = get_task(dest_task) if show_masks and isinstance(dest_task, ImageTask): test_masks = ImageTask.build_mask(targets, dest_task.mask_val, tol=1e-3) logger.images(test_masks.float(), f"{dest_task}_masks", resize=64) dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger) dest_task.plot_func(targets, target_name or f"{dest_task.name}_target", logger) if ood_images is not None: ood_preds = model.predict(ood_images) dest_task.plot_func(ood_preds, ood_name or f"{dest_task.name}_ood_preds", logger) for name, loss_model in loss_models.items(): with torch.no_grad(): output = loss_model(preds, targets, test_images) if hasattr(output, "task"): output.task.plot_func(output, name, logger, resize=128) else: logger.images(output.clamp(min=0, max=1), name, resize=128) def gaussian_filter(channels=3, kernel_size=5, sigma=1.0, device=0): x_cord = torch.arange(kernel_size).float() x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1) mean = (kernel_size - 1) / 2. variance = sigma ** 2. gaussian_kernel = (1. / (2. * math.pi * variance)) * torch.exp( -torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance) ) gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) return gaussian_kernel def motion_blur_filter(kernel_size=15): channels = 3 kernel_motion_blur = torch.zeros((kernel_size, kernel_size)) kernel_motion_blur[int((kernel_size - 1) / 2), :] = torch.ones(kernel_size) kernel_motion_blur = kernel_motion_blur / kernel_size kernel_motion_blur = kernel_motion_blur.view(1, 1, kernel_size, kernel_size) kernel_motion_blur = kernel_motion_blur.repeat(channels, 1, 1, 1) return kernel_motion_blur def sobel_kernel(x): def sobel_transform(x): image = x.data.cpu().numpy().mean(axis=0) blur = ndimage.filters.gaussian_filter(image, sigma=2, ) sx = ndimage.sobel(blur, axis=0, mode='constant') sy = ndimage.sobel(blur, axis=1, mode='constant') sob = np.hypot(sx, sy) edge = torch.FloatTensor(sob).unsqueeze(0) return edge x = torch.stack([sobel_transform(y) for y in x], dim=0) return x.to(DEVICE).requires_grad_() class SobelKernel(nn.Module): def __init__(self): super().__init__() def forward(self, x): return sobel_kernel(x) def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # cpu vars torch.cuda.manual_seed_all(seed) # gpu vars