Full Code of for-ai/TD for AI

master 877e3b9d1491 cached
50 files
136.8 KB
40.5k tokens
288 symbols
1 requests
Download .txt
Repository: for-ai/TD
Branch: master
Commit: 877e3b9d1491
Files: 50
Total size: 136.8 KB

Directory structure:
gitextract_tcky92iz/

├── .gitignore
├── .travis.yml
├── README.md
├── __init__.py
├── data/
│   ├── __init__.py
│   ├── data_generators/
│   │   ├── __init__.py
│   │   ├── cifar_generator.py
│   │   ├── generator_utils.py
│   │   └── mnist_generator.py
│   ├── dataset_maps.py
│   ├── image_reader.py
│   ├── imagenet_augs.py
│   └── registry.py
├── hparams/
│   ├── __init__.py
│   ├── basic.py
│   ├── defaults.py
│   ├── lenet.py
│   ├── registry.py
│   ├── resnet.py
│   ├── user.py
│   ├── utils.py
│   └── vgg.py
├── models/
│   ├── __init__.py
│   ├── basic/
│   │   ├── __init__.py
│   │   └── basic.py
│   ├── lenet/
│   │   ├── __init__.py
│   │   └── lenet.py
│   ├── registry.py
│   ├── resnet/
│   │   ├── __init__.py
│   │   └── resnet.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── dropouts.py
│   │   ├── initializations.py
│   │   ├── model_utils.py
│   │   └── optimizers.py
│   └── vgg/
│       ├── __init__.py
│       └── vgg.py
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   └── prune/
│       ├── README.md
│       ├── __init__.py
│       ├── eval.py
│       └── prune.py
├── train.py
└── training/
    ├── __init__.py
    ├── envs.py
    ├── flags.py
    ├── lr_schemes.py
    └── tpu.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

tmp
runs
run

# PyCharm
.idea/

# macOS metadata
.DS_Store

.vscode

================================================
FILE: .travis.yml
================================================
language: python
python:
  - "3.6"

# command to install dependencies
install:
  - pip install -r requirements.txt

# command to run tests
script:
  - export FILES="$(git diff --name-only $TRAVIS_COMMIT_RANGE)"
  - cd /home/travis/build/for-ai
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams mnist_basic_no_dropout
  
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_weight
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_trgtd_weight
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_unit
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_trgtd_unit
  
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_weight
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_trgtd_weight
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_unit
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_trgtd_unit
  
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_no_dropout
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_untargeted_dropout
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_targeted_dropout
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_untargeted_unit_dropout
  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_targeted_unit_dropout


================================================
FILE: README.md
================================================
# Targeted Dropout

Aidan N. Gomez, Ivan Zhang, Kevin Swersky, Yarin Gal, and Geoffrey E. Hinton

## Table of Contents
- [Requirements](#requirements)
- [Quick Start](#quick-start)
- [Experiments](#experiments)

## Requirements
- Python 3
- Tensorflow 1.8

## Quick Start
1. Train a model: `python -m TD.train --hparams=resnet_default`
2. Prune that model: `python -m TD.scripts.prune.eval --hparams=resnet_default --prune_percent 0.0,0.25,0.5,0.75,0.95`

### Flags
- `--env`: one of `local`, `gcp` (GPU instances), or `tpu` (TPU instances). Feel free to add more if necessary.
- `--hparams`: the hparam set you want to run.
- `--hparam_override`: manually specify hparams to be overridden (e.g `--hparam_override 'drop_rate=0.66'`)


================================================
FILE: __init__.py
================================================
__all__ = ["data", "hparams", "models", "training"]

from .data import *
from .hparams import *
from .models import *
from .training import *


================================================
FILE: data/__init__.py
================================================
__all__ = [
    "image_reader",
    "registry",
    "dataset_maps",
]


================================================
FILE: data/data_generators/__init__.py
================================================
__all__ = [
    "cifar_generator",
    "generator_utils",
    "mnist_generator",
]


================================================
FILE: data/data_generators/cifar_generator.py
================================================
try:
  import cPickle
except ImportError:
  import pickle as cPickle
import os
import random
import sys
import tarfile
import urllib.request
import numpy as np
import tensorflow as tf

from .generator_utils import generate_files
from ...models.utils.model_utils import ModeKeys

FLAGS = tf.app.flags.FLAGS

_URL = "http://www.cs.toronto.edu/~kriz/"
_CIFAR10_TAR = "cifar-10-python.tar.gz"
_CIFAR10_DIR = "cifar-10-batches-py"
_CIFAR10_TRAIN = [
    "data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4",
    "data_batch_5"
]
_CIFAR10_TEST = ["test_batch"]

_CIFAR100_TAR = "cifar-100-python.tar.gz"
_CIFAR100_DIR = "cifar-100-python"
_CIFAR100_TRAIN = ["train"]
_CIFAR100_TEST = ["test"]

_WORKING_DIR = "/tmp/tf_data"


def download(v100):
  archive = _CIFAR100_TAR if v100 else _CIFAR10_TAR
  filepath = os.path.join(_WORKING_DIR, archive)
  if not os.path.exists(_WORKING_DIR):
    os.makedirs(_WORKING_DIR)
  url = _URL + archive
  if not os.path.isfile(filepath):
    print("Downloading " + url)
    urllib.request.urlretrieve(url, filepath)
  print("Extracting " + filepath)
  tar = tarfile.open(filepath, "r:gz")
  tar.extractall(path=_WORKING_DIR)
  tar.close()


def maybe_download(files, v100):
  for file in files:
    filepath = os.path.join(_WORKING_DIR, _CIFAR100_DIR
                            if v100 else _CIFAR10_DIR, file)
    if not os.path.isfile(filepath):
      download(v100)
      break


def read_files(files, v100):
  images = None
  labels = None
  for file in files:
    filename = os.path.join(_WORKING_DIR, _CIFAR100_DIR
                            if v100 else _CIFAR10_DIR, file)
    data = None
    with tf.gfile.Open(filename, "rb") as f:
      if sys.version_info < (3,):
        data = cPickle.load(f)
      else:
        data = cPickle.load(f, encoding="bytes")

    info = np.transpose(data[b"data"].reshape((-1, 3, 32, 32)), (0, 2, 3, 1))
    if images is None:
      images = info
    else:
      images = np.concatenate((images, info))

    info = data[b"fine_labels"] if v100 else data[b"labels"]
    if labels is None:
      labels = info
    else:
      labels = np.concatenate((labels, info))
  return images, labels


def cifar_generator(v100, mode):
  files = None
  if v100:
    files = _CIFAR100_TRAIN if mode != ModeKeys.TEST else _CIFAR100_TEST
  else:
    files = _CIFAR10_TRAIN if mode != ModeKeys.TEST else _CIFAR10_TEST
  maybe_download(files, v100)

  images, labels = read_files(files, v100)
  data = list(zip(images, labels))
  random.shuffle(data)
  
  samples = len(data)
  if mode == ModeKeys.TRAIN:
    data = data[:int(samples * 0.8)]
  elif mode == ModeKeys.EVAL:
    data = data[int(samples * 0.8):]

  image_ph = tf.placeholder(dtype=tf.uint8, shape=(32, 32, 3))
  encoded_ph = tf.image.encode_png(image_ph)

  sess = tf.Session()
  for image, label in data:
    encoded_im = sess.run(encoded_ph, feed_dict={image_ph: image})
    yield {
        "image/encoded": [encoded_im],
        "image/format": [b"png"],
        "image/class/label": [label],
        "image/height": [32],
        "image/width": [32],
        "image/channels": [3]
    }


def generate(train_name, eval_name, test_name, hparams):
  v100 = hparams.data in ["cifar100", "cifar100_tpu"]
  generate_files(
      cifar_generator(v100, mode=ModeKeys.TRAIN), train_name, hparams.data_dir,
      FLAGS.num_shards)
  generate_files(
      cifar_generator(v100, mode=ModeKeys.EVAL), eval_name, hparams.data_dir,
      FLAGS.num_shards)
  generate_files(
      cifar_generator(v100, mode=ModeKeys.TEST), test_name, hparams.data_dir,
      FLAGS.num_shards)


================================================
FILE: data/data_generators/generator_utils.py
================================================
import operator
import os
import numpy as np
import tensorflow as tf

tf.flags.DEFINE_boolean("v100", False,
                        "Download CIFAR-100 instead of CIFAR-10.")
tf.flags.DEFINE_integer("num_shards", 1,
                        "The number of output shards to write to.")


def to_example(dictionary):
  features = {}
  for k, v in dictionary.items():
    if len(v) == 0:
      raise Exception("Empty field: %s" % str((k, v)))
    if isinstance(v[0], (int, np.int8, np.int32, np.int64)):
      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
    elif isinstance(v[0], (float, np.float32)):
      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
    elif isinstance(v[0], (str, bytes)):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    else:
      raise Exception("Unsupported type: %s" % type(v[0]))
  return tf.train.Example(features=tf.train.Features(feature=features))


def generate_files(generator,
                   output_name,
                   output_dir,
                   num_shards,
                   max_cases=None):
  if not tf.gfile.Exists(output_dir):
    tf.gfile.MakeDirs(output_dir)

  writers = []
  for shard in range(num_shards):
    output_filename = "%s-%dof%d" % (output_name, shard + 1, num_shards)
    output_file = os.path.join(output_dir, output_filename)
    writers.append(tf.python_io.TFRecordWriter(output_file))

  counter, shard = 0, 0
  for case in generator:
    if counter % 100 == 0:
      tf.logging.info("Processed %d examples..." % counter)
    counter += 1
    if max_cases and counter > max_cases:
      break
    sequence_example = to_example(case)
    writers[shard].write(sequence_example.SerializeToString())
    shard = (shard + 1) % num_shards

  for writer in writers:
    writer.close()


================================================
FILE: data/data_generators/mnist_generator.py
================================================
import gzip
import os
import random
import urllib
import numpy as np
import tensorflow as tf

from .generator_utils import generate_files
from ...models.utils.model_utils import ModeKeys

FLAGS = tf.app.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)

_TRAIN_IMAGE_COUNT = 60000
_TRAIN_IMAGE_FILE = "train-images-idx3-ubyte.gz"
_TRAIN_LABEL_FILE = "train-labels-idx1-ubyte.gz"

_TEST_IMAGE_COUNT = 10000
_TEST_IMAGE_FILE = "t10k-images-idx3-ubyte.gz"
_TEST_LABEL_FILE = "t10k-labels-idx1-ubyte.gz"

_WORKING_DIR = "/tmp/tf_data"


def download_files(filenames):
  """Download files to tmp/data if file does not exist
  Args:
    filenames: list of string; list of filenames to check if exist
  """
  if not os.path.exists(_WORKING_DIR):
    os.makedirs(_WORKING_DIR)
  for filename in filenames:
    filepath = os.path.join(_WORKING_DIR, filename)
    url = "http://yann.lecun.com/exdb/mnist/" + filename
    if not os.path.isfile(filepath):
      print("Downloading %s" % (url + filename))
      try:
        urllib.urlretrieve(url, filepath)
      except AttributeError:
        urllib.request.urlretrieve(url, filepath)


def read_images(filepath, num_images):
  with gzip.open(filepath) as f:
    f.read(16)
    buf = f.read(28 * 28 * num_images)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, 28, 28, 1)
  return data


def read_labels(filepath, num_labels):
  with gzip.open(filepath) as f:
    f.read(8)
    buf = f.read(num_labels)
    data = np.frombuffer(buf, dtype=np.uint8)
  return data.astype(np.int64)


def mnist_generator(mode):
  num_images = _TRAIN_IMAGE_COUNT if mode != ModeKeys.TEST else _TEST_IMAGE_COUNT
  image_filepath = _TRAIN_IMAGE_FILE if mode != ModeKeys.TEST else _TEST_IMAGE_FILE
  label_filepath = _TRAIN_LABEL_FILE if mode != ModeKeys.TEST else _TEST_LABEL_FILE

  download_files([image_filepath, label_filepath])

  image_filepath = os.path.join(_WORKING_DIR, image_filepath)
  label_filepath = os.path.join(_WORKING_DIR, label_filepath)

  images = read_images(image_filepath, num_images)
  labels = read_labels(label_filepath, num_images)

  data = list(zip(images, labels))
  random.shuffle(data)
  
  if mode == ModeKeys.TRAIN:
    data = data[:5*num_images//6]
  elif mode == ModeKeys.EVAL:
    data = data[5*num_images//6:]

  image_ph = tf.placeholder(dtype=tf.uint8, shape=(28, 28, 1))
  encoded_ph = tf.image.encode_png(image_ph)

  sess = tf.Session()
  for image, label in data:
    encoded_im = sess.run(encoded_ph, feed_dict={image_ph: image})
    yield {
        "image/encoded": [encoded_im],
        "image/format": [b"png"],
        "image/class/label": [label],
        "image/height": [28],
        "image/width": [28]
    }


def generate(train_name, eval_name, test_name, hparams):
  generate_files(
      mnist_generator(mode=ModeKeys.TRAIN), train_name, hparams.data_dir, 1)
  generate_files(
      mnist_generator(mode=ModeKeys.EVAL), eval_name, hparams.data_dir, 1)
  generate_files(
      mnist_generator(mode=ModeKeys.TEST), test_name, hparams.data_dir, 1)


================================================
FILE: data/dataset_maps.py
================================================
import tensorflow as tf
from . import imagenet_augs 

_AUGMENTATIONS = dict()


def register(fn):
  global _AUGMENTATIONS
  _AUGMENTATIONS[fn.__name__] = fn
  return fn


def get_augmentation(name, params, training):

  def fn(*args, **kwargs):
    return _AUGMENTATIONS[name](
        *args, **kwargs, training=training, params=params)

  return fn


@register
def cifar_augmentation(image, label, training, params):
  """Image augmentation suitable for CIFAR-10/100.
  As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5).
  Args:
    images: a Tensor.
  Returns:
    Tensor of the same shape as images.
  """
  if training:
    image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
    image = tf.random_crop(image, [32, 32, 3])
    image = tf.image.random_flip_left_right(image)

  image = tf.image.per_image_standardization(image)
  return image, label

@register
def imagenet_augmentation(image, label, training, params):
  """Imagenet augmentations.
  Args:
    images: a Tensor.
  Returns:
    Tensor of the same shape as images.
  """
  if training:
    image = imagenet_augs.preprocess_for_train(image, params.input_shape[0])
  else:
    image = imagenet_augs.preprocess_for_eval(image, params.input_shape[0])
  return image, label


@register
def load_images(example, training, params):
  data_fields_to_features = {
      "image/encoded": tf.FixedLenFeature((), tf.string),
      "image/format": tf.FixedLenFeature((), tf.string),
      "image/class/label": tf.FixedLenFeature((), tf.int64)
  }

  example = tf.parse_single_example(example, data_fields_to_features)
  image = example["image/encoded"]
  image = tf.image.decode_png(image, channels=params.channels, dtype=tf.uint8)
  image = tf.to_float(image)

  label = tf.to_int32(example["image/class/label"])

  return image, label

@register
def set_shapes(image, label, training, params):
  image = tf.reshape(image, params.input_shape)
  return image, label
@register
def transpose(image, label, training, params):
  image = tf.transpose(image, [2, 0, 1])
  return image, label

================================================
FILE: data/image_reader.py
================================================
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

from .registry import register
from .dataset_maps import get_augmentation
from .data_generators import cifar_generator, mnist_generator


@register("imagenet", None)
@register("mnist", mnist_generator.generate)
@register("cifar10", cifar_generator.generate)
@register("cifar100", cifar_generator.generate)
def image_reader(data_sources, hparams, training):
  """Input function for image data."""

  def _input_fn(params=None):
    """Input function compatible with Experiment API."""
    if params is not None and "batch_size" in params:
      hparams.batch_size = params["batch_size"]

    dataset = tf.data.TFRecordDataset(
        data_sources, num_parallel_reads=4 if training else 1)
    dataset = dataset.prefetch(5 * hparams.batch_size)

    if hparams.shuffle_data:
      dataset = dataset.shuffle(5 * hparams.batch_size)

    dataset = dataset.map(get_augmentation("load_images", hparams, training))

    if hparams.data_augmentations is not None:
      for augmentation_name in hparams.data_augmentations:
        dataset = dataset.map(
            get_augmentation(augmentation_name, hparams, training))

    dataset = dataset.map(get_augmentation("set_shapes", hparams, training))
    if hparams.data_format == "channels_first":
      dataset = dataset.map(get_augmentation("transpose", hparams, training))
    dataset = dataset.repeat().batch(hparams.batch_size)
    dataset_it = dataset.make_one_shot_iterator()

    images, labels = dataset_it.get_next()
    if params is not None and "batch_size" in params:
      images = tf.reshape(images,
                          [hparams.batch_size] + images.shape.as_list()[1:])
      labels = tf.reshape(labels,
                          [hparams.batch_size] + labels.shape.as_list()[1:])
    return {"inputs": images, "labels": labels}, labels

  return _input_fn


@register("mnist_simple", None)
def mnist_simple(data_source, params, training):
  """Input function for MNIST image data."""

  mnist = input_data.read_data_sets(data_source, one_hot=True)

  data_set = mnist.train if training else mnist.test

  def _input_fn():
    input_images = tf.constant(data_set.images)

    input_labels = tf.constant(
        data_set.labels) if not params.is_ae else tf.constant(data_set.images)

    image, label = tf.train.slice_input_producer([input_images, input_labels])

    imageBatch, labelBatch = tf.train.batch(
        [image, label], batch_size=params.batch_size)

    return {"inputs": imageBatch}, labelBatch

  return _input_fn


@register("fashion", None)
def fashion(data_source, params, training):
  """Input function for MNIST image data."""

  mnist = input_data.read_data_sets(
      data_source,
      source_url='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/',
      one_hot=True)

  data_set = mnist.train if training else mnist.test

  def _input_fn():
    input_images = tf.constant(data_set.images)

    input_labels = tf.constant(data_set.labels)
    image, label = tf.train.slice_input_producer([input_images, input_labels])

    imageBatch, labelBatch = tf.train.batch(
        [image, label], batch_size=params.batch_size)

    return {"inputs": imageBatch}, labelBatch

  return _input_fn


================================================
FILE: data/imagenet_augs.py
================================================
import tensorflow as tf

MEAN_RGB = [0.485, 0.456, 0.406]
STDDEV_RGB = [0.229, 0.224, 0.225]


# The following preprocessing functions were taken from
# cloud_tpu/models/resnet/resnet_preprocessing.py
# ==============================================================================
def _crop(image, offset_height, offset_width, crop_height, crop_width):
  """Crops the given image using the provided offsets and sizes.
  Note that the method doesn't assume we know the input image size but it does
  assume we know the input image rank.
  Args:
    image: `Tensor` image of shape [height, width, channels].
    offset_height: `Tensor` indicating the height offset.
    offset_width: `Tensor` indicating the width offset.
    crop_height: the height of the cropped image.
    crop_width: the width of the cropped image.
  Returns:
    the cropped (and resized) image.
  Raises:
    InvalidArgumentError: if the rank is not 3 or if the image dimensions are
      less than the crop size.
  """
  original_shape = tf.shape(image)

  rank_assertion = tf.Assert(
      tf.equal(tf.rank(image), 3), ["Rank of image must be equal to 3."])
  with tf.control_dependencies([rank_assertion]):
    cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]])

  size_assertion = tf.Assert(
      tf.logical_and(
          tf.greater_equal(original_shape[0], crop_height),
          tf.greater_equal(original_shape[1], crop_width)),
      ["Crop size greater than the image size."])

  offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))

  # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
  # define the crop size.
  with tf.control_dependencies([size_assertion]):
    image = tf.slice(image, offsets, cropped_shape)
  return tf.reshape(image, cropped_shape)


def distorted_bounding_box_crop(image,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
  """Generates cropped_image using a one of the bboxes randomly distorted.
  See `tf.image.sample_distorted_bounding_box` for more documentation.
  Args:
    image: `Tensor` of image (it will be converted to floats in [0, 1]).
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
        where each coordinate is [0, 1) and the coordinates are arranged
        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
        image.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
        area of the image must contain at least this fraction of any bounding
        box supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
        image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image
        must contain a fraction of the supplied image within in this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
        region of the image of the specified constraints. After `max_attempts`
        failures, return the entire image.
    scope: Optional `str` for name scope.
  Returns:
    (cropped image `Tensor`, distorted bbox `Tensor`).
  """
  with tf.name_scope(
      scope, default_name="distorted_bounding_box_crop", values=[image, bbox]):
    # Each bounding box has shape [1, num_boxes, box coords] and
    # the coordinates are ordered [ymin, xmin, ymax, xmax].

    # A large fraction of image datasets contain a human-annotated bounding
    # box delineating the region of the image containing the object of interest.
    # We choose to create a new bounding box for the object which is a randomly
    # distorted version of the human-annotated bounding box that obeys an
    # allowed range of aspect ratios, sizes and overlap with the human-annotated
    # bounding box. If no box is supplied, then we assume the bounding box is
    # the entire image.
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        tf.shape(image),
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    cropped_image = tf.slice(image, bbox_begin, bbox_size)
    return cropped_image, distort_bbox


def _random_crop(image, size):
  """Make a random crop of (`size` x `size`)."""
  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  random_image, bbox = distorted_bounding_box_crop(
      image,
      bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=1,
      scope=None)
  bad = _at_least_x_are_true(tf.shape(image), tf.shape(random_image), 3)

  image = tf.cond(
      bad, lambda: _center_crop(_do_scale(image, size), size),
      lambda: tf.image.resize_bicubic([random_image], [size, size])[0])
  return image


def _flip(image):
  """Random horizontal image flip."""
  image = tf.image.random_flip_left_right(image)
  return image


def _at_least_x_are_true(a, b, x):
  """At least `x` of `a` and `b` `Tensors` are true."""
  match = tf.equal(a, b)
  match = tf.cast(match, tf.int32)
  return tf.greater_equal(tf.reduce_sum(match), x)


def _do_scale(image, size):
  """Rescale the image by scaling the smaller spatial dimension to `size`."""
  shape = tf.cast(tf.shape(image), tf.float32)
  w_greater = tf.greater(shape[0], shape[1])
  shape = tf.cond(
      w_greater, lambda: tf.cast([shape[0] / shape[1] * size, size], tf.int32),
      lambda: tf.cast([size, shape[1] / shape[0] * size], tf.int32))

  return tf.image.resize_bicubic([image], shape)[0]


def _center_crop(image, size):
  """Crops to center of image with specified `size`."""
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]

  offset_height = ((image_height - size) + 1) / 2
  offset_width = ((image_width - size) + 1) / 2
  image = _crop(image, offset_height, offset_width, size, size)
  return image


def _normalize(image):
  """Normalize the image to zero mean and unit variance."""
  offset = tf.constant(MEAN_RGB, shape=[1, 1, 3])
  image -= offset

  scale = tf.constant(STDDEV_RGB, shape=[1, 1, 3])
  image /= scale
  return image


def preprocess_for_train(image, image_size=224):
  """Preprocesses the given image for evaluation.
  Args:
    image: `Tensor` representing an image of arbitrary size.
    image_size: int, how large the output image should be.
  Returns:
    A preprocessed image `Tensor`.
  """
  image = _random_crop(image, image_size)
  image = _normalize(image)
  image = _flip(image)
  image = tf.reshape(image, [image_size, image_size, 3])
  return image


def preprocess_for_eval(image, image_size=224):
  """Preprocesses the given image for evaluation.
  Args:
    image: `Tensor` representing an image of arbitrary size.
    image_size: int, how large the output image should be.
  Returns:
    A preprocessed image `Tensor`.
  """
  image = _do_scale(image, image_size + 32)
  image = _normalize(image)
  image = _center_crop(image, image_size)
  image = tf.reshape(image, [image_size, image_size, 3])
  return image


================================================
FILE: data/registry.py
================================================
import os

import tensorflow as tf

_INPUT_FNS = dict()
_GENERATORS = dict()


def register(name, generator):

  def add_to_dict(fn):
    global _INPUT_FNS
    global _GENERATORS
    _INPUT_FNS[name] = fn
    _GENERATORS[name] = generator
    return fn

  return add_to_dict


def get_input_fns(hparams, generate=True):
  train_path = os.path.join(hparams.data_dir, "train*")
  eval_path = os.path.join(hparams.data_dir, "eval*")
  test_path = os.path.join(hparams.data_dir, "test*")

  if generate:
    if not tf.gfile.Exists(hparams.data_dir):
      tf.gfile.MakeDirs(hparams.data_dir)

    # generate if train doesnt exist
    maybe_generate(train_path, hparams)
    maybe_generate(eval_path, hparams)
    maybe_generate(test_path, hparams)

  train_path = tf.gfile.Glob(train_path)
  eval_path = tf.gfile.Glob(eval_path)
  test_path = tf.gfile.Glob(test_path)

  input_fn = _INPUT_FNS[hparams.data]
  train_fn = input_fn(train_path, hparams, training=True)
  eval_fn = None if not eval_path else input_fn(
      eval_path, hparams, training=False)
  test_fn = None if not test_path else input_fn(
      test_path, hparams, training=False)
  if not (eval_path or test_path):
    raise Exception("Could not find eval or test files.")
  return train_fn, eval_fn, test_fn


def get_dataset(hparams):
  train_path = os.path.join(hparams.data_dir, "train*")
  eval_path = os.path.join(hparams.data_dir, "eval*")
  test_path = os.path.join(hparams.data_dir, "test*")
  maybe_generate(train_path, hparams)
  maybe_generate(eval_path, hparams)
  maybe_generate(test_path, hparams)
  return train_path, eval_path, test_path


def maybe_generate(check_path, hparams):
  if not tf.gfile.Glob(check_path):
    generate_fn = _GENERATORS[hparams.data]
    if generate_fn:
      generate_fn("train", "eval", "test", hparams)
    else:
      tf.logging.warn(
          "No generator function. Unable to generate: %s" % check_path)


================================================
FILE: hparams/__init__.py
================================================
__all__ = ["defaults", "registry", "resnet", "lenet", "utils", "vgg", "basic"]

from .defaults import *
from .resnet import *
from .registry import *
from .user import *
from .utils import *
from .lenet import *
from .basic import *
from .vgg import *
from .basic import *


================================================
FILE: hparams/basic.py
================================================
import tensorflow as tf

from . import defaults
from .registry import register


# MNIST =========================
@register
def mnist_basic_no_dropout():
  hps = defaults.default()
  hps.model = "basic"
  hps.data = "mnist"
  hps.activation = "relu"
  hps.batch_norm = False
  hps.drop_rate = 0.0
  hps.dropout_type = None
  hps.initializer = "glorot_uniform_initializer"
  hps.layers = [128, 64, 32]
  hps.input_shape = [784]
  hps.output_shape = [10]
  hps.layer_type = "dense"

  hps.learning_rate = 0.1
  hps.optimizer = "momentum"
  hps.momentum = 0.0

  return hps


@register
def mnist_basic_trgtd_dropout():
  hps = mnist_basic_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.5

  return hps


@register
def mnist_basic_untrgtd_dropout():
  hps = mnist_basic_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_weight"

  return hps


@register
def mnist_basic_trgtd_dropout_random():
  hps = mnist_basic_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_weight_random"
  hps.targ_rate = 0.5

  return hps


@register
def mnist_basic_trgtd_unit_dropout():
  hps = mnist_basic_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.5

  return hps


@register
def mnist_basic_smallify_dropout_1eneg4():
  hps = mnist_basic_no_dropout()
  hps.dropout_type = "smallify_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5

  return hps


@register
def mnist_basic_smallify_dropout_1eneg3():
  hps = mnist_basic_smallify_dropout_1eneg4()
  hps.smallify = 1e-3

  return hps


@register
def mnist_basic_smallify_weight_dropout_1eneg4():
  hps = mnist_basic_no_dropout()
  hps.dropout_type = "smallify_weight_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5

  return hps


@register
def cifar10_basic_no_dropout():
  hps = defaults.default()
  hps.model = "basic"
  hps.data = "cifar10"
  hps.activation = "relu"
  hps.batch_norm = False
  hps.drop_rate = 0.0
  hps.dropout_type = None
  hps.initializer = "glorot_uniform_initializer"
  hps.layers = [128, 64, 32]
  hps.channels = 3
  hps.input_shape = [32, 32, 3]
  hps.output_shape = [10]
  hps.layer_type = "dense"

  hps.learning_rate = 0.1
  hps.optimizer = "momentum"
  hps.momentum = 0.0

  return hps


@register
def cifar100_basic_no_dropout():
  hps = cifar10_basic_no_dropout()
  hps.output_shape = [100]
  hps.data = "cifar100"
  return hps


@register
def imagenet32_basic():
  hps = defaults.default_imagenet32()
  hps.model = "basic"
  hps.activation = "relu"
  hps.batch_norm = False
  hps.drop_rate = 0.0
  hps.dropout_type = None
  hps.initializer = "glorot_uniform_initializer"
  hps.layers = [128, 64, 32]
  hps.layer_type = "dense"
  hps.learning_rate = 0.1
  hps.optimizer = "momentum"
  hps.momentum = 0.0
  return hps

================================================
FILE: hparams/defaults.py
================================================
import tensorflow as tf

from .registry import register
from .utils import HParams


@register
def default():
  return HParams(
      model=None,
      data=None,
      shuffle_data=True,
      data_augmentations=None,
      train_epochs=256,
      eval_steps=100,
      type="image",
      batch_size=64,
      learning_rate=0.01,
      lr_scheme="constant",
      initializer="glorot_normal_initializer",
      delay=0,
      staircased=False,
      learning_rate_decay_interval=2000,
      learning_rate_decay_rate=0.1,
      clip_grad_norm=1.0,
      l2_loss=0.0,
      prune_val=0.8,
      label_smoothing=0.1,
      use_tpu=False,
      momentum=0.9,
      init_scheme="random",
      warmup_steps=10000,
      use_nesterov=False,
      louizos_cost=0.0,
      l1_norm=0.0,
      thresh=2.5,
      fixed=False,
      var_scale=1,
      klscale=1.0,
      ard_cost=0.0,
      logit_packing=0.0,
      logit_squeezing=0.0,
      clp=0.0,
      logit_bound=None,
      dropout_type=None,
      smallify=0.0,
      smallify_delay=1000,
      linear_drop_rate=False,
      weight_decay_and_noise=False,
      weight_decay_only_features=True,
      weight_decay_weight_names=["DW", "kernel", "bias"],
      dropout_delay_steps=5000,
      grad_noise_scale=0.0,
      td_nines=0,
      targ_cost=1.0,
      aparams="",
      channels=1,
      data_format="channels_last",
      epoch_size=50000,
  )


@register
def default_cifar10():
  hps = default()
  hps.data = "cifar10"
  hps.data_augmentations = ["cifar_augmentation"]
  hps.epoch_size = 50000  # number of images in train set

  hps.input_shape = [32, 32, 3]
  hps.output_shape = [10]
  hps.channels = 3
  hps.num_classes = 10

  return hps


@register
def default_cifar100():
  hps = default_cifar10()
  hps.data = "cifar100"
  hps.output_shape = [100]
  hps.num_classes = 100

  return hps


@register
def default_imagenet299():
  hps = default()
  hps.data = "imagenet"
  hps.data_augmentations = ["imagenet_augmentation"]
  hps.epoch_size = 1281167

  hps.input_shape = [299, 299, 3]
  hps.channels = 3
  hps.output_shape = [1001]
  hps.num_classes = 1001

  return hps


@register
def default_imagenet224():
  hps = default_imagenet299()
  hps.input_shape = [224, 224, 3]

  return hps


@register
def default_imagenet64():
  hps = default_imagenet299()
  hps.input_shape = [64, 64, 3]

  return hps


@register
def default_imagenet32():
  hps = default_imagenet299()
  hps.input_shape = [32, 32, 3]

  return hps


================================================
FILE: hparams/lenet.py
================================================
import tensorflow as tf

from .defaults import default, default_cifar10
from .registry import register

# lenet


@register
def cifar_lenet():
  hps = default_cifar10()

  hps.model = "lenet"

  hps.activation = "relu"
  hps.residual = True
  hps.initializer = "glorot_normal_initializer"
  hps.kernel_size = 5
  hps.lr_scheme = "constant"
  hps.batch_size = 128

  hps.learning_rate = 0.01
  hps.optimizer = "momentum"
  hps.momentum = 0.9
  hps.use_nesterov = True

  hps.drop_rate = 0.0
  hps.dropout_type = None
  hps.targ_rate = 0.0

  hps.axis_aligned_cost = False
  hps.clp = False
  hps.logit_squeezing = False

  return hps


@register
def cifar_lenet_no_dropout():
  hps = cifar_lenet()
  return hps


@register
def cifar_lenet_weight():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "untargeted_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_trgtd_weight():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.5
  hps.targ_rate = 0.5
  hps.dropout_type = "targeted_weight"
  return hps


@register
def cifar_lenet_unit():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_unit"
  return hps


@register
def cifar_lenet_trgtd_unit():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.5
  hps.targ_rate = 0.5
  hps.dropout_type = "targeted_unit"
  return hps


@register
def cifar_lenet_l1():
  hps = cifar_lenet_no_dropout()
  hps.l1_norm = 0.1
  return hps


@register
def cifar_lenet_trgtd_weight_l1():
  hps = cifar_lenet_no_dropout()
  hps.l1_norm = 0.1
  hps.drop_rate = 0.5
  hps.targ_rate = 0.5
  hps.dropout_type = "targeted_weight"
  return hps


@register
def cifar_lenet_trgtd_unit_l1():
  hps = cifar_lenet_no_dropout()
  hps.l1_norm = 0.1
  hps.drop_rate = 0.5
  hps.targ_rate = 0.5
  hps.dropout_type = "targeted_unit"
  return hps


@register
def cifar_lenet_trgtd_unit_botk75_33():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.33
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.75
  return hps


@register
def cifar_lenet_trgtd_unit_botk75_66():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.66
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.75
  return hps


@register
def cifar_lenet_trgtd_weight_botk75_33():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.33
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.75
  return hps


@register
def cifar_lenet_trgtd_weight_botk75_66():
  hps = cifar_lenet_no_dropout()
  hps.drop_rate = 0.66
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.75
  return hps


@register
def cifar_lenet_louizos_weight_1en3():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.001
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_weight_1en1():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.1
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_weight_1en2():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.01
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_weight_5en3():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.005
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_weight_1en4():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.0001
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_unit_1en3():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.001
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_unit_1en1():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.1
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_unit_1en2():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.01
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_unit_5en3():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.005
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_louizos_unit_1en4():
  hps = cifar_lenet_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.0001
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25
  return hps


@register
def cifar_lenet_variational():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "variational"
  hps.var_scale = 1. / 100
  hps.drop_rate = 0.75

  return hps


@register
def cifar_lenet_variational_unscaled():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "variational"
  hps.drop_rate = 0.75

  return hps


@register
def cifar_lenet_variational_unit():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "variational_unit"
  hps.var_scale = 1. / 100
  hps.drop_rate = 0.75

  return hps


@register
def cifar_lenet_variational_unit_unscaled():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "variational_unit"
  hps.drop_rate = 0.75

  return hps


@register
def cifar_lenet_smallify_neg4():
  hps = cifar_lenet_no_dropout()
  hps.dropout_type = "smallify_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5
  hps.smallify_delay = 10000
  return hps


================================================
FILE: hparams/registry.py
================================================
import tensorflow as tf

_HPARAMS = dict()


def register(fn):
  global _HPARAMS
  _HPARAMS[fn.__name__] = fn()
  return fn


def get_hparams(hparams_list):
  """Fetches a merged group of hyperparameter sets (chronological priority)."""
  final = tf.contrib.training.HParams()
  for name in hparams_list.split("-"):
    curr = _HPARAMS[name]
    final_dict = final.values()
    for k, v in curr.values().items():
      if k not in final_dict:
        final.add_hparam(k, v)
      elif final_dict[k] is None:
        setattr(final, k, v)
  return final


================================================
FILE: hparams/resnet.py
================================================
import tensorflow as tf

from .registry import register
from .defaults import *


# from https://github.com/tensorflow/models/blob/master/resnet/resnet_main.py
@register
def resnet_default():
  hps = default_cifar10()
  hps.model = "resnet"
  hps.residual_filters = [16, 32, 64, 128]
  hps.residual_units = [5, 5, 5]
  hps.use_bottleneck = False
  hps.batch_size = 128
  hps.learning_rate = 0.4
  hps.lr_scheme = "resnet"
  hps.weight_decay_rate = 2e-4
  hps.optimizer = "momentum"
  return hps


@register
def resnet102_imagenet224():
  hps = default_imagenet224()
  hps.model = "resnet"
  hps.residual_filters = [64, 64, 128, 256, 512]
  hps.residual_units = [3, 4, 23, 3]
  hps.use_bottleneck = True
  hps.batch_size = 128 * 8
  hps.learning_rate = 0.128 * hps.batch_size / 256.
  hps.lr_scheme = "warmup_cosine"
  hps.warmup_steps = 10000
  hps.weight_decay_rate = 1e-4
  hps.optimizer = "momentum"
  hps.use_nesterov = True
  hps.initializer = "variance_scaling_initializer"
  hps.learning_rate_cosine_cycle_steps = 120000
  hps.cosine_alpha = 0.0
  return hps


@register
def resnet102_imagenet64():
  hps = resnet102_imagenet224()
  hps.input_shape = [64, 64, 3]
  return hps


@register
def resnet50_imagenet224():
  hps = resnet102_imagenet224()
  hps.residual_units = [3, 4, 6, 3]
  return hps


@register
def resnet34_imagenet224():
  hps = resnet50_imagenet224()
  hps.use_bottleneck = False
  return hps


@register
def resnet_cifar100():
  hps = resnet_default()
  hps.num_classes = 100
  return hps


@register
def cifar10_resnet32():
  hps = resnet_default()

  return hps


@register
def cifar10_resnet32_no_dropout():
  hps = cifar10_resnet32()
  hps.drop_rate = 0.0

  return hps


@register
def cifar10_resnet32_trgtd_weight():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.5

  return hps


@register
def cifar10_resnet32_weight():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_weight"

  return hps


@register
def cifar10_resnet32_weight_50():
  hps = cifar10_resnet32_weight()
  hps.drop_rate = 0.50

  return hps


@register
def cifar10_resnet32_trgtd_unit():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.5

  return hps


@register
def cifar10_resnet32_trgtd_ard():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "targeted_ard"
  hps.targ_rate = 0.5

  return hps


@register
def cifar10_resnet32_unit():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_unit"

  return hps


@register
def cifar10_resnet32_unit_50():
  hps = cifar10_resnet32_unit()
  hps.drop_rate = 0.50

  return hps


@register
def cifar10_resnet32_l1_1eneg3():
  hps = cifar10_resnet32_no_dropout()
  hps.l1_norm = 0.001

  return hps


@register
def cifar10_resnet32_l1_1eneg2():
  hps = cifar10_resnet32_no_dropout()
  hps.l1_norm = 0.01

  return hps


@register
def cifar10_resnet32_l1_1eneg1():
  hps = cifar10_resnet32_no_dropout()
  hps.l1_norm = 0.1

  return hps


@register
def cifar10_resnet32_trgted_weight_l1():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.5
  hps.l1_norm = 0.1

  return hps


@register
def cifar10_resnet32_targeted_unit_l1():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.5
  hps.l1_norm = 0.1

  return hps


@register
def cifar10_resnet32_trgtd_unit_botk75_33():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.33
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.75

  return hps


@register
def cifar10_resnet32_trgtd_unit_botk75_66():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.66
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.75

  return hps


@register
def cifar10_resnet32_trgtd_weight_botk75_33():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.33
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.75

  return hps


@register
def cifar10_resnet32_trgtd_weight_botk75_66():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.66
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.75

  return hps


@register
def cifar10_resnet32_trgtd_unit_ramping_botk90_99():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.99
  hps.dropout_type = "targeted_unit_piecewise"
  hps.targ_rate = 0.90

  return hps


@register
def cifar10_resnet32_trgtd_weight_ramping_botk99_99():
  hps = cifar10_resnet32_no_dropout()
  hps.drop_rate = 0.99
  hps.dropout_type = "targeted_weight_piecewise"
  hps.targ_rate = 0.99
  hps.linear_drop_rate = True

  return hps


@register
def cifar10_resnet32_louizos_weight_1en3():
  hps = cifar10_resnet32_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.001
  hps.dropout_type = "louizos_weight"
  hps.drop_rate = 0.001

  return hps


@register
def cifar10_resnet32_louizos_weight_1en1():
  hps = cifar10_resnet32_louizos_weight_1en3()
  hps.louizos_cost = 0.1
  hps.dropout_type = "louizos_weight"

  return hps


@register
def cifar10_resnet32_louizos_weight_1en2():
  hps = cifar10_resnet32_louizos_weight_1en3()
  hps.louizos_cost = 0.01

  return hps


@register
def cifar10_resnet32_louizos_weight_5en3():
  hps = cifar10_resnet32_louizos_weight_1en3()
  hps.louizos_cost = 0.005

  return hps


@register
def cifar10_resnet32_louizos_weight_1en4():
  hps = cifar10_resnet32_louizos_weight_1en3()
  hps.louizos_cost = 0.0001

  return hps


@register
def cifar10_resnet32_louizos_unit_1en3():
  hps = cifar10_resnet32_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.001
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.001

  return hps


@register
def cifar10_resnet32_louizos_unit_1en1():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.1

  return hps


@register
def cifar10_resnet32_louizos_unit_1en2():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.01

  return hps


@register
def cifar10_resnet32_louizos_unit_5en3():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.005

  return hps


@register
def cifar10_resnet32_louizos_unit_1en4():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.0001

  return hps


@register
def cifar10_resnet32_louizos_unit_1en5():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.00001

  return hps


@register
def cifar10_resnet32_louizos_unit_1en6():
  hps = cifar10_resnet32_louizos_unit_1en3()
  hps.louizos_cost = 0.000001

  return hps


@register
def cifar10_resnet32_variational_weight():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "variational"
  hps.drop_rate = 0.75
  hps.thresh = 3
  hps.var_scale = 1. / 100
  hps.weight_decay_rate = None

  return hps


@register
def cifar10_resnet32_variational_weight_unscaled():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "variational"
  hps.drop_rate = 0.75
  hps.thresh = 3
  hps.var_scale = 1
  hps.weight_decay_rate = None

  return hps


@register
def cifar10_resnet32_variational_unit():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "variational_unit"
  hps.drop_rate = 0.75
  hps.thresh = 3
  hps.var_scale = 1. / 100
  hps.weight_decay_rate = None

  return hps


@register
def cifar10_resnet32_variational_unit_unscaled():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "variational_unit"
  hps.drop_rate = 0.75
  hps.thresh = 3
  hps.var_scale = 1
  hps.weight_decay_rate = None

  return hps


@register
def cifar10_resnet32_smallify_1eneg4():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "smallify_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5

  return hps


@register
def cifar10_resnet32_smallify_1eneg3():
  hps = cifar10_resnet32_smallify_1eneg4()
  hps.smallify = 1e-3

  return hps


@register
def cifar10_resnet32_smallify_1eneg5():
  hps = cifar10_resnet32_smallify_1eneg4()
  hps.smallify = 1e-5

  return hps


@register
def cifar10_resnet32_smallify_1eneg6():
  hps = cifar10_resnet32_smallify_1eneg4()
  hps.smallify = 1e-6

  return hps


@register
def cifar10_resnet32_smallify_weight_1eneg4():
  hps = cifar10_resnet32_no_dropout()
  hps.dropout_type = "smallify_weight_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5

  return hps


@register
def cifar10_resnet32_smallify_weight_1eneg3():
  hps = cifar10_resnet32_smallify_weight_1eneg4()
  hps.smallify = 1e-3

  return hps


@register
def cifar10_resnet32_smallify_weight_1eneg5():
  hps = cifar10_resnet32_smallify_weight_1eneg3()
  hps.smallify = 1e-5

  return hps


@register
def cifar10_resnet32_smallify_weight_1eneg6():
  hps = cifar10_resnet32_smallify_weight_1eneg3()
  hps.smallify = 1e-6

  return hps


# ================================


================================================
FILE: hparams/user.py
================================================
import tensorflow as tf

from .defaults import default
from .registry import register

# Add experimental hparams below


================================================
FILE: hparams/utils.py
================================================
import tensorflow as tf


class HParams(tf.contrib.training.HParams):
  """Override of TensorFlow's HParams.

  Replaces HParams.add_hparam(name, value) with simple attribute assignment.
    I.e. There is no need to explicitly add an hparam:
      Replace: `hparams.add_hparam("learning_rate", 0.1)`
      With:    `hparams.learning_rate = 0.1`
  """

  def __setattr__(self, name, value):
    """Adds {name, value} pair to hyperparameters.

    Args:
      name: Name of the hyperparameter.
      value: Value of the hyperparameter. Can be one of the following types:
        int, float, string, int list, float list, or string list.

    Raises:
      ValueError: if one of the arguments is invalid.
    """
    # Keys in kwargs are unique, but 'name' could the name of a pre-existing
    # attribute of this object.  In that case we refuse to use it as a
    # hyperparameter name.
    if name[0] == "_":
      object.__setattr__(self, name, value)
      return
    if isinstance(value, (list, tuple)):
      if not value:
        raise ValueError(
            'Multi-valued hyperparameters cannot be empty: %s' % name)
      self._hparam_types[name] = (type(value[0]), True)
    else:
      self._hparam_types[name] = (type(value), False)
    object.__setattr__(self, name, value)


================================================
FILE: hparams/vgg.py
================================================
import tensorflow as tf

from .registry import register
from .defaults import default, default_cifar10


# from https://github.com/tensorflow/models/blob/master/resnet/resnet_main.py
@register
def vgg16_default():
  vgg_default = default_cifar10()
  vgg_default.initializer = "glorot_uniform_initializer"
  vgg_default.model = "vgg"
  vgg_default.learning_rate = 0.01
  vgg_default.lr_scheme = "constant"
  vgg_default.weight_decay_rate = 0.0005
  vgg_default.num_classes = 10
  vgg_default.optimizer = "adam"
  vgg_default.adam_epsilon = 1e-6
  vgg_default.beta1 = 0.85
  vgg_default.beta2 = 0.997
  vgg_default.input_shape = [32, 32, 3]
  vgg_default.output_shape = [10]
  return vgg_default


@register
def cifar10_vgg16():
  hps = vgg16_default()
  hps.data = "cifar10"
  return hps


@register
def cifar100_vgg16_no_dropout():
  hps = vgg16_default()
  hps.data = "cifar100"

  hps.input_shape = [32, 32, 3]
  hps.output_shape = [100]
  hps.num_classes = 100
  hps.channels = 3
  hps.learning_rate = 0.0001
  return hps


@register
def cifar10_vgg16_no_dropout():
  hps = vgg16_default()
  hps.data = "cifar10"

  hps.input_shape = [32, 32, 3]
  hps.output_shape = [10]
  hps.num_classes = 10
  hps.channels = 3
  hps.learning_rate = 0.0001
  return hps


@register
def cifar100_vgg16_targeted_dropout():
  hps = cifar100_vgg16_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_weight"
  hps.targ_rate = 0.5
  return hps


@register
def cifar100_vgg16_untargeted_dropout():
  hps = cifar100_vgg16_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_weight"
  return hps


@register
def cifar100_vgg16_untargeted_unit_dropout():
  hps = cifar100_vgg16_no_dropout()
  hps.drop_rate = 0.25
  hps.dropout_type = "untargeted_unit"
  return hps


@register
def cifar100_vgg16_targeted_unit_dropout():
  hps = cifar100_vgg16_no_dropout()
  hps.drop_rate = 0.5
  hps.dropout_type = "targeted_unit"
  hps.targ_rate = 0.5
  return hps


@register
def cifar100_vgg16_targeted_unit_dropout_botk75_66():
  hps = cifar100_vgg16_targeted_unit_dropout()
  hps.drop_rate = 0.66
  hps.targ_rate = 0.75
  return hps


@register
def cifar100_vgg16_louizos_unit():
  hps = cifar100_vgg16_no_dropout()
  hps.louizos_beta = 2. / 3.
  hps.louizos_zeta = 1.1
  hps.louizos_gamma = -0.1
  hps.louizos_cost = 0.001
  hps.dropout_type = "louizos_unit"
  hps.drop_rate = 0.25

  return hps


@register
def cifar100_vgg16_louizos_weight():
  hps = cifar100_vgg16_louizos_unit()
  hps.dropout_type = "louizos_weight"

  return hps


@register
def cifar100_vgg16_variational_unscaled():
  hps = cifar100_vgg16_no_dropout()
  hps.dropout_type = "variational"
  hps.drop_rate = 0.75
  hps.thresh = 3
  hps.var_scale = 1
  hps.weight_decay_rate = 0.0

  return hps


@register
def cifar100_vgg16_variational():
  hps = cifar100_vgg16_variational_unscaled()
  hps.var_scale = 1. / 100

  return hps


@register
def cifar100_vgg16_variational_unit_unscaled():
  hps = cifar100_vgg16_variational_unscaled()
  hps.dropout_type = "variational_unit"

  return hps


@register
def cifar100_vgg16_variational_unit():
  hps = cifar100_vgg16_variational_unit_unscaled()
  hps.var_scale = 1. / 100

  return hps


@register
def cifar100_vgg16_smallify_1eneg4():
  hps = cifar100_vgg16_no_dropout()
  hps.dropout_type = "smallify_dropout"
  hps.smallify = 1e-4
  hps.smallify_mv = 0.9
  hps.smallify_thresh = 0.5

  return hps


@register
def cifar100_vgg16_smallify_weight_1eneg5():
  hps = cifar100_vgg16_smallify_1eneg4()
  hps.dropout_type = "smallify_weight_dropout"
  hps.smallify = 1e-5

  return hps


================================================
FILE: models/__init__.py
================================================
__all__ = ["basic", "registry", "resnet", "lenet", "vgg"]

from .basic import *
from .resnet import *
from .registry import *
from .lenet import *
from .vgg import *


================================================
FILE: models/basic/__init__.py
================================================
__all__ = ["basic"]

from .basic import *


================================================
FILE: models/basic/basic.py
================================================
import tensorflow as tf

from ..registry import register

from ..utils.activations import get_activation
from ..utils.initializations import get_init
from ..utils.optimizers import get_optimizer
from ..utils import model_utils


@register("basic")
def get_basic(params, lr):
  """Callable model function compatible with Experiment API.

  Args:
    params: a HParams object containing values for fields:
    lr: learning rate variable
  """

  def basic(features, labels, mode, _):
    """The basic neural net net template.

    Args:
      features: a dict containing key "inputs"
      mode: training, evaluation or infer
    """
    with tf.variable_scope("basic", initializer=get_init(params)):
      is_training = mode == tf.estimator.ModeKeys.TRAIN
      actvn = get_activation(params)
      x = features["inputs"]
      batch_size = tf.shape(x)[0]

      nonzero = 0
      activations = []
      for i, feature_count in enumerate(params.layers):
        with tf.variable_scope("layer_%d" % i):
          if params.layer_type == "dense":
            x, w = model_utils.collect_vars(
                lambda: model_utils.dense(x, feature_count, params, is_training)
            )
          elif params.layer_type == "conv":
            x, w = model_utils.collect_vars(lambda: tf.layers.conv2d(
                x, feature_count, params.kernel_size, padding="SAME"))
          if params.batch_norm:
            x = tf.layers.batch_normalization(x, training=is_training)
          x = actvn(x)
          activations.append(x)
      x = tf.reshape(x, [batch_size, params.layers[-1]])
      with tf.variable_scope('logit'):
        x = tf.layers.dense(x, params.output_shape[0], use_bias=False)

      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:
        predictions = {
            'classes': tf.argmax(x, axis=1),
            'logits': x,
            'probabilities': tf.nn.softmax(x, name='softmax_tensor'),
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=x)
      if params.smallify > 0.0:
        loss += model_utils.switch_loss() * params.smallify

      # Summaries
      # ========================
      if not params.use_tpu:
        tf.summary.scalar("nonzero", model_utils.nonzero_count())
        tf.summary.scalar("percent_sparsity", model_utils.percent_sparsity())
      # ========================

      return model_utils.model_top(labels, tf.nn.softmax(x, -1), loss, lr,
                                   mode, params)

  return basic


================================================
FILE: models/lenet/__init__.py
================================================
__all__ = ["lenet"]


================================================
FILE: models/lenet/lenet.py
================================================
import tensorflow as tf

from ..registry import register

from ..utils.activations import get_activation
from ..utils.dropouts import get_dropout
from ..utils.initializations import get_init
from ..utils.optimizers import get_optimizer
from ..utils import model_utils


@register("lenet")
def get_lenet(hparams, lr):
  """Callable model function compatible with Experiment API.

    Args:
      params: a HParams object containing values for fields:
      lr: learning rate variable
    """

  def _conv(name, x, filter_size, in_filters, out_filters, strides, mode):
    """Convolution."""
    with tf.variable_scope(name):
      kernel = tf.get_variable(
          'DW', [filter_size, filter_size, in_filters, out_filters],
          tf.float32)
      is_training = mode == tf.estimator.ModeKeys.TRAIN
      if hparams.dropout_type is not None:
        dropout_fn = get_dropout(hparams.dropout_type)
        kernel = dropout_fn(kernel, hparams, is_training)

        # special case for variational
        if hparams.dropout_type and "variational" in hparams.dropout_type:
          kernel, log_alpha = kernel[0], kernel[1]
          if is_training:
            conved_mu = tf.nn.conv2d(
                x, kernel, strides=strides, padding='VALID')
            conved_si = tf.sqrt(
                tf.nn.conv2d(
                    tf.square(x),
                    tf.exp(log_alpha) * tf.square(kernel),
                    strides=strides,
                    padding='VALID') + 1e-8)
            return conved_mu + tf.random_normal(
                tf.shape(conved_mu)) * conved_si, tf.count_nonzero(kernel)

      return tf.nn.conv2d(x, kernel, strides, padding='VALID')

  def lenet(features, labels, mode, params):
    """The lenet neural net net template.

            Args:
              features: a dict containing key "inputs"
              mode: training, evaluation or infer
            """
    with tf.variable_scope("lenet", initializer=get_init(hparams)):
      is_training = mode == tf.estimator.ModeKeys.TRAIN
      actvn = get_activation(hparams)

      if hparams.use_tpu and 'batch_size' in params.keys():
        hparams.batch_size = params['batch_size']

      # input layer
      x = features["inputs"]
      x = model_utils.standardize_images(x)

      # unflatten
      x = tf.reshape(x, [hparams.batch_size] + hparams.input_shape)

      # conv1
      b_conv1 = tf.get_variable(
          "Variable", initializer=tf.constant_initializer(0.1), shape=[6])
      h_conv1 = _conv('conv1', x, 5, 3, 6, [1, 1, 1, 1], mode) + b_conv1
      h_conv1 = tf.nn.relu(h_conv1)
      h_pool1 = tf.nn.max_pool(
          h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

      # conv2
      b_conv2 = tf.get_variable(
          "Variable_1", initializer=tf.constant_initializer(0.1), shape=[16])
      h_conv2 = _conv('conv2', h_pool1, 5, 6, 16, [1, 1, 1, 1], mode) + b_conv2
      h_conv2 = tf.nn.relu(h_conv2)
      h_pool2 = tf.nn.max_pool(
          h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

      # flatten for fc
      h_pool2_flat = tf.reshape(h_pool2, [hparams.batch_size, -1])

      # fc1
      with tf.variable_scope('fc1'):
        h_fc1 = tf.nn.relu(
            model_utils.dense(h_pool2_flat, 500, hparams, is_training))

      # fc2
      with tf.variable_scope('fc2'):
        y = model_utils.dense(h_fc1, 10, hparams, is_training, dropout=False)

      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:
        predictions = {
            'classes': tf.argmax(y, axis=1),
            'logits': y,
            'probabilities': tf.nn.softmax(y, name='softmax_tensor'),
        }

        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=y)

      if hparams.axis_aligned_cost:
        negativity_cost, axis_alignedness_cost, one_bound = model_utils.axis_aligned_cost(
            y, hparams)
        masked_max = tf.abs(y) * (
            1 - tf.one_hot(tf.argmax(tf.abs(y), -1), hparams.num_classes))
        tf.summary.scalar(
            "logit_prior",
            tf.reduce_mean(
                tf.to_float(
                    tf.logical_and(masked_max >= 0.0, masked_max <= 0.1))))
        tf.summary.scalar("avg_max",
                          tf.reduce_mean(tf.reduce_max(tf.abs(y), axis=-1)))
        loss += hparams.axis_aligned_cost * tf.reduce_mean(
            negativity_cost + axis_alignedness_cost + 20. * one_bound)

      if hparams.logit_squeezing:
        loss += hparams.logit_squeezing * tf.reduce_mean(y**2)

      if hparams.clp:
        loss += hparams.clp * tf.reduce_mean(
            (y[:hparams.batch_size // 2] - y[hparams.batch_size // 2:])**2)

      if hparams.dropout_type and "variational" in hparams.dropout_type:
        # prior DKL part of the ELBO
        graph = tf.get_default_graph()
        node_defs = [
            n for n in graph.as_graph_def().node if 'log_alpha' in n.name
        ]
        log_alphas = [
            graph.get_tensor_by_name(n.name + ":0") for n in node_defs
        ]
        divergences = [model_utils.dkl_qp(la) for la in log_alphas]
        # combine to form the ELBO
        N = float(50000)
        dkl = tf.reduce_sum(tf.stack(divergences))

        warmup_steps = 50000
        inv_base = tf.exp(tf.log(0.01) / warmup_steps)
        inv_decay = inv_base**(
            warmup_steps - tf.to_float(tf.train.get_global_step()))

        loss += (1. / N) * dkl * inv_decay * hparams.var_scale

      if hparams.smallify > 0.0:
        loss += model_utils.switch_loss() * hparams.smallify

      return model_utils.model_top(labels, tf.nn.softmax(y, -1), loss, lr,
                                   mode, hparams)

  return lenet


================================================
FILE: models/registry.py
================================================
from ..training.lr_schemes import get_lr

import tensorflow as tf

_MODELS = dict()


def register(name):

  def add_to_dict(fn):
    global _MODELS
    _MODELS[name] = fn
    return fn

  return add_to_dict


def get_model(hparams):

  def model_fn(features, labels, mode, params=None):
    lr = tf.constant(0.0)
    if mode == tf.estimator.ModeKeys.TRAIN:
      lr = get_lr(hparams)
    return _MODELS[hparams.model](hparams, lr)(features, labels, mode, params)

  return model_fn


================================================
FILE: models/resnet/__init__.py
================================================
__all__ = ["resnet"]


================================================
FILE: models/resnet/resnet.py
================================================
import tensorflow as tf
import numpy as np

from ..utils import dropouts
from ..utils.activations import get_activation
from ..utils.dropouts import get_dropout, smallify_dropout
from ..utils.initializations import get_init
from ..registry import register
from ..utils import model_utils
from ..utils.model_utils import ModeKeys
from ...training import tpu


@register("resnet")
def get_resnet(hparams, lr):
  """Callable model function compatible with Experiment API.

          Args:
            params: a HParams object containing values for fields:
              use_bottleneck: bool to bottleneck the network
              num_residual_units: number of residual units
              num_classes: number of classes
              batch_size: batch size
              weight_decay_rate: weight decay rate
          """

  def resnet(features, labels, mode, params):
    if hparams.use_tpu and 'batch_size' in params.keys():
      hparams.batch_size = params['batch_size']

    is_training = mode == tf.estimator.ModeKeys.TRAIN

    def _residual(x, out_filter, stride, projection=False):
      """Residual unit with 2 sub layers."""
      is_variational = hparams.dropout_type is not None and "variational" in hparams.dropout_type

      orig_x = x
      if not is_variational:
        x = model_utils.batch_norm(x, hparams, is_training)
        x = tf.nn.relu(x)

      if projection:
        orig_x = model_utils.conv(
            x,
            1,
            out_filter,
            hparams,
            is_training=is_training,
            strides=stride,
            name="shortcut")

      with tf.variable_scope('sub1'):
        x = model_utils.conv(
            x,
            3,
            out_filter,
            hparams,
            is_training=is_training,
            strides=stride,
            name='conv1')

        x = model_utils.batch_norm(x, hparams, is_training)
        x = tf.nn.relu(x)

      with tf.variable_scope('sub2'):
        x = model_utils.conv(
            x,
            3,
            out_filter,
            hparams,
            is_training=is_training,
            strides=[1, 1, 1, 1],
            name='conv2')

      x += orig_x

      return x

    def _bottleneck_residual(x, out_filter, stride, projection=False):
      """Residual unit with 3 sub layers."""

      is_variational = hparams.dropout_type is not None and "variational" in hparams.dropout_type

      orig_x = x
      if not is_variational:
        x = model_utils.batch_norm(x, hparams, is_training)
        x = tf.nn.relu(x)

      if projection:
        orig_x = model_utils.conv(
            x,
            1,
            4 * out_filter,
            hparams,
            is_training=is_training,
            strides=stride,
            name="shortcut")

      with tf.variable_scope('sub1'):
        x = model_utils.conv(
            x,
            1,
            out_filter,
            hparams,
            is_training=is_training,
            strides=[1, 1, 1, 1],
            name='conv1')
        x = model_utils.batch_norm(x, hparams, is_training)
        x = tf.nn.relu(x)
      with tf.variable_scope('sub2'):
        x = model_utils.conv(
            x,
            3,
            out_filter,
            hparams,
            is_training=is_training,
            strides=stride,
            name='conv2')
        x = model_utils.batch_norm(x, hparams, is_training)
        x = tf.nn.relu(x)
      with tf.variable_scope('sub3'):
        x = model_utils.conv(
            x,
            1,
            4 * out_filter,
            hparams,
            is_training=is_training,
            strides=[1, 1, 1, 1],
            name='conv3')

      return orig_x + x

    def _l1():
      """L1 weight decay loss."""
      if hparams.l1_norm == 0:
        return 0

      costs = []
      for var in tf.trainable_variables():
        if "DW" in var.name and "logit" not in var.name:
          costs.append(tf.reduce_mean(tf.abs(var)))

      return tf.multiply(hparams.l1_norm, tf.add_n(costs))

    def _fully_connected(x, out_dim):
      """FullyConnected layer for final output."""
      prev_dim = np.product(x.get_shape().as_list()[1:])
      x = tf.reshape(x, [hparams.batch_size, prev_dim])
      w = tf.get_variable('DW', [prev_dim, out_dim])
      b = tf.get_variable(
          'biases', [out_dim], initializer=tf.zeros_initializer())
      return tf.nn.xw_plus_b(x, w, b)

    def _global_avg_pool(x):
      assert x.get_shape().ndims == 4
      if hparams.data_format == "channels_last":
        return tf.reduce_mean(x, [1, 2])

      return tf.reduce_mean(x, [2, 3])

    def _stride_arr(stride):
      """Map a stride scalar to the stride array for tf.nn.conv2d."""
      if hparams.data_format == "channels_last":
        return [1, stride, stride, 1]

      return [1, 1, stride, stride]

    if mode == ModeKeys.PREDICT or mode == ModeKeys.ATTACK:
      if "labels" in features:
        labels = features["labels"]

    with tf.variable_scope("resnet", initializer=get_init(hparams)):
      hparams.mode = mode
      strides = [1, 2, 2, 2]
      res_func = (_residual
                  if not hparams.use_bottleneck else _bottleneck_residual)
      filters = hparams.residual_filters
      large_input = hparams.input_shape[0] > 32

      # 3 and 16 picked from example implementation
      with tf.variable_scope('init'):
        x = features["inputs"]
        stride = _stride_arr(2) if large_input else _stride_arr(1)
        x = model_utils.conv(
            x,
            7,
            filters[0],
            hparams,
            strides=stride,
            dropout=False,
            name='init_conv')

        if large_input:
          x = tf.layers.max_pooling2d(
              inputs=x,
              pool_size=3,
              strides=2,
              padding="SAME",
              data_format=hparams.data_format)

      with tf.variable_scope('unit_1_0'):
        x = res_func(x, filters[1], _stride_arr(strides[0]), True)

      for i in range(1, hparams.residual_units[0]):
        with tf.variable_scope('unit_1_%d' % i):
          x = res_func(x, filters[1], _stride_arr(1), False)

      with tf.variable_scope('unit_2_0'):
        x = res_func(x, filters[2], _stride_arr(strides[1]), True)

      for i in range(1, hparams.residual_units[1]):
        with tf.variable_scope('unit_2_%d' % i):
          x = res_func(x, filters[2], _stride_arr(1), False)

      with tf.variable_scope('unit_3_0'):
        x = res_func(x, filters[3], _stride_arr(strides[2]), True)

      for i in range(1, hparams.residual_units[2]):
        with tf.variable_scope('unit_3_%d' % i):
          x = res_func(x, filters[3], _stride_arr(1), False)

      if len(filters) == 5:
        with tf.variable_scope('unit_4_0'):
          x = res_func(x, filters[4], _stride_arr(strides[3]), True)

        for i in range(1, hparams.residual_units[3]):
          with tf.variable_scope('unit_4_%d' % i):
            x = res_func(x, filters[4], _stride_arr(1), False)

      x = model_utils.batch_norm(x, hparams, is_training)
      x = tf.nn.relu(x)

      with tf.variable_scope('unit_last'):
        x = _global_avg_pool(x)

      with tf.variable_scope('logit'):
        logits = _fully_connected(x, hparams.num_classes)
        predictions = tf.nn.softmax(logits)

      if mode in [ModeKeys.PREDICT, ModeKeys.ATTACK]:

        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={
                'classes': tf.argmax(predictions, axis=1),
                'logits': logits,
                'probabilities': predictions,
            })

      with tf.variable_scope('costs'):
        xent = tf.losses.sparse_softmax_cross_entropy(
            labels=labels, logits=logits)
        cost = tf.reduce_mean(xent, name='xent')
        if is_training:
          cost += model_utils.weight_decay(hparams)
          cost += _l1()

          if hparams.dropout_type is not None:
            if "louizos" in hparams.dropout_type:
              cost += hparams.louizos_cost * model_utils.louizos_complexity_cost(
                  hparams) / 50000

            if "variational" in hparams.dropout_type:
              # prior DKL part of the ELBO
              graph = tf.get_default_graph()
              node_defs = [
                  n for n in graph.as_graph_def().node if 'log_alpha' in n.name
              ]
              log_alphas = [
                  graph.get_tensor_by_name(n.name + ":0") for n in node_defs
              ]
              print([
                  n.name
                  for n in graph.as_graph_def().node
                  if 'log_alpha' in n.name
              ])
              print("found %i logalphas" % len(log_alphas))
              divergences = [dropouts.dkl_qp(la) for la in log_alphas]
              # combine to form the ELBO
              N = float(50000)
              dkl = tf.reduce_sum(tf.stack(divergences))

              warmup_steps = 50000
              dkl = (1. / N) * dkl * tf.minimum(
                  1.0,
                  tf.to_float(tf.train.get_global_step()) /
                  warmup_steps) * hparams.var_scale
              cost += dkl
              tf.summary.scalar("dkl", dkl)

          if hparams.ard_cost > 0.0:
            cost += model_utils.ard_cost() * hparams.ard_cost

          if hparams.smallify > 0.0:
            cost += model_utils.switch_loss() * hparams.smallify

    # Summaries
    # ========================
    tf.summary.scalar("total_nonzero", model_utils.nonzero_count())
    all_weights = tf.concat(
        [
            tf.reshape(v, [-1])
            for v in tf.trainable_variables()
            if "DW" in v.name
        ],
        axis=0)
    tf.summary.histogram("weights", all_weights)
    # ========================

    return model_utils.model_top(labels, predictions, cost, lr, mode, hparams)

  return resnet


================================================
FILE: models/utils/__init__.py
================================================
__all__ = [
    "activations", "dropouts", "initializations", "model_utils", "optimizers"
]

from .activations import *
from .dropouts import *
from .initializations import *
from .model_utils import *
from .optimizers import *


================================================
FILE: models/utils/activations.py
================================================
import tensorflow as tf

_ACTIVATION = dict()


def register(name):

  def add_to_dict(fn):
    global _ACTIVATION
    _ACTIVATION[name] = fn
    return fn

  return add_to_dict


def get_activation(params):
  return _ACTIVATION[params.activation](params)


@register("relu")
def relu(params):
  return tf.nn.relu


@register("brelu")
def brelu(params):

  def fn(a):
    idx = tf.range(a.shape[-1])
    idx = tf.mod(idx, 2)
    idx = tf.cast(idx, tf.bool)

    even = tf.nn.relu(a)
    odd = -tf.nn.relu(-a)

    return tf.where(idx, odd, even)

  return fn


@register("selu")
def selu(params):
  return tf.nn.selu


@register("elu")
def elu(params):
  return tf.nn.elu


@register("sigmoid")
def sigmoid(params):
  return tf.nn.sigmoid


@register("swish")
def swish(params):
  return lambda x: tf.nn.sigmoid(x) * x


@register("tanh")
def tanh(params):
  return tf.nn.tanh


================================================
FILE: models/utils/dropouts.py
================================================
import numpy as np
import tensorflow as tf

_DROPOUTS = dict()


def register(name):

  def add_to_dict(fn):
    global _DROPOUTS
    _DROPOUTS[name] = fn
    return fn

  return add_to_dict


def get_dropout(name):
  return _DROPOUTS[name]



@register("targeted_weight")
def targeted_weight_dropout(w, params, is_training):
  drop_rate = params.drop_rate
  targ_perc = params.targ_rate

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])
  norm = tf.abs(w)
  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))
  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]
  mask = norm < threshold[None, :]

  if not is_training:
    w = (1. - tf.to_float(mask)) * w
    w = tf.reshape(w, w_shape)
    return w

  mask = tf.to_float(
      tf.logical_and(tf.random_uniform(tf.shape(w)) < drop_rate, mask))
  w = (1. - mask) * w
  w = tf.reshape(w, w_shape)
  return w


@register("targeted_weight_random")
def targeted_weight_random(w, params, is_training):
  drop_rate = params.drop_rate
  targ_perc = params.targ_rate

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])

  switch = tf.get_variable(
      "mask",
      w.shape,
      initializer=tf.random_uniform_initializer(),
      trainable=False)

  if is_training:
    mask = tf.logical_and(switch < targ_perc,
                          tf.random_uniform(w.shape) < drop_rate)
  else:
    mask = switch < targ_perc

  mask = 1. - tf.to_float(mask)
  mask = tf.stop_gradient(mask)

  w = mask * w
  w = tf.reshape(w, w_shape)
  return w


@register("ramping_targeted_weight_random")
def ramping_targeted_weight_random(w, params, is_training):
  drop_rate = params.drop_rate
  targ_perc = 0.95 * params.targ_rate * tf.minimum(
      1.0,
      tf.to_float(tf.train.get_global_step()) / 20000.)
  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(
      0.0,
      tf.minimum(1.0,
                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])

  switch = tf.get_variable(
      "mask",
      w.shape,
      initializer=tf.random_uniform_initializer(),
      trainable=False)

  if is_training:
    mask = tf.logical_and(switch < targ_perc,
                          tf.random_uniform(w.shape) < drop_rate)
  else:
    mask = switch < (targ_perc * drop_rate)

  mask = 1. - tf.to_float(mask)
  mask = tf.stop_gradient(mask)

  w = mask * w
  w = tf.reshape(w, w_shape)
  return w


@register("targeted_weight_piecewise")
def targeted_weight_piecewise_dropout(w, params, is_training):
  drop_rate = params.drop_rate * tf.minimum(
      1.0,
      tf.to_float(tf.train.get_global_step()) / 40000.)

  targ_perc = 0.95 * params.targ_rate * tf.minimum(
      1.0,
      tf.to_float(tf.train.get_global_step()) / 20000.)
  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(
      0.0,
      tf.minimum(1.0,
                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])
  norm = tf.abs(w)
  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))
  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]
  mask = norm < threshold[None, :]

  if not is_training:
    w = w * (1 - tf.to_float(mask))
    return tf.reshape(w, w_shape)

  mask = tf.where(
      tf.logical_and((1. - drop_rate) < tf.random_uniform(tf.shape(w)), mask),
      tf.ones_like(w, dtype=tf.float32), tf.zeros_like(w, dtype=tf.float32))
  w = (1 - mask) * w
  w = tf.reshape(w, w_shape)
  return w


@register("targeted_unit_piecewise")
def targeted_unit_piecewise(w, params, is_training):
  drop_rate = params.drop_rate * tf.minimum(
      1.0,
      tf.to_float(tf.train.get_global_step()) / 40000.)

  targ_perc = 0.95 * params.targ_rate * tf.minimum(
      1.0,
      tf.to_float(tf.train.get_global_step()) / 20000.)
  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(
      0.0,
      tf.minimum(1.0,
                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))

  w_shape = w.shape
  w = tf.reshape(w, [-1, w.shape[-1]])
  norm = tf.norm(w, axis=0)
  idx = tf.to_int32(targ_perc * tf.to_float(w.shape[1]))
  sorted_norms = tf.contrib.framework.sort(norm)
  threshold = sorted_norms[idx]
  mask = (norm < threshold)[None, :]

  if not is_training:
    w = w * (1 - tf.to_float(mask))
    return tf.reshape(w, w_shape)

  mask = tf.tile(mask, [w.shape[0], 1])
  mask = tf.where(
      tf.logical_and((1. - drop_rate) < tf.random_uniform(tf.shape(w)), mask),
      tf.ones_like(w, dtype=tf.float32), tf.zeros_like(w, dtype=tf.float32))
  w = tf.reshape((1 - mask) * w, w_shape)
  return w


@register("delayed_targeted_weight_prune")
def delayed_targeted_weight(w, params, is_training):
  orig_w = w
  targ_perc = params.targ_rate

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])
  norm = tf.abs(w)
  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))
  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]
  mask = norm >= threshold[None, :]

  w = w * tf.to_float(mask)
  cond = tf.to_float(tf.train.get_global_step() >= params.dropout_delay_steps)
  return cond * tf.reshape(w, w_shape) + (1 - cond) * orig_w


@register("delayed_targeted_unit_prune")
def delayed_targeted_unit(x, params, is_training):
  orig_x = x

  w = tf.reshape(x, [-1, x.shape[-1]])
  norm = tf.norm(w, axis=0)
  idx = int(params.targ_rate * int(w.shape[1]))
  sorted_norms = tf.contrib.framework.sort(norm)
  threshold = sorted_norms[idx]
  mask = (norm >= threshold)[None, None]

  w = w * tf.to_float(mask)
  return tf.cond(
      tf.greater(tf.train.get_global_step(), params.dropout_delay_steps),
      lambda: tf.reshape(w, x.shape), lambda: orig_x)


@register("untargeted_weight")
def untargeted_weight(w, params, is_training):
  if not is_training:
    return w
  return tf.nn.dropout(w, keep_prob=(1. - params.drop_rate))


@register("targeted_unit")
def targeted_unit_dropout(x, params, is_training):
  w = tf.reshape(x, [-1, x.shape[-1]])
  norm = tf.norm(w, axis=0)
  idx = int(params.targ_rate * int(w.shape[1]))
  sorted_norms = tf.contrib.framework.sort(norm)
  threshold = sorted_norms[idx]
  mask = (norm < threshold)[None, :]
  mask = tf.tile(mask, [w.shape[0], 1])

  if not is_training:
    w = (1. - tf.to_float(mask)) * w
    w = tf.reshape(w, x.shape)
    return w

  
  mask = tf.where(
      tf.logical_and((1. - params.drop_rate) < tf.random_uniform(tf.shape(w)),
                     mask), tf.ones_like(w, dtype=tf.float32),
      tf.zeros_like(w, dtype=tf.float32))
  x = tf.reshape((1 - mask) * w, x.shape)
  return x


@register("targeted_unit_random")
def targeted_unit_random(w, params, is_training):
  drop_rate = params.drop_rate
  targ_perc = params.targ_rate

  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-1]])

  switch = tf.get_variable(
      "mask",
      w.shape[-1],
      initializer=tf.random_uniform_initializer(),
      trainable=False)

  if is_training:
    mask = tf.logical_and(switch < targ_perc,
                          tf.random_uniform(switch.shape) < drop_rate)
  else:
    mask = switch < targ_perc

  mask = 1. - tf.to_float(mask)
  mask = tf.stop_gradient(mask[None, :])

  w = mask * w
  w = tf.reshape(w, w_shape)
  return w


@register("targeted_ard")
def targeted_ard_dropout(w, x, params, is_training):
  if not is_training:
    return w
  x = tf.reshape(x, [-1, x.shape[-1]])
  activation_norms = tf.reduce_mean(tf.abs(x), axis=0)
  w_shape = w.shape
  w = tf.reshape(w, [-1, w_shape[-2], w_shape[-1]])
  norm = tf.norm(w, axis=(0, 2)) * activation_norms
  idx = int(params.targ_rate * int(w.shape[1]))
  sorted_norms = tf.contrib.framework.sort(norm)
  threshold = sorted_norms[idx]
  mask = (norm < threshold)[None, :, None]
  mask = tf.tile(mask, [w.shape[0], 1, w.shape[-1]])
  mask = tf.where(
      tf.logical_and((1. - params.drop_rate) < tf.random_uniform(tf.shape(w)),
                     mask), tf.ones_like(w, dtype=tf.float32),
      tf.zeros_like(w, dtype=tf.float32))
  w = tf.reshape((1 - mask) * w, w_shape)
  return w


@register("untargeted_unit")
def unit_dropout(w, params, is_training):
  if not is_training:
    return w
  w_shape = w.shape
  w = tf.reshape(w, [-1, w.shape[-1]])
  mask = tf.to_float(
      tf.random_uniform([int(w.shape[1])]) > params.drop_rate)[None, :]
  w = tf.reshape(mask * w, w_shape)
  return w / (1 - params.drop_rate)


@register("louizos_weight")
def louizos_weight_dropout(w, params, is_training):
  with tf.variable_scope("louizos"):
    EPS = 1e-8
    noise = (1 - EPS) * tf.random_uniform(w.shape) + (EPS / 2)
    rate = np.log(1 - params.drop_rate) - np.log(params.drop_rate)
    gates = tf.get_variable(
        "gates",
        shape=w.shape,
        initializer=tf.random_normal_initializer(mean=rate, stddev=0.01))
    if is_training:
      s = tf.nn.sigmoid(
          (gates + tf.log(noise / (1. - noise))) / params.louizos_beta)
      s_bar = s * (
          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma
    else:
      s = tf.nn.sigmoid(gates)
      s_bar = s * (
          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma
    mask = tf.minimum(1., tf.maximum(0., s_bar))

    return mask * w


@register("louizos_unit")
def louizos_unit_dropout(w, params, is_training):
  with tf.variable_scope("louizos"):
    EPS = 1e-8
    noise = (1 - EPS) * \
        tf.random_uniform([w.shape.as_list()[-1]]) + (EPS / 2)
    rate = np.log(1 - params.drop_rate) - np.log(params.drop_rate)
    gates = tf.get_variable(
        "gates",
        shape=[w.shape.as_list()[-1]],
        initializer=tf.random_normal_initializer(mean=rate, stddev=0.01))
    if is_training:
      s = tf.nn.sigmoid(
          (gates + tf.log(noise / (1. - noise))) / params.louizos_beta)
      s_bar = s * (
          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma
    else:
      s = tf.nn.sigmoid(gates)
      s_bar = s * (
          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma
    mask = tf.minimum(1., tf.maximum(0., s_bar))

    return mask * w


# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py
def log_sigma2_variable(shape, ard_init=-10.):
  return tf.get_variable(
      "log_sigma2", shape=shape, initializer=tf.constant_initializer(ard_init))


# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py
def get_log_alpha(log_sigma2, w):
  log_alpha = clip(log_sigma2 - paranoid_log(tf.square(w)))
  return tf.identity(log_alpha, name='log_alpha')


# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py
def paranoid_log(x, eps=1e-8):
  v = tf.log(x + eps)
  return v


# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py
def clip(x):
  return tf.clip_by_value(x, -8., 8.)


def dkl_qp(log_alpha):
  k1, k2, k3 = 0.63576, 1.8732, 1.48695
  C = -k1
  mdkl = k1 * tf.nn.sigmoid(k2 + k3 * log_alpha) - 0.5 * tf.log1p(
      tf.exp(-log_alpha)) + C
  return -tf.reduce_sum(mdkl)


@register("variational")
def variational_dropout(w, _, is_training):
  with tf.variable_scope("variational"):
    log_sigma2 = log_sigma2_variable(w.get_shape())
    log_alpha = get_log_alpha(log_sigma2, w)
    select_mask = tf.cast(tf.less(log_alpha, 3), tf.float32)

    if is_training:
      return w, log_alpha

    return w * select_mask, log_alpha


@register("variational_unit")
def variational_unit_dropout(w, _, is_training):
  with tf.variable_scope("variational"):
    log_sigma2 = log_sigma2_variable(int(w.shape[-1]))
    log_sigma2 = tf.reshape(log_sigma2, [1, 1, 1, -1])
    log_sigma2 = tf.tile(log_sigma2, [w.shape[0], w.shape[1], w.shape[2], 1])
    log_alpha = get_log_alpha(log_sigma2, w)
    select_mask = tf.cast(tf.less(log_alpha, 3), tf.float32)

    if is_training:
      return w, log_alpha

    return w * select_mask, log_alpha


@register("smallify_dropout")
def smallify_dropout(x, hparams, is_training):
  with tf.variable_scope("smallify", reuse=tf.AUTO_REUSE):
    switch = tf.get_variable(
        "switch",
        shape=[1] * (len(x.shape) - 1) + [x.shape[-1]],
        initializer=tf.random_uniform_initializer())

    mask = tf.get_variable(
        initializer=lambda: tf.ones_like(switch.initialized_value()),
        name="mask",
        trainable=False)
    exp_avg = tf.get_variable(
        initializer=lambda: tf.sign(switch.initialized_value()),
        name="exp_avg",
        trainable=False)
    exp_std = tf.get_variable(
        initializer=lambda: tf.zeros_like(switch.initialized_value()),
        name="exp_std",
        trainable=False)
    gates = switch * mask

    batch_sign = tf.sign(switch)
    diff = batch_sign - exp_avg

    new_mask = tf.cast(tf.less(exp_std, hparams.smallify_thresh), tf.float32)

    if not is_training:
      return tf.identity(x * gates, name="smallified")

    with tf.control_dependencies([
        tf.assign(mask, mask * new_mask),
        tf.assign(
            exp_std, hparams.smallify_mv * exp_std +
            (1 - hparams.smallify_mv) * diff**2),
        tf.assign(
            exp_avg, hparams.smallify_mv * exp_avg +
            (1 - hparams.smallify_mv) * batch_sign)
    ]):
      return tf.identity(x * gates, name="smallified")


@register("smallify_weight_dropout")
def smallify_weight_dropout(x, hparams, is_training):
  with tf.variable_scope("smallify"):
    switch = tf.get_variable(
        "switch", shape=x.shape, initializer=tf.random_uniform_initializer())

    mask = tf.get_variable(
        initializer=lambda: tf.ones_like(switch.initialized_value()),
        name="mask",
        trainable=False)
    exp_avg = tf.get_variable(
        initializer=lambda: tf.sign(switch.initialized_value()),
        name="exp_avg",
        trainable=False)
    exp_std = tf.get_variable(
        initializer=lambda: tf.zeros_like(switch.initialized_value()),
        name="exp_std",
        trainable=False)
    gates = switch * mask

    batch_sign = tf.sign(switch)
    diff = batch_sign - exp_avg

    new_mask = tf.cast(tf.less(exp_std, hparams.smallify_thresh), tf.float32)

    if not is_training:
      return tf.identity(x * gates, name="smallified")

    with tf.control_dependencies([
        tf.assign(mask, mask * new_mask),
        tf.assign(
            exp_std, hparams.smallify_mv * exp_std +
            (1 - hparams.smallify_mv) * diff**2),
        tf.assign(
            exp_avg, hparams.smallify_mv * exp_avg +
            (1 - hparams.smallify_mv) * batch_sign)
    ]):
      return tf.identity(x * gates, name="smallified")


================================================
FILE: models/utils/initializations.py
================================================
import tensorflow as tf

_INIT = dict()


def register(name):

  def add_to_dict(fn):
    global _INIT
    _INIT[name] = fn
    return fn

  return add_to_dict


def get_init(params):
  return _INIT[params.initializer](params)


@register("normal")
def normal(params):
  return tf.random_normal_initializer(mean=params.mean, stddev=params.sd)


@register("constant")
def constant(params):
  return tf.constant_initializer(0.1, tf.float32)


@register("uniform_unit_scaling")
def uniform_unit_scaling(params):
  return tf.uniform_unit_scaling_initializer()


@register("glorot_normal_initializer")
def glorot_normal_initializer(params):
  return tf.glorot_normal_initializer()


@register("glorot_uniform_initializer")
def glorot_uniform_initializer(params):
  return tf.glorot_uniform_initializer()


@register("variance_scaling_initializer")
def variance_scaling_initializer(params):
  return tf.variance_scaling_initializer()


class RandomUnitScaling(tf.keras.initializers.Initializer):

  def __call__(self, shape, dtype=None, partition_info=None):
    if len(shape) == 2:
      dim = (shape[0] + shape[1]) / 2.
    elif len(shape) == 4:
      dim = shape[0] * shape[1] * (shape[2] + shape[3]) / 2.

    m = tf.sqrt(3 / tf.to_float(dim))
    init = m * (2 * tf.random_uniform(shape) - 1)
    return init


class RandomHadamardConstant(tf.keras.initializers.Initializer):

  def __call__(self, shape, dtype=None, partition_info=None):
    dim = (shape[0] + shape[1]) / 2.

    flip = 2 * tf.round(tf.random_uniform(shape)) - 1
    m = tf.pow(dim, -1 / 2.)
    return m * flip


class RandomHadamardUnscaled(tf.keras.initializers.Initializer):

  def __call__(self, shape, dtype=None, partition_info=None):
    return 2 * tf.round(tf.random_uniform(shape)) - 1


class RandomWarpedUniform(tf.keras.initializers.Initializer):

  def __init__(self, k=2):
    self.k = k

  def __call__(self, shape, dtype=None, partition_info=None):
    if len(shape) == 2:
      dim = (shape[0] + shape[1]) / 2.
    elif len(shape) == 4:
      dim = shape[0] * shape[1] * (shape[2] + shape[3]) / 2.

    m = tf.sqrt(3 / tf.to_float(dim))

    eps = 1e-10
    unif = (1 - eps) * tf.random_uniform(shape) + eps / 2
    skew_unif = tf.nn.sigmoid(self.k * tf.log(unif / (1 - unif)))
    init = m * (2 * skew_unif - 1)
    return init


@register("warped_unif")
def warped_unif(params):
  return RandomWarpedUniform(params.k)


@register("unit_scaling")
def unit_scaling(params):
  return RandomUnitScaling()


@register("hadamard_constant")
def hadamard_constant(params):
  return RandomHadamardConstant()


@register("hadamard_unscaled")
def hadamard_unscaled(params):
  return RandomHadamardUnscaled()

================================================
FILE: models/utils/model_utils.py
================================================
import operator
from functools import reduce

import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_estimator

from . import dropouts
from .optimizers import get_optimizer
from ...training import tpu


class ModeKeys(object):
  TRAIN = tf.estimator.ModeKeys.TRAIN
  EVAL = tf.estimator.ModeKeys.EVAL
  TEST = "test"
  PREDICT = tf.estimator.ModeKeys.PREDICT
  ATTACK = "attack"


def collect_vars(fn):
  """Collect all new variables created within `fn`.

  Args:
    fn: a function that takes no arguments and creates trainable tf.Variable
      objects.

  Returns:
    outputs: the outputs of `fn()`.
    new_vars: a list of the newly created variables.
  """
  previous_vars = set(tf.trainable_variables())
  outputs = fn()
  current_vars = set(tf.trainable_variables())
  new_vars = current_vars.difference(previous_vars)
  return outputs, list(new_vars)


def dense(x, units, hparams, is_training, dropout=True):
  with tf.variable_scope(None, default_name="dense") as scope:
    w = tf.get_variable("kernel", shape=[x.shape[1], units], dtype=tf.float32)
    b = tf.get_variable(
        "bias",
        shape=[units],
        dtype=tf.float32,
        initializer=tf.zeros_initializer())
    if dropout and hparams.dropout_type is not None and is_training:
      w = dropouts.get_dropout(hparams.dropout_type)(w, hparams, is_training)

    w = tf.identity(w, name="post_dropout")
    y = tf.matmul(x, w) + b
    return y


def conv(x,
         filter_size,
         out_filters,
         hparams,
         strides=[1, 1, 1, 1],
         padding="SAME",
         is_training=False,
         activation=None,
         dropout=True,
         name=None,
         schit_layer=False):
  """Convolution."""
  with tf.variable_scope(name, default_name="conv2d"):
    if hparams.data_format == "channels_last":
      in_filters = x.shape[-1]
    else:
      in_filters = x.shape[1]

    kernel = tf.get_variable(
        'DW', [filter_size, filter_size, in_filters, out_filters], tf.float32)
    use_dropout = hparams.dropout_type is not None and dropout

    # schit layer
    if schit_layer:
      scale = tf.get_variable(
          'scale',
          kernel.shape[-1],
          tf.float32,
          initializer=tf.zeros_initializer())
      kernel = hparams.lipschitz_constant * tf.nn.sigmoid(
          scale) * kernel / tf.norm(
              tf.reshape(kernel, shape=[-1, kernel.shape[-1]]), axis=0)

    if use_dropout:
      dropout_fn = dropouts.get_dropout(hparams.dropout_type)

      if hparams.dropout_type == "targeted_ard":
        kernel = dropout_fn(kernel, x, hparams, is_training)
      else:
        kernel = dropout_fn(kernel, hparams, is_training)

      # special case for variational
      if "variational" in hparams.dropout_type:
        kernel, log_alpha = kernel[0], kernel[1]
        if is_training:
          conved_mu = tf.nn.conv2d(x, kernel, strides=strides, padding=padding)
          conved_si = tf.sqrt(
              tf.nn.conv2d(
                  tf.square(x),
                  tf.exp(log_alpha) * tf.square(kernel),
                  strides=strides,
                  padding=padding) + 1e-8)
          conved = conved_mu + tf.random_normal(
              tf.shape(conved_mu)) * conved_si

          conved = tf.identity(conved, name="post_dropout")
          return conved

    data_format = "NHWC" if hparams.data_format == "channels_last" else "NCHW"
    conv = tf.nn.conv2d(
        x, kernel, strides, padding=padding, data_format=data_format)

    if activation:
      conv = activation(conv)

    conv = tf.identity(conv, name="post_dropout")
    return conv


def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""

  weight_decay_loss = weight_decay(hparams)
  tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams, learning_rate)
  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss


def weight_noise(hparams, learning_rate):
  """Apply weight noise to vars in var_list."""
  if not hparams.weight_noise_rate:
    return [tf.no_op()]

  tf.logging.info("Applying weight noise scaled by learning rate, "
                  "noise_rate: %0.5f", hparams.weight_noise_rate)
  noise_ops = []

  noise_vars = [v for v in tf.trainable_variables() if "/body/" in v.name]
  for v in var_list:
    with tf.device(v._ref().device):  # pylint: disable=protected-access
      scale = hparams.weight_noise_rate * learning_rate * 0.001
      tf.summary.scalar("weight_noise_scale", scale)
      noise = tf.truncated_normal(v.shape) * scale
      noise_op = v.assign_add(noise)
      noise_ops.append(noise_op)
  return noise_ops


def weight_decay(hparams):
  """Apply weight decay to vars in var_list."""
  if not hparams.weight_decay_rate:
    return 0.

  only_features = hparams.weight_decay_only_features
  var_list = [v for v in tf.trainable_variables()]
  weight_decays = []
  for v in var_list:
    # Weight decay.
    is_feature = any(n in v.name for n in hparams.weight_decay_weight_names)
    if (not only_features) or is_feature:
      if hparams.initializer == "hadamard_unscaled":
        v_loss = tf.reduce_sum((tf.abs(v) - 1)**2) / 2
      else:
        v_loss = tf.nn.l2_loss(v)
      weight_decays.append(v_loss)

  return tf.reduce_sum(weight_decays, axis=0) * hparams.weight_decay_rate


def axis_aligned_cost(logits, hparams):
  negativity_cost = tf.nn.relu(-logits)
  max_mask = tf.one_hot(tf.argmax(tf.abs(logits), -1), hparams.num_classes)
  min_logits = tf.abs(logits) * (1 - max_mask)
  max_logit = tf.abs(logits) * max_mask
  one_bound = tf.nn.relu(logits - hparams.logit_bound)
  axis_alignedness_cost = tf.nn.relu(min_logits - 0.1 * hparams.logit_bound)

  logits_packed = tf.reduce_all(tf.less(max_logit, hparams.logit_bound), -1)
  logits_packed = tf.logical_and(logits_packed,
                                 tf.reduce_all(
                                     tf.less(min_logits,
                                             0.1 * hparams.logit_bound), -1))
  logits_packed = tf.reduce_mean(tf.to_float(logits_packed))
  tf.summary.scalar("logits_packed", logits_packed)
  tf.summary.scalar(
      "logits_max",
      tf.to_float(tf.shape(max_logit)[-1]) * tf.reduce_mean(max_logit))

  return negativity_cost, axis_alignedness_cost, one_bound


def ard_cost():
  with tf.variable_scope("ard_cost"):
    cost = 0
    for v in tf.trainable_variables():
      if "kernel" in v.name or "DW" in v.name:
        rv = tf.reshape(v, [-1, int(v.shape[-1])])
        sq_rv = tf.square(rv)
        sum_sq = tf.reduce_sum(sq_rv, axis=1, keepdims=True)
        ard = sq_rv / (sum_sq / tf.cast(tf.shape(sq_rv)[1], tf.float32)
                      ) - 0.5 * tf.log(sum_sq)
        cost += tf.reduce_sum(ard)

    return cost


def shape_list(x):
  """Return list of dims, statically where possible."""
  x = tf.convert_to_tensor(x)

  # If unknown rank, return dynamic shape
  if x.get_shape().dims is None:
    return tf.shape(x)

  static = x.get_shape().as_list()
  shape = tf.shape(x)

  ret = []
  for i, dim in enumerate(static):
    if dim is None:
      dim = shape[i]
    ret.append(dim)
  return ret


def standardize_images(x):
  """Image standardization on batches."""

  with tf.name_scope("standardize_images", [x]):
    x = tf.to_float(x)
    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keep_dims=True)
    x_variance = tf.reduce_mean(
        tf.square(x - x_mean), axis=[1, 2, 3], keep_dims=True)
    x_shape = shape_list(x)
    num_pixels = tf.to_float(x_shape[1] * x_shape[2] * x_shape[3])
    x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
    return x


def batch_norm(inputs, hparams, training):
  """Performs a batch normalization using a standard set of parameters."""
  # We set fused=True for a significant performance boost. See
  # https://www.tensorflow.org/performance/performance_guide#common_fused_ops
  if hparams.data_format == "channels_first":
    axis = 1
  else:
    axis = -1

  return tf.layers.batch_normalization(
      inputs=inputs,
      axis=axis,
      momentum=0.997,
      epsilon=0.001,
      center=True,
      scale=True,
      training=training,
      fused=True)


def louizos_complexity_cost(params):
  gates = {
      w.name.strip(":0"): w
      for w in tf.trainable_variables()
      if "gates" in w.name
  }
  names = list(gates.keys())
  concat_gates = tf.concat([tf.reshape(gates[name], [-1]) for name in names],
                           0)
  if params.dropout_type == "louizos_weight":
    complexity_cost = tf.nn.sigmoid(
        concat_gates - params.louizos_beta * tf.
        log(-1 * params.louizos_gamma / params.louizos_zeta))
  elif params.dropout_type == "louizos_unit":
    reshaped_gates = [
        tf.reshape(gates[name], [-1, gates[name].shape[-1]]) for name in names
    ]

    parameters = []
    for name in names:
      g_name = name[:-len("louizos/gates")] + "DW"
      g = tf.contrib.framework.get_unique_variable(g_name)
      parameters.extend(
          [reduce(operator.mul,
                  g.shape.as_list()[:-1], 1)] * g.shape.as_list()[-1])
    group_sizes = tf.constant(parameters)
    assert group_sizes.shape[0] == concat_gates.shape[0], "{} != {}".format(
        group_sizes.shape[0], concat_gates.shape[0])

    complexity_cost = tf.cast(group_sizes, tf.float32) * tf.nn.sigmoid(
        concat_gates - params.louizos_beta * tf.
        log(-1 * params.louizos_gamma / params.louizos_zeta))
  return tf.reduce_sum(complexity_cost)


def switch_loss():
  losses = 0

  for v in tf.trainable_variables():
    if "switch" in v.name:
      losses += tf.reduce_sum(tf.abs(v))

  tf.summary.scalar("switch_loss", losses)
  return losses


def nonzero_count():
  nonzeroes = 0
  for op in tf.get_default_graph().get_operations():
    if "post_dropout" in op.name:
      v = tf.get_default_graph().get_tensor_by_name(op.name + ":0")
      count = tf.to_float(tf.equal(v, 0.))
      count = tf.reduce_sum(1 - count)
      nonzeroes += count
  return nonzeroes


def percent_sparsity():
  nonzeroes = 0
  total = 0
  for op in tf.get_default_graph().get_operations():
    if "post_dropout" in op.name:
      v = tf.get_default_graph().get_tensor_by_name(op.name + ":0")
      count = tf.to_float(tf.equal(v, 0.))
      count = tf.reduce_sum(1 - count)
      nonzeroes += count
      total += tf.size(v)
  return tf.to_float(nonzeroes) / tf.to_float(total)


def convert(num, base, length=None):
  ''' Converter from decimal to numeral systems from base 2 to base 10 '''
  num = int(num)
  base = int(base)
  result = []
  if num == 0:
    result.append(0)
  else:
    while (num > 0):
      result.append(num % base)
      num //= base
  # Reverse from LSB to MSB
  result = result[::-1]
  if length is not None:
    n_to_fill = length - len(result)
    if n_to_fill > 0:
      result = [0] * n_to_fill + result
  return result


def equal_mult(size, num_branches):
  return [
      tf.constant(1.0 / num_branches, shape=[size, 1, 1, 1], dtype=tf.float32)
      for _ in range(num_branches)
  ]


def uniform(size, num_branches):
  return [
      tf.random_uniform([size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
      for _ in range(num_branches)
  ]


def bernoulli(size, num_branches):
  random = tf.random_uniform([size], maxval=num_branches, dtype=tf.int32)
  bernoulli = tf.one_hot(random, depth=num_branches)
  rand = tf.split(bernoulli, [1] * num_branches, 1)
  rand = [tf.reshape(x, [-1, 1, 1, 1]) for x in rand]
  return rand


def combine(rand_uniform, rand_bernoulli, num_branches):
  return [
      tf.concat([rand_uniform[i], rand_bernoulli[i]], axis=0)
      for i in range(num_branches)
  ]


def model_top(labels, preds, cost, lr, mode, hparams):
  tf.summary.scalar("acc",
      tf.reduce_mean(
          tf.to_float(
              tf.equal(labels,
                       tf.argmax(
                           preds, axis=-1,
                           output_type=tf.int32)))))
  tf.summary.scalar("loss", cost)

  gs = tf.train.get_global_step()

  if hparams.weight_decay_and_noise:
    cost = weight_decay_and_noise(cost, hparams, lr)
    cost = tf.identity(cost, name="total_loss")
  optimizer = get_optimizer(lr, hparams)

  train_op = tf.contrib.layers.optimize_loss(
      name="training",
      loss=cost,
      global_step=gs,
      learning_rate=lr,
      clip_gradients=hparams.clip_grad_norm or None,
      gradient_noise_scale=hparams.grad_noise_scale or None,
      optimizer=optimizer,
      colocate_gradients_with_ops=True)

  if hparams.use_tpu:

    def metric_fn(l, p):
      return {
          "acc":
          tf.metrics.accuracy(
              labels=l, predictions=tf.argmax(p, -1, output_type=tf.int32)),
      }

    host_call = None
    if hparams.tpu_summarize:
      host_call = tpu.create_host_call(hparams.output_dir)
    tpu.remove_summaries()

    if mode == tf.estimator.ModeKeys.EVAL:
      return tpu_estimator.TPUEstimatorSpec(
          mode=mode,
          predictions=preds,
          loss=cost,
          eval_metrics=(metric_fn, [labels, preds]),
          host_call=host_call)

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode, loss=cost, train_op=train_op, host_call=host_call)

  return tf.estimator.EstimatorSpec(
      mode,
      eval_metric_ops={
          "acc":
          tf.metrics.accuracy(
              labels=labels,
              predictions=tf.argmax(preds, axis=-1, output_type=tf.int32)),
      },
      loss=cost,
      train_op=train_op)


================================================
FILE: models/utils/optimizers.py
================================================
import tensorflow as tf

_OPTIMIZER = dict()


def register(name):

  def add_to_dict(fn):
    global _OPTIMIZER
    _OPTIMIZER[name] = fn
    return fn

  return add_to_dict


def get_optimizer(lr, params):
  optimizer = _OPTIMIZER[params.optimizer](lr, params)
  if params.use_tpu:
    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
  return optimizer


@register("sgd")
def sgd(lr, params):
  return tf.train.GradientDescentOptimizer(lr)


@register("adam")
def adam(lr, params):
  return tf.train.AdamOptimizer(lr, beta1=params.beta1, beta2=params.beta2)


@register("adagrad")
def adagrad(lr, params):
  return tf.train.AdagradOptimizer(lr)


@register("momentum")
def momentum(lr, params):
  return tf.train.MomentumOptimizer(
      lr, momentum=params.momentum, use_nesterov=params.use_nesterov)


================================================
FILE: models/vgg/__init__.py
================================================
__all__ = ["vgg"]


================================================
FILE: models/vgg/vgg.py
================================================
import tensorflow as tf

from ..utils.activations import get_activation
from ..utils.dropouts import get_dropout
from ..utils.initializations import get_init
from ..utils.optimizers import get_optimizer
from ..registry import register
from ..utils import model_utils
from ..utils import dropouts
from ...training import tpu
import six

import numpy as np
from tensorflow.contrib.tpu.python.tpu import tpu_estimator, tpu_optimizer


def metric_fn(labels, predictions):
  return {
      "acc":
      tf.metrics.accuracy(
          labels=tf.argmax(labels, -1), predictions=tf.argmax(predictions,
                                                              -1)),
  }


@register("vgg")
def get_vgg(hparams, lr):
  """Callable model function compatible with Experiment API."""

  def vgg(features, labels, mode, params):
    if hparams.use_tpu and 'batch_size' in params.keys():
      hparams.batch_size = params['batch_size']

    is_training = mode == tf.estimator.ModeKeys.TRAIN

    inputs = features["inputs"]
    with tf.variable_scope("vgg", initializer=get_init(hparams)):
      total_nonzero = 0
      conv1_1 = model_utils.conv(
          inputs, 3, 64, hparams, name="conv1_1", is_training=is_training)

      conv1_1 = model_utils.batch_norm(conv1_1, hparams, is_training)
      conv1_1 = tf.nn.relu(conv1_1)

      conv1_2 = model_utils.conv(
          conv1_1, 3, 64, hparams, name="conv1_2", is_training=is_training)
      conv1_2 = model_utils.batch_norm(conv1_2, hparams, is_training)
      conv1_2 = tf.nn.relu(conv1_2)

      pool1 = tf.layers.max_pooling2d(
          conv1_2, 2, 2, padding="SAME", name='pool1')

      conv2_1 = model_utils.conv(
          pool1, 3, 128, hparams, name="conv2_1", is_training=is_training)
      conv2_1 = model_utils.batch_norm(conv2_1, hparams, is_training)
      conv2_1 = tf.nn.relu(conv2_1)

      conv2_2 = model_utils.conv(
          conv2_1, 3, 128, hparams, name="conv2_2", is_training=is_training)
      conv2_2 = model_utils.batch_norm(conv2_2, hparams, is_training)
      conv2_2 = tf.nn.relu(conv2_2)

      pool2 = tf.layers.max_pooling2d(
          conv2_2, 2, 2, padding="SAME", name='pool2')

      conv3_1 = model_utils.conv(
          pool2, 3, 256, hparams, name="conv3_1", is_training=is_training)
      conv3_1 = model_utils.batch_norm(conv3_1, hparams, is_training)
      conv3_1 = tf.nn.relu(conv3_1)

      conv3_2 = model_utils.conv(
          conv3_1, 3, 256, hparams, name="conv3_2", is_training=is_training)
      conv3_2 = model_utils.batch_norm(conv3_2, hparams, is_training)
      conv3_2 = tf.nn.relu(conv3_2)

      conv3_3 = model_utils.conv(
          conv3_2, 3, 256, hparams, name="conv3_3", is_training=is_training)
      conv3_3 = model_utils.batch_norm(conv3_3, hparams, is_training)
      conv3_3 = tf.nn.relu(conv3_3)

      pool3 = tf.layers.max_pooling2d(
          conv3_3, 2, 2, padding="SAME", name='pool3')

      conv4_1 = model_utils.conv(
          pool3, 3, 512, hparams, name="conv4_1", is_training=is_training)
      conv4_1 = model_utils.batch_norm(conv4_1, hparams, is_training)
      conv4_1 = tf.nn.relu(conv4_1)

      conv4_2 = model_utils.conv(
          conv4_1, 3, 512, hparams, name="conv4_2", is_training=is_training)
      conv4_2 = model_utils.batch_norm(conv4_2, hparams, is_training)
      conv4_2 = tf.nn.relu(conv4_2)

      conv4_3 = model_utils.conv(
          conv4_2, 3, 512, hparams, name="conv4_3", is_training=is_training)
      conv4_3 = model_utils.batch_norm(conv4_3, hparams, is_training)
      conv4_3 = tf.nn.relu(conv4_3)

      pool4 = tf.layers.max_pooling2d(
          conv4_3, 2, 2, padding="SAME", name='pool4')

      conv5_1 = model_utils.conv(
          pool4, 3, 512, hparams, name="conv5_1", is_training=is_training)
      conv5_1 = model_utils.batch_norm(conv5_1, hparams, is_training)
      conv5_1 = tf.nn.relu(conv5_1)

      conv5_2 = model_utils.conv(
          conv5_1, 3, 512, hparams, name="conv5_2", is_training=is_training)
      conv5_2 = model_utils.batch_norm(conv5_2, hparams, is_training)
      conv5_2 = tf.nn.relu(conv5_2)

      conv5_3 = model_utils.conv(
          conv5_2, 3, 512, hparams, name="conv5_3", is_training=is_training)
      conv5_3 = model_utils.batch_norm(conv5_3, hparams, is_training)
      conv5_3 = tf.nn.relu(conv5_3)

      pool5 = tf.layers.max_pooling2d(
          conv5_3, 2, 2, padding="SAME", name='pool5')

      flat_x = tf.reshape(pool5, [hparams.batch_size, 512])
      fc6 = model_utils.batch_norm(
          model_utils.dense(flat_x, 4096, hparams, is_training), hparams,
          is_training)
      fc7 = model_utils.batch_norm(
          model_utils.dense(fc6, 4096, hparams, is_training), hparams,
          is_training)

      logits = tf.layers.dense(fc7, hparams.num_classes, name="logits")
      probs = tf.nn.softmax(logits, axis=-1)

      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={
                'classes': tf.argmax(probs, axis=1),
                'logits': logits,
                'probabilities': probs,
            })

      xent = tf.losses.sparse_softmax_cross_entropy(
          labels=labels, logits=logits)
      cost = tf.reduce_mean(xent, name='xent')
      cost += model_utils.weight_decay(hparams)

      tf.summary.scalar("total_nonzero", model_utils.nonzero_count())
      tf.summary.scalar("percent_sparsity", model_utils.percent_sparsity())
      if hparams.dropout_type is not None:
        if "louizos" in hparams.dropout_type:
          cost += hparams.louizos_cost * model_utils.louizos_complexity_cost(
              hparams) / 50000

        if "variational" in hparams.dropout_type:
          # prior DKL part of the ELBO
          graph = tf.get_default_graph()
          node_defs = [
              n for n in graph.as_graph_def().node if 'log_alpha' in n.name
          ]
          log_alphas = [
              graph.get_tensor_by_name(n.name + ":0") for n in node_defs
          ]
          print([
              n.name
              for n in graph.as_graph_def().node
              if 'log_alpha' in n.name
          ])
          print("found %i logalphas" % len(log_alphas))
          divergences = [dropouts.dkl_qp(la) for la in log_alphas]
          # combine to form the ELBO
          N = float(50000)
          dkl = tf.reduce_sum(tf.stack(divergences))

          warmup_steps = 50000
          dkl = (1. / N) * dkl * tf.minimum(
              1.0,
              tf.to_float(tf.train.get_global_step()) /
              warmup_steps) * hparams.var_scale
          cost += dkl
          tf.summary.scalar("dkl", dkl)

      if hparams.ard_cost > 0.0:
        cost += model_utils.ard_cost() * hparams.ard_cost

      if hparams.smallify > 0.0:
        cost += model_utils.switch_loss() * hparams.smallify

    return model_utils.model_top(labels, probs, cost, lr, mode, hparams)

  return vgg


================================================
FILE: requirements.txt
================================================
tensorflow>=1.9
requests>=2.19.1
dl-cloud>=0.0.4

================================================
FILE: scripts/__init__.py
================================================



================================================
FILE: scripts/prune/README.md
================================================
# Library for Pruning


================================================
FILE: scripts/prune/__init__.py
================================================


================================================
FILE: scripts/prune/eval.py
================================================
import tensorflow as tf
import os
import numpy as np

from ...hparams.registry import get_hparams
from ...models.registry import get_model
from ...data.registry import get_input_fns
from ...training import flags
from .prune import get_prune_fn, get_current_weights, get_louizos_masks, get_smallify_masks, prune_weights, is_prunable_weight


def init_flags():
  tf.flags.DEFINE_string("model", None, "Which model to use.")
  tf.flags.DEFINE_string("data", None, "Which data to use.")
  tf.flags.DEFINE_string("env", None, "Which environment to use.")
  tf.flags.DEFINE_string("hparams", None, "Which hparams to use.")
  tf.flags.DEFINE_string("hparam_override", "",
                         "Run-specific hparam settings to use.")
  tf.flags.DEFINE_string("output_dir", None, "The output directory.")
  tf.flags.DEFINE_string("data_dir", None, "The data directory.")
  tf.flags.DEFINE_integer("train_steps", 10000,
                          "Number of training steps to perform.")
  tf.flags.DEFINE_integer("eval_every", 1000,
                          "Number of steps between evaluations.")
  tf.flags.DEFINE_string(
      "post_weights_dir", "",
      "folder of the weights, if not set defaults to output_dir")
  tf.flags.DEFINE_string("prune_percent", "0.5",
                         "percent of weights to prune, comma separated")
  tf.flags.DEFINE_string("prune", "weight", "one_shot or fisher")
  tf.flags.DEFINE_boolean("variational", False, "use evaluate")
  tf.flags.DEFINE_string("eval_file", "eval_prune_results",
                         "file to put results")
  tf.flags.DEFINE_integer("train_epochs", None,
                          "Number of training epochs to perform.")
  tf.flags.DEFINE_integer("eval_steps", None,
                          "Number of evaluation steps to perform.")

def eval_model(FLAGS, hparam_name):
  hparams = get_hparams(hparam_name)
  hparams = hparams.parse(FLAGS.hparam_override)
  hparams = flags.update_hparams(FLAGS, hparams)

  model_fn = get_model(hparams)
  _, _, test_input_fn = get_input_fns(hparams, generate=False)

  features, labels = test_input_fn()
  sess = tf.Session()
  tf.train.create_global_step()
  model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
  saver = tf.train.Saver()
  ckpt_dir = tf.train.latest_checkpoint(hparams.output_dir)
  print("Loading model from...", ckpt_dir)
  saver.restore(sess, ckpt_dir)

  evals = []
  prune_percents = [float(i) for i in FLAGS.prune_percent.split(",")]

  mode = "standard"
  orig_weights = get_current_weights(sess)
  louizos_masks, smallify_masks = None, None
  if "louizos" in hparam_name:
    louizos_masks = get_louizos_masks(sess, orig_weights)
    mode = "louizos"
  elif "smallify" in hparam_name:
    smallify_masks = get_smallify_masks(sess, orig_weights)
  elif "variational" in hparam_name:
    mode = "variational"

  for prune_percent in prune_percents:
    if prune_percent > 0.0:
      prune_fn = get_prune_fn(FLAGS.prune)(mode, k=prune_percent)
      w_copy = dict(orig_weights)
      sm_copy = dict(smallify_masks) if smallify_masks is not None else None
      lm_copy = dict(louizos_masks) if louizos_masks is not None else None
      post_weights_pruned, weight_counts = prune_weights(
          prune_fn,
          w_copy,
          louizos_masks=lm_copy,
          smallify_masks=sm_copy,
          hparams=hparams)
      print("current weight counts at {}: {}".format(prune_percent,
                                                     weight_counts))

      print("there are ", len(tf.trainable_variables()), " weights")
      for v in tf.trainable_variables():
        if is_prunable_weight(v):
          assign_op = v.assign(
              np.reshape(post_weights_pruned[v.name.strip(":0")], v.shape))
          sess.run(assign_op)

    saver.save(sess, os.path.join(hparams.output_dir, "tmp", "model"))
    estimator = tf.estimator.Estimator(
        model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),
        model_dir=os.path.join(hparams.output_dir, "tmp"))
    print(
        f"Processing pruning {prune_percent} of weights for {hparams.eval_steps} steps"
    )  
    acc = estimator.evaluate(test_input_fn, hparams.eval_steps)['acc']
    print(f"Accuracy @ prune {100*prune_percent}% is {acc}")
    evals.append(acc)
  return evals


def _run(FLAGS):
  eval_file = open(FLAGS.eval_file, "w")

  hparams_list = FLAGS.hparams.split(",")
  total_evals = {}
  for hparam_name in hparams_list:
    evals = eval_model(FLAGS, hparam_name)

    print(hparam_name, ":", evals)
    eval_file.writelines("{}:{}\n".format(hparam_name, evals))
    total_evals[hparam_name] = evals
    tf.reset_default_graph()

  print("processed results:", total_evals)
  eval_file.close()


if __name__ == "__main__":
  init_flags()
  FLAGS = tf.app.flags.FLAGS
  _run(FLAGS)


================================================
FILE: scripts/prune/prune.py
================================================
import numpy as np
import tensorflow as tf
import statistics
from ...models.utils import model_utils

_PRUNE_FN = dict()


def register(fn):
  global _PRUNE_FN
  _PRUNE_FN[fn.__name__] = fn
  return fn


def get_prune_fn(name):
  return _PRUNE_FN[name]


@register
def weight(mode, k=0.5):

  if mode == "standard":

    def prune(weight_dict, weight_key):
      weights = weight_dict[weight_key]
      w = weights.copy()
      if len(weights.shape) == 4:
        w = w.reshape([-1, weights.shape[-1]])

      abs_w = np.abs(w)
      idx = int(k * abs_w.shape[0])
      med = np.sort(abs_w, axis=0)[idx:idx + 1]
      mask = (abs_w >= med).astype(float)
      pruned_w = mask * w

      return pruned_w, mask
  elif mode == "variational":

    def prune(weight_dict, weight_key):
      weights = weight_dict[weight_key]
      if k == 0.0:
        return weights, None
      log_alpha = weight_dict[weight_key.strip("DW") + "variational/log_alpha"]
      w = weights.copy()
      la = log_alpha.copy()
      if len(weights.shape) == 4:
        w = w.reshape([-1, weights.shape[-1]])
        la = la.reshape([-1, weights.shape[-1]])

      idx = int((1 - k) * la.shape[0])
      med = np.sort(la, axis=0)[idx:idx + 1]
      mask = (la < med).astype(float)
      pruned_w = mask * w

      return pruned_w, mask
  elif mode == "louizos":

    def prune(weight_dict, weight_key):
      weights = weight_dict[weight_key]
      w = weights.copy()
      if len(weights.shape) == 4:
        w = w.reshape([-1, weights.shape[-1]])

      idx = int(k * w.shape[0])
      med = np.sort(w, axis=0)[idx:idx + 1]
      mask = (w >= med).astype(float)
      pruned_w = mask * w

      return pruned_w, mask

  return prune


@register
def unit(mode, k=0.5):

  if mode == "standard" or mode == "variational":

    def prune(weight_dict, weight_key):
      weights = weight_dict[weight_key]
      w = weights.copy()
      if len(weights.shape) == 4:
        w = w.reshape([-1, weights.shape[-1]])
      norm = np.linalg.norm(w, axis=0)
      idx = int(k * norm.shape[0])
      med = np.sort(norm, axis=0)[idx]
      mask = (norm >= med).astype(float)
      pruned_w = mask * w

      return pruned_w, mask
  elif mode == "louizos":

    def prune(weight_dict, weight_key):
      weights = weight_dict[weight_key]
      w = weights.copy()
      assert len(weights.shape) == 1
      idx = int(k * w.shape[0])
      med = np.sort(w, axis=0)[idx]
      mask = (w >= med).astype(float)
      pruned_w = mask * w

      return pruned_w, mask

  return prune


@register
def ard(k=0.5):

  def prune(weight_dict, weight_key):
    weights = weight_dict[weight_key]
    w = weights.copy()
    if len(weights.shape) == 4:
      w = w.reshape([-1, weights.shape[-1]])
    norm = np.linalg.norm(w, axis=1, keepdims=True)
    idx = int(k * norm.shape[0])
    med = np.sort(norm, axis=0)[idx]
    mask = (norm >= med).astype(float)
    pruned_w = mask * w

    return pruned_w, mask

  return prune


def prune_weights(prune_fn,
                  weights,
                  louizos_masks=None,
                  smallify_masks=None,
                  hparams=None):
  weights_pruned = {}

  pre_prune_nonzero = 0
  pre_prune_total = 0
  if louizos_masks:
    orig_weights = dict(weights)
    for weight_name in weights:
      if weight_name not in louizos_masks.keys():
        print("WARN louizos: mask not found for {}".format(weight_name))
        continue
      weights[weight_name] = louizos_masks[weight_name]
  elif smallify_masks:
    orig_weights = dict(weights)
    for weight_name in weights:
      if weight_name not in smallify_masks.keys():
        print("WARN smallify: not pruning {}".format(weight_name))
        continue
      mask = smallify_masks[weight_name]
      weights[weight_name] = weights[weight_name] * mask

  for weight_name in weights:
    if "variational" in weight_name:
      print("WARN variational: not pruning {}".format(weight_name))
      continue

    pre_prune_nonzero += np.count_nonzero(weights[weight_name])
    pre_prune_total += weights[weight_name].size

    weights_pruned[weight_name], mask = prune_fn(weights, weight_name)
    if louizos_masks or smallify_masks:
      print("applied masks to", weight_name)
      weights_pruned[weight_name] = mask * orig_weights[weight_name].reshape(
          [-1, orig_weights[weight_name].shape[-1]])

  return weights_pruned, {
      "pre_prune_nonzero": pre_prune_nonzero,
      "pre_prune_total": pre_prune_total
  }


def get_louizos_masks(sess, weights):
  masks = {}
  for weight_name in weights:
    m_name = weight_name.strip("DW") + "louizos/gates"
    m = tf.contrib.framework.get_variables_by_name(m_name)
    assert len(m) == 1
    m = m[0]
    masks[weight_name] = sess.run(m)

  return masks


def get_smallify_masks(sess, weights):
  masks = {}
  for weight_name in weights:
    switch_name = weight_name.strip("DW") + "smallify/switch"
    mask_name = weight_name.strip("DW") + "smallify/mask"
    switch = tf.contrib.framework.get_variables_by_name(switch_name)
    mask = tf.contrib.framework.get_variables_by_name(mask_name)
    assert len(switch) == 1 and len(mask) == 1
    switch, mask = switch[0], mask[0]
    switch, mask = sess.run((switch, mask))

    masks[weight_name] = switch * mask

  return masks


def is_prunable_weight(weight):
  necessary_tokens = ["kernel", "DW", "variational"]
  blacklisted_tokens = ["logit", "fc", "init", "switch", "mask", "log_sigma"]

  contains_a_necessary_token = any(t in weight.name for t in necessary_tokens)
  contains_a_blacklisted_token = any(
      t in weight.name for t in blacklisted_tokens)

  is_prunable = contains_a_necessary_token and not contains_a_blacklisted_token

  if not is_prunable:
    print("WARN: not pruning %s" % weight.name)

  return is_prunable


def get_current_weights(sess):
  weights = {}
  variables = {}
  for v in tf.trainable_variables():
    if is_prunable_weight(v):
      name = v.name.strip(":0")
      variables[name] = v

  graph = tf.get_default_graph()
  node_defs = [n for n in graph.as_graph_def().node if 'log_alpha' in n.name]

  for n in node_defs:
    weights[n.name] = sess.run(graph.get_tensor_by_name(n.name + ":0"))

  for weight_name, w in variables.items():
    weights[weight_name] = sess.run(w)

  return weights


def prune_sess_weights(sess, prune_percent, FLAGS, hparams):
  current_weights = get_current_weights(sess)
  prune_fn = get_prune_fn(FLAGS.prune)(k=prune_percent)
  current_weights_pruned = prune_weights(prune_fn, current_weights, None,
                                         hparams)

  print("there are ", len(tf.trainable_variables()), " weights")
  for v in tf.trainable_variables():
    if is_prunable_weight(v):
      assign_op = v.assign(
          np.reshape(current_weights_pruned[v.name.strip(":0")], v.shape))
      sess.run(assign_op)


================================================
FILE: train.py
================================================
import cloud
import os
import sys
import subprocess
import random
import tensorflow as tf
import numpy as np
import time
import logging

from .hparams.registry import get_hparams
from .models.registry import get_model
from .data.registry import get_input_fns
from .training.lr_schemes import get_lr
from .training.envs import get_env
from .training import flags
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator


def init_flags():
  tf.flags.DEFINE_string("env", None, "Which environment to use.")  # required
  tf.flags.DEFINE_string("hparams", None, "Which hparams to use.")  # required
  # Utility flags
  tf.flags.DEFINE_string("hparam_override", "",
                         "Run-specific hparam settings to use.")
  tf.flags.DEFINE_boolean("fresh", False, "Remove output_dir before running.")
  tf.flags.DEFINE_integer("seed", None, "Random seed.")
  tf.flags.DEFINE_integer("train_epochs", None,
                          "Number of training epochs to perform.")
  tf.flags.DEFINE_integer("eval_steps", None,
                          "Number of evaluation steps to perform.")
  # TPU flags
  tf.flags.DEFINE_string("tpu_name", "", "Name of TPU(s)")
  tf.flags.DEFINE_integer(
      "tpu_iterations_per_loop", 1000,
      "The number of training steps to run on TPU before"
      "returning control to CPU.")
  tf.flags.DEFINE_integer(
      "tpu_shards", 8, "The number of TPU shards in the system "
      "(a single Cloud TPU has 8 shards.")
  tf.flags.DEFINE_boolean(
      "tpu_summarize", False, "Save summaries for TensorBoard. "
      "Warning: this will slow down execution.")
  tf.flags.DEFINE_boolean("tpu_dedicated", False,
                          "Do not use preemptible TPUs.")
  tf.flags.DEFINE_string("data_dir", None, "The data directory.")
  tf.flags.DEFINE_string("output_dir", None, "The output directory.")
  tf.flags.DEFINE_integer("eval_every", 1000,
                          "Number of steps between evaluations.")


tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = None


def init_random_seeds():
  tf.set_random_seed(FLAGS.seed)
  random.seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)


def init_model(hparams_name):
  flags.validate_flags(FLAGS)

  tf.reset_default_graph()

  hparams = get_hparams(hparams_name)
  hparams = hparams.parse(FLAGS.hparam_override)
  hparams = flags.update_hparams(FLAGS, hparams, hparams_name)

  # set larger eval_every for TPUs to improve utilization
  if FLAGS.env == "tpu":
    FLAGS.eval_every = max(FLAGS.eval_every, 5000)
    hparams.tpu_summarize = FLAGS.tpu_summarize

  tf.logging.warn("\n-----------------------------------------\n"
                  "BEGINNING RUN:\n"
                  "\t hparams: %s\n"
                  "\t output_dir: %s\n"
                  "\t data_dir: %s\n"
                  "-----------------------------------------\n" %
                  (hparams_name, hparams.output_dir, hparams.data_dir))

  return hparams


def construct_estimator(model_fn, hparams, tpu=None):
  if hparams.use_tpu:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=tpu.name)
    master = tpu_cluster_resolver.get_master()
    config = tpu_config.RunConfig(
        master=master,
        evaluation_master=master,
        model_dir=hparams.output_dir,
        session_config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.tpu_iterations_per_loop,
            num_shards=FLAGS.tpu_shards),
        save_checkpoints_steps=FLAGS.eval_every)
    estimator = tpu_estimator.TPUEstimator(
        use_tpu=hparams.use_tpu,
        model_fn=model_fn,
        model_dir=hparams.output_dir,
        config=config,
        train_batch_size=hparams.batch_size,
        eval_batch_size=hparams.batch_size)
  else:
    gpu_config = tf.ConfigProto(allow_soft_placement=True)
    gpu_config.gpu_options.allow_growth = True
    run_config = tf.estimator.RunConfig(
        save_checkpoints_steps=FLAGS.eval_every, session_config=gpu_config)

    estimator = tf.estimator.Estimator(
        model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),
        model_dir=hparams.output_dir,
        config=run_config)

  return estimator


def _run(hparams_name):
  """Run training, evaluation and inference."""
  hparams = init_model(hparams_name)
  original_batch_size = hparams.batch_size
  if tf.gfile.Exists(hparams.output_dir) and FLAGS.fresh:
    tf.gfile.DeleteRecursively(hparams.output_dir)

  if not tf.gfile.Exists(hparams.output_dir):
    tf.gfile.MakeDirs(hparams.output_dir)
  model_fn = get_model(hparams)
  train_input_fn, eval_input_fn, test_input_fn = get_input_fns(hparams)

  tpu = None
  if hparams.use_tpu:
    cloud.instance.tpu.clean()
    tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)

  estimator = construct_estimator(model_fn, hparams, tpu)

  if not hparams.use_tpu:
    features, labels = train_input_fn()
    sess = tf.Session()
    tf.train.get_or_create_global_step()

    model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
    sess.run(tf.global_variables_initializer())

  # output metadata about the run
  with tf.gfile.GFile(os.path.join(hparams.output_dir, 'hparams.txt'),
                      'w') as hparams_file:
    hparams_file.write("{}\n".format(time.time()))
    hparams_file.write("{}\n".format(str(hparams)))

  def loop(steps=FLAGS.eval_every):
    estimator.train(train_input_fn, steps=steps)
    if eval_input_fn:
      estimator.evaluate(eval_input_fn, steps=hparams.eval_steps, name="eval")
    if test_input_fn:
      estimator.evaluate(test_input_fn, steps=hparams.eval_steps, name="test")

  loop(1)

  steps = estimator.get_variable_value("global_step")
  k = steps * original_batch_size / float(hparams.epoch_size)
  while k <= hparams.train_epochs:
    tf.logging.info("Beginning epoch %f / %d" % (k, hparams.train_epochs))

    if tpu and not tpu.usable:
      tpu.delete(async=True)
      tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)
      estimator = construct_estimator(model_fn, hparams, tpu)

    loop()

    steps = estimator.get_variable_value("global_step")
    k = steps * original_batch_size / float(hparams.epoch_size)


def main(_):
  global FLAGS
  FLAGS = tf.app.flags.FLAGS

  init_random_seeds()
  if FLAGS.env != "local":
    cloud.connect()
  for hparams_name in FLAGS.hparams.split(","):
    _run(hparams_name)


if __name__ == "__main__":
  init_flags()
  tf.app.run()


================================================
FILE: training/__init__.py
================================================
__all__ = ["lr_schemes", "tpu", "flags"]

from .lr_schemes import *
from .tpu import *
from .flags import *


================================================
FILE: training/envs.py
================================================
_ENVS = dict()


def register(cls):
  global _ENVS
  _ENVS[cls.__name__.lower()] = cls()
  return cls


def get_env(name):
  return _ENVS[name]


@register
class GCP(object):
  data_dir = "/path/to/your/data"
  output_dir = "/path/to/your/output"


@register
class TPU(object):
  data_dir = "/path/to/your/data"
  output_dir = "/path/to/your/output"


@register
class Local(object):
  data_dir = "/tmp/data"
  output_dir = "/tmp/runs"


================================================
FILE: training/flags.py
================================================
import getpass
import os
import subprocess

import tensorflow as tf

from .envs import get_env


def validate_flags(FLAGS):
  messages = []
  if not FLAGS.env:
    messages.append("Missing required flag --env")
  if not FLAGS.hparams:
    messages.append("Missing required flag --hparams")

  if len(messages) > 0:
    raise Exception("\n".join(messages))

  return FLAGS


def update_hparams(FLAGS, hparams, hparams_name):
  hparams.env = FLAGS.env
  hparams.use_tpu = hparams.env == "tpu"
  hparams.train_epochs = FLAGS.train_epochs or hparams.train_epochs
  hparams.eval_steps = FLAGS.eval_steps or hparams.eval_steps

  env = get_env(FLAGS.env)
  hparams.data_dir = os.path.join(FLAGS.data_dir or env.data_dir, hparams.data)
  hparams.output_dir = os.path.join(env.output_dir, FLAGS.hparams)

  return hparams


================================================
FILE: training/lr_schemes.py
================================================
import tensorflow as tf

_LR = dict()


def register(name):

  def add_to_dict(fn):
    global _LR
    _LR[name] = fn
    return fn

  return add_to_dict


def get_lr(params):
  gs = tf.train.get_global_step()
  return _LR[params.lr_scheme](gs, params)


@register("constant")
def constant(gs, params):
  return tf.constant(params.learning_rate)


@register("exp")
def exponential_decay(gs, params, delay=0):
  gs -= delay
  return tf.train.exponential_decay(
      params.learning_rate,
      gs,
      params.learning_rate_decay_interval,
      params.learning_rate_decay_rate,
      staircase=params.staircased)


@register("lin")
def linear_decay(gs, params, delay=0):
  gs -= delay
  return (
      params.learning_rate -
      (tf.to_float(gs) / (params.train_steps - delay)) * params.learning_rate)


@register("delay_exp")
def delayed_exponential_decay(gs, params):
  d = params.delay
  return tf.cond(
      tf.greater(gs, d), lambda: exponential_decay(gs - d, params, delay=d),
      lambda: params.learning_rate)


@register("delay_lin")
def delayed_linear_decay(gs, params):
  d = params.delay
  return tf.cond(
      tf.greater(gs, d), lambda: linear_decay(gs - d, params, delay=d),
      lambda: params.learning_rate)


@register("warmup_resnet")
def warmup_resnet(gs, params):
  warmup_steps = params.warmup_steps
  inv_base = tf.exp(tf.log(0.01) / warmup_steps)
  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))

  epoch = params.epoch_size // params.batch_size
  boundaries = [epoch * 30, epoch * 60, epoch * 80, epoch * 90]
  values = [1e0, 1e-1, 1e-2, 1e-3, 1e-4]
  lr = tf.train.piecewise_constant(
      gs - warmup_steps, boundaries=boundaries, values=values)

  return tf.cond(
      tf.greater(gs, warmup_steps), lambda: lr,
      lambda: inv_decay * params.learning_rate)


@register("resnet")
def resnet(gs, params):
  return tf.cond(
      tf.less(gs, 40000),
      lambda: params.learning_rate,
      lambda: tf.cond(
          tf.less(gs, 60000),
          lambda: params.learning_rate*0.1,
          lambda: tf.cond(
              tf.less(gs, 80000),
              lambda: params.learning_rate * 0.01,
              lambda: params.learning_rate * 0.001)))


@register("lenet")
def lenet(gs, _):
  return tf.cond(
      tf.less(gs, 80000), lambda: 0.05,
      lambda: tf.cond(tf.less(gs, 120000), lambda: 0.005, lambda: 0.0005))


@register("steps")
def stepped_lr(gs, params):
  lr = params.lr_values[-1]
  for step, value in reversed(list(zip(params.lr_steps, params.lr_values))):
    lr = tf.cond(tf.greater(gs, step), lambda: lr, lambda: value)
  return lr


@register("warmup_linear_decay")
def warmup_linear_decay(gs, params):
  d = params.delay
  warmup_steps = params.warmup_steps
  inv_base = tf.exp(tf.log(0.01) / warmup_steps)
  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))

  return tf.cond(
      tf.greater(gs, warmup_steps), lambda: linear_decay(gs, params, delay=d),
      lambda: inv_decay * params.learning_rate)


@register("warmup_constant")
def warmup_constant(gs, params):
  warmup_steps = params.warmup_steps
  inv_base = tf.exp(tf.log(0.01) / warmup_steps)
  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))

  return tf.cond(
      tf.greater(gs, warmup_steps), lambda: constant(gs, params),
      lambda: inv_decay * params.learning_rate)


@register("warmup_exponential_decay")
def warmup_exponential_decay(gs, params):
  d = params.delay
  warmup_steps = params.warmup_steps
  inv_base = tf.exp(tf.log(0.01) / warmup_steps)
  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))

  return tf.cond(
      tf.greater(gs,
                 warmup_steps), lambda: exponential_decay(gs, params, delay=d),
      lambda: inv_decay * params.learning_rate)


@register("warmup_cosine")
def warmup_cosine(gs, params):
  from numpy import pi

  warmup_steps = params.warmup_steps
  inv_base = tf.exp(tf.log(0.01) / warmup_steps)
  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))

  gs = tf.minimum(gs - warmup_steps, params.learning_rate_cosine_cycle_steps)
  cosine_decay = 0.5 * (1 + tf.cos(
      pi * tf.to_float(gs) / params.learning_rate_cosine_cycle_steps))
  decayed = (1 - params.cosine_alpha) * cosine_decay + params.cosine_alpha
  lr = params.learning_rate * decayed

  return tf.cond(
      tf.greater(gs, warmup_steps), lambda: lr,
      lambda: inv_decay * params.learning_rate)


@register("cosine")
def cosine_annealing(gs, params):
  from numpy import pi

  gs = tf.minimum(gs, params.learning_rate_cosine_cycle_steps)
  cosine_decay = 0.5 * (1 + tf.cos(
      pi * tf.to_float(gs) / params.learning_rate_cosine_cycle_steps))
  decayed = (1 - params.cosine_alpha) * cosine_decay + params.cosine_alpha
  decayed_learning_rate = params.learning_rate * decayed

  return decayed_learning_rate


================================================
FILE: training/tpu.py
================================================
import collections
import six

import tensorflow as tf


def remove_summaries():
  g = tf.get_default_graph()
  key = tf.GraphKeys.SUMMARIES
  del g.get_collection_ref(key)[:]
  assert not g.get_collection(key)


# From Tensor2Tensor
def create_host_call(model_dir):
  """Construct a host_call writing scalar summaries.
  Args:
    model_dir: String containing path to train
  Returns:
    (fn, args) Pair to be called by TPUEstimator as the host_call.
  """
  graph = tf.get_default_graph()
  summaries = graph.get_collection(tf.GraphKeys.SUMMARIES)

  gs_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
  summary_kwargs = collections.OrderedDict()
  for t in summaries:
    if t.op.type not in ["ScalarSummary", "HistogramSummary"]:
      tf.logging.warn("Ignoring unsupported tf.Summary type %s" % t.op.type)
      continue

    name = t.op.name
    tensor = t.op.inputs[1]
    if t.op.type == "ScalarSummary":
      assert tensor.shape.is_compatible_with([])
      if tensor.dtype == tf.int64:
        tensor = tf.to_int32(tensor)
      summary_kwargs["ScalarSummary" + name] = tf.reshape(tensor, [1])
    elif t.op.type == "HistogramSummary":
      summary_kwargs["HistogramSummary" + name] = tf.reshape(tensor, [-1])
  # When no supported summaries are found, don't create host_call. Otherwise,
  # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
  # it, eventually causing hang.
  if not summary_kwargs:
    return None
  summary_kwargs["global_step"] = gs_t

  def host_call_fn(**kwargs):
    """Training host call. Creates summaries for training metrics.
    Args:
      **kwargs: Dict of {str: Tensor} , with `Tensor` of shape `[batch]`. Must
        contain key "global_step" with value of current global_step Tensor.
    Returns:
      List of summary ops to run on the CPU host.
    """
    gs = tf.to_int64(kwargs.pop("global_step")[0])
    with tf.contrib.summary.create_file_writer(model_dir).as_default():
      with tf.contrib.summary.always_record_summaries():
        # We need to use tf.contrib.summary in order to feed the `step`.
        for name, value in sorted(six.iteritems(kwargs)):
          if name.startswith("ScalarSummary"):
            name = name[len("ScalarSummary"):]
            tf.contrib.summary.scalar(
                name, tf.reduce_mean(tf.to_float(value)), step=gs)
          elif name.startswith("HistogramSummary"):
            name = name[len("HistogramSummary"):]
            tf.contrib.summary.histogram(name, value, step=gs)
          elif name.startswith("ImageSummary"):
            name = name[len("ImageSummary"):]
            tf.contrib.summary.image(name, value, step=gs)

        return tf.contrib.summary.all_summary_ops()

  return (host_call_fn, summary_kwargs)
Download .txt
gitextract_tcky92iz/

├── .gitignore
├── .travis.yml
├── README.md
├── __init__.py
├── data/
│   ├── __init__.py
│   ├── data_generators/
│   │   ├── __init__.py
│   │   ├── cifar_generator.py
│   │   ├── generator_utils.py
│   │   └── mnist_generator.py
│   ├── dataset_maps.py
│   ├── image_reader.py
│   ├── imagenet_augs.py
│   └── registry.py
├── hparams/
│   ├── __init__.py
│   ├── basic.py
│   ├── defaults.py
│   ├── lenet.py
│   ├── registry.py
│   ├── resnet.py
│   ├── user.py
│   ├── utils.py
│   └── vgg.py
├── models/
│   ├── __init__.py
│   ├── basic/
│   │   ├── __init__.py
│   │   └── basic.py
│   ├── lenet/
│   │   ├── __init__.py
│   │   └── lenet.py
│   ├── registry.py
│   ├── resnet/
│   │   ├── __init__.py
│   │   └── resnet.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── dropouts.py
│   │   ├── initializations.py
│   │   ├── model_utils.py
│   │   └── optimizers.py
│   └── vgg/
│       ├── __init__.py
│       └── vgg.py
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   └── prune/
│       ├── README.md
│       ├── __init__.py
│       ├── eval.py
│       └── prune.py
├── train.py
└── training/
    ├── __init__.py
    ├── envs.py
    ├── flags.py
    ├── lr_schemes.py
    └── tpu.py
Download .txt
SYMBOL INDEX (288 symbols across 31 files)

FILE: data/data_generators/cifar_generator.py
  function download (line 35) | def download(v100):
  function maybe_download (line 50) | def maybe_download(files, v100):
  function read_files (line 59) | def read_files(files, v100):
  function cifar_generator (line 86) | def cifar_generator(v100, mode):
  function generate (line 120) | def generate(train_name, eval_name, test_name, hparams):

FILE: data/data_generators/generator_utils.py
  function to_example (line 12) | def to_example(dictionary):
  function generate_files (line 28) | def generate_files(generator,

FILE: data/data_generators/mnist_generator.py
  function download_files (line 25) | def download_files(filenames):
  function read_images (line 43) | def read_images(filepath, num_images):
  function read_labels (line 52) | def read_labels(filepath, num_labels):
  function mnist_generator (line 60) | def mnist_generator(mode):
  function generate (line 96) | def generate(train_name, eval_name, test_name, hparams):

FILE: data/dataset_maps.py
  function register (line 7) | def register(fn):
  function get_augmentation (line 13) | def get_augmentation(name, params, training):
  function cifar_augmentation (line 23) | def cifar_augmentation(image, label, training, params):
  function imagenet_augmentation (line 40) | def imagenet_augmentation(image, label, training, params):
  function load_images (line 55) | def load_images(example, training, params):
  function set_shapes (line 72) | def set_shapes(image, label, training, params):
  function transpose (line 76) | def transpose(image, label, training, params):

FILE: data/image_reader.py
  function image_reader (line 13) | def image_reader(data_sources, hparams, training):
  function mnist_simple (line 53) | def mnist_simple(data_source, params, training):
  function fashion (line 77) | def fashion(data_source, params, training):

FILE: data/imagenet_augs.py
  function _crop (line 10) | def _crop(image, offset_height, offset_width, crop_height, crop_width):
  function distorted_bounding_box_crop (line 48) | def distorted_bounding_box_crop(image,
  function _random_crop (line 104) | def _random_crop(image, size):
  function _flip (line 123) | def _flip(image):
  function _at_least_x_are_true (line 129) | def _at_least_x_are_true(a, b, x):
  function _do_scale (line 136) | def _do_scale(image, size):
  function _center_crop (line 147) | def _center_crop(image, size):
  function _normalize (line 158) | def _normalize(image):
  function preprocess_for_train (line 168) | def preprocess_for_train(image, image_size=224):
  function preprocess_for_eval (line 183) | def preprocess_for_eval(image, image_size=224):

FILE: data/registry.py
  function register (line 9) | def register(name, generator):
  function get_input_fns (line 21) | def get_input_fns(hparams, generate=True):
  function get_dataset (line 50) | def get_dataset(hparams):
  function maybe_generate (line 60) | def maybe_generate(check_path, hparams):

FILE: hparams/basic.py
  function mnist_basic_no_dropout (line 9) | def mnist_basic_no_dropout():
  function mnist_basic_trgtd_dropout (line 31) | def mnist_basic_trgtd_dropout():
  function mnist_basic_untrgtd_dropout (line 41) | def mnist_basic_untrgtd_dropout():
  function mnist_basic_trgtd_dropout_random (line 50) | def mnist_basic_trgtd_dropout_random():
  function mnist_basic_trgtd_unit_dropout (line 60) | def mnist_basic_trgtd_unit_dropout():
  function mnist_basic_smallify_dropout_1eneg4 (line 70) | def mnist_basic_smallify_dropout_1eneg4():
  function mnist_basic_smallify_dropout_1eneg3 (line 81) | def mnist_basic_smallify_dropout_1eneg3():
  function mnist_basic_smallify_weight_dropout_1eneg4 (line 89) | def mnist_basic_smallify_weight_dropout_1eneg4():
  function cifar10_basic_no_dropout (line 100) | def cifar10_basic_no_dropout():
  function cifar100_basic_no_dropout (line 123) | def cifar100_basic_no_dropout():
  function imagenet32_basic (line 131) | def imagenet32_basic():

FILE: hparams/defaults.py
  function default (line 8) | def default():
  function default_cifar10 (line 64) | def default_cifar10():
  function default_cifar100 (line 79) | def default_cifar100():
  function default_imagenet299 (line 89) | def default_imagenet299():
  function default_imagenet224 (line 104) | def default_imagenet224():
  function default_imagenet64 (line 112) | def default_imagenet64():
  function default_imagenet32 (line 120) | def default_imagenet32():

FILE: hparams/lenet.py
  function cifar_lenet (line 10) | def cifar_lenet():
  function cifar_lenet_no_dropout (line 39) | def cifar_lenet_no_dropout():
  function cifar_lenet_weight (line 45) | def cifar_lenet_weight():
  function cifar_lenet_trgtd_weight (line 53) | def cifar_lenet_trgtd_weight():
  function cifar_lenet_unit (line 62) | def cifar_lenet_unit():
  function cifar_lenet_trgtd_unit (line 70) | def cifar_lenet_trgtd_unit():
  function cifar_lenet_l1 (line 79) | def cifar_lenet_l1():
  function cifar_lenet_trgtd_weight_l1 (line 86) | def cifar_lenet_trgtd_weight_l1():
  function cifar_lenet_trgtd_unit_l1 (line 96) | def cifar_lenet_trgtd_unit_l1():
  function cifar_lenet_trgtd_unit_botk75_33 (line 106) | def cifar_lenet_trgtd_unit_botk75_33():
  function cifar_lenet_trgtd_unit_botk75_66 (line 115) | def cifar_lenet_trgtd_unit_botk75_66():
  function cifar_lenet_trgtd_weight_botk75_33 (line 124) | def cifar_lenet_trgtd_weight_botk75_33():
  function cifar_lenet_trgtd_weight_botk75_66 (line 133) | def cifar_lenet_trgtd_weight_botk75_66():
  function cifar_lenet_louizos_weight_1en3 (line 142) | def cifar_lenet_louizos_weight_1en3():
  function cifar_lenet_louizos_weight_1en1 (line 154) | def cifar_lenet_louizos_weight_1en1():
  function cifar_lenet_louizos_weight_1en2 (line 166) | def cifar_lenet_louizos_weight_1en2():
  function cifar_lenet_louizos_weight_5en3 (line 178) | def cifar_lenet_louizos_weight_5en3():
  function cifar_lenet_louizos_weight_1en4 (line 190) | def cifar_lenet_louizos_weight_1en4():
  function cifar_lenet_louizos_unit_1en3 (line 202) | def cifar_lenet_louizos_unit_1en3():
  function cifar_lenet_louizos_unit_1en1 (line 214) | def cifar_lenet_louizos_unit_1en1():
  function cifar_lenet_louizos_unit_1en2 (line 226) | def cifar_lenet_louizos_unit_1en2():
  function cifar_lenet_louizos_unit_5en3 (line 238) | def cifar_lenet_louizos_unit_5en3():
  function cifar_lenet_louizos_unit_1en4 (line 250) | def cifar_lenet_louizos_unit_1en4():
  function cifar_lenet_variational (line 262) | def cifar_lenet_variational():
  function cifar_lenet_variational_unscaled (line 272) | def cifar_lenet_variational_unscaled():
  function cifar_lenet_variational_unit (line 281) | def cifar_lenet_variational_unit():
  function cifar_lenet_variational_unit_unscaled (line 291) | def cifar_lenet_variational_unit_unscaled():
  function cifar_lenet_smallify_neg4 (line 300) | def cifar_lenet_smallify_neg4():

FILE: hparams/registry.py
  function register (line 6) | def register(fn):
  function get_hparams (line 12) | def get_hparams(hparams_list):

FILE: hparams/resnet.py
  function resnet_default (line 9) | def resnet_default():
  function resnet102_imagenet224 (line 24) | def resnet102_imagenet224():
  function resnet102_imagenet64 (line 44) | def resnet102_imagenet64():
  function resnet50_imagenet224 (line 51) | def resnet50_imagenet224():
  function resnet34_imagenet224 (line 58) | def resnet34_imagenet224():
  function resnet_cifar100 (line 65) | def resnet_cifar100():
  function cifar10_resnet32 (line 72) | def cifar10_resnet32():
  function cifar10_resnet32_no_dropout (line 79) | def cifar10_resnet32_no_dropout():
  function cifar10_resnet32_trgtd_weight (line 87) | def cifar10_resnet32_trgtd_weight():
  function cifar10_resnet32_weight (line 97) | def cifar10_resnet32_weight():
  function cifar10_resnet32_weight_50 (line 106) | def cifar10_resnet32_weight_50():
  function cifar10_resnet32_trgtd_unit (line 114) | def cifar10_resnet32_trgtd_unit():
  function cifar10_resnet32_trgtd_ard (line 124) | def cifar10_resnet32_trgtd_ard():
  function cifar10_resnet32_unit (line 134) | def cifar10_resnet32_unit():
  function cifar10_resnet32_unit_50 (line 143) | def cifar10_resnet32_unit_50():
  function cifar10_resnet32_l1_1eneg3 (line 151) | def cifar10_resnet32_l1_1eneg3():
  function cifar10_resnet32_l1_1eneg2 (line 159) | def cifar10_resnet32_l1_1eneg2():
  function cifar10_resnet32_l1_1eneg1 (line 167) | def cifar10_resnet32_l1_1eneg1():
  function cifar10_resnet32_trgted_weight_l1 (line 175) | def cifar10_resnet32_trgted_weight_l1():
  function cifar10_resnet32_targeted_unit_l1 (line 186) | def cifar10_resnet32_targeted_unit_l1():
  function cifar10_resnet32_trgtd_unit_botk75_33 (line 197) | def cifar10_resnet32_trgtd_unit_botk75_33():
  function cifar10_resnet32_trgtd_unit_botk75_66 (line 207) | def cifar10_resnet32_trgtd_unit_botk75_66():
  function cifar10_resnet32_trgtd_weight_botk75_33 (line 217) | def cifar10_resnet32_trgtd_weight_botk75_33():
  function cifar10_resnet32_trgtd_weight_botk75_66 (line 227) | def cifar10_resnet32_trgtd_weight_botk75_66():
  function cifar10_resnet32_trgtd_unit_ramping_botk90_99 (line 237) | def cifar10_resnet32_trgtd_unit_ramping_botk90_99():
  function cifar10_resnet32_trgtd_weight_ramping_botk99_99 (line 247) | def cifar10_resnet32_trgtd_weight_ramping_botk99_99():
  function cifar10_resnet32_louizos_weight_1en3 (line 258) | def cifar10_resnet32_louizos_weight_1en3():
  function cifar10_resnet32_louizos_weight_1en1 (line 271) | def cifar10_resnet32_louizos_weight_1en1():
  function cifar10_resnet32_louizos_weight_1en2 (line 280) | def cifar10_resnet32_louizos_weight_1en2():
  function cifar10_resnet32_louizos_weight_5en3 (line 288) | def cifar10_resnet32_louizos_weight_5en3():
  function cifar10_resnet32_louizos_weight_1en4 (line 296) | def cifar10_resnet32_louizos_weight_1en4():
  function cifar10_resnet32_louizos_unit_1en3 (line 304) | def cifar10_resnet32_louizos_unit_1en3():
  function cifar10_resnet32_louizos_unit_1en1 (line 317) | def cifar10_resnet32_louizos_unit_1en1():
  function cifar10_resnet32_louizos_unit_1en2 (line 325) | def cifar10_resnet32_louizos_unit_1en2():
  function cifar10_resnet32_louizos_unit_5en3 (line 333) | def cifar10_resnet32_louizos_unit_5en3():
  function cifar10_resnet32_louizos_unit_1en4 (line 341) | def cifar10_resnet32_louizos_unit_1en4():
  function cifar10_resnet32_louizos_unit_1en5 (line 349) | def cifar10_resnet32_louizos_unit_1en5():
  function cifar10_resnet32_louizos_unit_1en6 (line 357) | def cifar10_resnet32_louizos_unit_1en6():
  function cifar10_resnet32_variational_weight (line 365) | def cifar10_resnet32_variational_weight():
  function cifar10_resnet32_variational_weight_unscaled (line 377) | def cifar10_resnet32_variational_weight_unscaled():
  function cifar10_resnet32_variational_unit (line 389) | def cifar10_resnet32_variational_unit():
  function cifar10_resnet32_variational_unit_unscaled (line 401) | def cifar10_resnet32_variational_unit_unscaled():
  function cifar10_resnet32_smallify_1eneg4 (line 413) | def cifar10_resnet32_smallify_1eneg4():
  function cifar10_resnet32_smallify_1eneg3 (line 424) | def cifar10_resnet32_smallify_1eneg3():
  function cifar10_resnet32_smallify_1eneg5 (line 432) | def cifar10_resnet32_smallify_1eneg5():
  function cifar10_resnet32_smallify_1eneg6 (line 440) | def cifar10_resnet32_smallify_1eneg6():
  function cifar10_resnet32_smallify_weight_1eneg4 (line 448) | def cifar10_resnet32_smallify_weight_1eneg4():
  function cifar10_resnet32_smallify_weight_1eneg3 (line 459) | def cifar10_resnet32_smallify_weight_1eneg3():
  function cifar10_resnet32_smallify_weight_1eneg5 (line 467) | def cifar10_resnet32_smallify_weight_1eneg5():
  function cifar10_resnet32_smallify_weight_1eneg6 (line 475) | def cifar10_resnet32_smallify_weight_1eneg6():

FILE: hparams/utils.py
  class HParams (line 4) | class HParams(tf.contrib.training.HParams):
    method __setattr__ (line 13) | def __setattr__(self, name, value):

FILE: hparams/vgg.py
  function vgg16_default (line 9) | def vgg16_default():
  function cifar10_vgg16 (line 27) | def cifar10_vgg16():
  function cifar100_vgg16_no_dropout (line 34) | def cifar100_vgg16_no_dropout():
  function cifar10_vgg16_no_dropout (line 47) | def cifar10_vgg16_no_dropout():
  function cifar100_vgg16_targeted_dropout (line 60) | def cifar100_vgg16_targeted_dropout():
  function cifar100_vgg16_untargeted_dropout (line 69) | def cifar100_vgg16_untargeted_dropout():
  function cifar100_vgg16_untargeted_unit_dropout (line 77) | def cifar100_vgg16_untargeted_unit_dropout():
  function cifar100_vgg16_targeted_unit_dropout (line 85) | def cifar100_vgg16_targeted_unit_dropout():
  function cifar100_vgg16_targeted_unit_dropout_botk75_66 (line 94) | def cifar100_vgg16_targeted_unit_dropout_botk75_66():
  function cifar100_vgg16_louizos_unit (line 102) | def cifar100_vgg16_louizos_unit():
  function cifar100_vgg16_louizos_weight (line 115) | def cifar100_vgg16_louizos_weight():
  function cifar100_vgg16_variational_unscaled (line 123) | def cifar100_vgg16_variational_unscaled():
  function cifar100_vgg16_variational (line 135) | def cifar100_vgg16_variational():
  function cifar100_vgg16_variational_unit_unscaled (line 143) | def cifar100_vgg16_variational_unit_unscaled():
  function cifar100_vgg16_variational_unit (line 151) | def cifar100_vgg16_variational_unit():
  function cifar100_vgg16_smallify_1eneg4 (line 159) | def cifar100_vgg16_smallify_1eneg4():
  function cifar100_vgg16_smallify_weight_1eneg5 (line 170) | def cifar100_vgg16_smallify_weight_1eneg5():

FILE: models/basic/basic.py
  function get_basic (line 12) | def get_basic(params, lr):

FILE: models/lenet/lenet.py
  function get_lenet (line 13) | def get_lenet(hparams, lr):

FILE: models/registry.py
  function register (line 8) | def register(name):
  function get_model (line 18) | def get_model(hparams):

FILE: models/resnet/resnet.py
  function get_resnet (line 15) | def get_resnet(hparams, lr):

FILE: models/utils/activations.py
  function register (line 6) | def register(name):
  function get_activation (line 16) | def get_activation(params):
  function relu (line 21) | def relu(params):
  function brelu (line 26) | def brelu(params):
  function selu (line 42) | def selu(params):
  function elu (line 47) | def elu(params):
  function sigmoid (line 52) | def sigmoid(params):
  function swish (line 57) | def swish(params):
  function tanh (line 62) | def tanh(params):

FILE: models/utils/dropouts.py
  function register (line 7) | def register(name):
  function get_dropout (line 17) | def get_dropout(name):
  function targeted_weight_dropout (line 23) | def targeted_weight_dropout(w, params, is_training):
  function targeted_weight_random (line 47) | def targeted_weight_random(w, params, is_training):
  function ramping_targeted_weight_random (line 75) | def ramping_targeted_weight_random(w, params, is_training):
  function targeted_weight_piecewise_dropout (line 109) | def targeted_weight_piecewise_dropout(w, params, is_training):
  function targeted_unit_piecewise (line 142) | def targeted_unit_piecewise(w, params, is_training):
  function delayed_targeted_weight (line 176) | def delayed_targeted_weight(w, params, is_training):
  function delayed_targeted_unit (line 193) | def delayed_targeted_unit(x, params, is_training):
  function untargeted_weight (line 210) | def untargeted_weight(w, params, is_training):
  function targeted_unit_dropout (line 217) | def targeted_unit_dropout(x, params, is_training):
  function targeted_unit_random (line 241) | def targeted_unit_random(w, params, is_training):
  function targeted_ard_dropout (line 269) | def targeted_ard_dropout(w, x, params, is_training):
  function unit_dropout (line 291) | def unit_dropout(w, params, is_training):
  function louizos_weight_dropout (line 303) | def louizos_weight_dropout(w, params, is_training):
  function louizos_unit_dropout (line 327) | def louizos_unit_dropout(w, params, is_training):
  function log_sigma2_variable (line 352) | def log_sigma2_variable(shape, ard_init=-10.):
  function get_log_alpha (line 358) | def get_log_alpha(log_sigma2, w):
  function paranoid_log (line 364) | def paranoid_log(x, eps=1e-8):
  function clip (line 370) | def clip(x):
  function dkl_qp (line 374) | def dkl_qp(log_alpha):
  function variational_dropout (line 383) | def variational_dropout(w, _, is_training):
  function variational_unit_dropout (line 396) | def variational_unit_dropout(w, _, is_training):
  function smallify_dropout (line 411) | def smallify_dropout(x, hparams, is_training):
  function smallify_weight_dropout (line 453) | def smallify_weight_dropout(x, hparams, is_training):

FILE: models/utils/initializations.py
  function register (line 6) | def register(name):
  function get_init (line 16) | def get_init(params):
  function normal (line 21) | def normal(params):
  function constant (line 26) | def constant(params):
  function uniform_unit_scaling (line 31) | def uniform_unit_scaling(params):
  function glorot_normal_initializer (line 36) | def glorot_normal_initializer(params):
  function glorot_uniform_initializer (line 41) | def glorot_uniform_initializer(params):
  function variance_scaling_initializer (line 46) | def variance_scaling_initializer(params):
  class RandomUnitScaling (line 50) | class RandomUnitScaling(tf.keras.initializers.Initializer):
    method __call__ (line 52) | def __call__(self, shape, dtype=None, partition_info=None):
  class RandomHadamardConstant (line 63) | class RandomHadamardConstant(tf.keras.initializers.Initializer):
    method __call__ (line 65) | def __call__(self, shape, dtype=None, partition_info=None):
  class RandomHadamardUnscaled (line 73) | class RandomHadamardUnscaled(tf.keras.initializers.Initializer):
    method __call__ (line 75) | def __call__(self, shape, dtype=None, partition_info=None):
  class RandomWarpedUniform (line 79) | class RandomWarpedUniform(tf.keras.initializers.Initializer):
    method __init__ (line 81) | def __init__(self, k=2):
    method __call__ (line 84) | def __call__(self, shape, dtype=None, partition_info=None):
  function warped_unif (line 100) | def warped_unif(params):
  function unit_scaling (line 105) | def unit_scaling(params):
  function hadamard_constant (line 110) | def hadamard_constant(params):
  function hadamard_unscaled (line 115) | def hadamard_unscaled(params):

FILE: models/utils/model_utils.py
  class ModeKeys (line 12) | class ModeKeys(object):
  function collect_vars (line 20) | def collect_vars(fn):
  function dense (line 38) | def dense(x, units, hparams, is_training, dropout=True):
  function conv (line 54) | def conv(x,
  function weight_decay_and_noise (line 123) | def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  function weight_noise (line 136) | def weight_noise(hparams, learning_rate):
  function weight_decay (line 156) | def weight_decay(hparams):
  function axis_aligned_cost (line 177) | def axis_aligned_cost(logits, hparams):
  function ard_cost (line 199) | def ard_cost():
  function shape_list (line 214) | def shape_list(x):
  function standardize_images (line 233) | def standardize_images(x):
  function batch_norm (line 247) | def batch_norm(inputs, hparams, training):
  function louizos_complexity_cost (line 267) | def louizos_complexity_cost(params):
  function switch_loss (line 302) | def switch_loss():
  function nonzero_count (line 313) | def nonzero_count():
  function percent_sparsity (line 324) | def percent_sparsity():
  function convert (line 337) | def convert(num, base, length=None):
  function equal_mult (line 357) | def equal_mult(size, num_branches):
  function uniform (line 364) | def uniform(size, num_branches):
  function bernoulli (line 371) | def bernoulli(size, num_branches):
  function combine (line 379) | def combine(rand_uniform, rand_bernoulli, num_branches):
  function model_top (line 386) | def model_top(labels, preds, cost, lr, mode, hparams):

FILE: models/utils/optimizers.py
  function register (line 6) | def register(name):
  function get_optimizer (line 16) | def get_optimizer(lr, params):
  function sgd (line 24) | def sgd(lr, params):
  function adam (line 29) | def adam(lr, params):
  function adagrad (line 34) | def adagrad(lr, params):
  function momentum (line 39) | def momentum(lr, params):

FILE: models/vgg/vgg.py
  function metric_fn (line 17) | def metric_fn(labels, predictions):
  function get_vgg (line 27) | def get_vgg(hparams, lr):

FILE: scripts/prune/eval.py
  function init_flags (line 12) | def init_flags():
  function eval_model (line 39) | def eval_model(FLAGS, hparam_name):
  function _run (line 105) | def _run(FLAGS):

FILE: scripts/prune/prune.py
  function register (line 9) | def register(fn):
  function get_prune_fn (line 15) | def get_prune_fn(name):
  function weight (line 20) | def weight(mode, k=0.5):
  function unit (line 75) | def unit(mode, k=0.5):
  function ard (line 108) | def ard(k=0.5):
  function prune_weights (line 126) | def prune_weights(prune_fn,
  function get_louizos_masks (line 171) | def get_louizos_masks(sess, weights):
  function get_smallify_masks (line 183) | def get_smallify_masks(sess, weights):
  function is_prunable_weight (line 199) | def is_prunable_weight(weight):
  function get_current_weights (line 215) | def get_current_weights(sess):
  function prune_sess_weights (line 235) | def prune_sess_weights(sess, prune_percent, FLAGS, hparams):

FILE: train.py
  function init_flags (line 21) | def init_flags():
  function init_random_seeds (line 57) | def init_random_seeds():
  function init_model (line 63) | def init_model(hparams_name):
  function construct_estimator (line 88) | def construct_estimator(model_fn, hparams, tpu=None):
  function _run (line 124) | def _run(hparams_name):
  function main (line 182) | def main(_):

FILE: training/envs.py
  function register (line 4) | def register(cls):
  function get_env (line 10) | def get_env(name):
  class GCP (line 15) | class GCP(object):
  class TPU (line 21) | class TPU(object):
  class Local (line 27) | class Local(object):

FILE: training/flags.py
  function validate_flags (line 10) | def validate_flags(FLAGS):
  function update_hparams (line 23) | def update_hparams(FLAGS, hparams, hparams_name):

FILE: training/lr_schemes.py
  function register (line 6) | def register(name):
  function get_lr (line 16) | def get_lr(params):
  function constant (line 22) | def constant(gs, params):
  function exponential_decay (line 27) | def exponential_decay(gs, params, delay=0):
  function linear_decay (line 38) | def linear_decay(gs, params, delay=0):
  function delayed_exponential_decay (line 46) | def delayed_exponential_decay(gs, params):
  function delayed_linear_decay (line 54) | def delayed_linear_decay(gs, params):
  function warmup_resnet (line 62) | def warmup_resnet(gs, params):
  function resnet (line 79) | def resnet(gs, params):
  function lenet (line 93) | def lenet(gs, _):
  function stepped_lr (line 100) | def stepped_lr(gs, params):
  function warmup_linear_decay (line 108) | def warmup_linear_decay(gs, params):
  function warmup_constant (line 120) | def warmup_constant(gs, params):
  function warmup_exponential_decay (line 131) | def warmup_exponential_decay(gs, params):
  function warmup_cosine (line 144) | def warmup_cosine(gs, params):
  function cosine_annealing (line 163) | def cosine_annealing(gs, params):

FILE: training/tpu.py
  function remove_summaries (line 7) | def remove_summaries():
  function create_host_call (line 15) | def create_host_call(model_dir):
Condensed preview — 50 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (149K chars).
[
  {
    "path": ".gitignore",
    "chars": 1225,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".travis.yml",
    "chars": 2165,
    "preview": "language: python\npython:\n  - \"3.6\"\n\n# command to install dependencies\ninstall:\n  - pip install -r requirements.txt\n\n# co"
  },
  {
    "path": "README.md",
    "chars": 733,
    "preview": "# Targeted Dropout\n\nAidan N. Gomez, Ivan Zhang, Kevin Swersky, Yarin Gal, and Geoffrey E. Hinton\n\n## Table of Contents\n-"
  },
  {
    "path": "__init__.py",
    "chars": 142,
    "preview": "__all__ = [\"data\", \"hparams\", \"models\", \"training\"]\n\nfrom .data import *\nfrom .hparams import *\nfrom .models import *\nfr"
  },
  {
    "path": "data/__init__.py",
    "chars": 70,
    "preview": "__all__ = [\n    \"image_reader\",\n    \"registry\",\n    \"dataset_maps\",\n]\n"
  },
  {
    "path": "data/data_generators/__init__.py",
    "chars": 83,
    "preview": "__all__ = [\n    \"cifar_generator\",\n    \"generator_utils\",\n    \"mnist_generator\",\n]\n"
  },
  {
    "path": "data/data_generators/cifar_generator.py",
    "chars": 3596,
    "preview": "try:\n  import cPickle\nexcept ImportError:\n  import pickle as cPickle\nimport os\nimport random\nimport sys\nimport tarfile\ni"
  },
  {
    "path": "data/data_generators/generator_utils.py",
    "chars": 1833,
    "preview": "import operator\nimport os\nimport numpy as np\nimport tensorflow as tf\n\ntf.flags.DEFINE_boolean(\"v100\", False,\n           "
  },
  {
    "path": "data/data_generators/mnist_generator.py",
    "chars": 3060,
    "preview": "import gzip\nimport os\nimport random\nimport urllib\nimport numpy as np\nimport tensorflow as tf\n\nfrom .generator_utils impo"
  },
  {
    "path": "data/dataset_maps.py",
    "chars": 2064,
    "preview": "import tensorflow as tf\nfrom . import imagenet_augs \n\n_AUGMENTATIONS = dict()\n\n\ndef register(fn):\n  global _AUGMENTATION"
  },
  {
    "path": "data/image_reader.py",
    "chars": 3271,
    "preview": "import tensorflow as tf\nfrom tensorflow.examples.tutorials.mnist import input_data\n\nfrom .registry import register\nfrom "
  },
  {
    "path": "data/imagenet_augs.py",
    "chars": 7543,
    "preview": "import tensorflow as tf\n\nMEAN_RGB = [0.485, 0.456, 0.406]\nSTDDEV_RGB = [0.229, 0.224, 0.225]\n\n\n# The following preproces"
  },
  {
    "path": "data/registry.py",
    "chars": 1918,
    "preview": "import os\n\nimport tensorflow as tf\n\n_INPUT_FNS = dict()\n_GENERATORS = dict()\n\n\ndef register(name, generator):\n\n  def add"
  },
  {
    "path": "hparams/__init__.py",
    "chars": 273,
    "preview": "__all__ = [\"defaults\", \"registry\", \"resnet\", \"lenet\", \"utils\", \"vgg\", \"basic\"]\n\nfrom .defaults import *\nfrom .resnet imp"
  },
  {
    "path": "hparams/basic.py",
    "chars": 2875,
    "preview": "import tensorflow as tf\n\nfrom . import defaults\nfrom .registry import register\n\n\n# MNIST =========================\n@regi"
  },
  {
    "path": "hparams/defaults.py",
    "chars": 2476,
    "preview": "import tensorflow as tf\n\nfrom .registry import register\nfrom .utils import HParams\n\n\n@register\ndef default():\n  return H"
  },
  {
    "path": "hparams/lenet.py",
    "chars": 6118,
    "preview": "import tensorflow as tf\n\nfrom .defaults import default, default_cifar10\nfrom .registry import register\n\n# lenet\n\n\n@regis"
  },
  {
    "path": "hparams/registry.py",
    "chars": 552,
    "preview": "import tensorflow as tf\n\n_HPARAMS = dict()\n\n\ndef register(fn):\n  global _HPARAMS\n  _HPARAMS[fn.__name__] = fn()\n  return"
  },
  {
    "path": "hparams/resnet.py",
    "chars": 9079,
    "preview": "import tensorflow as tf\n\nfrom .registry import register\nfrom .defaults import *\n\n\n# from https://github.com/tensorflow/m"
  },
  {
    "path": "hparams/user.py",
    "chars": 120,
    "preview": "import tensorflow as tf\n\nfrom .defaults import default\nfrom .registry import register\n\n# Add experimental hparams below\n"
  },
  {
    "path": "hparams/utils.py",
    "chars": 1285,
    "preview": "import tensorflow as tf\n\n\nclass HParams(tf.contrib.training.HParams):\n  \"\"\"Override of TensorFlow's HParams.\n\n  Replaces"
  },
  {
    "path": "hparams/vgg.py",
    "chars": 3599,
    "preview": "import tensorflow as tf\n\nfrom .registry import register\nfrom .defaults import default, default_cifar10\n\n\n# from https://"
  },
  {
    "path": "models/__init__.py",
    "chars": 166,
    "preview": "__all__ = [\"basic\", \"registry\", \"resnet\", \"lenet\", \"vgg\"]\n\nfrom .basic import *\nfrom .resnet import *\nfrom .registry imp"
  },
  {
    "path": "models/basic/__init__.py",
    "chars": 42,
    "preview": "__all__ = [\"basic\"]\n\nfrom .basic import *\n"
  },
  {
    "path": "models/basic/basic.py",
    "chars": 2589,
    "preview": "import tensorflow as tf\n\nfrom ..registry import register\n\nfrom ..utils.activations import get_activation\nfrom ..utils.in"
  },
  {
    "path": "models/lenet/__init__.py",
    "chars": 20,
    "preview": "__all__ = [\"lenet\"]\n"
  },
  {
    "path": "models/lenet/lenet.py",
    "chars": 5786,
    "preview": "import tensorflow as tf\n\nfrom ..registry import register\n\nfrom ..utils.activations import get_activation\nfrom ..utils.dr"
  },
  {
    "path": "models/registry.py",
    "chars": 483,
    "preview": "from ..training.lr_schemes import get_lr\n\nimport tensorflow as tf\n\n_MODELS = dict()\n\n\ndef register(name):\n\n  def add_to_"
  },
  {
    "path": "models/resnet/__init__.py",
    "chars": 21,
    "preview": "__all__ = [\"resnet\"]\n"
  },
  {
    "path": "models/resnet/resnet.py",
    "chars": 9866,
    "preview": "import tensorflow as tf\nimport numpy as np\n\nfrom ..utils import dropouts\nfrom ..utils.activations import get_activation\n"
  },
  {
    "path": "models/utils/__init__.py",
    "chars": 228,
    "preview": "__all__ = [\n    \"activations\", \"dropouts\", \"initializations\", \"model_utils\", \"optimizers\"\n]\n\nfrom .activations import *\n"
  },
  {
    "path": "models/utils/activations.py",
    "chars": 877,
    "preview": "import tensorflow as tf\n\n_ACTIVATION = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _ACTIVATION\n    _"
  },
  {
    "path": "models/utils/dropouts.py",
    "chars": 14649,
    "preview": "import numpy as np\nimport tensorflow as tf\n\n_DROPOUTS = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global "
  },
  {
    "path": "models/utils/initializations.py",
    "chars": 2683,
    "preview": "import tensorflow as tf\n\n_INIT = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _INIT\n    _INIT[name] ="
  },
  {
    "path": "models/utils/model_utils.py",
    "chars": 13609,
    "preview": "import operator\nfrom functools import reduce\n\nimport tensorflow as tf\nfrom tensorflow.contrib.tpu.python.tpu import tpu_"
  },
  {
    "path": "models/utils/optimizers.py",
    "chars": 814,
    "preview": "import tensorflow as tf\n\n_OPTIMIZER = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _OPTIMIZER\n    _OP"
  },
  {
    "path": "models/vgg/__init__.py",
    "chars": 18,
    "preview": "__all__ = [\"vgg\"]\n"
  },
  {
    "path": "models/vgg/vgg.py",
    "chars": 6960,
    "preview": "import tensorflow as tf\n\nfrom ..utils.activations import get_activation\nfrom ..utils.dropouts import get_dropout\nfrom .."
  },
  {
    "path": "requirements.txt",
    "chars": 48,
    "preview": "tensorflow>=1.9\nrequests>=2.19.1\ndl-cloud>=0.0.4"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "scripts/prune/README.md",
    "chars": 22,
    "preview": "# Library for Pruning\n"
  },
  {
    "path": "scripts/prune/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/prune/eval.py",
    "chars": 4803,
    "preview": "import tensorflow as tf\nimport os\nimport numpy as np\n\nfrom ...hparams.registry import get_hparams\nfrom ...models.registr"
  },
  {
    "path": "scripts/prune/prune.py",
    "chars": 6846,
    "preview": "import numpy as np\nimport tensorflow as tf\nimport statistics\nfrom ...models.utils import model_utils\n\n_PRUNE_FN = dict()"
  },
  {
    "path": "train.py",
    "chars": 6578,
    "preview": "import cloud\nimport os\nimport sys\nimport subprocess\nimport random\nimport tensorflow as tf\nimport numpy as np\nimport time"
  },
  {
    "path": "training/__init__.py",
    "chars": 108,
    "preview": "__all__ = [\"lr_schemes\", \"tpu\", \"flags\"]\n\nfrom .lr_schemes import *\nfrom .tpu import *\nfrom .flags import *\n"
  },
  {
    "path": "training/envs.py",
    "chars": 435,
    "preview": "_ENVS = dict()\n\n\ndef register(cls):\n  global _ENVS\n  _ENVS[cls.__name__.lower()] = cls()\n  return cls\n\n\ndef get_env(name"
  },
  {
    "path": "training/flags.py",
    "chars": 814,
    "preview": "import getpass\nimport os\nimport subprocess\n\nimport tensorflow as tf\n\nfrom .envs import get_env\n\n\ndef validate_flags(FLAG"
  },
  {
    "path": "training/lr_schemes.py",
    "chars": 4784,
    "preview": "import tensorflow as tf\n\n_LR = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _LR\n    _LR[name] = fn\n  "
  },
  {
    "path": "training/tpu.py",
    "chars": 2764,
    "preview": "import collections\nimport six\n\nimport tensorflow as tf\n\n\ndef remove_summaries():\n  g = tf.get_default_graph()\n  key = tf"
  }
]

About this extraction

This page contains the full source code of the for-ai/TD GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 50 files (136.8 KB), approximately 40.5k tokens, and a symbol index with 288 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!