| \n", " | index | \n", "genome | \n", "identifier | \n", "file | \n", "clip | \n", "scale | \n", "sum_stat | \n", "description | \n", "
|---|---|---|---|---|---|---|---|---|
| 0 | \n", "0 | \n", "0 | \n", "ENCFF833POA | \n", "/home/drk/tillage/datasets/human/dnase/encode/... | \n", "32 | \n", "2 | \n", "mean | \n", "DNASE:cerebellum male adult (27 years) and mal... | \n", "
| 1 | \n", "1 | \n", "0 | \n", "ENCFF110QGM | \n", "/home/drk/tillage/datasets/human/dnase/encode/... | \n", "32 | \n", "2 | \n", "mean | \n", "DNASE:frontal cortex male adult (27 years) and... | \n", "
| 2 | \n", "2 | \n", "0 | \n", "ENCFF880MKD | \n", "/home/drk/tillage/datasets/human/dnase/encode/... | \n", "32 | \n", "2 | \n", "mean | \n", "DNASE:chorion | \n", "
| \n", " | chrom | \n", "pos | \n", "id | \n", "ref | \n", "alt | \n", "PC0 | \n", "PC1 | \n", "PC2 | \n", "PC3 | \n", "PC4 | \n", "PC5 | \n", "PC6 | \n", "PC7 | \n", "PC8 | \n", "PC9 | \n", "PC10 | \n", "PC11 | \n", "PC12 | \n", "PC13 | \n", "PC14 | \n", "PC15 | \n", "PC16 | \n", "PC17 | \n", "PC18 | \n", "PC19 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "chr1 | \n", "925952 | \n", "1019397 | \n", "G | \n", "A | \n", "13.865371 | \n", "9.379375 | \n", "1.322473 | \n", "7.198019 | \n", "11.926774 | \n", "-4.407538 | \n", "-5.878580 | \n", "-10.701156 | \n", "-3.140507 | \n", "2.994015 | \n", "4.716916 | \n", "1.098637 | \n", "2.569388 | \n", "13.693736 | \n", "8.564518 | \n", "9.383035 | \n", "-2.159512 | \n", "-9.733231 | \n", "7.727090 | \n", "-0.669298 | \n", "
| 1 | \n", "chr1 | \n", "930188 | \n", "846933 | \n", "G | \n", "A | \n", "-61.468933 | \n", "-5.653942 | \n", "-2.758731 | \n", "6.289482 | \n", "1.844845 | \n", "3.446712 | \n", "6.052454 | \n", "0.632046 | \n", "2.584915 | \n", "1.117951 | \n", "1.942497 | \n", "-6.513691 | \n", "-4.948788 | \n", "-1.172066 | \n", "-2.903700 | \n", "0.482835 | \n", "2.896945 | \n", "1.757883 | \n", "3.686084 | \n", "-6.673547 | \n", "
| 2 | \n", "chr1 | \n", "930200 | \n", "1043045 | \n", "G | \n", "A | \n", "-61.995975 | \n", "-10.007704 | \n", "-0.312641 | \n", "10.605079 | \n", "-5.349404 | \n", "-2.555728 | \n", "7.008485 | \n", "9.793589 | \n", "-14.216670 | \n", "4.411201 | \n", "0.295830 | \n", "-4.968991 | \n", "-10.770261 | \n", "-1.512434 | \n", "0.186349 | \n", "7.461446 | \n", "5.153117 | \n", "4.041459 | \n", "0.512155 | \n", "-2.865725 | \n", "
| 3 | \n", "chr1 | \n", "930203 | \n", "972363 | \n", "C | \n", "T | \n", "21.486368 | \n", "-9.248886 | \n", "3.684767 | \n", "9.606424 | \n", "0.538447 | \n", "5.264496 | \n", "1.418155 | \n", "-14.005326 | \n", "12.898662 | \n", "9.391730 | \n", "-5.201692 | \n", "-3.091272 | \n", "1.975370 | \n", "-5.240757 | \n", "-14.367105 | \n", "-7.802942 | \n", "0.138479 | \n", "12.087408 | \n", "-5.559704 | \n", "9.171222 | \n", "
| 4 | \n", "chr1 | \n", "930222 | \n", "998906 | \n", "GAACTC | \n", "TTCTTCTG | \n", "13.672197 | \n", "181.645172 | \n", "-302.586548 | \n", "184.414001 | \n", "146.373199 | \n", "284.204163 | \n", "-108.902489 | \n", "100.794731 | \n", "-205.568008 | \n", "302.769409 | \n", "203.416458 | \n", "111.947685 | \n", "-61.380695 | \n", "222.271515 | \n", "152.539993 | \n", "114.129166 | \n", "-26.604349 | \n", "-68.656372 | \n", "-36.595196 | \n", "38.175354 | \n", "
| \n", " | chrom | \n", "pos | \n", "id | \n", "ref | \n", "alt | \n", "DNASE:cerebellum mal | \n", "DNASE:frontal cortex | \n", "DNASE:chorion | \n", "DNASE:Ishikawa treat | \n", "DNASE:GM03348 | \n", "DNASE:GM03348 geneti | \n", "DNASE:AG08395 | \n", "DNASE:AG08396 | \n", "DNASE:AG20443 | \n", "DNASE:H54 | \n", "DNASE:GM10248 | \n", "DNASE:GM12878 | \n", "DNASE:GM12891 | \n", "DNASE:GM12892 | \n", "DNASE:GM18507 | \n", "DNASE:GM19238 | \n", "DNASE:GM19239 | \n", "DNASE:GM19240 | \n", "DNASE:H1-hESC | \n", "DNASE:H7-hESC | \n", "DNASE:H9 | \n", "DNASE:heart male adu | \n", "DNASE:HEK293T | \n", "DNASE:HeLa-S3 treate | \n", "DNASE:HeLa-S3 | \n", "DNASE:hepatocyte | \n", "DNASE:HepG2 | \n", "DNASE:HTR-8/SVneo | \n", "DNASE:endothelial ce | \n", "DNASE:CWRU1 male | \n", "DNASE:iPS-NIHi11 mal | \n", "DNASE:iPS-NIHi7 fema | \n", "DNASE:K562 treated w | \n", "DNASE:K562 G2 phase | \n", "DNASE:K562 G1 phase | \n", "... | \n", "CAGE:CD14+CD16- Mono | \n", "CAGE:achilles tendon | \n", "CAGE:cerebrospinal f | \n", "CAGE:cruciate ligame | \n", "CAGE:eye - vitreous | \n", "CAGE:eye - muscle su | \n", "CAGE:eye - muscle la | \n", "CAGE:eye - muscle me | \n", "CAGE:eye - muscle in | \n", "CAGE:Fingernail (inc | \n", "CAGE:optic nerve, | \n", "CAGE:Skin - palm, | \n", "CAGE:tongue epidermi | \n", "CAGE:Urethra, | \n", "CAGE:CD14+ monocytes | \n", "CAGE:Hep-2 cells tre | \n", "CAGE:Hep-2 cells moc | \n", "CAGE:immature langer | \n", "CAGE:migratory lange | \n", "CAGE:CD34 cells diff | \n", "CAGE:amygdala - adul | \n", "CAGE:thalamus - adul | \n", "CAGE:hippocampus - a | \n", "CAGE:parietal lobe - | \n", "CAGE:cerebellum - ad | \n", "CAGE:pineal gland - | \n", "CAGE:spinal cord - a | \n", "CAGE:Olfactory epith | \n", "CAGE:gamma delta pos | \n", "CAGE:Mast cell, expa | \n", "CAGE:adipose, | \n", "CAGE:cerebellum, new | \n", "CAGE:spinal cord, ne | \n", "CAGE:amygdala, newbo | \n", "CAGE:hippocampus, ne | \n", "CAGE:putamen, newbor | \n", "CAGE:thalamus, newbo | \n", "CAGE:thymic carcinom | \n", "CAGE:Smooth muscle c | \n", "CAGE:parietal cortex | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "chr1 | \n", "925952 | \n", "1019397 | \n", "G | \n", "A | \n", "-3.953313 | \n", "-2.132398 | \n", "-1.133056 | \n", "-0.018265 | \n", "-7.324341 | \n", "-5.928950 | \n", "-2.696301 | \n", "-2.234811 | \n", "-1.869620 | \n", "16.743450 | \n", "2.218565 | \n", "3.167206 | \n", "1.084149 | \n", "0.671318 | \n", "6.513158 | \n", "0.971147 | \n", "6.255165 | \n", "5.352125 | \n", "1.729454 | \n", "-3.350782 | \n", "1.896389 | \n", "-1.563089 | \n", "-2.268677 | \n", "11.647644 | \n", "14.272476 | \n", "-7.074303 | \n", "1.850793 | \n", "5.803924 | \n", "0.381484 | \n", "-0.443644 | \n", "-0.299062 | \n", "-0.139204 | \n", "-8.861604 | \n", "-8.072408 | \n", "-6.212577 | \n", "... | \n", "-0.046439 | \n", "4.083508 | \n", "3.535412 | \n", "11.717014 | \n", "-5.239547 | \n", "1.081644 | \n", "2.213174 | \n", "0.853902 | \n", "-0.216257 | \n", "2.654844 | \n", "0.402727 | \n", "1.567667 | \n", "3.049335 | \n", "1.010953 | \n", "-0.800231 | \n", "157.583069 | \n", "172.648773 | \n", "0.576450 | \n", "5.026437 | \n", "0.464598 | \n", "2.133233 | \n", "1.689052 | \n", "2.324542 | \n", "1.849092 | \n", "-1.319847 | \n", "-25.130344 | \n", "2.456997 | \n", "11.093707 | \n", "-4.529293 | \n", "5.072914 | \n", "-6.317924 | \n", "-2.482907 | \n", "1.591787 | \n", "-5.362659 | \n", "0.643638 | \n", "-1.936050 | \n", "2.499822 | \n", "12.372701 | \n", "23.365688 | \n", "2.915666 | \n", "
| 1 | \n", "chr1 | \n", "930188 | \n", "846933 | \n", "G | \n", "A | \n", "-1.791559 | \n", "-15.679683 | \n", "-22.533566 | \n", "0.523771 | \n", "-3.826311 | \n", "-3.219894 | \n", "-2.668196 | \n", "-9.468966 | \n", "-3.158970 | \n", "-7.704255 | \n", "-3.150757 | \n", "-0.523286 | \n", "-11.126574 | \n", "-11.116987 | \n", "-3.674609 | \n", "-8.464477 | \n", "-11.757355 | \n", "-8.595486 | \n", "-0.536293 | \n", "0.227047 | \n", "-0.779173 | \n", "-15.736004 | \n", "0.879350 | \n", "-6.572251 | \n", "-4.200795 | \n", "-8.630462 | \n", "-0.465244 | \n", "-3.442806 | \n", "-0.311721 | \n", "-4.801476 | \n", "-5.145035 | \n", "-1.415882 | \n", "0.362791 | \n", "-1.997213 | \n", "-2.307346 | \n", "... | \n", "-33.052082 | \n", "-2.994223 | \n", "-8.416894 | \n", "-5.190420 | \n", "-4.534584 | \n", "-5.453042 | \n", "-3.034095 | \n", "-7.163062 | \n", "-2.291493 | \n", "-8.985706 | \n", "-4.789318 | \n", "-4.794244 | \n", "-1.449592 | \n", "-1.989321 | \n", "-22.916094 | \n", "-35.752350 | \n", "-35.275730 | \n", "-13.375345 | \n", "-28.358622 | \n", "-31.596754 | \n", "-11.371334 | \n", "-9.254264 | \n", "-9.006978 | \n", "-9.230368 | \n", "-3.783889 | \n", "-0.283448 | \n", "-12.094806 | \n", "-59.170837 | \n", "-66.303123 | \n", "-34.475914 | \n", "-21.272675 | \n", "-4.328580 | \n", "-16.075752 | \n", "-6.281566 | \n", "-10.389462 | \n", "-6.422119 | \n", "-9.181828 | \n", "-14.224935 | \n", "-46.487968 | \n", "-11.499000 | \n", "
| 2 | \n", "chr1 | \n", "930200 | \n", "1043045 | \n", "G | \n", "A | \n", "-0.526763 | \n", "-14.688834 | \n", "-17.484657 | \n", "2.908407 | \n", "2.734465 | \n", "3.514852 | \n", "0.294236 | \n", "-4.154023 | \n", "1.062205 | \n", "-7.452735 | \n", "-2.387097 | \n", "-0.988034 | \n", "-7.148329 | \n", "-7.479169 | \n", "-3.070057 | \n", "-9.537191 | \n", "-12.243765 | \n", "-10.340226 | \n", "0.070186 | \n", "0.338647 | \n", "-0.706960 | \n", "-13.103579 | \n", "3.004640 | \n", "-8.224607 | \n", "-7.412826 | \n", "-9.011602 | \n", "0.814177 | \n", "-2.572933 | \n", "0.154302 | \n", "-4.661432 | \n", "-3.152854 | \n", "-1.105172 | \n", "-2.336678 | \n", "-4.403462 | \n", "-4.830977 | \n", "... | \n", "-26.200890 | \n", "-4.082983 | \n", "-11.688838 | \n", "-5.358507 | \n", "-5.728800 | \n", "-6.592316 | \n", "-5.169149 | \n", "-8.596480 | \n", "-3.361437 | \n", "-6.800953 | \n", "-4.874091 | \n", "-4.676669 | \n", "-2.054912 | \n", "-2.659639 | \n", "-19.059345 | \n", "-86.237938 | \n", "-90.280151 | \n", "-13.633442 | \n", "-32.443680 | \n", "-19.412699 | \n", "-7.266642 | \n", "-7.957299 | \n", "-6.725488 | \n", "-6.803071 | \n", "-6.814243 | \n", "1.440345 | \n", "-8.739142 | \n", "-70.331505 | \n", "-39.735146 | \n", "-31.107134 | \n", "-18.125492 | \n", "-3.016214 | \n", "-8.161813 | \n", "-12.665734 | \n", "-9.362435 | \n", "-5.167178 | \n", "-3.976753 | \n", "-12.334543 | \n", "-38.066742 | \n", "-8.078856 | \n", "
| 3 | \n", "chr1 | \n", "930203 | \n", "972363 | \n", "C | \n", "T | \n", "3.887291 | \n", "-0.802319 | \n", "-1.669034 | \n", "4.747784 | \n", "1.695362 | \n", "2.217486 | \n", "-0.282765 | \n", "-1.139245 | \n", "0.368520 | \n", "-3.772724 | \n", "3.466754 | \n", "1.096701 | \n", "2.536208 | \n", "5.108255 | \n", "1.874467 | \n", "3.965143 | \n", "1.232345 | \n", "3.818713 | \n", "0.651470 | \n", "2.554193 | \n", "0.991104 | \n", "1.639103 | \n", "5.514153 | \n", "-6.208551 | \n", "-13.487192 | \n", "6.362380 | \n", "0.164872 | \n", "-4.553425 | \n", "-0.160736 | \n", "-1.997855 | \n", "-2.654027 | \n", "0.416074 | \n", "7.220230 | \n", "4.069194 | \n", "2.957357 | \n", "... | \n", "33.843147 | \n", "-0.583508 | \n", "0.264617 | \n", "-3.204525 | \n", "5.382181 | \n", "1.291355 | \n", "0.997021 | \n", "2.093935 | \n", "-0.294896 | \n", "4.008999 | \n", "-0.024318 | \n", "-0.550827 | \n", "-1.001716 | \n", "0.569551 | \n", "27.447952 | \n", "-22.858391 | \n", "-32.015335 | \n", "12.032137 | \n", "15.755173 | \n", "19.405922 | \n", "2.799939 | \n", "3.406132 | \n", "1.644471 | \n", "2.383641 | \n", "-1.271341 | \n", "35.166897 | \n", "-0.529095 | \n", "-15.446460 | \n", "42.712193 | \n", "22.990261 | \n", "-2.007191 | \n", "1.718304 | \n", "1.981185 | \n", "5.209938 | \n", "6.387944 | \n", "3.160304 | \n", "0.638121 | \n", "-0.786704 | \n", "-4.402274 | \n", "3.053887 | \n", "
| 4 | \n", "chr1 | \n", "930222 | \n", "998906 | \n", "GAACTC | \n", "TTCTTCTG | \n", "-6.190521 | \n", "-137.378281 | \n", "34.811737 | \n", "-1.299490 | \n", "-44.415581 | \n", "-15.935372 | \n", "-6.835675 | \n", "-44.718884 | \n", "12.754320 | \n", "71.546776 | \n", "-100.758392 | \n", "-128.016174 | \n", "-168.724716 | \n", "-330.777496 | \n", "-142.454132 | \n", "-350.230988 | \n", "-131.790482 | \n", "-207.531921 | \n", "-7.376725 | \n", "-45.294678 | \n", "-16.531658 | \n", "-183.512756 | \n", "10.864501 | \n", "18.745724 | \n", "9.173240 | \n", "15.418909 | \n", "13.582534 | \n", "24.021170 | \n", "-13.055187 | \n", "41.811131 | \n", "57.707191 | \n", "46.170353 | \n", "-125.677353 | \n", "-66.172997 | \n", "-53.709393 | \n", "... | \n", "45.992233 | \n", "1.093487 | \n", "-150.802185 | \n", "16.499266 | \n", "24.087925 | \n", "-66.276947 | \n", "-92.977158 | \n", "-139.676208 | \n", "-51.160301 | \n", "-36.617786 | \n", "-4.993864 | \n", "5.801898 | \n", "-10.315315 | \n", "-5.569825 | \n", "353.592102 | \n", "544.979797 | \n", "542.144470 | \n", "44.670349 | \n", "-91.620117 | \n", "-66.408241 | \n", "65.320213 | \n", "27.775331 | \n", "36.321232 | \n", "27.018156 | \n", "-37.838741 | \n", "-20.264828 | \n", "-15.695430 | \n", "-72.608513 | \n", "-119.713341 | \n", "-22.978678 | \n", "20.147364 | \n", "3.596015 | \n", "20.581316 | \n", "-82.511292 | \n", "10.583510 | \n", "99.163254 | \n", "87.440636 | \n", "28.396255 | \n", "26.421061 | \n", "27.204330 | \n", "
5 rows × 3378 columns
\n", "
The data used in the "Stanford 3D Objects" section of the experimental results
is available in [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/s3o4d).
The data consists of 100,000 renderings each of the Bunny and Dragon objects
from the [Stanford 3D Scanning Repository](http://graphics.stanford.edu/data/3Dscanrep/).
More objects may be added in the future, but only the Bunny and Dragon are used
in the paper. Each object is rendered with a uniformly sampled illumination from
a point on the 2-sphere, and a uniformly sampled 3D rotation. The true latent
states are provided as NumPy arrays along with the images. The lighting is given
as a 3-vector with unit norm, while the rotation is provided both as a
quaternion and a 3x3 orthogonal matrix.
### Why another dataset?
There are many similarities between S3O4D and existing ML benchmark datasets
like [NORB](https://cs.nyu.edu/~ylclab/data/norb-v1.0/),
[3D Chairs](https://github.com/mathieuaubry/seeing3Dchairs),
[3D Shapes](https://github.com/deepmind/3d-shapes) and many others, which also
include renderings of a set of objects under different pose and illumination
conditions. However, none of these existing datasets include the *full manifold*
of rotations in 3D - most include only a subset of changes to elevation and
azimuth. S3O4D images are sampled uniformly and independently from the full
space of rotations and illuminations, meaning the dataset contains objects that
are upside down and illuminated from behind or underneath. We believe that this
makes S3O4D uniquely suited for research on generative models where the latent
space has non-trivial topology, as well as for general manifold learning
methods where the curvature of the manifold is important.
### Usage
To load from TensorFlow Datasets, simply run:
```
import tensorflow_datasets as tfds
ds = tfds.load('s3o4d', split='bunny_train', shuffle_files=True)
for example in ds.take(1):
image, label, illumination, pose_mat, pose_quat = (
example['image'], example['label'], example['illumination'],
example['pose_mat'], example['pose_quat'])
```
where the split can be any of `bunny_train`, `dragon_train`, `bunny_test` or
`dragon_test`.
If you prefer to not have TensorFlow as a dependency for your project, and want
to download the data manually, you can find the raw data (as zipped JPEGs and
NumPy arrays) on [Google Cloud](https://console.cloud.google.com/storage/browser/dm_s3o4d).
To load the data for a given object, unzip `images.zip` into a folder called
`images` in the same directory as `latents.npz`, and from inside that
directory run:
```
import numpy as np
from PIL import Image
with open('latents.npz', 'r') as f:
data = np.load(f)
illumination = data['illumination'] # lighting source position, a 3-vector
pose_quat = data['pose_quat'] # object pose (3D rotation as a quaternion)
pose_mat = data['pose_mat'] # object pose (3D rotation as a matrix)
def get_data(i):
"""Return data and latent given an index up to 100,000."""
img = np.array(Image.open(f'images/{i:05}.jpg'))
# Uses the matrix, not quaternion, representation,
# similarly to the experiments in the paper
latent = np.concatenate((illumination[i],
pose_mat[i].reshape(-1)))
return img, latent
img, latent = get_data(0)
```
To do the same train/test split as in TensorFlow Datasets, simply use the first
80,000 images for each object as training data and the last 20,000 as test.
## Giving Credit
If you use this code or the Stanford 3D Objects for Disentangling data in your
work, we ask you to cite this paper:
```
@article{pfau2020disentangling,
title={Disentangling by Subspace Diffusion},
author={Pfau, David and Higgins, Irina and Botev, Aleksandar and Racani\`ere,
S{\'e}bastian},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```
## Disclaimer
This is not an official Google product.
================================================
FILE: geomancer/data_writer.py
================================================
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data writer for Stanford Bunny experiments and other objects."""
# pylint: disable=unused-import
import copy
import io
import os
import time
from absl import app
from absl import flags
from absl import logging
from dm_control import mujoco
import numpy as np
_SHARD = flags.DEFINE_integer('shard', 0, 'Shard index')
_SIZE = flags.DEFINE_integer('size', 1000,
'Number of images to save to a shard')
_OBJECT = flags.DEFINE_string('object', 'dragon', 'Which object to render')
_PATH = flags.DEFINE_string('path', '', 'Path to folder with .stl files')
render_height = 1024
render_width = 1024
height = 256
width = 256
def get_normal(x):
"""Get vectors normal to a unit vector."""
_, _, v = np.linalg.svd(x[None, :])
return v[:, 1:]
def render(quat, light, mesh='bunny', meshdir='data'):
"""Script to render an image."""
scale, pos = None, None
if mesh == 'bunny':
scale = 0.03
pos = -1.0
elif mesh == 'dragon':
scale = 0.06
pos = -0.3
simple_world_mjcf_template = """
We also make available the code for linear evaluation of a pre-trained model
in UCF101 and the JAX checkpoints for our best models.
We use different parameters for video compression in UCF101 than the ones
used in `tensorflow_datasets`. We provide the code to download and
preprocess the dataset. The eval_ucf101.py script reproduces the results we
report in Table 2 of the paper, using the checkpoints provided below.
Visual Backbone | Training Dataset | Results on Linear UCF101
------- | -------- | --------
S3D-G | AudioSet + HowTo | 89.6
Resnet TSM-50 | AudioSet + HowTo | 91.5
Resnet TSM-50 (x2) | AudioSet + HowTo | 91.8
## Setup
To set up a Python virtual environment with the required dependencies, run:
```shell
python3 -m venv mmv_env
source mmv_env/bin/activate
pip install --upgrade pip setuptools wheel
pip install -r mmv/requirements.txt --use-feature=2020-resolver
```
### Linear evaluation
The linear evaluation on UCF101 can be run using:
```shell
python -m mmv.eval_ucf101 \
--checkpoint_path= \
--dataset_folder=
```
## Checkpoints
We provide three checkpoints containing the best pre-trained weights for each
of the visual backbones we use in the paper, i. e., S3D-G, Resnet-50 TSM,
and Resnet-50 TSM x 2.
- [S3D-G](https://storage.googleapis.com/deepmind-research-mmv/mmv_s3d.pkl)
- [Resnet-50 TSM](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x1.pkl)
- [Resnet-50 TSMx2](https://storage.googleapis.com/deepmind-research-mmv/mmv_tsm_resnet_x2.pkl)
## References
### Citing our work
If you use that code for your research, please consider citing our paper:
```bibtex
@inproceedings{alayrac2020self,
title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks},
author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew},
booktitle={NeurIPS},
year={2020}
}
```
### Models in TF
You may also be interested in using our TF-Hub release models available at:
- [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1)
- [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1)
- [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50x2/1)
## License
While the code is licensed under the Apache 2.0 License, the checkpoints weights
are made available for non-commercial use only under the terms of the
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
license. You can find details at:
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
================================================
FILE: mmv/config.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration parameters for MMV."""
def get_model_config(ckpt_path):
"""Returns the model configuration to be used with each checkpoint."""
config = {
'audio_backbone': 'resnet50',
'audio_model_kwargs': {
'bn_config': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
}
},
'bn_config_proj': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
},
'config_audio_text': {
'embedding_dim': 512,
'toaud_bn_after_proj': False,
'toaud_head_mode': 'linear',
'totxt_bn_after_proj': False,
'totxt_head_mode': 'linear'
},
'config_video_audio': {
'embedding_dim': 512,
'toaud_bn_after_proj': True,
'toaud_head_mode': 'mlp@512',
'tovid_bn_after_proj': False,
'tovid_head_mode': 'linear'
},
'config_video_text': {
'embedding_dim': 256,
'totxt_bn_after_proj': True,
'totxt_head_mode': 'linear',
'tovid_bn_after_proj': False,
'tovid_head_mode': 'linear'
},
'mm_embedding_graph': 'fac_relu',
'name': 'text_audio_video',
'sentence_dim': 2048,
'use_xreplica_bn': True,
'vision_model_kwargs': {
'bn_config': {
'create_offset': True,
'create_scale': True,
'decay_rate': 0.9,
'eps': 1.0e-5
},
'n_frames': 32,
'width_mult': 1,
},
}
if 's3d' in ckpt_path:
config['visual_backbone'] = 's3d'
if 'tsm_resnet_x1' in ckpt_path:
config['visual_backbone'] = 'resnet50tsm'
if 'tsm_resnet_x2' in ckpt_path:
config['visual_backbone'] = 'resnet50tsm'
config['vision_model_kwargs']['width_mult'] = 2
return config
================================================
FILE: mmv/eval_ucf101.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UCF101 linear evaluation."""
import functools
from typing import Any, Dict, Optional
from absl import app
from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import sklearn
from sklearn import preprocessing
import sklearn.linear_model
import sklearn.svm
import tensorflow as tf
import tensorflow_datasets as tfds
from mmv import config
from mmv.models import mm_embeddings
from mmv.utils import checkpoint
from mmv.utils import ucf101_dataset
flags.DEFINE_string('checkpoint_path', '~/tmp/mmv_s3d.pkl',
'The directory to load pre-trained weights from.')
flags.DEFINE_string('dataset_folder', '/tmp/ucf101',
'The directory with the ucf101 dataset.')
flags.DEFINE_integer('eval_batch_size', 1,
'The batch size for evaluation.')
flags.DEFINE_integer('train_batch_size', 16,
'The batch size for training.')
flags.DEFINE_integer('num_train_epochs', 10,
'How many epochs to collect features during training.')
flags.DEFINE_integer('num_test_windows', 10,
'How many windows to average on during test.')
flags.DEFINE_integer('min_resize', 224,
'Min value to resize images to during preprocessing.')
flags.DEFINE_integer('crop_size', 224,
'Value to resize images to during preprocessing.')
flags.DEFINE_integer('num_frames', 32,
'Number of video frames.')
flags.DEFINE_integer('stride', 2,
'Stride for video frames.')
flags.DEFINE_integer('ucf101_split', 1,
'Which split of ucf101 to use.')
FLAGS = flags.FLAGS
def get_sampling_offset(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
stride: int = 1,
seed: Optional[int] = None) -> tf.Tensor:
"""Calculates the initial offset for a sequence where all steps will fit.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: The number of timesteps we will output. If None,
deterministically start at the first frame.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame always the first frame.
stride: distance to sample between timesteps.
seed: a deterministic seed to use when sampling.
Returns:
The first index to begin sampling from. A best effort is made to provide a
starting index such that all requested steps fit within the sequence (i.e.
offset + 1 + (num_steps - 1) * stride < len(sequence)). If this is not
satisfied, the starting index is chosen randomly from the full sequence.
"""
if num_steps is None or not is_training:
return tf.constant(0)
sequence_length = tf.shape(sequence)[0]
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
offset = tf.random.uniform(
(),
maxval=tf.cast(max_offset, tf.int32),
dtype=tf.int32,
seed=seed)
return offset
def sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
repeat_sequence: bool = True,
stride: int = 1,
offset: Optional[int] = None) -> tf.Tensor:
"""Returns indices to take for sampling or padding a sequence to fixed size.
Samples num_steps from the sequence. If the sequence is shorter than
num_steps, the sequence loops. If the sequence is longer than num_steps and
is_training is True, then we seek to a random offset before sampling. If
offset is provided, we use that deterministic offset.
This method is appropriate for sampling from a tensor where you want every
timestep between a start and end time. See sample_stacked_sequence_indices for
more flexibility.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: how many steps (e.g. frames) to take. If None, all steps from
start to end are considered and `is_training` has no effect.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame is deterministic.
repeat_sequence: A boolean indicates whether the sequence will repeat to
have enough steps for sampling. If False, a runtime error is thrown if
num_steps * stride is longer than sequence length.
stride: distance to sample between timesteps.
offset: a deterministic offset to use regardless of the is_training value.
Returns:
Indices to gather from the sequence Tensor to get a fixed size sequence.
"""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
if num_steps:
if offset is None:
offset = get_sampling_offset(sequence, num_steps, is_training, stride)
if repeat_sequence:
# Repeats sequence until num_steps are available in total.
num_repeats = tf.cast(
tf.math.ceil(
tf.math.divide(
tf.cast(num_steps * stride + offset, tf.float32),
tf.cast(sequence_length, tf.float32)
)), tf.int32)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
else:
steps = tf.range(0, sequence_length, stride)
return tf.gather(sel_idx, steps)
def random_sample_sequence(sequence: tf.Tensor,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Randomly sample a segment of size num_steps from a given sequence."""
indices = sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=True, # Random sample.
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=None)
indices.set_shape((num_steps,))
output = tf.gather(sequence, indices)
return output
def sample_linspace_sequence(sequence: tf.Tensor,
num_windows: int,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Samples num_windows segments from sequence with linearly spaced offsets.
The samples are concatenated in a single Tensor in order to have the same
format structure per timestep (e.g. a single frame). If num_steps * stride is
bigger than the number of timesteps, the sequence is repeated. This function
can be used in evaluation in order to extract enough segments in order to span
the entire sequence.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_windows: Number of windows retrieved from the sequence.
num_steps: Number of steps (e.g. frames) to take.
stride: Distance to sample between timesteps.
Returns:
A single Tensor with first dimension num_windows * num_steps. The Tensor
contains the concatenated list of num_windows tensors which offsets have
been linearly spaced from input.
"""
sequence_length = tf.shape(sequence)[0]
max_offset = tf.maximum(0, sequence_length - num_steps * stride)
offsets = tf.linspace(0.0, tf.cast(max_offset, tf.float32), num_windows)
offsets = tf.cast(offsets, tf.int32)
all_indices = []
for i in range(num_windows):
all_indices.append(
sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=False,
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=offsets[i]))
indices = tf.concat(all_indices, axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor:
"""Resizes frames so that min(height, width) is equal to min_resize.
This function will not do anything if the min(height, width) is already equal
to min_resize. This allows to save compute time.
Args:
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
min_resize: Minimum size of the final image dimensions.
Returns:
A Tensor of shape [timesteps, output_h, output_w, channels] of type
frames.dtype where min(output_h, output_w) = min_resize.
"""
shape = tf.shape(frames)
input_h = shape[1]
input_w = shape[2]
output_h = tf.maximum(min_resize, (input_h * min_resize) // input_w)
output_w = tf.maximum(min_resize, (input_w * min_resize) // input_h)
def resize_fn():
frames_resized = tf.image.resize(frames, (output_h, output_w))
return tf.cast(frames_resized, frames.dtype)
should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w),
tf.not_equal(input_h, output_h))
frames = tf.cond(should_resize, resize_fn, lambda: frames)
return frames
def process_samples(features_dict, num_frames=32, stride=1, is_training=True,
min_resize=224, crop_size=224, num_windows=1):
"""Process video frames."""
video = features_dict['video']
if is_training:
assert num_windows == 1
video = random_sample_sequence(video, num_frames, stride)
is_flipped = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
video = tf.cond(tf.equal(is_flipped, 1),
true_fn=lambda: tf.image.flip_left_right(video),
false_fn=lambda: video)
else:
video = sample_linspace_sequence(video, num_windows, num_frames, stride)
# Resize smallest side.
video = resize_smallest(video, min_resize)
if is_training:
# Random crop.
video = tf.image.random_crop(video, [num_frames, crop_size, crop_size, 3])
else:
# Central crop.
video = tf.image.resize_with_crop_or_pad(video, crop_size, crop_size)
video = tf.cast(video, tf.float32)
video /= 255.0 # Set between [0, 1].
features_dict['video'] = video
return features_dict
def space_to_depth_batch(features_dict):
images = features_dict['video']
_, l, h, w, c = images.shape
images = tf.reshape(images, [-1, l // 2, 2, h // 2, 2, w // 2, 2, c])
images = tf.transpose(images, [0, 1, 3, 5, 2, 4, 6, 7])
images = tf.reshape(images, [-1, l // 2, h // 2, w // 2, 8 * c])
features_dict['video'] = images
return features_dict
def reshape_windows(features_dict, num_frames):
x = features_dict['video']
x = tf.reshape(x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4]))
features_dict['video'] = x
return features_dict
def compute_accuracy_metrics(pred, gt, prefix=''):
order_pred = np.argsort(pred, axis=1)
assert len(gt.shape) == len(order_pred.shape) == 2
top1_pred = order_pred[:, -1:]
top5_pred = order_pred[:, -5:]
top1_acc = np.mean(top1_pred == gt)
top5_acc = np.mean(np.max(top5_pred == gt, 1))
return {prefix + 'top1': top1_acc,
prefix + 'top5': top5_acc}
def forward_fn(images: jnp.ndarray,
audio_spectrogram: jnp.ndarray,
word_ids: jnp.ndarray,
is_training: bool,
model_config: Dict[str, Any]):
"""Forward pass of the model."""
# This should contain the pre-trained weights. We set it to zero because it
# will be loaded from the checkpoint.
language_model_vocab_size = 65536
word_embedding_dim = 300
dummy_embedding_matrix = jnp.zeros(shape=(language_model_vocab_size,
word_embedding_dim))
module = mm_embeddings.AudioTextVideoEmbedding(
**model_config,
word_embedding_matrix=dummy_embedding_matrix)
return module(images=images,
audio_spectrogram=audio_spectrogram,
word_ids=word_ids,
is_training=is_training)['vid_repr']
def main(argv):
del argv
sklearn_reg = 0.001
model_config = config.get_model_config(FLAGS.checkpoint_path)
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
forward_apply = jax.jit(functools.partial(forward.apply,
is_training=False,
model_config=model_config))
# Get the UCF101 config.
dset_config = tfds.video.ucf101.Ucf101.BUILDER_CONFIGS[FLAGS.ucf101_split]
builder = ucf101_dataset.ModUcf101(
data_dir=FLAGS.dataset_folder,
config=dset_config)
# Create the tfrecord files (no-op if already exists)
dl_config = tfds.download.DownloadConfig(verify_ssl=False)
builder.download_and_prepare(download_config=dl_config)
# Generate the training dataset.
train_ds = builder.as_dataset(split='train', shuffle_files=False)
train_ds = train_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=True,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size))
train_ds = train_ds.batch(batch_size=FLAGS.train_batch_size)
if model_config['visual_backbone'] == 's3d':
train_ds = train_ds.map(space_to_depth_batch)
train_ds = train_ds.repeat(FLAGS.num_train_epochs)
# Generate the test dataset.
test_ds = builder.as_dataset(split='test', shuffle_files=False)
test_ds = test_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=False,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size,
num_windows=FLAGS.num_test_windows))
test_ds = test_ds.batch(batch_size=FLAGS.eval_batch_size)
test_ds = test_ds.map(lambda x: reshape_windows( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames))
if model_config['visual_backbone'] == 's3d':
test_ds = test_ds.map(space_to_depth_batch)
test_ds = test_ds.repeat(1)
pretrained_weights = checkpoint.load_checkpoint(FLAGS.checkpoint_path)
params = pretrained_weights['params']
state = pretrained_weights['state']
# Collect training samples.
audio_frames = 96
mel_filters = 40
num_tokens = 16
dummy_audio = jnp.zeros(
shape=(FLAGS.train_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.train_batch_size, num_tokens), dtype=jnp.int32)
train_features = []
train_labels = []
print('Computing features on train')
training_examples = iter(tfds.as_numpy(train_ds))
for train_ex in training_examples:
vid_representation, _ = forward_apply(params=params,
state=state,
images=train_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
train_labels.append(train_ex['label'])
train_features.append(vid_representation)
if len(train_labels) % 50 == 0:
print(f'Processed {len(train_labels)} examples.')
train_labels = np.concatenate(train_labels, axis=0)
train_features = np.concatenate(train_features, axis=0)
print(f'Finish collecting train features of shape {train_features.shape}')
# Collect test samples.
dummy_audio = jnp.zeros(
shape=(FLAGS.eval_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.eval_batch_size, num_tokens), dtype=jnp.int32)
test_features = []
test_labels = []
print('Computing features on test')
test_examples = iter(tfds.as_numpy(test_ds))
for test_ex in test_examples:
vid_representation_test, _ = forward_apply(params=params,
state=state,
images=test_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
test_labels.append(test_ex['label'])
test_features.append(vid_representation_test)
if len(test_labels) % 50 == 0:
print(f'Processed {len(test_labels)} examples.')
test_features = np.concatenate(test_features, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
print(f'Finish collecting test features of shape {test_features.shape}')
# Train classifier
print('Training linear classifier!')
classifier = sklearn.svm.LinearSVC(C=sklearn_reg)
scaler = preprocessing.StandardScaler().fit(train_features)
train_features = scaler.transform(train_features)
classifier.fit(train_features, train_labels.ravel())
print('Training done !')
# Evaluation.
test_features = scaler.transform(test_features)
print('Running inference on train')
pred_train = classifier.decision_function(train_features)
print('Running inference on test')
pred_test = classifier.decision_function(test_features)
if FLAGS.num_test_windows > 1:
pred_test = np.reshape(
pred_test, (test_labels.shape[0], -1, pred_test.shape[1]))
pred_test = pred_test.mean(axis=1)
# Compute accuracies.
metrics = compute_accuracy_metrics(pred_train, train_labels[:, None],
prefix='train_')
metrics.update(
compute_accuracy_metrics(pred_test, test_labels[:, None], prefix='test_'))
print(metrics)
if __name__ == '__main__':
app.run(main)
================================================
FILE: mmv/models/mm_embeddings.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model for text-video-audio embeddings."""
from typing import Any, Dict, Optional
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import normalization
from mmv.models import resnet
from mmv.models import s3d
from mmv.models import tsm_resnet
_DEFAULT_CFG_AUDTXT = {
"totxt_head_mode": "linear",
"toaud_head_mode": "linear",
"toaud_bn_after_proj": False,
"totxt_bn_after_proj": False,
"embedding_dim": 512}
_DEFAULT_CFG_VIDAUD = {
"tovid_head_mode": "linear",
"toaud_head_mode": "mlp@512",
"tovid_bn_after_proj": False,
"toaud_bn_after_proj": True,
"embedding_dim": 512}
_DEFAULT_CFG_VIDTXT = {
"tovid_head_mode": "linear",
"totxt_head_mode": "mlp@512",
"tovid_bn_after_proj": False,
"totxt_bn_after_proj": True,
"embedding_dim": 512}
_DEFAULT_CFG_BN = {"decay_rate": 0.9, "eps": 1e-5,
"create_scale": True, "create_offset": True}
def _setkey_if_not_exists(d, key, value):
if key not in d:
d[key] = value
class AudioTextVideoEmbedding(hk.Module):
"""Module to fuse audio, text and video for joint embedding learning."""
def __init__(
self,
# Language parameters.
word_embedding_matrix,
sentence_dim=2048,
# Audio parameters.
audio_backbone="resnet18",
audio_model_kwargs=None,
# Vision parameters.
visual_backbone="s3d",
vision_model_kwargs=None,
# Common parameters.
mm_embedding_graph="fac_relu",
use_xreplica_bn=True,
bn_config_proj=None,
config_video_text=None,
config_video_audio=None,
config_audio_text=None,
use_audio_text=False,
name="audio_text_video_model"):
"""Initialize the AudioTextVideoEmbedding class.
Args:
word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words.
sentence_dim: The dimension of the sentence representation.
audio_backbone: Backbone for audio.
audio_model_kwargs: Other specific parameters to pass to the audio
module.
visual_backbone: The video backbone.
vision_model_kwargs: Other specific parameters to pass to the vision
module.
mm_embedding_graph: Embedding graph merging strategy.
Can be `shared`, `disjoint` or `fac` (fac can be followed by an
activation function name e.g. `fac_relu`).
use_xreplica_bn: Whether or not to use the cross replica batch norm.
bn_config_proj: BN config of the projection heads.
config_video_text: Config for the video and the text branches.
config_video_audio: Config for the video and the audio branches.
config_audio_text: Config for the audio and the text branches.
use_audio_text: Whether or not the audio text branch is used during
training.
name: graph name.
"""
super(AudioTextVideoEmbedding, self).__init__(name=name)
# Audio parameters.
self._audio_backbone = audio_backbone
self._audio_model_kwargs = audio_model_kwargs
# Language parameters.
self._sentence_dim = sentence_dim
self._word_embedding_matrix = word_embedding_matrix
# Vision parameters.
self._visual_backbone = visual_backbone
self._vision_model_kwargs = vision_model_kwargs
# Joint parameters.
self._use_xreplica_bn = use_xreplica_bn
if self._use_xreplica_bn:
self._normalizer_name = "cross_replica_batch_norm"
else:
self._normalizer_name = "batch_norm"
# Projection head parameters.
if config_video_text is None:
config_video_text = _DEFAULT_CFG_VIDTXT
for k, v in _DEFAULT_CFG_VIDTXT.items():
_setkey_if_not_exists(config_video_text, k, v)
self._cfg_vid_txt = config_video_text
if config_video_audio is None:
config_video_audio = _DEFAULT_CFG_VIDAUD
for k, v in _DEFAULT_CFG_VIDAUD.items():
_setkey_if_not_exists(config_video_audio, k, v)
self._cfg_vid_aud = config_video_audio
if config_audio_text is None:
config_audio_text = _DEFAULT_CFG_AUDTXT
for k, v in _DEFAULT_CFG_AUDTXT.items():
_setkey_if_not_exists(config_audio_text, k, v)
self._cfg_aud_txt = config_audio_text
self._use_audio_text = use_audio_text
self._mm_embedding_graph = mm_embedding_graph
self._use_separate_heads = (
mm_embedding_graph == "disjoint" or
mm_embedding_graph.startswith("fac"))
self._bn_config_proj = bn_config_proj or _DEFAULT_CFG_BN
def _get_pair_embedding_heads(self,
embedding_dim_1, embedding_dim_2,
mode1, mode2,
use_bn_out1, use_bn_out2,
name1, name2):
embd1_module = EmbeddingModule(
embedding_dim_1,
mode=mode1,
use_bn_out=use_bn_out1,
bn_config=self._bn_config_proj,
use_xreplica_bn=self._use_xreplica_bn,
name=name1)
if self._use_separate_heads:
embd2_module = EmbeddingModule(
embedding_dim_2,
mode=mode2,
use_bn_out=use_bn_out2,
use_xreplica_bn=self._use_xreplica_bn,
bn_config=self._bn_config_proj,
name=name2)
else:
assert embedding_dim_1 == embedding_dim_2, (
"Using shared heads but inconsistent embedding dims where provided.")
assert mode1 == mode2, (
"Using shared heads but inconsistent modes where provided.")
assert use_bn_out1 == use_bn_out2, (
"Using shared heads but inconsistent bn conf where provided.")
embd2_module = embd1_module
return embd1_module, embd2_module
def _activate_interaction(self, inputs, activation_fn, is_training,
activation_module=None):
"""Activation function for the interaction modules."""
if activation_fn == "relu":
inputs = jax.nn.relu(inputs)
elif activation_fn == "bnrelu":
if activation_module is None:
activation_module = normalization.get_normalize_fn(
normalizer_name=self._normalizer_name,
normalizer_kwargs=self._bn_config_proj)
inputs = activation_module(inputs, is_training=is_training)
inputs = jax.nn.relu(inputs)
else:
raise ValueError(f"{activation_fn} not supported.")
return inputs, activation_module
def __call__(self,
images,
audio_spectrogram,
word_ids,
is_training,
return_intermediate_audio=False):
"""Computes video, text and audio embeddings.
Args:
images: The videos tensor of shape [B1, T, H, W, 3] where B1 is the batch
size, T is the number of frames per clip, H the height, W the width
and 3 the rgb channels.
audio_spectrogram: The audio tensor of shape [B2, T', F] where B2 is the
batch size, T' is the number of temporal frames, F is the number of
frequency frames.
word_ids: If words_embeddings is set to None, it will use the word indices
input instead so that we can compute the word embeddings within the
model graph. The expected shape is [B3, N, D] where B3 is the batch size
and N the maximum number of words per sentence.
is_training: Whether or not to activate the graph in training mode.
return_intermediate_audio: Return audio intermediate representation.
Returns:
if return_intermediate_audio = True
audio_representation: the 4-dim audio representation taken before
averaging over spatial dims in the Resnet.
else
visual_embd: a dict containing the video embeddings in audio and text
of shape [B1, d_embd].
audio_embd: a dict containing the audio embeddings in video and text
of shape [B2, d_embd].
txt_embd: a dict containing the text embeddings in video and audio
of shape[B3, d_embd].
visual_representation: the video rep of shape [B1, d_visual].
audio_representation: the audio rep of shape [B2, d_audio].
"""
# Computes the visual representation.
video_cnn = VisualModule(backbone=self._visual_backbone,
use_xreplica_bn=self._use_xreplica_bn,
model_kwargs=self._vision_model_kwargs)
visual_representation = video_cnn(images, is_training=is_training)
# Projection heads: Video -> Text and Video -> Audio.
vid2txt_embd_module, vid2aud_embd_module = self._get_pair_embedding_heads(
embedding_dim_1=self._cfg_vid_txt["embedding_dim"],
embedding_dim_2=self._cfg_vid_aud["embedding_dim"],
mode1=self._cfg_vid_txt["totxt_head_mode"],
mode2=self._cfg_vid_aud["toaud_head_mode"],
use_bn_out1=self._cfg_vid_txt["totxt_bn_after_proj"],
use_bn_out2=self._cfg_vid_aud["toaud_bn_after_proj"],
name1="vis_embd",
name2="vid2audio_embd")
video_embd = {}
if self._mm_embedding_graph in ["shared", "disjoint"]:
video_embd["toaud"] = vid2aud_embd_module(visual_representation,
is_training=is_training)
video_embd["totxt"] = vid2txt_embd_module(visual_representation,
is_training=is_training)
elif self._mm_embedding_graph.startswith("fac"):
# Activation function if specificed in the name, e.g. fac_relu.
activation_fn = None
if len(self._mm_embedding_graph.split("_")) == 2:
activation_fn = self._mm_embedding_graph.split("_")[1]
video_embd["toaud"] = vid2aud_embd_module(visual_representation,
is_training=is_training)
fine_rep = video_embd["toaud"]
# Eventually activate the fine grained representation.
if activation_fn:
fine_rep, activation_module = self._activate_interaction(
inputs=fine_rep, activation_fn=activation_fn,
is_training=is_training)
video_embd["totxt"] = vid2txt_embd_module(fine_rep,
is_training=is_training)
else:
raise ValueError(
f"{self._mm_embedding_graph} is not a valid MM embedding graph.")
# Computes the audio representation.
audio_cnn = AudioModule(backbone=self._audio_backbone,
use_xreplica_bn=self._use_xreplica_bn,
model_kwargs=self._audio_model_kwargs)
if return_intermediate_audio:
return audio_cnn(audio_spectrogram,
is_training=is_training,
return_intermediate=True)
audio_representation = audio_cnn(audio_spectrogram, is_training=is_training)
# Projection heads: Audio -> Video and Audio -> Text.
aud2vid_embd_module, aud2txt_embd_module = self._get_pair_embedding_heads(
embedding_dim_1=self._cfg_vid_aud["embedding_dim"],
embedding_dim_2=self._cfg_aud_txt["embedding_dim"],
mode1=self._cfg_vid_aud["tovid_head_mode"],
mode2=self._cfg_aud_txt["totxt_head_mode"],
use_bn_out1=self._cfg_vid_aud["tovid_bn_after_proj"],
use_bn_out2=self._cfg_aud_txt["totxt_bn_after_proj"],
name1="audio_embd",
name2="audio2txt_embd")
audio_embd = {}
audio_embd["tovid"] = aud2vid_embd_module(audio_representation,
is_training=is_training)
# Computes the projection to the text domain depending on the MM graph mode.
if (self._mm_embedding_graph.startswith("fac") and
(self._use_audio_text or (not is_training))):
# In case the audio text branch is not used during training, we do that
# only at eval time (is_training=False) in order to not pollute the BN
# stats in vid2txt_embd_module with audio features during training.
fine_rep_audio = audio_embd["tovid"]
if activation_fn:
fine_rep_audio, _ = self._activate_interaction(
inputs=fine_rep_audio, activation_fn=activation_fn,
is_training=is_training, activation_module=activation_module)
audio_embd["totxt"] = vid2txt_embd_module(fine_rep_audio,
is_training=is_training)
else:
audio_embd["totxt"] = aud2txt_embd_module(audio_representation,
is_training=is_training)
# Computes the text representation.
txt_representation = TextModule(
sentence_dim=self._sentence_dim,
word_embedding_matrix=self._word_embedding_matrix)(
word_ids, is_training=is_training)
# Projection heads: Text -> Video and Text -> Audio.
txt2vid_embd_module, txt2aud_embd_module = self._get_pair_embedding_heads(
embedding_dim_1=self._cfg_vid_txt["embedding_dim"],
embedding_dim_2=self._cfg_aud_txt["embedding_dim"],
mode1=self._cfg_vid_txt["tovid_head_mode"],
mode2=self._cfg_aud_txt["toaud_head_mode"],
use_bn_out1=self._cfg_vid_txt["tovid_bn_after_proj"],
use_bn_out2=self._cfg_aud_txt["toaud_bn_after_proj"],
name1="txt_embd",
name2="txt2audio_embd")
txt_embd = {}
txt_embd["tovid"] = txt2vid_embd_module(txt_representation,
is_training=is_training)
txt_embd["toaud"] = txt2aud_embd_module(txt_representation,
is_training=is_training)
return {
"vid_embd": video_embd,
"aud_embd": audio_embd,
"txt_embd": txt_embd,
"vid_repr": visual_representation,
"aud_repr": audio_representation,
}
class EmbeddingModule(hk.Module):
"""Final Embedding module."""
def __init__(self,
embedding_dim: int,
mode: str = "linear",
use_bn_out: bool = False,
bn_config: Optional[Dict[str, Any]] = None,
use_xreplica_bn: bool = True,
name="embedding_module"):
self._embedding_dim = embedding_dim
self._use_bn_out = use_bn_out
self._mode = mode
# Set default BN config.
bn_config = bn_config or _DEFAULT_CFG_BN
if use_xreplica_bn:
normalizer_name = "cross_replica_batch_norm"
else:
normalizer_name = "batch_norm"
self._batch_norm = normalization.get_normalize_fn(
normalizer_name=normalizer_name,
normalizer_kwargs=bn_config)
super(EmbeddingModule, self).__init__(name=name)
def __call__(self, input_feature, is_training):
if self._mode == "linear":
proj = hk.Linear(self._embedding_dim, name="final_projection")
embedding = proj(input_feature)
elif self._mode.startswith("mlp"):
if "@" not in self._mode:
raise ValueError(
("Please specify the inner dimensions of the MLP with `@` symbol"
"e.g. mlp@512 or mlp@512@256 for a 2 layer MLP."))
inner_dims = [int(dim) for dim in self._mode.split("@")[1:]]
embedding = input_feature
for inner_dim in inner_dims:
embedding = hk.Linear(inner_dim, with_bias=True,
name="final_projection_inner")(embedding)
if not self._mode.startswith("mlp_nobn"):
embedding = self._batch_norm(embedding, is_training=is_training)
embedding = jax.nn.relu(embedding)
# Final projection.
embedding = hk.Linear(self._embedding_dim, name="final_projection",
with_bias=not self._use_bn_out)(embedding)
else:
raise NotImplementedError
if self._use_bn_out:
embedding = self._batch_norm(embedding, is_training=is_training)
return embedding
class VisualModule(hk.Module):
"""The visual module selects which CNN backbone to connect to the graph."""
def __init__(self,
use_xreplica_bn=True,
backbone="s3d",
model_kwargs=None,
name="visual_module"):
self._backbone = backbone
super(VisualModule, self).__init__(name=name)
if model_kwargs is None:
model_kwargs = {}
bn_config = model_kwargs.get("bn_config", _DEFAULT_CFG_BN)
if use_xreplica_bn:
normalizer_name = "cross_replica_batch_norm"
else:
normalizer_name = "batch_norm"
normalize_fn = normalization.get_normalize_fn(
normalizer_name=normalizer_name,
normalizer_kwargs=bn_config)
if backbone == "s3d":
self._cnn = s3d.S3D(normalize_fn=normalize_fn)
elif backbone == "resnet50tsm":
width_mult = model_kwargs.get("width_mult", 1)
self._cnn = tsm_resnet.TSMResNetV2(
normalize_fn=normalize_fn,
depth=50,
num_frames=model_kwargs["n_frames"],
width_mult=width_mult)
else:
raise NotImplementedError
def __call__(self, images, is_training):
"""Connects graph to images."""
features = self._cnn(images, is_training=is_training)
return features
class AudioModule(hk.Module):
"""The audio module selects which CNN backbone to connect to the graph."""
def __init__(self,
backbone="resnet18",
use_xreplica_bn=True,
model_kwargs=None,
name="audio_module"):
super(AudioModule, self).__init__(name=name)
model_kwargs = model_kwargs or {}
bn_config = model_kwargs.get("bn_config", _DEFAULT_CFG_BN)
backbone_to_depth = {
"resnet18": 18,
"resnet34": 34,
"resnet50": 50,
"resnet101": 101
}
assert backbone in backbone_to_depth, (
f"backbone should be in {backbone_to_depth.keys()}")
if use_xreplica_bn:
normalizer_name = "cross_replica_batch_norm"
else:
normalizer_name = "batch_norm"
self._cnn = resnet.ResNetV2(
depth=backbone_to_depth[backbone],
normalize_fn=normalization.get_normalize_fn(
normalizer_name=normalizer_name,
normalizer_kwargs=bn_config),
num_classes=None)
def __call__(self,
audio_spectrogram,
is_training,
return_intermediate=False):
"""Connects graph to audio spectrogram."""
final_endpoint = "output"
if return_intermediate:
final_endpoint = "last_conv"
return self._cnn(audio_spectrogram,
is_training=is_training,
final_endpoint=final_endpoint)
class TextModule(hk.Module):
"""Text module computes the sentences representation."""
def __init__(self,
word_embedding_matrix,
sentence_dim=1024,
name="text_module"):
"""Initialize text module.
Args:
word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words.
sentence_dim: dimension of sentence representation.
name: module name.
"""
super(TextModule, self).__init__(name=name)
self._word_embedding_module = hk.Embed(
embedding_matrix=word_embedding_matrix)
self._conv1d_module = hk.Conv1D(sentence_dim, 1, name="text_conv1")
def __call__(self, word_ids, is_training):
"""Connects graph to sentence representation."""
word_embeddings = self._word_embedding_module(word_ids)
word_embeddings = jax.lax.stop_gradient(word_embeddings)
output = self._conv1d_module(word_embeddings)
output = jax.nn.relu(output)
output = jnp.amax(output, axis=1)
return output
================================================
FILE: mmv/models/normalization.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize functions constructors."""
from typing import Any, Dict, Optional, Sequence, Union
import haiku as hk
from jax import numpy as jnp
from mmv.models import types
class _BatchNorm(hk.BatchNorm):
"""A `hk.BatchNorm` with adapted default arguments."""
def __init__(self,
create_scale: bool = True,
create_offset: bool = True,
decay_rate: float = 0.9,
eps: float = 1e-5,
test_local_stats: bool = False,
**kwargs):
# Check args.
if kwargs.get('cross_replica_axis', None) is not None:
raise ValueError(
'Attempting to use \'batch_norm\' normalizer, but specifying '
'`cross_replica_axis`. If you want this behavior use '
'`normalizer=\'cross_replica_batch_norm\'` directly.')
self._test_local_stats = test_local_stats
super().__init__(create_scale=create_scale,
create_offset=create_offset,
decay_rate=decay_rate,
eps=eps,
**kwargs)
def __call__(self,
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
return super().__call__(x, is_training,
test_local_stats=self._test_local_stats)
class _CrossReplicaBatchNorm(hk.BatchNorm):
"""A `hk.BatchNorm` with adapted default arguments for cross replica."""
def __init__(self,
create_scale: bool = True,
create_offset: bool = True,
decay_rate: float = 0.9,
eps: float = 1e-5,
test_local_stats: bool = False,
**kwargs):
# Check args.
if 'cross_replica_axis' in kwargs and kwargs['cross_replica_axis'] is None:
raise ValueError(
'Attempting to use \'cross_replica_batch_norm\' normalizer, but '
'specifying `cross_replica_axis` to be None. If you want this '
'behavior use `normalizer=\'batch_norm\'` directly.')
self._test_local_stats = test_local_stats
kwargs['cross_replica_axis'] = kwargs.get('cross_replica_axis', 'i')
super().__init__(create_scale=create_scale,
create_offset=create_offset,
decay_rate=decay_rate,
eps=eps,
**kwargs)
def __call__(self,
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
return super().__call__(x, is_training,
test_local_stats=self._test_local_stats)
class _LayerNorm(hk.LayerNorm):
"""A `hk.LayerNorm` accepting (and discarding) an `is_training` argument."""
def __init__(self,
axis: Union[int, Sequence[int]] = (1, 2),
create_scale: bool = True,
create_offset: bool = True,
**kwargs):
super().__init__(axis=axis,
create_scale=create_scale,
create_offset=create_offset,
**kwargs)
def __call__(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
x: types.TensorLike,
is_training: bool) -> jnp.ndarray:
del is_training # Unused.
return super().__call__(x)
_NORMALIZER_NAME_TO_CLASS = {
'batch_norm': _BatchNorm,
'cross_replica_batch_norm': _CrossReplicaBatchNorm,
'layer_norm': _LayerNorm,
}
def get_normalize_fn(
normalizer_name: str = 'batch_norm',
normalizer_kwargs: Optional[Dict[str, Any]] = None,
) -> types.NormalizeFn:
"""Handles NormalizeFn creation.
These functions are expected to be used as part of Haiku model. On each
application of the returned normalization_fn, a new Haiku layer will be added
to the model.
Args:
normalizer_name: The name of the normalizer to be constructed.
normalizer_kwargs: The kwargs passed to the normalizer constructor.
Returns:
A `types.NormalizeFn` that when applied will create a new layer.
Raises:
ValueError: If `normalizer_name` is unknown.
"""
# Check args.
if normalizer_name not in _NORMALIZER_NAME_TO_CLASS:
raise ValueError(f'Unrecognized `normalizer_name` {normalizer_name}.')
normalizer_class = _NORMALIZER_NAME_TO_CLASS[normalizer_name]
normalizer_kwargs = normalizer_kwargs or dict()
return lambda *a, **k: normalizer_class(**normalizer_kwargs)(*a, **k) # pylint: disable=unnecessary-lambda
================================================
FILE: mmv/models/resnet.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet V2 modules.
Equivalent to hk.Resnet except accepting a final_endpoint to return
intermediate activations.
"""
from typing import Optional, Sequence, Text, Type, Union
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import types
class BottleneckBlock(hk.Module):
"""Implements a bottleneck residual block (ResNet50 and ResNet101)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BottleneckBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
self._conv_2 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_2')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class BasicBlock(hk.Module):
"""Implements a basic residual block (ResNet18 and ResNet34)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BasicBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class ResNetUnit(hk.Module):
"""Unit (group of blocks) for ResNet."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
num_blocks: int,
stride: Union[int, Sequence[int]],
block_module: Type[BottleneckBlock],
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
super(ResNetUnit, self).__init__(name=name)
self._channels = channels
self._num_blocks = num_blocks
self._stride = stride
self._normalize_fn = normalize_fn
self._block_module = block_module
self._remat = remat
def __call__(self,
inputs,
is_training):
input_channels = inputs.shape[-1]
self._blocks = []
for id_block in range(self._num_blocks):
use_projection = id_block == 0 and self._channels != input_channels
self._blocks.append(
self._block_module(
channels=self._channels,
stride=self._stride if id_block == 0 else 1,
use_projection=use_projection,
normalize_fn=self._normalize_fn,
name='block_%d' % id_block))
net = inputs
for block in self._blocks:
if self._remat:
# Note: we can ignore cell-var-from-loop because the lambda is evaluated
# inside every iteration of the loop. This is needed to go around the
# way variables are passed to jax.remat.
net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop
else:
net = block(net, is_training=is_training)
return net
class ResNetV2(hk.Module):
"""ResNetV2 model."""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'resnet_stem',
'resnet_unit_0',
'resnet_unit_1',
'resnet_unit_2',
'resnet_unit_3',
'last_conv',
'output',
)
# pylint:disable=g-bare-generic
def __init__(self,
depth=50,
num_classes: Optional[int] = 1000,
width_mult: int = 1,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
"""Creates ResNetV2 Haiku module.
Args:
depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
num_classes: (int) Number of outputs in final layer. If None will not add
a classification head and will return the output embedding.
width_mult: multiplier for channel width.
normalize_fn: normalization function, see helpers/utils.py
name: Name of the module.
remat: Whether to rematerialize intermediate activations (saves memory).
"""
super(ResNetV2, self).__init__(name=name)
self._normalize_fn = normalize_fn
self._num_classes = num_classes
self._width_mult = width_mult
self._strides = [1, 2, 2, 2]
num_blocks = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
200: [3, 24, 36, 3],
}
if depth not in num_blocks:
raise ValueError(
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
self._num_blocks = num_blocks[depth]
if depth >= 50:
self._block_module = BottleneckBlock
self._channels = [256, 512, 1024, 2048]
else:
self._block_module = BasicBlock
self._channels = [64, 128, 256, 512]
self._initial_conv = hk.Conv2D(
output_channels=64 * self._width_mult,
kernel_shape=7,
stride=2,
with_bias=False,
padding='SAME',
name='initial_conv')
if remat:
self._initial_conv = hk.remat(self._initial_conv)
self._block_groups = []
for i in range(4):
self._block_groups.append(
ResNetUnit(
channels=self._channels[i] * self._width_mult,
num_blocks=self._num_blocks[i],
block_module=self._block_module,
stride=self._strides[i],
normalize_fn=self._normalize_fn,
name='block_group_%d' % i,
remat=remat))
if num_classes is not None:
self._logits_layer = hk.Linear(
output_size=num_classes, w_init=jnp.zeros, name='logits')
def __call__(self, inputs, is_training, final_endpoint='output'):
self._final_endpoint = final_endpoint
net = self._initial_conv(inputs)
net = hk.max_pool(
net, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')
end_point = 'resnet_stem'
if self._final_endpoint == end_point:
return net
for i_group, block_group in enumerate(self._block_groups):
net = block_group(net, is_training=is_training)
end_point = f'resnet_unit_{i_group}'
if self._final_endpoint == end_point:
return net
end_point = 'last_conv'
if self._final_endpoint == end_point:
return net
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
# The actual representation
net = jnp.mean(net, axis=[1, 2])
assert self._final_endpoint == 'output'
if self._num_classes is None:
# If num_classes was None, we just return the output
# of the last block, without fully connected layer.
return net
return self._logits_layer(net)
================================================
FILE: mmv/models/s3d.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Haiku S3D model."""
import collections
from typing import Optional, Sequence
import haiku as hk
import jax
from jax import numpy as jnp
from mmv.models import types
class _MaxPool(hk.MaxPool):
"""A `hk.MaxPool` accepting (and discarding) an `is_training` argument."""
def __call__(self,
x: types.TensorLike,
is_training: bool = True) -> jnp.ndarray:
del is_training # Unused.
return super().__call__(x)
def self_gating(inputs: types.TensorLike) -> jnp.ndarray:
"""Feature gating as used in S3D-G.
Transforms the input features by aggregating features from all spatial and
temporal locations, and applying gating conditioned on the aggregated
features. More details can be found at: https://arxiv.org/abs/1712.04851.
Args:
inputs: A 5-D float array of shape `[B, T, H, W, C]`.
Returns:
A tensor with the same shape as input_tensor.
Raises:
ValueError: If `inputs` has the wrong shape.
"""
if inputs.ndim != 5:
raise ValueError(
f'Expected an input of shape `[B, T, H, W, C]` but got {inputs.shape}.')
input_shape = inputs.shape
num_channels = input_shape[4]
spatiotemporal_average = jnp.mean(inputs, axis=(1, 2, 3))
weights = hk.Linear(num_channels, name='self_gating')(spatiotemporal_average)
weights = jax.nn.sigmoid(weights)
return jnp.multiply(weights[:, None, None, None, :], inputs)
class SUnit3D(hk.Module):
"""Base 3d Unit combining Conv3d + Batch Norm + non-linearity."""
def __init__(
self,
output_channels: int,
kernel_shape: Sequence[int] = (1, 1, 1),
stride: Sequence[int] = (1, 1, 1),
with_bias: bool = False,
separable: bool = False,
normalize_fn: Optional[types.NormalizeFn] = None,
activation_fn: Optional[types.ActivationFn] = jax.nn.relu,
self_gating_fn: Optional[types.GatingFn] = None,
name='SUnit3D'):
"""Initializes the SUnit3D module.
Args:
output_channels: Number of output channels.
kernel_shape: The shape of the kernel. A sequence of length 3.
stride: Stride for the kernel. A sequence of length 3.
with_bias: Whether to add a bias to the convolution.
separable: Whether to use separable.
normalize_fn: Function used for normalization.
activation_fn: Function used as non-linearity.
self_gating_fn: Function used for self-gating.
name: The name of the module.
Raises:
ValueError: If `kernel_shape` or `stride` has the wrong shape.
"""
super().__init__(name=name)
# Check args.
if len(kernel_shape) != 3:
raise ValueError(
'Given `kernel_shape` must have length 3 but has length '
f'{len(kernel_shape)}.')
if len(stride) != 3:
raise ValueError(
f'Given `stride` must have length 3 but has length {len(stride)}.')
self._normalize_fn = normalize_fn
self._activation_fn = activation_fn
self._self_gating_fn = self_gating_fn
k0, k1, k2 = kernel_shape
if separable and k1 != 1:
spatial_kernel_shape = [1, k1, k2]
temporal_kernel_shape = [k0, 1, 1]
s0, s1, s2 = stride
spatial_stride = [1, s1, s2]
temporal_stride = [s0, 1, 1]
self._convolutions = [
hk.Conv3D(
output_channels=output_channels,
kernel_shape=spatial_kernel_shape,
stride=spatial_stride,
padding='SAME',
with_bias=with_bias),
hk.Conv3D(
output_channels=output_channels,
kernel_shape=temporal_kernel_shape,
stride=temporal_stride,
padding='SAME',
with_bias=with_bias)
]
else:
self._convolutions = [
hk.Conv3D(
output_channels=output_channels,
kernel_shape=kernel_shape,
stride=stride,
padding='SAME',
with_bias=with_bias)]
def __call__(
self,
inputs: types.TensorLike,
is_training: bool) -> jnp.ndarray:
"""Connects the module to inputs.
Args:
inputs: A 5-D float array of shape `[B, T, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 5-D float array of shape `[B, new_t, new_h, new_w, output_channels]`.
"""
x = inputs
for conv in self._convolutions:
x = conv(x)
if self._normalize_fn is not None:
x = self._normalize_fn(x, is_training=is_training)
if self._activation_fn is not None:
x = self._activation_fn(x)
if self._self_gating_fn:
x = self._self_gating_fn(x)
return x # pytype: disable=bad-return-type # jax-devicearray
class InceptionBlockV13D(hk.Module):
"""A 3D Inception v1 block.
This allows use of separable 3D convolutions and self-gating, as described in:
Rethinking Spatiotemporal Feature Learning For Video Understanding.
Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu and Kevin Murphy.
https://arxiv.org/abs/1712.04851.
"""
def __init__(self,
output_channels: Sequence[int],
normalize_fn: Optional[types.NormalizeFn],
temporal_kernel_size: int = 3,
self_gating_fn: Optional[types.GatingFn] = None,
name: str = 'InceptionBlockV13D'):
"""Initializes the InceptionBlockV13D module.
Args:
output_channels: The size of the output channels of each block, ordered as
[Conv2d_0a_1x1, Conv2d_0a_1x1, Conv2d_0b_3x3, Conv2d_0a_1x1,
Conv2d_0b_3x3, Conv2d_0b_1x1]
normalize_fn: Function used for normalization.
temporal_kernel_size: The size of the temporal convolutional filters in
the conv3d_spatiotemporal blocks.
self_gating_fn: Function which optionally performs self-gating. If `None`,
no self-gating is applied.
name: The name of the module.
Raises:
ValueError: If `output_channels` has the wrong shape.
"""
super().__init__(name=name)
# Check args.
if len(output_channels) != 6:
raise ValueError(
'Given `output_channels` must have length 6 but has length '
f'{len(output_channels)}.')
self._output_channels = output_channels
self._normalize_fn = normalize_fn
self._temporal_kernel_size = temporal_kernel_size
if self_gating_fn is None:
self._self_gating_fn = lambda x: x
else:
self._self_gating_fn = self_gating_fn
def __call__(
self,
inputs: types.TensorLike,
is_training: bool) -> jnp.ndarray:
"""Connects the module to inputs.
Args:
inputs: A 5-D float array of shape `[B, T, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 5-D float array of shape
`[B, new_t, new_h, new_w, sum(output_channels)]`.
"""
# Branch 0
branch_0 = SUnit3D(
output_channels=self._output_channels[0],
kernel_shape=(1, 1, 1),
separable=False,
normalize_fn=self._normalize_fn,
self_gating_fn=self._self_gating_fn,
name='Branch_0_Conv2d_0a_1x1')(
inputs, is_training=is_training)
# Branch 1
branch_1 = SUnit3D(
output_channels=self._output_channels[1],
kernel_shape=(1, 1, 1),
separable=False,
normalize_fn=self._normalize_fn,
self_gating_fn=None,
name='Branch_1_Conv2d_0a_1x1')(
inputs, is_training=is_training)
branch_1 = SUnit3D(
output_channels=self._output_channels[2],
kernel_shape=(self._temporal_kernel_size, 3, 3),
separable=True,
normalize_fn=self._normalize_fn,
self_gating_fn=self._self_gating_fn,
name='Branch_1_Conv2d_0b_3x3')(
branch_1, is_training=is_training)
# Branch 2
branch_2 = SUnit3D(
output_channels=self._output_channels[3],
kernel_shape=(1, 1, 1),
separable=False,
normalize_fn=self._normalize_fn,
self_gating_fn=None,
name='Branch_2_Conv2d_0a_1x1')(
inputs, is_training=is_training)
branch_2 = SUnit3D(
output_channels=self._output_channels[4],
kernel_shape=(self._temporal_kernel_size, 3, 3),
separable=True,
normalize_fn=self._normalize_fn,
self_gating_fn=self._self_gating_fn,
name='Branch_2_Conv2d_0b_3x3')(
branch_2, is_training=is_training)
# Branch 3
branch_3 = hk.MaxPool(
window_shape=(1, 3, 3, 3, 1),
strides=(1, 1, 1, 1, 1),
padding='SAME',
name='Branch_3_MaxPool_0a_3x3')(
inputs)
branch_3 = SUnit3D(
output_channels=self._output_channels[5],
kernel_shape=(1, 1, 1),
separable=False,
normalize_fn=self._normalize_fn,
self_gating_fn=self._self_gating_fn,
name='Branch_3_Conv2d_0b_1x1')(
branch_3, is_training=is_training)
return jnp.concatenate((branch_0, branch_1, branch_2, branch_3), axis=4)
_Layer = collections.namedtuple('_Layer', ('name', 'module', 'kwargs'))
class S3D(hk.Module):
"""S3D architecture.
Any intermediary representation can be obtained by choosing one of the valid
`final_endpoint`s. The final value returned by this model (when 'Embeddings'
is used as `final_endpoint`) is a single 1-D representation for each video in
the batch. Another layer can be externally added on top of that to obtain
logits.
"""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'Conv2d_1a_7x7',
'MaxPool_2a_3x3',
'Conv2d_2b_1x1',
'Conv2d_2c_3x3',
'MaxPool_3a_3x3',
'Mixed_3b',
'Mixed_3c',
'MaxPool_4a_3x3',
'Mixed_4b',
'Mixed_4c',
'Mixed_4d',
'Mixed_4e',
'Mixed_4f',
'MaxPool_5a_2x2',
'Mixed_5b',
'Mixed_5c',
'Embeddings',
)
def __init__(self,
normalize_fn: Optional[types.NormalizeFn] = None,
first_temporal_kernel_size: int = 7,
temporal_conv_startat: Optional[str] = 'Conv2d_2c_3x3',
gating_startat: Optional[str] = 'Conv2d_2c_3x3',
name='S3D'):
"""Initializes the S3D module.
Args:
normalize_fn: Function used for normalization.
first_temporal_kernel_size: Specifies the temporal kernel size for the
first conv3d filter. A larger value slows down the model but provides
little accuracy improvement. Must be set to one of 1, 3, 5 or 7.
temporal_conv_startat: Specifies the first conv block to use separable 3D
convs rather than 2D convs (implemented as [1, k, k] 3D conv). This is
used to construct the inverted pyramid models. 'Conv2d_2c_3x3' is the
first valid block to use separable 3D convs. If provided block name is
not present, all valid blocks will use separable 3D convs.
gating_startat: Specifies the first conv block to use self gating.
'Conv2d_2c_3x3' is the first valid block to use self gating. If provided
block name is not present, all valid blocks will use separable 3D convs.
name: The name of the module.
Raises:
ValueError: If `temporal_conv_startat`, `gating_startat` or
`first_temporal_kernel_size` is not recognized.
"""
super().__init__(name=name)
self._first_temporal_kernel_size = first_temporal_kernel_size
self._temporal_conv_startat = temporal_conv_startat
self._gating_startat = gating_startat
self._normalize_fn = normalize_fn
if (temporal_conv_startat not in self.VALID_ENDPOINTS
and temporal_conv_startat is not None):
raise ValueError(
f'Provided `temporal_conv_startat`: {temporal_conv_startat} not '
f'valid. It must be one of: {self.VALID_ENDPOINTS}, or `None`.')
if (gating_startat not in self.VALID_ENDPOINTS
and gating_startat is not None):
raise ValueError(
f'Provided `gating_startat`: {gating_startat} not valid. '
f'It must be one of: {self.VALID_ENDPOINTS}, or `None`.')
if first_temporal_kernel_size not in [1, 3, 5, 7]:
raise ValueError('`first_temporal_kernel_size` can only be 1, 3, 5 or 7.')
def __call__(self,
inputs: types.TensorLike,
is_training: bool,
final_endpoint: str = 'Embeddings') -> jnp.ndarray:
"""Connects the model to inputs.
Args:
inputs: A 5-D float array of shape `[B, T, H, W, C]`.
is_training: Whether to use training mode.
final_endpoint: Up to which endpoint to run / return.
Returns:
A 5-D float array of shape
`[B, new_t, new_h, new_w, sum(output_channels)]`.
Returns:
Network output at location `final_endpoint`. A float array which shape
depends on `final_endpoint`.
Raises:
ValueError: If `final_endpoint` is not recognized.
"""
if final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError(f'Provided final_endpoint: {final_endpoint} not valid.'
f' It must be one of: {self.VALID_ENDPOINTS}')
x = inputs
# We define layers with tuples (name, module, kwargs)
# Not all kwargs are present, as we will need to fill in certain properties
# as we move down the network.
layers = []
# The first layer is conditional on the input data shape: the channel size
# is used to identify whether the `space_to_depth` transformation has been
# applied to the input. This is used to speed up computation on TPUs.
if x.shape[-1] == 3:
layers.append(
_Layer('Conv2d_1a_7x7', SUnit3D,
dict(output_channels=64, stride=(2, 2, 2), separable=False,
kernel_shape=(self._first_temporal_kernel_size, 7, 7),
normalize_fn=self._normalize_fn)))
else:
layers.append(
_Layer('Conv2d_1a_7x7', SUnit3D,
dict(output_channels=64, kernel_shape=(2, 4, 4),
stride=(1, 1, 1), separable=False,
normalize_fn=self._normalize_fn)))
layers.extend([
_Layer('MaxPool_2a_3x3', _MaxPool,
dict(window_shape=(1, 1, 3, 3, 1), strides=(1, 1, 2, 2, 1),
padding='SAME')),
_Layer('Conv2d_2b_1x1', SUnit3D,
dict(output_channels=64, kernel_shape=(1, 1, 1),
normalize_fn=self._normalize_fn)),
_Layer('Conv2d_2c_3x3', SUnit3D,
dict(output_channels=192, separable=True,
normalize_fn=self._normalize_fn)),
_Layer('MaxPool_3a_3x3', _MaxPool,
dict(window_shape=(1, 1, 3, 3, 1), strides=(1, 1, 2, 2, 1),
padding='SAME')),
_Layer('Mixed_3b', InceptionBlockV13D,
dict(output_channels=(64, 96, 128, 16, 32, 32),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_3c', InceptionBlockV13D,
dict(output_channels=(128, 128, 192, 32, 96, 64),
normalize_fn=self._normalize_fn)),
_Layer('MaxPool_4a_3x3', _MaxPool,
dict(window_shape=(1, 3, 3, 3, 1), strides=(1, 2, 2, 2, 1),
padding='SAME')),
_Layer('Mixed_4b', InceptionBlockV13D,
dict(output_channels=(192, 96, 208, 16, 48, 64),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_4c', InceptionBlockV13D,
dict(output_channels=(160, 112, 224, 24, 64, 64),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_4d', InceptionBlockV13D,
dict(output_channels=(128, 128, 256, 24, 64, 64),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_4e', InceptionBlockV13D,
dict(output_channels=(112, 144, 288, 32, 64, 64),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_4f', InceptionBlockV13D,
dict(output_channels=(256, 160, 320, 32, 128, 128),
normalize_fn=self._normalize_fn)),
_Layer('MaxPool_5a_2x2', _MaxPool,
dict(window_shape=(1, 2, 2, 2, 1), strides=(1, 2, 2, 2, 1),
padding='SAME')),
_Layer('Mixed_5b', InceptionBlockV13D,
dict(output_channels=(256, 160, 320, 32, 128, 128),
normalize_fn=self._normalize_fn)),
_Layer('Mixed_5c', InceptionBlockV13D,
dict(output_channels=(384, 192, 384, 48, 128, 128),
normalize_fn=self._normalize_fn)),
])
# These parameters may change thoughout the computation.
self_gating_fn = None
temporal_kernel_size = 1
# Iterate over layers.
for layer in layers:
# Update
if layer.name == self._gating_startat:
self_gating_fn = self_gating
if layer.name == self._temporal_conv_startat:
temporal_kernel_size = 3
kwargs = layer.kwargs
if layer.module is SUnit3D:
kwargs['self_gating_fn'] = self_gating_fn
if 'kernel_shape' not in kwargs:
kwargs['kernel_shape'] = (temporal_kernel_size, 3, 3)
elif layer.module is InceptionBlockV13D:
kwargs['self_gating_fn'] = self_gating_fn
kwargs['temporal_kernel_size'] = temporal_kernel_size
module = layer.module(name=layer.name, **kwargs)
x = module(x, is_training=is_training)
if final_endpoint == layer.name:
return x
assert final_endpoint == 'Embeddings'
return jnp.mean(x, axis=(1, 2, 3))
================================================
FILE: mmv/models/s3d_test.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for s3d."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import numpy as np
from mmv.models import normalization
from mmv.models import s3d
class _CallableS3D:
"""Wrapper around S3D that take care of parameter book keeping."""
def __init__(self, *args, **kwargs):
self._model = hk.transform_with_state(
lambda *a, **k: # pylint: disable=g-long-lambda,unnecessary-lambda
s3d.S3D(
normalize_fn=normalization.get_normalize_fn(),
*args, **kwargs)(*a, **k))
self._rng = jax.random.PRNGKey(42)
self._params, self._state = None, None
def init(self, inputs, **kwargs):
self._params, self._state = self._model.init(
self._rng, inputs, is_training=True, **kwargs)
def __call__(self, inputs, **kwargs):
if self._params is None:
self.init(inputs)
output, _ = self._model.apply(
self._params, self._state, self._rng, inputs, **kwargs)
return output
class S3DTest(parameterized.TestCase):
# Testing all layers is quite slow, added in comments for completeness.
@parameterized.parameters(
# dict(endpoint='Conv2d_1a_7x7', expected_size=(2, 8, 112, 112, 64)),
# dict(endpoint='MaxPool_2a_3x3', expected_size=(2, 8, 56, 56, 64)),
# dict(endpoint='Conv2d_2b_1x1', expected_size=(2, 8, 56, 56, 64)),
# dict(endpoint='Conv2d_2c_3x3', expected_size=(2, 8, 56, 56, 192)),
# dict(endpoint='MaxPool_3a_3x3', expected_size=(2, 8, 28, 28, 192)),
# dict(endpoint='Mixed_3b', expected_size=(2, 8, 28, 28, 256)),
# dict(endpoint='Mixed_3c', expected_size=(2, 8, 28, 28, 480)),
# dict(endpoint='MaxPool_4a_3x3', expected_size=(2, 4, 14, 14, 480)),
# dict(endpoint='Mixed_4b', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4c', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4d', expected_size=(2, 4, 14, 14, 512)),
# dict(endpoint='Mixed_4e', expected_size=(2, 4, 14, 14, 528)),
# dict(endpoint='Mixed_4f', expected_size=(2, 4, 14, 14, 832)),
# dict(endpoint='MaxPool_5a_2x2', expected_size=(2, 2, 7, 7, 832)),
# dict(endpoint='Mixed_5b', expected_size=(2, 2, 7, 7, 832)),
# dict(endpoint='Mixed_5c', expected_size=(2, 2, 7, 7, 1024)),
dict(endpoint='Embeddings', expected_size=(2, 1024)),
)
def test_endpoint_expected_output_dimensions(self, endpoint, expected_size):
inputs = np.random.normal(size=(2, 16, 224, 224, 3))
model = _CallableS3D()
output = model(inputs, is_training=False, final_endpoint=endpoint)
self.assertSameElements(output.shape, expected_size)
def test_space_to_depth(self):
inputs = np.random.normal(size=(2, 16//2, 224//2, 224//2, 3*2*2*2))
model = _CallableS3D()
output = model(inputs, is_training=False, final_endpoint='Conv2d_1a_7x7')
self.assertSameElements(output.shape, (2, 8, 112, 112, 64))
if __name__ == '__main__':
absltest.main()
================================================
FILE: mmv/models/tsm_resnet.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Temporal Shift Module w/ ResNet-50 and ResNet-101.
Based on:
TSM: Temporal Shift Module for Efficient Video Understanding
Ji Lin, Chuang Gan, Song Han
https://arxiv.org/pdf/1811.08383.pdf.
"""
from typing import Optional
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import tsm_utils as tsmu
from mmv.models import types
class TSMResNetBlock(hk.Module):
"""A ResNet subblock with Temporal Channel Shifting.
Combines a typical ResNetV2 block implementation
(see https://arxiv.org/abs/1512.03385) with a pre-convolution Temporal
Shift Module (see https://arxiv.org/pdf/1811.08383.pdf) in the residual.
"""
def __init__(self,
output_channels: int,
stride: int,
use_projection: bool,
tsm_mode: str,
normalize_fn: Optional[types.NormalizeFn] = None,
channel_shift_fraction: float = 0.125,
num_frames: int = 8,
name: str = 'TSMResNetBlock'):
"""Initializes the TSMResNetBlock module.
Args:
output_channels: Number of output channels.
stride: Stride used in convolutions.
use_projection: Whether to use a projection for the shortcut.
tsm_mode: Mode for TSM ('gpu' or 'tpu').
normalize_fn: Function used for normalization.
channel_shift_fraction: The fraction of temporally shifted channels. If
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
block.
num_frames: Size of frame dimension in a single batch example
name: The name of the module.
"""
super().__init__(name=name)
self._output_channels = output_channels
self._bottleneck_channels = output_channels // 4
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
self._tsm_mode = tsm_mode
self._channel_shift_fraction = channel_shift_fraction
self._num_frames = num_frames
def __call__(self,
inputs: types.TensorLike,
is_training: bool = True) -> jnp.ndarray:
"""Connects the ResNetBlock module into the graph.
Args:
inputs: A 4-D float array of shape `[B, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 4-D float array of shape
`[B * num_frames, new_h, new_w, output_channels]`.
"""
# ResNet V2 uses pre-activation, where the batch norm and relu are before
# convolutions, rather than after as in ResNet V1.
preact = inputs
if self._normalize_fn is not None:
preact = self._normalize_fn(preact, is_training=is_training)
preact = jax.nn.relu(preact)
if self._use_projection:
shortcut = hk.Conv2D(
output_channels=self._output_channels,
kernel_shape=1,
stride=self._stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')(
preact)
else:
shortcut = inputs
# Eventually applies Temporal Shift Module.
if self._channel_shift_fraction != 0:
preact = tsmu.apply_temporal_shift(
preact, tsm_mode=self._tsm_mode, num_frames=self._num_frames,
channel_shift_fraction=self._channel_shift_fraction)
# First convolution.
residual = hk.Conv2D(
self._bottleneck_channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')(
preact)
# Second convolution.
if self._normalize_fn is not None:
residual = self._normalize_fn(residual, is_training=is_training)
residual = jax.nn.relu(residual)
residual = hk.Conv2D(
output_channels=self._bottleneck_channels,
kernel_shape=3,
stride=self._stride,
with_bias=False,
padding='SAME',
name='conv_1')(
residual)
# Third convolution.
if self._normalize_fn is not None:
residual = self._normalize_fn(residual, is_training=is_training)
residual = jax.nn.relu(residual)
residual = hk.Conv2D(
output_channels=self._output_channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_2')(
residual)
# NOTE: we do not use block multiplier.
output = shortcut + residual
return output
class TSMResNetUnit(hk.Module):
"""Block group for TSM ResNet."""
def __init__(self,
output_channels: int,
num_blocks: int,
stride: int,
tsm_mode: str,
num_frames: int,
normalize_fn: Optional[types.NormalizeFn] = None,
channel_shift_fraction: float = 0.125,
name: str = 'tsm_resnet_unit'):
"""Creates a TSMResNet Unit.
Args:
output_channels: Number of output channels.
num_blocks: Number of ResNet blocks in the unit.
stride: Stride of the unit.
tsm_mode: Which temporal shift module to use.
num_frames: Size of frame dimension in a single batch example.
normalize_fn: Function used for normalization.
channel_shift_fraction: The fraction of temporally shifted channels. If
`channel_shift_fraction` is 0, the block is the same as a normal ResNet
block.
name: The name of the module.
"""
super().__init__(name=name)
self._output_channels = output_channels
self._num_blocks = num_blocks
self._normalize_fn = normalize_fn
self._stride = stride
self._tsm_mode = tsm_mode
self._channel_shift_fraction = channel_shift_fraction
self._num_frames = num_frames
def __call__(self,
inputs: types.TensorLike,
is_training: bool) -> jnp.ndarray:
"""Connects the module to inputs.
Args:
inputs: A 4-D float array of shape `[B * num_frames, H, W, C]`.
is_training: Whether to use training mode.
Returns:
A 4-D float array of shape
`[B * num_frames, H // stride, W // stride, output_channels]`.
"""
net = inputs
for idx_block in range(self._num_blocks):
net = TSMResNetBlock(
self._output_channels,
stride=self._stride if idx_block == 0 else 1,
use_projection=idx_block == 0,
normalize_fn=self._normalize_fn,
tsm_mode=self._tsm_mode,
channel_shift_fraction=self._channel_shift_fraction,
num_frames=self._num_frames,
name=f'block_{idx_block}')(
net, is_training=is_training)
return net # pytype: disable=bad-return-type # jax-devicearray
class TSMResNetV2(hk.Module):
"""TSM based on ResNet V2 as described in https://arxiv.org/abs/1603.05027."""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'tsm_resnet_stem',
'tsm_resnet_unit_0',
'tsm_resnet_unit_1',
'tsm_resnet_unit_2',
'tsm_resnet_unit_3',
'last_conv',
'Embeddings',
)
def __init__(self,
normalize_fn: Optional[types.NormalizeFn] = None,
depth: int = 50,
num_frames: int = 16,
channel_shift_fraction: float = 0.125,
width_mult: int = 1,
name: str = 'TSMResNetV2'):
"""Constructs a ResNet model.
Args:
normalize_fn: Function used for normalization.
depth: Depth of the desired ResNet.
num_frames: Number of frames (used in TPU mode).
channel_shift_fraction: Fraction of channels that are temporally shifted,
if `channel_shift_fraction` is 0, a regular ResNet is returned.
width_mult: Whether or not to use a width multiplier.
name: The name of the module.
Raises:
ValueError: If `channel_shift_fraction` or `depth` has invalid value.
"""
super().__init__(name=name)
if not 0. <= channel_shift_fraction <= 1.0:
raise ValueError(
f'channel_shift_fraction ({channel_shift_fraction})'
' has to be in [0, 1].')
self._num_frames = num_frames
self._channels = (256, 512, 1024, 2048)
self._strides = (1, 2, 2, 2)
num_blocks = {
50: (3, 4, 6, 3),
101: (3, 4, 23, 3),
152: (3, 8, 36, 3),
200: (3, 24, 36, 3),
}
if depth not in num_blocks:
raise ValueError(
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
self._num_blocks = num_blocks[depth]
self._width_mult = width_mult
self._channel_shift_fraction = channel_shift_fraction
self._normalize_fn = normalize_fn
def __call__(
self,
inputs: types.TensorLike,
is_training: bool = True,
final_endpoint: str = 'Embeddings') -> jnp.ndarray:
"""Connects the TSM ResNetV2 module into the graph.
Args:
inputs: A 4-D float array of shape `[B, H, W, C]`.
is_training: Whether to use training mode.
final_endpoint: Up to which endpoint to run / return.
Returns:
Network output at location `final_endpoint`. A float array which shape
depends on `final_endpoint`.
Raises:
ValueError: If `final_endpoint` is not recognized.
"""
# Prepare inputs for TSM.
inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs)
num_frames = num_frames or self._num_frames
self._final_endpoint = final_endpoint
if self._final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError(f'Unknown final endpoint {self._final_endpoint}')
# Stem convolution.
end_point = 'tsm_resnet_stem'
net = hk.Conv2D(
output_channels=64 * self._width_mult,
kernel_shape=7,
stride=2,
with_bias=False,
name=end_point,
padding='SAME')(
inputs)
net = hk.MaxPool(
window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')(
net)
if self._final_endpoint == end_point:
return net
# Residual block.
for unit_id, (channels, num_blocks, stride) in enumerate(
zip(self._channels, self._num_blocks, self._strides)):
end_point = f'tsm_resnet_unit_{unit_id}'
net = TSMResNetUnit(
output_channels=channels * self._width_mult,
num_blocks=num_blocks,
stride=stride,
normalize_fn=self._normalize_fn,
channel_shift_fraction=self._channel_shift_fraction,
num_frames=num_frames,
tsm_mode=tsm_mode,
name=end_point)(
net, is_training=is_training)
if self._final_endpoint == end_point:
return net
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
end_point = 'last_conv'
if self._final_endpoint == end_point:
return net
net = jnp.mean(net, axis=(1, 2))
# Prepare embedding outputs for TSM (temporal average of features).
net = tsmu.prepare_outputs(net, tsm_mode, num_frames)
assert self._final_endpoint == 'Embeddings'
return net
================================================
FILE: mmv/models/tsm_resnet_test.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TSM ResNet model."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import tsm_resnet
class TSMResNetTest(parameterized.TestCase):
@parameterized.parameters(
('tsm_resnet_stem', (2 * 32, 56, 56, 64)),
('tsm_resnet_unit_0', (2 * 32, 56, 56, 256)),
('tsm_resnet_unit_1', (2 * 32, 28, 28, 512)),
('tsm_resnet_unit_2', (2 * 32, 14, 14, 1024)),
('tsm_resnet_unit_3', (2 * 32, 7, 7, 2048)),
('last_conv', (2 * 32, 7, 7, 2048)),
('Embeddings', (2, 2048)),
)
def test_output_dimension(self, final_endpoint, expected_shape):
input_shape = (2, 32, 224, 224, 3)
def f():
data = jnp.zeros(input_shape)
net = tsm_resnet.TSMResNetV2()
return net(data, final_endpoint=final_endpoint)
init_fn, apply_fn = hk.transform(f)
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
self.assertEqual(out.shape, expected_shape)
def test_tpu_mode(self):
input_shape = (32 * 2, 224, 224, 3)
def f():
data = jnp.zeros(input_shape)
net = tsm_resnet.TSMResNetV2(num_frames=32)
return net(data, final_endpoint='Embeddings')
init_fn, apply_fn = hk.transform(f)
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
self.assertEqual(out.shape, (2, 2048))
if __name__ == '__main__':
absltest.main()
================================================
FILE: mmv/models/tsm_utils.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils functions for TSM."""
from typing import Tuple
import jax
import jax.numpy as jnp
from mmv.models import types
def prepare_inputs(
inputs: types.TensorLike) -> Tuple[jnp.ndarray, str, int]:
"""Deduces input mode for TSM."""
# Deduce if we run on TPU based on input shape.
if len(inputs.shape) == 5:
# Input is given in the standard [B, T, H, W, 3] format.
tsm_mode = 'gpu'
num_frames = inputs.shape[1]
inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:]))
else:
# Input is given in the [T * B, H, W, 3] format.
tsm_mode = 'tpu'
num_frames = None
return inputs, tsm_mode, num_frames
def prepare_outputs(outputs: types.TensorLike,
tsm_mode: str,
num_frames: int) -> jnp.ndarray:
"""Processes output of TSM by averaging representations over time axis."""
n_channels = outputs.shape[-1]
if tsm_mode == 'tpu':
outputs = jnp.reshape(outputs, [num_frames, -1, n_channels])
outputs = jnp.mean(outputs, axis=0)
elif tsm_mode == 'gpu':
outputs = jnp.reshape(outputs, [-1, num_frames, n_channels])
outputs = jnp.mean(outputs, axis=1)
else:
raise ValueError(
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
return outputs
def apply_temporal_shift(
x: types.TensorLike,
tsm_mode: str,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode."""
if tsm_mode == 'tpu':
outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction)
elif tsm_mode == 'gpu':
outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction)
else:
raise ValueError(
f'`tsm_mode` should be \'tpu\' or \'gpu\' ({tsm_mode} given)')
return outputs
def temporal_shift_gpu(
x: types.TensorLike,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383."""
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
# Input is (B * T, H, W, C)
orig_shp = tuple(x.shape)
reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:])
n_channels = orig_shp[-1]
n_shift = int(n_channels * channel_shift_fraction)
new_shp = tuple(reshaped_x.shape)
# shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]
shifted_backward = jax.lax.slice(
reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift),
(new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4]))
shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0))
shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding)
# shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]
shifted_forward = jax.lax.slice(
reshaped_x, (0, 0, 0, 0, 0),
(new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift))
shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0))
shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding)
no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift]
shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward],
axis=4)
return jnp.reshape(shifted_x, (-1,) + orig_shp[1:])
def temporal_shift_tpu(
x: types.TensorLike,
num_frames: int,
channel_shift_fraction: float = 0.125) -> jnp.ndarray:
"""Performs a temporal shift: https://arxiv.org/abs/1811.08383.
TPU optimized version of TSM. Reshape is avoided by having the images
reshaped in [T * B, :] so that frames corresponding to same time frame in
videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse
pad->slice into convolution, we reformulate the slice pad into a pad then
slice. Finally, to avoid concatenate that prevent some fusion from happening
we simply sum masked version of the features.
Args:
x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped
from a time major version of the input).
num_frames: number of frames T per video.
channel_shift_fraction: fraction of the channel to shift forward and
backward.
Returns:
The temporal shifted version of x.
"""
# B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
# Input is (T * B, H, W, C)
original_shape = list(x.shape)
batch_size = int(original_shape[0] / num_frames)
n_channels = int(original_shape[-1])
n_shift = int(n_channels * channel_shift_fraction)
# Cast to bfloat16.
x = x.astype(jnp.bfloat16)
# For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1.
# Shift backward, we first pad by zeros [x1, x2, x3, 0, 0].
orig_shp = list(x.shape)
shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0),
(0, n_channels - n_shift, 0))
x_backward_padding = jax.lax.pad(
x,
padding_value=jnp.bfloat16(0.),
padding_config=shifted_backward_padding)
# The following shift gets to [x3^+1, 0, 0] (where +1 means from the future).
shifted_backward = jax.lax.slice(x_backward_padding,
(batch_size, 0, 0, n_channels - n_shift),
(orig_shp[0] + batch_size, orig_shp[1],
orig_shp[2], 2 * n_channels - n_shift))
# Shift forward, we first pad by zeros [0, 0, x1, x2, x3].
shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0),
(n_channels - n_shift, 0, 0))
x_forward_padding = jax.lax.pad(
x,
padding_value=jnp.bfloat16(0.),
padding_config=shifted_forward_padding)
# The following shift gets to [0, 0, x1^-1] (where -1 means from the past).
shifted_forward = jax.lax.slice(
x_forward_padding, (0, 0, 0, 0),
(orig_shp[0], orig_shp[1], orig_shp[2], n_channels))
# No shift is in the middle, this gets [0, x2, 0].
mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) &
(jnp.arange(n_channels) < n_channels - n_shift),
(1, 1, 1, -1))).astype(jnp.bfloat16)
no_shift = mask_noshift * x
# By summing everything together, we end up with [x3^+1, x2, x1^-1].
# Note: channels have been reordered but that doesn't matter for the model.
shifted_x = shifted_backward + shifted_forward + no_shift
return shifted_x.astype(jnp.float32)
================================================
FILE: mmv/models/tsm_utils_test.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tsm_utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
from mmv.models import tsm_utils
class TsmUtilsTest(parameterized.TestCase):
@parameterized.parameters(
((2, 32, 224, 224, 3), 'gpu', (2 * 32, 224, 224, 3), 32),
((32, 224, 224, 3), 'tpu', (32, 224, 224, 3), None),
)
def test_prepare_inputs(self, input_shape, expected_mode, expected_shape,
expected_num_frames):
data = jnp.zeros(input_shape)
out, mode, num_frames = tsm_utils.prepare_inputs(data)
self.assertEqual(out.shape, expected_shape)
self.assertEqual(mode, expected_mode)
self.assertEqual(num_frames, expected_num_frames)
def test_prepare_outputs(self):
data = jnp.concatenate([jnp.zeros(4), jnp.ones(4)]).reshape(4, 2)
out_gpu = tsm_utils.prepare_outputs(data, 'gpu', 2)
out_tpu = tsm_utils.prepare_outputs(data, 'tpu', 2)
expected_gpu = np.concatenate([np.zeros(2), np.ones(2)]).reshape(2, 2)
expected_tpu = 0.5 * jnp.ones((2, 2))
np.testing.assert_allclose(out_gpu, expected_gpu)
np.testing.assert_allclose(out_tpu, expected_tpu)
def test_apply_tsm(self):
shape = (32, 224, 224, 16)
data = jnp.zeros(shape)
out_gpu = tsm_utils.apply_temporal_shift(data, 'gpu', 16)
out_tpu = tsm_utils.apply_temporal_shift(data, 'tpu', 16)
self.assertEqual(out_gpu.shape, shape)
self.assertEqual(out_tpu.shape, shape)
if __name__ == '__main__':
absltest.main()
================================================
FILE: mmv/models/types.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type Aliases."""
from typing import Callable, Tuple, Union
import jax
import numpy as np
import optax
TensorLike = Union[np.ndarray, jax.Array]
ActivationFn = Callable[[TensorLike], TensorLike]
GatingFn = Callable[[TensorLike], TensorLike]
NetworkFn = Callable[[TensorLike], TensorLike]
# Callable doesn't allow kwargs to be used, and we often want to
# pass in is_training=..., so ignore the arguments for the sake of pytype.
NormalizeFn = Callable[..., TensorLike]
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
================================================
FILE: mmv/requirements.txt
================================================
dm-haiku
dm-tree
jax
jaxlib
numpy>=1.16
optax
sklearn
tensorflow
tensorflow_datasets
================================================
FILE: mmv/utils/checkpoint.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint restoring utilities."""
from absl import logging
import dill
def load_checkpoint(checkpoint_path):
try:
with open(checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s', checkpoint_path)
return checkpoint_data
except FileNotFoundError:
return None
================================================
FILE: mmv/utils/ucf101_dataset.py
================================================
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ucf101 with custom decoding params."""
import tensorflow as tf
import tensorflow_datasets as tfds
# Utilities functions.
tf.compat.v1.enable_eager_execution()
_CITATION = """\
@article{DBLP:journals/corr/abs-1212-0402,
author = {Khurram Soomro and
Amir Roshan Zamir and
Mubarak Shah},
title = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in
The Wild},
journal = {CoRR},
volume = {abs/1212.0402},
year = {2012},
url = {http://arxiv.org/abs/1212.0402},
archivePrefix = {arXiv},
eprint = {1212.0402},
timestamp = {Mon, 13 Aug 2018 16:47:45 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
_LABELS_FNAME = 'video/ucf101_labels.txt'
class ModUcf101(tfds.video.Ucf101):
"""Ucf101 action recognition dataset with better quality.
"""
def _info(self):
ffmpeg_extra_args = ('-qscale:v', '2', '-r', '25', '-t', '00:00:20')
video_shape = (
None, self.builder_config.height, self.builder_config.width, 3)
labels_names_file = tfds.core.tfds_path(_LABELS_FNAME)
features = tfds.features.FeaturesDict({
'video': tfds.features.Video(video_shape,
ffmpeg_extra_args=ffmpeg_extra_args,
encoding_format='jpeg'), # pytype: disable=wrong-arg-types # gen-stub-imports
'label': tfds.features.ClassLabel(names_file=labels_names_file),
})
return tfds.core.DatasetInfo(
builder=self,
description='A 101-label video classification dataset.',
features=features,
homepage='https://www.crcv.ucf.edu/data-sets/ucf101/',
citation=_CITATION,
)
================================================
FILE: neural_mip_solving/README.md
================================================
# Neural MIP Solving - NN Verification Dataset
This is the “Neural Network Verification” dataset used in the paper
[Solving Mixed Integer Programs Using Neural Networks (Nair et al., 2020)](https://arxiv.org/abs/2012.13349).
It contains a set of mixed integer programs (MIPs) for the problem of verifying
a neural network’s robustness to perturbations to its inputs. The MIP
formulation is described in the paper
[On the Effectiveness of Interval Bound Propagation for Training Verifiably Robust Models (Gowal et al., 2018)](https://arxiv.org/abs/1810.12715).
This dataset corresponds to MIPs defined for
verifying a neural network with the architecture labelled as “small” in Table 1
of Gowal et al., 2018, and trained on the MNIST image dataset. The code used to
train the neural network to be verified is available at
https://github.com/deepmind/interval-bound-propagation. The MIPs are split into
the same training, validation, and test sets as that in Nair et al., 2020.
## Dataset Location
The dataset is available in the following
[link](https://storage.cloud.google.com/neural-mip-solving/nn_verification.tar.gz)
## Dataset Metadata
The following table is necessary for this dataset to be indexed by search
engines such as Google Dataset Search.
| property | value | ||||||
|---|---|---|---|---|---|---|---|
| name | Neural Network Verification Dataset |
||||||
| url | https://github.com/deepmind/deepmind-research/tree/master/neural_mip_solving |
||||||
| sameAs | https://github.com/deepmind/deepmind-research/tree/master/neural_mip_solving |
||||||
| description |
This dataset contains a set of mixed integer programs (MIPs) for the
problem of verifying a neural network’s robustness to perturbations of its
inputs. The MIPs are encoded in LP format. |
||||||
| license | https://creativecommons.org/licenses/by/4.0/legalcode
|
||||||
| provider |
|
||||||
| citation | https://arxiv.org/abs/2012.13349 |
# An Instance-Dependent Simulation Framework for Learning with Label Noise
We propose a simulation framework for generating instance-dependent
noisy labels via a pseudo-labeling paradigm. We show that this framework
generates synthetic noisy labels whose distribution is closer to human labels
compared to independent and class-conditional random flipping.
Equipped with controllable label noise, we study the negative impact of
noisy labels across a few practical settings to
understand when label noise is more problematic. Additionally, with the
availability of annotator information from our simulation framework, we propose
a new technique, Label Quality Model (LQM), that leverages annotator features to
predict and correct against noisy labels. We show that by adding LQM as a label
correction step before applying existing noisy label techniques, we can further
improve the models' performance.
[An Instance-Dependent Simulation Framework for Learning with Label Noise](https://arxiv.org/pdf/2107.11413.pdf).
In this repository, we provide the link to the datasets that we used in Sections
4 and 5 of the above paper, along with a colab that demonstrates how to load the
data and rater features.
We consider 4 tasks:
[CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html),
[CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html),
[Patch Camelyon](https://patchcamelyon.grand-challenge.org/),
and
[Cats vs Dogs](https://www.microsoft.com/en-us/download/details.aspx?id=54765).
For each task, we generate three synthetic noisy label
datasets, named as "low", "medium", and "high" according to the amount of label
noise. The data are stored as TFRecords and the rater features are stored as
json files.
The data is available under
[noisy label synthetic dataset GCP bucket](https://console.cloud.google.com/storage/browser/noisy_label_synthetic_datasets).
The colab that contains details of the datasets and examples for data loading
is at
[this colab example](https://github.com/deepmind/deepmind-research/blob/master/noisy_label/noisy_label_datasets_and_rater_features.ipynb)
## License
The noisy labels and rater features in our datasets are under the
[CC0 License](https://choosealicense.com/licenses/cc0-1.0/).
Other parts of the datasets are under the original license of the datasets.
When using the datasets based on CIFAR10/CIFAR100, users are required to
attribute the following paper:
Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009
When using the datasets based on Patch Camelyon, users are required to
attribute the following paper:
Rotation Equivariant CNNs for Digital Pathology, Bastiaan S. Veeling,
Jasper Linmans, Jim Winkens, Taco Cohen, and Max Welling, arXiv:1806.03962.
When using the datasets based on Cats vs Dogs, users are required to
attribute the following paper:
Asirra: a CAPTCHA that exploits interest-aligned manual image categorization,
Jeremy Elson, John R. Douceur, Jon Howell, and Jared Saul, ACM Conference on
Computer and Communications Security, 2007.
The colab example is provided under the Apache License, Version 2.0.
## Citation
Please use the following bibtex for citations to our paper:
```
@article{gu2021instance,
title={An Instance-Dependent Simulation Framework for Learning with Label Noise},
author={Gu, Keren and Masotto, Xander and Bachani, Vandana and Lakshminarayanan, Balaji and Nikodem, Jack and Yin, Dong},
year={2021}
}
```
# Dataset Metadata
The following table is necessary for this dataset to be indexed by search
engines such as Google Dataset Search.
| property | value | ||||||
|---|---|---|---|---|---|---|---|
| name | Noisy Label Synthetic Datasets |
||||||
| url | https://github.com/deepmind/deepmind-research/tree/master/noisy_label |
||||||
| sameAs | https://github.com/deepmind/deepmind-research/tree/master/noisy_label |
||||||
| description |
Data accompanying
[An Instance-Dependent Simulation Framework for Learning with Label Noise]().
|
||||||
| provider |
|
||||||
| citation | https://identifiers.org/arxiv:2107.11413 |
| property | value | ||||||
|---|---|---|---|---|---|---|---|
| name | REGAL CostGraphDef Synthetic Dataset |
||||||
| url | https://github.com/deepmind/deepmind-research/tree/master/regal |
||||||
| sameAs | https://github.com/deepmind/deepmind-research/tree/master/regal |
||||||
| description |
This dataset contains dataflow computational graphs generated
procedurally, intended for training and evaluating algorithms that
optimize execution (e.g. placement and scheduling), in
[TensorFlow's CostGraphDef](https://github.com/tensorflow/tensorflow/blob/59ee7f9138482d85cd93c004aca961bea35820c7/tensorflow/core/framework/cost_graph.proto#L12)
[protocol buffer](https://en.wikipedia.org/wiki/Protocol_Buffers)
format and encoded as
[text](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.text_format).
|
||||||
| provider |
|
||||||
| citation | https://identifiers.org/arxiv:1905.02494 |
# RL Unplugged: Benchmarks for Offline Reinforcement Learning
RL Unplugged is suite of benchmarks for offline reinforcement learning. The RL
Unplugged is designed around the following considerations: to facilitate ease of
use, we provide the datasets with a unified API which makes it easy for the
practitioner to work with all data in the suite once a general pipeline has been
established. This is a dataset accompanying the paper
[RL Unplugged: Benchmarks for Offline Reinforcement Learning]([https://arxiv.org/abs/2006.13888]).
In this suite of benchmarks, we try to focus on the following problems:
- High dimensional action spaces, for example the locomotion humanoid domains,
we have 56 dimensional actions.
- High dimensional observations.
- Partial observability, observations have egocentric vision.
- Difficulty of exploration, using states of the art algorithms and imitation
to generate data for difficult environments.
- Real world challenges.
The data is available under
[RL Unplugged GCP bucket](https://console.cloud.google.com/storage/browser/rl_unplugged).
## Atari Dataset
We are releasing a large and diverse dataset of gameplay following the protocol
described by [Agarwal et al., 2020], which can be used to evaluate several
discrete offline RL algorithms. The dataset is generated by running an online
DQN agent and recording transitions from its replay during training with sticky
actions [Machado et al., 2018]. As stated in [Agarwal et al., 2020], for each
game we use data from five runs with 50 million transitions each. States in each
transition include stacks of four frames to be able to do frame-stacking with
our baselines. We release datasets for 46 Atari games. For details on how the
dataset was generated, please refer to the paper.
Atari is a standard RL benchmark. We recommend you to try offline RL methods
on Atari if you are interested in comparing your approach to other state of the
art offline RL methods with discrete actions.
## DeepMind Locomotion Dataset
These tasks are made up of the corridor locomotion tasks involving the CMU
Humanoid, for which prior efforts have either used motion capture data [Merel et
al., 2019a], [Merel et al., 2019b] or training from scratch [Song et al., 2020].
In addition, the DM Locomotion repository contains a set of tasks adapted to be
suited to a virtual rodent [Merel et al., 2020]. We emphasize that the DM
Locomotion tasks feature the combination of challenging high-DoF continuous
control along with perception from rich egocentric observations. For details on
how the dataset was generated, please refer to the paper.
We recommend you to try offline RL methods on DeepMind Locomotion dataset, if
you are interested in very challenging offline RL dataset with continuous
action space.
## DeepMind Control Suite Dataset
DeepMind Control Suite [Tassa et al., 2018] is a set of control tasks
implemented in MuJoCo [Todorov et al., 2012]. We consider a subset of the tasks
provided in the suite that cover a wide range of difficulties.
Most of the datasets in this domain are generated using D4PG. For the
environments Manipulator insert ball and Manipulator insert peg we use V-MPO
[Song et al., 2020] to generate the data as D4PG is unable to solve these tasks.
We release datasets for 9 control suite tasks. For details on how the dataset
was generated, please refer to the paper.
DeepMind Control Suite is a traditional continuous action RL benchmark. In
particular, we recommend you test your approach in DeepMind Control Suite if
you are interested in comparing against other state of the art offline RL
methods.
## Realworld RL Dataset
Examples in the dataset represent SARS transitions stored when running a
partially online trained agent as described in
[RWRL](https://arxiv.org/abs/2003.11881).
We release 8 datasets in total -- with no combined challenge and easy combined
challenge on the cartpole, walker, quadruped, and humanoid tasks. For details on
how the dataset was generated, please refer to the paper.
## DeepMind Lab Dataset
DeepMind Lab dataset has several levels from the challenging, partially
observable [Deepmind Lab suite](https://github.com/deepmind/lab). DeepMind Lab
dataset is collected by training distributed R2D2 by [Kapturowski et al., 2018]
agents from scratch on individual tasks. We recorded the experience across all
actors during entire training runs a few times for every task. The details of
the dataset generation process is described in [Gulcehre et al., 2021].
We release datasets for five different DeepMind Lab levels: `seekavoid_arena_01`,
`explore_rewards_few`, `explore_rewards_many`, `rooms_watermaze`,
`rooms_select_nonmatching_object`. We also release the snapshot datasets for
`seekavoid_arena_01` level that we generated the datasets from a trained R2D2
snapshot with different levels of epsilons for the epsilon-greedy algorithm
when evaluating the agent in the environment.
DeepMind Lab dataset is fairly large-scale. We recommend you to try it if you
are interested in large-scale offline RL models with memory.
## bsuite Dataset
[bsuite](https://github.com/deepmind/bsuite) data was collected by training DQN
agents with the default setting in [Acme](https://github.com/deepmind/acme) from
scratch in each one of the following three tasks: cartpole, catch, and
mountain_car.
We converted the originally deterministic environments into stochastic ones by
randomly replacing the agent action with a uniformly sampled action with a
probability of {0, 0.1, 0.2, 0.3, 0.4, 0.5}. In this case, probability of 0
corresponds to original environment. The details of
the dataset generation process is described in [Gulcehre et al., 2021].
bsuite datasets are fairly light-weight and running experiments doesn't require
too much compute. We recommend you to try bsuite, if you are interested in
small-scale and easy to run offline RL datasets generated by stochastic
environments where the stochasticity of the environment is easy to control.
## Running the code
### Installation
* Install dependencies: `pip install -r requirements.txt`
* (Optional) Setup MuJoCo license key for DM Control environments
([instructions](https://github.com/deepmind/dm_control#requirements-and-installation)).
* (Optional) Install
[realworldrl_suite](https://github.com/google-research/realworldrl_suite#installation).
### Atari example
```
mkdir -p /tmp/dataset/Asterix
gsutil cp gs://rl_unplugged/atari/Asterix/run_1-00000-of-00100 \
/tmp/dataset/Asterix/run_1-00000-of-00001
python atari_example.py --path=/tmp/dataset --game=Asterix
```
This copies a single shard from one of the Asterix datasets from GCP to a local
folder, and then runs a script that loads a single example and runs a step on
the Atari environment.
## Citation
Please use the following bibtex for citations:
```
@misc{gulcehre2020rl,
title={RL Unplugged: Benchmarks for Offline Reinforcement Learning},
author={Caglar Gulcehre and Ziyu Wang and Alexander Novikov and Tom Le Paine
and Sergio Gómez Colmenarejo and Konrad Zolna and Rishabh Agarwal and
Josh Merel and Daniel Mankowitz and Cosmin Paduraru and Gabriel
Dulac-Arnold and Jerry Li and Mohammad Norouzi and Matt Hoffman and
Ofir Nachum and George Tucker and Nicolas Heess and Nando deFreitas},
year={2020},
eprint={2006.13888},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
# Dataset Metadata
The following table is necessary for this dataset to be indexed by search
engines such as Google Dataset Search.
| property | value | ||||||
|---|---|---|---|---|---|---|---|
| name | RL Unplugged |
||||||
| url | https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged |
||||||
| sameAs | https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged |
||||||
| description |
Data accompanying
[RL Unplugged: Benchmarks for Offline Reinforcement Learning]().
|
||||||
| provider |
|
||||||
| citation | https://identifiers.org/arxiv:2006.13888 |
| property | value | ||||||
|---|---|---|---|---|---|---|---|
| name | Sketchy |
||||||
| url | https://github.com/deepmind/deepmind-research/tree/master/sketchy |
||||||
| sameAs | https://github.com/deepmind/deepmind-research/tree/master/sketchy |
||||||
| description |
Data accompanying
[Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
|
||||||
| provider |
|
||||||
| citation | https://identifiers.org/arxiv:1909.12200 |