Full Code of silviazuffi/smalst for AI

master f7871d6e5333 cached
50 files
37.6 MB
1.4M tokens
289 symbols
1 requests
Download .txt
Showing preview only (5,708K chars total). Download the full file or copy to clipboard to get everything.
Repository: silviazuffi/smalst
Branch: master
Commit: f7871d6e5333
Files: 50
Total size: 37.6 MB

Directory structure:
gitextract_38wvuzcu/

├── LICENSE.txt
├── LICENSE_SMAL_MODEL.txt
├── README.md
├── __init__.py
├── _config.yml
├── data/
│   ├── __init__.py
│   ├── smal_base.py
│   └── zebra.py
├── docs/
│   └── tmp.txt
├── experiments/
│   ├── __init__.py
│   └── smal_shape.py
├── external/
│   ├── __init__.py
│   └── install_external.sh
├── nnutils/
│   ├── __init__.py
│   ├── geom_utils.py
│   ├── loss_utils.py
│   ├── net_blocks.py
│   ├── nmr.py
│   ├── perceptual_loss.py
│   ├── smal_mesh_eval.py
│   ├── smal_mesh_net.py
│   ├── smal_predictor.py
│   └── train_utils.py
├── requirements.txt
├── scripts/
│   ├── smalst_evaluation_run.sh
│   ├── smalst_op_run.sh
│   └── smalst_train_run.sh
├── smal_eval.py
├── smal_model/
│   ├── __init__.py
│   ├── batch_lbs.py
│   ├── smal_basics.py
│   ├── smal_torch.py
│   ├── symIdx.pkl
│   ├── template_w_tex_uv.mtl
│   └── template_w_tex_uv.obj
├── smpl_models/
│   ├── my_smpl_00781_4_all.pkl
│   ├── my_smpl_00781_4_all_template_w_tex_uv_001.pkl
│   ├── my_smpl_data_00781_4_all.pkl
│   └── symIdx.pkl
├── testset_shape_experiments_crops.py
├── utils/
│   ├── __init__.py
│   ├── geometry.py
│   ├── image.py
│   ├── mesh.py
│   ├── obj2nmr.py
│   ├── smal_vis.py
│   ├── transformations.py
│   ├── visualizer.py
│   └── visutil.py
└── zebra_data/
    └── verts2kp.pkl

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE.txt
================================================
MIT License

Copyright (c) 2018 silviazuffi

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: LICENSE_SMAL_MODEL.txt
================================================
License
Software Copyright License for non-commercial scientific research purposes
Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SMAL data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software, you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License

Ownership / Licensees
The Software and the associated materials has been developed at the 

Max Planck Institute for Intelligent Systems (hereinafter "MPI").

Any copyright or patent right is owned by and proprietary material of the

Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)

hereinafter the “Licensor”.

License Grant
Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:

To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. 

This license also prohibits the use of the Data & Software to train methods/algorithms/neural networks/etc for commercial use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.

No Distribution
The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.

Disclaimer of Representations and Warranties
You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.

Limitation of Liability
Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.

No Maintenance Services
You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.

Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.

Publications using the Data & Software
You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.

Citation:
@inproceedings{Zuffi:CVPR:2017,
  title = {{3D} Menagerie: Modeling the {3D} Shape and Pose of Animals},
  author = {Zuffi, Silvia and Kanazawa, Angjoo and Jacobs, David and Black, Michael J.},
  booktitle = {Proceedings IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2017},
  pages = {5524--5532},
  publisher = {IEEE},
  address = {Piscataway, NJ, USA},
  month = jul,
  year = {2017},
  month_numeric = {7}
}

Commercial licensing opportunities
For commercial uses of the Data & Software, please send email to ps-license@tue.mpg.de

This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.


================================================
FILE: README.md
================================================
# Three-D Safari: Learning to Estimate Zebra Pose, Shape, and Texture from Images "In the Wild"

Silvia Zuffi<sup>1</sup>, Angjoo Kanazawa<sup>2</sup>, Tanya Berger-Wolf<sup>3</sup>, Michael J. Black<sup>4</sup>

<sup>1</sup>IMATI-CNR, Milan, Italy, <sup>2</sup>University of California, Berkeley,
<sup>3</sup>University of Illinois at Chicago, <sup>4</sup>Max Planck Institute for Intelligent Systems, Tuebingen, Germany

In ICCV 2019

![alt text](https://github.com/silviazuffi/smalst/blob/master/docs/teaser4.jpg)

<p align="center">
  <img src="https://github.com/silviazuffi/smalst/blob/master/docs/zebra_video.gif">
</p>


[paper](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/533/6034_after_pdfexpress.pdf)

[suppmat](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/535/6034_supp.pdf)


### Requirements
- Python 2.7
- [PyTorch](https://pytorch.org/) tested on version `0.5.0`

### Installation

Note that the following warning has been issued:
"Pillow before 8.1.1 allows attackers to cause a denial of service (memory consumption) because the reported size of a contained image is not properly checked for an ICO container, and thus an attempted memory allocation can be very large."

#### Setup virtualenv
```
virtualenv venv_smalst
source venv_smalst/bin/activate
pip install -U pip
deactivate
source venv_smalst/bin/activate
pip install -r requirements.txt
```

#### Install Neural Mesh Renderer and Perceptual loss
```
cd external;
bash install_external.sh
```
#### Install SMPL model
download the [SMPL model](https://ps.is.tuebingen.mpg.de/code/smpl/) and create a directory smpl_webuser under the smalst/smal_model directory

#### Download data
- [Trained network](https://drive.google.com/a/berkeley.edu/file/d/1ZkKmqlbs3LlcGTrMK1j0ZVBpddg9b6Jf/view?usp=drivesdk)

- [Training data](https://drive.google.com/open?id=1yVy4--M4CNfE5x9wUr1QBmAXEcWb6PWF)

- [Test data](https://drive.google.com/a/berkeley.edu/file/d/1g5jZeA2ptAgdKVOAbZoVqsU-dNE-HD-e/view?usp=drivesdk)

- [Validation data](https://drive.google.com/a/berkeley.edu/file/d/1Ae0J83Y7Un1zBYFVd2za94d1KNnks8IL/view?usp=drivesdk)

The test and validation data are images collected in [The Great Grevy's Rally 2018](https://www.marwell.org.uk/media/other/cs_report_ggr_2018v.4.pdf)

Place the downloaded network pred_net_186.pth in the folder cachedir/snapshots/smal_net_600/


#### Usage

See the script in smalst/script directory for training and testing

#### Notes
The code in this repository is widely based on the project https://github.com/akanazawa/cmr

#### Citation

If you use this code please cite
```
@inproceedings{Zuffi:ICCV:2019,
  title = {Three-D Safari: Learning to Estimate Zebra Pose, Shape, and Texture from Images "In the Wild"},
  author = {Zuffi, Silvia and Kanazawa, Angjoo and Berger-Wolf, Tanya and Black, Michael J.},
  booktitle = {International Conference on Computer Vision},
  month = oct,
  year = {2019},
  month_numeric = {10}
}
```





================================================
FILE: __init__.py
================================================


================================================
FILE: _config.yml
================================================
theme: jekyll-theme-slate

================================================
FILE: data/__init__.py
================================================


================================================
FILE: data/smal_base.py
================================================
"""
Base data loading class.

Should output:
    - img: B X 3 X H X W
    - kp: B X nKp X 2
    - mask: B X H X W
    # Silvia - sfm_pose: B X 7 (s, tr, q)
    - camera_params: B X 4 (s, tr)
    (kp, sfm_pose) correspond to image coordinates in [-1, 1]
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import numpy as np

import scipy.misc
import scipy.linalg
import scipy.ndimage.interpolation
from absl import flags, app

import pickle as pkl

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

from ..utils import image as image_utils
from ..utils import transformations
from ..nnutils.geom_utils import perspective_proj_withz


flags.DEFINE_integer('bgval', 1, 'color for padding input image')
flags.DEFINE_integer('border', 0, 'padding input image')
flags.DEFINE_integer('img_size', 256, 'image size')
flags.DEFINE_boolean('use_bbox', True, 'If doing the cropping based on bboxes')
flags.DEFINE_boolean('perturb_bbox', True, '')
flags.DEFINE_boolean('online_training', False, 'If to change dataset')
flags.DEFINE_boolean('save_training', False, 'Save the cropped images')
flags.DEFINE_boolean('update_vis', True, 'If to update visibility in the keypoint normaliation')

flags.DEFINE_float('padding_frac', 0.1, #0.05,
                   'bbox is increased by this fraction of max_dim')

flags.DEFINE_float('jitter_frac', 0.1, #0.05,
                   'bbox is jittered by this fraction of max_dim')

flags.DEFINE_enum('split', 'train', ['train', 'val', 'all', 'test'], 'eval split')
flags.DEFINE_integer('num_kps', 28, 'The dataloader should override these.')
flags.DEFINE_integer('n_data_workers', 4, 'Number of data loading workers')


# -------------- Dataset ------------- #
# ------------------------------------ #
class BaseDataset(Dataset):
    ''' 
    img, mask, kp, pose, texture_map data loader
    '''

    def __init__(self, opts, filter_key=None):
        # Child class should define/load:
        # self.kp_perm
        # self.img_dir
        # self.anno
        # self.anno_camera
        self.opts = opts
        self.img_size = opts.img_size
        self.jitter_frac = opts.jitter_frac
        self.padding_frac = opts.padding_frac
        self.filter_key = filter_key

        if not opts.use_smal_betas:
            # We need to load the blendshapes
            model_path = osp.join(self.opts.model_dir, self.opts.model_name)
            with open(model_path, 'r') as f:
                dd = pkl.load(f)
                num_betas = dd['shapedirs'].shape[-1]
                self.shapedirs = np.reshape(dd['shapedirs'], [-1, num_betas]).T

    def forward_img(self, index):
        if True: 
            data = self.anno[index].copy()
            data_sfm = self.anno_camera[index].copy()

            img_path = data['img_path']
            if 'img' in data.keys():
                img = data['img']
            else:
                img = scipy.misc.imread(img_path) / 255.0

            camera_params = [np.copy(data_sfm['flength']), np.zeros(2)]

            if 'texture_map' in data.keys():
                texture_map_path = data['texture_map']
                if 'texture_map_data' in data.keys():
                    texture_map = data['texture_map_data']
                else:
                    texture_map = scipy.misc.imread(texture_map_path) / 255.0
                    texture_map = np.transpose(texture_map, (2, 0, 1))
            else:
                texture_map = None

            if data['mask_path'] is not None:
                mask_path = data['mask_path']
                if 'mask' in data.keys():
                    mask = data['mask']
                else:
                    mask = scipy.misc.imread(mask_path) / 255.0
            else:
                mask = None

            if 'uv_flow_path' in data.keys():
                uv_flow_path = data['uv_flow_path']
                if 'uv_flow' in data.keys():
                    uv_flow = data['uv_flow']
                else:
                    uvdata = pkl.load(open(uv_flow_path))
                    uv_flow = uvdata['uv_flow'].astype(np.float32)
                    uv_flow[:,:,0] = uv_flow[:,:,0] /(uvdata['img_h']/2.)
                    uv_flow[:,:,1] = uv_flow[:,:,1] /(uvdata['img_w']/2.)
            else:
                uv_flow = None

            occ_map = None


            kp = data['keypoints']
            if 'trans' in data.keys():
                model_trans = data['trans'].copy()
            else:
                model_trans = None
            if 'pose' in data.keys():
                model_pose = data['pose'].copy()
            else:
                model_pose = None
            if 'betas' in data.keys():
                model_betas = data['betas'].copy()
            else:
                model_betas = None
            if 'delta_v' in data.keys():
                model_delta_v = data['delta_v'].copy()
                if not self.opts.use_smal_betas:
                    # Modify the deformation to include B*\betas
                    nBetas = len(model_betas)
                    model_delta_v = model_delta_v + np.reshape(np.matmul(model_betas, self.shapedirs[:nBetas,:]), [model_delta_v.shape[0], model_delta_v.shape[1]])
            else:
                model_delta_v = None
        
        # Perspective camera needs image center 
        camera_params[1][0] = img.shape[1]/2.
        camera_params[1][1] = img.shape[0]/2.

        if mask is not None:
            M = mask[:,:,0]
            xmin = np.min(np.where(M>0)[1])
            ymin = np.min(np.where(M>0)[0])
            xmax = np.max(np.where(M>0)[1])
            ymax = np.max(np.where(M>0)[0])
        else:
            xmin = 0
            ymin = 0
            xmax = img.shape[1]
            ymax = img.shape[0]

        # Compute bbox
        bbox = np.array([xmin, ymin, xmax, ymax], float)

        if self.opts.border > 0:
            assert(('trans' in data.keys())==False)
            assert(('pose' in data.keys())==False)
            assert(('kp' in data.keys())==False)
            # This has to be used only if there are no annotations for the refinement!
            scale_factor = float(self.opts.img_size-2*self.opts.border) / np.max(img.shape[:2])
            img, _ = image_utils.resize_img(img, scale_factor)
            if mask is not None:
                mask, _ = image_utils.resize_img(mask, scale_factor)
                # Crop img_size x img_size from the center
            center = np.round(np.array(img.shape[:2]) / 2).astype(int)
            # img center in (x, y)
            center = center[::-1]
            bbox = np.hstack([center - self.opts.img_size / 2., center + self.opts.img_size / 2.])


        if kp is not None:
            vis = kp[:, 2] > 0
            kp[vis, :2] -= 1
        else:
            vis = None

        # Peturb bbox
        if self.opts.perturb_bbox and mask is not None:
            bbox = image_utils.peturb_bbox(bbox, pf=self.padding_frac, jf=self.jitter_frac)

        orig_bbox = bbox[:]
        bbox = image_utils.square_bbox(bbox)

        # crop image around bbox, translate kps

        #if self.opts.use_bbox and mask is not None:
        if not self.opts.is_optimization:
            img, mask, kp, camera_params, model_trans, occ_map, uv_flow = self.crop_image(img, 
                mask, bbox, kp, vis, camera_params, model_trans, occ_map, uv_flow)

            # scale image, and mask. And scale kps.        
            img, mask, kp, camera_params, occ_map, uv_flow = self.scale_image(img, mask, kp, vis, camera_params, 
                    occ_map, orig_bbox, uv_flow)

        # Normalize kp to be [-1, 1]
        img_h, img_w = img.shape[:2]
        if kp is not None:
            kp_norm = self.normalize_kp(kp, img_h, img_w, self.opts.update_vis)
        else:
            kp_norm = None

        if not self.opts.use_camera:
            focal_length_fix = self.opts.camera_ref
            f_scale = focal_length_fix/camera_params[0]
            camera_params[0] *= f_scale
            model_trans[2] *= f_scale

        if self.opts.save_training:
            scipy.misc.imsave(img_path+'.crop.png', img)
            scipy.misc.imsave(img_path+'.crop_mask.png', mask)
            data = {'kp':kp, 'sfm_pose':camera_params, 'model_trans':model_trans, 'model_pose':model_pose, 'model_betas':model_betas}
            pkl.dump(data, open(img_path+'.crop.pkl', 'wb'))
            if uv_flow is not None:
                pkl.dump(uv_flow, open(img_path+'._uv_flow_crop.pkl', 'wb'))
            print('saved ' + img_path)

        # Finally transpose the image to 3xHxW
        img = np.transpose(img, (2, 0, 1))
        if mask is not None:
            mask = np.transpose(mask, (2, 0, 1))

        if self.opts.border > 0:
            camera_params[1][0] = img.shape[1]/2.
            camera_params[1][1] = img.shape[0]/2.
        return img, kp_norm, mask, camera_params, texture_map, model_trans, model_pose, model_betas, model_delta_v, occ_map, img_path, uv_flow

    def normalize_kp(self, kp, img_h, img_w, update_vis=False):
        vis = kp[:, 2, None] > 0
        new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
                           2 * (kp[:, 1] / img_h) - 1,
                           kp[:, 2]]).T

        if update_vis:
            new_kp[np.where(new_kp[:,0] < -1),2] = 0
            new_kp[np.where(new_kp[:,0] > 1),2] = 0
            new_kp[np.where(new_kp[:,1] < -1),2] = 0
            new_kp[np.where(new_kp[:,1] > 1),2] = 0

        new_kp = vis * new_kp

        return new_kp

    def get_camera_projection_matrix(self, f, c):
        P = np.hstack([np.eye(3), np.zeros((3,1))])
        # Add camera matrix
        K = np.zeros((3, 3))
        K[0, 0] = f
        K[1, 1] = f
        K[2, 2] = 1
        K[0, 2] = c[0]
        K[1, 2] = c[1]
        KP = np.array(np.matrix(K)*np.matrix(P))
        return KP

    def my_project_points(self, ptsw, P):
        # Project world points ptsw(Nx3) into image points ptsi(Nx2) using the camera matrix P(3X4)
        nPts = ptsw.shape[0]
        ptswh = np.ones((nPts, 4))
        ptswh[:, :-1] = ptsw
        ptsih = np.dot(ptswh, P.T)
        ptsi = np.divide(ptsih[:, :-1], ptsih[:, -1][:, np.newaxis])
        return ptsi

    def my_anti_project_points(self, ptsi, P):
        nPts = ptsi.shape[0]
        ptsih = np.ones((nPts, 3))
        ptsih[:, :-1] = ptsi
        ptswh = np.dot(ptsih, np.array(np.matrix(P.T).I))
        nPts = ptswh.shape[0]
        if P[-1,-1] == 0:
            ptsw = ptswh[:, :-1]
        else:
            ptsw = np.divide(ptswh[:, :-1], ptswh[:, -1][:, np.newaxis])
        return ptsw


    def get_model_trans_for_cropped_image(self, trans, bbox, flength, img_w, img_h):
        '''
        trans: 3  model translation 
        bbox: 1 x 4 xmin, ymin, xmax, ymax
        flength: 1
        img_w: 1 width original image
        img_h: 1 height original image
        '''
        # Location of the model in image frame (pixel coo)
        P = self.get_camera_projection_matrix(flength, np.array([img_w/2., img_h/2.]))
        Q = np.zeros((1,3))
        Q[0,:] = trans
        W = self.my_project_points(Q, P)

        # Location of the model w.r.t. the center of the bbox (pixel coo)
        E = np.zeros((1,2))
        E[0,0] = W[0,0] - (bbox[0] + (bbox[2]-bbox[0])/2.)
        E[0,1] = W[0,1] - (bbox[1] + (bbox[3]-bbox[1])/2.)

        # Define the new camera for the bbox
        # Center of the bbox in the bbox frame
        c = np.array([bbox[2]-bbox[0], bbox[3]-bbox[1]])/2.
        P = self.get_camera_projection_matrix(flength, c)
        P[-1,-1] = trans[2]
        # Location of the model in world space w.r.t. the new image and camera
        D = self.my_anti_project_points(E, P)

        trans[:] = np.array([D[0,0], D[0,1], trans[2]])

        return trans


    def crop_image(self, img, mask, bbox, kp, vis, camera_params, model_trans, occ_map, uv_flow):

        img_orig_h, img_orig_w = img.shape[:2]

        # crop image and mask and translate kps
        img = image_utils.crop(img, bbox, bgval=self.opts.bgval)
        if mask is not None:
            mask = image_utils.crop(mask, bbox, bgval=0)

        if occ_map is not None:
            occ_map = image_utils.crop(occ_map, bbox, bgval=0)

        if uv_flow is not None:
            # uv_flow has image coordinates in the first 2 channels and a mask in the third channel
            # image coordinates are normalized w.r.t the original size so their value is in [-1,1]
            # un-normalize uv_flow coordinates
            uv = uv_flow[:,:,:2]
            uv[:,:,0] = uv[:,:,0]*(img_orig_h/2.)+img_orig_h/2.
            uv[:,:,1] = uv[:,:,1]*(img_orig_w/2.)+img_orig_w/2.
            # Change the values
            uv[:,:,0] -= bbox[0]
            uv[:,:,1] -= bbox[1]
            img_h, img_w = img.shape[:2]
            uv_flow[:,:,0] = (uv[:,:,0]-(img_h/2.))/(img_h/2.)
            uv_flow[:,:,1] = (uv[:,:,1]-(img_w/2.))/(img_w/2.)

        if kp is not None:
            kp[vis, 0] -= bbox[0]
            kp[vis, 1] -= bbox[1]
        
        if camera_params[0]>0: 
            model_trans = self.get_model_trans_for_cropped_image(model_trans, bbox, camera_params[0], img_orig_w, img_orig_h)
            camera_params[1][0] = img.shape[1]/2.
            camera_params[1][1] = img.shape[0]/2.
        else:
            import pdb; pdb.set_trace()

        return img, mask, kp, camera_params, model_trans, occ_map, uv_flow

    def scale_image(self, img, mask, kp, vis, camera_params, occ_map, orig_bbox, uv_flow):

        # Scale image so largest bbox size is img_size
        bwidth = np.shape(img)[0]
        bheight = np.shape(img)[1]

        scale = self.img_size / float(max(bwidth, bheight))
        img_scale, _ = image_utils.resize_img(img, scale)

        if occ_map is not None:
            occ_map, _ = image_utils.resize_img(occ_map, scale)

        if mask is not None:
            mask_scale, _ = image_utils.resize_img(mask, scale)
        else:
            mask_scale = None

        if kp is not None:
            kp[vis, :2] *= scale

        if uv_flow is not None:
            img_orig_h, img_orig_w = img.shape[:2]
            # un-normalize uv_flow coordinates
            uv = uv_flow[:,:,:2]
            uv[:,:,0] = uv[:,:,0]*(img_orig_h/2.)+img_orig_h/2.
            uv[:,:,1] = uv[:,:,1]*(img_orig_w/2.)+img_orig_w/2.
            # Change the values
            uv[:,:,0] *= scale
            uv[:,:,1] *= scale
            img_h, img_w = img_scale.shape[:2]
            uv_flow[:,:,0] = (uv[:,:,0]-(img_h/2.))/(img_h/2.)
            uv_flow[:,:,1] = (uv[:,:,1]-(img_w/2.))/(img_w/2.)

        bwidth = orig_bbox[2] - orig_bbox[0] + 1
        bheight = orig_bbox[3] - orig_bbox[1] + 1

        if camera_params[0] > 0:
            camera_params[0] *= scale 
        camera_params[1] *= scale
        return img_scale, mask_scale, kp, camera_params, occ_map, uv_flow


    def __len__(self):
        return self.num_imgs

    def __getitem__(self, index):
        img, kp, mask, camera_params, texture_map, model_trans, model_pose, model_betas, model_delta_v, occ_map, img_path, uv_flow = self.forward_img(index)

        camera_params[0].shape = 1

        elem = {
            'img': img,
            'inds': index,
            'img_path':img_path,
            'camera_params_c':camera_params[1],
        }
        if kp is not None:
            elem['kp'] = kp
        if mask is not None:
            elem['mask'] = mask
        if texture_map is not None:
            elem['texture_map'] = texture_map
        if camera_params[0]:
            elem['camera_params'] = np.concatenate(camera_params)
        if model_trans is not None:
            elem['model_trans'] = model_trans
        if model_pose is not None:
            elem['model_pose'] = model_pose
        if model_betas is not None:
            elem['model_betas'] = model_betas
        if model_delta_v is not None:
            elem['model_delta_v'] = model_delta_v
        if uv_flow is not None:
            elem['uv_flow'] = uv_flow

        if self.filter_key is not None:
            if self.filter_key not in elem.keys():
                print('Bad filter key %s' % self.filter_key)
                import ipdb; ipdb.set_trace()
            if self.filter_key == 'camera_params':
                # Return both vis and sfm_pose
                vis = elem['kp'][:, 2]
                elem = {
                    'vis': vis,
                    'camera_params': elem['camera_params'],
                }
            else:
                elem = elem[self.filter_key]


        return elem

# ------------ Data Loader ----------- #
# ------------------------------------ #
def base_loader(d_set_func, batch_size, opts, filter_key=None, shuffle=True, filter_name=None):
    dset = d_set_func(opts, filter_key=filter_key, filter_name=filter_name)
    return DataLoader(
        dset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=opts.n_data_workers,
        drop_last=True)


================================================
FILE: data/zebra.py
================================================
"""

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import numpy as np
import pickle as pkl

import scipy.io as sio
import scipy.misc

from absl import flags, app

import torch
from torch.utils.data import Dataset

from . import smal_base as base_data
from ..utils import transformations

import pickle as pkl


# -------------- flags ------------- #
# ---------------------------------- #
#kData = 'nokap/zebra_data'
    
flags.DEFINE_string('zebra_dir', 'smalst/zebra_data', 'Zebra Data Directory')
flags.DEFINE_string('image_file_string', '*.png', 'String use to read the images')

curr_path = osp.dirname(osp.abspath(__file__))
cache_path = osp.join(curr_path, '..', 'cachedir')
flags.DEFINE_string('zebra_cache_dir', osp.join(cache_path, 'zebra'), 'Zebra Data Directory')
flags.DEFINE_integer('num_images', 3000, 'Number of training images')
flags.DEFINE_boolean('preload_image', False, '')
flags.DEFINE_boolean('preload_mask', False, '')
flags.DEFINE_boolean('preload_texture_map', False, '')
flags.DEFINE_boolean('preload_uvflow', False, '')
flags.DEFINE_boolean('use_per_file_texmap', True, 'use the file with updated change in color')

opts = flags.FLAGS

# -------------- Dataset ------------- #
# ------------------------------------ #
class ZebraDataset(base_data.BaseDataset):
    '''
    Zebra Data loader
    '''

    def __init__(self, opts, filter_key=None, filter_name=None):
        super(ZebraDataset, self).__init__(opts, filter_key=filter_key)
        self.data_cache_dir = opts.zebra_cache_dir
        self.filter_key = filter_key
        self.filter_name = filter_name 

        if True: 
            self.data_dir = opts.zebra_dir
            self.img_dir = osp.join(self.data_dir, 'images')
            from glob import glob
            if filter_name is None:
                images = glob(osp.join(self.img_dir, opts.image_file_string))
            else:
                images = glob(osp.join(self.img_dir, filter_name + opts.image_file_string))
            num_images = np.min([len(images), opts.num_images])
            images = images[:num_images]
            self.anno = [None]*num_images
            self.anno_camera = [None]*len(images)


            for i, img in enumerate(images):
                anno_path = osp.join(self.data_dir, 'annotations/%s.pkl' % osp.splitext(osp.basename(img))[0])
                if osp.exists(anno_path):
                    self.anno[i] = pkl.load(open(anno_path))
                    self.anno[i]['mask_path'] = osp.join(self.data_dir, 'bgsub/%s.png' % osp.splitext(osp.basename(img))[0])
                    self.anno[i]['img_path'] = img
                    uv_flow_path = osp.join(self.data_dir, 'uvflow/%s.pkl' % osp.splitext(osp.basename(img))[0])
                    if osp.exists(uv_flow_path):
                        self.anno[i]['uv_flow_path'] = uv_flow_path

                    # In case we have the texture map
                    if 'texture_map_filename' in self.anno[i].keys():
                        if opts.use_per_file_texmap:
                            self.anno[i]['texture_map'] = osp.join(self.data_dir, 'texmap/%s.png' % osp.splitext(osp.basename(img))[0])
                        else:
                            self.anno[i]['texture_map'] = osp.join(self.data_dir, 'texture_maps/%s' % self.anno[i]['texture_map_filename'])

                    # Add a column to the keypoints in case the visibility is not defined
                    if self.anno[i]['keypoints'].shape[1] < 3:
                        self.anno[i]['keypoints'] = np.column_stack([self.anno[i]['keypoints'], np.ones(self.anno[i]['keypoints'].shape[0])])

                    self.anno_camera[i]= {'flength': self.anno[i]['flength'], 'trans': np.zeros(2, dtype=float)}
                    self.kp_perm = np.array(range(self.anno[0]['keypoints'].shape[0]))

                    if opts.preload_image:
                        self.anno[i]['img'] = scipy.misc.imread(self.anno[i]['img_path']) / 255.0
                    if opts.preload_texture_map:
                        texture_map = scipy.misc.imread(self.anno[i]['texture_map']) / 255.0
                        self.anno[i]['texture_map_data'] = np.transpose(texture_map, (2, 0, 1))
                    if opts.preload_mask:
                        self.anno[i]['mask'] = scipy.misc.imread(self.anno[i]['mask_path']) / 255.0
                    if opts.preload_uvflow:
                        uvdata = pkl.load(open(self.anno[i]['uv_flow_path']))
                        uv_flow = uvdata['uv_flow'].astype(np.float32)
                        uv_flow[:,:,0] = uv_flow[:,:,0] /(uvdata['img_h']/2.)
                        uv_flow[:,:,1] = uv_flow[:,:,1] /(uvdata['img_w']/2.)
                        self.anno[i]['uv_flow'] = uv_flow

                else:
                    mask_path = osp.join(self.data_dir, 'bgsub/%s.png' % osp.splitext(osp.basename(img))[0])
                    if osp.exists(mask_path):
                        self.anno[i] = {'mask_path': mask_path, 'img_path':img, 'keypoints': None, 'uv_flow':None}
                    else:
                        self.anno[i] = {'mask_path': None, 'img_path':img, 'keypoints': None, 'uv_flow':None}
                    self.anno_camera[i] = {'flength': None, 'trans': None}

            self.num_imgs = len(self.anno)

        print('%d images' % self.num_imgs)

        #import pdb; pdb.set_trace()
        #self.debug_crop()


#----------- Data Loader ----------#
#----------------------------------#
def data_loader(opts, shuffle=True, filter_name=None):
    return base_data.base_loader(ZebraDataset, opts.batch_size, opts, filter_key=None, shuffle=shuffle, filter_name=filter_name)


def kp_data_loader(batch_size, opts):
    return base_data.base_loader(ZebraDataset, batch_size, opts, filter_key='kp')


def mask_data_loader(batch_size, opts):
    return base_data.base_loader(ZebraDataset, batch_size, opts, filter_key='mask')

def texture_map_data_loader(batch_size, opts):
    return base_data.base_loader(ZebraDataset, batch_size, opts, filter_key='texture_map')
    


================================================
FILE: docs/tmp.txt
================================================



================================================
FILE: experiments/__init__.py
================================================


================================================
FILE: experiments/smal_shape.py
================================================
"""

Example usage:

python -m smalst.experiments.smal_shape --zebra_dir='smalst/zebra_no_toys_wtex_1000_0' --num_epochs=100000 --save_epoch_freq=20 --name=smal_net_600 --save_training_imgs=True --num_images=20000 --do_validation=True

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags

import os.path as osp
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import scipy.io as sio
import scipy
import scipy.misc
from collections import OrderedDict
import pickle as pkl

from ..data import zebra as zebra_data
from ..utils import visutil
from ..utils import smal_vis
from ..utils import image as image_utils
from ..nnutils import train_utils
from ..nnutils import loss_utils
from ..nnutils import smal_mesh_net
from ..nnutils.nmr import NeuralRenderer
from ..nnutils import geom_utils

flags.DEFINE_string('dataset', 'zebra', 'zebra')
# Weights:
flags.DEFINE_float('kp_loss_wt', 10., 'keypoint loss weight')
flags.DEFINE_float('kp_2D_loss_wt', 10., 'loss weight for the 2D keypoints predicted by the network')
flags.DEFINE_float('mask_loss_wt', 30., 'mask loss weight') 
flags.DEFINE_float('cam_loss_wt', 10000., 'weights to camera loss')
flags.DEFINE_float('deform_reg_wt', 100., 'reg to deformation')
flags.DEFINE_float('triangle_reg_wt', 100., 'weights to triangle smoothness prior')
flags.DEFINE_float('vert2kp_loss_wt', .16, 'reg to vertex assignment')
flags.DEFINE_float('tex_loss_wt', 10., 'weights to tex loss')

flags.DEFINE_boolean('grad_v_in_tex_loss', False, '')
flags.DEFINE_boolean('use_keypoints', True, 'use keypoints loss')
flags.DEFINE_boolean('use_mask', True, 'use mask loss')
flags.DEFINE_boolean('use_shape_reg', False, 'use shape regularizers')
flags.DEFINE_float('tex_map_loss_wt', 10., 'weights to tex map loss') 
flags.DEFINE_float('tex_dt_loss_wt', .5, 'weights to tex dt loss')
flags.DEFINE_float('mod_trans_loss_wt', 4000., 'weights for model translation loss')
flags.DEFINE_float('mod_pose_loss_wt', 200000., 'weights for model pose loss')
flags.DEFINE_float('betas_reg_wt', 100000., 'weights for betas prior loss')
flags.DEFINE_float('delta_v_loss_wt', 100000., 'weights for model delta_v')
flags.DEFINE_float('occ_loss_wt', 100., 'weights for occlusion loss')
flags.DEFINE_boolean('infer_vert2kp', False, 'estimate keypoints on the 3D model instead of using predefined values.')

flags.DEFINE_boolean('no_delta_v', False, 'set predicted deformations to zero')
flags.DEFINE_boolean('use_gtpose', False, 'if true uses gt pose for projection, but trans still gets trained.')
flags.DEFINE_boolean('use_gttrans', False, 'if true uses gt trans for projection, but pose still gets trained.')
flags.DEFINE_boolean('use_gtcam', False, 'if true uses gt cam for projection, but cam still gets trained.')
flags.DEFINE_boolean('use_gtbetas', False, 'if true uses gt betas for projection, but betas still gets trained.')
flags.DEFINE_boolean('use_gtdeltav', False, '')
flags.DEFINE_boolean('use_gttexture', False, '')

flags.DEFINE_boolean('use_camera_loss', True, 'if train with gt camera')
flags.DEFINE_boolean('random_bkg', False, 'if using a random background rather than black in the pred image')
flags.DEFINE_boolean('use_perceptual_loss', True, '')

flags.DEFINE_boolean('uv_flow', True, '')
flags.DEFINE_float('uv_flow_loss_wt', 100000., 'weights for uv_flow loss')

flags.DEFINE_boolean('use_pose_geodesic_loss', True, '')
flags.DEFINE_boolean('use_loss_on_whole_image', False, 'if compose the predicted animal with the image background')
flags.DEFINE_boolean('use_tex_dt', True, 'if use loss (4) in the birds paper')

flags.DEFINE_boolean('white_balance_for_texture_map', False, '')
flags.DEFINE_boolean('use_img_as_background', False, 'if to use the input image as background for the optimization')
flags.DEFINE_boolean('use_gtmask_for_background', False, 'if to use the input image as background for the optimization')
flags.DEFINE_boolean('use_per_image_rgb_bg', False, 'if to compute per-imag rgb colors for background in optimization')

opts = flags.FLAGS

curr_path = osp.dirname(osp.abspath(__file__))
cache_path = osp.join(curr_path, '..', 'cachedir')


class ShapeTrainer(train_utils.Trainer):
    def define_model(self):
        opts = self.opts

        self.symmetric = opts.symmetric

        img_size = (opts.img_size, opts.img_size)

        texture_mask_path = 'smalst/'+opts.dataset+'_data/texture_maps/my_smpl_00781_4_all_template_w_tex_uv_001_mask_small.png'
        self.texture_map_mask = torch.Tensor(scipy.misc.imread(texture_mask_path) / 255.0).cuda(device=opts.gpu_id)

        tex_masks = None

        data_path = 'smalst/smpl_models/my_smpl_data_00781_4_all.pkl'
        data = pkl.load(open(data_path))

        pca_var = data['eigenvalues'][:opts.num_betas]
        self.betas_prec = torch.Tensor(pca_var).cuda(device=opts.gpu_id).expand(opts.batch_size, opts.num_betas)

        self.model = smal_mesh_net.MeshNet(
            img_size, opts, nz_feat=opts.nz_feat, num_kps=opts.num_kps, tex_masks=tex_masks)

        if opts.num_pretrain_epochs > 0:
            self.load_network(self.model, 'pred', opts.num_pretrain_epochs)

        self.model = self.model.cuda(device=opts.gpu_id)

        if not opts.infer_vert2kp:
            self.vert2kp = torch.Tensor(pkl.load(open('smalst/'+opts.dataset+'_data/verts2kp.pkl'))).cuda(device=opts.gpu_id)

        # Data structures to use for triangle priors.
        edges2verts = self.model.edges2verts
        # B x E x 4
        edges2verts = np.tile(np.expand_dims(edges2verts, 0), (opts.batch_size, 1, 1))
        self.edges2verts = Variable(torch.LongTensor(edges2verts).cuda(device=opts.gpu_id), requires_grad=False)
        # For renderering.
        faces = self.model.faces.view(1, -1, 3)
        self.faces = faces.repeat(opts.batch_size, 1, 1)
        self.renderer = NeuralRenderer(opts.img_size, opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)
        
        if opts.texture:
            self.tex_renderer = NeuralRenderer(opts.img_size, opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)
            # Only use ambient light for tex renderer
            if opts.use_directional_light:
                self.tex_renderer.directional_light_only()
            else:
                self.tex_renderer.ambient_light_only()

        # For visualization
        self.vis_rend = smal_vis.VisRenderer(opts.img_size, faces.data.cpu().numpy(), opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)

        self.background_imgs = None
        return

    def init_dataset(self):
        opts = self.opts
        if opts.dataset == 'zebra':
            self.data_module = zebra_data
        else:
            print('Unknown dataset %d!' % opts.dataset)

        self.dataloader = self.data_module.data_loader(opts)
        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])

    def define_criterion(self):
        if opts.use_keypoints:
            self.projection_loss = loss_utils.kp_l2_loss
        if opts.use_mask:
            self.mask_loss_fn = loss_utils.mask_loss
        if opts.infer_vert2kp:
            self.entropy_loss = loss_utils.entropy_loss
        if self.opts.use_camera_loss:
            self.camera_loss = loss_utils.camera_loss

        if opts.use_smal_betas:
            self.betas_loss_fn = loss_utils.betas_loss
        self.delta_v_loss_fn = loss_utils.delta_v_loss

        if self.opts.texture:
            if self.opts.use_perceptual_loss:
                if False: 
                    self.texture_loss = loss_utils.MSE_texture_loss
                else:
                    self.texture_loss = loss_utils.PerceptualTextureLoss()
            else:
                self.texture_loss = loss_utils.texture_loss
            self.texture_dt_loss_fn = loss_utils.texture_dt_loss
            if opts.texture_map:
                self.texture_map_loss = loss_utils.texture_map_loss
            if opts.uv_flow:
                self.uv_flow_loss = loss_utils.uv_flow_loss

        self.model_trans_loss_fn = loss_utils.model_trans_loss
        self.model_pose_loss_fn = loss_utils.model_pose_loss

    def set_optimization_input(self):
        opts = self.opts
        cams = np.zeros((self.scale_pred.shape[0], 3))
        cams[:,0] = self.scale_pred.data
        cams[:,1:] = 128
        self.cams = Variable(torch.FloatTensor(cams).cuda(device=opts.gpu_id), requires_grad=False)
        self.model_trans = Variable(self.trans_pred.cuda(device=opts.gpu_id), requires_grad=False)

    def set_optimization_variables(self):
        '''
        Sets as optimization variables those obtained as prediction from the network
        '''
        opts = self.opts
        cams = np.zeros((self.scale_pred.shape[0], 3))
        cams[:,0] = self.scale_pred.data
        cams[:,1:] = 128

        # Prediction is gt
        self.cams = Variable(torch.FloatTensor(cams).cuda(device=opts.gpu_id), requires_grad=False)
        self.model_pose = Variable(self.pose_pred.cuda(device=opts.gpu_id), requires_grad=False)
        self.model_trans = Variable(self.trans_pred.cuda(device=opts.gpu_id), requires_grad=False)
        self.delta_v= Variable(self.delta_v.cuda(device=opts.gpu_id), requires_grad=False)


    def set_input(self, batch):
        opts = self.opts

        # Image with annotations.
        input_img_tensor = batch['img'].type(torch.FloatTensor)

        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])

        img_tensor = batch['img'].type(torch.FloatTensor)
        self.input_imgs = Variable( input_img_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        self.imgs = Variable( img_tensor.cuda(device=opts.gpu_id), requires_grad=False)

        #if opts.use_mask and 'mask' in batch.keys():
        if 'mask' in batch.keys():
            mask_tensor = batch['mask'].type(torch.FloatTensor)
            self.masks = Variable( mask_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.masks = None
        if opts.use_keypoints and 'kp' in batch.keys():
            kp_tensor = batch['kp'].type(torch.FloatTensor)
            self.kps = Variable( kp_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.kps = None

        self.img_paths = batch['img_path']

        if 'camera_params' in batch.keys():
            cam_tensor = batch['camera_params'].type(torch.FloatTensor)
            if opts.use_norm_f_and_z:
                cam_tensor[:,0] = (cam_tensor[:,0]-opts.norm_f0)/opts.norm_f
            self.cams = Variable( cam_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.cams = None
            cam_c_tensor = batch['camera_params_c'].type(torch.FloatTensor)
            self.cams_center = Variable(cam_c_tensor.cuda(device=opts.gpu_id), requires_grad=False)

        if 'model_trans' in batch.keys():
            model_trans_tensor = batch['model_trans'].type(torch.FloatTensor)
            if opts.use_norm_f_and_z:
                model_trans_tensor[:,2] = model_trans_tensor[:,2]-opts.norm_z +1.
            self.model_trans = Variable(
                model_trans_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        if 'model_pose' in batch.keys():
            model_pose_tensor = batch['model_pose'].type(torch.FloatTensor)
            self.model_pose = Variable(
                model_pose_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.model_trans = None
            self.model_pose = None

        if 'model_betas' in batch.keys():
            model_betas_tensor = batch['model_betas'][:,:self.opts.num_betas].type(torch.FloatTensor)
            self.model_betas = Variable(
                model_betas_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.model_betas = None

        if 'model_delta_v' in batch.keys():
            model_delta_v_tensor = batch['model_delta_v'].type(torch.FloatTensor)
            self.model_delta_v = Variable(
                model_delta_v_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.model_delta_v = None

        if opts.texture_map:
            assert('texture_map' in batch.keys())
            texture_map_tensor = batch['texture_map'].type(torch.FloatTensor)
            self.texture_map = Variable(texture_map_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.texture_map = None

        if 'uv_flow' in batch.keys():
            uv_flow_tensor = batch['uv_flow'].type(torch.FloatTensor).permute(0,3,1,2)
            self.uv_flow_gt = Variable(uv_flow_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        else:
            self.uv_flow_gt = None

        # Compute barrier distance transform.
        #if opts.use_mask and self.masks is not None:
        if self.masks is not None:
            mask_dts = np.stack([image_utils.compute_dt_barrier(m) for m in batch['mask']])
            dt_tensor = torch.FloatTensor(mask_dts).cuda(device=opts.gpu_id)
            # B x 1 x N x N
            self.dts_barrier = Variable(dt_tensor, requires_grad=False).unsqueeze(1)


    def forward(self, opts_scale=None, opts_pose=None, opts_trans=None, opts_delta_v=None):
        opts = self.opts
        if opts.use_double_input:
            masks = self.input_imgs*self.masks
        else:
            masks = None
        if opts.texture:
            pred_codes, self.textures = self.model.forward(self.input_imgs, masks)
        else:
            pred_codes = self.model.forward(self.input_imgs, masks)

        self.delta_v, self.scale_pred, self.trans_pred, self.pose_pred, self.betas_pred, self.kp_2D_pred = pred_codes

        if opts.fix_trans:
            self.trans_pred[:,2] = self.model_trans[:,2] 

        if opts.use_gttrans:
            print('Using gt trans') 
            self.trans_pred = self.model_trans 
        if opts.use_gtpose:
            print('Using gt pose') 
            self.pose_pred = self.model_pose
        if opts.use_gtcam:
            print('Using gt cam') 
            self.scale_pred = self.cams[:,0,None]
        if opts.use_gtbetas:
            print('Using gt betas') 
            self.betas_pred = self.model_betas
        if opts.use_gtdeltav:
            print('Using gt delta_v') 
            self.delta_v = self.model_delta_v

        if self.cams is not None:
            # The camera center does not change; here we predicting flength  
            self.cam_pred = torch.cat([self.scale_pred, self.cams[:,1:]], 1)
        else:
            self.cam_pred = torch.cat([self.scale_pred, self.cams_center], 1)

        if opts.only_mean_sym:
            del_v = self.delta_v
        else:
            del_v = self.model.symmetrize(self.delta_v)

        if opts.no_delta_v:
            del_v[:] = 0

        if opts.use_smal_pose: 
            self.pred_v = self.model.get_smal_verts(self.pose_pred, self.betas_pred, self.trans_pred, del_v)
        else:
            # TODO
            self.mean_shape = self.model.get_mean_shape()
            self.pred_v = self.mean_shape + del_v + self.trans_pred

        # Compute keypoints.
        if opts.infer_vert2kp:
            self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp, dim=1)
        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Set projection camera
        proj_cam = self.cam_pred

        # Project keypoints
        if opts.use_keypoints:
            self.kp_pred = self.renderer.project_points(self.kp_verts, proj_cam)

        # Render mask.
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces, proj_cam)

        if opts.texture:
            self.texture_flow = self.textures
            self.textures = geom_utils.sample_textures(self.texture_flow, self.imgs)
            tex_size = self.textures.size(2)
            self.textures = self.textures.unsqueeze(4).repeat(1, 1, 1, 1, tex_size, 1)

            if opts.use_gttexture:
                idx=0
                from ..utils.obj2nmr import obj2nmr_uvmap
                uv_map = obj2nmr_uvmap(self.model.ft, self.model.vt, tex_size=tex_size)
                uv_img = self.texture_map[idx,:,:,:]
                uv_img = uv_img.permute(1,2,0)
                texture_t = sample_texture(uv_map, uv_img)
                self.textures[0,:,:,:,:,:] = texture_t[0,:,:,:,:,:]

            if opts.grad_v_in_tex_loss:
                self.texture_pred = self.tex_renderer.forward(self.pred_v, self.faces, proj_cam.detach(), textures=self.textures)
            else:
                self.texture_pred = self.tex_renderer.forward(self.pred_v.detach(), self.faces, proj_cam.detach(), textures=self.textures)

        else:
            self.textures = None
            if opts.save_training_imgs and opts.use_mask and self.masks is not None:
                T = 255*self.mask_pred.cpu().detach().numpy()[0,:,:]
                scipy.misc.imsave(opts.name + '_mask_pred.png', T)
                T = 255*self.masks.cpu().detach().numpy()[0,:,:,:]
                T = np.transpose(T,(1,2,0))[:,:,0]
                scipy.misc.imsave(opts.name + '_mask_gt.png', T)

        # Compute losses for this instance.
        if self.opts.use_keypoints and self.kps is not None:
            self.kp_loss = self.projection_loss(self.kp_pred, self.kps)
        if self.opts.use_mask and self.masks is not None:
            self.mask_loss = self.mask_loss_fn(self.mask_pred, self.masks[:,0,:,:])
        if self.opts.use_camera_loss and self.cams is not None:
            self.cam_loss = self.camera_loss(self.cam_pred, self.cams, 0, self.opts.use_norm_f_and_z)
        if self.model_trans is not None:
            self.mod_trans_loss = self.model_trans_loss_fn(self.trans_pred, self.model_trans)
        if self.model_pose is not None:
            self.mod_pose_loss = self.model_pose_loss_fn(self.pose_pred, self.model_pose, self.opts)

        if opts.texture:
            if opts.use_loss_on_whole_image:
                
                if self.background_imgs is None:
                    print("SETTING BACKGROUND MODEL")
                    self.background_imgs = np.zeros(self.imgs.shape)
                    fg_mask = self.mask_pred.detach().cpu().numpy()
                    I = self.imgs.detach().cpu().numpy()
                    bg_mask = np.abs(fg_mask-1)
                    rgb = np.zeros((3))
                    n = np.sum(bg_mask)
                    for c in range(3):
                        I[:,c,:,:] = I[:,c,:,:] * bg_mask  
                        rgb[c] = np.sum(I[0,c,:,:])/n

                    if self.background_model_top is not None:
                        N = 128
                        for c in range(3):
                            self.background_imgs[:,c,:N,:] = self.background_model_top[c]
                            self.background_imgs[:,c,N:,:] = self.background_model_bottom[c]
                    else:
                        # This is what we use for optimization
                        if opts.use_per_image_rgb_bg:
                            self.background_imgs[:,0,:,:] = rgb[0] 
                            self.background_imgs[:,1,:,:] = rgb[1] 
                            self.background_imgs[:,2,:,:] = rgb[2] 
                        else:
                            self.background_imgs[:,0,:,:] = .6964 
                            self.background_imgs[:,1,:,:] = .5806 
                            self.background_imgs[:,2,:,:] = .4780 
                        
                        # Verification experiment: replace with image
                        if opts.use_img_as_background:
                            self.background_imgs[:,0,:,:] = self.imgs.data[:,0,:,:]
                            self.background_imgs[:,1,:,:] = self.imgs.data[:,1,:,:]
                            self.background_imgs[:,2,:,:] = self.imgs.data[:,2,:,:]
               
                    self.background_imgs = torch.Tensor(self.background_imgs).cuda(device=opts.gpu_id)
            if self.masks is not None:
                if opts.use_loss_on_whole_image:
                    self.tex_loss = self.texture_loss(self.texture_pred, self.imgs, self.mask_pred, None, self.background_imgs)
                else:
                    self.tex_loss = self.texture_loss(self.texture_pred, self.imgs, self.mask_pred, self.masks[:,0,:,:])
                if opts.use_tex_dt:
                    self.tex_dt_loss = self.texture_dt_loss_fn(self.texture_flow, self.dts_barrier[:,:,:,:,0])
            else:
                if opts.use_loss_on_whole_image:
                    self.tex_loss = self.texture_loss(self.texture_pred, self.imgs, self.mask_pred, None, self.background_imgs)
                else:
                    self.tex_loss = self.texture_loss(self.texture_pred, self.imgs, self.mask_pred, None)
            if opts.texture_map and self.texture_map is not None:
                uv_flows = self.model.texture_predictor.uvimage_pred
                uv_flows = uv_flows.permute(0, 2, 3, 1)
                uv_images = torch.nn.functional.grid_sample(self.imgs, uv_flows)
                self.tex_map_loss = self.texture_map_loss(uv_images, self.texture_map, self.texture_map_mask, self.opts)
            if opts.uv_flow and self.uv_flow_gt is not None:
                uv_flows = self.model.texture_predictor.uvimage_pred
                self.uv_f_loss = self.uv_flow_loss(uv_flows, self.uv_flow_gt)
           
        # Priors:
        if opts.infer_vert2kp:
            self.vert2kp_loss = self.entropy_loss(self.vert2kp)
        if opts.use_smal_betas: 
            self.betas_loss = self.betas_loss_fn(self.betas_pred, self.model_betas, self.betas_prec)
        if self.model_delta_v is not None:
            self.delta_v_loss = self.delta_v_loss_fn(self.delta_v, self.model_delta_v)

        # Finally sum up the loss.
        # Instance loss:
        if opts.use_keypoints and self.kps is not None:
            self.total_loss = opts.kp_loss_wt * self.kp_loss
            if opts.use_mask and self.masks is not None:
                self.total_loss += opts.mask_loss_wt * self.mask_loss
        else:
            if opts.use_mask and self.masks is not None:
                self.total_loss = opts.mask_loss_wt * self.mask_loss
            else:
                self.total_loss = 0

        if not opts.use_gtcam and self.opts.use_camera_loss and self.cams is not None:
            self.total_loss += opts.cam_loss_wt * self.cam_loss

        if opts.texture:
            self.total_loss += opts.tex_loss_wt * self.tex_loss

        if opts.texture_map and self.texture_map is not None:
            self.total_loss += opts.tex_map_loss_wt * self.tex_map_loss
        if opts.uv_flow and self.uv_flow_gt is not None:
            self.total_loss += opts.uv_flow_loss_wt * self.uv_f_loss
        if self.model_trans is not None:
            if not opts.use_gttrans:
                self.total_loss += opts.mod_trans_loss_wt * self.mod_trans_loss
        if self.model_pose is not None:
            if not opts.use_gtpose:
                self.total_loss += opts.mod_pose_loss_wt * self.mod_pose_loss

        if self.model_delta_v is not None:
            self.total_loss += opts.delta_v_loss_wt*self.delta_v_loss

        # Priors:
        if opts.infer_vert2kp:
            self.total_loss += opts.vert2kp_loss_wt * self.vert2kp_loss
        if opts.use_smal_betas: 
            self.total_loss += opts.betas_reg_wt * self.betas_loss

        if opts.texture and self.masks is not None and opts.use_tex_dt:
            self.total_loss += opts.tex_dt_loss_wt * self.tex_dt_loss



    def get_current_visuals(self):
        vis_dict = {}
        try:
            mask_concat = torch.cat([self.masks[:,0,:,:], self.mask_pred], 2)
        except:
            import pdb; pdb.set_trace()


        if self.opts.texture:
            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            uv_flows = uv_flows.permute(0, 2, 3, 1)
            uv_images = torch.nn.functional.grid_sample(self.imgs, uv_flows)

        num_show = min(2, self.opts.batch_size)
        show_uv_imgs = []
        show_uv_flows = []

        for i in range(num_show):
            input_img = smal_vis.kp2im(self.kps[i].data, self.imgs[i].data)
            pred_kp_img = smal_vis.kp2im(self.kp_pred[i].data, self.imgs[i].data)
            masks = smal_vis.tensor2mask(mask_concat[i].data)
            if self.opts.texture:
                texture_here = self.textures[i]
            else:
                texture_here = None

            rend_predcam = self.vis_rend(self.pred_v[i], self.cam_pred[i], texture=texture_here)
            # Render from front & back:
            rend_frontal = self.vis_rend.diff_vp(self.pred_v[i], self.cam_pred[i], texture=texture_here, kp_verts=self.kp_verts[i])
            rend_top = self.vis_rend.diff_vp(self.pred_v[i], self.cam_pred[i], axis=[0, 1, 0], texture=texture_here, kp_verts=self.kp_verts[i])
            diff_rends = np.hstack((rend_frontal, rend_top))

            if self.opts.texture:
                uv_img = smal_vis.tensor2im(uv_images[i].data)
                show_uv_imgs.append(uv_img)
                uv_flow = smal_vis.visflow(uv_flows[i].data)
                show_uv_flows.append(uv_flow)

                tex_img = smal_vis.tensor2im(self.texture_pred[i].data)
                imgs = np.hstack((input_img, pred_kp_img, tex_img))
            else:
                imgs = np.hstack((input_img, pred_kp_img))

            rend_gtcam = self.vis_rend(self.pred_v[i], self.cams[i], texture=texture_here)
            rends = np.hstack((diff_rends, rend_predcam, rend_gtcam))
            vis_dict['%d' % i] = np.hstack((imgs, rends, masks))
            vis_dict['masked_img %d' % i] = smal_vis.tensor2im((self.imgs[i] * self.masks[i]).data)

        if self.opts.texture:
            vis_dict['uv_images'] = np.hstack(show_uv_imgs)
            vis_dict['uv_flow_vis'] = np.hstack(show_uv_flows)

        return vis_dict


    def get_current_points(self):
        return {
            'mean_shape': visutil.tensor2verts(self.mean_shape.data),
            'verts': visutil.tensor2verts(self.pred_v.data),
        }

    def get_current_scalars(self):
        sc_dict = OrderedDict([
            ('smoothed_total_loss', self.smoothed_total_loss),
            ('total_loss', self.total_loss.item()),
        ])
        if self.opts.use_smal_betas: 
            sc_dict['betas_reg'] = self.betas_loss.item()
        if self.opts.use_mask and self.masks is not None:
            sc_dict['mask_loss'] = self.mask_loss.item()
        if self.opts.use_keypoints and self.kps is not None:
            sc_dict['kp_loss'] = self.kp_loss.item()
        if self.opts.use_camera_loss and self.cams is not None:
            sc_dict['cam_loss'] = self.cam_loss.item()
        if self.opts.texture:
            sc_dict['tex_loss'] = self.tex_loss.item()
        if self.opts.texture_map and self.opts.use_tex_dt and self.masks is not None:
            sc_dict['tex_dt_loss'] = self.tex_dt_loss.item()
        if self.opts.uv_flow and self.uv_flow_gt is not None:
            sc_dict['uv_flow_loss'] = self.uv_f_loss.item()
        if self.opts.texture_map and self.texture_map is not None:
            sc_dict['tex_map_loss'] = self.tex_map_loss.item()
        if self.model_trans is not None:
            sc_dict['model_trans_loss'] = self.mod_trans_loss.item()
        if self.model_pose is not None:
            sc_dict['model_pose_loss'] = self.mod_pose_loss.item()
        if opts.infer_vert2kp:
            sc_dict['vert2kp_loss'] = self.vert2kp_loss.item()
        if self.model_delta_v is not None:
            sc_dict['model_delta_v_loss'] = self.delta_v_loss.item()

        return sc_dict


def main(_):
    torch.manual_seed(0)
    np.random.seed(0)
    trainer = ShapeTrainer(opts)
    trainer.init_training()
    trainer.train()

if __name__ == '__main__':
    app.run(main)


================================================
FILE: external/__init__.py
================================================


================================================
FILE: external/install_external.sh
================================================
git clone https://github.com/shubhtuls/PerceptualSimilarity

git clone https://github.com/hiroharu-kato/neural_renderer --branch v1.1.0
cd neural_renderer
python setup.py install


================================================
FILE: nnutils/__init__.py
================================================


================================================
FILE: nnutils/geom_utils.py
================================================
"""
Utils related to geometry like projection,,
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch

def sample_textures(texture_flow, images):
    """
    texture_flow: B x F x T x T x 2
    (In normalized coordinate [-1, 1])
    images: B x 3 x N x N

    output: B x F x T x T x 3
    """
    # Reshape into B x F x T*T x 2
    T = texture_flow.size(-2)
    F = texture_flow.size(1)
    flow_grid = texture_flow.view(-1, F, T * T, 2)
    # B x 3 x F x T*T
    samples = torch.nn.functional.grid_sample(images, flow_grid)
    # B x 3 x F x T x T
    samples = samples.view(-1, 3, F, T, T)
    # B x F x T x T x 3
    return samples.permute(0, 2, 3, 4, 1)


def perspective_proj_withz(X, cam, offset_z=0, cuda_device=0,norm_f=1., norm_z=0.,norm_f0=0.):
    """
    X: B x N x 3
    cam: B x 3: [f, cx, cy] 
    offset_z is for being compatible with previous code and is not used and should be removed
    """

    # B x 1 x 1
    #f = norm_f * cam[:, 0].contiguous().view(-1, 1, 1)
    f = norm_f0+norm_f * cam[:, 0].contiguous().view(-1, 1, 1)
    # B x N x 1
    z = norm_z + X[:, :, 2, None]

    # Will z ever be 0? We probably should max it..
    eps = 1e-6 * torch.ones(1).cuda(device=cuda_device)
    z = torch.max(z, eps)
    image_size_half = cam[0,1]
    scale = f / (z*image_size_half)

    # Offset is because cam is at -1
    return torch.cat((scale * X[:, :, :2], z+offset_z),2)



================================================
FILE: nnutils/loss_utils.py
================================================
"""
Loss Utils.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
from . import geom_utils
import numpy as np
from ..smal_model.batch_lbs import batch_rodrigues

def texture_dt_loss(texture_flow, dist_transf, vis_rend=None, cams=None, verts=None, tex_pred=None):
    """
    texture_flow: B x F x T x T x 2
    (In normalized coordinate [-1, 1])
    dist_transf: B x 1 x N x N

    Similar to geom_utils.sample_textures
    But instead of sampling image, it samples dt values.
    """
    # Reshape into B x F x T*T x 2
    T = texture_flow.size(-2)
    F = texture_flow.size(1)
    flow_grid = texture_flow.view(-1, F, T * T, 2)
    # B x 1 x F x T*T
    dist_transf = torch.nn.functional.grid_sample(dist_transf, flow_grid)

    if vis_rend is not None:
        # Visualize the error!
        # B x 3 x F x T*T
        dts = dist_transf.repeat(1, 3, 1, 1)
        # B x 3 x F x T x T
        dts = dts.view(-1, 3, F, T, T)
        # B x F x T x T x 3
        dts = dts.permute(0, 2, 3, 4, 1)
        dts = dts.unsqueeze(4).repeat(1, 1, 1, 1, T, 1) / dts.max()

        from ..utils import smal_vis
        for i in range(dist_transf.size(0)):
            rend_dt = vis_rend(verts[i], cams[i], dts[i])
            rend_img = smal_vis.tensor2im(tex_pred[i].data)            
            import matplotlib.pyplot as plt
            plt.ion()
            fig=plt.figure(1)
            plt.clf()
            ax = fig.add_subplot(121)
            ax.imshow(rend_dt)
            ax = fig.add_subplot(122)
            ax.imshow(rend_img)
            import pdb; pdb.set_trace()

    return dist_transf.mean()


def texture_loss(img_pred, img_gt, mask_pred, mask_gt):
    """
    Input:
      img_pred, img_gt: B x 3 x H x W
      mask_pred, mask_gt: B x H x W
    """
    mask_pred = mask_pred.unsqueeze(1)
    mask_gt = mask_gt.unsqueeze(1)

    masked_rend = (img_pred * mask_pred)[0].data.cpu().numpy()
    masked_gt = (img_gt * mask_gt)[0].data.cpu().numpy()

    return torch.nn.L1Loss()(img_pred * mask_pred, img_gt * mask_gt)

def uv_flow_loss(uv_flow_pred, uv_flow_gt_w_mask):
    """
    Input:
      uv_flow_pred: B x 2 x H x W
      uv_flow_gt_w_mask: B x 3 x H x W
    """
    # We only have info for the uv_flow for the points where mask is one
    mask = uv_flow_gt_w_mask[:,2,:,:].unsqueeze(1)
    uv_flow_gt = uv_flow_gt_w_mask[:,:2,:,:]
    return torch.nn.L1Loss()(uv_flow_pred*mask, uv_flow_gt*mask)

def mask_loss(mask_pred, mask_gt):
    """
    Input:
      mask_pred: B x 3 x H x W
      mask_gt: B x H x W
    """

    return torch.nn.L1Loss()(mask_pred, mask_gt)

def delta_v_loss(delta_v, delta_v_gt):
    criterion = torch.nn.MSELoss()
    return criterion(delta_v, delta_v_gt)

def texture_map_loss(texture_map_pred, texture_map_gt, texture_map_mask, opts=None):
    """
    Input:
      texture_map_pred: B x 3 x tH x tW
      texture_gt: B x 3 x tH x tW
      texture_map_mask: tH x tW
    """
    mask = texture_map_mask[None,:,:,0]
    texture_map_pred = texture_map_pred*mask
    texture_map_gt = texture_map_gt*mask
    if opts.white_balance_for_texture_map:
        # do gray world normalization
        N = torch.sum(mask)
        B = texture_map_pred.shape[0]
        # gray values
        g_pred = torch.sum(texture_map_pred.view(B,3,-1),dim=2)/N
        g_gt = torch.sum(texture_map_gt.view(B,3,-1),dim=2)/N

        texture_map_pred = texture_map_pred / (g_pred.unsqueeze_(-1).unsqueeze_(-1))
        texture_map_gt = texture_map_gt / (g_gt.unsqueeze_(-1).unsqueeze_(-1))
    
    return torch.nn.L1Loss()(texture_map_pred, texture_map_gt)
    
def camera_loss(cam_pred, cam_gt, margin, normalized):
    """
    cam_* are B x 7, [sc, tx, ty, quat]
    Losses are in similar magnitude so one margin is ok.
    """

    # Only the first element as the rest is fixed
    if normalized:
        criterion = torch.nn.MSELoss()
        return criterion(cam_pred[:, 0], cam_gt[:, 0])
    else:
        st_loss = ((cam_pred[:, 0] - cam_gt[:, 0])/1e3)**2
        return st_loss.mean()

def model_trans_loss(trans_pred, trans_gt):
    """
    trans_pred: B x 3
    trans_gt: B x 3
    """
    criterion = torch.nn.MSELoss()
    return criterion(trans_pred, trans_gt)

def model_pose_loss(pose_pred, pose_gt, opts):
    """
    pose_pred: B x 115
    pose_gt: B x 115
    """
    if opts.use_pose_geodesic_loss:
        # Convert each angle in 
        R = torch.reshape( batch_rodrigues(torch.reshape(pose_pred, [-1, 3]), opts=opts), [-1, 35, 3, 3])
        # Loss is acos((tr(R'R)-1)/2)
        Rgt = torch.reshape( batch_rodrigues(torch.reshape(pose_gt, [-1, 3]), opts=opts), [-1, 35, 3, 3])
        RT = R.permute(0,1,3,2)
        A = torch.matmul(RT.view(-1,3,3),Rgt.view(-1,3,3))
        # torch.trace works only for 2D tensors

        n = A.shape[0]
        po_loss =  0    
        eps = 1e-7
        for i in range(A.shape[0]):
            T = (torch.trace(A[i,:,:])-1)/2.
            po_loss += torch.acos(torch.clamp(T, -1 + eps, 1-eps))
        po_loss = po_loss/(n*35)
        return po_loss
    else:
        criterion = torch.nn.MSELoss()
        return criterion(pose_pred, pose_gt)
    


def betas_loss(betas_pred, betas_gt=None, prec=None):
    """
    betas_pred: B x 10
    """
    if betas_gt is None:
        if prec is None:
            b_loss = betas_pred**2
        else:
            b_loss = betas_pred*prec
            return b_loss.mean()
    else:
        criterion = torch.nn.MSELoss()
        return criterion(betas_pred, betas_gt)


def hinge_loss(loss, margin):
    # Only penalize if loss > margin
    zeros = torch.autograd.Variable(torch.zeros(1).cuda(), requires_grad=False)
    return torch.max(loss - margin, zeros)

def kp_l2_loss(kp_pred, kp_gt):
    """
    L2 loss between visible keypoints.

    \Sum_i [0.5 * vis[i] * (kp_gt[i] - kp_pred[i])^2] / (|vis|)
    """
    criterion = torch.nn.MSELoss()

    vis = (kp_gt[:, :, 2, None] > 0).float()

    # This always has to be (output, target), not (target, output)
    return criterion(vis * kp_pred, vis * kp_gt[:, :, :2])

def keypoints_2D_loss(kp_pred, kp_gt):
    criterion = torch.nn.MSELoss()

    vis = (kp_gt[:, :, 2, None] > 0).float()

    return criterion(vis * kp_pred, vis * kp_gt[:, :, :2])

def MSE_texture_loss(img_pred, img_gt, mask_pred, mask_gt, background_imgs=None):
    mask_pred = mask_pred.unsqueeze(1)
    if mask_gt is None:
        M = torch.abs(mask_pred - 1.)
        img_pred = img_pred*mask_pred + background_imgs*M
        
        dist = torch.nn.MSELoss()(img_pred*mask_pred, img_gt*mask_pred)

        return dist 

class PerceptualTextureLoss(object):
    def __init__(self):
        from ..nnutils.perceptual_loss import PerceptualLoss
        self.perceptual_loss = PerceptualLoss()

    def __call__(self, img_pred, img_gt, mask_pred, mask_gt, background_imgs=None):
        """
        Input:
          img_pred, img_gt: B x 3 x H x W
        mask_pred, mask_gt: B x H x W
        """
        mask_pred = mask_pred.unsqueeze(1)
        
        # Add a background to img_pred. This is used for the optimization without the groundtruth mask, but could
        # be also used for the regular training
        if mask_gt is None:
            img_pred = img_pred*mask_pred + background_imgs*(torch.abs(mask_pred - 1.))

            dist = self.perceptual_loss(img_pred, img_gt)
            return dist.mean()
        
        mask_gt = mask_gt.unsqueeze(1)
        # Only use mask_gt..
        dist = self.perceptual_loss(img_pred * mask_gt, img_gt * mask_gt)
        return dist.mean()


================================================
FILE: nnutils/net_blocks.py
================================================
'''
CNN building blocks.
Taken from https://github.com/shubhtuls/factored3d/
'''
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import math

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)

class Unsqueeze(nn.Module):
    def __init__(self, dim):
        super(Unsqueeze, self).__init__()
        self.dim = dim

    def forward(self, x):
        return x.unsqueeze(self.dim)

## fc layers
def fc(norm_type, nc_inp, nc_out):
    if norm_type == 'batch':
        return nn.Sequential(
            nn.Linear(nc_inp, nc_out, bias=True),
            nn.BatchNorm1d(nc_out),
            nn.LeakyReLU(0.2,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Linear(nc_inp, nc_out),
            nn.LeakyReLU(0.1,inplace=True)
        )

def fc_stack(nc_inp, nc_out, nlayers, norm_type='batch'):
    modules = []
    for l in range(nlayers):
        modules.append(fc(norm_type, nc_inp, nc_out))
        nc_inp = nc_out
    encoder = nn.Sequential(*modules)
    net_init(encoder)
    return encoder

def fc_stack_dropout(nc_inp, nc_out, nlayers): 
    modules = []
    modules.append(nn.Linear(nc_inp, 1024, bias=True))
    modules.append(nn.ReLU())
    modules.append(nn.Dropout())
    modules.append(nn.Linear(1024, 1024, bias=True))
    modules.append(nn.ReLU())
    modules.append(nn.Dropout())
    modules.append(nn.Linear(1024, nc_out, bias=True))

    encoder = nn.Sequential(*modules)
    net_init(encoder)
    nl = 1
    for m in encoder.modules():
        if isinstance(m, nn.Linear):
            if nl == nlayers:
                torch.nn.init.xavier_normal(m.weight, gain=0.01)
            else:
                torch.nn.init.xavier_normal(m.weight)
            if m.bias is not None:
                m.bias.data.zero_()
            nl += 1

    return encoder

## 2D convolution layers
def conv2d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2):
    if norm_type == 'batch':
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    elif norm_type == 'group':
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.GroupNorm(num_groups, out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.2,inplace=True)
        )


def deconv2d(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
        nn.LeakyReLU(0.2,inplace=True)
    )


def upconv2d(in_planes, out_planes, mode='bilinear'):
    if mode == 'nearest':
        print('Using NN upsample!!')
    upconv = nn.Sequential(
        nn.Upsample(scale_factor=2, mode=mode),
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0),
        nn.LeakyReLU(0.2,inplace=True)
    )
    return upconv


def decoder2d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True, use_deconv=False, upconv_mode='bilinear', num_groups=2):
    ''' Simple 3D encoder with nlayers.
    
    Args:
        nlayers: number of decoder layers
        nz_shape: number of bottleneck
        nc_input: number of channels to start upconvolution from
        use_bn: whether to use batch_norm
        nc_final: number of output channels
        nc_min: number of min channels
        nc_step: double number of channels every nc_step layers
        init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D
    '''
    modules = []
    if init_fc:
        modules.append(fc('batch', nz_shape, nc_input))
        for d in range(3):
            modules.append(Unsqueeze(2))
    nc_output = nc_input
    for nl in range(nlayers):
        if (nl % nc_step==0) and (nc_output//2 >= nc_min):
            nc_output = nc_output//2
        if use_deconv:
            print('Using deconv decoder!')
            modules.append(deconv2d(nc_input, nc_output))
            nc_input = nc_output
            modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups//2))
        else:
            modules.append(upconv2d(nc_input, nc_output, mode=upconv_mode))
            nc_input = nc_output
            modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups//2))

    modules.append(nn.Conv2d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True))
    decoder = nn.Sequential(*modules)
    net_init(decoder)
    return decoder


## 3D convolution layers
def conv3d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2):
    if norm_type == 'batch':
        return nn.Sequential(
            nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.BatchNorm3d(out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    elif norm_type == 'group':
        return nn.Sequential(
            nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.GroupNorm(num_groups, out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.2,inplace=True)
        )


def deconv3d(norm_type, in_planes, out_planes, num_groups=2):
    if norm_type == 'batch':
        return nn.Sequential(
            nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
            nn.BatchNorm3d(out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    elif norm_type == 'group':
        return nn.Sequential(
            nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
            nn.GroupNorm(num_groups, out_planes),
            nn.LeakyReLU(0.2,inplace=True)
        )
    else:        
        return nn.Sequential(
            nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2,inplace=True)
        )


## 3D Network Modules
def encoder3d(nlayers, norm_type='batch', nc_input=1, nc_max=128, nc_l1=8, nc_step=1, nz_shape=20):
    ''' Simple 3D encoder with nlayers.
    
    Args:
        nlayers: number of encoder layers
        use_bn: whether to use batch_norm
        nc_input: number of input channels
        nc_max: number of max channels
        nc_l1: number of channels in layer 1
        nc_step: double number of channels every nc_step layers      
        nz_shape: size of bottleneck layer
    '''
    modules = []
    nc_output = nc_l1
    for nl in range(nlayers):
        if (nl>=1) and (nl%nc_step==0) and (nc_output <= nc_max*2):
            nc_output *= 2

        modules.append(conv3d(norm_type, nc_input, nc_output, stride=1))
        nc_input = nc_output
        modules.append(conv3d(norm_type, nc_input, nc_output, stride=1))
        modules.append(torch.nn.MaxPool3d(kernel_size=2, stride=2))

    modules.append(Flatten())
    modules.append(fc_stack(nc_output, nz_shape, 2, norm_type))
    encoder = nn.Sequential(*modules)
    net_init(encoder)
    return encoder, nc_output


def decoder3d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True):
    ''' Simple 3D encoder with nlayers.
    
    Args:
        nlayers: number of decoder layers
        nz_shape: number of bottleneck
        nc_input: number of channels to start upconvolution from
        use_bn: whether to use batch_norm
        nc_final: number of output channels
        nc_min: number of min channels
        nc_step: double number of channels every nc_step layers
        init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D
    '''
    modules = []
    if init_fc:
        modules.append(fc('batch', nz_shape, nc_input))
        for d in range(3):
            modules.append(Unsqueeze(2))
    nc_output = nc_input
    for nl in range(nlayers):
        if (nl%nc_step==0) and (nc_output//2 >= nc_min):
            nc_output = nc_output//2

        modules.append(deconv3d(norm_type, nc_input, nc_output))
        nc_input = nc_output
        modules.append(conv3d(norm_type, nc_input, nc_output))

    modules.append(nn.Conv3d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True))
    decoder = nn.Sequential(*modules)
    net_init(decoder)
    return decoder


def net_init(net):
    for m in net.modules():
        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

        if isinstance(m, nn.Conv2d): #or isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

        if isinstance(m, nn.ConvTranspose2d):
            # Initialize Deconv with bilinear weights.
            base_weights = bilinear_init(m.weight.data.size(-1))
            base_weights = base_weights.unsqueeze(0).unsqueeze(0)
            m.weight.data = base_weights.repeat(m.weight.data.size(0), m.weight.data.size(1), 1, 1)
            if m.bias is not None:
                m.bias.data.zero_()

        if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()


def bilinear_init(kernel_size=4):
    # Following Caffe's BilinearUpsamplingFiller
    # https://github.com/BVLC/caffe/pull/2213/files
    import numpy as np
    width = kernel_size
    height = kernel_size
    f = int(np.ceil(width / 2.))
    cc = (2 * f - 1 - f % 2) / (2.*f)
    weights = torch.zeros((height, width))
    for y in range(height):
        for x in range(width):
            weights[y, x] = (1 - np.abs(x / f - cc)) * (1 - np.abs(y / f - cc))

    return weights


if __name__ == '__main__':
    decoder2d(5, None, 256, use_deconv=True, init_fc=False)
    bilinear_init()


================================================
FILE: nnutils/nmr.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy.misc
import tqdm

import chainer
import torch

import neural_renderer

from ..nnutils import geom_utils

#############
### Utils ###
#############
def convert_as(src, trg):
    src = src.type_as(trg)
    if src.is_cuda:
        src = src.cuda(device=trg.get_device())
    return src

########################################################################
############ Wrapper class for the chainer Neural Renderer #############
##### All functions must only use numpy arrays as inputs/outputs #######
########################################################################
class NMR(object):
    def __init__(self):
        # setup renderer
        renderer = neural_renderer.Renderer()
        self.renderer = renderer

    def to_gpu(self, device=0):
        # self.renderer.to_gpu(device)
        self.cuda_device = device

    def forward_mask(self, vertices, faces):
        ''' Renders masks.
        Args:
            vertices: B X N X 3 numpy array
            faces: B X F X 3 numpy array
        Returns:
            masks: B X 256 X 256 numpy array
        '''
        self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device))
        self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device))

        self.masks = self.renderer.render_silhouettes(self.vertices, self.faces)

        masks = self.masks.data.get()
        return masks
    
    def backward_mask(self, grad_masks):
        ''' Compute gradient of vertices given mask gradients.
        Args:
            grad_masks: B X 256 X 256 numpy array
        Returns:
            grad_vertices: B X N X 3 numpy array
        '''
        self.masks.grad = chainer.cuda.to_gpu(grad_masks, self.cuda_device)
        self.masks.backward()
        return self.vertices.grad.get()

    def forward_img(self, vertices, faces, textures):
        ''' Renders masks.
        Args:
            vertices: B X N X 3 numpy array
            faces: B X F X 3 numpy array
            textures: B X F X T X T X T X 3 numpy array
        Returns:
            images: B X 3 x 256 X 256 numpy array
        '''
        self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device))
        self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device))
        self.textures = chainer.Variable(chainer.cuda.to_gpu(textures, self.cuda_device))
        self.images = self.renderer.render(self.vertices, self.faces, self.textures)

        images = self.images.data.get()
        return images


    def backward_img(self, grad_images):
        ''' Compute gradient of vertices given image gradients.
        Args:
            grad_images: B X 3? X 256 X 256 numpy array
        Returns:
            grad_vertices: B X N X 3 numpy array
            grad_textures: B X F X T X T X T X 3 numpy array
        '''
        self.images.grad = chainer.cuda.to_gpu(grad_images, self.cuda_device)
        self.images.backward()
        return self.vertices.grad.get(), self.textures.grad.get()

########################################################################
################# Wrapper class a rendering PythonOp ###################
##### All functions must only use torch Tensors as inputs/outputs ######
########################################################################
class Render(torch.autograd.Function):
    # TODO(Shubham): Make sure the outputs/gradients are on the GPU
    def __init__(self, renderer):
        super(Render, self).__init__()
        self.renderer = renderer

    def forward(self, vertices, faces, textures=None):
        # B x N x 3
        # Flipping the y-axis here to make it align with the image coordinate system!
        vs = vertices.cpu().numpy()
        vs[:, :, 1] *= -1
        fs = faces.cpu().numpy()
        if textures is None:
            self.mask_only = True
            masks = self.renderer.forward_mask(vs, fs)
            return convert_as(torch.Tensor(masks), vertices)
        else:
            self.mask_only = False
            ts = textures.cpu().numpy()
            imgs = self.renderer.forward_img(vs, fs, ts)
            return convert_as(torch.Tensor(imgs), vertices)

    def backward(self, grad_out):
        g_o = grad_out.cpu().numpy()
        if self.mask_only:
            grad_verts = self.renderer.backward_mask(g_o)
            grad_verts = convert_as(torch.Tensor(grad_verts), grad_out)
            grad_tex = None
        else:
            grad_verts, grad_tex = self.renderer.backward_img(g_o)
            grad_verts = convert_as(torch.Tensor(grad_verts), grad_out)
            grad_tex = convert_as(torch.Tensor(grad_tex), grad_out)

        grad_verts[:, :, 1] *= -1
        return grad_verts, None, grad_tex


########################################################################
############## Wrapper torch module for Neural Renderer ################
########################################################################
class NeuralRenderer(torch.nn.Module):
    """
    This is the core pytorch function to call.
    Every torch NMR has a chainer NMR.
    Only fwd/bwd once per iteration.
    """
    def __init__(self, img_size=256, proj_type='perspective', norm_f=1., norm_z=0.,norm_f0=0.):
        super(NeuralRenderer, self).__init__()
        self.renderer = NMR()

        self.norm_f = norm_f
        self.norm_f0 = norm_f0
        self.norm_z = norm_z

        # Adjust the core renderer
        self.renderer.renderer.image_size = img_size
        self.renderer.renderer.perspective = False

        # Set a default camera to be at (0, 0, -2.732)
        self.renderer.renderer.eye = [0, 0, -1.0]

        # Make it a bit brighter for vis
        self.renderer.renderer.light_intensity_ambient = 0.8

        self.renderer.to_gpu()

        # Silvia
        if proj_type == 'perspective':
            self.proj_fn = geom_utils.perspective_proj_withz
        else:
            print('unknown projection type')
            import pdb; pdb.set_trace()

        self.offset_z = -1.0

    def ambient_light_only(self):
        # Make light only ambient.
        self.renderer.renderer.light_intensity_ambient = 1
        self.renderer.renderer.light_intensity_directional = 0

    def directional_light_only(self):
        # Make light only directional.
        self.renderer.renderer.light_intensity_ambient = 0.8
        self.renderer.renderer.light_intensity_directional = 0.8
        self.renderer.renderer.light_direction = [0, 1, 0]  # up-to-down, this is the default

    def set_bgcolor(self, color):
        self.renderer.renderer.background_color = color

    def project_points(self, verts, cams):
        proj = self.proj_fn(verts, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z, norm_f0=self.norm_f0)
        return proj[:, :, :2]

    def forward(self, vertices, faces, cams, textures=None):
        verts = self.proj_fn(vertices, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z, norm_f0=self.norm_f0)

        if textures is not None:
            return Render(self.renderer)(verts, faces, textures)
        else:
            return Render(self.renderer)(verts, faces)




================================================
FILE: nnutils/perceptual_loss.py
================================================
"""
Calls Richard's Perceptual Loss.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
from torch.autograd import Variable
from ..external.PerceptualSimilarity.models import dist_model


class PerceptualLoss(object):
    def __init__(self, model='net', net='alex', use_gpu=True):
        print('Setting up Perceptual loss..')
        self.model = dist_model.DistModel()
        self.model.initialize(model=model, net=net, use_gpu=True)
        print('Done')

    def __call__(self, pred, target, normalize=True):
        """
        Pred and target are Variables.
        If normalize is on, scales images between [-1, 1]
        Assumes the inputs are in range [0, 1].
        """
        if normalize:
            target = 2 * target - 1
            pred = 2 * pred - 1

        dist = self.model.forward_pair(target, pred)

        return dist

================================================
FILE: nnutils/smal_mesh_eval.py
================================================
"""
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags, app
import numpy as np
import skimage.io as io

import pickle as pkl

import torch
import scipy
import scipy.misc
from ..nnutils import smal_predictor as pred_util
from ..utils import image as img_util
from glob import glob
import scipy.io as sio
import torchvision
from torch.autograd import Variable


opts = flags.FLAGS


def preprocess_image(img_path, img_size=256, kp=None, border=20):
    img = io.imread(img_path) / 255.
    img = img[:,:,:3]

    # Scale the max image size to be img_size
    scale_factor = float(img_size-2*border) / np.max(img.shape[:2])
    img, _ = img_util.resize_img(img, scale_factor)

    # Crop img_size x img_size from the center
    center = np.round(np.array(img.shape[:2]) / 2).astype(int)
    # img center in (x, y)
    center = center[::-1]
    bbox = np.hstack([center - img_size / 2., center + img_size / 2.])

    img = img_util.crop(img, bbox, bgval=None)
    img, _ = img_util.resize_img(img, 256/257.)

    # Transpose the image to 3xHxW
    img = np.transpose(img, (2, 0, 1))

    if kp is not None:
        kp = kp*scale_factor
        kp[:,0] -= bbox[0]
        kp[:,1] -= bbox[1]

    return img, kp


def smal_mesh_eval(num_train_epoch):

    #predictor.eval()

    opts.num_train_epoch = num_train_epoch
    img_path = 'smalst/validation_set/'
    images = sorted(glob(img_path+'*.jpg'))
    N = len(images)

    batch_size = opts.batch_size
    opts.batch_size=1
    
    predictor = pred_util.MeshPredictor(opts)
    tot_pose_err = 0

    err_tot = np.zeros((N))
    for idx, img_path in enumerate(images):

        anno_path = img_path.replace('.jpg', '_ferrari-tail-face.mat')
        res = sio.loadmat(anno_path, squeeze_me=True, struct_as_record=False)
        res = res['annotation']
        kp = res.kp.astype(float)
        invisible = res.invisible
        vis = np.atleast_2d(~invisible.astype(bool)).T
        landmarks = np.hstack((kp, vis))
        names = [str(res.names[i]) for i in range(len(res.names))]

        img, kp = preprocess_image(img_path, img_size=opts.img_size, kp=kp)

        batch = {'img': torch.Tensor(np.expand_dims(img, 0))}
        
        outputs = predictor.predict(batch)

        kp_pred = ((outputs['kp_pred'].cpu().detach().numpy()[0,:,:]+1.)*128).astype(int)
        kp_err = np.sum(np.abs(kp - kp_pred)*vis)/np.sum(vis)
        tot_pose_err += kp_err
        err_tot[idx] = kp_err
    opts.batch_size=batch_size

    return tot_pose_err/N

def set_input(self, batch):
    resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    opts = self.opts

    img_tensor = batch['img'].clone().type(torch.FloatTensor)

    input_img_tensor = batch['img'].type(torch.FloatTensor)
    for b in range(input_img_tensor.size(0)):
        input_img_tensor[b] = resnet_transform(input_img_tensor[b])

    self.input_imgs = Variable(
        input_img_tensor.cuda(device=opts.gpu_id), requires_grad=False)
    self.imgs = Variable(
        img_tensor.cuda(device=opts.gpu_id), requires_grad=False)


def collect_outputs(self):
    outputs = {
            'pose_pred': self.pose.data,
            'kp_pred': self.kp_pred.data,
            'verts': self.pred_v.data,
            'kp_verts': self.kp_verts.data,
            'cam_pred': self.cam_pred.data,
            'mask_pred': self.mask_pred.data,
            'faces': self.faces,
            'delta_v_pred': self.delta_v.data,
            'trans_pred':self.trans.data,
            'kp_2D_pred':self.kp_2D_pred,
            'f':self.faces,
            'v':self.smal_verts
    }
    if self.opts.use_smal_betas:
        outputs['betas_pred'] = self.betas.data
    if self.opts.texture:
        outputs['texture'] = self.textures
        outputs['texture_pred'] = self.texture_pred.data
        outputs['uv_image'] = self.uv_images.data
        outputs['uv_flow'] = self.uv_flows.data
    if self.opts.predict_ambient_occlusion:
        outputs['occ_pred'] = self.occ_map

    return outputs
 


================================================
FILE: nnutils/smal_mesh_net.py
================================================
"""
Mesh net model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
import os
import os.path as osp
import numpy as np
import pickle as pkl
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
from ..smal_model.smal_basics import load_smal_model
from ..smal_model.smal_torch import SMAL

from ..utils import mesh
from ..utils import geometry as geom_utils
from . import net_blocks as nb

#-------------- flags -------------#
#----------------------------------#
flags.DEFINE_boolean('is_optimization', False, 'set to True to do refinement')
flags.DEFINE_boolean('is_refinement', False, 'set to True to do refinement')

flags.DEFINE_string('model_dir', 'smalst/smpl_models/', 'location of the SMAL model')
flags.DEFINE_string('model_name', 'my_smpl_00781_4_all.pkl', 'name of the model')
flags.DEFINE_boolean('symmetric', False, 'Use symmetric mesh or not')
flags.DEFINE_boolean('symmetric_texture', False, 'if true texture is symmetric!')
flags.DEFINE_integer('nz_feat', 1024, 'Encoded feature size')

flags.DEFINE_boolean('use_norm_f_and_z', True, 'if to use normalized f and z')
flags.DEFINE_float('camera_ref', '2700.', 'expected focal length value')
flags.DEFINE_float('trans_ref', '19.', 'expected model distance from camera (13 for linear)')
flags.DEFINE_float('norm_f', 2700., 'term in f=norm_f0+norm_f*x_f')  
flags.DEFINE_float('norm_f0', 2700., 'term in f=norm_f0+norm_f*x_f')  
flags.DEFINE_float('norm_z', 20., 'normalization term for depth')
flags.DEFINE_boolean('use_sym_idx', True, 'If to predict only half delta_v')

flags.DEFINE_bool('use_double_input', False, 'input the img and the fg')

flags.DEFINE_boolean('use_camera', True, 'if optimize camera focal length')
flags.DEFINE_boolean('use_delta_v', True, 'if predict vertex displacements')

flags.DEFINE_boolean('texture', True, 'if true uses texture!')
flags.DEFINE_boolean('texture_map', True, 'if true uses texture map loss!')
flags.DEFINE_boolean('use_directional_light', True, 'if using directional light rather than ambient')

flags.DEFINE_integer('num_betas', 20, 'Number of betas variables')
flags.DEFINE_boolean('use_smal_pose', True, 'if using articulated shape')
flags.DEFINE_boolean('use_smal_betas', False, 'if using smal shape space')

flags.DEFINE_integer('scale_bias', 1, '1 or 0 for bias in nn.Linear') # Does not work for 0
flags.DEFINE_boolean('fix_trans', False, 'do not optimize trans')

flags.DEFINE_integer('tex_size', 6, 'Texture resolution per face') 
flags.DEFINE_integer('texture_img_size', 256, 'Texture resolution per face')
flags.DEFINE_integer('number_of_textures', 4, 'Number of texture layers that compose the texture map')

flags.DEFINE_float('occlusion_map_scale', 1./16., 'division of the image')

flags.DEFINE_integer('bottleneck_size', 2048, 'Define bottleneck size')
flags.DEFINE_integer('channels_per_group', 16, 'number of channels per group in group normalization')

flags.DEFINE_integer('subdivide', 3, '# to subdivide icosahedron, 3=642verts, 4=2562 verts')

flags.DEFINE_boolean('use_deconv', False, 'If true uses Deconv')
flags.DEFINE_string('upconv_mode', 'bilinear', 'upsample mode')

flags.DEFINE_boolean('only_mean_sym', True, 'If true, only the meanshape is symmetric')

flags.DEFINE_boolean('use_resnet50', True, 'otherwise use resnet18')

flags.DEFINE_string('uv_data_file', 'my_smpl_00781_4_all_template_w_tex_uv_001.pkl', 'ft and vt data of the obj file')
flags.DEFINE_string('projection_type', 'perspective', 'camera projection type (orth or perspective')

flags.DEFINE_integer('n_shape_feat', 40, 'number of shape features when we do not use the betas')

flags.DEFINE_float('depth_var', 2.0, 'see TransPred')
flags.DEFINE_float('x_var', 2.0, 'see TransPred')
flags.DEFINE_float('y_var', 1.0, 'see TransPred')
flags.DEFINE_float('pose_var', 1.0, 'see PosePred')


#------------- Modules ------------#
#----------------------------------#
class ResNetConv(nn.Module):
    def __init__(self, n_blocks=4, opts=None):
        super(ResNetConv, self).__init__()
        if opts.use_resnet50:
            self.resnet = torchvision.models.resnet50(pretrained=True)
        else:
            self.resnet = torchvision.models.resnet18(pretrained=True)
        self.n_blocks = n_blocks
        self.opts = opts
        if self.opts.use_double_input:
            self.fc = nb.fc_stack(512*16*8, 512*8*8, 2)

    def forward(self, x, y=None):
        if self.opts.use_double_input and y is not None:
            x = torch.cat([x, y], 2)
        n_blocks = self.n_blocks
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        if n_blocks >= 1:
            x = self.resnet.layer1(x)
        if n_blocks >= 2:
            x = self.resnet.layer2(x)
        if n_blocks >= 3:
            x = self.resnet.layer3(x)
        if n_blocks >= 4:
            x = self.resnet.layer4(x)
        if self.opts.use_double_input and y is not None:
            x = x.view(x.size(0), -1)
            x = self.fc.forward(x)
            x = x.view(x.size(0), 512, 8, 8)
            
        return x

class Encoder(nn.Module):
    """
    Current:
    Resnet with 4 blocks (x32 spatial dim reduction)
    Another conv with stride 2 (x64)
    This is sent to 2 fc layers with final output nz_feat.
    """

    def __init__(self, opts, input_shape, n_blocks=4, nz_feat=100, bott_size=256):
        super(Encoder, self).__init__()
        self.opts = opts
        self.resnet_conv = ResNetConv(n_blocks=4, opts=opts)
        num_norm_groups = bott_size//opts.channels_per_group
        if opts.use_resnet50:
            self.enc_conv1 = nb.conv2d('group', 2048, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups)
        else:
            self.enc_conv1 = nb.conv2d('group', 512, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups)

        nc_input = bott_size * (input_shape[0] // 64) * (input_shape[1] // 64)
        self.enc_fc = nb.fc_stack(nc_input, nz_feat, 2, 'batch')
        self.nenc_feat = nc_input

        nb.net_init(self.enc_conv1)

    def forward(self, img, fg_img):
        resnet_feat = self.resnet_conv.forward(img, fg_img)
        out_enc_conv1 = self.enc_conv1(resnet_feat)
        out_enc_conv1 = out_enc_conv1.view(img.size(0), -1)
        feat = self.enc_fc.forward(out_enc_conv1)
        return feat, out_enc_conv1


class TexturePredictorUV(nn.Module):
    """
    Outputs mesh texture
    """

    def __init__(self, nz_feat, uv_sampler, opts, img_H=64, img_W=128, n_upconv=5, nc_init=256, predict_flow=False, symmetric=False, num_sym_faces=624, tex_masks=None,vt=None, ft=None):
        super(TexturePredictorUV, self).__init__()
        self.opts = opts
        self.feat_H = img_H // (2 ** n_upconv)
        self.feat_W = img_W // (2 ** n_upconv)
        self.nc_init = nc_init
        self.symmetric = symmetric
        self.num_sym_faces = num_sym_faces
        self.F = uv_sampler.size(1)
        self.T = uv_sampler.size(2)
        self.predict_flow = predict_flow
        self.tex_masks = tex_masks

        # Convert texture masks into the nmr format
        # B x F x T x T x 2 --> B x F x T*T x 2
        self.uv_sampler = uv_sampler.view(-1, self.F, self.T*self.T, 2)

        if opts.number_of_textures > 0:
            self.enc = nn.ModuleList([nb.fc_stack(nz_feat, self.nc_init*self.feat_H[i]*self.feat_W[i], 2, 'batch') for i in range(opts.number_of_textures)])
        else:
            self.enc = nb.fc_stack(nz_feat, self.nc_init*self.feat_H*self.feat_W, 2, 'batch')
        
        if predict_flow:
            nc_final=2
        else:
            nc_final=3
        if opts.number_of_textures > 0:
            num_groups = nc_init//opts.channels_per_group
            self.decoder = nn.ModuleList([nb.decoder2d(n_upconv, None, nc_init, 
                norm_type='group', num_groups=num_groups, init_fc=False, nc_final=nc_final,
                use_deconv=opts.use_deconv, upconv_mode=opts.upconv_mode) for _ in range(opts.number_of_textures)])
            self.uvimage_pred_layer = [None]*opts.number_of_textures
        else:
            num_groups = nc_init//opts.channels_per_group
            self.decoder = nb.decoder2d(n_upconv, None, nc_init, norm_type='group',
                num_groups=num_groups, init_fc=False, nc_final=nc_final,
                use_deconv=opts.use_deconv, upconv_mode=opts.upconv_mode)

    def forward(self, feat):
        if self.opts.number_of_textures > 0:
            tex_pred_layer = [None]*self.opts.number_of_textures
            uvimage_pred_layer = [None]*self.opts.number_of_textures
            for i in range(self.opts.number_of_textures):
                uvimage_pred_layer[i] = self.enc[i].forward(feat)
                uvimage_pred_layer[i] = uvimage_pred_layer[i].view(uvimage_pred_layer[i].size(0), self.nc_init, self.feat_H[i], self.feat_W[i])
                # B x 2 or 3 x H x W
                self.uvimage_pred_layer[i] = self.decoder[i].forward(uvimage_pred_layer[i])
                self.uvimage_pred_layer[i] = torch.nn.functional.tanh(self.uvimage_pred_layer[i])

            # Compose the predicted texture maps
            # Composition by tiling
            if self.opts.number_of_textures == 7:
                upper = torch.cat((uvimage_pred_layer[0], uvimage_pred_layer[1], uvimage_pred_layer[2]), 3)
                lower = torch.cat((uvimage_pred_layer[4], uvimage_pred_layer[5]), 3)
                right = torch.cat((uvimage_pred_layer[3], uvimage_pred_layer[6]), 2)
                uvimage_pred = torch.cat((torch.cat((upper, lower), 2), right), 3)

                upper = torch.cat((self.uvimage_pred_layer[0], self.uvimage_pred_layer[1], self.uvimage_pred_layer[2]), 3)
                lower = torch.cat((self.uvimage_pred_layer[4], self.uvimage_pred_layer[5]), 3)
                right = torch.cat((self.uvimage_pred_layer[3], self.uvimage_pred_layer[6]), 2)
                self.uvimage_pred = torch.cat((torch.cat((upper, lower), 2), right), 3)
            elif self.opts.number_of_textures == 4:
                uvimage_pred = torch.cat((torch.cat((uvimage_pred_layer[0],
                    torch.cat((uvimage_pred_layer[1], uvimage_pred_layer[2]), 3)), 2), uvimage_pred_layer[3]), 3)
                self.uvimage_pred = torch.cat((torch.cat((self.uvimage_pred_layer[0],
                    torch.cat((self.uvimage_pred_layer[1], self.uvimage_pred_layer[2]), 3)), 2), self.uvimage_pred_layer[3]), 3)
        else:
            uvimage_pred = self.enc.forward(feat)
            uvimage_pred = uvimage_pred.view(uvimage_pred.size(0), self.nc_init, self.feat_H, self.feat_W)
            # B x 2 or 3 x H x W
            self.uvimage_pred = self.decoder.forward(uvimage_pred)
            self.uvimage_pred = torch.nn.functional.tanh(self.uvimage_pred)

        tex_pred = torch.nn.functional.grid_sample(self.uvimage_pred, self.uv_sampler)
        tex_pred = tex_pred.view(self.uvimage_pred.size(0), -1, self.F, self.T, self.T).permute(0, 2, 3, 4, 1)

        if self.symmetric:
            # Symmetrize.
            tex_left = tex_pred[:, -self.num_sym_faces:]
            return torch.cat([tex_pred, tex_left], 1)
        else:
            # Contiguous Needed after the permute..
            return tex_pred.contiguous()



class ShapePredictor(nn.Module):
    """
    Outputs mesh deformations
    """

    def __init__(self, nz_feat, num_verts, opts, left_idx, right_idx, shapedirs):
        super(ShapePredictor, self).__init__()
        self.opts = opts
        if opts.use_delta_v:
            if opts.use_sym_idx:
                self.left_idx = left_idx
                self.right_idx = right_idx
                self.num_verts = num_verts
                B = shapedirs.reshape([shapedirs.shape[0], num_verts, 3])[:,left_idx]
                B = B.reshape([B.shape[0], -1])
                self.pred_layer = nn.Linear(nz_feat, len(left_idx) * 3)
            else:
                B = shapedirs
                self.pred_layer = nn.Linear(nz_feat, num_verts * 3)

            if opts.use_smal_betas:
                # Initialize pred_layer weights to be small so initial def aren't so big
                self.pred_layer.weight.data.normal_(0, 0.0001)
            else:
                self.fc = nb.fc('batch', nz_feat, opts.n_shape_feat)
                n_feat = opts.n_shape_feat
                B = B.permute(1,0)
                A = torch.Tensor(np.zeros((B.size(0), n_feat)))
                n = np.min((B.size(1), n_feat))
                A[:,:n] = B[:,:n]
                self.pred_layer.weight.data = torch.nn.Parameter(A)
                self.pred_layer.bias.data.fill_(0.)

        else:
            self.ref_delta_v = torch.Tensor(np.zeros((opts.batch_size,num_verts,3))).cuda(device=opts.gpu_id) 


    def forward(self, feat):
        if self.opts.use_sym_idx:
            delta_v = torch.Tensor(np.zeros((self.opts.batch_size,self.num_verts,3))).cuda(device=self.opts.gpu_id)
            feat = self.fc(feat)
            self.shape_f = feat
     
            half_delta_v = self.pred_layer.forward(feat)
            half_delta_v = half_delta_v.view(half_delta_v.size(0), -1, 3)
            delta_v[:,self.left_idx,:] = half_delta_v
            half_delta_v[:,:,1] = -1.*half_delta_v[:,:,1]
            delta_v[:,self.right_idx,:] = half_delta_v
        else:
            delta_v = self.pred_layer.forward(feat)
            # Make it B x num_verts x 3
            delta_v = delta_v.view(delta_v.size(0), -1, 3)
        # print('shape: ( Mean = {}, Var = {} )'.format(delta_v.mean().data[0], delta_v.var().data[0]))
        return delta_v

class PosePredictor(nn.Module):
    """
    """
    def __init__(self, opts, nz_feat, num_joints=35):
        super(PosePredictor, self).__init__()
        self.opts = opts
        self.num_joints = num_joints
        self.pred_layer = nn.Linear(nz_feat, num_joints*3)

    def forward(self, feat):
        pose = self.opts.pose_var*self.pred_layer.forward(feat)

        # Add this to have zero to correspond to frontal facing
        pose[:,0] += 1.20919958
        pose[:,1] += 1.20919958
        pose[:,2] += -1.20919958
        return pose

class BetasPredictor(nn.Module):
    def __init__(self, opts, nz_feat, nenc_feat, num_betas=10):
        super(BetasPredictor, self).__init__()
        self.opts = opts
        self.pred_layer = nn.Linear(nenc_feat, num_betas)

    def forward(self, feat, enc_feat):
        betas = self.pred_layer.forward(enc_feat)

        return betas

class Keypoints2DPredictor(nn.Module):
    def __init__(self, opts, nz_feat, nenc_feat, num_keypoints=28):
        super(Keypoints2DPredictor, self).__init__()
        self.opts = opts
        self.num_keypoints = num_keypoints
        self.pred_layer = nn.Linear(nz_feat, 2*num_keypoints)

    def forward(self, feat, enc_feat):
        keypoints2D = self.pred_layer.forward(feat)
        return keypoints2D.view(-1,self.num_keypoints,2)



class ScalePredictor(nn.Module):
    '''
    In case of perspective projection scale is focal length
    '''
    def __init__(self, nz, opts):
        super(ScalePredictor, self).__init__()
        self.opts = opts
        if opts.use_camera:
            self.opts = opts
            self.pred_layer = nn.Linear(nz, opts.scale_bias)
        else:
            scale = np.zeros((opts.batch_size,1))
            scale[:,0] = 0.
            self.ref_camera = torch.Tensor(scale).cuda(device=opts.gpu_id) 

    def forward(self, feat):
        if not self.opts.use_camera:
            return self.ref_camera
        if self.opts.norm_f0 != 0:
            off = 0.
        else:
            off = 1.
        scale = self.pred_layer.forward(feat) + off   
        return scale


class TransPredictor(nn.Module):
    """
    Outputs [tx, ty] or [tx, ty, tz]
    """

    def __init__(self, nz, projection_type, opts):
        super(TransPredictor, self).__init__()
        self.opts = opts
        if projection_type =='orth':
            self.pred_layer = nn.Linear(nz, 2)
        elif projection_type == 'perspective':
            self.pred_layer_xy = nn.Linear(nz, 2)
            self.pred_layer_z = nn.Linear(nz, 1)
            self.pred_layer_xy.weight.data.normal_(0, 0.0001)
            self.pred_layer_xy.bias.data.normal_(0, 0.0001)
            self.pred_layer_z.weight.data.normal_(0, 0.0001)
            self.pred_layer_z.bias.data.normal_(0, 0.0001)
        else:
            print('Unknown projection type')

    def forward(self, feat):
        trans = torch.Tensor(np.zeros((feat.shape[0],3))).cuda(device=self.opts.gpu_id)
        f = torch.Tensor(np.zeros((feat.shape[0],1))).cuda(device=self.opts.gpu_id)
        feat_xy = feat
        feat_z = feat
        trans[:,:2] = self.pred_layer_xy(feat_xy)
        trans[:,0] += 1.0
        trans[:,2] = 1.0+self.pred_layer_z(feat_z)[:,0]

        if self.opts.fix_trans:
            trans[:,2] = 1.

        # print('trans: ( Mean = {}, Var = {} )'.format(trans.mean().data[0], trans.var().data[0]))
        return trans


class CodePredictor(nn.Module):
    def __init__(self, nz_feat=100, nenc_feat=2048, num_verts=1000, opts=None, left_idx=None, right_idx=None, shapedirs=None):
        super(CodePredictor, self).__init__()
        self.opts = opts
        self.shape_predictor = ShapePredictor(nz_feat, num_verts=num_verts, opts=self.opts, left_idx=left_idx, right_idx=right_idx, shapedirs=shapedirs)
        self.scale_predictor = ScalePredictor(nz_feat, self.opts)
        self.trans_predictor = TransPredictor(nz_feat, self.opts.projection_type, self.opts)
        if opts.use_smal_pose:
            self.pose_predictor = PosePredictor(self.opts, nz_feat)
        if opts.use_smal_betas:
            self.betas_predictor = BetasPredictor(self.opts, nz_feat, nenc_feat, self.opts.num_betas)

    def forward(self, feat, enc_feat):
        if self.opts.use_delta_v:
            shape_pred = self.shape_predictor.forward(feat)
        else:
            shape_pred = self.shape_predictor.ref_delta_v
        if self.opts.use_camera:
            scale_pred = self.scale_predictor.forward(feat)
        else:
            scale_pred = self.scale_predictor.ref_camera

        trans_pred = self.trans_predictor.forward(feat)

        if self.opts.use_smal_pose:
            pose_pred = self.pose_predictor.forward(feat)
        else:
            pose_pred = None

        if self.opts.use_smal_betas:
            betas_pred = self.betas_predictor.forward(feat, enc_feat)
        else:
            betas_pred = None

        keypoints2D_pred = None

        return shape_pred, scale_pred, trans_pred, pose_pred, betas_pred, keypoints2D_pred

#------------ Mesh Net ------------#
#----------------------------------#
class MeshNet(nn.Module):
    def __init__(self, input_shape, opts, nz_feat=100, num_kps=28, sfm_mean_shape=None, tex_masks=None):
        # Input shape is H x W of the image.
        super(MeshNet, self).__init__()
        self.opts = opts
        self.pred_texture = opts.texture
        self.symmetric = opts.symmetric
        self.symmetric_texture = opts.symmetric_texture
        self.tex_masks = tex_masks

        self.op_features = None

        # Instantiate the SMAL model in Torch
        model_path = os.path.join(self.opts.model_dir, self.opts.model_name)
        self.smal = SMAL(pkl_path=model_path, opts=self.opts)

        self.left_idx = np.hstack((self.smal.left_inds, self.smal.center_inds))
        self.right_idx = np.hstack((self.smal.right_inds, self.smal.center_inds))

        pose = np.zeros((1,105))
        betas = np.zeros((1,self.opts.num_betas))
        V,J,R = self.smal(torch.Tensor(betas).cuda(device=self.opts.gpu_id), torch.Tensor(pose).cuda(device=self.opts.gpu_id))
        verts = V[0,:,:]
        verts = verts.data.cpu().numpy()
        faces = self.smal.f


        num_verts = verts.shape[0]

        if self.symmetric:
            verts, faces, num_indept, num_sym, num_indept_faces, num_sym_faces = mesh.make_symmetric(verts, faces, self.smal.left_inds, self.smal.right_inds, self.smal.center_inds)
            if sfm_mean_shape is not None:
                verts = geom_utils.project_verts_on_mesh(verts, sfm_mean_shape[0], sfm_mean_shape[1])

            num_sym_output = num_indept + num_sym
            if opts.only_mean_sym:
                print('Only the mean shape is symmetric!')
                self.num_output = num_verts
            else:
                self.num_output = num_sym_output
            self.num_sym = num_sym
            self.num_indept = num_indept
            self.num_indept_faces = num_indept_faces
            self.num_sym_faces = num_sym_faces
            # mean shape is only half.
            self.mean_v = nn.Parameter(torch.Tensor(verts[:num_sym_output]))

            # Needed for symmetrizing..
            self.flip = Variable(torch.ones(1, 3).cuda(device=self.opts.gpu_id), requires_grad=False)
            self.flip[0, 0] = -1
        else:
            if sfm_mean_shape is not None:
                verts = geom_utils.project_verts_on_mesh(verts, sfm_mean_shape[0], sfm_mean_shape[1])            
            self.mean_v = nn.Parameter(torch.Tensor(verts))
            self.num_output = num_verts
            faces = faces.astype(np.int32) 

        verts_np = verts
        faces_np = faces
        self.faces = Variable(torch.LongTensor(faces).cuda(device=self.opts.gpu_id), requires_grad=False)
        self.edges2verts = mesh.compute_edges2verts(verts, faces)

        vert2kp_init = torch.Tensor(np.ones((num_kps, num_verts)) / float(num_verts))
        # Remember initial vert2kp (after softmax)
        self.vert2kp_init = torch.nn.functional.softmax(Variable(vert2kp_init.cuda(device=self.opts.gpu_id), requires_grad=False), dim=1)
        self.vert2kp = nn.Parameter(vert2kp_init)

        self.encoder = Encoder(self.opts, input_shape, n_blocks=4, nz_feat=nz_feat, bott_size=opts.bottleneck_size)
        nenc_feat = self.encoder.nenc_feat
        self.code_predictor = CodePredictor(nz_feat=nz_feat, nenc_feat=nenc_feat,
            num_verts=self.num_output, opts=opts, left_idx=self.left_idx, right_idx=self.right_idx, shapedirs=self.smal.shapedirs)

        if self.pred_texture:
            if self.symmetric_texture:
                num_faces = self.num_indept_faces + self.num_sym_faces
            else:
                num_faces = faces.shape[0]
                self.num_sym_faces = 0

            # Instead of loading an obj file
            uv_data = pkl.load(open(os.path.join(self.opts.model_dir,opts.uv_data_file)))
            vt = uv_data['vt']
            ft = uv_data['ft']
            self.vt = vt
            self.ft = ft
            uv_sampler = mesh.compute_uvsampler(verts_np, faces_np[:num_faces], vt, ft, tex_size=opts.tex_size)
            # F' x T x T x 2
            uv_sampler = Variable(torch.FloatTensor(uv_sampler).cuda(device=self.opts.gpu_id), requires_grad=False)
            # B x F' x T x T x 2
            uv_sampler = uv_sampler.unsqueeze(0).repeat(self.opts.batch_size, 1, 1, 1, 1)
            if opts.number_of_textures > 0:
                    if opts.texture_img_size == 256:
                        if opts.number_of_textures == 7:
                            img_H = np.array([96, 96, 96, 96, 160, 160, 160])
                            img_W = np.array([64, 128, 32, 32, 128, 96, 32])
                        elif opts.number_of_textures == 4:
                            img_H = np.array([96, 160, 160, 256])
                            img_W = np.array([224, 128, 96, 32])
                    else:
                        print('ERROR texture')
                        import pdb; pdb.set_trace()
            else:
                img_H = opts.texture_img_size 
                img_W = opts.texture_img_size 


            self.texture_predictor = TexturePredictorUV(
              nz_feat, uv_sampler, opts, img_H=img_H, img_W=img_W, predict_flow=True, symmetric=opts.symmetric_texture,
              num_sym_faces=self.num_sym_faces, tex_masks=self.tex_masks, vt=vt, ft=ft)
           
            nb.net_init(self.texture_predictor)

    def forward(self, img, masks=None):
        opts = self.opts
        if self.opts.is_optimization:
            if self.opts.is_var_opt:
                img_feat, enc_feat = self.encoder.forward(img, masks)
                if self.op_features is None:
                    codes_pred = self.code_predictor.forward(img_feat, enc_feat)
                    self.opts_scale = Variable(codes_pred[1].cuda(device=opts.gpu_id), requires_grad=True)
                    self.opts_pose = Variable(codes_pred[3].cuda(device=opts.gpu_id), requires_grad=True)
                    self.opts_trans = Variable(codes_pred[2].cuda(device=opts.gpu_id), requires_grad=True)
                    self.opts_delta_v= Variable(codes_pred[0].cuda(device=opts.gpu_id), requires_grad=True)
                    self.op_features = [self.opts_scale, self.opts_pose, self.opts_trans] 
                codes_pred = (self.opts_delta_v, self.opts_scale, self.opts_trans, self.opts_pose, None, None)
            else:
                # Optimization over the features
                if self.op_features is None:
                    img_feat, enc_feat = self.encoder.forward(img, masks)
                    self.op_features = Variable(img_feat.cuda(device=self.opts.gpu_id), requires_grad=True)
                codes_pred = self.code_predictor.forward(self.op_features, None)
                img_feat = self.op_features

        else:
            img_feat, enc_feat = self.encoder.forward(img, masks)
            codes_pred = self.code_predictor.forward(img_feat, enc_feat)
        if self.pred_texture:
            texture_pred = self.texture_predictor.forward(img_feat)
            return codes_pred, texture_pred
        else:
            return codes_pred

    def symmetrize(self, V):
        """
        Takes num_indept+num_sym verts and makes it
        num_indept + num_sym + num_sym
        Is identity if model is not symmetric
        """
        if self.symmetric:
            if V.dim() == 2:
                # No batch
                V_left = self.flip * V[-self.num_sym:]
                return torch.cat([V, V_left], 0)
            else:
                # With batch
                V_left = self.flip * V[:, -self.num_sym:]
                return torch.cat([V, V_left], 1)
        else:
            return V

    def get_smal_verts(self, pose=None, betas=None, trans=None, del_v=None):
        if pose is None:
            pose = torch.Tensor(np.zeros((1,105))).cuda(device=self.opts.gpu_id)
        if betas is None:
            betas = torch.Tensor(np.zeros((1,self.opts.num_betas))).cuda(device=self.opts.gpu_id)
        if trans is None:
            trans = torch.Tensor(np.zeros((1,3))).cuda(device=self.opts.gpu_id)

        verts, _, _ = self.smal(betas, pose, trans, del_v)
        return verts

    def get_mean_shape(self):
        return self.symmetrize(self.mean_v)



================================================
FILE: nnutils/smal_predictor.py
================================================
"""
Takes an image, returns stuff.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
import os
import os.path as osp
import numpy as np
import scipy.misc
import torch
import torchvision
from torch.autograd import Variable
import scipy.io as sio

from ..nnutils import smal_mesh_net as mesh_net
from ..nnutils import geom_utils
from ..nnutils.nmr import NeuralRenderer
from ..utils import smal_vis
import pickle as pkl

# These options are off by default, but used for some ablations reported.
flags.DEFINE_boolean('ignore_pred_delta_v', False, 'Use only mean shape for prediction')

class MeshPredictor(object):
    def __init__(self, opts):
        self.opts = opts

        self.symmetric = opts.symmetric

        img_size = (opts.img_size, opts.img_size)

        # Load the texture map layers
        tex_masks = [None]*opts.number_of_textures
        self.vert2kp = torch.Tensor(pkl.load(open('smalst/zebra_data/verts2kp.pkl'))).cuda(device=opts.gpu_id)

        print('Setting up model..')
        self.model = mesh_net.MeshNet(img_size, opts, nz_feat=opts.nz_feat, tex_masks=tex_masks)

        self.load_network(self.model, 'pred', self.opts.num_train_epoch)
        # set the module in evaluation mode
        self.model.eval()
        self.model = self.model.cuda(device=self.opts.gpu_id)

        self.renderer = NeuralRenderer(opts.img_size, opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)

        if opts.texture:
            self.tex_renderer = NeuralRenderer(opts.img_size, opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)
            # Only use ambient light for tex renderer
            self.tex_renderer.ambient_light_only()

        self.mean_shape = self.model.get_mean_shape()

        # For visualization
        faces = self.model.faces.view(1, -1, 3)

        self.faces = faces.repeat(opts.batch_size, 1, 1)
        self.vis_rend = smal_vis.VisRenderer(opts.img_size,
                                             faces.data.cpu().numpy(), opts.projection_type, opts.norm_f, opts.norm_z, opts.norm_f0)
        self.vis_rend.set_bgcolor([1., 1., 1.])

        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def load_network(self, network, network_label, epoch_label):
        save_filename = '{}_net_{}.pth'.format(network_label, epoch_label)
        network_dir = os.path.join(self.opts.checkpoint_dir, self.opts.name)
        save_path = os.path.join(network_dir, save_filename)
        print('loading {}..'.format(save_path))
        network.load_state_dict(torch.load(save_path))

        return

    def set_input(self, batch):
        opts = self.opts

        # original image where texture is sampled from.
        img_tensor = batch['img'].clone().type(torch.FloatTensor)

        # input_img is the input to resnet
        input_img_tensor = batch['img'].type(torch.FloatTensor)

        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])

        self.input_imgs = Variable(
            input_img_tensor.cuda(device=opts.gpu_id), requires_grad=False)
        self.imgs = Variable(
            img_tensor.cuda(device=opts.gpu_id), requires_grad=False)

    def predict(self, batch, cam_gt=None, trans_gt=None, pose_gt=None, betas_gt=None, rot=0):
        """
        batch has B x C x H x W numpy
        """
        self.set_input(batch)
        self.forward(cam_gt, trans_gt, pose_gt, betas_gt, rot)
        return self.collect_outputs()

    def forward(self, cam_gt=None, trans_gt=None, pose_gt=None, betas_gt=None, rot=0):
        if self.opts.texture:
            pred_codes, self.textures = self.model.forward(self.input_imgs)
        else:
            pred_codes = self.model.forward(self.input_imgs)

        self.delta_v, scale, self.trans, self.pose, self.betas, self.kp_2D_pred = pred_codes

        # Rotate the view
        if rot != 0:
            import cv2
            r0 = self.pose[:,:3].detach().cpu().numpy()
            R0, _ = cv2.Rodrigues(r0)
            ry = np.array([0, rot, 0])
            Ry, _ = cv2.Rodrigues(ry)
            Rt = np.matrix(Ry)*np.matrix(R0)
            rt, _ = cv2.Rodrigues(Rt)
            self.pose[:,:3] = torch.Tensor(rt).permute(1,0)

        if cam_gt is not None:
            print('Setting gt cam')
            scale[:] = cam_gt
        if trans_gt is not None:
            print('Setting gt trans')
            self.trans[0,:] = torch.Tensor(trans_gt)
        if pose_gt is not None:
            print('Setting gt pose')
            self.pose[0,:] = torch.Tensor(pose_gt)
        if betas_gt is not None:
            print('Setting gt betas')
            self.betas[0,:] = torch.Tensor(betas_gt[:10])
            print('Removing delta_v')
            self.delta_v[:] = 0

        if True:
            if self.opts.projection_type == 'perspective':
                # The camera center does not change;
                cam_center = torch.Tensor([self.input_imgs.shape[2]//2, self.input_imgs.shape[3]//2]).cuda(device=self.opts.gpu_id)
                if scale.shape[0] == 1:
                    self.cam_pred = torch.cat([scale, cam_center[None,:]], 1)
                else:
                    self.cam_pred = torch.cat([scale.permute(1,0), cam_center.repeat(scale.shape[0],1).permute(1,0)]).permute(1,0)
            else:
                import pdb; pdb.set_trace()


        del_v = self.delta_v
        # Deform mean shape:
        if self.opts.ignore_pred_delta_v:
            del_v[:] = 0

        if self.opts.use_smal_pose:
            self.smal_verts = self.model.get_smal_verts(self.pose, self.betas, self.trans, del_v)
            self.pred_v = self.smal_verts
        else:
            # TODO
            import pdb; pdb.set_trace()

        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts,
                                                    self.cam_pred)
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces,
                                               self.cam_pred)

        # Render texture.
        if self.opts.texture: 
            if self.textures.size(-1) == 2:
                # Flow texture!
                self.texture_flow = self.textures
                self.textures = geom_utils.sample_textures(self.textures,
                                                           self.imgs)
            if self.textures.dim() == 5:  # B x F x T x T x 3
                tex_size = self.textures.size(2)
                self.textures = self.textures.unsqueeze(4).repeat(1, 1, 1, 1,
                                                                  tex_size, 1)

            # Render texture:
            self.texture_pred = self.tex_renderer.forward(
                self.pred_v, self.faces, self.cam_pred, textures=self.textures)

            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            self.uv_flows = uv_flows.permute(0, 2, 3, 1)

            self.uv_images = torch.nn.functional.grid_sample(self.imgs, self.uv_flows)
        else:
            self.textures = None

    def collect_outputs(self):
        outputs = {
            'pose_pred': self.pose.data,
            'kp_pred': self.kp_pred.data,
            'verts': self.pred_v.data,
            'kp_verts': self.kp_verts.data,
            'cam_pred': self.cam_pred.data,
            'mask_pred': self.mask_pred.data,
            'delta_v_pred': self.delta_v.data,
            'trans_pred':self.trans.data,
            'kp_2D_pred':self.kp_2D_pred,
            'shape_f': self.model.code_predictor.shape_predictor.shape_f.data,
            'f':self.faces,
            'v':self.smal_verts
        }
        if self.opts.use_smal_betas: 
            outputs['betas_pred'] = self.betas.data
        if self.opts.texture: 
            outputs['texture'] = self.textures
            outputs['texture_pred'] = self.texture_pred.data
            outputs['uv_image'] = self.uv_images.data
            outputs['uv_flow'] = self.uv_flows.data

        return outputs


================================================
FILE: nnutils/train_utils.py
================================================
"""
Generic Training Utils.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import os
import os.path as osp
import time
import pdb
from absl import flags
import pickle as pkl
import scipy.misc
import numpy as np
from ..nnutils import geom_utils
import time

from ..utils.visualizer import Visualizer
from .smal_mesh_eval import smal_mesh_eval

#-------------- flags -------------#
#----------------------------------#
## Flags for training
curr_path = osp.dirname(osp.abspath(__file__))
cache_path = osp.join(curr_path, '..', 'cachedir')

flags.DEFINE_string('name', 'exp_name', 'Experiment Name')
# flags.DEFINE_string('cache_dir', cache_path, 'Cachedir') # Not used!
flags.DEFINE_integer('gpu_id', 0, 'Which gpu to use')
flags.DEFINE_integer('num_epochs', 1000, 'Number of epochs to train')
flags.DEFINE_integer('num_pretrain_epochs', 0, 'If >0, we will pretain from an existing saved model.')
flags.DEFINE_float('learning_rate', 0.0001, 'learning rate')
flags.DEFINE_float('beta1', 0.9, 'Momentum term of adam')

flags.DEFINE_bool('use_sgd', False, 'if true uses sgd instead of adam, beta1 is used as mmomentu')

flags.DEFINE_integer('batch_size', 8, 'Size of minibatches')
flags.DEFINE_integer('num_iter', 0, 'Number of training iterations. 0 -> Use epoch_iter')
flags.DEFINE_integer('new_dataset_freq', 2, 'at which epoch to get a new dataset')

## Flags for logging and snapshotting
flags.DEFINE_string('checkpoint_dir', osp.join(cache_path, 'snapshots'),
                    'Root directory for output files')
flags.DEFINE_integer('print_freq', 20, 'scalar logging frequency')
flags.DEFINE_integer('save_latest_freq', 10000, 'save latest model every x iterations')
flags.DEFINE_integer('save_epoch_freq', 25, 'save model every k epochs')

flags.DEFINE_bool('save_training_imgs', False, 'save mask and images for debugging')

## Flags for visualization
flags.DEFINE_integer('display_freq', 100, 'visuals logging frequency')
flags.DEFINE_boolean('display_visuals', False, 'whether to display images')
flags.DEFINE_boolean('print_scalars', True, 'whether to print scalars')
flags.DEFINE_boolean('plot_scalars', False, 'whether to plot scalars')
flags.DEFINE_boolean('is_train', True, 'Are we training ?')
flags.DEFINE_integer('display_id', 1, 'Display Id')
flags.DEFINE_integer('display_winsize', 256, 'Display Size')
flags.DEFINE_integer('display_port', 8097, 'Display port')
flags.DEFINE_integer('display_single_pane_ncols', 0, 'if positive, display all images in a single visdom web panel with certain number of images per row.')

flags.DEFINE_integer('num_train_epoch', 40, '')
flags.DEFINE_boolean('do_validation', False, 'compute on validation set at each epoch')
flags.DEFINE_boolean('is_var_opt', False, 'set to True to optimize over pose scale trans')



def set_bn_eval(m):
    classname = m.__class__.__name__
    if (classname.find('BatchNorm1d') != -1) or (classname.find('BatchNorm2d') != -1):
        m.eval()

#-------- tranining class ---------#
#----------------------------------#
class Trainer():
    def __init__(self, opts):
        self.opts = opts
        self.gpu_id = opts.gpu_id
        self.Tensor = torch.cuda.FloatTensor if (self.gpu_id is not None) else torch.Tensor
        self.invalid_batch = False #the trainer can optionally reset this every iteration during set_input call
        self.save_dir = os.path.join(opts.checkpoint_dir, opts.name)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        log_file = os.path.join(self.save_dir, 'opts.log')
        with open(log_file, 'w') as f:
            for k in dir(opts):
                f.write('{}: {}\n'.format(k, opts.__getattr__(k)))


    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_id=None):
        save_filename = '{}_net_{}.pth'.format(network_label, epoch_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if gpu_id is not None and torch.cuda.is_available():
            network.cuda(device=gpu_id)
        return

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label, network_dir=None):
        save_filename = '{}_net_{}.pth'.format(network_label, epoch_label)
        if network_dir is None:
            network_dir = self.save_dir
        save_path = os.path.join(network_dir, save_filename)
        network.load_state_dict(torch.load(save_path))
        return

    def define_model(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def init_dataset(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def define_criterion(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def set_input(self, batch):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def forward(self):
        '''Should compute self.total_loss. To be implemented by the child class.'''
        raise NotImplementedError

    def save(self, epoch_prefix):
        '''Saves the model.'''
        self.save_network(self.model, 'pred', epoch_prefix, gpu_id=self.opts.gpu_id)
        return

    def get_current_visuals(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def get_current_scalars(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def get_current_points(self):
        '''Should be implemented by the child class.'''
        raise NotImplementedError

    def init_training(self):
        opts = self.opts
        self.init_dataset()    
        self.define_model()
        self.define_criterion()
        if opts.use_sgd:
            self.optimizer = torch.optim.SGD(
                self.model.parameters(), lr=opts.learning_rate, momentum=opts.beta1)
        else:
            self.optimizer = torch.optim.Adam(
                self.model.parameters(), lr=opts.learning_rate, betas=(opts.beta1, 0.999))

    def save_current(self, opts, initial_loss=0, final_loss=0, code='_'):
        res_dict = {
            'final_loss': final_loss.data.detach().cpu().numpy(),
            'delta_v': self.delta_v.data.detach().cpu().numpy(),
            'kp_pred': self.kp_pred.data.detach().cpu().numpy(),
            'scale': self.scale_pred.data.detach().cpu().numpy(),
            'trans': self.trans_pred.data.detach().cpu().numpy(),
            'pose': self.pose_pred.data.detach().cpu().numpy(),
            'initial_loss': initial_loss.data.detach().cpu().numpy(),
            }

        scipy.misc.imsave(opts.image_file_string.replace('.png', code + 'mask.png'),
                            255*self.mask_pred.detach().cpu()[0,:,:])
        uv_flows = self.model.texture_predictor.uvimage_pred
        uv_flows = uv_flows.permute(0, 2, 3, 1)
        uv_images = torch.nn.functional.grid_sample(self.imgs, uv_flows)

        scipy.misc.imsave(opts.image_file_string.replace('.png', code + 'tex.png'),
                     255*np.transpose(uv_images.detach().cpu()[0,:,:,:], (1, 2, 0)))
        pkl.dump(res_dict, open(opts.image_file_string.replace('.png', code + 'res.pkl'), 'wb'))


    def train(self):

        time_stamp = str(time.time())[:10]
        opts = self.opts
        self.smoothed_total_loss = 0
        self.visualizer = Visualizer(opts)
        visualizer = self.visualizer
        total_steps = 0
        dataset_size = len(self.dataloader)
        print('dataset_size '+str(dataset_size))
        
        v_log_file = os.path.join(self.save_dir, 'validation.log')
        curr_epoch_err = 1000

        if opts.is_optimization:
            self.model.eval()
            self.model.texture_predictor.eval()
            for param in self.model.texture_predictor.parameters():
                param.requires_grad = False

            self.model.encoder.eval()
            for param in self.model.encoder.parameters():
                param.requires_grad = False

            self.model.code_predictor.shape_predictor.pred_layer.eval()
            for param in self.model.code_predictor.shape_predictor.pred_layer.parameters():
                param.requires_grad = False

            self.model.code_predictor.shape_predictor.fc.eval()
            for param in self.model.code_predictor.shape_predictor.fc.parameters():
                param.requires_grad = False

            self.model.code_predictor.scale_predictor.pred_layer.eval()
            for param in self.model.code_predictor.scale_predictor.pred_layer.parameters():
                param.requires_grad = False

            self.model.code_predictor.trans_predictor.pred_layer_xy.eval()
            for param in self.model.code_predictor.trans_predictor.pred_layer_xy.parameters():
                param.requires_grad = False

            self.model.code_predictor.trans_predictor.pred_layer_z.eval()
            for param in self.model.code_predictor.trans_predictor.pred_layer_z.parameters():
                param.requires_grad = False

            self.model.code_predictor.pose_predictor.pred_layer.eval()
            for param in self.model.code_predictor.pose_predictor.pred_layer.parameters():
                param.requires_grad = False

            self.model.apply(set_bn_eval)
            

        if opts.is_optimization:
            code = osp.splitext(osp.basename(opts.image_file_string))[0]
            visualizer.print_message(code)
        self.background_model_top = None
        set_optimization_input = True
        if True:
            for epoch in range(opts.num_pretrain_epochs, opts.num_epochs):
                epoch_iter = 0

                for i, batch in enumerate(self.dataloader):
                    iter_start_time = time.time()
                    if not opts.is_optimization:
                        self.set_input(batch)
                    else:
                        if set_optimization_input:
                            self.set_input(batch)

                    if not self.invalid_batch:
                        self.optimizer.zero_grad()

                        self.forward()

                        if opts.is_optimization:
                            if  set_optimization_input:
                                initial_loss = self.tex_loss
                                print("Initial loss")
                                print(initial_loss)
                                current_loss = initial_loss
                                opt_loss = current_loss
                                # Now the input should be the image prediction
                                self.set_optimization_input()
                                set_optimization_input = False
                                self.save_current(opts, initial_loss, current_loss, code='_init_')
                            else:
                                current_loss = self.tex_loss
                                if current_loss < opt_loss:
                                    opt_loss = current_loss
                                    self.save_current(opts, initial_loss, current_loss, code='_best_')
                                    visualizer.print_message('save current best ' + str(current_loss))
                                 
                        # self.background_model_top is not used but exloited as a flag
                        if opts.is_optimization and self.background_model_top is None:
                            # Create background model with current prediction
                            M = np.abs(self.mask_pred.cpu().detach().numpy()[0,:,:]-1)
                            I = np.transpose(self.imgs.cpu().detach().numpy()[0,:,:,:],(1,2,0))
                            N = 128

                            # Top half of the image
                            self.background_model_top = np.zeros((3))
                            n = np.sum(M[:N,:])
                            for c in range(3):
                               J = I[:,:,c] * M
                               self.background_model_top[c] = np.sum(J[:N,:])/n

                            self.background_model_bottom = np.zeros((3))
                            n = np.sum(M[N:,:])
                            for c in range(3):
                               J = I[:,:,c] * M
                               self.background_model_bottom[c] = np.sum(J[N:,:])/n
                            if opts.use_sgd:
                                self.optimizer = torch.optim.SGD(
                                    [self.model.op_features], lr=opts.learning_rate, momentum=opts.beta1)
                            else:
                                if opts.is_var_opt:
                                    self.optimizer = torch.optim.Adam(
                                        self.model.op_features, lr=opts.learning_rate, betas=(opts.beta1, 0.999))
                                else:
                                    self.optimizer = torch.optim.Adam(
                                        [self.model.op_features], lr=opts.learning_rate, betas=(opts.beta1, 0.999))

                        self.smoothed_total_loss = self.smoothed_total_loss*0.99 + 0.01*self.total_loss.data
                        self.total_loss.backward()
                        self.optimizer.step()

                    total_steps += 1
                    epoch_iter += 1

                    if opts.display_visuals and (total_steps % opts.display_freq == 0):
                        iter_end_time = time.time()
                        print('time/itr %.2g' % ((iter_end_time - iter_start_time)/opts.display_freq))
                        visualizer.display_current_results(self.get_current_visuals(), epoch)
                        visualizer.plot_current_points(self.get_current_points())

                    if opts.print_scalars and (total_steps % opts.print_freq == 0):
                        scalars = self.get_current_scalars()
                        visualizer.print_current_scalars(epoch, epoch_iter, scalars)
                        if opts.plot_scalars:
                            visualizer.plot_current_scalars(epoch, float(epoch_iter)/dataset_size, opts, scalars)

                    if total_steps % opts.save_latest_freq == 0:
                        print('saving the model at the end of epoch {:d}, iters {:d}'.format(epoch, total_steps))
                        self.save('latest')

                    if total_steps == opts.num_iter:
                        return

                if opts.do_validation:
                    self.save('100000')
                    epoch_err = smal_mesh_eval(num_train_epoch=100000)
                    if epoch_err <= curr_epoch_err:
                        print('update best model')
                        curr_epoch_err = epoch_err
                        self.save('best')
                        with open(v_log_file, 'a') as f:
                            f.write('{}: {}\n'.format(epoch, epoch_err))

                '''
                if opts.is_optimization and (epoch==(opts.num_epochs-1) or epoch==opts.num_pretrain_epochs):
                    img_pred = self.texture_pred*self.mask_pred + self.background_imgs*(torch.abs(self.mask_pred - 1.))
                    T = img_pred.cpu().detach().numpy()[0,:,:,:]
                    T = np.transpose(T,(1,2,0))
                    scipy.misc.imsave(code + '_img_pred_'+str(epoch)+'.png', T)
                    img_pred = self.texture_pred*self.mask_pred + self.imgs*(torch.abs(self.mask_pred - 1.))
                    T = img_pred.cpu().detach().numpy()[0,:,:,:]
                    T = np.transpose(T,(1,2,0))
                    scipy.misc.imsave(code + '_img_ol_'+str(epoch)+'.png', T)
                '''

                '''
                if opts.is_optimization and opts.save_training_imgs and np.mod(epoch,20)==0:
                    img_pred = self.texture_pred*self.mask_pred + self.background_imgs*(torch.abs(self.mask_pred - 1.))
                    T = img_pred.cpu().detach().numpy()[0,:,:,:]
                    T = np.transpose(T,(1,2,0))
                    scipy.misc.imsave(opts.name + '_img_pred_'+str(epoch)+'.png', T)
                    img_pred = self.texture_pred*self.mask_pred + self.imgs*(torch.abs(self.mask_pred - 1.))
                    T = img_pred.cpu().detach().numpy()[0,:,:,:]
                    T = np.transpose(T,(1,2,0))
                    scipy.misc.imsave(opts.name + '_img_ol_'+str(epoch)+'.png', T)
                '''

                '''
                if opts.is_optimization and epoch == opts.num_pretrain_epochs:
                    T = 255*self.imgs.cpu().detach().numpy()[0,:,:,:]
                    T = np.transpose(T,(1,2,0))
                    scipy.misc.imsave(code+'_img_gt.png', T)
                '''

                
                if (epoch+1) % opts.save_epoch_freq == 0:
                    print('saving the model at the end of epoch {:d}, iters {:d}'.format(epoch, total_steps))
                    if opts.is_optimization:
                        self.save_current(opts, initial_loss, current_loss, code=None)
                    else:
                        self.save(epoch+1)
                        self.save('latest')
            if opts.is_optimization:
                if opt_loss < initial_loss:
                    visualizer.print_message('updated')


================================================
FILE: requirements.txt
================================================
absl-py==0.1.10
chainer==3.3.0
cupy==2.3.0
Cython==0.27.3
h5py==2.7.1
imageio==2.2.0
ipdb==0.10.3
matplotlib==2.1.2
meshzoo==0.3.1
numpy==1.14.0
opencv-python==3.4.0.12
Pillow==8.3.2
progressbar==2.3
protobuf==3.5.1
PyOpenGL==3.1.0
scikit-image==0.13.1
scipy==1.0.0
torch==0.4.1
torchfile==0.1.0
torchvision==0.2.0
tqdm==4.19.5
visdom==0.1.7


================================================
FILE: scripts/smalst_evaluation_run.sh
================================================
# Example script to run the evaluation

python -m smalst.smal_eval --name=smal_net_600 --img_path='smalst/zebra_testset/' --num_train_epoch=186 --use_annotations=True --mirror=True --segm_eval=True --img_ext='.jpg' --anno_path='smalst/zebra_testset/annotations'

#python -m smalst.smal_eval --name=smal_net_600 --img_path='smalst/zebra_video_frame/' --num_train_epoch=186 --use_annotations=False --mirror=False --segm_eval=False --img_ext='.png' --bgval=0 --save_input=False --test_optimization_results=True --optimization_dir=smalst_experiments/demo_var_feat


================================================
FILE: scripts/smalst_op_run.sh
================================================
#!/bin/sh
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib

# The directory processed contains the input image as processed by the network in the feed forward pass.
# Before running the optimization please run the feed forward network on the data with the flag save_input at True (see smalst_evaluation_run.sh)

zebra_dir='/Users/silvia/Dropbox/Work/smalst/zebra_video_frame/processed'
zebra_dir_images='/Users/silvia/Dropbox/Work/smalst/zebra_video_frame/processed/*'
for dir in $zebra_dir_images
do
    for fil in $dir
    do
        file=$(basename $fil)
        echo "$file"
        python -m smalst.experiments.smal_shape --name=smal_net_600 --zebra_dir=$zebra_dir --image_file_string=$fil --num_pretrain_epochs=186 --batch_size=1 --texture_map=False --perturb_bbox=False --save_epoch_freq=1000 --save_training_imgs=True --is_optimization=True --use_loss_on_whole_image=True --learning_rate=0.0001 --use_directional_light=False --num_epochs=220 --is_var_opt=False
    done
done



================================================
FILE: scripts/smalst_train_run.sh
================================================
 
#python -m smalst.experiments.smal_shape --zebra_dir='smalst/zebra_training_set' --num_epochs=40 --save_epoch_freq=20 --name=smal_net_0 --save_training_imgs=False --num_images=20000 --do_validation=False --texture=False --texture_map=False --uv_flow=False --batch_size=2 --nz_feat=16 --bottleneck_size=16

python -m nokap.experiments.smal_shape --zebra_dir='nokap/zebra_training_set' --num_epochs=200 --save_epoch_freq=20 --name=smal_net_0 --save_training_imgs=False --num_images=20000 --do_validation=True





================================================
FILE: smal_eval.py
================================================
"""
Evaluation on the testset.

python -m smalst.smal_eval --name=smal_net_600 --img_path='smalst/testset_zoo/' --num_train_epoch=186 --use_annotations=False --mirror=False --segm_eval=False --img_ext='.png' --bgval=0

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags, app
import numpy as np

import pickle as pkl

import torch
import scipy
import scipy.misc
#from .nnutils import test_utils
from .nnutils import smal_predictor as pred_util
from .utils import image as img_util
from glob import glob
import scipy.io as sio
import matplotlib.pyplot as plt

# Only necessary for running on the testset to pad the image with the original one
from .testset_shape_experiments_crops import bboxes
import os.path as osp
from os.path import exists
from os import makedirs

curr_path = osp.dirname(osp.abspath(__file__))
cache_path = osp.join(curr_path, 'cachedir')

flags.DEFINE_string('img_path', 'data/im1963.jpg', 'Image to run')
flags.DEFINE_integer('img_size', 256, 'image size the network was trained on.')
flags.DEFINE_boolean('use_annotations', True, '')
flags.DEFINE_boolean('mirror', False, '')
flags.DEFINE_boolean('visualize', False, '')
flags.DEFINE_boolean('save_input', False, '')
flags.DEFINE_string('anno_path', '.', 'where the annotations are')
flags.DEFINE_string('img_ext', '.jpg', 'image extension')
flags.DEFINE_boolean('segm_eval', True, 'if we have gt segmentations and evaluate overlap')
flags.DEFINE_boolean('synthetic', False, '')
flags.DEFINE_integer('bgval', -1, '-1 means to pad the image with the original one (used for the testset)')
flags.DEFINE_boolean('test_optimization_results', False, '')
flags.DEFINE_string('optimization_dir', 'smalst/optimization_results', '')
flags.DEFINE_integer('batch_size', 4, 'Size of minibatches')
flags.DEFINE_string('name', 'exp_name', 'Experiment Name')
flags.DEFINE_integer('num_train_epoch', 0, 'Number of training iterations')
flags.DEFINE_integer('gpu_id', 0, 'Which gpu to use')
flags.DEFINE_string('cache_dir', cache_path, 'Cachedir')
flags.DEFINE_string('checkpoint_dir',
                    osp.join(cache_path, 'snapshots'),
                    'Directory where networks are saved')
flags.DEFINE_string('out_path', './smalst_results', 'where to save the result images')


opts = flags.FLAGS

def get_bbox(img_path, kp):
    # Load the mask
    mask_img = scipy.misc.imread(img_path.replace('images', 'bgsub')) / 255.
    # Load the image
    img = scipy.misc.imread(img_path) / 255.
    where = np.array(np.where(mask_img))
    xmin, ymin, _ = np.amin(where, axis=1)
    xmax, ymax, _ = np.amax(where, axis=1)
    mask_img = mask_img[xmin:xmax, ymin:ymax,:]
    img = img[xmin:xmax, ymin:ymax,:]
    kp[:,0] = kp[:,0] - ymin
    kp[:,1] = kp[:,1] - xmin
    return img, mask_img, kp

def preprocess_image(img_path, img_size=256, kp=None, border=5, bgval=-1, img=None, img_ext='.jpg'):

    if img is None:
        img = scipy.misc.imread(img_path) / 255.
    
    img = img[:,:,:3]
    img_in_shape = img.shape

    # Scale the max image size to be img_size
    scale_factor = float(img_size-2*border) / np.max(img.shape[:2])
    img, _ = img_util.resize_img(img, scale_factor)

    # Crop img_size x img_size from the center
    center = np.round(np.array(img.shape[:2]) / 2).astype(int)
    # img center in (x, y)
    center = center[::-1]
    bbox = np.hstack([center - img_size / 2., center + img_size / 2.])

    img = img_util.crop(img, bbox, bgval=bgval)

    
    # Replace the border with the real background
    img_name = osp.splitext(osp.basename(img_path))[0] 
    # Read the full image
    if bgval == -1:
        full_img_path = osp.join(osp.dirname(img_path), 'full_size', img_name+'*'+img_ext)
        full_img_path = glob(full_img_path)[0]
        if osp.exists(full_img_path):
            full_img = scipy.misc.imread(full_img_path)/255.
            bbox_orig = np.array(bboxes[img_name])
            sf = img_in_shape[0]/(1.0*bbox_orig[3])
            new_img, _ = img_util.resize_img(full_img, sf*scale_factor)
            center[0] = np.round((bbox_orig[2]/2. + bbox_orig[0])*sf*scale_factor).astype(int)
            center[1] = np.round((bbox_orig[3]/2. + bbox_orig[1])*sf*scale_factor).astype(int)
            bbox2 = np.hstack([center - img_size / 2., center + img_size / 2.])
            img = img_util.crop(new_img, bbox2, bgval=0)
 

    img, _ = img_util.resize_img(img, 256/257.)

    # Transpose the image to 3xHxW
    img = np.transpose(img, (2, 0, 1))

    if kp is not None:
        kp = kp*scale_factor
        kp[:,0] -= bbox[0]
        kp[:,1] -= bbox[1]

    return img, kp

def mirror_keypoints(keyp, vis, img_w):
    '''
    exchange keypoints from left to right and update x value
    names =['leftEye','rightEye','chin','frontLeftFoot','frontRightFoot',
            'backLeftFoot','backRightFoot','tailStart','frontLeftKnee',
            'frontRightKnee','backLeftKnee','backRightKnee','leftShoulder',
            'rightShoulder','frontLeftAnkle','frontRightAnkle','backLeftAnkle'
            'backRightAnkle','neck','TailTip','leftEar','rightEar',
            'nostrilLeft','nostrilRight','mouthLeft','mouthRight',
            'cheekLeft','cheekRight']
    '''

    dx2sx = [1,0,2,4,3,6,5,7,9,8,11,10,13,12,15,14,17,16,18,19,21,20,23,22,25,24,27,26]
        

    keyp_m = np.zeros_like(keyp)
    keyp_m[:,0] = img_w - keyp[dx2sx,0] - 1
    keyp_m[:,1] = keyp[dx2sx,1]

    vis_m = np.zeros_like(vis)
    vis_m[:] = vis[dx2sx]

    return keyp_m, vis_m

def mirror_image(img):
    if len(img.shape)==3:
        img_m = img[:,:,::-1].copy()
    else:
        img_m = img[:,::-1].copy()
    return img_m

def visualize_opt(img, predictor, renderer, data, out_path):

    pose = torch.Tensor(data['pose']).cuda(0)
    trans = torch.Tensor(data['trans']).cuda(0)
    del_v = torch.Tensor(data['delta_v']).cuda(0)
    vert = predictor.model.get_smal_verts(pose, None, trans, del_v)
    cam = 128*np.ones((3))
    cam[0] = data['scale'][0,:]
    cam = torch.Tensor(cam).cuda(0)
    shape_pred = renderer(vert, cam)
    img = np.transpose(img, (1, 2, 0))
    I = 0.3*255*img + 0.7*shape_pred
    scipy.misc.imsave(out_path, I)


def visualize(img, outputs, renderer, img_path=None, opts=None, cam_gt=0,
              z_gt=0, kp_gt=None, kp_pred=None, mask_gt=None, vis=None, out_path=None, tex_mask=None, visualize=False):

    vert = outputs['verts'][0]
    cam = outputs['cam_pred'][0]
    mask = outputs['mask_pred'].cpu().detach().numpy()

    if 'texture' in outputs.keys():
        texture = outputs['texture'][0]
        uv_image = outputs['uv_image'][0].cpu().detach().numpy()
        T = uv_image.transpose(1,2,0)
        img_pred = renderer(vert, cam, texture=texture)
    if 'occ_pred' in outputs.keys():
        occ_pred = outputs['occ_pred'][0].cpu().detach().numpy()

    shape_pred = renderer(vert, cam)

    img = np.transpose(img, (1, 2, 0))
    mask = np.transpose(mask, (1, 2, 0))[:,:,0]

    Iov = 0.3*255*img + 0.7*shape_pred
    I = shape_pred
    scipy.misc.imsave(osp.join(opts.out_path, 'shape_ov_'+out_path), Iov)
    scipy.misc.imsave(osp.join(opts.out_path, 'shape_'+out_path), I)

    
    N = 0.6*np.abs(tex_mask[:,:,:3]-1)
    if 'texture' in outputs.keys():
        scipy.misc.imsave(osp.join(opts.out_path, 'tex_'+out_path), N+T*tex_mask[:,:,:3])
        scipy.misc.imsave(osp.join(opts.out_path, 'img_'+out_path), img_pred)
    
    if visualize:

        plt.ion()
        plt.figure(1)
        plt.clf()
        plt.subplot(231)
        plt.imshow(img)
        plt.title('input')
        plt.axis('off')
        plt.subplot(232)
        plt.imshow(img)
        plt.imshow(shape_pred, alpha=0.7)
        if kp_gt is not None:
            idx = np.where(vis==True)
            plt.scatter(kp_gt[idx[0],0], kp_gt[idx[0],1])
            plt.scatter(kp_pred[idx,0], kp_pred[idx,1])

        plt.title('pred mesh')
        plt.axis('off')
        plt.subplot(233)
        if 'texture' in outputs.keys():
            plt.imshow(img_pred)
            plt.title('pred mesh w/texture')
            plt.axis('off')
            plt.subplot(234)
            plt.imshow(T)
            plt.axis('off')
            plt.subplot(235)
            plt.imshow(T*tex_mask[:,:,:3])
            plt.axis('off')
        plt.subplot(236)
        plt.imshow(mask)
        plt.axis('off')
        plt.draw()
        plt.show()
        plt.savefig(out_path, bbox_inches='tight')

        import pdb; pdb.set_trace()


def save_params(outputs, idx):
    data = {'pose': outputs['pose_pred'].data.detach().cpu().numpy()[0,:],
            'verts': outputs['verts'].data.detach().cpu().numpy()[0,:],
            'f': outputs['f'].data.detach().cpu().numpy()[0,:],
            'v': outputs['v'].data.detach().cpu().numpy()[0,:],}
    pkl.dump(data, open('data_'+str(idx)+'.pkl', 'wb'))


def main(_):

    texture_mask_path = 'smalst/zebra_data/texture_maps/my_smpl_00781_4_all_template_w_tex_uv_001_mask_small_256.png'
    tex_mask = scipy.misc.imread(texture_mask_path)/255. 

    if not exists(opts.out_path): makedirs(opts.out_path)

    images = sorted(glob(opts.img_path+'*'+opts.img_ext))
    print(str(len(images)))
    N = len(images)
    show = True
    segm_eval = opts.segm_eval
    use_annotations = opts.use_annotations
    mirror = opts.mirror
    if mirror:
        N = 2*N
    if show:
        predictor = pred_util.MeshPredictor(opts)
    tot_pose_err = 0
    annotations_path = opts.anno_path
    alpha = [0.01, 0.02, 0.05, 0.1, 0.15] 
    n_alpha = len(alpha)

    global_rotation = 0 
    mirr_global_rotation = 0 

    err_tot = np.zeros((N, n_alpha))

    shape_f = np.zeros((N,40))

    overlap = np.zeros((N))
    IOU = np.zeros((N))

    idx = 0

    for iidx, img_path in enumerate(images):

        if use_annotations:
            if opts.synthetic:
                anno_path = osp.join(annotations_path, osp.basename(img_path).replace(opts.img_ext, '.pkl'))
                res = pkl.load(open(anno_path))
                kp = res['keypoints']
                pose = res['pose']
                trans = res['trans']
                flength = res['flength']
                delta_v = res['delta_v']
                vis = np.ones((kp.shape[0],1),dtype=bool)
            else:
                anno_path = osp.join(annotations_path, osp.basename(img_path).replace(opts.img_ext, '_ferrari-tail-face.mat'))
                res = sio.loadmat(anno_path, squeeze_me=True, struct_as_record=False)
                res = res['annotation']
                kp = res.kp.astype(float)
                invisible = res.invisible
                vis = np.atleast_2d(~invisible.astype(bool)).T
                landmarks = np.hstack((kp, vis))
                names = [str(res.names[i]) for i in range(len(res.names))]
        else:
            kp = None
            kp_pred = None
            vis = None

        if opts.synthetic:
            img, mask_img, kp = get_bbox(img_path, kp=kp)
            mask_img, _ = preprocess_image(img_path.replace('images','bgsub'), img_size=opts.img_size, kp=None, bgval=0, img=mask_img, img_ext=opts.img_ext)

            img, kp = preprocess_image(img_path, img_size=opts.img_size, kp=kp, img=img, bgval=0, img_ext=opts.img_ext)
        else:
            img, kp = preprocess_image(img_path, img_size=opts.img_size, kp=kp, bgval=opts.bgval, img_ext=opts.img_ext)

        code = osp.splitext(osp.basename(img_path))[0]

        # Load the gt mask
        if segm_eval:
            if not opts.synthetic:
                mask_path = osp.join(opts.img_path, 'masks', osp.basename(img_path).replace('jpg','png'))
                mask_img, _ = preprocess_image(mask_path, img_size=opts.img_size, bgval=0, img_ext=opts.img_ext)

        if opts.save_input:
            scipy.misc.imsave(osp.join(opts.out_path, 'proc_'+osp.basename(img_path)), 255*np.transpose(img, (1, 2, 0)))
            if segm_eval:
                scipy.misc.imsave(osp.join(opts.out_path, 'mask_proc_'+osp.basename(img_path)), 255*np.transpose(mask_img, (1, 2, 0)))

        print(idx)
        print(code)

        if opts.test_optimization_results:
            res_file = osp.join(opts.optimization_dir, 'proc_'+code+'_best_res.pkl')
            mask_file = osp.join(opts.optimization_dir, 'proc_'+code+'_best_mask.png')

            print('look for file ' + res_file)
            if not osp.exists(res_file) or not osp.exists(mask_file):
                res_file = osp.join(opts.optimization_dir, 'proc_'+code+'_init_res.pkl')
                mask_file = osp.join(opts.optimization_dir, 'proc_'+code+'_init_mask.png')
            else:
                print('found optimization result')

            data = pkl.load(open(res_file))
            mask_pred = scipy.misc.imread(mask_file)/255. 

        else:
            batch = {'img': torch.Tensor(np.expand_dims(img, 0))}
            outputs = predictor.predict(batch, rot=global_rotation)
            mask_pred = outputs['mask_pred'].detach().cpu().numpy()[0,:,:]
            shape_f[idx,:] = outputs['shape_f'].detach().cpu().numpy()
    
        if segm_eval: 
            M_gt = mask_img[0,:,:]
            overlap[idx] = np.sum(M_gt*mask_pred)/(np.sum(M_gt)+np.sum(mask_pred))
            IOU[idx] = np.sum(M_gt*mask_pred)/(np.sum(M_gt)+np.sum(mask_pred)-np.sum(M_gt*mask_pred))
            print(overlap[idx])
            print(IOU[idx])


        if use_annotations:
            if opts.test_optimization_results:
                kp_pred = ((data['kp_pred'][0,:,:]+1.)*128).astype(int)
            else:
                kp_pred = ((outputs['kp_pred'].cpu().detach().numpy()[0,:,:]+1.)*128).astype(int)

            kp_diffs = np.linalg.norm(kp[vis[:,0],:]/256. - kp_pred[vis[:,0],:]/256., axis=1)
            for a in range(n_alpha):
                kp_err = np.mean(kp_diffs < alpha[a])
                print(kp_err)
                err_tot[idx, a] = kp_err

        if show and not opts.test_optimization_results:
            renderer = predictor.vis_rend
            renderer.set_light_dir([0, 1, -1], 0.4)
            visualize(img, outputs, predictor.vis_rend, img_path, opts,
                    kp_gt=kp, kp_pred=kp_pred, vis=vis, out_path=opts.name+'_test_%03d' % (idx) +'.png', tex_mask=tex_mask,
                    visualize=opts.visualize)

        if show and opts.test_optimization_results:
            visualize_opt(img, predictor, predictor.vis_rend, data, opts.name+'_opt_%03d' % (idx) +'.png')

        idx += 1

        if mirror:
            kp_m = None
            vis_m = None
            img_m = mirror_image(img)
            M_gt_m = mirror_image(M_gt)
            if opts.save_input:
                scipy.misc.imsave('proc_mirr_'+osp.basename(img_path), 255*np.transpose(img_m, (1, 2, 0)))

            if opts.test_optimization_results:
                res_file = osp.join(opts.optimization_dir, 'proc_mirr_'+code+'_best_res.pkl')
                mask_file = osp.join(opts.optimization_dir, 'proc_mirr_'+code+'_best_mask.png')
                if not osp.exists(res_file) or not osp.exists(mask_file):
                    res_file = osp.join(opts.optimization_dir, 'proc_mirr_'+code+'_init_res.pkl')
                    mask_file = osp.join(opts.optimization_dir, 'proc_mirr_'+code+'_init_mask.png')
                data = pkl.load(open(res_file))
                mask_pred = scipy.misc.imread(mask_file)/255. 
            else:
                batch = {'img': torch.Tensor(np.expand_dims(img_m, 0))}
                outputs = predictor.predict(batch, rot=mirr_global_rotation)
                shape_f[idx,:] = outputs['shape_f'].detach().cpu().numpy()
                mask_pred = outputs['mask_pred'].detach().cpu().numpy()[0,:,:]
            if segm_eval:
                M_gt = mask_img[0,:,:]
                overlap[idx] = np.sum(M_gt_m*mask_pred)/(np.sum(M_gt_m)+np.sum(mask_pred))
                IOU[idx] = np.sum(M_gt*mask_pred)/(np.sum(M_gt)+np.sum(mask_pred)-np.sum(M_gt*mask_pred))
                print(overlap[idx])
                print(IOU[idx])

            if use_annotations:
                if opts.test_optimization_results:
                    kp_m, vis_m = mirror_keypoints(kp, vis, mask_img.shape[1])
                    kp_pred = ((data['kp_pred'][0,:,:]+1.)*128).astype(int)
                else:
                    kp_m, vis_m = mirror_keypoints(kp, vis, img.shape[2])
                    kp_pred = ((outputs['kp_pred'].cpu().detach().numpy()[0,:,:]+1.)*128).astype(int)

                kp_diffs = np.linalg.norm(kp_m[vis[:,0],:]/256. - kp_pred[vis[:,0],:]/256., axis=1)
                for a in range(n_alpha):
                    kp_err = np.mean(kp_diffs < alpha[a])
                    print(kp_err)
                    err_tot[idx, a] = kp_err

            if show and not opts.test_optimization_results:
                visualize(img_m, outputs, predictor.vis_rend, img_path, opts,
                    kp_gt=kp_m, kp_pred=kp_pred, vis=vis_m, out_path=opts.name+'_test_%03d' % (idx) +'.png', tex_mask=tex_mask,
                    visualize=opts.visualize)

            if show and opts.test_optimization_results:
                visualize_opt(img_m, predictor, predictor.vis_rend, data, opts.name+'_opt_%03d' % (idx) +'.png')
            idx += 1

    

    if use_annotations:
        print('PCK')
        print(np.mean(err_tot, axis=0))
        print(np.median(err_tot, axis=0))
        print(np.std(err_tot, axis=0))
        print('Overlap')
        print(np.mean(overlap))
        print(np.median(overlap))
        print(np.std(overlap))
        print('IOU')
        print(np.mean(IOU))
        print(np.median(IOU))
        print(np.std(IOU))
    


if __name__ == '__main__':
    opts.batch_size = 1
    app.run(main)


================================================
FILE: smal_model/__init__.py
================================================


================================================
FILE: smal_model/batch_lbs.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import numpy as np


def batch_skew(vec, batch_size=None, opts=None):
    """
    vec is N x 3, batch_size is int

    returns N x 3 x 3. Skew_sym version of each matrix.
    """
    if batch_size is None:
        batch_size = vec.shape.as_list()[0]
    col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7])
    indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1])
    updates = torch.reshape(
            torch.stack(
                [
                    -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1],
                    vec[:, 0]
                ],
                dim=1), [-1])
    out_shape = [batch_size * 9]
    res = torch.Tensor(np.zeros(out_shape[0])).cuda(device=opts.gpu_id)
    res[np.array(indices.flatten())] = updates
    res = torch.reshape(res, [batch_size, 3, 3])

    return res



def batch_rodrigues(theta, opts=None):
    """
    Theta is Nx3
    """
    batch_size = theta.shape[0]

    angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1)
    r = (torch.div(theta, angle)).unsqueeze(-1)

    angle = angle.unsqueeze(-1)
    cos = torch.cos(angle)
    sin = torch.sin(angle)

    outer = torch.matmul(r, r.transpose(1,2))

    eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).cuda(device=opts.gpu_id)
    H = batch_skew(r, batch_size=batch_size, opts=opts)
    R = cos * eyes + (1 - cos) * outer + sin * H 

    return R

def batch_lrotmin(theta):
    """
    Output of this is used to compute joint-to-pose blend shape mapping.
    Equation 9 in SMPL paper.


    Args:
      pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints.
            This includes the global rotation so K=24

    Returns
      diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted.,
    """
    # Ignore global rotation
    theta = theta[:,3:]

    Rs = batch_rodrigues(torch.reshape(theta, [-1,3]))
    lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207])

    return lrotmin

def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False, opts=None):
    """
    Computes absolute joint locations given pose.

    rotate_base: if True, rotates the global rotation by 90 deg in x axis.
    if False, this is the original SMPL coordinate.

    Args:
      Rs: N x 24 x 3 x 3 rotation vector of K joints
      Js: N x 24 x 3, joint locations before posing
      parent: 24 holding the parent id for each index

    Returns
      new_J : `Tensor`: N x 24 x 3 location of absolute joints
      A     : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS.
    """
    if rotate_base:
        print('Flipping the SMPL coordinate frame!!!!')
        rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
        rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile
        root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x)
    else:
        root_rotation = Rs[:, 0, :, :]

    # Now Js is N x 24 x 3 x 1
    Js = Js.unsqueeze(-1)
    N = Rs.shape[0]

    def make_A(R, t):
        # Rs is N x 3 x 3, ts is N x 3 x 1
        R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0))
        t_homo = torch.cat([t, torch.ones([N, 1, 1]).cuda(device=opts.gpu_id)], 1)
        return torch.cat([R_homo, t_homo], 2)

    A0 = make_A(root_rotation, Js[:, 0])
    results = [A0]
    for i in range(1, parent.shape[0]):
        j_here = Js[:, i] - Js[:, parent[i]]
        A_here = make_A(Rs[:, i], j_here)
        res_here = torch.matmul(
            results[parent[i]], A_here)
        results.append(res_here)

    # 10 x 24 x 4 x 4
    results = torch.stack(results, dim=1)

    new_J = results[:, :, :3, 3]

    # --- Compute relative A: Skinning is based on
    # how much the bone moved (not the final location of the bone)
    # but (final_bone - init_bone)
    # ---
    Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).cuda(device=opts.gpu_id)], 2)
    init_bone = torch.matmul(results, Js_w0)
    # Append empty 4 x 3:
    init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0))
    A = results - init_bone

    return new_J, A


================================================
FILE: smal_model/smal_basics.py
================================================
import os
import pickle as pkl
import numpy as np
from smpl_webuser.serialization import load_model
import pickle as pkl

model_dir = 'smalst/smpl_models/'

def align_smal_template_to_symmetry_axis(v):
    # These are the indexes of the points that are on the symmetry axis
    I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003]

    v = v - np.mean(v)
    y = np.mean(v[I,1])
    v[:,1] = v[:,1] - y
    v[I,1] = 0
    sym_path = os.path.join(model_dir, 'symIdx.pkl')
    symIdx = pkl.load(open(sym_path))
    left = v[:, 1] < 0
    right = v[:, 1] > 0
    center = v[:, 1] == 0
    v[left[symIdx]] = np.array([1,-1,1])*v[left]

    left_inds = np.where(left)[0]
    right_inds = np.where(right)[0]
    center_inds = np.where(center)[0]

    try:
        assert(len(left_inds) == len(right_inds))
    except:
        import pdb; pdb.set_trace()

    return v, left_inds, right_inds, center_inds

def load_smal_model(model_name='my_smpl_00781_4_all.pkl'):
    model_path = os.path.join(model_dir, model_name)

    model = load_model(model_path)
    v = align_smal_template_to_symmetry_axis(model.r.copy())

   
    return v, model.f

def get_horse_template(model_name='my_smpl_00781_4_all.pkl', data_name='my_smpl_data_00781_4_all.pkl'):

    model_path = os.path.join(model_dir, model_name)
    model = load_model(model_path)
    nBetas = len(model.betas.r)
    data_path = os.path.join(model_dir, 'my_smpl_data_00781_4_all.pkl')
    data = pkl.load(open(data_path))
    # Select average zebra/horse
    betas = data['cluster_means'][2][:nBetas]
    model.betas[:] = betas
    v = model.r.copy()
    return v




================================================
FILE: smal_model/smal_torch.py
================================================
"""

    PyTorch implementation of the SMAL/SMPL model

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch
from torch.autograd import Variable
import pickle as pkl 
from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation
from .smal_basics import align_smal_template_to_symmetry_axis, get_horse_template

# There are chumpy variables so convert them to numpy.
def undo_chumpy(x):
    return x if isinstance(x, np.ndarray) else x.r

class SMAL(object):
    def __init__(self, pkl_path, opts, dtype=torch.float):
        self.opts = opts
        # -- Load SMPL params --
        with open(pkl_path, 'r') as f:
            dd = pkl.load(f)

        self.f = dd['f']

        v_template = get_horse_template(model_name='my_smpl_00781_4_all.pkl', data_name='my_smpl_data_00781_4_all.pkl')
        v, self.left_inds, self.right_inds, self.center_inds = align_smal_template_to_symmetry_axis(v_template)

        # Mean template vertices
        self.v_template = Variable(
            torch.Tensor(v).cuda(device=self.opts.gpu_id),
            requires_grad=False)
        # Size of mesh [Number of vertices, 3]
        self.size = [self.v_template.shape[0], 3]
        self.num_betas = dd['shapedirs'].shape[-1]
        # Shape blend shape basis
        
        shapedir = np.reshape(
            undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T
        self.shapedirs = Variable(
            torch.Tensor(shapedir).cuda(device=self.opts.gpu_id), requires_grad=False)

        # Regressor for joint locations given shape 
        self.J_regressor = Variable(
            torch.Tensor(dd['J_regressor'].T.todense()).cuda(device=self.opts.gpu_id),
            requires_grad=False)

        # Pose blend shape basis
        num_pose_basis = dd['posedirs'].shape[-1]
        
        posedirs = np.reshape(
            undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T
        self.posedirs = Variable(
            torch.Tensor(posedirs).cuda(device=self.opts.gpu_id), requires_grad=False)

        # indices of parents for each joints
        self.parents = dd['kintree_table'][0].astype(np.int32)

        # LBS weights
        self.weights = Variable(
            torch.Tensor(undo_chumpy(dd['weights'])).cuda(device=self.opts.gpu_id),
            requires_grad=False)

    def __call__(self, beta, theta, trans=None, del_v=None, get_skin=True):

        if self.opts.use_smal_betas:
            nBetas = beta.shape[1]
        else:
            nBetas = 0

        # 1. Add shape blend shapes
        
        if nBetas > 0:
            if del_v is None:
                v_shaped = self.v_template + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]])
            else:
                v_shaped = self.v_template + del_v + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]])
        else:
            if del_v is None:
                v_shaped = self.v_template.unsqueeze(0)
            else:
                v_shaped = self.v_template + del_v 

        # 2. Infer shape-dependent joint locations.
        Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
        Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
        Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
        J = torch.stack([Jx, Jy, Jz], dim=2)

        # 3. Add pose blend shapes
        # N x 24 x 3 x 3
        Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3]), opts=self.opts), [-1, 35, 3, 3])
        # Ignore global rotation.
        pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).cuda(device=self.opts.gpu_id), [-1, 306])


        
        v_posed = torch.reshape(
            torch.matmul(pose_feature, self.posedirs),
            [-1, self.size[0], self.size[1]]) + v_shaped

        #4. Get the global joint location
        self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents, opts=self.opts)


        # 5. Do skinning:
        num_batch = theta.shape[0]
        
        weights_t = self.weights.repeat([num_batch, 1])
        W = torch.reshape(weights_t, [num_batch, -1, 35])

            
        T = torch.reshape(
            torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])),
                [num_batch, -1, 4, 4])
        v_posed_homo = torch.cat(
                [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).cuda(device=self.opts.gpu_id)], 2)
        v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1))

        verts = v_homo[:, :, :3, 0]

        if trans is None:
            trans = torch.zeros((num_batch,3)).cuda(device=self.opts.gpu_id)

        verts = verts + trans[:,None,:]

        # Get joints:
        joint_x = torch.matmul(verts[:, :, 0], self.J_regressor)
        joint_y = torch.matmul(verts[:, :, 1], self.J_regressor)
        joint_z = torch.matmul(verts[:, :, 2], self.J_regressor)
        joints = torch.stack([joint_x, joint_y, joint_z], dim=2)

        if get_skin:
            return verts, joints, Rs
        else:
            return joints













================================================
FILE: smal_model/symIdx.pkl
================================================
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I3889
tp6
cnumpy
dtype
p7
(S'i8'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x00\x0b\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x0e\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x11\x00\x00\x00\x00\x00\x00\x00\x12\x00\x00\x00\x00\x00\x00\x00\x13\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00\x15\x00\x00\x00\x00\x00\x00\x00\x16\x00\x00\x00\x00\x00\x00\x00\x17\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00\x19\x00\x00\x00\x00\x00\x00\x00\x1a\x00\x00\x00\x00\x00\x00\x00\x1b\x00\x00\x00\x00\x00\x00\x00\x1c\x00\x00\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x1f\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\xdc\x07\x00\x00\x00\x00\x00\x00\xdd\x07\x00\x00\x00\x00\x00\x00\xde\x07\x00\x00\x00\x00\x00\x00\xdf\x07\x00\x00\x00\x00\x00\x00%\x00\x00\x00\x00\x00\x00\x00\xe0\x07\x00\x00\x00\x00\x00\x00\xe1\x07\x00\x00\x00\x00\x00\x00\xe2\x07\x00\x00\x00\x00\x00\x00\xe3\x07\x00\x00\x00\x00\x00\x00\xe4\x07\x00\x00\x00\x00\x00\x00\xe5\x07\x00\x00\x00\x00\x00\x00\xe6\x07\x00\x00\x00\x00\x00\x00\xe7\x07\x00\x00\x00\x00\x00\x00\xe8\x07\x00\x00\x00\x00\x00\x00\xe9\x07\x00\x00\x00\x00\x00\x00\xea\x07\x00\x00\x00\x00\x00\x00\xeb\x07\x00\x00\x00\x00\x00\x00\xec\x07\x00\x00\x00\x00\x00\x00\xed\x07\x00\x00\x00\x00\x00\x00\xee\x07\x00\x00\x00\x00\x00\x00\xef\x07\x00\x00\x00\x00\x00\x00\xf0\x07\x00\x00\x00\x00\x00\x007\x00\x00\x00\x00\x00\x00\x00\xf1\x07\x00\x00\x00\x00\x00\x00\xf2\x07\x00\x00\x00\x00\x00\x00\xf3\x07\x00\x00\x00\x00\x00\x00\xf4\x07\x00\x00\x00\x00\x00\x00\xf5\x07\x00\x00\x00\x00\x00\x00\xf6\x07\x00\x00\x00\x00\x00\x00\xf7\x07\x00\x00\x00\x00\x00\x00\xf8\x07\x00\x00\x00\x00\x00\x00\xf9\x07\x00\x00\x00\x00\x00\x00\xfa\x07\x00\x00\x00\x00\x00\x00\xfb\x07\x00\x00\x00\x00\x00\x00\xfc\x07\x00\x00\x00\x00\x00\x00\xfd\x07\x00\x00\x00\x00\x00\x00\xfe\x07\x00\x00\x00\x00\x00\x00\xff\x07\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x01\x08\x00\x00\x00\x00\x00\x00\x02\x08\x00\x00\x00\x00\x00\x00\x03\x08\x00\x00\x00\x00\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\x05\x08\x00\x00\x00\x00\x00\x00\x06\x08\x00\x00\x00\x00\x00\x00\x07\x08\x00\x00\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\t\x08\x00\x00\x00\x00\x00\x00\n\x08\x00\x00\x00\x00\x00\x00\x0b\x08\x00\x00\x00\x00\x00\x00\x0c\x08\x00\x00\x00\x00\x00\x00\r\x08\x00\x00\x00\x00\x00\x00\x0e\x08\x00\x00\x00\x00\x00\x00\x0f\x08\x00\x00\x00\x00\x00\x00\x10\x08\x00\x00\x00\x00\x00\x00\x11\x08\x00\x00\x00\x00\x00\x00\x12\x08\x00\x00\x00\x00\x00\x00\x13\x08\x00\x00\x00\x00\x00\x00\x14\x08\x00\x00\x00\x00\x00\x00\x15\x08\x00\x00\x00\x00\x00\x00\x16\x08\x00\x00\x00\x00\x00\x00\x17\x08\x00\x00\x00\x00\x00\x00\x18\x08\x00\x00\x00\x00\x00\x00\x19\x08\x00\x00\x00\x00\x00\x00\x1a\x08\x00\x00\x00\x00\x00\x00\x1b\x08\x00\x00\x00\x00\x00\x00\x1c\x08\x00\x00\x00\x00\x00\x00\x1d\x08\x00\x00\x00\x00\x00\x00\x1e\x08\x00\x00\x00\x00\x00\x00\x1f\x08\x00\x00\x00\x00\x00\x00 \x08\x00\x00\x00\x00\x00\x00!\x08\x00\x00\x00\x00\x00\x00"\x08\x00\x00\x00\x00\x00\x00#\x08\x00\x00\x00\x00\x00\x00$\x08\x00\x00\x00\x00\x00\x00%\x08\x00\x00\x00\x00\x00\x00&\x08\x00\x00\x00\x00\x00\x00\'\x08\x00\x00\x00\x00\x00\x00(\x08\x00\x00\x00\x00\x00\x00)\x08\x00\x00\x00\x00\x00\x00*\x08\x00\x00\x00\x00\x00\x00+\x08\x00\x00\x00\x00\x00\x00,\x08\x00\x00\x00\x00\x00\x00-\x08\x00\x00\x00\x00\x00\x00.\x08\x00\x00\x00\x00\x00\x00/\x08\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00\x000\x08\x00\x00\x00\x00\x00\x001\x08\x00\x00\x00\x00\x00\x002\x08\x00\x00\x00\x00\x00\x003\x08\x00\x00\x00\x00\x00\x004\x08\x00\x00\x00\x00\x00\x005\x08\x00\x00\x00\x00\x00\x006\x08\x00\x00\x00\x00\x00\x007\x08\x00\x00\x00\x00\x00\x008\x08\x00\x00\x00\x00\x00\x009\x08\x00\x00\x00\x00\x00\x00:\x08\x00\x00\x00\x00\x00\x00;\x08\x00\x00\x00\x00\x00\x00<\x08\x00\x00\x00\x00\x00\x00=\x08\x00\x00\x00\x00\x00\x00>\x08\x00\x00\x00\x00\x00\x00?\x08\x00\x00\x00\x00\x00\x00@\x08\x00\x00\x00\x00\x00\x00A\x08\x00\x00\x00\x00\x00\x00B\x08\x00\x00\x00\x00\x00\x00C\x08\x00\x00\x00\x00\x00\x00D\x08\x00\x00\x00\x00\x00\x00E\x08\x00\x00\x00\x00\x00\x00F\x08\x00\x00\x00\x00\x00\x00G\x08\x00\x00\x00\x00\x00\x00H\x08\x00\x00\x00\x00\x00\x00I\x08\x00\x00\x00\x00\x00\x00J\x08\x00\x00\x00\x00\x00\x00K\x08\x00\x00\x00\x00\x00\x00L\x08\x00\x00\x00\x00\x00\x00M\x08\x00\x00\x00\x00\x00\x00N\x08\x00\x00\x00\x00\x00\x00O\x08\x00\x00\x00\x00\x00\x00P\x08\x00\x00\x00\x00\x00\x00Q\x08\x00\x00\x00\x00\x00\x00R\x08\x00\x00\x00\x00\x00\x00S\x08\x00\x00\x00\x00\x00\x00T\x08\x00\x00\x00\x00\x00\x00U\x08\x00\x00\x00\x00\x00\x00V\x08\x00\x00\x00\x00\x00\x00W\x08\x00\x00\x00\x00\x00\x00X\x08\x00\x00\x00\x00\x00\x00Y\x08\x00\x00\x00\x00\x00\x00\xa3\x00\x00\x00\x00\x00\x00\x00Z\x08\x00\x00\x00\x00\x00\x00[\x08\x00\x00\x00\x00\x00\x00\\\x08\x00\x00\x00\x00\x00\x00]\x08\x00\x00\x00\x00\x00\x00^\x08\x00\x00\x00\x00\x00\x00_\x08\x00\x00\x00\x00\x00\x00`\x08\x00\x00\x00\x00\x00\x00a\x08\x00\x00\x00\x00\x00\x00b\x08\x00\x00\x00\x00\x00\x00c\x08\x00\x00\x00\x00\x00\x00d\x08\x00\x00\x00\x00\x00\x00e\x08\x00\x00\x00\x00\x00\x00f\x08\x00\x00\x00\x00\x00\x00g\x08\x00\x00\x00\x00\x00\x00h\x08\x00\x00\x00\x00\x00\x00i\x08\x00\x00\x00\x00\x00\x00j\x08\x00\x00\x00\x00\x00\x00k\x08\x00\x00\x00\x00\x00\x00l\x08\x00\x00\x00\x00\x00\x00m\x08\x00\x00\x00\x00\x00\x00n\x08\x00\x00\x00\x00\x00\x00o\x08\x00\x00\x00\x00\x00\x00p\x08\x00\x00\x00\x00\x00\x00q\x08\x00\x00\x00\x00\x00\x00r\x08\x00\x00\x00\x00\x00\x00s\x08\x00\x00\x00\x00\x00\x00t\x08\x00\x00\x00\x00\x00\x00u\x08\x00\x00\x00\x00\x00\x00v\x08\x00\x00\x00\x00\x00\x00w\x08\x00\x00\x00\x00\x00\x00x\x08\x00\x00\x00\x00\x00\x00y\x08\x00\x00\x00\x00\x00\x00z\x08\x00\x00\x00\x00\x00\x00{\x08\x00\x00\x00\x00\x00\x00|\x08\x00\x00\x00\x00\x00\x00}\x08\x00\x00\x00\x00\x00\x00~\x08\x00\x00\x00\x00\x00\x00\x7f\x08\x00\x00\x00\x00\x00\x00\x80\x08\x00\x00\x00\x00\x00\x00\x81\x08\x00\x00\x00\x00\x00\x00\x82\x08\x00\x00\x00\x00\x00\x00\x83\x08\x00\x00\x00\x00\x00\x00\x84\x08\x00\x00\x00\x00\x00\x00\x85\x08\x00\x00\x00\x00\x00\x00\x86\x08\x00\x00\x00\x00\x00\x00\xd1\x00\x00\x00\x00\x00\x00\x00\xd2\x00\x00\x00\x00\x00\x00\x00\xd3\x00\x00\x00\x00\x00\x00\x00\x87\x08\x00\x00\x00\x00\x00\x00\xd5\x00\x00\x00\x00\x00\x00\x00\x88\x08\x00\x00\x00\x00\x00\x00\x89\x08\x00\x00\x00\x00\x00\x00\xd8\x00\x00\x00\x00\x00\x00\x00\x8a\x08\x00\x00\x00\x00\x00\x00\x8b\x08\x00\x00\x00\x00\x00\x00\x8c\x08\x00\x00\x00\x00\x00\x00\x8d\x08\x00\x00\x00\x00\x00\x00\x8e\x08\x00\x00\x00\x00\x00\x00\x8f\x08\x00\x00\x00\x00\x00\x00\x90\x08\x00\x00\x00\x00\x00\x00\x91\x08\x00\x00\x00\x00\x00\x00\x92\x08\x00\x00\x00\x00\x00\x00\x93\x08\x00\x00\x00\x00\x00\x00\xe3\x00\x00\x00\x00\x00\x00\x00\x94\x08\x00\x00\x00\x00\x00\x00\x95\x08\x00\x00\x00\x00\x00\x00\x96\x08\x00\x00\x00\x00\x00\x00\x97\x08\x00\x00\x00\x00\x00\x00\x98\x08\x00\x00\x00\x00\x00\x00\x99\x08\x00\x00\x00\x00\x00\x00\x9a\x08\x00\x00\x00\x00\x00\x00\x9b\x08\x00\x00\x00\x00\x00\x00\x9c\x08\x00\x00\x00\x00\x00\x00\x9d\x08\x00\x00\x00\x00\x00\x00\x9e\x08\x00\x00\x00\x00\x00\x00\x9f\x08\x00\x00\x00\x00\x00\x00\xa0\x08\x00\x00\x00\x00\x00\x00\xa1\x08\x00\x00\x00\x00\x00\x00\xa2\x08\x00\x00\x00\x00\x00\x00\xa3\x08\x00\x00\x00\x00\x00\x00\xa4\x08\x00\x00\x00\x00\x00\x00\xa5\x08\x00\x00\x00\x00\x00\x00\xa6\x08\x00\x00\x00\x00\x00\x00\xa7\x08\x00\x00\x00\x00\x00\x00\xa8\x08\x00\x00\x00\x00\x00\x00\xa9\x08\x00\x00\x00\x00\x00\x00\xaa\x08\x00\x00\x00\x00\x00\x00\xab\x08\x00\x00\x00\x00\x00\x00\xac\x08\x00\x00\x00\x00\x00\x00\xad\x08\x00\x00\x00\x00\x00\x00\xae\x08\x00\x00\x00\x00\x00\x00\xaf\x08\x00\x00\x00\x00\x00\x00\xb0\x08\x00\x00\x00\x00\x00\x00\xb1\x08\x00\x00\x00\x00\x00\x00\xb2\x08\x00\x00\x00\x00\x00\x00\xb3\x08\x00\x00\x00\x00\x00\x00\xb4\x08\x00\x00\x00\x00\x00\x00\xb5\x08\x00\x00\x00\x00\x00\x00\xb6\x08\x00\x00\x00\x00\x00\x00\xb7\x08\x00\x00\x00\x00\x00\x00\xb8\x08\x00\x00\x00\x00\x00\x00\xb9\x08\x00\x00\x00\x00\x00\x00\xba\x08\x00\x00\x00\x00\x00\x00\xbb\x08\x00\x00\x00\x00\x00\x00\xbc\x08\x00\x00\x00\x00\x00\x00\xbd\x08\x00\x00\x00\x00\x00\x00\xbe\x08\x00\x00\x00\x00\x00\x00\xbf\x08\x00\x00\x00\x00\x00\x00\xc0\x08\x00\x00\x00\x00\x00\x00\xc1\x08\x00\x00\x00\x00\x00\x00\xc2\x08\x00\x00\x00\x00\x00\x00\xc3\x08\x00\x00\x00\x00\x00\x00\xc4\x08\x00\x00\x00\x00\x00\x00\xc5\x08\x00\x00\x00\x00\x00\x00\xc6\x08\x00\x00\x00\x00\x00\x00\xc7\x08\x00\x00\x00\x00\x00\x00\xc8\x08\x00\x00\x00\x00\x00\x00\xc9\x08\x00\x00\x00\x00\x00\x00\xca\x08\x00\x00\x00\x00\x00\x00\xcb\x08\x00\x00\x00\x00\x00\x00\xcc\x08\x00\x00\x00\x00\x00\x00\xcd\x08\x00\x00\x00\x00\x00\x00\xce\x08\x00\x00\x00\x00\x00\x00\xcf\x08\x00\x00\x00\x00\x00\x00\xd0\x08\x00\x00\x00\x00\x00\x00\xd1\x08\x00\x00\x00\x00\x00\x00\xd2\x08\x00\x00\x00\x00\x00\x00\xd3\x08\x00\x00\x00\x00\x00\x00\xd4\x08\x00\x00\x00\x00\x00\x00\xd5\x08\x00\x00\x00\x00\x00\x00\xd6\x08\x00\x00\x00\x00\x00\x00\xd7\x08\x00\x00\x00\x00\x00\x00\xd8\x08\x00\x00\x00\x00\x00\x00\xd9\x08\x00\x00\x00\x00\x00\x00\xda\x08\x00\x00\x00\x00\x00\x00\xdb\x08\x00\x00\x00\x00\x00\x00\xdc\x08\x00\x00\x00\x00\x00\x00\xdd\x08\x00\x00\x00\x00\x00\x00\xde\x08\x00\x00\x00\x00\x00\x00\xdf\x08\x00\x00\x00\x00\x00\x00\xe0\x08\x00\x00\x00\x00\x00\x00\xe1\x08\x00\x00\x00\x00\x00\x00\xe2\x08\x00\x00\x00\x00\x00\x00\xe3\x08\x00\x00\x00\x00\x00\x00\xe4\x08\x00\x00\x00\x00\x00\x00\xe5\x08\x00\x00\x00\x00\x00\x00\xe6\x08\x00\x00\x00\x00\x00\x00\xe7\x08\x00\x00\x00\x00\x00\x00\xe8\x08\x00\x00\x00\x00\x00\x00\xe9\x08\x00\x00\x00\x00\x00\x00\xea\x08\x00\x00\x00\x00\x00\x00\xeb\x08\x00\x00\x00\x00\x00\x00\xec\x08\x00\x00\x00\x00\x00\x00\xed\x08\x00\x00\x00\x00\x00\x00\xee\x08\x00\x00\x00\x00\x00\x00\xef\x08\x00\x00\x00\x00\x00\x00\xf0\x08\x00\x00\x00\x00\x00\x00\xf1\x08\x00\x00\x00\x00\x00\x00\xf2\x08\x00\x00\x00\x00\x00\x00\xf3\x08\x00\x00\x00\x00\x00\x00\xf4\x08\x00\x00\x00\x00\x00\x00\xf5\x08\x00\x00\x00\x00\x00\x00F\x01\x00\x00\x00\x00\x00\x00\xf6\x08\x00\x00\x00\x00\x00\x00\xf7\x08\x00\x00\x00\x00\x00\x00\xf8\x08\x00\x00\x00\x00\x00\x00\xf9\x08\x00\x00\x00\x00\x00\x00\xfa\x08\x00\x00\x00\x00\x00\x00\xfb\x08\x00\x00\x00\x00\x00\x00\xfc\x08\x00\x00\x00\x00\x00\x00\xfd\x08\x00\x00\x00\x00\x00\x00\xfe\x08\x00\x00\x00\x00\x00\x00\xff\x08\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x00\x01\t\x00\x00\x00\x00\x00\x00\x02\t\x00\x00\x00\x00\x00\x00\x03\t\x00\x00\x00\x00\x00\x00\x04\t\x00\x00\x00\x00\x00\x00\x05\t\x00\x00\x00\x00\x00\x00\x06\t\x00\x00\x00\x00\x00\x00\x07\t\x00\x00\x00\x00\x00\x00\x08\t\x00\x00\x00\x00\x00\x00\t\t\x00\x00\x00\x00\x00\x00\n\t\x00\x00\x00\x00\x00\x00\x0b\t\x00\x00\x00\x00\x00\x00\x0c\t\x00\x00\x00\x00\x00\x00\r\t\x00\x00\x00\x00\x00\x00\x0e\t\x00\x00\x00\x00\x00\x00\x0f\t\x00\x00\x00\x00\x00\x00\x10\t\x00\x00\x00\x00\x00\x00\x11\t\x00\x00\x00\x00\x00\x00\x12\t\x00\x00\x00\x00\x00\x00\x13\t\x00\x00\x00\x00\x00\x00\x14\t\x00\x00\x00\x00\x00\x00\x15\t\x00\x00\x00\x00\x00\x00\x16\t\x00\x00\x00\x00\x00\x00\x17\t\x00\x00\x00\x00\x00\x00\x18\t\x00\x00\x00\x00\x00\x00\x19\t\x00\x00\x00\x00\x00\x00\x1a\t\x00\x00\x00\x00\x00\x00\x1b\t\x00\x00\x00\x00\x00\x00\x1c\t\x00\x00\x00\x00\x00\x00\x1d\t\x00\x00\x00\x00\x00\x00\x1e\t\x00\x00\x00\x00\x00\x00\x1f\t\x00\x00\x00\x00\x00\x00 \t\x00\x00\x00\x00\x00\x00!\t\x00\x00\x00\x00\x00\x00"\t\x00\x00\x00\x00\x00\x00#\t\x00\x00\x00\x00\x00\x00$\t\x00\x00\x00\x00\x00\x00%\t\x00\x00\x00\x00\x00\x00&\t\x00\x00\x00\x00\x00\x00\'\t\x00\x00\x00\x00\x00\x00(\t\x00\x00\x00\x00\x00\x00)\t\x00\x00\x00\x00\x00\x00*\t\x00\x00\x00\x00\x00\x00+\t\x00\x00\x00\x00\x00\x00,\t\x00\x00\x00\x00\x00\x00-\t\x00\x00\x00\x00\x00\x00.\t\x00\x00\x00\x00\x00\x00/\t\x00\x00\x00\x00\x00\x000\t\x00\x00\x00\x00\x00\x001\t\x00\x00\x00\x00\x00\x002\t\x00\x00\x00\x00\x00\x003\t\x00\x00\x00\x00\x00\x004\t\x00\x00\x00\x00\x00\x005\t\x00\x00\x00\x00\x00\x006\t\x00\x00\x00\x00\x00\x007\t\x00\x00\x00\x00\x00\x008\t\x00\x00\x00\x00\x00\x009\t\x00\x00\x00\x00\x00\x00\x8b\x01\x00\x00\x00\x00\x00\x00:\t\x00\x00\x00\x00\x00\x00;\t\x00\x00\x00\x00\x00\x00<\t\x00\x00\x00\x00\x00\x00=\t\x00\x00\x00\x00\x00\x00>\t\x00\x00\x00\x00\x00\x00?\t\x00\x00\x00\x00\x00\x00@\t\x00\x00\x00\x00\x00\x00A\t\x00\x00\x00\x00\x00\x00B\t\x00\x00\x00\x00\x00\x00C\t\x00\x00\x00\x00\x00\x00D\t\x00\x00\x00\x00\x00\x00E\t\x00\x00\x00\x00\x00\x00F\t\x00\x00\x00\x00\x00\x00G\t\x00\x00\x00\x00\x00\x00H\t\x00\x00\x00\x00\x00\x00I\t\x00\x00\x00\x00\x00\x00J\t\x00\x00\x00\x00\x00\x00K\t\x00\x00\x00\x00\x00\x00L\t\x00\x00\x00\x00\x00\x00M\t\x00\x00\x00\x00\x00\x00N\t\x00\x00\x00\x00\x00\x00O\t\x00\x00\x00\x00\x00\x00P\t\x00\x00\x00\x00\x00\x00Q\t\x00\x00\x00\x00\x00\x00R\t\x00\x00\x00\x00\x00\x00S\t\x00\x00\x00\x00\x00\x00T\t\x00\x00\x00\x00\x00\x00U\t\x00\x00\x00\x00\x00\x00V\t\x00\x00\x00\x00\x00\x00W\t\x00\x00\x00\x00\x00\x00X\t\x00\x00\x00\x00\x00\x00Y\t\x00\x00\x00\x00\x00\x00Z\t\x00\x00\x00\x00\x00\x00[\t\x00\x00\x00\x00\x00\x00\\\t\x00\x00\x00\x00\x00\x00]\t\x00\x00\x00\x00\x00\x00^\t\x00\x00\x00\x00\x00\x00_\t\x00\x00\x00\x00\x00\x00`\t\x00\x00\x00\x00\x00\x00a\t\x00\x00\x00\x00\x00\x00b\t\x00\x00\x00\x00\x00\x00c\t\x00\x00\x00\x00\x00\x00d\t\x00\x00\x00\x00\x00\x00e\t\x00\x00\x00\x00\x00\x00f\t\x00\x00\x00\x00\x00\x00g\t\x00\x00\x00\x00\x00\x00h\t\x00\x00\x00\x00\x00\x00i\t\x00\x00\x00\x00\x00\x00j\t\x00\x00\x00\x00\x00\x00k\t\x00\x00\x00\x00\x00\x00l\t\x00\x00\x00\x00\x00\x00m\t\x00\x00\x00\x00\x00\x00n\t\x00\x00\x00\x00\x00\x00o\t\x00\x00\x00\x00\x00\x00p\t\x00\x00\x00\x00\x00\x00q\t\x00\x00\x00\x00\x00\x00\xc4\x01\x00\x00\x00\x00\x00\x00r\t\x00\x00\x00\x00\x00\x00s\t\x00\x00\x00\x00\x00\x00t\t\x00\x00\x00\x00\x00\x00u\t\x00\x00\x00\x00\x00\x00v\t\x00\x00\x00\x00\x00\x00w\t\x00\x00\x00\x00\x00\x00x\t\x00\x00\x00\x00\x00\x00y\t\x00\x00\x00\x00\x00\x00z\t\x00\x00\x00\x00\x00\x00{\t\x00\x00\x00\x00\x00\x00|\t\x00\x00\x00\x00\x00\x00}\t\x00\x00\x00\x00\x00\x00~\t\x00\x00\x00\x00\x00\x00\x7f\t\x00\x00\x00\x00\x00\x00\x80\t\x00\x00\x00\x00\x00\x00\x81\t\x00\x00\x00\x00\x00\x00\x82\t\x00\x00\x00\x00\x00\x00\x83\t\x00\x00\x00\x00\x00\x00\x84\t\x00\x00\x00\x00\x00\x00\x85\t\x00\x00\x00\x00\x00\x00\x86\t\x00\x00\x00\x00\x00\x00\x87\t\x00\x00\x00\x00\x00\x00\x88\t\x00\x00\x00\x00\x00\x00\x89\t\x00\x00\x00\x00\x00\x00\x8a\t\x00\x00\x00\x00\x00\x00\x8b\t\x00\x00\x00\x00\x00\x00\x8c\t\x00\x00\x00\x00\x00\x00\x8d\t\x00\x00\x00\x00\x00\x00\x8e\t\x00\x00\x00\x00\x00\x00\x8f\t\x00\x00\x00\x00\x00\x00\x90\t\x00\x00\x00\x00\x00\x00\x91\t\x00\x00\x00\x00\x00\x00\x92\t\x00\x00\x00\x00\x00\x00\x93\t\x00\x00\x00\x00\x00\x00\x94\t\x00\x00\x00\x00\x00\x00\x95\t\x00\x00\x00\x00\x00\x00\x96\t\x00\x00\x00\x00\x00\x00\x97\t\x00\x00\x00\x00\x00\x00\x98\t\x00\x00\x00\x00\x00\x00\x99\t\x00\x00\x00\x00\x00\x00\x9a\t\x00\x00\x00\x00\x00\x00\x9b\t\x00\x00\x00\x00\x00\x00\x9c\t\x00\x00\x00\x00\x00\x00\x9d\t\x00\x00\x00\x00\x00\x00\x9e\t\x00\x00\x00\x00\x00\x00\x9f\t\x00\x00\x00\x00\x00\x00\xa0\t\x00\x00\x00\x00\x00\x00\xa1\t\x00\x00\x00\x00\x00\x00\xa2\t\x00\x00\x00\x00\x00\x00\xa3\t\x00\x00\x00\x00\x00\x00\xa4\t\x00\x00\x00\x00\x00\x00\xa5\t\x00\x00\x00\x00\x00\x00\xa6\t\x00\x00\x00\x00\x00\x00\xa7\t\x00\x00\x00\x00\x00\x00\xa8\t\x00\x00\x00\x00\x00\x00\xa9\t\x00\x00\x00\x00\x00\x00\xaa\t\x00\x00\x00\x00\x00\x00\xab\t\x00\x00\x00\x00\x00\x00\xac\t\x00\x00\x00\x00\x00\x00\xad\t\x00\x00\x00\x00\x00\x00\xae\t\x00\x00\x00\x00\x00\x00\xaf\t\x00\x00\x00\x00\x00\x00\xb0\t\x00\x00\x00\x00\x00\x00\xb1\t\x00\x00\x00\x00\x00\x00\xb2\t\x00\x00\x00\x00\x00\x00\xb3\t\x00\x00\x00\x00\x00\x00\xb4\t\x00\x00\x00\x00\x00\x00\xb5\t\x00\x00\x00\x00\x00\x00\xb6\t\x00\x00\x00\x00\x00\x00\xb7\t\x00\x00\x00\x00\x00\x00\xb8\t\x00\x00\x00\x00\x00\x00\xb9\t\x00\x00\x00\x00\x00\x00\xba\t\x00\x00\x00\x00\x00\x00\xbb\t\x00\x00\x00\x00\x00\x00\xbc\t\x00\x00\x00\x00\x00\x00\xbd\t\x00\x00\x00\x00\x00\x00\xbe\t\x00\x00\x00\x00\x00\x00\xbf\t\x00\x00\x00\x00\x00\x00\xc0\t\x00\x00\x00\x00\x00\x00\xc1\t\x00\x00\x00\x00\x00\x00\xc2\t\x00\x00\x00\x00\x00\x00\xc3\t\x00\x00\x00\x00\x00\x00\xc4\t\x00\x00\x00\x00\x00\x00\xc5\t\x00\x00\x00\x00\x00\x00\xc6\t\x00\x00\x00\x00\x00\x00\xc7\t\x00\x00\x00\x00\x00\x00\xc8\t\x00\x00\x00\x00\x00\x00\xc9\t\x00\x00\x00\x00\x00\x00\xca\t\x00\x00\x00\x00\x00\x00\xcb\t\x00\x00\x00\x00\x00\x00\xcc\t\x00\x00\x00\x00\x00\x00\xcd\t\x00\x00\x00\x00\x00\x00\xce\t\x00\x00\x00\x00\x00\x00\xcf\t\x00\x00\x00\x00\x00\x00\xd0\t\x00\x00\x00\x00\x00\x00\xd1\t\x00\x00\x00\x00\x00\x00\xd2\t\x00\x00\x00\x00\x00\x00\xd3\t\x00\x00\x00\x00\x00\x00\xd4\t\x00\x00\x00\x00\x00\x00\xd5\t\x00\x00\x00\x00\x00\x00\xd6\t\x00\x00\x00\x00\x00\x00\xd7\t\x00\x00\x00\x00\x00\x00\xd8\t\x00\x00\x00\x00\x00\x00\xd9\t\x00\x00\x00\x00\x00\x00\xda\t\x00\x00\x00\x00\x00\x00\xdb\t\x00\x00\x00\x00\x00\x00\xdc\t\x00\x00\x00\x00\x00\x00\xdd\t\x00\x00\x00\x00\x00\x00\xde\t\
Download .txt
gitextract_38wvuzcu/

├── LICENSE.txt
├── LICENSE_SMAL_MODEL.txt
├── README.md
├── __init__.py
├── _config.yml
├── data/
│   ├── __init__.py
│   ├── smal_base.py
│   └── zebra.py
├── docs/
│   └── tmp.txt
├── experiments/
│   ├── __init__.py
│   └── smal_shape.py
├── external/
│   ├── __init__.py
│   └── install_external.sh
├── nnutils/
│   ├── __init__.py
│   ├── geom_utils.py
│   ├── loss_utils.py
│   ├── net_blocks.py
│   ├── nmr.py
│   ├── perceptual_loss.py
│   ├── smal_mesh_eval.py
│   ├── smal_mesh_net.py
│   ├── smal_predictor.py
│   └── train_utils.py
├── requirements.txt
├── scripts/
│   ├── smalst_evaluation_run.sh
│   ├── smalst_op_run.sh
│   └── smalst_train_run.sh
├── smal_eval.py
├── smal_model/
│   ├── __init__.py
│   ├── batch_lbs.py
│   ├── smal_basics.py
│   ├── smal_torch.py
│   ├── symIdx.pkl
│   ├── template_w_tex_uv.mtl
│   └── template_w_tex_uv.obj
├── smpl_models/
│   ├── my_smpl_00781_4_all.pkl
│   ├── my_smpl_00781_4_all_template_w_tex_uv_001.pkl
│   ├── my_smpl_data_00781_4_all.pkl
│   └── symIdx.pkl
├── testset_shape_experiments_crops.py
├── utils/
│   ├── __init__.py
│   ├── geometry.py
│   ├── image.py
│   ├── mesh.py
│   ├── obj2nmr.py
│   ├── smal_vis.py
│   ├── transformations.py
│   ├── visualizer.py
│   └── visutil.py
└── zebra_data/
    └── verts2kp.pkl
Download .txt
SYMBOL INDEX (289 symbols across 24 files)

FILE: data/smal_base.py
  class BaseDataset (line 59) | class BaseDataset(Dataset):
    method __init__ (line 64) | def __init__(self, opts, filter_key=None):
    method forward_img (line 84) | def forward_img(self, index):
    method normalize_kp (line 244) | def normalize_kp(self, kp, img_h, img_w, update_vis=False):
    method get_camera_projection_matrix (line 260) | def get_camera_projection_matrix(self, f, c):
    method my_project_points (line 272) | def my_project_points(self, ptsw, P):
    method my_anti_project_points (line 281) | def my_anti_project_points(self, ptsi, P):
    method get_model_trans_for_cropped_image (line 294) | def get_model_trans_for_cropped_image(self, trans, bbox, flength, img_...
    method crop_image (line 326) | def crop_image(self, img, mask, bbox, kp, vis, camera_params, model_tr...
    method scale_image (line 365) | def scale_image(self, img, mask, kp, vis, camera_params, occ_map, orig...
    method __len__ (line 407) | def __len__(self):
    method __getitem__ (line 410) | def __getitem__(self, index):
  function base_loader (line 459) | def base_loader(d_set_func, batch_size, opts, filter_key=None, shuffle=T...

FILE: data/zebra.py
  class ZebraDataset (line 47) | class ZebraDataset(base_data.BaseDataset):
    method __init__ (line 52) | def __init__(self, opts, filter_key=None, filter_name=None):
  function data_loader (line 128) | def data_loader(opts, shuffle=True, filter_name=None):
  function kp_data_loader (line 132) | def kp_data_loader(batch_size, opts):
  function mask_data_loader (line 136) | def mask_data_loader(batch_size, opts):
  function texture_map_data_loader (line 139) | def texture_map_data_loader(batch_size, opts):

FILE: experiments/smal_shape.py
  class ShapeTrainer (line 90) | class ShapeTrainer(train_utils.Trainer):
    method define_model (line 91) | def define_model(self):
    method init_dataset (line 144) | def init_dataset(self):
    method define_criterion (line 156) | def define_criterion(self):
    method set_optimization_input (line 187) | def set_optimization_input(self):
    method set_optimization_variables (line 195) | def set_optimization_variables(self):
    method set_input (line 211) | def set_input(self, batch):
    method forward (line 298) | def forward(self, opts_scale=None, opts_pose=None, opts_trans=None, op...
    method get_current_visuals (line 517) | def get_current_visuals(self):
    method get_current_points (line 574) | def get_current_points(self):
    method get_current_scalars (line 580) | def get_current_scalars(self):
  function main (line 613) | def main(_):

FILE: nnutils/geom_utils.py
  function sample_textures (line 10) | def sample_textures(texture_flow, images):
  function perspective_proj_withz (line 30) | def perspective_proj_withz(X, cam, offset_z=0, cuda_device=0,norm_f=1., ...

FILE: nnutils/loss_utils.py
  function texture_dt_loss (line 14) | def texture_dt_loss(texture_flow, dist_transf, vis_rend=None, cams=None,...
  function texture_loss (line 57) | def texture_loss(img_pred, img_gt, mask_pred, mask_gt):
  function uv_flow_loss (line 71) | def uv_flow_loss(uv_flow_pred, uv_flow_gt_w_mask):
  function mask_loss (line 82) | def mask_loss(mask_pred, mask_gt):
  function delta_v_loss (line 91) | def delta_v_loss(delta_v, delta_v_gt):
  function texture_map_loss (line 95) | def texture_map_loss(texture_map_pred, texture_map_gt, texture_map_mask,...
  function camera_loss (line 118) | def camera_loss(cam_pred, cam_gt, margin, normalized):
  function model_trans_loss (line 132) | def model_trans_loss(trans_pred, trans_gt):
  function model_pose_loss (line 140) | def model_pose_loss(pose_pred, pose_gt, opts):
  function betas_loss (line 168) | def betas_loss(betas_pred, betas_gt=None, prec=None):
  function hinge_loss (line 183) | def hinge_loss(loss, margin):
  function kp_l2_loss (line 188) | def kp_l2_loss(kp_pred, kp_gt):
  function keypoints_2D_loss (line 201) | def keypoints_2D_loss(kp_pred, kp_gt):
  function MSE_texture_loss (line 208) | def MSE_texture_loss(img_pred, img_gt, mask_pred, mask_gt, background_im...
  class PerceptualTextureLoss (line 218) | class PerceptualTextureLoss(object):
    method __init__ (line 219) | def __init__(self):
    method __call__ (line 223) | def __call__(self, img_pred, img_gt, mask_pred, mask_gt, background_im...

FILE: nnutils/net_blocks.py
  class Flatten (line 11) | class Flatten(nn.Module):
    method forward (line 12) | def forward(self, x):
  class Unsqueeze (line 15) | class Unsqueeze(nn.Module):
    method __init__ (line 16) | def __init__(self, dim):
    method forward (line 20) | def forward(self, x):
  function fc (line 24) | def fc(norm_type, nc_inp, nc_out):
  function fc_stack (line 37) | def fc_stack(nc_inp, nc_out, nlayers, norm_type='batch'):
  function fc_stack_dropout (line 46) | def fc_stack_dropout(nc_inp, nc_out, nlayers):
  function conv2d (line 72) | def conv2d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, nu...
  function deconv2d (line 92) | def deconv2d(in_planes, out_planes):
  function upconv2d (line 99) | def upconv2d(in_planes, out_planes, mode='bilinear'):
  function decoder2d (line 111) | def decoder2d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1...
  function conv3d (line 150) | def conv3d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, nu...
  function deconv3d (line 170) | def deconv3d(norm_type, in_planes, out_planes, num_groups=2):
  function encoder3d (line 191) | def encoder3d(nlayers, norm_type='batch', nc_input=1, nc_max=128, nc_l1=...
  function decoder3d (line 221) | def decoder3d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1...
  function net_init (line 254) | def net_init(net):
  function bilinear_init (line 284) | def bilinear_init(kernel_size=4):

FILE: nnutils/nmr.py
  function convert_as (line 19) | def convert_as(src, trg):
  class NMR (line 29) | class NMR(object):
    method __init__ (line 30) | def __init__(self):
    method to_gpu (line 35) | def to_gpu(self, device=0):
    method forward_mask (line 39) | def forward_mask(self, vertices, faces):
    method backward_mask (line 55) | def backward_mask(self, grad_masks):
    method forward_img (line 66) | def forward_img(self, vertices, faces, textures):
    method backward_img (line 84) | def backward_img(self, grad_images):
  class Render (line 100) | class Render(torch.autograd.Function):
    method __init__ (line 102) | def __init__(self, renderer):
    method forward (line 106) | def forward(self, vertices, faces, textures=None):
    method backward (line 122) | def backward(self, grad_out):
  class NeuralRenderer (line 140) | class NeuralRenderer(torch.nn.Module):
    method __init__ (line 146) | def __init__(self, img_size=256, proj_type='perspective', norm_f=1., n...
    method ambient_light_only (line 175) | def ambient_light_only(self):
    method directional_light_only (line 180) | def directional_light_only(self):
    method set_bgcolor (line 186) | def set_bgcolor(self, color):
    method project_points (line 189) | def project_points(self, verts, cams):
    method forward (line 193) | def forward(self, vertices, faces, cams, textures=None):

FILE: nnutils/perceptual_loss.py
  class PerceptualLoss (line 14) | class PerceptualLoss(object):
    method __init__ (line 15) | def __init__(self, model='net', net='alex', use_gpu=True):
    method __call__ (line 21) | def __call__(self, pred, target, normalize=True):

FILE: nnutils/smal_mesh_eval.py
  function preprocess_image (line 28) | def preprocess_image(img_path, img_size=256, kp=None, border=20):
  function smal_mesh_eval (line 56) | def smal_mesh_eval(num_train_epoch):
  function set_input (line 97) | def set_input(self, batch):
  function collect_outputs (line 114) | def collect_outputs(self):

FILE: nnutils/smal_mesh_net.py
  class ResNetConv (line 91) | class ResNetConv(nn.Module):
    method __init__ (line 92) | def __init__(self, n_blocks=4, opts=None):
    method forward (line 103) | def forward(self, x, y=None):
  class Encoder (line 127) | class Encoder(nn.Module):
    method __init__ (line 135) | def __init__(self, opts, input_shape, n_blocks=4, nz_feat=100, bott_si...
    method forward (line 151) | def forward(self, img, fg_img):
  class TexturePredictorUV (line 159) | class TexturePredictorUV(nn.Module):
    method __init__ (line 164) | def __init__(self, nz_feat, uv_sampler, opts, img_H=64, img_W=128, n_u...
    method forward (line 202) | def forward(self, feat):
  class ShapePredictor (line 250) | class ShapePredictor(nn.Module):
    method __init__ (line 255) | def __init__(self, nz_feat, num_verts, opts, left_idx, right_idx, shap...
    method forward (line 287) | def forward(self, feat):
  class PosePredictor (line 305) | class PosePredictor(nn.Module):
    method __init__ (line 308) | def __init__(self, opts, nz_feat, num_joints=35):
    method forward (line 314) | def forward(self, feat):
  class BetasPredictor (line 323) | class BetasPredictor(nn.Module):
    method __init__ (line 324) | def __init__(self, opts, nz_feat, nenc_feat, num_betas=10):
    method forward (line 329) | def forward(self, feat, enc_feat):
  class Keypoints2DPredictor (line 334) | class Keypoints2DPredictor(nn.Module):
    method __init__ (line 335) | def __init__(self, opts, nz_feat, nenc_feat, num_keypoints=28):
    method forward (line 341) | def forward(self, feat, enc_feat):
  class ScalePredictor (line 347) | class ScalePredictor(nn.Module):
    method __init__ (line 351) | def __init__(self, nz, opts):
    method forward (line 362) | def forward(self, feat):
  class TransPredictor (line 373) | class TransPredictor(nn.Module):
    method __init__ (line 378) | def __init__(self, nz, projection_type, opts):
    method forward (line 393) | def forward(self, feat):
  class CodePredictor (line 409) | class CodePredictor(nn.Module):
    method __init__ (line 410) | def __init__(self, nz_feat=100, nenc_feat=2048, num_verts=1000, opts=N...
    method forward (line 421) | def forward(self, feat, enc_feat):
  class MeshNet (line 449) | class MeshNet(nn.Module):
    method __init__ (line 450) | def __init__(self, input_shape, opts, nz_feat=100, num_kps=28, sfm_mea...
    method forward (line 561) | def forward(self, img, masks=None):
    method symmetrize (line 591) | def symmetrize(self, V):
    method get_smal_verts (line 609) | def get_smal_verts(self, pose=None, betas=None, trans=None, del_v=None):
    method get_mean_shape (line 620) | def get_mean_shape(self):

FILE: nnutils/smal_predictor.py
  class MeshPredictor (line 29) | class MeshPredictor(object):
    method __init__ (line 30) | def __init__(self, opts):
    method load_network (line 69) | def load_network(self, network, network_label, epoch_label):
    method set_input (line 78) | def set_input(self, batch):
    method predict (line 95) | def predict(self, batch, cam_gt=None, trans_gt=None, pose_gt=None, bet...
    method forward (line 103) | def forward(self, cam_gt=None, trans_gt=None, pose_gt=None, betas_gt=N...
    method collect_outputs (line 194) | def collect_outputs(self):

FILE: nnutils/train_utils.py
  function set_bn_eval (line 69) | def set_bn_eval(m):
  class Trainer (line 76) | class Trainer():
    method __init__ (line 77) | def __init__(self, opts):
    method save_network (line 92) | def save_network(self, network, network_label, epoch_label, gpu_id=None):
    method load_network (line 101) | def load_network(self, network, network_label, epoch_label, network_di...
    method define_model (line 109) | def define_model(self):
    method init_dataset (line 113) | def init_dataset(self):
    method define_criterion (line 117) | def define_criterion(self):
    method set_input (line 121) | def set_input(self, batch):
    method forward (line 125) | def forward(self):
    method save (line 129) | def save(self, epoch_prefix):
    method get_current_visuals (line 134) | def get_current_visuals(self):
    method get_current_scalars (line 138) | def get_current_scalars(self):
    method get_current_points (line 142) | def get_current_points(self):
    method init_training (line 146) | def init_training(self):
    method save_current (line 158) | def save_current(self, opts, initial_loss=0, final_loss=0, code='_'):
    method train (line 180) | def train(self):

FILE: smal_eval.py
  function get_bbox (line 62) | def get_bbox(img_path, kp):
  function preprocess_image (line 76) | def preprocess_image(img_path, img_size=256, kp=None, border=5, bgval=-1...
  function mirror_keypoints (line 126) | def mirror_keypoints(keyp, vis, img_w):
  function mirror_image (line 150) | def mirror_image(img):
  function visualize_opt (line 157) | def visualize_opt(img, predictor, renderer, data, out_path):
  function visualize (line 172) | def visualize(img, outputs, renderer, img_path=None, opts=None, cam_gt=0,
  function save_params (line 243) | def save_params(outputs, idx):
  function main (line 251) | def main(_):

FILE: smal_model/batch_lbs.py
  function batch_skew (line 9) | def batch_skew(vec, batch_size=None, opts=None):
  function batch_rodrigues (line 35) | def batch_rodrigues(theta, opts=None):
  function batch_lrotmin (line 56) | def batch_lrotmin(theta):
  function batch_global_rigid_transformation (line 77) | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = Fals...

FILE: smal_model/smal_basics.py
  function align_smal_template_to_symmetry_axis (line 9) | def align_smal_template_to_symmetry_axis(v):
  function load_smal_model (line 35) | def load_smal_model(model_name='my_smpl_00781_4_all.pkl'):
  function get_horse_template (line 44) | def get_horse_template(model_name='my_smpl_00781_4_all.pkl', data_name='...

FILE: smal_model/smal_torch.py
  function undo_chumpy (line 18) | def undo_chumpy(x):
  class SMAL (line 21) | class SMAL(object):
    method __init__ (line 22) | def __init__(self, pkl_path, opts, dtype=torch.float):
    method __call__ (line 68) | def __call__(self, beta, theta, trans=None, del_v=None, get_skin=True):

FILE: utils/geometry.py
  function triangle_direction_intersection (line 11) | def triangle_direction_intersection(tri, trg):
  function project_verts_on_mesh (line 42) | def project_verts_on_mesh(verts, mesh_verts, mesh_faces):

FILE: utils/image.py
  function resize_img (line 8) | def resize_img(img, scale_factor):
  function peturb_bbox (line 17) | def peturb_bbox(bbox, pf=0, jf=0):
  function square_bbox (line 40) | def square_bbox(bbox):
  function crop (line 60) | def crop(img, bbox, bgval=None):
  function compute_dt (line 97) | def compute_dt(mask):
  function compute_dt_barrier (line 105) | def compute_dt_barrier(mask, k=50):

FILE: utils/mesh.py
  function from_img_to_nmr (line 14) | def from_img_to_nmr(texture_map, ft, vt, tex_size):
  function from_uvflow_to_img (line 24) | def from_uvflow_to_img(imgs, uv_flows):
  function sample_texture (line 31) | def sample_texture(uv_map, uv_img):
  function sample_textures (line 52) | def sample_textures(texture_flow, images):
  function create_sphere (line 71) | def create_sphere(n_subdivide=3):
  function make_symmetric (line 78) | def make_symmetric(verts, faces, left_inds, right_inds, center_inds):
  function make_faces_symmetric (line 136) | def make_faces_symmetric(verts, faces, num_indept_verts, num_sym_verts):
  function compute_edges2verts (line 234) | def compute_edges2verts(verts, faces):
  function compute_vert2kp (line 251) | def compute_vert2kp(verts, mean_shape):
  function get_spherical_coords (line 268) | def get_spherical_coords(X):
  function compute_uvsampler_sphere (line 283) | def compute_uvsampler_sphere(verts, faces, tex_size=2):
  function compute_uvsampler (line 311) | def compute_uvsampler(verts, faces, vt, ft, tex_size=2):
  function append_obj (line 323) | def append_obj(mf_handle, vertices, faces):

FILE: utils/obj2nmr.py
  function obj2nmr_uvmap (line 29) | def obj2nmr_uvmap(ft, vt, tex_size=6):

FILE: utils/smal_vis.py
  class VisRenderer (line 18) | class VisRenderer(object):
    method __init__ (line 24) | def __init__(self, img_size, faces, proj_type, norm_f, norm_z, norm_f0...
    method __call__ (line 52) | def __call__(self, verts, cams=None, texture=None, rend_mask=False):
    method rotated (line 88) | def rotated(self, vert, deg, axis=[0, 1, 0], cam=None, texture=None):
    method set_bgcolor (line 102) | def set_bgcolor(self, color):
    method set_light_dir (line 105) | def set_light_dir(self, direction, int_dir=0.8, int_amb=0.8):
  function asVariable (line 112) | def asVariable(x):
  function convert_as (line 118) | def convert_as(src, trg):
  function convert2np (line 127) | def convert2np(x):
  function tensor2mask (line 137) | def tensor2mask(image_tensor, imtype=np.uint8):
  function kp2im (line 145) | def kp2im(kp, img, radius=None):
  function draw_kp (line 178) | def draw_kp(kp, img, radius=None):
  function vis_verts (line 217) | def vis_verts(mean_shape, verts, face, mvs=None, textures=None):
  function vis_vert2kp (line 253) | def vis_vert2kp(verts, vert2kp, face, mvs=None):
  function tensor2im (line 303) | def tensor2im(image_tensor, imtype=np.uint8, scale_to_range_1=False):
  function visflow (line 316) | def visflow(flow_img):
  function visflow_jonas (line 347) | def visflow_jonas(flow_img, img_size):

FILE: utils/transformations.py
  function identity_matrix (line 207) | def identity_matrix():
  function translation_matrix (line 222) | def translation_matrix(direction):
  function translation_from_matrix (line 235) | def translation_from_matrix(matrix):
  function reflection_matrix (line 247) | def reflection_matrix(point, normal):
  function reflection_from_matrix (line 273) | def reflection_from_matrix(matrix):
  function rotation_matrix (line 302) | def rotation_matrix(angle, direction, point=None):
  function rotation_from_matrix (line 346) | def rotation_from_matrix(matrix):
  function scale_matrix (line 386) | def scale_matrix(factor, origin=None, direction=None):
  function scale_from_matrix (line 420) | def scale_from_matrix(matrix):
  function projection_matrix (line 461) | def projection_matrix(point, normal, direction=None,
  function projection_from_matrix (line 523) | def projection_from_matrix(matrix, pseudo=False):
  function clip_matrix (line 596) | def clip_matrix(left, right, bottom, top, near, far, perspective=False):
  function shear_matrix (line 648) | def shear_matrix(angle, direction, point, normal):
  function shear_from_matrix (line 679) | def shear_from_matrix(matrix):
  function decompose_matrix (line 724) | def decompose_matrix(matrix):
  function compose_matrix (line 809) | def compose_matrix(scale=None, shear=None, angles=None, translate=None,
  function orthogonalization_matrix (line 862) | def orthogonalization_matrix(lengths, angles):
  function affine_matrix_from_points (line 889) | def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
  function superimposition_matrix (line 998) | def superimposition_matrix(v0, v1, scale=False, usesvd=True):
  function euler_matrix (line 1049) | def euler_matrix(ai, aj, ak, axes='sxyz'):
  function euler_from_matrix (line 1112) | def euler_from_matrix(matrix, axes='sxyz'):
  function euler_from_quaternion (line 1170) | def euler_from_quaternion(quaternion, axes='sxyz'):
  function quaternion_from_euler (line 1181) | def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
  function quaternion_about_axis (line 1238) | def quaternion_about_axis(angle, axis):
  function quaternion_matrix (line 1254) | def quaternion_matrix(quaternion):
  function quaternion_from_matrix (line 1281) | def quaternion_from_matrix(matrix, isprecise=False):
  function quaternion_multiply (line 1366) | def quaternion_multiply(quaternion1, quaternion0):
  function quaternion_conjugate (line 1383) | def quaternion_conjugate(quaternion):
  function quaternion_inverse (line 1397) | def quaternion_inverse(quaternion):
  function quaternion_real (line 1411) | def quaternion_real(quaternion):
  function quaternion_imag (line 1421) | def quaternion_imag(quaternion):
  function quaternion_slerp (line 1431) | def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
  function random_quaternion (line 1472) | def random_quaternion(rand=None):
  function random_rotation_matrix (line 1500) | def random_rotation_matrix(rand=None):
  class Arcball (line 1515) | class Arcball(object):
    method __init__ (line 1538) | def __init__(self, initial=None):
    method place (line 1563) | def place(self, center, radius):
    method setaxes (line 1576) | def setaxes(self, *axes):
    method constrain (line 1584) | def constrain(self):
    method constrain (line 1589) | def constrain(self, value):
    method down (line 1593) | def down(self, point):
    method drag (line 1603) | def drag(self, point):
    method next (line 1616) | def next(self, acceleration=0.0):
    method matrix (line 1621) | def matrix(self):
  function arcball_map_to_sphere (line 1626) | def arcball_map_to_sphere(point, center, radius):
  function arcball_constrain_to_axis (line 1639) | def arcball_constrain_to_axis(point, axis):
  function arcball_nearest_axis (line 1655) | def arcball_nearest_axis(point, axes):
  function vector_norm (line 1688) | def vector_norm(data, axis=None, out=None):
  function unit_vector (line 1727) | def unit_vector(data, axis=None, out=None):
  function random_vector (line 1771) | def random_vector(size):
  function vector_product (line 1786) | def vector_product(v0, v1, axis=0):
  function angle_between_vectors (line 1807) | def angle_between_vectors(v0, v1, directed=True, axis=0):
  function inverse_matrix (line 1838) | def inverse_matrix(matrix):
  function concatenate_matrices (line 1854) | def concatenate_matrices(*matrices):
  function is_same_transform (line 1870) | def is_same_transform(matrix0, matrix1):
  function is_same_quaternion (line 1886) | def is_same_quaternion(q0, q1):
  function _import_module (line 1893) | def _import_module(name, package=None, warn=True, prefix='_py_', ignore=...

FILE: utils/visualizer.py
  class Visualizer (line 9) | class Visualizer():
    method __init__ (line 10) | def __init__(self, opt):
    method print_message (line 24) | def print_message(self, text):
    method display_current_results (line 29) | def display_current_results(self, visuals, epoch):
    method plot_current_scalars (line 77) | def plot_current_scalars(self, epoch, counter_ratio, opt, scalars):
    method plot_current_points (line 93) | def plot_current_points(self, points, disp_offset=10):
    method print_current_scalars (line 102) | def print_current_scalars(self, epoch, i, scalars):
    method save_images (line 112) | def save_images(self, webpage, visuals, image_path):

FILE: utils/visutil.py
  function tensor2im (line 13) | def tensor2im(image_tensor, imtype=np.uint8):
  function tensor2kps (line 19) | def tensor2kps(kp_tensor):
  function tensor2verts (line 28) | def tensor2verts(vert_tensor):
  function tensor2im_batch (line 38) | def tensor2im_batch(image_tensor, num_batch, imtype=np.uint8):
  function undo_resnet_preprocess (line 49) | def undo_resnet_preprocess(image_tensor):
  function diagnose_network (line 57) | def diagnose_network(net, name='network'):
  function save_image (line 70) | def save_image(image_numpy, image_path):
  function info (line 74) | def info(object, spacing=10, collapse=1):
  function varname (line 84) | def varname(p):
  function print_numpy (line 90) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 100) | def mkdirs(paths):
  function mkdir (line 108) | def mkdir(path):
Condensed preview — 50 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (6,863K chars).
[
  {
    "path": "LICENSE.txt",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2018 silviazuffi\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "LICENSE_SMAL_MODEL.txt",
    "chars": 5987,
    "preview": "License\nSoftware Copyright License for non-commercial scientific research purposes\nPlease read carefully the following t"
  },
  {
    "path": "README.md",
    "chars": 2983,
    "preview": "# Three-D Safari: Learning to Estimate Zebra Pose, Shape, and Texture from Images \"In the Wild\"\n\nSilvia Zuffi<sup>1</sup"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "_config.yml",
    "chars": 25,
    "preview": "theme: jekyll-theme-slate"
  },
  {
    "path": "data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/smal_base.py",
    "chars": 17168,
    "preview": "\"\"\"\nBase data loading class.\n\nShould output:\n    - img: B X 3 X H X W\n    - kp: B X nKp X 2\n    - mask: B X H X W\n    # "
  },
  {
    "path": "data/zebra.py",
    "chars": 6147,
    "preview": "\"\"\"\n\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\ni"
  },
  {
    "path": "docs/tmp.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "experiments/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "experiments/smal_shape.py",
    "chars": 28200,
    "preview": "\"\"\"\n\nExample usage:\n\npython -m smalst.experiments.smal_shape --zebra_dir='smalst/zebra_no_toys_wtex_1000_0' --num_epochs"
  },
  {
    "path": "external/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "external/install_external.sh",
    "chars": 179,
    "preview": "git clone https://github.com/shubhtuls/PerceptualSimilarity\n\ngit clone https://github.com/hiroharu-kato/neural_renderer "
  },
  {
    "path": "nnutils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nnutils/geom_utils.py",
    "chars": 1470,
    "preview": "\"\"\"\nUtils related to geometry like projection,,\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import divisi"
  },
  {
    "path": "nnutils/loss_utils.py",
    "chars": 7618,
    "preview": "\"\"\"\nLoss Utils.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print"
  },
  {
    "path": "nnutils/net_blocks.py",
    "chars": 10721,
    "preview": "'''\nCNN building blocks.\nTaken from https://github.com/shubhtuls/factored3d/\n'''\nfrom __future__ import division\nfrom __"
  },
  {
    "path": "nnutils/nmr.py",
    "chars": 7287,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport num"
  },
  {
    "path": "nnutils/perceptual_loss.py",
    "chars": 924,
    "preview": "\"\"\"\nCalls Richard's Perceptual Loss.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __"
  },
  {
    "path": "nnutils/smal_mesh_eval.py",
    "chars": 4135,
    "preview": "\"\"\"\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nf"
  },
  {
    "path": "nnutils/smal_mesh_net.py",
    "chars": 27075,
    "preview": "\"\"\"\nMesh net model.\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import pr"
  },
  {
    "path": "nnutils/smal_predictor.py",
    "chars": 8312,
    "preview": "\"\"\"\nTakes an image, returns stuff.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __fu"
  },
  {
    "path": "nnutils/train_utils.py",
    "chars": 17585,
    "preview": "\"\"\"\nGeneric Training Utils.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ "
  },
  {
    "path": "requirements.txt",
    "chars": 342,
    "preview": "absl-py==0.1.10\nchainer==3.3.0\ncupy==2.3.0\nCython==0.27.3\nh5py==2.7.1\nimageio==2.2.0\nipdb==0.10.3\nmatplotlib==2.1.2\nmesh"
  },
  {
    "path": "scripts/smalst_evaluation_run.sh",
    "chars": 560,
    "preview": "# Example script to run the evaluation\n\npython -m smalst.smal_eval --name=smal_net_600 --img_path='smalst/zebra_testset/"
  },
  {
    "path": "scripts/smalst_op_run.sh",
    "chars": 995,
    "preview": "#!/bin/sh\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib\n\n# The directory processed contains the input imag"
  },
  {
    "path": "scripts/smalst_train_run.sh",
    "chars": 512,
    "preview": " \n#python -m smalst.experiments.smal_shape --zebra_dir='smalst/zebra_training_set' --num_epochs=40 --save_epoch_freq=20 "
  },
  {
    "path": "smal_eval.py",
    "chars": 17756,
    "preview": "\"\"\"\nEvaluation on the testset.\n\npython -m smalst.smal_eval --name=smal_net_600 --img_path='smalst/testset_zoo/' --num_tr"
  },
  {
    "path": "smal_model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "smal_model/batch_lbs.py",
    "chars": 4235,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tor"
  },
  {
    "path": "smal_model/smal_basics.py",
    "chars": 2277,
    "preview": "import os\nimport pickle as pkl\nimport numpy as np\nfrom smpl_webuser.serialization import load_model\nimport pickle as pkl"
  },
  {
    "path": "smal_model/smal_torch.py",
    "chars": 5157,
    "preview": "\"\"\"\n\n    PyTorch implementation of the SMAL/SMPL model\n\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ impor"
  },
  {
    "path": "smal_model/symIdx.pkl",
    "chars": 118711,
    "preview": "cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I3889\ntp6\ncnumpy\ndtype\np7\n(S'i8'\n"
  },
  {
    "path": "smal_model/template_w_tex_uv.mtl",
    "chars": 221,
    "preview": "# Blender MTL File: 'None'\n# Material Count: 1\n\nnewmtl my_mat\nNs 96.078431\nKa 0.000000 0.000000 0.000000\nKd 0.640000 0.6"
  },
  {
    "path": "smal_model/template_w_tex_uv.obj",
    "chars": 842006,
    "preview": "# Blender v2.75 (sub 0) OBJ File: ''\n# www.blender.org\nmtllib template_w_tex_uv.mtl\no my_smpl_0001_4_all_template\nv -0.8"
  },
  {
    "path": "smpl_models/my_smpl_00781_4_all_template_w_tex_uv_001.pkl",
    "chars": 594716,
    "preview": "(dp0\nS'vt'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I1\n(I5691\nI2\ntp8\ncnumpy"
  },
  {
    "path": "smpl_models/my_smpl_data_00781_4_all.pkl",
    "chars": 246211,
    "preview": "(dp0\nS'eigenvalues'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I1\n(I41\ntp8\ncn"
  },
  {
    "path": "smpl_models/symIdx.pkl",
    "chars": 118711,
    "preview": "cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I3889\ntp6\ncnumpy\ndtype\np7\n(S'i8'\n"
  },
  {
    "path": "testset_shape_experiments_crops.py",
    "chars": 11315,
    "preview": "bboxes = {'804454c0-4f2f-407a-a4fa-d2b38a88d50c_female_96c17794-e32d-4aa4-943a-193e17426ce5_right': [731,517,2540,2472],"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/geometry.py",
    "chars": 1639,
    "preview": "\"\"\"\nGeometry stuff.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import p"
  },
  {
    "path": "utils/image.py",
    "chars": 3573,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport cv2\n"
  },
  {
    "path": "utils/mesh.py",
    "chars": 11693,
    "preview": "\"\"\"\nMesh stuff.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print"
  },
  {
    "path": "utils/obj2nmr.py",
    "chars": 2024,
    "preview": "\"\"\"\nHelper for converting obj (vt, ft) texture to NMR texture format.\n\n\nTo use NMR, we need to supply a uv_map (F x T x "
  },
  {
    "path": "utils/smal_vis.py",
    "chars": 10891,
    "preview": "\"\"\"\nVisualization helpers specific to birds.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division"
  },
  {
    "path": "utils/transformations.py",
    "chars": 66389,
    "preview": "# -*- coding: utf-8 -*-\n# transformations.py\n\n# Copyright (c) 2006-2017, Christoph Gohlke\n# Copyright (c) 2006-2017, The"
  },
  {
    "path": "utils/visualizer.py",
    "chars": 5516,
    "preview": "'''Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix'''\nimport numpy as np\nimport os\nimport ntpa"
  },
  {
    "path": "utils/visutil.py",
    "chars": 3332,
    "preview": "'''Code from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix'''\nfrom __future__ import print_function\nimport tor"
  },
  {
    "path": "zebra_data/verts2kp.pkl",
    "chars": 3483709,
    "preview": "cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I28\nI3889\ntp6\ncnumpy\ndtype\np7\n(S'"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the silviazuffi/smalst GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 50 files (37.6 MB), approximately 1.4M tokens, and a symbol index with 289 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!