[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\ntmp\nruns\nrun\n\n# PyCharm\n.idea/\n\n# macOS metadata\n.DS_Store\n\n.vscode"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\npython:\n  - \"3.6\"\n\n# command to install dependencies\ninstall:\n  - pip install -r requirements.txt\n\n# command to run tests\nscript:\n  - export FILES=\"$(git diff --name-only $TRAVIS_COMMIT_RANGE)\"\n  - cd /home/travis/build/for-ai\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams mnist_basic_no_dropout\n  \n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_weight\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_trgtd_weight\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_unit\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar_lenet_trgtd_unit\n  \n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_weight\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_trgtd_weight\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_unit\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar10_resnet32_trgtd_unit\n  \n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_no_dropout\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_untargeted_dropout\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_targeted_dropout\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_untargeted_unit_dropout\n  - python3 -m TD.train --eval_steps 1 --eval_every 1 --train_epochs -1 --env local --hparams cifar100_vgg16_targeted_unit_dropout\n"
  },
  {
    "path": "README.md",
    "content": "# Targeted Dropout\n\nAidan N. Gomez, Ivan Zhang, Kevin Swersky, Yarin Gal, and Geoffrey E. Hinton\n\n## Table of Contents\n- [Requirements](#requirements)\n- [Quick Start](#quick-start)\n- [Experiments](#experiments)\n\n## Requirements\n- Python 3\n- Tensorflow 1.8\n\n## Quick Start\n1. Train a model: `python -m TD.train --hparams=resnet_default`\n2. Prune that model: `python -m TD.scripts.prune.eval --hparams=resnet_default --prune_percent 0.0,0.25,0.5,0.75,0.95`\n\n### Flags\n- `--env`: one of `local`, `gcp` (GPU instances), or `tpu` (TPU instances). Feel free to add more if necessary.\n- `--hparams`: the hparam set you want to run.\n- `--hparam_override`: manually specify hparams to be overridden (e.g `--hparam_override 'drop_rate=0.66'`)\n"
  },
  {
    "path": "__init__.py",
    "content": "__all__ = [\"data\", \"hparams\", \"models\", \"training\"]\n\nfrom .data import *\nfrom .hparams import *\nfrom .models import *\nfrom .training import *\n"
  },
  {
    "path": "data/__init__.py",
    "content": "__all__ = [\n    \"image_reader\",\n    \"registry\",\n    \"dataset_maps\",\n]\n"
  },
  {
    "path": "data/data_generators/__init__.py",
    "content": "__all__ = [\n    \"cifar_generator\",\n    \"generator_utils\",\n    \"mnist_generator\",\n]\n"
  },
  {
    "path": "data/data_generators/cifar_generator.py",
    "content": "try:\n  import cPickle\nexcept ImportError:\n  import pickle as cPickle\nimport os\nimport random\nimport sys\nimport tarfile\nimport urllib.request\nimport numpy as np\nimport tensorflow as tf\n\nfrom .generator_utils import generate_files\nfrom ...models.utils.model_utils import ModeKeys\n\nFLAGS = tf.app.flags.FLAGS\n\n_URL = \"http://www.cs.toronto.edu/~kriz/\"\n_CIFAR10_TAR = \"cifar-10-python.tar.gz\"\n_CIFAR10_DIR = \"cifar-10-batches-py\"\n_CIFAR10_TRAIN = [\n    \"data_batch_1\", \"data_batch_2\", \"data_batch_3\", \"data_batch_4\",\n    \"data_batch_5\"\n]\n_CIFAR10_TEST = [\"test_batch\"]\n\n_CIFAR100_TAR = \"cifar-100-python.tar.gz\"\n_CIFAR100_DIR = \"cifar-100-python\"\n_CIFAR100_TRAIN = [\"train\"]\n_CIFAR100_TEST = [\"test\"]\n\n_WORKING_DIR = \"/tmp/tf_data\"\n\n\ndef download(v100):\n  archive = _CIFAR100_TAR if v100 else _CIFAR10_TAR\n  filepath = os.path.join(_WORKING_DIR, archive)\n  if not os.path.exists(_WORKING_DIR):\n    os.makedirs(_WORKING_DIR)\n  url = _URL + archive\n  if not os.path.isfile(filepath):\n    print(\"Downloading \" + url)\n    urllib.request.urlretrieve(url, filepath)\n  print(\"Extracting \" + filepath)\n  tar = tarfile.open(filepath, \"r:gz\")\n  tar.extractall(path=_WORKING_DIR)\n  tar.close()\n\n\ndef maybe_download(files, v100):\n  for file in files:\n    filepath = os.path.join(_WORKING_DIR, _CIFAR100_DIR\n                            if v100 else _CIFAR10_DIR, file)\n    if not os.path.isfile(filepath):\n      download(v100)\n      break\n\n\ndef read_files(files, v100):\n  images = None\n  labels = None\n  for file in files:\n    filename = os.path.join(_WORKING_DIR, _CIFAR100_DIR\n                            if v100 else _CIFAR10_DIR, file)\n    data = None\n    with tf.gfile.Open(filename, \"rb\") as f:\n      if sys.version_info < (3,):\n        data = cPickle.load(f)\n      else:\n        data = cPickle.load(f, encoding=\"bytes\")\n\n    info = np.transpose(data[b\"data\"].reshape((-1, 3, 32, 32)), (0, 2, 3, 1))\n    if images is None:\n      images = info\n    else:\n      images = np.concatenate((images, info))\n\n    info = data[b\"fine_labels\"] if v100 else data[b\"labels\"]\n    if labels is None:\n      labels = info\n    else:\n      labels = np.concatenate((labels, info))\n  return images, labels\n\n\ndef cifar_generator(v100, mode):\n  files = None\n  if v100:\n    files = _CIFAR100_TRAIN if mode != ModeKeys.TEST else _CIFAR100_TEST\n  else:\n    files = _CIFAR10_TRAIN if mode != ModeKeys.TEST else _CIFAR10_TEST\n  maybe_download(files, v100)\n\n  images, labels = read_files(files, v100)\n  data = list(zip(images, labels))\n  random.shuffle(data)\n  \n  samples = len(data)\n  if mode == ModeKeys.TRAIN:\n    data = data[:int(samples * 0.8)]\n  elif mode == ModeKeys.EVAL:\n    data = data[int(samples * 0.8):]\n\n  image_ph = tf.placeholder(dtype=tf.uint8, shape=(32, 32, 3))\n  encoded_ph = tf.image.encode_png(image_ph)\n\n  sess = tf.Session()\n  for image, label in data:\n    encoded_im = sess.run(encoded_ph, feed_dict={image_ph: image})\n    yield {\n        \"image/encoded\": [encoded_im],\n        \"image/format\": [b\"png\"],\n        \"image/class/label\": [label],\n        \"image/height\": [32],\n        \"image/width\": [32],\n        \"image/channels\": [3]\n    }\n\n\ndef generate(train_name, eval_name, test_name, hparams):\n  v100 = hparams.data in [\"cifar100\", \"cifar100_tpu\"]\n  generate_files(\n      cifar_generator(v100, mode=ModeKeys.TRAIN), train_name, hparams.data_dir,\n      FLAGS.num_shards)\n  generate_files(\n      cifar_generator(v100, mode=ModeKeys.EVAL), eval_name, hparams.data_dir,\n      FLAGS.num_shards)\n  generate_files(\n      cifar_generator(v100, mode=ModeKeys.TEST), test_name, hparams.data_dir,\n      FLAGS.num_shards)\n"
  },
  {
    "path": "data/data_generators/generator_utils.py",
    "content": "import operator\nimport os\nimport numpy as np\nimport tensorflow as tf\n\ntf.flags.DEFINE_boolean(\"v100\", False,\n                        \"Download CIFAR-100 instead of CIFAR-10.\")\ntf.flags.DEFINE_integer(\"num_shards\", 1,\n                        \"The number of output shards to write to.\")\n\n\ndef to_example(dictionary):\n  features = {}\n  for k, v in dictionary.items():\n    if len(v) == 0:\n      raise Exception(\"Empty field: %s\" % str((k, v)))\n    if isinstance(v[0], (int, np.int8, np.int32, np.int64)):\n      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))\n    elif isinstance(v[0], (float, np.float32)):\n      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))\n    elif isinstance(v[0], (str, bytes)):\n      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))\n    else:\n      raise Exception(\"Unsupported type: %s\" % type(v[0]))\n  return tf.train.Example(features=tf.train.Features(feature=features))\n\n\ndef generate_files(generator,\n                   output_name,\n                   output_dir,\n                   num_shards,\n                   max_cases=None):\n  if not tf.gfile.Exists(output_dir):\n    tf.gfile.MakeDirs(output_dir)\n\n  writers = []\n  for shard in range(num_shards):\n    output_filename = \"%s-%dof%d\" % (output_name, shard + 1, num_shards)\n    output_file = os.path.join(output_dir, output_filename)\n    writers.append(tf.python_io.TFRecordWriter(output_file))\n\n  counter, shard = 0, 0\n  for case in generator:\n    if counter % 100 == 0:\n      tf.logging.info(\"Processed %d examples...\" % counter)\n    counter += 1\n    if max_cases and counter > max_cases:\n      break\n    sequence_example = to_example(case)\n    writers[shard].write(sequence_example.SerializeToString())\n    shard = (shard + 1) % num_shards\n\n  for writer in writers:\n    writer.close()\n"
  },
  {
    "path": "data/data_generators/mnist_generator.py",
    "content": "import gzip\nimport os\nimport random\nimport urllib\nimport numpy as np\nimport tensorflow as tf\n\nfrom .generator_utils import generate_files\nfrom ...models.utils.model_utils import ModeKeys\n\nFLAGS = tf.app.flags.FLAGS\ntf.logging.set_verbosity(tf.logging.INFO)\n\n_TRAIN_IMAGE_COUNT = 60000\n_TRAIN_IMAGE_FILE = \"train-images-idx3-ubyte.gz\"\n_TRAIN_LABEL_FILE = \"train-labels-idx1-ubyte.gz\"\n\n_TEST_IMAGE_COUNT = 10000\n_TEST_IMAGE_FILE = \"t10k-images-idx3-ubyte.gz\"\n_TEST_LABEL_FILE = \"t10k-labels-idx1-ubyte.gz\"\n\n_WORKING_DIR = \"/tmp/tf_data\"\n\n\ndef download_files(filenames):\n  \"\"\"Download files to tmp/data if file does not exist\n  Args:\n    filenames: list of string; list of filenames to check if exist\n  \"\"\"\n  if not os.path.exists(_WORKING_DIR):\n    os.makedirs(_WORKING_DIR)\n  for filename in filenames:\n    filepath = os.path.join(_WORKING_DIR, filename)\n    url = \"http://yann.lecun.com/exdb/mnist/\" + filename\n    if not os.path.isfile(filepath):\n      print(\"Downloading %s\" % (url + filename))\n      try:\n        urllib.urlretrieve(url, filepath)\n      except AttributeError:\n        urllib.request.urlretrieve(url, filepath)\n\n\ndef read_images(filepath, num_images):\n  with gzip.open(filepath) as f:\n    f.read(16)\n    buf = f.read(28 * 28 * num_images)\n    data = np.frombuffer(buf, dtype=np.uint8)\n    data = data.reshape(num_images, 28, 28, 1)\n  return data\n\n\ndef read_labels(filepath, num_labels):\n  with gzip.open(filepath) as f:\n    f.read(8)\n    buf = f.read(num_labels)\n    data = np.frombuffer(buf, dtype=np.uint8)\n  return data.astype(np.int64)\n\n\ndef mnist_generator(mode):\n  num_images = _TRAIN_IMAGE_COUNT if mode != ModeKeys.TEST else _TEST_IMAGE_COUNT\n  image_filepath = _TRAIN_IMAGE_FILE if mode != ModeKeys.TEST else _TEST_IMAGE_FILE\n  label_filepath = _TRAIN_LABEL_FILE if mode != ModeKeys.TEST else _TEST_LABEL_FILE\n\n  download_files([image_filepath, label_filepath])\n\n  image_filepath = os.path.join(_WORKING_DIR, image_filepath)\n  label_filepath = os.path.join(_WORKING_DIR, label_filepath)\n\n  images = read_images(image_filepath, num_images)\n  labels = read_labels(label_filepath, num_images)\n\n  data = list(zip(images, labels))\n  random.shuffle(data)\n  \n  if mode == ModeKeys.TRAIN:\n    data = data[:5*num_images//6]\n  elif mode == ModeKeys.EVAL:\n    data = data[5*num_images//6:]\n\n  image_ph = tf.placeholder(dtype=tf.uint8, shape=(28, 28, 1))\n  encoded_ph = tf.image.encode_png(image_ph)\n\n  sess = tf.Session()\n  for image, label in data:\n    encoded_im = sess.run(encoded_ph, feed_dict={image_ph: image})\n    yield {\n        \"image/encoded\": [encoded_im],\n        \"image/format\": [b\"png\"],\n        \"image/class/label\": [label],\n        \"image/height\": [28],\n        \"image/width\": [28]\n    }\n\n\ndef generate(train_name, eval_name, test_name, hparams):\n  generate_files(\n      mnist_generator(mode=ModeKeys.TRAIN), train_name, hparams.data_dir, 1)\n  generate_files(\n      mnist_generator(mode=ModeKeys.EVAL), eval_name, hparams.data_dir, 1)\n  generate_files(\n      mnist_generator(mode=ModeKeys.TEST), test_name, hparams.data_dir, 1)\n"
  },
  {
    "path": "data/dataset_maps.py",
    "content": "import tensorflow as tf\nfrom . import imagenet_augs \n\n_AUGMENTATIONS = dict()\n\n\ndef register(fn):\n  global _AUGMENTATIONS\n  _AUGMENTATIONS[fn.__name__] = fn\n  return fn\n\n\ndef get_augmentation(name, params, training):\n\n  def fn(*args, **kwargs):\n    return _AUGMENTATIONS[name](\n        *args, **kwargs, training=training, params=params)\n\n  return fn\n\n\n@register\ndef cifar_augmentation(image, label, training, params):\n  \"\"\"Image augmentation suitable for CIFAR-10/100.\n  As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5).\n  Args:\n    images: a Tensor.\n  Returns:\n    Tensor of the same shape as images.\n  \"\"\"\n  if training:\n    image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)\n    image = tf.random_crop(image, [32, 32, 3])\n    image = tf.image.random_flip_left_right(image)\n\n  image = tf.image.per_image_standardization(image)\n  return image, label\n\n@register\ndef imagenet_augmentation(image, label, training, params):\n  \"\"\"Imagenet augmentations.\n  Args:\n    images: a Tensor.\n  Returns:\n    Tensor of the same shape as images.\n  \"\"\"\n  if training:\n    image = imagenet_augs.preprocess_for_train(image, params.input_shape[0])\n  else:\n    image = imagenet_augs.preprocess_for_eval(image, params.input_shape[0])\n  return image, label\n\n\n@register\ndef load_images(example, training, params):\n  data_fields_to_features = {\n      \"image/encoded\": tf.FixedLenFeature((), tf.string),\n      \"image/format\": tf.FixedLenFeature((), tf.string),\n      \"image/class/label\": tf.FixedLenFeature((), tf.int64)\n  }\n\n  example = tf.parse_single_example(example, data_fields_to_features)\n  image = example[\"image/encoded\"]\n  image = tf.image.decode_png(image, channels=params.channels, dtype=tf.uint8)\n  image = tf.to_float(image)\n\n  label = tf.to_int32(example[\"image/class/label\"])\n\n  return image, label\n\n@register\ndef set_shapes(image, label, training, params):\n  image = tf.reshape(image, params.input_shape)\n  return image, label\n@register\ndef transpose(image, label, training, params):\n  image = tf.transpose(image, [2, 0, 1])\n  return image, label"
  },
  {
    "path": "data/image_reader.py",
    "content": "import tensorflow as tf\nfrom tensorflow.examples.tutorials.mnist import input_data\n\nfrom .registry import register\nfrom .dataset_maps import get_augmentation\nfrom .data_generators import cifar_generator, mnist_generator\n\n\n@register(\"imagenet\", None)\n@register(\"mnist\", mnist_generator.generate)\n@register(\"cifar10\", cifar_generator.generate)\n@register(\"cifar100\", cifar_generator.generate)\ndef image_reader(data_sources, hparams, training):\n  \"\"\"Input function for image data.\"\"\"\n\n  def _input_fn(params=None):\n    \"\"\"Input function compatible with Experiment API.\"\"\"\n    if params is not None and \"batch_size\" in params:\n      hparams.batch_size = params[\"batch_size\"]\n\n    dataset = tf.data.TFRecordDataset(\n        data_sources, num_parallel_reads=4 if training else 1)\n    dataset = dataset.prefetch(5 * hparams.batch_size)\n\n    if hparams.shuffle_data:\n      dataset = dataset.shuffle(5 * hparams.batch_size)\n\n    dataset = dataset.map(get_augmentation(\"load_images\", hparams, training))\n\n    if hparams.data_augmentations is not None:\n      for augmentation_name in hparams.data_augmentations:\n        dataset = dataset.map(\n            get_augmentation(augmentation_name, hparams, training))\n\n    dataset = dataset.map(get_augmentation(\"set_shapes\", hparams, training))\n    if hparams.data_format == \"channels_first\":\n      dataset = dataset.map(get_augmentation(\"transpose\", hparams, training))\n    dataset = dataset.repeat().batch(hparams.batch_size)\n    dataset_it = dataset.make_one_shot_iterator()\n\n    images, labels = dataset_it.get_next()\n    if params is not None and \"batch_size\" in params:\n      images = tf.reshape(images,\n                          [hparams.batch_size] + images.shape.as_list()[1:])\n      labels = tf.reshape(labels,\n                          [hparams.batch_size] + labels.shape.as_list()[1:])\n    return {\"inputs\": images, \"labels\": labels}, labels\n\n  return _input_fn\n\n\n@register(\"mnist_simple\", None)\ndef mnist_simple(data_source, params, training):\n  \"\"\"Input function for MNIST image data.\"\"\"\n\n  mnist = input_data.read_data_sets(data_source, one_hot=True)\n\n  data_set = mnist.train if training else mnist.test\n\n  def _input_fn():\n    input_images = tf.constant(data_set.images)\n\n    input_labels = tf.constant(\n        data_set.labels) if not params.is_ae else tf.constant(data_set.images)\n\n    image, label = tf.train.slice_input_producer([input_images, input_labels])\n\n    imageBatch, labelBatch = tf.train.batch(\n        [image, label], batch_size=params.batch_size)\n\n    return {\"inputs\": imageBatch}, labelBatch\n\n  return _input_fn\n\n\n@register(\"fashion\", None)\ndef fashion(data_source, params, training):\n  \"\"\"Input function for MNIST image data.\"\"\"\n\n  mnist = input_data.read_data_sets(\n      data_source,\n      source_url='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/',\n      one_hot=True)\n\n  data_set = mnist.train if training else mnist.test\n\n  def _input_fn():\n    input_images = tf.constant(data_set.images)\n\n    input_labels = tf.constant(data_set.labels)\n    image, label = tf.train.slice_input_producer([input_images, input_labels])\n\n    imageBatch, labelBatch = tf.train.batch(\n        [image, label], batch_size=params.batch_size)\n\n    return {\"inputs\": imageBatch}, labelBatch\n\n  return _input_fn\n"
  },
  {
    "path": "data/imagenet_augs.py",
    "content": "import tensorflow as tf\n\nMEAN_RGB = [0.485, 0.456, 0.406]\nSTDDEV_RGB = [0.229, 0.224, 0.225]\n\n\n# The following preprocessing functions were taken from\n# cloud_tpu/models/resnet/resnet_preprocessing.py\n# ==============================================================================\ndef _crop(image, offset_height, offset_width, crop_height, crop_width):\n  \"\"\"Crops the given image using the provided offsets and sizes.\n  Note that the method doesn't assume we know the input image size but it does\n  assume we know the input image rank.\n  Args:\n    image: `Tensor` image of shape [height, width, channels].\n    offset_height: `Tensor` indicating the height offset.\n    offset_width: `Tensor` indicating the width offset.\n    crop_height: the height of the cropped image.\n    crop_width: the width of the cropped image.\n  Returns:\n    the cropped (and resized) image.\n  Raises:\n    InvalidArgumentError: if the rank is not 3 or if the image dimensions are\n      less than the crop size.\n  \"\"\"\n  original_shape = tf.shape(image)\n\n  rank_assertion = tf.Assert(\n      tf.equal(tf.rank(image), 3), [\"Rank of image must be equal to 3.\"])\n  with tf.control_dependencies([rank_assertion]):\n    cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]])\n\n  size_assertion = tf.Assert(\n      tf.logical_and(\n          tf.greater_equal(original_shape[0], crop_height),\n          tf.greater_equal(original_shape[1], crop_width)),\n      [\"Crop size greater than the image size.\"])\n\n  offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))\n\n  # Use tf.slice instead of crop_to_bounding box as it accepts tensors to\n  # define the crop size.\n  with tf.control_dependencies([size_assertion]):\n    image = tf.slice(image, offsets, cropped_shape)\n  return tf.reshape(image, cropped_shape)\n\n\ndef distorted_bounding_box_crop(image,\n                                bbox,\n                                min_object_covered=0.1,\n                                aspect_ratio_range=(0.75, 1.33),\n                                area_range=(0.05, 1.0),\n                                max_attempts=100,\n                                scope=None):\n  \"\"\"Generates cropped_image using a one of the bboxes randomly distorted.\n  See `tf.image.sample_distorted_bounding_box` for more documentation.\n  Args:\n    image: `Tensor` of image (it will be converted to floats in [0, 1]).\n    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`\n        where each coordinate is [0, 1) and the coordinates are arranged\n        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole\n        image.\n    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped\n        area of the image must contain at least this fraction of any bounding\n        box supplied.\n    aspect_ratio_range: An optional list of `float`s. The cropped area of the\n        image must have an aspect ratio = width / height within this range.\n    area_range: An optional list of `float`s. The cropped area of the image\n        must contain a fraction of the supplied image within in this range.\n    max_attempts: An optional `int`. Number of attempts at generating a cropped\n        region of the image of the specified constraints. After `max_attempts`\n        failures, return the entire image.\n    scope: Optional `str` for name scope.\n  Returns:\n    (cropped image `Tensor`, distorted bbox `Tensor`).\n  \"\"\"\n  with tf.name_scope(\n      scope, default_name=\"distorted_bounding_box_crop\", values=[image, bbox]):\n    # Each bounding box has shape [1, num_boxes, box coords] and\n    # the coordinates are ordered [ymin, xmin, ymax, xmax].\n\n    # A large fraction of image datasets contain a human-annotated bounding\n    # box delineating the region of the image containing the object of interest.\n    # We choose to create a new bounding box for the object which is a randomly\n    # distorted version of the human-annotated bounding box that obeys an\n    # allowed range of aspect ratios, sizes and overlap with the human-annotated\n    # bounding box. If no box is supplied, then we assume the bounding box is\n    # the entire image.\n    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(\n        tf.shape(image),\n        bounding_boxes=bbox,\n        min_object_covered=min_object_covered,\n        aspect_ratio_range=aspect_ratio_range,\n        area_range=area_range,\n        max_attempts=max_attempts,\n        use_image_if_no_bounding_boxes=True)\n    bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box\n\n    # Crop the image to the specified bounding box.\n    cropped_image = tf.slice(image, bbox_begin, bbox_size)\n    return cropped_image, distort_bbox\n\n\ndef _random_crop(image, size):\n  \"\"\"Make a random crop of (`size` x `size`).\"\"\"\n  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])\n  random_image, bbox = distorted_bounding_box_crop(\n      image,\n      bbox,\n      min_object_covered=0.1,\n      aspect_ratio_range=(3. / 4, 4. / 3.),\n      area_range=(0.08, 1.0),\n      max_attempts=1,\n      scope=None)\n  bad = _at_least_x_are_true(tf.shape(image), tf.shape(random_image), 3)\n\n  image = tf.cond(\n      bad, lambda: _center_crop(_do_scale(image, size), size),\n      lambda: tf.image.resize_bicubic([random_image], [size, size])[0])\n  return image\n\n\ndef _flip(image):\n  \"\"\"Random horizontal image flip.\"\"\"\n  image = tf.image.random_flip_left_right(image)\n  return image\n\n\ndef _at_least_x_are_true(a, b, x):\n  \"\"\"At least `x` of `a` and `b` `Tensors` are true.\"\"\"\n  match = tf.equal(a, b)\n  match = tf.cast(match, tf.int32)\n  return tf.greater_equal(tf.reduce_sum(match), x)\n\n\ndef _do_scale(image, size):\n  \"\"\"Rescale the image by scaling the smaller spatial dimension to `size`.\"\"\"\n  shape = tf.cast(tf.shape(image), tf.float32)\n  w_greater = tf.greater(shape[0], shape[1])\n  shape = tf.cond(\n      w_greater, lambda: tf.cast([shape[0] / shape[1] * size, size], tf.int32),\n      lambda: tf.cast([size, shape[1] / shape[0] * size], tf.int32))\n\n  return tf.image.resize_bicubic([image], shape)[0]\n\n\ndef _center_crop(image, size):\n  \"\"\"Crops to center of image with specified `size`.\"\"\"\n  image_height = tf.shape(image)[0]\n  image_width = tf.shape(image)[1]\n\n  offset_height = ((image_height - size) + 1) / 2\n  offset_width = ((image_width - size) + 1) / 2\n  image = _crop(image, offset_height, offset_width, size, size)\n  return image\n\n\ndef _normalize(image):\n  \"\"\"Normalize the image to zero mean and unit variance.\"\"\"\n  offset = tf.constant(MEAN_RGB, shape=[1, 1, 3])\n  image -= offset\n\n  scale = tf.constant(STDDEV_RGB, shape=[1, 1, 3])\n  image /= scale\n  return image\n\n\ndef preprocess_for_train(image, image_size=224):\n  \"\"\"Preprocesses the given image for evaluation.\n  Args:\n    image: `Tensor` representing an image of arbitrary size.\n    image_size: int, how large the output image should be.\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _random_crop(image, image_size)\n  image = _normalize(image)\n  image = _flip(image)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  return image\n\n\ndef preprocess_for_eval(image, image_size=224):\n  \"\"\"Preprocesses the given image for evaluation.\n  Args:\n    image: `Tensor` representing an image of arbitrary size.\n    image_size: int, how large the output image should be.\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _do_scale(image, image_size + 32)\n  image = _normalize(image)\n  image = _center_crop(image, image_size)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  return image\n"
  },
  {
    "path": "data/registry.py",
    "content": "import os\n\nimport tensorflow as tf\n\n_INPUT_FNS = dict()\n_GENERATORS = dict()\n\n\ndef register(name, generator):\n\n  def add_to_dict(fn):\n    global _INPUT_FNS\n    global _GENERATORS\n    _INPUT_FNS[name] = fn\n    _GENERATORS[name] = generator\n    return fn\n\n  return add_to_dict\n\n\ndef get_input_fns(hparams, generate=True):\n  train_path = os.path.join(hparams.data_dir, \"train*\")\n  eval_path = os.path.join(hparams.data_dir, \"eval*\")\n  test_path = os.path.join(hparams.data_dir, \"test*\")\n\n  if generate:\n    if not tf.gfile.Exists(hparams.data_dir):\n      tf.gfile.MakeDirs(hparams.data_dir)\n\n    # generate if train doesnt exist\n    maybe_generate(train_path, hparams)\n    maybe_generate(eval_path, hparams)\n    maybe_generate(test_path, hparams)\n\n  train_path = tf.gfile.Glob(train_path)\n  eval_path = tf.gfile.Glob(eval_path)\n  test_path = tf.gfile.Glob(test_path)\n\n  input_fn = _INPUT_FNS[hparams.data]\n  train_fn = input_fn(train_path, hparams, training=True)\n  eval_fn = None if not eval_path else input_fn(\n      eval_path, hparams, training=False)\n  test_fn = None if not test_path else input_fn(\n      test_path, hparams, training=False)\n  if not (eval_path or test_path):\n    raise Exception(\"Could not find eval or test files.\")\n  return train_fn, eval_fn, test_fn\n\n\ndef get_dataset(hparams):\n  train_path = os.path.join(hparams.data_dir, \"train*\")\n  eval_path = os.path.join(hparams.data_dir, \"eval*\")\n  test_path = os.path.join(hparams.data_dir, \"test*\")\n  maybe_generate(train_path, hparams)\n  maybe_generate(eval_path, hparams)\n  maybe_generate(test_path, hparams)\n  return train_path, eval_path, test_path\n\n\ndef maybe_generate(check_path, hparams):\n  if not tf.gfile.Glob(check_path):\n    generate_fn = _GENERATORS[hparams.data]\n    if generate_fn:\n      generate_fn(\"train\", \"eval\", \"test\", hparams)\n    else:\n      tf.logging.warn(\n          \"No generator function. Unable to generate: %s\" % check_path)\n"
  },
  {
    "path": "hparams/__init__.py",
    "content": "__all__ = [\"defaults\", \"registry\", \"resnet\", \"lenet\", \"utils\", \"vgg\", \"basic\"]\n\nfrom .defaults import *\nfrom .resnet import *\nfrom .registry import *\nfrom .user import *\nfrom .utils import *\nfrom .lenet import *\nfrom .basic import *\nfrom .vgg import *\nfrom .basic import *\n"
  },
  {
    "path": "hparams/basic.py",
    "content": "import tensorflow as tf\n\nfrom . import defaults\nfrom .registry import register\n\n\n# MNIST =========================\n@register\ndef mnist_basic_no_dropout():\n  hps = defaults.default()\n  hps.model = \"basic\"\n  hps.data = \"mnist\"\n  hps.activation = \"relu\"\n  hps.batch_norm = False\n  hps.drop_rate = 0.0\n  hps.dropout_type = None\n  hps.initializer = \"glorot_uniform_initializer\"\n  hps.layers = [128, 64, 32]\n  hps.input_shape = [784]\n  hps.output_shape = [10]\n  hps.layer_type = \"dense\"\n\n  hps.learning_rate = 0.1\n  hps.optimizer = \"momentum\"\n  hps.momentum = 0.0\n\n  return hps\n\n\n@register\ndef mnist_basic_trgtd_dropout():\n  hps = mnist_basic_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef mnist_basic_untrgtd_dropout():\n  hps = mnist_basic_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_weight\"\n\n  return hps\n\n\n@register\ndef mnist_basic_trgtd_dropout_random():\n  hps = mnist_basic_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_weight_random\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef mnist_basic_trgtd_unit_dropout():\n  hps = mnist_basic_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef mnist_basic_smallify_dropout_1eneg4():\n  hps = mnist_basic_no_dropout()\n  hps.dropout_type = \"smallify_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n\n  return hps\n\n\n@register\ndef mnist_basic_smallify_dropout_1eneg3():\n  hps = mnist_basic_smallify_dropout_1eneg4()\n  hps.smallify = 1e-3\n\n  return hps\n\n\n@register\ndef mnist_basic_smallify_weight_dropout_1eneg4():\n  hps = mnist_basic_no_dropout()\n  hps.dropout_type = \"smallify_weight_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_basic_no_dropout():\n  hps = defaults.default()\n  hps.model = \"basic\"\n  hps.data = \"cifar10\"\n  hps.activation = \"relu\"\n  hps.batch_norm = False\n  hps.drop_rate = 0.0\n  hps.dropout_type = None\n  hps.initializer = \"glorot_uniform_initializer\"\n  hps.layers = [128, 64, 32]\n  hps.channels = 3\n  hps.input_shape = [32, 32, 3]\n  hps.output_shape = [10]\n  hps.layer_type = \"dense\"\n\n  hps.learning_rate = 0.1\n  hps.optimizer = \"momentum\"\n  hps.momentum = 0.0\n\n  return hps\n\n\n@register\ndef cifar100_basic_no_dropout():\n  hps = cifar10_basic_no_dropout()\n  hps.output_shape = [100]\n  hps.data = \"cifar100\"\n  return hps\n\n\n@register\ndef imagenet32_basic():\n  hps = defaults.default_imagenet32()\n  hps.model = \"basic\"\n  hps.activation = \"relu\"\n  hps.batch_norm = False\n  hps.drop_rate = 0.0\n  hps.dropout_type = None\n  hps.initializer = \"glorot_uniform_initializer\"\n  hps.layers = [128, 64, 32]\n  hps.layer_type = \"dense\"\n  hps.learning_rate = 0.1\n  hps.optimizer = \"momentum\"\n  hps.momentum = 0.0\n  return hps"
  },
  {
    "path": "hparams/defaults.py",
    "content": "import tensorflow as tf\n\nfrom .registry import register\nfrom .utils import HParams\n\n\n@register\ndef default():\n  return HParams(\n      model=None,\n      data=None,\n      shuffle_data=True,\n      data_augmentations=None,\n      train_epochs=256,\n      eval_steps=100,\n      type=\"image\",\n      batch_size=64,\n      learning_rate=0.01,\n      lr_scheme=\"constant\",\n      initializer=\"glorot_normal_initializer\",\n      delay=0,\n      staircased=False,\n      learning_rate_decay_interval=2000,\n      learning_rate_decay_rate=0.1,\n      clip_grad_norm=1.0,\n      l2_loss=0.0,\n      prune_val=0.8,\n      label_smoothing=0.1,\n      use_tpu=False,\n      momentum=0.9,\n      init_scheme=\"random\",\n      warmup_steps=10000,\n      use_nesterov=False,\n      louizos_cost=0.0,\n      l1_norm=0.0,\n      thresh=2.5,\n      fixed=False,\n      var_scale=1,\n      klscale=1.0,\n      ard_cost=0.0,\n      logit_packing=0.0,\n      logit_squeezing=0.0,\n      clp=0.0,\n      logit_bound=None,\n      dropout_type=None,\n      smallify=0.0,\n      smallify_delay=1000,\n      linear_drop_rate=False,\n      weight_decay_and_noise=False,\n      weight_decay_only_features=True,\n      weight_decay_weight_names=[\"DW\", \"kernel\", \"bias\"],\n      dropout_delay_steps=5000,\n      grad_noise_scale=0.0,\n      td_nines=0,\n      targ_cost=1.0,\n      aparams=\"\",\n      channels=1,\n      data_format=\"channels_last\",\n      epoch_size=50000,\n  )\n\n\n@register\ndef default_cifar10():\n  hps = default()\n  hps.data = \"cifar10\"\n  hps.data_augmentations = [\"cifar_augmentation\"]\n  hps.epoch_size = 50000  # number of images in train set\n\n  hps.input_shape = [32, 32, 3]\n  hps.output_shape = [10]\n  hps.channels = 3\n  hps.num_classes = 10\n\n  return hps\n\n\n@register\ndef default_cifar100():\n  hps = default_cifar10()\n  hps.data = \"cifar100\"\n  hps.output_shape = [100]\n  hps.num_classes = 100\n\n  return hps\n\n\n@register\ndef default_imagenet299():\n  hps = default()\n  hps.data = \"imagenet\"\n  hps.data_augmentations = [\"imagenet_augmentation\"]\n  hps.epoch_size = 1281167\n\n  hps.input_shape = [299, 299, 3]\n  hps.channels = 3\n  hps.output_shape = [1001]\n  hps.num_classes = 1001\n\n  return hps\n\n\n@register\ndef default_imagenet224():\n  hps = default_imagenet299()\n  hps.input_shape = [224, 224, 3]\n\n  return hps\n\n\n@register\ndef default_imagenet64():\n  hps = default_imagenet299()\n  hps.input_shape = [64, 64, 3]\n\n  return hps\n\n\n@register\ndef default_imagenet32():\n  hps = default_imagenet299()\n  hps.input_shape = [32, 32, 3]\n\n  return hps\n"
  },
  {
    "path": "hparams/lenet.py",
    "content": "import tensorflow as tf\n\nfrom .defaults import default, default_cifar10\nfrom .registry import register\n\n# lenet\n\n\n@register\ndef cifar_lenet():\n  hps = default_cifar10()\n\n  hps.model = \"lenet\"\n\n  hps.activation = \"relu\"\n  hps.residual = True\n  hps.initializer = \"glorot_normal_initializer\"\n  hps.kernel_size = 5\n  hps.lr_scheme = \"constant\"\n  hps.batch_size = 128\n\n  hps.learning_rate = 0.01\n  hps.optimizer = \"momentum\"\n  hps.momentum = 0.9\n  hps.use_nesterov = True\n\n  hps.drop_rate = 0.0\n  hps.dropout_type = None\n  hps.targ_rate = 0.0\n\n  hps.axis_aligned_cost = False\n  hps.clp = False\n  hps.logit_squeezing = False\n\n  return hps\n\n\n@register\ndef cifar_lenet_no_dropout():\n  hps = cifar_lenet()\n  return hps\n\n\n@register\ndef cifar_lenet_weight():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"untargeted_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_weight():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.5\n  hps.targ_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  return hps\n\n\n@register\ndef cifar_lenet_unit():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_unit\"\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_unit():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.5\n  hps.targ_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  return hps\n\n\n@register\ndef cifar_lenet_l1():\n  hps = cifar_lenet_no_dropout()\n  hps.l1_norm = 0.1\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_weight_l1():\n  hps = cifar_lenet_no_dropout()\n  hps.l1_norm = 0.1\n  hps.drop_rate = 0.5\n  hps.targ_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_unit_l1():\n  hps = cifar_lenet_no_dropout()\n  hps.l1_norm = 0.1\n  hps.drop_rate = 0.5\n  hps.targ_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_unit_botk75_33():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.33\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.75\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_unit_botk75_66():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.66\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.75\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_weight_botk75_33():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.33\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.75\n  return hps\n\n\n@register\ndef cifar_lenet_trgtd_weight_botk75_66():\n  hps = cifar_lenet_no_dropout()\n  hps.drop_rate = 0.66\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.75\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_weight_1en3():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.001\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_weight_1en1():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.1\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_weight_1en2():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.01\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_weight_5en3():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.005\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_weight_1en4():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.0001\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_unit_1en3():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.001\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_unit_1en1():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.1\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_unit_1en2():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.01\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_unit_5en3():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.005\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_louizos_unit_1en4():\n  hps = cifar_lenet_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.0001\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n  return hps\n\n\n@register\ndef cifar_lenet_variational():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"variational\"\n  hps.var_scale = 1. / 100\n  hps.drop_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar_lenet_variational_unscaled():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"variational\"\n  hps.drop_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar_lenet_variational_unit():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"variational_unit\"\n  hps.var_scale = 1. / 100\n  hps.drop_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar_lenet_variational_unit_unscaled():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"variational_unit\"\n  hps.drop_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar_lenet_smallify_neg4():\n  hps = cifar_lenet_no_dropout()\n  hps.dropout_type = \"smallify_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n  hps.smallify_delay = 10000\n  return hps\n"
  },
  {
    "path": "hparams/registry.py",
    "content": "import tensorflow as tf\n\n_HPARAMS = dict()\n\n\ndef register(fn):\n  global _HPARAMS\n  _HPARAMS[fn.__name__] = fn()\n  return fn\n\n\ndef get_hparams(hparams_list):\n  \"\"\"Fetches a merged group of hyperparameter sets (chronological priority).\"\"\"\n  final = tf.contrib.training.HParams()\n  for name in hparams_list.split(\"-\"):\n    curr = _HPARAMS[name]\n    final_dict = final.values()\n    for k, v in curr.values().items():\n      if k not in final_dict:\n        final.add_hparam(k, v)\n      elif final_dict[k] is None:\n        setattr(final, k, v)\n  return final\n"
  },
  {
    "path": "hparams/resnet.py",
    "content": "import tensorflow as tf\n\nfrom .registry import register\nfrom .defaults import *\n\n\n# from https://github.com/tensorflow/models/blob/master/resnet/resnet_main.py\n@register\ndef resnet_default():\n  hps = default_cifar10()\n  hps.model = \"resnet\"\n  hps.residual_filters = [16, 32, 64, 128]\n  hps.residual_units = [5, 5, 5]\n  hps.use_bottleneck = False\n  hps.batch_size = 128\n  hps.learning_rate = 0.4\n  hps.lr_scheme = \"resnet\"\n  hps.weight_decay_rate = 2e-4\n  hps.optimizer = \"momentum\"\n  return hps\n\n\n@register\ndef resnet102_imagenet224():\n  hps = default_imagenet224()\n  hps.model = \"resnet\"\n  hps.residual_filters = [64, 64, 128, 256, 512]\n  hps.residual_units = [3, 4, 23, 3]\n  hps.use_bottleneck = True\n  hps.batch_size = 128 * 8\n  hps.learning_rate = 0.128 * hps.batch_size / 256.\n  hps.lr_scheme = \"warmup_cosine\"\n  hps.warmup_steps = 10000\n  hps.weight_decay_rate = 1e-4\n  hps.optimizer = \"momentum\"\n  hps.use_nesterov = True\n  hps.initializer = \"variance_scaling_initializer\"\n  hps.learning_rate_cosine_cycle_steps = 120000\n  hps.cosine_alpha = 0.0\n  return hps\n\n\n@register\ndef resnet102_imagenet64():\n  hps = resnet102_imagenet224()\n  hps.input_shape = [64, 64, 3]\n  return hps\n\n\n@register\ndef resnet50_imagenet224():\n  hps = resnet102_imagenet224()\n  hps.residual_units = [3, 4, 6, 3]\n  return hps\n\n\n@register\ndef resnet34_imagenet224():\n  hps = resnet50_imagenet224()\n  hps.use_bottleneck = False\n  return hps\n\n\n@register\ndef resnet_cifar100():\n  hps = resnet_default()\n  hps.num_classes = 100\n  return hps\n\n\n@register\ndef cifar10_resnet32():\n  hps = resnet_default()\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_no_dropout():\n  hps = cifar10_resnet32()\n  hps.drop_rate = 0.0\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_weight():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_weight():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_weight\"\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_weight_50():\n  hps = cifar10_resnet32_weight()\n  hps.drop_rate = 0.50\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_unit():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_ard():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"targeted_ard\"\n  hps.targ_rate = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_unit():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_unit\"\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_unit_50():\n  hps = cifar10_resnet32_unit()\n  hps.drop_rate = 0.50\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_l1_1eneg3():\n  hps = cifar10_resnet32_no_dropout()\n  hps.l1_norm = 0.001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_l1_1eneg2():\n  hps = cifar10_resnet32_no_dropout()\n  hps.l1_norm = 0.01\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_l1_1eneg1():\n  hps = cifar10_resnet32_no_dropout()\n  hps.l1_norm = 0.1\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgted_weight_l1():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.5\n  hps.l1_norm = 0.1\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_targeted_unit_l1():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.5\n  hps.l1_norm = 0.1\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_unit_botk75_33():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.33\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_unit_botk75_66():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.66\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_weight_botk75_33():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.33\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_weight_botk75_66():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.66\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.75\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_unit_ramping_botk90_99():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.99\n  hps.dropout_type = \"targeted_unit_piecewise\"\n  hps.targ_rate = 0.90\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_trgtd_weight_ramping_botk99_99():\n  hps = cifar10_resnet32_no_dropout()\n  hps.drop_rate = 0.99\n  hps.dropout_type = \"targeted_weight_piecewise\"\n  hps.targ_rate = 0.99\n  hps.linear_drop_rate = True\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_weight_1en3():\n  hps = cifar10_resnet32_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.001\n  hps.dropout_type = \"louizos_weight\"\n  hps.drop_rate = 0.001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_weight_1en1():\n  hps = cifar10_resnet32_louizos_weight_1en3()\n  hps.louizos_cost = 0.1\n  hps.dropout_type = \"louizos_weight\"\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_weight_1en2():\n  hps = cifar10_resnet32_louizos_weight_1en3()\n  hps.louizos_cost = 0.01\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_weight_5en3():\n  hps = cifar10_resnet32_louizos_weight_1en3()\n  hps.louizos_cost = 0.005\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_weight_1en4():\n  hps = cifar10_resnet32_louizos_weight_1en3()\n  hps.louizos_cost = 0.0001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en3():\n  hps = cifar10_resnet32_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.001\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en1():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.1\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en2():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.01\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_5en3():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.005\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en4():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.0001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en5():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.00001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_louizos_unit_1en6():\n  hps = cifar10_resnet32_louizos_unit_1en3()\n  hps.louizos_cost = 0.000001\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_variational_weight():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"variational\"\n  hps.drop_rate = 0.75\n  hps.thresh = 3\n  hps.var_scale = 1. / 100\n  hps.weight_decay_rate = None\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_variational_weight_unscaled():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"variational\"\n  hps.drop_rate = 0.75\n  hps.thresh = 3\n  hps.var_scale = 1\n  hps.weight_decay_rate = None\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_variational_unit():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"variational_unit\"\n  hps.drop_rate = 0.75\n  hps.thresh = 3\n  hps.var_scale = 1. / 100\n  hps.weight_decay_rate = None\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_variational_unit_unscaled():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"variational_unit\"\n  hps.drop_rate = 0.75\n  hps.thresh = 3\n  hps.var_scale = 1\n  hps.weight_decay_rate = None\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_1eneg4():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"smallify_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_1eneg3():\n  hps = cifar10_resnet32_smallify_1eneg4()\n  hps.smallify = 1e-3\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_1eneg5():\n  hps = cifar10_resnet32_smallify_1eneg4()\n  hps.smallify = 1e-5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_1eneg6():\n  hps = cifar10_resnet32_smallify_1eneg4()\n  hps.smallify = 1e-6\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_weight_1eneg4():\n  hps = cifar10_resnet32_no_dropout()\n  hps.dropout_type = \"smallify_weight_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_weight_1eneg3():\n  hps = cifar10_resnet32_smallify_weight_1eneg4()\n  hps.smallify = 1e-3\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_weight_1eneg5():\n  hps = cifar10_resnet32_smallify_weight_1eneg3()\n  hps.smallify = 1e-5\n\n  return hps\n\n\n@register\ndef cifar10_resnet32_smallify_weight_1eneg6():\n  hps = cifar10_resnet32_smallify_weight_1eneg3()\n  hps.smallify = 1e-6\n\n  return hps\n\n\n# ================================\n"
  },
  {
    "path": "hparams/user.py",
    "content": "import tensorflow as tf\n\nfrom .defaults import default\nfrom .registry import register\n\n# Add experimental hparams below\n"
  },
  {
    "path": "hparams/utils.py",
    "content": "import tensorflow as tf\n\n\nclass HParams(tf.contrib.training.HParams):\n  \"\"\"Override of TensorFlow's HParams.\n\n  Replaces HParams.add_hparam(name, value) with simple attribute assignment.\n    I.e. There is no need to explicitly add an hparam:\n      Replace: `hparams.add_hparam(\"learning_rate\", 0.1)`\n      With:    `hparams.learning_rate = 0.1`\n  \"\"\"\n\n  def __setattr__(self, name, value):\n    \"\"\"Adds {name, value} pair to hyperparameters.\n\n    Args:\n      name: Name of the hyperparameter.\n      value: Value of the hyperparameter. Can be one of the following types:\n        int, float, string, int list, float list, or string list.\n\n    Raises:\n      ValueError: if one of the arguments is invalid.\n    \"\"\"\n    # Keys in kwargs are unique, but 'name' could the name of a pre-existing\n    # attribute of this object.  In that case we refuse to use it as a\n    # hyperparameter name.\n    if name[0] == \"_\":\n      object.__setattr__(self, name, value)\n      return\n    if isinstance(value, (list, tuple)):\n      if not value:\n        raise ValueError(\n            'Multi-valued hyperparameters cannot be empty: %s' % name)\n      self._hparam_types[name] = (type(value[0]), True)\n    else:\n      self._hparam_types[name] = (type(value), False)\n    object.__setattr__(self, name, value)\n"
  },
  {
    "path": "hparams/vgg.py",
    "content": "import tensorflow as tf\n\nfrom .registry import register\nfrom .defaults import default, default_cifar10\n\n\n# from https://github.com/tensorflow/models/blob/master/resnet/resnet_main.py\n@register\ndef vgg16_default():\n  vgg_default = default_cifar10()\n  vgg_default.initializer = \"glorot_uniform_initializer\"\n  vgg_default.model = \"vgg\"\n  vgg_default.learning_rate = 0.01\n  vgg_default.lr_scheme = \"constant\"\n  vgg_default.weight_decay_rate = 0.0005\n  vgg_default.num_classes = 10\n  vgg_default.optimizer = \"adam\"\n  vgg_default.adam_epsilon = 1e-6\n  vgg_default.beta1 = 0.85\n  vgg_default.beta2 = 0.997\n  vgg_default.input_shape = [32, 32, 3]\n  vgg_default.output_shape = [10]\n  return vgg_default\n\n\n@register\ndef cifar10_vgg16():\n  hps = vgg16_default()\n  hps.data = \"cifar10\"\n  return hps\n\n\n@register\ndef cifar100_vgg16_no_dropout():\n  hps = vgg16_default()\n  hps.data = \"cifar100\"\n\n  hps.input_shape = [32, 32, 3]\n  hps.output_shape = [100]\n  hps.num_classes = 100\n  hps.channels = 3\n  hps.learning_rate = 0.0001\n  return hps\n\n\n@register\ndef cifar10_vgg16_no_dropout():\n  hps = vgg16_default()\n  hps.data = \"cifar10\"\n\n  hps.input_shape = [32, 32, 3]\n  hps.output_shape = [10]\n  hps.num_classes = 10\n  hps.channels = 3\n  hps.learning_rate = 0.0001\n  return hps\n\n\n@register\ndef cifar100_vgg16_targeted_dropout():\n  hps = cifar100_vgg16_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_weight\"\n  hps.targ_rate = 0.5\n  return hps\n\n\n@register\ndef cifar100_vgg16_untargeted_dropout():\n  hps = cifar100_vgg16_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_weight\"\n  return hps\n\n\n@register\ndef cifar100_vgg16_untargeted_unit_dropout():\n  hps = cifar100_vgg16_no_dropout()\n  hps.drop_rate = 0.25\n  hps.dropout_type = \"untargeted_unit\"\n  return hps\n\n\n@register\ndef cifar100_vgg16_targeted_unit_dropout():\n  hps = cifar100_vgg16_no_dropout()\n  hps.drop_rate = 0.5\n  hps.dropout_type = \"targeted_unit\"\n  hps.targ_rate = 0.5\n  return hps\n\n\n@register\ndef cifar100_vgg16_targeted_unit_dropout_botk75_66():\n  hps = cifar100_vgg16_targeted_unit_dropout()\n  hps.drop_rate = 0.66\n  hps.targ_rate = 0.75\n  return hps\n\n\n@register\ndef cifar100_vgg16_louizos_unit():\n  hps = cifar100_vgg16_no_dropout()\n  hps.louizos_beta = 2. / 3.\n  hps.louizos_zeta = 1.1\n  hps.louizos_gamma = -0.1\n  hps.louizos_cost = 0.001\n  hps.dropout_type = \"louizos_unit\"\n  hps.drop_rate = 0.25\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_louizos_weight():\n  hps = cifar100_vgg16_louizos_unit()\n  hps.dropout_type = \"louizos_weight\"\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_variational_unscaled():\n  hps = cifar100_vgg16_no_dropout()\n  hps.dropout_type = \"variational\"\n  hps.drop_rate = 0.75\n  hps.thresh = 3\n  hps.var_scale = 1\n  hps.weight_decay_rate = 0.0\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_variational():\n  hps = cifar100_vgg16_variational_unscaled()\n  hps.var_scale = 1. / 100\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_variational_unit_unscaled():\n  hps = cifar100_vgg16_variational_unscaled()\n  hps.dropout_type = \"variational_unit\"\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_variational_unit():\n  hps = cifar100_vgg16_variational_unit_unscaled()\n  hps.var_scale = 1. / 100\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_smallify_1eneg4():\n  hps = cifar100_vgg16_no_dropout()\n  hps.dropout_type = \"smallify_dropout\"\n  hps.smallify = 1e-4\n  hps.smallify_mv = 0.9\n  hps.smallify_thresh = 0.5\n\n  return hps\n\n\n@register\ndef cifar100_vgg16_smallify_weight_1eneg5():\n  hps = cifar100_vgg16_smallify_1eneg4()\n  hps.dropout_type = \"smallify_weight_dropout\"\n  hps.smallify = 1e-5\n\n  return hps\n"
  },
  {
    "path": "models/__init__.py",
    "content": "__all__ = [\"basic\", \"registry\", \"resnet\", \"lenet\", \"vgg\"]\n\nfrom .basic import *\nfrom .resnet import *\nfrom .registry import *\nfrom .lenet import *\nfrom .vgg import *\n"
  },
  {
    "path": "models/basic/__init__.py",
    "content": "__all__ = [\"basic\"]\n\nfrom .basic import *\n"
  },
  {
    "path": "models/basic/basic.py",
    "content": "import tensorflow as tf\n\nfrom ..registry import register\n\nfrom ..utils.activations import get_activation\nfrom ..utils.initializations import get_init\nfrom ..utils.optimizers import get_optimizer\nfrom ..utils import model_utils\n\n\n@register(\"basic\")\ndef get_basic(params, lr):\n  \"\"\"Callable model function compatible with Experiment API.\n\n  Args:\n    params: a HParams object containing values for fields:\n    lr: learning rate variable\n  \"\"\"\n\n  def basic(features, labels, mode, _):\n    \"\"\"The basic neural net net template.\n\n    Args:\n      features: a dict containing key \"inputs\"\n      mode: training, evaluation or infer\n    \"\"\"\n    with tf.variable_scope(\"basic\", initializer=get_init(params)):\n      is_training = mode == tf.estimator.ModeKeys.TRAIN\n      actvn = get_activation(params)\n      x = features[\"inputs\"]\n      batch_size = tf.shape(x)[0]\n\n      nonzero = 0\n      activations = []\n      for i, feature_count in enumerate(params.layers):\n        with tf.variable_scope(\"layer_%d\" % i):\n          if params.layer_type == \"dense\":\n            x, w = model_utils.collect_vars(\n                lambda: model_utils.dense(x, feature_count, params, is_training)\n            )\n          elif params.layer_type == \"conv\":\n            x, w = model_utils.collect_vars(lambda: tf.layers.conv2d(\n                x, feature_count, params.kernel_size, padding=\"SAME\"))\n          if params.batch_norm:\n            x = tf.layers.batch_normalization(x, training=is_training)\n          x = actvn(x)\n          activations.append(x)\n      x = tf.reshape(x, [batch_size, params.layers[-1]])\n      with tf.variable_scope('logit'):\n        x = tf.layers.dense(x, params.output_shape[0], use_bias=False)\n\n      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:\n        predictions = {\n            'classes': tf.argmax(x, axis=1),\n            'logits': x,\n            'probabilities': tf.nn.softmax(x, name='softmax_tensor'),\n        }\n        return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n\n      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=x)\n      if params.smallify > 0.0:\n        loss += model_utils.switch_loss() * params.smallify\n\n      # Summaries\n      # ========================\n      if not params.use_tpu:\n        tf.summary.scalar(\"nonzero\", model_utils.nonzero_count())\n        tf.summary.scalar(\"percent_sparsity\", model_utils.percent_sparsity())\n      # ========================\n\n      return model_utils.model_top(labels, tf.nn.softmax(x, -1), loss, lr,\n                                   mode, params)\n\n  return basic\n"
  },
  {
    "path": "models/lenet/__init__.py",
    "content": "__all__ = [\"lenet\"]\n"
  },
  {
    "path": "models/lenet/lenet.py",
    "content": "import tensorflow as tf\n\nfrom ..registry import register\n\nfrom ..utils.activations import get_activation\nfrom ..utils.dropouts import get_dropout\nfrom ..utils.initializations import get_init\nfrom ..utils.optimizers import get_optimizer\nfrom ..utils import model_utils\n\n\n@register(\"lenet\")\ndef get_lenet(hparams, lr):\n  \"\"\"Callable model function compatible with Experiment API.\n\n    Args:\n      params: a HParams object containing values for fields:\n      lr: learning rate variable\n    \"\"\"\n\n  def _conv(name, x, filter_size, in_filters, out_filters, strides, mode):\n    \"\"\"Convolution.\"\"\"\n    with tf.variable_scope(name):\n      kernel = tf.get_variable(\n          'DW', [filter_size, filter_size, in_filters, out_filters],\n          tf.float32)\n      is_training = mode == tf.estimator.ModeKeys.TRAIN\n      if hparams.dropout_type is not None:\n        dropout_fn = get_dropout(hparams.dropout_type)\n        kernel = dropout_fn(kernel, hparams, is_training)\n\n        # special case for variational\n        if hparams.dropout_type and \"variational\" in hparams.dropout_type:\n          kernel, log_alpha = kernel[0], kernel[1]\n          if is_training:\n            conved_mu = tf.nn.conv2d(\n                x, kernel, strides=strides, padding='VALID')\n            conved_si = tf.sqrt(\n                tf.nn.conv2d(\n                    tf.square(x),\n                    tf.exp(log_alpha) * tf.square(kernel),\n                    strides=strides,\n                    padding='VALID') + 1e-8)\n            return conved_mu + tf.random_normal(\n                tf.shape(conved_mu)) * conved_si, tf.count_nonzero(kernel)\n\n      return tf.nn.conv2d(x, kernel, strides, padding='VALID')\n\n  def lenet(features, labels, mode, params):\n    \"\"\"The lenet neural net net template.\n\n            Args:\n              features: a dict containing key \"inputs\"\n              mode: training, evaluation or infer\n            \"\"\"\n    with tf.variable_scope(\"lenet\", initializer=get_init(hparams)):\n      is_training = mode == tf.estimator.ModeKeys.TRAIN\n      actvn = get_activation(hparams)\n\n      if hparams.use_tpu and 'batch_size' in params.keys():\n        hparams.batch_size = params['batch_size']\n\n      # input layer\n      x = features[\"inputs\"]\n      x = model_utils.standardize_images(x)\n\n      # unflatten\n      x = tf.reshape(x, [hparams.batch_size] + hparams.input_shape)\n\n      # conv1\n      b_conv1 = tf.get_variable(\n          \"Variable\", initializer=tf.constant_initializer(0.1), shape=[6])\n      h_conv1 = _conv('conv1', x, 5, 3, 6, [1, 1, 1, 1], mode) + b_conv1\n      h_conv1 = tf.nn.relu(h_conv1)\n      h_pool1 = tf.nn.max_pool(\n          h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n\n      # conv2\n      b_conv2 = tf.get_variable(\n          \"Variable_1\", initializer=tf.constant_initializer(0.1), shape=[16])\n      h_conv2 = _conv('conv2', h_pool1, 5, 6, 16, [1, 1, 1, 1], mode) + b_conv2\n      h_conv2 = tf.nn.relu(h_conv2)\n      h_pool2 = tf.nn.max_pool(\n          h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n\n      # flatten for fc\n      h_pool2_flat = tf.reshape(h_pool2, [hparams.batch_size, -1])\n\n      # fc1\n      with tf.variable_scope('fc1'):\n        h_fc1 = tf.nn.relu(\n            model_utils.dense(h_pool2_flat, 500, hparams, is_training))\n\n      # fc2\n      with tf.variable_scope('fc2'):\n        y = model_utils.dense(h_fc1, 10, hparams, is_training, dropout=False)\n\n      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:\n        predictions = {\n            'classes': tf.argmax(y, axis=1),\n            'logits': y,\n            'probabilities': tf.nn.softmax(y, name='softmax_tensor'),\n        }\n\n        return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n\n      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=y)\n\n      if hparams.axis_aligned_cost:\n        negativity_cost, axis_alignedness_cost, one_bound = model_utils.axis_aligned_cost(\n            y, hparams)\n        masked_max = tf.abs(y) * (\n            1 - tf.one_hot(tf.argmax(tf.abs(y), -1), hparams.num_classes))\n        tf.summary.scalar(\n            \"logit_prior\",\n            tf.reduce_mean(\n                tf.to_float(\n                    tf.logical_and(masked_max >= 0.0, masked_max <= 0.1))))\n        tf.summary.scalar(\"avg_max\",\n                          tf.reduce_mean(tf.reduce_max(tf.abs(y), axis=-1)))\n        loss += hparams.axis_aligned_cost * tf.reduce_mean(\n            negativity_cost + axis_alignedness_cost + 20. * one_bound)\n\n      if hparams.logit_squeezing:\n        loss += hparams.logit_squeezing * tf.reduce_mean(y**2)\n\n      if hparams.clp:\n        loss += hparams.clp * tf.reduce_mean(\n            (y[:hparams.batch_size // 2] - y[hparams.batch_size // 2:])**2)\n\n      if hparams.dropout_type and \"variational\" in hparams.dropout_type:\n        # prior DKL part of the ELBO\n        graph = tf.get_default_graph()\n        node_defs = [\n            n for n in graph.as_graph_def().node if 'log_alpha' in n.name\n        ]\n        log_alphas = [\n            graph.get_tensor_by_name(n.name + \":0\") for n in node_defs\n        ]\n        divergences = [model_utils.dkl_qp(la) for la in log_alphas]\n        # combine to form the ELBO\n        N = float(50000)\n        dkl = tf.reduce_sum(tf.stack(divergences))\n\n        warmup_steps = 50000\n        inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n        inv_decay = inv_base**(\n            warmup_steps - tf.to_float(tf.train.get_global_step()))\n\n        loss += (1. / N) * dkl * inv_decay * hparams.var_scale\n\n      if hparams.smallify > 0.0:\n        loss += model_utils.switch_loss() * hparams.smallify\n\n      return model_utils.model_top(labels, tf.nn.softmax(y, -1), loss, lr,\n                                   mode, hparams)\n\n  return lenet\n"
  },
  {
    "path": "models/registry.py",
    "content": "from ..training.lr_schemes import get_lr\n\nimport tensorflow as tf\n\n_MODELS = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _MODELS\n    _MODELS[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_model(hparams):\n\n  def model_fn(features, labels, mode, params=None):\n    lr = tf.constant(0.0)\n    if mode == tf.estimator.ModeKeys.TRAIN:\n      lr = get_lr(hparams)\n    return _MODELS[hparams.model](hparams, lr)(features, labels, mode, params)\n\n  return model_fn\n"
  },
  {
    "path": "models/resnet/__init__.py",
    "content": "__all__ = [\"resnet\"]\n"
  },
  {
    "path": "models/resnet/resnet.py",
    "content": "import tensorflow as tf\nimport numpy as np\n\nfrom ..utils import dropouts\nfrom ..utils.activations import get_activation\nfrom ..utils.dropouts import get_dropout, smallify_dropout\nfrom ..utils.initializations import get_init\nfrom ..registry import register\nfrom ..utils import model_utils\nfrom ..utils.model_utils import ModeKeys\nfrom ...training import tpu\n\n\n@register(\"resnet\")\ndef get_resnet(hparams, lr):\n  \"\"\"Callable model function compatible with Experiment API.\n\n          Args:\n            params: a HParams object containing values for fields:\n              use_bottleneck: bool to bottleneck the network\n              num_residual_units: number of residual units\n              num_classes: number of classes\n              batch_size: batch size\n              weight_decay_rate: weight decay rate\n          \"\"\"\n\n  def resnet(features, labels, mode, params):\n    if hparams.use_tpu and 'batch_size' in params.keys():\n      hparams.batch_size = params['batch_size']\n\n    is_training = mode == tf.estimator.ModeKeys.TRAIN\n\n    def _residual(x, out_filter, stride, projection=False):\n      \"\"\"Residual unit with 2 sub layers.\"\"\"\n      is_variational = hparams.dropout_type is not None and \"variational\" in hparams.dropout_type\n\n      orig_x = x\n      if not is_variational:\n        x = model_utils.batch_norm(x, hparams, is_training)\n        x = tf.nn.relu(x)\n\n      if projection:\n        orig_x = model_utils.conv(\n            x,\n            1,\n            out_filter,\n            hparams,\n            is_training=is_training,\n            strides=stride,\n            name=\"shortcut\")\n\n      with tf.variable_scope('sub1'):\n        x = model_utils.conv(\n            x,\n            3,\n            out_filter,\n            hparams,\n            is_training=is_training,\n            strides=stride,\n            name='conv1')\n\n        x = model_utils.batch_norm(x, hparams, is_training)\n        x = tf.nn.relu(x)\n\n      with tf.variable_scope('sub2'):\n        x = model_utils.conv(\n            x,\n            3,\n            out_filter,\n            hparams,\n            is_training=is_training,\n            strides=[1, 1, 1, 1],\n            name='conv2')\n\n      x += orig_x\n\n      return x\n\n    def _bottleneck_residual(x, out_filter, stride, projection=False):\n      \"\"\"Residual unit with 3 sub layers.\"\"\"\n\n      is_variational = hparams.dropout_type is not None and \"variational\" in hparams.dropout_type\n\n      orig_x = x\n      if not is_variational:\n        x = model_utils.batch_norm(x, hparams, is_training)\n        x = tf.nn.relu(x)\n\n      if projection:\n        orig_x = model_utils.conv(\n            x,\n            1,\n            4 * out_filter,\n            hparams,\n            is_training=is_training,\n            strides=stride,\n            name=\"shortcut\")\n\n      with tf.variable_scope('sub1'):\n        x = model_utils.conv(\n            x,\n            1,\n            out_filter,\n            hparams,\n            is_training=is_training,\n            strides=[1, 1, 1, 1],\n            name='conv1')\n        x = model_utils.batch_norm(x, hparams, is_training)\n        x = tf.nn.relu(x)\n      with tf.variable_scope('sub2'):\n        x = model_utils.conv(\n            x,\n            3,\n            out_filter,\n            hparams,\n            is_training=is_training,\n            strides=stride,\n            name='conv2')\n        x = model_utils.batch_norm(x, hparams, is_training)\n        x = tf.nn.relu(x)\n      with tf.variable_scope('sub3'):\n        x = model_utils.conv(\n            x,\n            1,\n            4 * out_filter,\n            hparams,\n            is_training=is_training,\n            strides=[1, 1, 1, 1],\n            name='conv3')\n\n      return orig_x + x\n\n    def _l1():\n      \"\"\"L1 weight decay loss.\"\"\"\n      if hparams.l1_norm == 0:\n        return 0\n\n      costs = []\n      for var in tf.trainable_variables():\n        if \"DW\" in var.name and \"logit\" not in var.name:\n          costs.append(tf.reduce_mean(tf.abs(var)))\n\n      return tf.multiply(hparams.l1_norm, tf.add_n(costs))\n\n    def _fully_connected(x, out_dim):\n      \"\"\"FullyConnected layer for final output.\"\"\"\n      prev_dim = np.product(x.get_shape().as_list()[1:])\n      x = tf.reshape(x, [hparams.batch_size, prev_dim])\n      w = tf.get_variable('DW', [prev_dim, out_dim])\n      b = tf.get_variable(\n          'biases', [out_dim], initializer=tf.zeros_initializer())\n      return tf.nn.xw_plus_b(x, w, b)\n\n    def _global_avg_pool(x):\n      assert x.get_shape().ndims == 4\n      if hparams.data_format == \"channels_last\":\n        return tf.reduce_mean(x, [1, 2])\n\n      return tf.reduce_mean(x, [2, 3])\n\n    def _stride_arr(stride):\n      \"\"\"Map a stride scalar to the stride array for tf.nn.conv2d.\"\"\"\n      if hparams.data_format == \"channels_last\":\n        return [1, stride, stride, 1]\n\n      return [1, 1, stride, stride]\n\n    if mode == ModeKeys.PREDICT or mode == ModeKeys.ATTACK:\n      if \"labels\" in features:\n        labels = features[\"labels\"]\n\n    with tf.variable_scope(\"resnet\", initializer=get_init(hparams)):\n      hparams.mode = mode\n      strides = [1, 2, 2, 2]\n      res_func = (_residual\n                  if not hparams.use_bottleneck else _bottleneck_residual)\n      filters = hparams.residual_filters\n      large_input = hparams.input_shape[0] > 32\n\n      # 3 and 16 picked from example implementation\n      with tf.variable_scope('init'):\n        x = features[\"inputs\"]\n        stride = _stride_arr(2) if large_input else _stride_arr(1)\n        x = model_utils.conv(\n            x,\n            7,\n            filters[0],\n            hparams,\n            strides=stride,\n            dropout=False,\n            name='init_conv')\n\n        if large_input:\n          x = tf.layers.max_pooling2d(\n              inputs=x,\n              pool_size=3,\n              strides=2,\n              padding=\"SAME\",\n              data_format=hparams.data_format)\n\n      with tf.variable_scope('unit_1_0'):\n        x = res_func(x, filters[1], _stride_arr(strides[0]), True)\n\n      for i in range(1, hparams.residual_units[0]):\n        with tf.variable_scope('unit_1_%d' % i):\n          x = res_func(x, filters[1], _stride_arr(1), False)\n\n      with tf.variable_scope('unit_2_0'):\n        x = res_func(x, filters[2], _stride_arr(strides[1]), True)\n\n      for i in range(1, hparams.residual_units[1]):\n        with tf.variable_scope('unit_2_%d' % i):\n          x = res_func(x, filters[2], _stride_arr(1), False)\n\n      with tf.variable_scope('unit_3_0'):\n        x = res_func(x, filters[3], _stride_arr(strides[2]), True)\n\n      for i in range(1, hparams.residual_units[2]):\n        with tf.variable_scope('unit_3_%d' % i):\n          x = res_func(x, filters[3], _stride_arr(1), False)\n\n      if len(filters) == 5:\n        with tf.variable_scope('unit_4_0'):\n          x = res_func(x, filters[4], _stride_arr(strides[3]), True)\n\n        for i in range(1, hparams.residual_units[3]):\n          with tf.variable_scope('unit_4_%d' % i):\n            x = res_func(x, filters[4], _stride_arr(1), False)\n\n      x = model_utils.batch_norm(x, hparams, is_training)\n      x = tf.nn.relu(x)\n\n      with tf.variable_scope('unit_last'):\n        x = _global_avg_pool(x)\n\n      with tf.variable_scope('logit'):\n        logits = _fully_connected(x, hparams.num_classes)\n        predictions = tf.nn.softmax(logits)\n\n      if mode in [ModeKeys.PREDICT, ModeKeys.ATTACK]:\n\n        return tf.estimator.EstimatorSpec(\n            mode=mode,\n            predictions={\n                'classes': tf.argmax(predictions, axis=1),\n                'logits': logits,\n                'probabilities': predictions,\n            })\n\n      with tf.variable_scope('costs'):\n        xent = tf.losses.sparse_softmax_cross_entropy(\n            labels=labels, logits=logits)\n        cost = tf.reduce_mean(xent, name='xent')\n        if is_training:\n          cost += model_utils.weight_decay(hparams)\n          cost += _l1()\n\n          if hparams.dropout_type is not None:\n            if \"louizos\" in hparams.dropout_type:\n              cost += hparams.louizos_cost * model_utils.louizos_complexity_cost(\n                  hparams) / 50000\n\n            if \"variational\" in hparams.dropout_type:\n              # prior DKL part of the ELBO\n              graph = tf.get_default_graph()\n              node_defs = [\n                  n for n in graph.as_graph_def().node if 'log_alpha' in n.name\n              ]\n              log_alphas = [\n                  graph.get_tensor_by_name(n.name + \":0\") for n in node_defs\n              ]\n              print([\n                  n.name\n                  for n in graph.as_graph_def().node\n                  if 'log_alpha' in n.name\n              ])\n              print(\"found %i logalphas\" % len(log_alphas))\n              divergences = [dropouts.dkl_qp(la) for la in log_alphas]\n              # combine to form the ELBO\n              N = float(50000)\n              dkl = tf.reduce_sum(tf.stack(divergences))\n\n              warmup_steps = 50000\n              dkl = (1. / N) * dkl * tf.minimum(\n                  1.0,\n                  tf.to_float(tf.train.get_global_step()) /\n                  warmup_steps) * hparams.var_scale\n              cost += dkl\n              tf.summary.scalar(\"dkl\", dkl)\n\n          if hparams.ard_cost > 0.0:\n            cost += model_utils.ard_cost() * hparams.ard_cost\n\n          if hparams.smallify > 0.0:\n            cost += model_utils.switch_loss() * hparams.smallify\n\n    # Summaries\n    # ========================\n    tf.summary.scalar(\"total_nonzero\", model_utils.nonzero_count())\n    all_weights = tf.concat(\n        [\n            tf.reshape(v, [-1])\n            for v in tf.trainable_variables()\n            if \"DW\" in v.name\n        ],\n        axis=0)\n    tf.summary.histogram(\"weights\", all_weights)\n    # ========================\n\n    return model_utils.model_top(labels, predictions, cost, lr, mode, hparams)\n\n  return resnet\n"
  },
  {
    "path": "models/utils/__init__.py",
    "content": "__all__ = [\n    \"activations\", \"dropouts\", \"initializations\", \"model_utils\", \"optimizers\"\n]\n\nfrom .activations import *\nfrom .dropouts import *\nfrom .initializations import *\nfrom .model_utils import *\nfrom .optimizers import *\n"
  },
  {
    "path": "models/utils/activations.py",
    "content": "import tensorflow as tf\n\n_ACTIVATION = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _ACTIVATION\n    _ACTIVATION[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_activation(params):\n  return _ACTIVATION[params.activation](params)\n\n\n@register(\"relu\")\ndef relu(params):\n  return tf.nn.relu\n\n\n@register(\"brelu\")\ndef brelu(params):\n\n  def fn(a):\n    idx = tf.range(a.shape[-1])\n    idx = tf.mod(idx, 2)\n    idx = tf.cast(idx, tf.bool)\n\n    even = tf.nn.relu(a)\n    odd = -tf.nn.relu(-a)\n\n    return tf.where(idx, odd, even)\n\n  return fn\n\n\n@register(\"selu\")\ndef selu(params):\n  return tf.nn.selu\n\n\n@register(\"elu\")\ndef elu(params):\n  return tf.nn.elu\n\n\n@register(\"sigmoid\")\ndef sigmoid(params):\n  return tf.nn.sigmoid\n\n\n@register(\"swish\")\ndef swish(params):\n  return lambda x: tf.nn.sigmoid(x) * x\n\n\n@register(\"tanh\")\ndef tanh(params):\n  return tf.nn.tanh\n"
  },
  {
    "path": "models/utils/dropouts.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\n_DROPOUTS = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _DROPOUTS\n    _DROPOUTS[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_dropout(name):\n  return _DROPOUTS[name]\n\n\n\n@register(\"targeted_weight\")\ndef targeted_weight_dropout(w, params, is_training):\n  drop_rate = params.drop_rate\n  targ_perc = params.targ_rate\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n  norm = tf.abs(w)\n  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))\n  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]\n  mask = norm < threshold[None, :]\n\n  if not is_training:\n    w = (1. - tf.to_float(mask)) * w\n    w = tf.reshape(w, w_shape)\n    return w\n\n  mask = tf.to_float(\n      tf.logical_and(tf.random_uniform(tf.shape(w)) < drop_rate, mask))\n  w = (1. - mask) * w\n  w = tf.reshape(w, w_shape)\n  return w\n\n\n@register(\"targeted_weight_random\")\ndef targeted_weight_random(w, params, is_training):\n  drop_rate = params.drop_rate\n  targ_perc = params.targ_rate\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n\n  switch = tf.get_variable(\n      \"mask\",\n      w.shape,\n      initializer=tf.random_uniform_initializer(),\n      trainable=False)\n\n  if is_training:\n    mask = tf.logical_and(switch < targ_perc,\n                          tf.random_uniform(w.shape) < drop_rate)\n  else:\n    mask = switch < targ_perc\n\n  mask = 1. - tf.to_float(mask)\n  mask = tf.stop_gradient(mask)\n\n  w = mask * w\n  w = tf.reshape(w, w_shape)\n  return w\n\n\n@register(\"ramping_targeted_weight_random\")\ndef ramping_targeted_weight_random(w, params, is_training):\n  drop_rate = params.drop_rate\n  targ_perc = 0.95 * params.targ_rate * tf.minimum(\n      1.0,\n      tf.to_float(tf.train.get_global_step()) / 20000.)\n  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(\n      0.0,\n      tf.minimum(1.0,\n                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n\n  switch = tf.get_variable(\n      \"mask\",\n      w.shape,\n      initializer=tf.random_uniform_initializer(),\n      trainable=False)\n\n  if is_training:\n    mask = tf.logical_and(switch < targ_perc,\n                          tf.random_uniform(w.shape) < drop_rate)\n  else:\n    mask = switch < (targ_perc * drop_rate)\n\n  mask = 1. - tf.to_float(mask)\n  mask = tf.stop_gradient(mask)\n\n  w = mask * w\n  w = tf.reshape(w, w_shape)\n  return w\n\n\n@register(\"targeted_weight_piecewise\")\ndef targeted_weight_piecewise_dropout(w, params, is_training):\n  drop_rate = params.drop_rate * tf.minimum(\n      1.0,\n      tf.to_float(tf.train.get_global_step()) / 40000.)\n\n  targ_perc = 0.95 * params.targ_rate * tf.minimum(\n      1.0,\n      tf.to_float(tf.train.get_global_step()) / 20000.)\n  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(\n      0.0,\n      tf.minimum(1.0,\n                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n  norm = tf.abs(w)\n  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))\n  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]\n  mask = norm < threshold[None, :]\n\n  if not is_training:\n    w = w * (1 - tf.to_float(mask))\n    return tf.reshape(w, w_shape)\n\n  mask = tf.where(\n      tf.logical_and((1. - drop_rate) < tf.random_uniform(tf.shape(w)), mask),\n      tf.ones_like(w, dtype=tf.float32), tf.zeros_like(w, dtype=tf.float32))\n  w = (1 - mask) * w\n  w = tf.reshape(w, w_shape)\n  return w\n\n\n@register(\"targeted_unit_piecewise\")\ndef targeted_unit_piecewise(w, params, is_training):\n  drop_rate = params.drop_rate * tf.minimum(\n      1.0,\n      tf.to_float(tf.train.get_global_step()) / 40000.)\n\n  targ_perc = 0.95 * params.targ_rate * tf.minimum(\n      1.0,\n      tf.to_float(tf.train.get_global_step()) / 20000.)\n  targ_perc = targ_perc + 0.05 * params.targ_rate * tf.maximum(\n      0.0,\n      tf.minimum(1.0,\n                 (tf.to_float(tf.train.get_global_step()) - 20000.) / 20000.))\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w.shape[-1]])\n  norm = tf.norm(w, axis=0)\n  idx = tf.to_int32(targ_perc * tf.to_float(w.shape[1]))\n  sorted_norms = tf.contrib.framework.sort(norm)\n  threshold = sorted_norms[idx]\n  mask = (norm < threshold)[None, :]\n\n  if not is_training:\n    w = w * (1 - tf.to_float(mask))\n    return tf.reshape(w, w_shape)\n\n  mask = tf.tile(mask, [w.shape[0], 1])\n  mask = tf.where(\n      tf.logical_and((1. - drop_rate) < tf.random_uniform(tf.shape(w)), mask),\n      tf.ones_like(w, dtype=tf.float32), tf.zeros_like(w, dtype=tf.float32))\n  w = tf.reshape((1 - mask) * w, w_shape)\n  return w\n\n\n@register(\"delayed_targeted_weight_prune\")\ndef delayed_targeted_weight(w, params, is_training):\n  orig_w = w\n  targ_perc = params.targ_rate\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n  norm = tf.abs(w)\n  idx = tf.to_int32(targ_perc * tf.to_float(tf.shape(w)[0]))\n  threshold = tf.contrib.framework.sort(norm, axis=0)[idx]\n  mask = norm >= threshold[None, :]\n\n  w = w * tf.to_float(mask)\n  cond = tf.to_float(tf.train.get_global_step() >= params.dropout_delay_steps)\n  return cond * tf.reshape(w, w_shape) + (1 - cond) * orig_w\n\n\n@register(\"delayed_targeted_unit_prune\")\ndef delayed_targeted_unit(x, params, is_training):\n  orig_x = x\n\n  w = tf.reshape(x, [-1, x.shape[-1]])\n  norm = tf.norm(w, axis=0)\n  idx = int(params.targ_rate * int(w.shape[1]))\n  sorted_norms = tf.contrib.framework.sort(norm)\n  threshold = sorted_norms[idx]\n  mask = (norm >= threshold)[None, None]\n\n  w = w * tf.to_float(mask)\n  return tf.cond(\n      tf.greater(tf.train.get_global_step(), params.dropout_delay_steps),\n      lambda: tf.reshape(w, x.shape), lambda: orig_x)\n\n\n@register(\"untargeted_weight\")\ndef untargeted_weight(w, params, is_training):\n  if not is_training:\n    return w\n  return tf.nn.dropout(w, keep_prob=(1. - params.drop_rate))\n\n\n@register(\"targeted_unit\")\ndef targeted_unit_dropout(x, params, is_training):\n  w = tf.reshape(x, [-1, x.shape[-1]])\n  norm = tf.norm(w, axis=0)\n  idx = int(params.targ_rate * int(w.shape[1]))\n  sorted_norms = tf.contrib.framework.sort(norm)\n  threshold = sorted_norms[idx]\n  mask = (norm < threshold)[None, :]\n  mask = tf.tile(mask, [w.shape[0], 1])\n\n  if not is_training:\n    w = (1. - tf.to_float(mask)) * w\n    w = tf.reshape(w, x.shape)\n    return w\n\n  \n  mask = tf.where(\n      tf.logical_and((1. - params.drop_rate) < tf.random_uniform(tf.shape(w)),\n                     mask), tf.ones_like(w, dtype=tf.float32),\n      tf.zeros_like(w, dtype=tf.float32))\n  x = tf.reshape((1 - mask) * w, x.shape)\n  return x\n\n\n@register(\"targeted_unit_random\")\ndef targeted_unit_random(w, params, is_training):\n  drop_rate = params.drop_rate\n  targ_perc = params.targ_rate\n\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-1]])\n\n  switch = tf.get_variable(\n      \"mask\",\n      w.shape[-1],\n      initializer=tf.random_uniform_initializer(),\n      trainable=False)\n\n  if is_training:\n    mask = tf.logical_and(switch < targ_perc,\n                          tf.random_uniform(switch.shape) < drop_rate)\n  else:\n    mask = switch < targ_perc\n\n  mask = 1. - tf.to_float(mask)\n  mask = tf.stop_gradient(mask[None, :])\n\n  w = mask * w\n  w = tf.reshape(w, w_shape)\n  return w\n\n\n@register(\"targeted_ard\")\ndef targeted_ard_dropout(w, x, params, is_training):\n  if not is_training:\n    return w\n  x = tf.reshape(x, [-1, x.shape[-1]])\n  activation_norms = tf.reduce_mean(tf.abs(x), axis=0)\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w_shape[-2], w_shape[-1]])\n  norm = tf.norm(w, axis=(0, 2)) * activation_norms\n  idx = int(params.targ_rate * int(w.shape[1]))\n  sorted_norms = tf.contrib.framework.sort(norm)\n  threshold = sorted_norms[idx]\n  mask = (norm < threshold)[None, :, None]\n  mask = tf.tile(mask, [w.shape[0], 1, w.shape[-1]])\n  mask = tf.where(\n      tf.logical_and((1. - params.drop_rate) < tf.random_uniform(tf.shape(w)),\n                     mask), tf.ones_like(w, dtype=tf.float32),\n      tf.zeros_like(w, dtype=tf.float32))\n  w = tf.reshape((1 - mask) * w, w_shape)\n  return w\n\n\n@register(\"untargeted_unit\")\ndef unit_dropout(w, params, is_training):\n  if not is_training:\n    return w\n  w_shape = w.shape\n  w = tf.reshape(w, [-1, w.shape[-1]])\n  mask = tf.to_float(\n      tf.random_uniform([int(w.shape[1])]) > params.drop_rate)[None, :]\n  w = tf.reshape(mask * w, w_shape)\n  return w / (1 - params.drop_rate)\n\n\n@register(\"louizos_weight\")\ndef louizos_weight_dropout(w, params, is_training):\n  with tf.variable_scope(\"louizos\"):\n    EPS = 1e-8\n    noise = (1 - EPS) * tf.random_uniform(w.shape) + (EPS / 2)\n    rate = np.log(1 - params.drop_rate) - np.log(params.drop_rate)\n    gates = tf.get_variable(\n        \"gates\",\n        shape=w.shape,\n        initializer=tf.random_normal_initializer(mean=rate, stddev=0.01))\n    if is_training:\n      s = tf.nn.sigmoid(\n          (gates + tf.log(noise / (1. - noise))) / params.louizos_beta)\n      s_bar = s * (\n          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma\n    else:\n      s = tf.nn.sigmoid(gates)\n      s_bar = s * (\n          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma\n    mask = tf.minimum(1., tf.maximum(0., s_bar))\n\n    return mask * w\n\n\n@register(\"louizos_unit\")\ndef louizos_unit_dropout(w, params, is_training):\n  with tf.variable_scope(\"louizos\"):\n    EPS = 1e-8\n    noise = (1 - EPS) * \\\n        tf.random_uniform([w.shape.as_list()[-1]]) + (EPS / 2)\n    rate = np.log(1 - params.drop_rate) - np.log(params.drop_rate)\n    gates = tf.get_variable(\n        \"gates\",\n        shape=[w.shape.as_list()[-1]],\n        initializer=tf.random_normal_initializer(mean=rate, stddev=0.01))\n    if is_training:\n      s = tf.nn.sigmoid(\n          (gates + tf.log(noise / (1. - noise))) / params.louizos_beta)\n      s_bar = s * (\n          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma\n    else:\n      s = tf.nn.sigmoid(gates)\n      s_bar = s * (\n          params.louizos_zeta - params.louizos_gamma) + params.louizos_gamma\n    mask = tf.minimum(1., tf.maximum(0., s_bar))\n\n    return mask * w\n\n\n# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py\ndef log_sigma2_variable(shape, ard_init=-10.):\n  return tf.get_variable(\n      \"log_sigma2\", shape=shape, initializer=tf.constant_initializer(ard_init))\n\n\n# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py\ndef get_log_alpha(log_sigma2, w):\n  log_alpha = clip(log_sigma2 - paranoid_log(tf.square(w)))\n  return tf.identity(log_alpha, name='log_alpha')\n\n\n# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py\ndef paranoid_log(x, eps=1e-8):\n  v = tf.log(x + eps)\n  return v\n\n\n# from https://github.com/BayesWatch/tf-variational-dropout/blob/master/variational_dropout.py\ndef clip(x):\n  return tf.clip_by_value(x, -8., 8.)\n\n\ndef dkl_qp(log_alpha):\n  k1, k2, k3 = 0.63576, 1.8732, 1.48695\n  C = -k1\n  mdkl = k1 * tf.nn.sigmoid(k2 + k3 * log_alpha) - 0.5 * tf.log1p(\n      tf.exp(-log_alpha)) + C\n  return -tf.reduce_sum(mdkl)\n\n\n@register(\"variational\")\ndef variational_dropout(w, _, is_training):\n  with tf.variable_scope(\"variational\"):\n    log_sigma2 = log_sigma2_variable(w.get_shape())\n    log_alpha = get_log_alpha(log_sigma2, w)\n    select_mask = tf.cast(tf.less(log_alpha, 3), tf.float32)\n\n    if is_training:\n      return w, log_alpha\n\n    return w * select_mask, log_alpha\n\n\n@register(\"variational_unit\")\ndef variational_unit_dropout(w, _, is_training):\n  with tf.variable_scope(\"variational\"):\n    log_sigma2 = log_sigma2_variable(int(w.shape[-1]))\n    log_sigma2 = tf.reshape(log_sigma2, [1, 1, 1, -1])\n    log_sigma2 = tf.tile(log_sigma2, [w.shape[0], w.shape[1], w.shape[2], 1])\n    log_alpha = get_log_alpha(log_sigma2, w)\n    select_mask = tf.cast(tf.less(log_alpha, 3), tf.float32)\n\n    if is_training:\n      return w, log_alpha\n\n    return w * select_mask, log_alpha\n\n\n@register(\"smallify_dropout\")\ndef smallify_dropout(x, hparams, is_training):\n  with tf.variable_scope(\"smallify\", reuse=tf.AUTO_REUSE):\n    switch = tf.get_variable(\n        \"switch\",\n        shape=[1] * (len(x.shape) - 1) + [x.shape[-1]],\n        initializer=tf.random_uniform_initializer())\n\n    mask = tf.get_variable(\n        initializer=lambda: tf.ones_like(switch.initialized_value()),\n        name=\"mask\",\n        trainable=False)\n    exp_avg = tf.get_variable(\n        initializer=lambda: tf.sign(switch.initialized_value()),\n        name=\"exp_avg\",\n        trainable=False)\n    exp_std = tf.get_variable(\n        initializer=lambda: tf.zeros_like(switch.initialized_value()),\n        name=\"exp_std\",\n        trainable=False)\n    gates = switch * mask\n\n    batch_sign = tf.sign(switch)\n    diff = batch_sign - exp_avg\n\n    new_mask = tf.cast(tf.less(exp_std, hparams.smallify_thresh), tf.float32)\n\n    if not is_training:\n      return tf.identity(x * gates, name=\"smallified\")\n\n    with tf.control_dependencies([\n        tf.assign(mask, mask * new_mask),\n        tf.assign(\n            exp_std, hparams.smallify_mv * exp_std +\n            (1 - hparams.smallify_mv) * diff**2),\n        tf.assign(\n            exp_avg, hparams.smallify_mv * exp_avg +\n            (1 - hparams.smallify_mv) * batch_sign)\n    ]):\n      return tf.identity(x * gates, name=\"smallified\")\n\n\n@register(\"smallify_weight_dropout\")\ndef smallify_weight_dropout(x, hparams, is_training):\n  with tf.variable_scope(\"smallify\"):\n    switch = tf.get_variable(\n        \"switch\", shape=x.shape, initializer=tf.random_uniform_initializer())\n\n    mask = tf.get_variable(\n        initializer=lambda: tf.ones_like(switch.initialized_value()),\n        name=\"mask\",\n        trainable=False)\n    exp_avg = tf.get_variable(\n        initializer=lambda: tf.sign(switch.initialized_value()),\n        name=\"exp_avg\",\n        trainable=False)\n    exp_std = tf.get_variable(\n        initializer=lambda: tf.zeros_like(switch.initialized_value()),\n        name=\"exp_std\",\n        trainable=False)\n    gates = switch * mask\n\n    batch_sign = tf.sign(switch)\n    diff = batch_sign - exp_avg\n\n    new_mask = tf.cast(tf.less(exp_std, hparams.smallify_thresh), tf.float32)\n\n    if not is_training:\n      return tf.identity(x * gates, name=\"smallified\")\n\n    with tf.control_dependencies([\n        tf.assign(mask, mask * new_mask),\n        tf.assign(\n            exp_std, hparams.smallify_mv * exp_std +\n            (1 - hparams.smallify_mv) * diff**2),\n        tf.assign(\n            exp_avg, hparams.smallify_mv * exp_avg +\n            (1 - hparams.smallify_mv) * batch_sign)\n    ]):\n      return tf.identity(x * gates, name=\"smallified\")\n"
  },
  {
    "path": "models/utils/initializations.py",
    "content": "import tensorflow as tf\n\n_INIT = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _INIT\n    _INIT[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_init(params):\n  return _INIT[params.initializer](params)\n\n\n@register(\"normal\")\ndef normal(params):\n  return tf.random_normal_initializer(mean=params.mean, stddev=params.sd)\n\n\n@register(\"constant\")\ndef constant(params):\n  return tf.constant_initializer(0.1, tf.float32)\n\n\n@register(\"uniform_unit_scaling\")\ndef uniform_unit_scaling(params):\n  return tf.uniform_unit_scaling_initializer()\n\n\n@register(\"glorot_normal_initializer\")\ndef glorot_normal_initializer(params):\n  return tf.glorot_normal_initializer()\n\n\n@register(\"glorot_uniform_initializer\")\ndef glorot_uniform_initializer(params):\n  return tf.glorot_uniform_initializer()\n\n\n@register(\"variance_scaling_initializer\")\ndef variance_scaling_initializer(params):\n  return tf.variance_scaling_initializer()\n\n\nclass RandomUnitScaling(tf.keras.initializers.Initializer):\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    if len(shape) == 2:\n      dim = (shape[0] + shape[1]) / 2.\n    elif len(shape) == 4:\n      dim = shape[0] * shape[1] * (shape[2] + shape[3]) / 2.\n\n    m = tf.sqrt(3 / tf.to_float(dim))\n    init = m * (2 * tf.random_uniform(shape) - 1)\n    return init\n\n\nclass RandomHadamardConstant(tf.keras.initializers.Initializer):\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    dim = (shape[0] + shape[1]) / 2.\n\n    flip = 2 * tf.round(tf.random_uniform(shape)) - 1\n    m = tf.pow(dim, -1 / 2.)\n    return m * flip\n\n\nclass RandomHadamardUnscaled(tf.keras.initializers.Initializer):\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    return 2 * tf.round(tf.random_uniform(shape)) - 1\n\n\nclass RandomWarpedUniform(tf.keras.initializers.Initializer):\n\n  def __init__(self, k=2):\n    self.k = k\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    if len(shape) == 2:\n      dim = (shape[0] + shape[1]) / 2.\n    elif len(shape) == 4:\n      dim = shape[0] * shape[1] * (shape[2] + shape[3]) / 2.\n\n    m = tf.sqrt(3 / tf.to_float(dim))\n\n    eps = 1e-10\n    unif = (1 - eps) * tf.random_uniform(shape) + eps / 2\n    skew_unif = tf.nn.sigmoid(self.k * tf.log(unif / (1 - unif)))\n    init = m * (2 * skew_unif - 1)\n    return init\n\n\n@register(\"warped_unif\")\ndef warped_unif(params):\n  return RandomWarpedUniform(params.k)\n\n\n@register(\"unit_scaling\")\ndef unit_scaling(params):\n  return RandomUnitScaling()\n\n\n@register(\"hadamard_constant\")\ndef hadamard_constant(params):\n  return RandomHadamardConstant()\n\n\n@register(\"hadamard_unscaled\")\ndef hadamard_unscaled(params):\n  return RandomHadamardUnscaled()"
  },
  {
    "path": "models/utils/model_utils.py",
    "content": "import operator\nfrom functools import reduce\n\nimport tensorflow as tf\nfrom tensorflow.contrib.tpu.python.tpu import tpu_estimator\n\nfrom . import dropouts\nfrom .optimizers import get_optimizer\nfrom ...training import tpu\n\n\nclass ModeKeys(object):\n  TRAIN = tf.estimator.ModeKeys.TRAIN\n  EVAL = tf.estimator.ModeKeys.EVAL\n  TEST = \"test\"\n  PREDICT = tf.estimator.ModeKeys.PREDICT\n  ATTACK = \"attack\"\n\n\ndef collect_vars(fn):\n  \"\"\"Collect all new variables created within `fn`.\n\n  Args:\n    fn: a function that takes no arguments and creates trainable tf.Variable\n      objects.\n\n  Returns:\n    outputs: the outputs of `fn()`.\n    new_vars: a list of the newly created variables.\n  \"\"\"\n  previous_vars = set(tf.trainable_variables())\n  outputs = fn()\n  current_vars = set(tf.trainable_variables())\n  new_vars = current_vars.difference(previous_vars)\n  return outputs, list(new_vars)\n\n\ndef dense(x, units, hparams, is_training, dropout=True):\n  with tf.variable_scope(None, default_name=\"dense\") as scope:\n    w = tf.get_variable(\"kernel\", shape=[x.shape[1], units], dtype=tf.float32)\n    b = tf.get_variable(\n        \"bias\",\n        shape=[units],\n        dtype=tf.float32,\n        initializer=tf.zeros_initializer())\n    if dropout and hparams.dropout_type is not None and is_training:\n      w = dropouts.get_dropout(hparams.dropout_type)(w, hparams, is_training)\n\n    w = tf.identity(w, name=\"post_dropout\")\n    y = tf.matmul(x, w) + b\n    return y\n\n\ndef conv(x,\n         filter_size,\n         out_filters,\n         hparams,\n         strides=[1, 1, 1, 1],\n         padding=\"SAME\",\n         is_training=False,\n         activation=None,\n         dropout=True,\n         name=None,\n         schit_layer=False):\n  \"\"\"Convolution.\"\"\"\n  with tf.variable_scope(name, default_name=\"conv2d\"):\n    if hparams.data_format == \"channels_last\":\n      in_filters = x.shape[-1]\n    else:\n      in_filters = x.shape[1]\n\n    kernel = tf.get_variable(\n        'DW', [filter_size, filter_size, in_filters, out_filters], tf.float32)\n    use_dropout = hparams.dropout_type is not None and dropout\n\n    # schit layer\n    if schit_layer:\n      scale = tf.get_variable(\n          'scale',\n          kernel.shape[-1],\n          tf.float32,\n          initializer=tf.zeros_initializer())\n      kernel = hparams.lipschitz_constant * tf.nn.sigmoid(\n          scale) * kernel / tf.norm(\n              tf.reshape(kernel, shape=[-1, kernel.shape[-1]]), axis=0)\n\n    if use_dropout:\n      dropout_fn = dropouts.get_dropout(hparams.dropout_type)\n\n      if hparams.dropout_type == \"targeted_ard\":\n        kernel = dropout_fn(kernel, x, hparams, is_training)\n      else:\n        kernel = dropout_fn(kernel, hparams, is_training)\n\n      # special case for variational\n      if \"variational\" in hparams.dropout_type:\n        kernel, log_alpha = kernel[0], kernel[1]\n        if is_training:\n          conved_mu = tf.nn.conv2d(x, kernel, strides=strides, padding=padding)\n          conved_si = tf.sqrt(\n              tf.nn.conv2d(\n                  tf.square(x),\n                  tf.exp(log_alpha) * tf.square(kernel),\n                  strides=strides,\n                  padding=padding) + 1e-8)\n          conved = conved_mu + tf.random_normal(\n              tf.shape(conved_mu)) * conved_si\n\n          conved = tf.identity(conved, name=\"post_dropout\")\n          return conved\n\n    data_format = \"NHWC\" if hparams.data_format == \"channels_last\" else \"NCHW\"\n    conv = tf.nn.conv2d(\n        x, kernel, strides, padding=padding, data_format=data_format)\n\n    if activation:\n      conv = activation(conv)\n\n    conv = tf.identity(conv, name=\"post_dropout\")\n    return conv\n\n\ndef weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):\n  \"\"\"Apply weight decay and weight noise.\"\"\"\n\n  weight_decay_loss = weight_decay(hparams)\n  tf.summary.scalar(\"losses/weight_decay\", weight_decay_loss)\n  weight_noise_ops = weight_noise(hparams, learning_rate)\n  with tf.control_dependencies(weight_noise_ops):\n    loss = tf.identity(loss)\n\n  loss += weight_decay_loss\n  return loss\n\n\ndef weight_noise(hparams, learning_rate):\n  \"\"\"Apply weight noise to vars in var_list.\"\"\"\n  if not hparams.weight_noise_rate:\n    return [tf.no_op()]\n\n  tf.logging.info(\"Applying weight noise scaled by learning rate, \"\n                  \"noise_rate: %0.5f\", hparams.weight_noise_rate)\n  noise_ops = []\n\n  noise_vars = [v for v in tf.trainable_variables() if \"/body/\" in v.name]\n  for v in var_list:\n    with tf.device(v._ref().device):  # pylint: disable=protected-access\n      scale = hparams.weight_noise_rate * learning_rate * 0.001\n      tf.summary.scalar(\"weight_noise_scale\", scale)\n      noise = tf.truncated_normal(v.shape) * scale\n      noise_op = v.assign_add(noise)\n      noise_ops.append(noise_op)\n  return noise_ops\n\n\ndef weight_decay(hparams):\n  \"\"\"Apply weight decay to vars in var_list.\"\"\"\n  if not hparams.weight_decay_rate:\n    return 0.\n\n  only_features = hparams.weight_decay_only_features\n  var_list = [v for v in tf.trainable_variables()]\n  weight_decays = []\n  for v in var_list:\n    # Weight decay.\n    is_feature = any(n in v.name for n in hparams.weight_decay_weight_names)\n    if (not only_features) or is_feature:\n      if hparams.initializer == \"hadamard_unscaled\":\n        v_loss = tf.reduce_sum((tf.abs(v) - 1)**2) / 2\n      else:\n        v_loss = tf.nn.l2_loss(v)\n      weight_decays.append(v_loss)\n\n  return tf.reduce_sum(weight_decays, axis=0) * hparams.weight_decay_rate\n\n\ndef axis_aligned_cost(logits, hparams):\n  negativity_cost = tf.nn.relu(-logits)\n  max_mask = tf.one_hot(tf.argmax(tf.abs(logits), -1), hparams.num_classes)\n  min_logits = tf.abs(logits) * (1 - max_mask)\n  max_logit = tf.abs(logits) * max_mask\n  one_bound = tf.nn.relu(logits - hparams.logit_bound)\n  axis_alignedness_cost = tf.nn.relu(min_logits - 0.1 * hparams.logit_bound)\n\n  logits_packed = tf.reduce_all(tf.less(max_logit, hparams.logit_bound), -1)\n  logits_packed = tf.logical_and(logits_packed,\n                                 tf.reduce_all(\n                                     tf.less(min_logits,\n                                             0.1 * hparams.logit_bound), -1))\n  logits_packed = tf.reduce_mean(tf.to_float(logits_packed))\n  tf.summary.scalar(\"logits_packed\", logits_packed)\n  tf.summary.scalar(\n      \"logits_max\",\n      tf.to_float(tf.shape(max_logit)[-1]) * tf.reduce_mean(max_logit))\n\n  return negativity_cost, axis_alignedness_cost, one_bound\n\n\ndef ard_cost():\n  with tf.variable_scope(\"ard_cost\"):\n    cost = 0\n    for v in tf.trainable_variables():\n      if \"kernel\" in v.name or \"DW\" in v.name:\n        rv = tf.reshape(v, [-1, int(v.shape[-1])])\n        sq_rv = tf.square(rv)\n        sum_sq = tf.reduce_sum(sq_rv, axis=1, keepdims=True)\n        ard = sq_rv / (sum_sq / tf.cast(tf.shape(sq_rv)[1], tf.float32)\n                      ) - 0.5 * tf.log(sum_sq)\n        cost += tf.reduce_sum(ard)\n\n    return cost\n\n\ndef shape_list(x):\n  \"\"\"Return list of dims, statically where possible.\"\"\"\n  x = tf.convert_to_tensor(x)\n\n  # If unknown rank, return dynamic shape\n  if x.get_shape().dims is None:\n    return tf.shape(x)\n\n  static = x.get_shape().as_list()\n  shape = tf.shape(x)\n\n  ret = []\n  for i, dim in enumerate(static):\n    if dim is None:\n      dim = shape[i]\n    ret.append(dim)\n  return ret\n\n\ndef standardize_images(x):\n  \"\"\"Image standardization on batches.\"\"\"\n\n  with tf.name_scope(\"standardize_images\", [x]):\n    x = tf.to_float(x)\n    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keep_dims=True)\n    x_variance = tf.reduce_mean(\n        tf.square(x - x_mean), axis=[1, 2, 3], keep_dims=True)\n    x_shape = shape_list(x)\n    num_pixels = tf.to_float(x_shape[1] * x_shape[2] * x_shape[3])\n    x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))\n    return x\n\n\ndef batch_norm(inputs, hparams, training):\n  \"\"\"Performs a batch normalization using a standard set of parameters.\"\"\"\n  # We set fused=True for a significant performance boost. See\n  # https://www.tensorflow.org/performance/performance_guide#common_fused_ops\n  if hparams.data_format == \"channels_first\":\n    axis = 1\n  else:\n    axis = -1\n\n  return tf.layers.batch_normalization(\n      inputs=inputs,\n      axis=axis,\n      momentum=0.997,\n      epsilon=0.001,\n      center=True,\n      scale=True,\n      training=training,\n      fused=True)\n\n\ndef louizos_complexity_cost(params):\n  gates = {\n      w.name.strip(\":0\"): w\n      for w in tf.trainable_variables()\n      if \"gates\" in w.name\n  }\n  names = list(gates.keys())\n  concat_gates = tf.concat([tf.reshape(gates[name], [-1]) for name in names],\n                           0)\n  if params.dropout_type == \"louizos_weight\":\n    complexity_cost = tf.nn.sigmoid(\n        concat_gates - params.louizos_beta * tf.\n        log(-1 * params.louizos_gamma / params.louizos_zeta))\n  elif params.dropout_type == \"louizos_unit\":\n    reshaped_gates = [\n        tf.reshape(gates[name], [-1, gates[name].shape[-1]]) for name in names\n    ]\n\n    parameters = []\n    for name in names:\n      g_name = name[:-len(\"louizos/gates\")] + \"DW\"\n      g = tf.contrib.framework.get_unique_variable(g_name)\n      parameters.extend(\n          [reduce(operator.mul,\n                  g.shape.as_list()[:-1], 1)] * g.shape.as_list()[-1])\n    group_sizes = tf.constant(parameters)\n    assert group_sizes.shape[0] == concat_gates.shape[0], \"{} != {}\".format(\n        group_sizes.shape[0], concat_gates.shape[0])\n\n    complexity_cost = tf.cast(group_sizes, tf.float32) * tf.nn.sigmoid(\n        concat_gates - params.louizos_beta * tf.\n        log(-1 * params.louizos_gamma / params.louizos_zeta))\n  return tf.reduce_sum(complexity_cost)\n\n\ndef switch_loss():\n  losses = 0\n\n  for v in tf.trainable_variables():\n    if \"switch\" in v.name:\n      losses += tf.reduce_sum(tf.abs(v))\n\n  tf.summary.scalar(\"switch_loss\", losses)\n  return losses\n\n\ndef nonzero_count():\n  nonzeroes = 0\n  for op in tf.get_default_graph().get_operations():\n    if \"post_dropout\" in op.name:\n      v = tf.get_default_graph().get_tensor_by_name(op.name + \":0\")\n      count = tf.to_float(tf.equal(v, 0.))\n      count = tf.reduce_sum(1 - count)\n      nonzeroes += count\n  return nonzeroes\n\n\ndef percent_sparsity():\n  nonzeroes = 0\n  total = 0\n  for op in tf.get_default_graph().get_operations():\n    if \"post_dropout\" in op.name:\n      v = tf.get_default_graph().get_tensor_by_name(op.name + \":0\")\n      count = tf.to_float(tf.equal(v, 0.))\n      count = tf.reduce_sum(1 - count)\n      nonzeroes += count\n      total += tf.size(v)\n  return tf.to_float(nonzeroes) / tf.to_float(total)\n\n\ndef convert(num, base, length=None):\n  ''' Converter from decimal to numeral systems from base 2 to base 10 '''\n  num = int(num)\n  base = int(base)\n  result = []\n  if num == 0:\n    result.append(0)\n  else:\n    while (num > 0):\n      result.append(num % base)\n      num //= base\n  # Reverse from LSB to MSB\n  result = result[::-1]\n  if length is not None:\n    n_to_fill = length - len(result)\n    if n_to_fill > 0:\n      result = [0] * n_to_fill + result\n  return result\n\n\ndef equal_mult(size, num_branches):\n  return [\n      tf.constant(1.0 / num_branches, shape=[size, 1, 1, 1], dtype=tf.float32)\n      for _ in range(num_branches)\n  ]\n\n\ndef uniform(size, num_branches):\n  return [\n      tf.random_uniform([size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)\n      for _ in range(num_branches)\n  ]\n\n\ndef bernoulli(size, num_branches):\n  random = tf.random_uniform([size], maxval=num_branches, dtype=tf.int32)\n  bernoulli = tf.one_hot(random, depth=num_branches)\n  rand = tf.split(bernoulli, [1] * num_branches, 1)\n  rand = [tf.reshape(x, [-1, 1, 1, 1]) for x in rand]\n  return rand\n\n\ndef combine(rand_uniform, rand_bernoulli, num_branches):\n  return [\n      tf.concat([rand_uniform[i], rand_bernoulli[i]], axis=0)\n      for i in range(num_branches)\n  ]\n\n\ndef model_top(labels, preds, cost, lr, mode, hparams):\n  tf.summary.scalar(\"acc\",\n      tf.reduce_mean(\n          tf.to_float(\n              tf.equal(labels,\n                       tf.argmax(\n                           preds, axis=-1,\n                           output_type=tf.int32)))))\n  tf.summary.scalar(\"loss\", cost)\n\n  gs = tf.train.get_global_step()\n\n  if hparams.weight_decay_and_noise:\n    cost = weight_decay_and_noise(cost, hparams, lr)\n    cost = tf.identity(cost, name=\"total_loss\")\n  optimizer = get_optimizer(lr, hparams)\n\n  train_op = tf.contrib.layers.optimize_loss(\n      name=\"training\",\n      loss=cost,\n      global_step=gs,\n      learning_rate=lr,\n      clip_gradients=hparams.clip_grad_norm or None,\n      gradient_noise_scale=hparams.grad_noise_scale or None,\n      optimizer=optimizer,\n      colocate_gradients_with_ops=True)\n\n  if hparams.use_tpu:\n\n    def metric_fn(l, p):\n      return {\n          \"acc\":\n          tf.metrics.accuracy(\n              labels=l, predictions=tf.argmax(p, -1, output_type=tf.int32)),\n      }\n\n    host_call = None\n    if hparams.tpu_summarize:\n      host_call = tpu.create_host_call(hparams.output_dir)\n    tpu.remove_summaries()\n\n    if mode == tf.estimator.ModeKeys.EVAL:\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          predictions=preds,\n          loss=cost,\n          eval_metrics=(metric_fn, [labels, preds]),\n          host_call=host_call)\n\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode, loss=cost, train_op=train_op, host_call=host_call)\n\n  return tf.estimator.EstimatorSpec(\n      mode,\n      eval_metric_ops={\n          \"acc\":\n          tf.metrics.accuracy(\n              labels=labels,\n              predictions=tf.argmax(preds, axis=-1, output_type=tf.int32)),\n      },\n      loss=cost,\n      train_op=train_op)\n"
  },
  {
    "path": "models/utils/optimizers.py",
    "content": "import tensorflow as tf\n\n_OPTIMIZER = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _OPTIMIZER\n    _OPTIMIZER[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_optimizer(lr, params):\n  optimizer = _OPTIMIZER[params.optimizer](lr, params)\n  if params.use_tpu:\n    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)\n  return optimizer\n\n\n@register(\"sgd\")\ndef sgd(lr, params):\n  return tf.train.GradientDescentOptimizer(lr)\n\n\n@register(\"adam\")\ndef adam(lr, params):\n  return tf.train.AdamOptimizer(lr, beta1=params.beta1, beta2=params.beta2)\n\n\n@register(\"adagrad\")\ndef adagrad(lr, params):\n  return tf.train.AdagradOptimizer(lr)\n\n\n@register(\"momentum\")\ndef momentum(lr, params):\n  return tf.train.MomentumOptimizer(\n      lr, momentum=params.momentum, use_nesterov=params.use_nesterov)\n"
  },
  {
    "path": "models/vgg/__init__.py",
    "content": "__all__ = [\"vgg\"]\n"
  },
  {
    "path": "models/vgg/vgg.py",
    "content": "import tensorflow as tf\n\nfrom ..utils.activations import get_activation\nfrom ..utils.dropouts import get_dropout\nfrom ..utils.initializations import get_init\nfrom ..utils.optimizers import get_optimizer\nfrom ..registry import register\nfrom ..utils import model_utils\nfrom ..utils import dropouts\nfrom ...training import tpu\nimport six\n\nimport numpy as np\nfrom tensorflow.contrib.tpu.python.tpu import tpu_estimator, tpu_optimizer\n\n\ndef metric_fn(labels, predictions):\n  return {\n      \"acc\":\n      tf.metrics.accuracy(\n          labels=tf.argmax(labels, -1), predictions=tf.argmax(predictions,\n                                                              -1)),\n  }\n\n\n@register(\"vgg\")\ndef get_vgg(hparams, lr):\n  \"\"\"Callable model function compatible with Experiment API.\"\"\"\n\n  def vgg(features, labels, mode, params):\n    if hparams.use_tpu and 'batch_size' in params.keys():\n      hparams.batch_size = params['batch_size']\n\n    is_training = mode == tf.estimator.ModeKeys.TRAIN\n\n    inputs = features[\"inputs\"]\n    with tf.variable_scope(\"vgg\", initializer=get_init(hparams)):\n      total_nonzero = 0\n      conv1_1 = model_utils.conv(\n          inputs, 3, 64, hparams, name=\"conv1_1\", is_training=is_training)\n\n      conv1_1 = model_utils.batch_norm(conv1_1, hparams, is_training)\n      conv1_1 = tf.nn.relu(conv1_1)\n\n      conv1_2 = model_utils.conv(\n          conv1_1, 3, 64, hparams, name=\"conv1_2\", is_training=is_training)\n      conv1_2 = model_utils.batch_norm(conv1_2, hparams, is_training)\n      conv1_2 = tf.nn.relu(conv1_2)\n\n      pool1 = tf.layers.max_pooling2d(\n          conv1_2, 2, 2, padding=\"SAME\", name='pool1')\n\n      conv2_1 = model_utils.conv(\n          pool1, 3, 128, hparams, name=\"conv2_1\", is_training=is_training)\n      conv2_1 = model_utils.batch_norm(conv2_1, hparams, is_training)\n      conv2_1 = tf.nn.relu(conv2_1)\n\n      conv2_2 = model_utils.conv(\n          conv2_1, 3, 128, hparams, name=\"conv2_2\", is_training=is_training)\n      conv2_2 = model_utils.batch_norm(conv2_2, hparams, is_training)\n      conv2_2 = tf.nn.relu(conv2_2)\n\n      pool2 = tf.layers.max_pooling2d(\n          conv2_2, 2, 2, padding=\"SAME\", name='pool2')\n\n      conv3_1 = model_utils.conv(\n          pool2, 3, 256, hparams, name=\"conv3_1\", is_training=is_training)\n      conv3_1 = model_utils.batch_norm(conv3_1, hparams, is_training)\n      conv3_1 = tf.nn.relu(conv3_1)\n\n      conv3_2 = model_utils.conv(\n          conv3_1, 3, 256, hparams, name=\"conv3_2\", is_training=is_training)\n      conv3_2 = model_utils.batch_norm(conv3_2, hparams, is_training)\n      conv3_2 = tf.nn.relu(conv3_2)\n\n      conv3_3 = model_utils.conv(\n          conv3_2, 3, 256, hparams, name=\"conv3_3\", is_training=is_training)\n      conv3_3 = model_utils.batch_norm(conv3_3, hparams, is_training)\n      conv3_3 = tf.nn.relu(conv3_3)\n\n      pool3 = tf.layers.max_pooling2d(\n          conv3_3, 2, 2, padding=\"SAME\", name='pool3')\n\n      conv4_1 = model_utils.conv(\n          pool3, 3, 512, hparams, name=\"conv4_1\", is_training=is_training)\n      conv4_1 = model_utils.batch_norm(conv4_1, hparams, is_training)\n      conv4_1 = tf.nn.relu(conv4_1)\n\n      conv4_2 = model_utils.conv(\n          conv4_1, 3, 512, hparams, name=\"conv4_2\", is_training=is_training)\n      conv4_2 = model_utils.batch_norm(conv4_2, hparams, is_training)\n      conv4_2 = tf.nn.relu(conv4_2)\n\n      conv4_3 = model_utils.conv(\n          conv4_2, 3, 512, hparams, name=\"conv4_3\", is_training=is_training)\n      conv4_3 = model_utils.batch_norm(conv4_3, hparams, is_training)\n      conv4_3 = tf.nn.relu(conv4_3)\n\n      pool4 = tf.layers.max_pooling2d(\n          conv4_3, 2, 2, padding=\"SAME\", name='pool4')\n\n      conv5_1 = model_utils.conv(\n          pool4, 3, 512, hparams, name=\"conv5_1\", is_training=is_training)\n      conv5_1 = model_utils.batch_norm(conv5_1, hparams, is_training)\n      conv5_1 = tf.nn.relu(conv5_1)\n\n      conv5_2 = model_utils.conv(\n          conv5_1, 3, 512, hparams, name=\"conv5_2\", is_training=is_training)\n      conv5_2 = model_utils.batch_norm(conv5_2, hparams, is_training)\n      conv5_2 = tf.nn.relu(conv5_2)\n\n      conv5_3 = model_utils.conv(\n          conv5_2, 3, 512, hparams, name=\"conv5_3\", is_training=is_training)\n      conv5_3 = model_utils.batch_norm(conv5_3, hparams, is_training)\n      conv5_3 = tf.nn.relu(conv5_3)\n\n      pool5 = tf.layers.max_pooling2d(\n          conv5_3, 2, 2, padding=\"SAME\", name='pool5')\n\n      flat_x = tf.reshape(pool5, [hparams.batch_size, 512])\n      fc6 = model_utils.batch_norm(\n          model_utils.dense(flat_x, 4096, hparams, is_training), hparams,\n          is_training)\n      fc7 = model_utils.batch_norm(\n          model_utils.dense(fc6, 4096, hparams, is_training), hparams,\n          is_training)\n\n      logits = tf.layers.dense(fc7, hparams.num_classes, name=\"logits\")\n      probs = tf.nn.softmax(logits, axis=-1)\n\n      if mode in [model_utils.ModeKeys.PREDICT, model_utils.ModeKeys.ATTACK]:\n        return tf.estimator.EstimatorSpec(\n            mode=mode,\n            predictions={\n                'classes': tf.argmax(probs, axis=1),\n                'logits': logits,\n                'probabilities': probs,\n            })\n\n      xent = tf.losses.sparse_softmax_cross_entropy(\n          labels=labels, logits=logits)\n      cost = tf.reduce_mean(xent, name='xent')\n      cost += model_utils.weight_decay(hparams)\n\n      tf.summary.scalar(\"total_nonzero\", model_utils.nonzero_count())\n      tf.summary.scalar(\"percent_sparsity\", model_utils.percent_sparsity())\n      if hparams.dropout_type is not None:\n        if \"louizos\" in hparams.dropout_type:\n          cost += hparams.louizos_cost * model_utils.louizos_complexity_cost(\n              hparams) / 50000\n\n        if \"variational\" in hparams.dropout_type:\n          # prior DKL part of the ELBO\n          graph = tf.get_default_graph()\n          node_defs = [\n              n for n in graph.as_graph_def().node if 'log_alpha' in n.name\n          ]\n          log_alphas = [\n              graph.get_tensor_by_name(n.name + \":0\") for n in node_defs\n          ]\n          print([\n              n.name\n              for n in graph.as_graph_def().node\n              if 'log_alpha' in n.name\n          ])\n          print(\"found %i logalphas\" % len(log_alphas))\n          divergences = [dropouts.dkl_qp(la) for la in log_alphas]\n          # combine to form the ELBO\n          N = float(50000)\n          dkl = tf.reduce_sum(tf.stack(divergences))\n\n          warmup_steps = 50000\n          dkl = (1. / N) * dkl * tf.minimum(\n              1.0,\n              tf.to_float(tf.train.get_global_step()) /\n              warmup_steps) * hparams.var_scale\n          cost += dkl\n          tf.summary.scalar(\"dkl\", dkl)\n\n      if hparams.ard_cost > 0.0:\n        cost += model_utils.ard_cost() * hparams.ard_cost\n\n      if hparams.smallify > 0.0:\n        cost += model_utils.switch_loss() * hparams.smallify\n\n    return model_utils.model_top(labels, probs, cost, lr, mode, hparams)\n\n  return vgg\n"
  },
  {
    "path": "requirements.txt",
    "content": "tensorflow>=1.9\nrequests>=2.19.1\ndl-cloud>=0.0.4"
  },
  {
    "path": "scripts/__init__.py",
    "content": "\n"
  },
  {
    "path": "scripts/prune/README.md",
    "content": "# Library for Pruning\n"
  },
  {
    "path": "scripts/prune/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/prune/eval.py",
    "content": "import tensorflow as tf\nimport os\nimport numpy as np\n\nfrom ...hparams.registry import get_hparams\nfrom ...models.registry import get_model\nfrom ...data.registry import get_input_fns\nfrom ...training import flags\nfrom .prune import get_prune_fn, get_current_weights, get_louizos_masks, get_smallify_masks, prune_weights, is_prunable_weight\n\n\ndef init_flags():\n  tf.flags.DEFINE_string(\"model\", None, \"Which model to use.\")\n  tf.flags.DEFINE_string(\"data\", None, \"Which data to use.\")\n  tf.flags.DEFINE_string(\"env\", None, \"Which environment to use.\")\n  tf.flags.DEFINE_string(\"hparams\", None, \"Which hparams to use.\")\n  tf.flags.DEFINE_string(\"hparam_override\", \"\",\n                         \"Run-specific hparam settings to use.\")\n  tf.flags.DEFINE_string(\"output_dir\", None, \"The output directory.\")\n  tf.flags.DEFINE_string(\"data_dir\", None, \"The data directory.\")\n  tf.flags.DEFINE_integer(\"train_steps\", 10000,\n                          \"Number of training steps to perform.\")\n  tf.flags.DEFINE_integer(\"eval_every\", 1000,\n                          \"Number of steps between evaluations.\")\n  tf.flags.DEFINE_string(\n      \"post_weights_dir\", \"\",\n      \"folder of the weights, if not set defaults to output_dir\")\n  tf.flags.DEFINE_string(\"prune_percent\", \"0.5\",\n                         \"percent of weights to prune, comma separated\")\n  tf.flags.DEFINE_string(\"prune\", \"weight\", \"one_shot or fisher\")\n  tf.flags.DEFINE_boolean(\"variational\", False, \"use evaluate\")\n  tf.flags.DEFINE_string(\"eval_file\", \"eval_prune_results\",\n                         \"file to put results\")\n  tf.flags.DEFINE_integer(\"train_epochs\", None,\n                          \"Number of training epochs to perform.\")\n  tf.flags.DEFINE_integer(\"eval_steps\", None,\n                          \"Number of evaluation steps to perform.\")\n\ndef eval_model(FLAGS, hparam_name):\n  hparams = get_hparams(hparam_name)\n  hparams = hparams.parse(FLAGS.hparam_override)\n  hparams = flags.update_hparams(FLAGS, hparams)\n\n  model_fn = get_model(hparams)\n  _, _, test_input_fn = get_input_fns(hparams, generate=False)\n\n  features, labels = test_input_fn()\n  sess = tf.Session()\n  tf.train.create_global_step()\n  model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)\n  saver = tf.train.Saver()\n  ckpt_dir = tf.train.latest_checkpoint(hparams.output_dir)\n  print(\"Loading model from...\", ckpt_dir)\n  saver.restore(sess, ckpt_dir)\n\n  evals = []\n  prune_percents = [float(i) for i in FLAGS.prune_percent.split(\",\")]\n\n  mode = \"standard\"\n  orig_weights = get_current_weights(sess)\n  louizos_masks, smallify_masks = None, None\n  if \"louizos\" in hparam_name:\n    louizos_masks = get_louizos_masks(sess, orig_weights)\n    mode = \"louizos\"\n  elif \"smallify\" in hparam_name:\n    smallify_masks = get_smallify_masks(sess, orig_weights)\n  elif \"variational\" in hparam_name:\n    mode = \"variational\"\n\n  for prune_percent in prune_percents:\n    if prune_percent > 0.0:\n      prune_fn = get_prune_fn(FLAGS.prune)(mode, k=prune_percent)\n      w_copy = dict(orig_weights)\n      sm_copy = dict(smallify_masks) if smallify_masks is not None else None\n      lm_copy = dict(louizos_masks) if louizos_masks is not None else None\n      post_weights_pruned, weight_counts = prune_weights(\n          prune_fn,\n          w_copy,\n          louizos_masks=lm_copy,\n          smallify_masks=sm_copy,\n          hparams=hparams)\n      print(\"current weight counts at {}: {}\".format(prune_percent,\n                                                     weight_counts))\n\n      print(\"there are \", len(tf.trainable_variables()), \" weights\")\n      for v in tf.trainable_variables():\n        if is_prunable_weight(v):\n          assign_op = v.assign(\n              np.reshape(post_weights_pruned[v.name.strip(\":0\")], v.shape))\n          sess.run(assign_op)\n\n    saver.save(sess, os.path.join(hparams.output_dir, \"tmp\", \"model\"))\n    estimator = tf.estimator.Estimator(\n        model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),\n        model_dir=os.path.join(hparams.output_dir, \"tmp\"))\n    print(\n        f\"Processing pruning {prune_percent} of weights for {hparams.eval_steps} steps\"\n    )  \n    acc = estimator.evaluate(test_input_fn, hparams.eval_steps)['acc']\n    print(f\"Accuracy @ prune {100*prune_percent}% is {acc}\")\n    evals.append(acc)\n  return evals\n\n\ndef _run(FLAGS):\n  eval_file = open(FLAGS.eval_file, \"w\")\n\n  hparams_list = FLAGS.hparams.split(\",\")\n  total_evals = {}\n  for hparam_name in hparams_list:\n    evals = eval_model(FLAGS, hparam_name)\n\n    print(hparam_name, \":\", evals)\n    eval_file.writelines(\"{}:{}\\n\".format(hparam_name, evals))\n    total_evals[hparam_name] = evals\n    tf.reset_default_graph()\n\n  print(\"processed results:\", total_evals)\n  eval_file.close()\n\n\nif __name__ == \"__main__\":\n  init_flags()\n  FLAGS = tf.app.flags.FLAGS\n  _run(FLAGS)\n"
  },
  {
    "path": "scripts/prune/prune.py",
    "content": "import numpy as np\nimport tensorflow as tf\nimport statistics\nfrom ...models.utils import model_utils\n\n_PRUNE_FN = dict()\n\n\ndef register(fn):\n  global _PRUNE_FN\n  _PRUNE_FN[fn.__name__] = fn\n  return fn\n\n\ndef get_prune_fn(name):\n  return _PRUNE_FN[name]\n\n\n@register\ndef weight(mode, k=0.5):\n\n  if mode == \"standard\":\n\n    def prune(weight_dict, weight_key):\n      weights = weight_dict[weight_key]\n      w = weights.copy()\n      if len(weights.shape) == 4:\n        w = w.reshape([-1, weights.shape[-1]])\n\n      abs_w = np.abs(w)\n      idx = int(k * abs_w.shape[0])\n      med = np.sort(abs_w, axis=0)[idx:idx + 1]\n      mask = (abs_w >= med).astype(float)\n      pruned_w = mask * w\n\n      return pruned_w, mask\n  elif mode == \"variational\":\n\n    def prune(weight_dict, weight_key):\n      weights = weight_dict[weight_key]\n      if k == 0.0:\n        return weights, None\n      log_alpha = weight_dict[weight_key.strip(\"DW\") + \"variational/log_alpha\"]\n      w = weights.copy()\n      la = log_alpha.copy()\n      if len(weights.shape) == 4:\n        w = w.reshape([-1, weights.shape[-1]])\n        la = la.reshape([-1, weights.shape[-1]])\n\n      idx = int((1 - k) * la.shape[0])\n      med = np.sort(la, axis=0)[idx:idx + 1]\n      mask = (la < med).astype(float)\n      pruned_w = mask * w\n\n      return pruned_w, mask\n  elif mode == \"louizos\":\n\n    def prune(weight_dict, weight_key):\n      weights = weight_dict[weight_key]\n      w = weights.copy()\n      if len(weights.shape) == 4:\n        w = w.reshape([-1, weights.shape[-1]])\n\n      idx = int(k * w.shape[0])\n      med = np.sort(w, axis=0)[idx:idx + 1]\n      mask = (w >= med).astype(float)\n      pruned_w = mask * w\n\n      return pruned_w, mask\n\n  return prune\n\n\n@register\ndef unit(mode, k=0.5):\n\n  if mode == \"standard\" or mode == \"variational\":\n\n    def prune(weight_dict, weight_key):\n      weights = weight_dict[weight_key]\n      w = weights.copy()\n      if len(weights.shape) == 4:\n        w = w.reshape([-1, weights.shape[-1]])\n      norm = np.linalg.norm(w, axis=0)\n      idx = int(k * norm.shape[0])\n      med = np.sort(norm, axis=0)[idx]\n      mask = (norm >= med).astype(float)\n      pruned_w = mask * w\n\n      return pruned_w, mask\n  elif mode == \"louizos\":\n\n    def prune(weight_dict, weight_key):\n      weights = weight_dict[weight_key]\n      w = weights.copy()\n      assert len(weights.shape) == 1\n      idx = int(k * w.shape[0])\n      med = np.sort(w, axis=0)[idx]\n      mask = (w >= med).astype(float)\n      pruned_w = mask * w\n\n      return pruned_w, mask\n\n  return prune\n\n\n@register\ndef ard(k=0.5):\n\n  def prune(weight_dict, weight_key):\n    weights = weight_dict[weight_key]\n    w = weights.copy()\n    if len(weights.shape) == 4:\n      w = w.reshape([-1, weights.shape[-1]])\n    norm = np.linalg.norm(w, axis=1, keepdims=True)\n    idx = int(k * norm.shape[0])\n    med = np.sort(norm, axis=0)[idx]\n    mask = (norm >= med).astype(float)\n    pruned_w = mask * w\n\n    return pruned_w, mask\n\n  return prune\n\n\ndef prune_weights(prune_fn,\n                  weights,\n                  louizos_masks=None,\n                  smallify_masks=None,\n                  hparams=None):\n  weights_pruned = {}\n\n  pre_prune_nonzero = 0\n  pre_prune_total = 0\n  if louizos_masks:\n    orig_weights = dict(weights)\n    for weight_name in weights:\n      if weight_name not in louizos_masks.keys():\n        print(\"WARN louizos: mask not found for {}\".format(weight_name))\n        continue\n      weights[weight_name] = louizos_masks[weight_name]\n  elif smallify_masks:\n    orig_weights = dict(weights)\n    for weight_name in weights:\n      if weight_name not in smallify_masks.keys():\n        print(\"WARN smallify: not pruning {}\".format(weight_name))\n        continue\n      mask = smallify_masks[weight_name]\n      weights[weight_name] = weights[weight_name] * mask\n\n  for weight_name in weights:\n    if \"variational\" in weight_name:\n      print(\"WARN variational: not pruning {}\".format(weight_name))\n      continue\n\n    pre_prune_nonzero += np.count_nonzero(weights[weight_name])\n    pre_prune_total += weights[weight_name].size\n\n    weights_pruned[weight_name], mask = prune_fn(weights, weight_name)\n    if louizos_masks or smallify_masks:\n      print(\"applied masks to\", weight_name)\n      weights_pruned[weight_name] = mask * orig_weights[weight_name].reshape(\n          [-1, orig_weights[weight_name].shape[-1]])\n\n  return weights_pruned, {\n      \"pre_prune_nonzero\": pre_prune_nonzero,\n      \"pre_prune_total\": pre_prune_total\n  }\n\n\ndef get_louizos_masks(sess, weights):\n  masks = {}\n  for weight_name in weights:\n    m_name = weight_name.strip(\"DW\") + \"louizos/gates\"\n    m = tf.contrib.framework.get_variables_by_name(m_name)\n    assert len(m) == 1\n    m = m[0]\n    masks[weight_name] = sess.run(m)\n\n  return masks\n\n\ndef get_smallify_masks(sess, weights):\n  masks = {}\n  for weight_name in weights:\n    switch_name = weight_name.strip(\"DW\") + \"smallify/switch\"\n    mask_name = weight_name.strip(\"DW\") + \"smallify/mask\"\n    switch = tf.contrib.framework.get_variables_by_name(switch_name)\n    mask = tf.contrib.framework.get_variables_by_name(mask_name)\n    assert len(switch) == 1 and len(mask) == 1\n    switch, mask = switch[0], mask[0]\n    switch, mask = sess.run((switch, mask))\n\n    masks[weight_name] = switch * mask\n\n  return masks\n\n\ndef is_prunable_weight(weight):\n  necessary_tokens = [\"kernel\", \"DW\", \"variational\"]\n  blacklisted_tokens = [\"logit\", \"fc\", \"init\", \"switch\", \"mask\", \"log_sigma\"]\n\n  contains_a_necessary_token = any(t in weight.name for t in necessary_tokens)\n  contains_a_blacklisted_token = any(\n      t in weight.name for t in blacklisted_tokens)\n\n  is_prunable = contains_a_necessary_token and not contains_a_blacklisted_token\n\n  if not is_prunable:\n    print(\"WARN: not pruning %s\" % weight.name)\n\n  return is_prunable\n\n\ndef get_current_weights(sess):\n  weights = {}\n  variables = {}\n  for v in tf.trainable_variables():\n    if is_prunable_weight(v):\n      name = v.name.strip(\":0\")\n      variables[name] = v\n\n  graph = tf.get_default_graph()\n  node_defs = [n for n in graph.as_graph_def().node if 'log_alpha' in n.name]\n\n  for n in node_defs:\n    weights[n.name] = sess.run(graph.get_tensor_by_name(n.name + \":0\"))\n\n  for weight_name, w in variables.items():\n    weights[weight_name] = sess.run(w)\n\n  return weights\n\n\ndef prune_sess_weights(sess, prune_percent, FLAGS, hparams):\n  current_weights = get_current_weights(sess)\n  prune_fn = get_prune_fn(FLAGS.prune)(k=prune_percent)\n  current_weights_pruned = prune_weights(prune_fn, current_weights, None,\n                                         hparams)\n\n  print(\"there are \", len(tf.trainable_variables()), \" weights\")\n  for v in tf.trainable_variables():\n    if is_prunable_weight(v):\n      assign_op = v.assign(\n          np.reshape(current_weights_pruned[v.name.strip(\":0\")], v.shape))\n      sess.run(assign_op)\n"
  },
  {
    "path": "train.py",
    "content": "import cloud\nimport os\nimport sys\nimport subprocess\nimport random\nimport tensorflow as tf\nimport numpy as np\nimport time\nimport logging\n\nfrom .hparams.registry import get_hparams\nfrom .models.registry import get_model\nfrom .data.registry import get_input_fns\nfrom .training.lr_schemes import get_lr\nfrom .training.envs import get_env\nfrom .training import flags\nfrom tensorflow.contrib.tpu.python.tpu import tpu_config\nfrom tensorflow.contrib.tpu.python.tpu import tpu_estimator\n\n\ndef init_flags():\n  tf.flags.DEFINE_string(\"env\", None, \"Which environment to use.\")  # required\n  tf.flags.DEFINE_string(\"hparams\", None, \"Which hparams to use.\")  # required\n  # Utility flags\n  tf.flags.DEFINE_string(\"hparam_override\", \"\",\n                         \"Run-specific hparam settings to use.\")\n  tf.flags.DEFINE_boolean(\"fresh\", False, \"Remove output_dir before running.\")\n  tf.flags.DEFINE_integer(\"seed\", None, \"Random seed.\")\n  tf.flags.DEFINE_integer(\"train_epochs\", None,\n                          \"Number of training epochs to perform.\")\n  tf.flags.DEFINE_integer(\"eval_steps\", None,\n                          \"Number of evaluation steps to perform.\")\n  # TPU flags\n  tf.flags.DEFINE_string(\"tpu_name\", \"\", \"Name of TPU(s)\")\n  tf.flags.DEFINE_integer(\n      \"tpu_iterations_per_loop\", 1000,\n      \"The number of training steps to run on TPU before\"\n      \"returning control to CPU.\")\n  tf.flags.DEFINE_integer(\n      \"tpu_shards\", 8, \"The number of TPU shards in the system \"\n      \"(a single Cloud TPU has 8 shards.\")\n  tf.flags.DEFINE_boolean(\n      \"tpu_summarize\", False, \"Save summaries for TensorBoard. \"\n      \"Warning: this will slow down execution.\")\n  tf.flags.DEFINE_boolean(\"tpu_dedicated\", False,\n                          \"Do not use preemptible TPUs.\")\n  tf.flags.DEFINE_string(\"data_dir\", None, \"The data directory.\")\n  tf.flags.DEFINE_string(\"output_dir\", None, \"The output directory.\")\n  tf.flags.DEFINE_integer(\"eval_every\", 1000,\n                          \"Number of steps between evaluations.\")\n\n\ntf.logging.set_verbosity(tf.logging.INFO)\nFLAGS = None\n\n\ndef init_random_seeds():\n  tf.set_random_seed(FLAGS.seed)\n  random.seed(FLAGS.seed)\n  np.random.seed(FLAGS.seed)\n\n\ndef init_model(hparams_name):\n  flags.validate_flags(FLAGS)\n\n  tf.reset_default_graph()\n\n  hparams = get_hparams(hparams_name)\n  hparams = hparams.parse(FLAGS.hparam_override)\n  hparams = flags.update_hparams(FLAGS, hparams, hparams_name)\n\n  # set larger eval_every for TPUs to improve utilization\n  if FLAGS.env == \"tpu\":\n    FLAGS.eval_every = max(FLAGS.eval_every, 5000)\n    hparams.tpu_summarize = FLAGS.tpu_summarize\n\n  tf.logging.warn(\"\\n-----------------------------------------\\n\"\n                  \"BEGINNING RUN:\\n\"\n                  \"\\t hparams: %s\\n\"\n                  \"\\t output_dir: %s\\n\"\n                  \"\\t data_dir: %s\\n\"\n                  \"-----------------------------------------\\n\" %\n                  (hparams_name, hparams.output_dir, hparams.data_dir))\n\n  return hparams\n\n\ndef construct_estimator(model_fn, hparams, tpu=None):\n  if hparams.use_tpu:\n    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(\n        tpu=tpu.name)\n    master = tpu_cluster_resolver.get_master()\n    config = tpu_config.RunConfig(\n        master=master,\n        evaluation_master=master,\n        model_dir=hparams.output_dir,\n        session_config=tf.ConfigProto(\n            allow_soft_placement=True, log_device_placement=True),\n        tpu_config=tpu_config.TPUConfig(\n            iterations_per_loop=FLAGS.tpu_iterations_per_loop,\n            num_shards=FLAGS.tpu_shards),\n        save_checkpoints_steps=FLAGS.eval_every)\n    estimator = tpu_estimator.TPUEstimator(\n        use_tpu=hparams.use_tpu,\n        model_fn=model_fn,\n        model_dir=hparams.output_dir,\n        config=config,\n        train_batch_size=hparams.batch_size,\n        eval_batch_size=hparams.batch_size)\n  else:\n    gpu_config = tf.ConfigProto(allow_soft_placement=True)\n    gpu_config.gpu_options.allow_growth = True\n    run_config = tf.estimator.RunConfig(\n        save_checkpoints_steps=FLAGS.eval_every, session_config=gpu_config)\n\n    estimator = tf.estimator.Estimator(\n        model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),\n        model_dir=hparams.output_dir,\n        config=run_config)\n\n  return estimator\n\n\ndef _run(hparams_name):\n  \"\"\"Run training, evaluation and inference.\"\"\"\n  hparams = init_model(hparams_name)\n  original_batch_size = hparams.batch_size\n  if tf.gfile.Exists(hparams.output_dir) and FLAGS.fresh:\n    tf.gfile.DeleteRecursively(hparams.output_dir)\n\n  if not tf.gfile.Exists(hparams.output_dir):\n    tf.gfile.MakeDirs(hparams.output_dir)\n  model_fn = get_model(hparams)\n  train_input_fn, eval_input_fn, test_input_fn = get_input_fns(hparams)\n\n  tpu = None\n  if hparams.use_tpu:\n    cloud.instance.tpu.clean()\n    tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)\n\n  estimator = construct_estimator(model_fn, hparams, tpu)\n\n  if not hparams.use_tpu:\n    features, labels = train_input_fn()\n    sess = tf.Session()\n    tf.train.get_or_create_global_step()\n\n    model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)\n    sess.run(tf.global_variables_initializer())\n\n  # output metadata about the run\n  with tf.gfile.GFile(os.path.join(hparams.output_dir, 'hparams.txt'),\n                      'w') as hparams_file:\n    hparams_file.write(\"{}\\n\".format(time.time()))\n    hparams_file.write(\"{}\\n\".format(str(hparams)))\n\n  def loop(steps=FLAGS.eval_every):\n    estimator.train(train_input_fn, steps=steps)\n    if eval_input_fn:\n      estimator.evaluate(eval_input_fn, steps=hparams.eval_steps, name=\"eval\")\n    if test_input_fn:\n      estimator.evaluate(test_input_fn, steps=hparams.eval_steps, name=\"test\")\n\n  loop(1)\n\n  steps = estimator.get_variable_value(\"global_step\")\n  k = steps * original_batch_size / float(hparams.epoch_size)\n  while k <= hparams.train_epochs:\n    tf.logging.info(\"Beginning epoch %f / %d\" % (k, hparams.train_epochs))\n\n    if tpu and not tpu.usable:\n      tpu.delete(async=True)\n      tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)\n      estimator = construct_estimator(model_fn, hparams, tpu)\n\n    loop()\n\n    steps = estimator.get_variable_value(\"global_step\")\n    k = steps * original_batch_size / float(hparams.epoch_size)\n\n\ndef main(_):\n  global FLAGS\n  FLAGS = tf.app.flags.FLAGS\n\n  init_random_seeds()\n  if FLAGS.env != \"local\":\n    cloud.connect()\n  for hparams_name in FLAGS.hparams.split(\",\"):\n    _run(hparams_name)\n\n\nif __name__ == \"__main__\":\n  init_flags()\n  tf.app.run()\n"
  },
  {
    "path": "training/__init__.py",
    "content": "__all__ = [\"lr_schemes\", \"tpu\", \"flags\"]\n\nfrom .lr_schemes import *\nfrom .tpu import *\nfrom .flags import *\n"
  },
  {
    "path": "training/envs.py",
    "content": "_ENVS = dict()\n\n\ndef register(cls):\n  global _ENVS\n  _ENVS[cls.__name__.lower()] = cls()\n  return cls\n\n\ndef get_env(name):\n  return _ENVS[name]\n\n\n@register\nclass GCP(object):\n  data_dir = \"/path/to/your/data\"\n  output_dir = \"/path/to/your/output\"\n\n\n@register\nclass TPU(object):\n  data_dir = \"/path/to/your/data\"\n  output_dir = \"/path/to/your/output\"\n\n\n@register\nclass Local(object):\n  data_dir = \"/tmp/data\"\n  output_dir = \"/tmp/runs\"\n"
  },
  {
    "path": "training/flags.py",
    "content": "import getpass\nimport os\nimport subprocess\n\nimport tensorflow as tf\n\nfrom .envs import get_env\n\n\ndef validate_flags(FLAGS):\n  messages = []\n  if not FLAGS.env:\n    messages.append(\"Missing required flag --env\")\n  if not FLAGS.hparams:\n    messages.append(\"Missing required flag --hparams\")\n\n  if len(messages) > 0:\n    raise Exception(\"\\n\".join(messages))\n\n  return FLAGS\n\n\ndef update_hparams(FLAGS, hparams, hparams_name):\n  hparams.env = FLAGS.env\n  hparams.use_tpu = hparams.env == \"tpu\"\n  hparams.train_epochs = FLAGS.train_epochs or hparams.train_epochs\n  hparams.eval_steps = FLAGS.eval_steps or hparams.eval_steps\n\n  env = get_env(FLAGS.env)\n  hparams.data_dir = os.path.join(FLAGS.data_dir or env.data_dir, hparams.data)\n  hparams.output_dir = os.path.join(env.output_dir, FLAGS.hparams)\n\n  return hparams\n"
  },
  {
    "path": "training/lr_schemes.py",
    "content": "import tensorflow as tf\n\n_LR = dict()\n\n\ndef register(name):\n\n  def add_to_dict(fn):\n    global _LR\n    _LR[name] = fn\n    return fn\n\n  return add_to_dict\n\n\ndef get_lr(params):\n  gs = tf.train.get_global_step()\n  return _LR[params.lr_scheme](gs, params)\n\n\n@register(\"constant\")\ndef constant(gs, params):\n  return tf.constant(params.learning_rate)\n\n\n@register(\"exp\")\ndef exponential_decay(gs, params, delay=0):\n  gs -= delay\n  return tf.train.exponential_decay(\n      params.learning_rate,\n      gs,\n      params.learning_rate_decay_interval,\n      params.learning_rate_decay_rate,\n      staircase=params.staircased)\n\n\n@register(\"lin\")\ndef linear_decay(gs, params, delay=0):\n  gs -= delay\n  return (\n      params.learning_rate -\n      (tf.to_float(gs) / (params.train_steps - delay)) * params.learning_rate)\n\n\n@register(\"delay_exp\")\ndef delayed_exponential_decay(gs, params):\n  d = params.delay\n  return tf.cond(\n      tf.greater(gs, d), lambda: exponential_decay(gs - d, params, delay=d),\n      lambda: params.learning_rate)\n\n\n@register(\"delay_lin\")\ndef delayed_linear_decay(gs, params):\n  d = params.delay\n  return tf.cond(\n      tf.greater(gs, d), lambda: linear_decay(gs - d, params, delay=d),\n      lambda: params.learning_rate)\n\n\n@register(\"warmup_resnet\")\ndef warmup_resnet(gs, params):\n  warmup_steps = params.warmup_steps\n  inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))\n\n  epoch = params.epoch_size // params.batch_size\n  boundaries = [epoch * 30, epoch * 60, epoch * 80, epoch * 90]\n  values = [1e0, 1e-1, 1e-2, 1e-3, 1e-4]\n  lr = tf.train.piecewise_constant(\n      gs - warmup_steps, boundaries=boundaries, values=values)\n\n  return tf.cond(\n      tf.greater(gs, warmup_steps), lambda: lr,\n      lambda: inv_decay * params.learning_rate)\n\n\n@register(\"resnet\")\ndef resnet(gs, params):\n  return tf.cond(\n      tf.less(gs, 40000),\n      lambda: params.learning_rate,\n      lambda: tf.cond(\n          tf.less(gs, 60000),\n          lambda: params.learning_rate*0.1,\n          lambda: tf.cond(\n              tf.less(gs, 80000),\n              lambda: params.learning_rate * 0.01,\n              lambda: params.learning_rate * 0.001)))\n\n\n@register(\"lenet\")\ndef lenet(gs, _):\n  return tf.cond(\n      tf.less(gs, 80000), lambda: 0.05,\n      lambda: tf.cond(tf.less(gs, 120000), lambda: 0.005, lambda: 0.0005))\n\n\n@register(\"steps\")\ndef stepped_lr(gs, params):\n  lr = params.lr_values[-1]\n  for step, value in reversed(list(zip(params.lr_steps, params.lr_values))):\n    lr = tf.cond(tf.greater(gs, step), lambda: lr, lambda: value)\n  return lr\n\n\n@register(\"warmup_linear_decay\")\ndef warmup_linear_decay(gs, params):\n  d = params.delay\n  warmup_steps = params.warmup_steps\n  inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))\n\n  return tf.cond(\n      tf.greater(gs, warmup_steps), lambda: linear_decay(gs, params, delay=d),\n      lambda: inv_decay * params.learning_rate)\n\n\n@register(\"warmup_constant\")\ndef warmup_constant(gs, params):\n  warmup_steps = params.warmup_steps\n  inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))\n\n  return tf.cond(\n      tf.greater(gs, warmup_steps), lambda: constant(gs, params),\n      lambda: inv_decay * params.learning_rate)\n\n\n@register(\"warmup_exponential_decay\")\ndef warmup_exponential_decay(gs, params):\n  d = params.delay\n  warmup_steps = params.warmup_steps\n  inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))\n\n  return tf.cond(\n      tf.greater(gs,\n                 warmup_steps), lambda: exponential_decay(gs, params, delay=d),\n      lambda: inv_decay * params.learning_rate)\n\n\n@register(\"warmup_cosine\")\ndef warmup_cosine(gs, params):\n  from numpy import pi\n\n  warmup_steps = params.warmup_steps\n  inv_base = tf.exp(tf.log(0.01) / warmup_steps)\n  inv_decay = inv_base**(warmup_steps - tf.to_float(gs))\n\n  gs = tf.minimum(gs - warmup_steps, params.learning_rate_cosine_cycle_steps)\n  cosine_decay = 0.5 * (1 + tf.cos(\n      pi * tf.to_float(gs) / params.learning_rate_cosine_cycle_steps))\n  decayed = (1 - params.cosine_alpha) * cosine_decay + params.cosine_alpha\n  lr = params.learning_rate * decayed\n\n  return tf.cond(\n      tf.greater(gs, warmup_steps), lambda: lr,\n      lambda: inv_decay * params.learning_rate)\n\n\n@register(\"cosine\")\ndef cosine_annealing(gs, params):\n  from numpy import pi\n\n  gs = tf.minimum(gs, params.learning_rate_cosine_cycle_steps)\n  cosine_decay = 0.5 * (1 + tf.cos(\n      pi * tf.to_float(gs) / params.learning_rate_cosine_cycle_steps))\n  decayed = (1 - params.cosine_alpha) * cosine_decay + params.cosine_alpha\n  decayed_learning_rate = params.learning_rate * decayed\n\n  return decayed_learning_rate\n"
  },
  {
    "path": "training/tpu.py",
    "content": "import collections\nimport six\n\nimport tensorflow as tf\n\n\ndef remove_summaries():\n  g = tf.get_default_graph()\n  key = tf.GraphKeys.SUMMARIES\n  del g.get_collection_ref(key)[:]\n  assert not g.get_collection(key)\n\n\n# From Tensor2Tensor\ndef create_host_call(model_dir):\n  \"\"\"Construct a host_call writing scalar summaries.\n  Args:\n    model_dir: String containing path to train\n  Returns:\n    (fn, args) Pair to be called by TPUEstimator as the host_call.\n  \"\"\"\n  graph = tf.get_default_graph()\n  summaries = graph.get_collection(tf.GraphKeys.SUMMARIES)\n\n  gs_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])\n  summary_kwargs = collections.OrderedDict()\n  for t in summaries:\n    if t.op.type not in [\"ScalarSummary\", \"HistogramSummary\"]:\n      tf.logging.warn(\"Ignoring unsupported tf.Summary type %s\" % t.op.type)\n      continue\n\n    name = t.op.name\n    tensor = t.op.inputs[1]\n    if t.op.type == \"ScalarSummary\":\n      assert tensor.shape.is_compatible_with([])\n      if tensor.dtype == tf.int64:\n        tensor = tf.to_int32(tensor)\n      summary_kwargs[\"ScalarSummary\" + name] = tf.reshape(tensor, [1])\n    elif t.op.type == \"HistogramSummary\":\n      summary_kwargs[\"HistogramSummary\" + name] = tf.reshape(tensor, [-1])\n  # When no supported summaries are found, don't create host_call. Otherwise,\n  # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue\n  # it, eventually causing hang.\n  if not summary_kwargs:\n    return None\n  summary_kwargs[\"global_step\"] = gs_t\n\n  def host_call_fn(**kwargs):\n    \"\"\"Training host call. Creates summaries for training metrics.\n    Args:\n      **kwargs: Dict of {str: Tensor} , with `Tensor` of shape `[batch]`. Must\n        contain key \"global_step\" with value of current global_step Tensor.\n    Returns:\n      List of summary ops to run on the CPU host.\n    \"\"\"\n    gs = tf.to_int64(kwargs.pop(\"global_step\")[0])\n    with tf.contrib.summary.create_file_writer(model_dir).as_default():\n      with tf.contrib.summary.always_record_summaries():\n        # We need to use tf.contrib.summary in order to feed the `step`.\n        for name, value in sorted(six.iteritems(kwargs)):\n          if name.startswith(\"ScalarSummary\"):\n            name = name[len(\"ScalarSummary\"):]\n            tf.contrib.summary.scalar(\n                name, tf.reduce_mean(tf.to_float(value)), step=gs)\n          elif name.startswith(\"HistogramSummary\"):\n            name = name[len(\"HistogramSummary\"):]\n            tf.contrib.summary.histogram(name, value, step=gs)\n          elif name.startswith(\"ImageSummary\"):\n            name = name[len(\"ImageSummary\"):]\n            tf.contrib.summary.image(name, value, step=gs)\n\n        return tf.contrib.summary.all_summary_ops()\n\n  return (host_call_fn, summary_kwargs)\n"
  }
]