[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n*.hdf5\n*.h5\n*.pth\n\nnotebooks/"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\npython:\n  # We don't actually use the Travis Python, but this keeps it organized.\n  - \"2.7\"\n  - \"3.6\"\ninstall:\n  - sudo apt-get update\n  # We do this conditionally because it saves us some downloading if the\n  # version is the same.\n  - if [[ \"$TRAVIS_PYTHON_VERSION\" == \"2.7\" ]]; then\n      wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;\n    else\n      wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;\n    fi\n  - bash miniconda.sh -b -p $HOME/miniconda\n  - export PATH=\"$HOME/miniconda/bin:$PATH\"\n  - hash -r\n  - conda config --set always_yes yes --set changeps1 no\n  - conda update -q conda\n  # Useful for debugging any issues with conda\n  - conda info -a\n\n  - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION\n  - source activate test-environment\n  - conda install pytorch torchvision cuda80 numpy scipy h5py -c pytorch -q\n  - conda install tensorflow -c conda-forge -q\n  - conda install theano pygpu -q\n  - python setup.py install\n\nscript:\n  - TEST_TRANSFER_DIRECTION=keras2pytorch KERAS_BACKEND=theano python setup.py test\n  - TEST_TRANSFER_DIRECTION=keras2pytorch KERAS_BACKEND=tensorflow python setup.py test\n  - TEST_TRANSFER_DIRECTION=pytorch2keras KERAS_BACKEND=theano python setup.py test\n  - TEST_TRANSFER_DIRECTION=pytorch2keras KERAS_BACKEND=tensorflow python setup.py test"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2017 Guido Zuidhof\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# nn-transfer\n\n[![Build Status](https://travis-ci.org/gzuidhof/nn-transfer.svg?branch=master)](https://travis-ci.org/gzuidhof/nn-transfer)\n\n**NOTE: This repository does not seem to yield the correct output anymore with the latest versions of Keras and PyTorch. Take care to verify the results or use an alternative method for conversion.**\n\nThis repository contains utilities for **converting PyTorch models to Keras and the other way around**. More specifically, it allows you to copy the weights from a PyTorch model to an identical model in Keras and vice-versa.\n\nFrom Keras you can then run it on the **TensorFlow**, **Theano** and **CNTK** backend. You can also convert it to a pure TensorFlow model (see [[1]](https://github.com/amir-abdi/keras_to_tensorflow) and [[2]](https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html)), which allows you to choose more robust deployment options in the cloud, or even mobile devices. From Keras you can also do inference in browsers with [keras-js](https://github.com/transcranial/keras-js).\n\n## Installation\nClone this repository, and simply run\n\n```\npip install .\n```\n\nYou need to have PyTorch and torchvision installed beforehand, see the [PyTorch website](https://www.pytorch.org) for how to easily install that.\n\n## Tests\n\nTo run the unit and integration tests:\n\n```\npython setup.py test\n# OR, if you have nose2 installed,\nnose2\n```\n\nThere is also Travis CI which will automatically build every commit, see the button at the top of the readme. You can test the direction of weight transfer individually using the `TEST_TRANSFER_DIRECTION` environment variable, see `.travis.yml`.\n\n## How to use\n\nSee [**example.ipynb**](example.ipynb) for a small tutorial on how to use this library.\n\n## Code guidelines\n\n* This repository is fully PEP8 compliant, I recommend `flake8`.\n* It works for both Python 2 and 3.\n"
  },
  {
    "path": "example.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Using Theano backend.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from __future__ import print_function\\n\",\n    \"from collections import OrderedDict\\n\",\n    \"\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"from nn_transfer import transfer, util\\n\",\n    \"\\n\",\n    \"%matplotlib inline\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Step 1\\n\",\n    \"Simply define your PyTorch model like usual, and create an instance of it.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from torch.autograd import Variable\\n\",\n    \"import torch.nn as nn\\n\",\n    \"import torch.nn.functional as F\\n\",\n    \"\\n\",\n    \"class LeNet(nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super(LeNet, self).__init__()\\n\",\n    \"        self.conv1 = nn.Conv2d(1, 6, 5)\\n\",\n    \"        self.conv2 = nn.Conv2d(6, 16, 5)\\n\",\n    \"        self.fc1   = nn.Linear(16*5*5, 120)\\n\",\n    \"        self.fc2   = nn.Linear(120, 84)\\n\",\n    \"        self.fc3   = nn.Linear(84, 10)\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        out = F.relu(self.conv1(x))\\n\",\n    \"        out = F.max_pool2d(out, 2)\\n\",\n    \"        out = F.relu(self.conv2(out))\\n\",\n    \"        out = F.max_pool2d(out, 2)\\n\",\n    \"        out = out.view(out.size(0), -1)\\n\",\n    \"        out = F.relu(self.fc1(out))\\n\",\n    \"        out = F.relu(self.fc2(out))\\n\",\n    \"        out = self.fc3(out)\\n\",\n    \"        return out\\n\",\n    \"    \\n\",\n    \"pytorch_network = LeNet()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Step 2\\n\",\n    \"Determine the names of the layers.\\n\",\n    \"\\n\",\n    \"For the above model example it is very straightforward, but if you use param groups it may be a little more involved. To determine the names of the layers the next commands are useful:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"LeNet (\\n\",\n      \"  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\\n\",\n      \"  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\\n\",\n      \"  (fc1): Linear (400 -> 120)\\n\",\n      \"  (fc2): Linear (120 -> 84)\\n\",\n      \"  (fc3): Linear (84 -> 10)\\n\",\n      \")\\n\",\n      \"['conv1', 'conv2', 'fc1', 'fc2', 'fc3']\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# The most useful, just print the network\\n\",\n    \"print(pytorch_network)\\n\",\n    \"\\n\",\n    \"# Also useful: will only print those layers with params\\n\",\n    \"state_dict = pytorch_network.state_dict()\\n\",\n    \"print(util.state_dict_layer_names(state_dict))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Step 3\\n\",\n    \"Define an equivalent Keras network. Use the built-in `name` keyword argument for each layer with params.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import keras\\n\",\n    \"from keras import backend as K\\n\",\n    \"from keras.models import Sequential\\n\",\n    \"from keras.layers import Dense, Dropout, Flatten\\n\",\n    \"from keras.layers import Conv2D, MaxPooling2D\\n\",\n    \"K.set_image_data_format('channels_first')\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def lenet_keras():\\n\",\n    \"\\n\",\n    \"    model = Sequential()\\n\",\n    \"    model.add(Conv2D(6, kernel_size=(5, 5),\\n\",\n    \"                     activation='relu',\\n\",\n    \"                     input_shape=(1,32,32),\\n\",\n    \"                     name='conv1'))\\n\",\n    \"    model.add(MaxPooling2D(pool_size=(2, 2)))\\n\",\n    \"    model.add(Conv2D(16, (5, 5), activation='relu', name='conv2'))\\n\",\n    \"    model.add(MaxPooling2D(pool_size=(2, 2)))\\n\",\n    \"    model.add(Flatten())\\n\",\n    \"    model.add(Dense(120, activation='relu', name='fc1'))\\n\",\n    \"    model.add(Dense(84, activation='relu', name='fc2'))\\n\",\n    \"    model.add(Dense(10, activation=None, name='fc3'))\\n\",\n    \"\\n\",\n    \"    model.compile(loss=keras.losses.categorical_crossentropy,\\n\",\n    \"                  optimizer=keras.optimizers.Adadelta())\\n\",\n    \"    \\n\",\n    \"    return model\\n\",\n    \"    \\n\",\n    \"keras_network = lenet_keras()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"source\": [\n    \"## Step 4\\n\",\n    \"Now simply convert!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Layer names in target ['conv1', 'conv2', 'fc1', 'fc2', 'fc3']\\n\",\n      \"Layer names in Keras HDF5 ['conv1', 'conv2', 'fc1', 'fc2', 'fc3', 'flatten_1', 'max_pooling2d_1', 'max_pooling2d_2']\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"transfer.keras_to_pytorch(keras_network, pytorch_network)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"source\": [\n    \"## Done!\\n\",\n    \"\\n\",\n    \"Now let's check whether it was succesful. If it was, both networks should have the same output.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create dummy data\\n\",\n    \"data = torch.rand(6,1,32,32)\\n\",\n    \"data_keras = data.numpy()\\n\",\n    \"data_pytorch = Variable(data, requires_grad=False)\\n\",\n    \"\\n\",\n    \"# Do a forward pass in both frameworks\\n\",\n    \"keras_pred = keras_network.predict(data_keras)\\n\",\n    \"pytorch_pred = pytorch_network(data_pytorch).data.numpy()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAW4AAADrCAYAAABJqHxQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\nAAALEgAACxIB0t1+/AAABU5JREFUeJzt28GKVwUYxuFXZ1SqGQNLAjEzjMqNLZKgNkGLLiC6mWgT\\nUVcQrdq0iFZRVBvbtbdF4qISlIisLMQkZxR0/HcLnsWfzxeeZz18vBzm/OZsZt9qtQoAPfZPDwBg\\nGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0CZzXUcPf3Ve+P/jvnhma+nJyRJjm3emJ6Qd6+8\\nNT0hSXLljyenJ+TjVz+fnpAk+Xfv0ekJOX/r2ekJSZIvz5+dnpDTH92cnpAkOXfxg30P8nO+uAHK\\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\\nygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUGZz\\nHUdfP3F5HWcX+eKfs9MTkiSfnfx+ekKu/bc1PSFJstqb/064eOf49IQkydvbF6Yn5NOrr01PSJIc\\nuLExPSG3j29PT1hk/k0CYBHhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBug\\njHADlBFugDLCDVBGuAHKCDdAmc11HP3h76fXcXaRI4/sTk9Ikly+e2t6Qo5u7UxPSJLc3j00PSFv\\nPPbT9IQkyS93n5iekKs3H5+ekCS5/8zt6Qk5+N2P0xMW8cUNUEa4AcoIN0AZ4QYoI9wAZYQboIxw\\nA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAym+s4unPn4DrOVtpZreURL/LrxWPT\\nE5Ikz535fXpC/to7PD0hSXL93tb0hJw6cn16QpLkwqUT0xOy8fyp6QmL+OIGKCPcAGWEG6CMcAOU\\nEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHAD\\nlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYos7mOo0e3d9Zx\\ndpFDG/emJyRJDuT+9ISsHpI/zxv755/F9v470xOSJHc31vLqLXJtd2t6QpJk3+7G9ISsfrs6PWGR\\nh+SVBuBBCTdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAy\\nwg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6A\\nMsINUGZzHUf/vHF4HWcXeeelc9MTkiS7q7U84kVeOXtpekKS5MonL0xPyMn3b01PSJLcX81/M715\\n7OfpCUmSb759anpC9l5+cXrCIvO/PQAsItwAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABl\\nhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wA\\nZYQboIxwA5QRboAywg1QRrgBygg3QBnhBiizb7VaTW8AYAFf3ABlhBugjHADlBFugDLCDVBGuAHK\\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\\nygg3QJn/ATUPZi0NSPeeAAAAAElFTkSuQmCC\\n\",\n      \"text/plain\": [\n       \"<matplotlib.figure.Figure at 0x7f588249dcd0>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAW4AAADrCAYAAABJqHxQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\nAAALEgAACxIB0t1+/AAABU5JREFUeJzt28GKVwUYxuFXZ1SqGQNLAjEzjMqNLZKgNkGLLiC6mWgT\\nUVcQrdq0iFZRVBvbtbdF4qISlIisLMQkZxR0/HcLnsWfzxeeZz18vBzm/OZsZt9qtQoAPfZPDwBg\\nGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0CZzXUcPf3Ve+P/jvnhma+nJyRJjm3emJ6Qd6+8\\nNT0hSXLljyenJ+TjVz+fnpAk+Xfv0ekJOX/r2ekJSZIvz5+dnpDTH92cnpAkOXfxg30P8nO+uAHK\\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\\nygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUGZz\\nHUdfP3F5HWcX+eKfs9MTkiSfnfx+ekKu/bc1PSFJstqb/064eOf49IQkydvbF6Yn5NOrr01PSJIc\\nuLExPSG3j29PT1hk/k0CYBHhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBug\\njHADlBFugDLCDVBGuAHKCDdAmc11HP3h76fXcXaRI4/sTk9Ikly+e2t6Qo5u7UxPSJLc3j00PSFv\\nPPbT9IQkyS93n5iekKs3H5+ekCS5/8zt6Qk5+N2P0xMW8cUNUEa4AcoIN0AZ4QYoI9wAZYQboIxw\\nA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAym+s4unPn4DrOVtpZreURL/LrxWPT\\nE5Ikz535fXpC/to7PD0hSXL93tb0hJw6cn16QpLkwqUT0xOy8fyp6QmL+OIGKCPcAGWEG6CMcAOU\\nEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHAD\\nlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYos7mOo0e3d9Zx\\ndpFDG/emJyRJDuT+9ISsHpI/zxv755/F9v470xOSJHc31vLqLXJtd2t6QpJk3+7G9ISsfrs6PWGR\\nh+SVBuBBCTdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAy\\nwg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6A\\nMsINUGZzHUf/vHF4HWcXeeelc9MTkiS7q7U84kVeOXtpekKS5MonL0xPyMn3b01PSJLcX81/M715\\n7OfpCUmSb759anpC9l5+cXrCIvO/PQAsItwAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABl\\nhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wA\\nZYQboIxwA5QRboAywg1QRrgBygg3QBnhBiizb7VaTW8AYAFf3ABlhBugjHADlBFugDLCDVBGuAHK\\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\\nygg3QJn/ATUPZi0NSPeeAAAAAElFTkSuQmCC\\n\",\n      \"text/plain\": [\n       \"<matplotlib.figure.Figure at 0x7f58824a9a90>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"assert keras_pred.shape == pytorch_pred.shape\\n\",\n    \"\\n\",\n    \"plt.axis('Off')\\n\",\n    \"plt.imshow(keras_pred)\\n\",\n    \"plt.show()\\n\",\n    \"plt.axis('Off')\\n\",\n    \"plt.imshow(pytorch_pred)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"source\": [\n    \"They are the same, it works :)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python [Root]\",\n   \"language\": \"python\",\n   \"name\": \"Python [Root]\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"2.7.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "nn_transfer/__init__.py",
    "content": ""
  },
  {
    "path": "nn_transfer/test/__init__.py",
    "content": ""
  },
  {
    "path": "nn_transfer/test/architectures/__init__.py",
    "content": ""
  },
  {
    "path": "nn_transfer/test/architectures/lenet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom keras.models import Sequential\nfrom keras.layers import Dense, Flatten\nfrom keras.layers import Conv2D, MaxPooling2D\n\nK.set_image_data_format('channels_first')\n\n\nclass LeNetPytorch(nn.Module):\n    def __init__(self):\n        super(LeNetPytorch, self).__init__()\n        self.conv1 = nn.Conv2d(1, 6, 5)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = F.max_pool2d(out, 2)\n        out = F.relu(self.conv2(out))\n        out = F.max_pool2d(out, 2)\n        out = out.view(out.size(0), -1)\n        out = F.relu(self.fc1(out))\n        out = F.relu(self.fc2(out))\n        out = self.fc3(out)\n        return out\n\n\ndef lenet_keras():\n\n    model = Sequential()\n    model.add(Conv2D(6, kernel_size=(5, 5),\n                     activation='relu',\n                     input_shape=(1, 32, 32),\n                     name='conv1'))\n    model.add(MaxPooling2D(pool_size=(2, 2)))\n    model.add(Conv2D(16, (5, 5), activation='relu', name='conv2'))\n    model.add(MaxPooling2D(pool_size=(2, 2)))\n    model.add(Flatten())\n    model.add(Dense(120, activation='relu', name='fc1'))\n    model.add(Dense(84, activation='relu', name='fc2'))\n    model.add(Dense(10, activation=None, name='fc3'))\n\n    model.compile(loss=keras.losses.categorical_crossentropy,\n                  optimizer=keras.optimizers.SGD())\n\n    return model\n"
  },
  {
    "path": "nn_transfer/test/architectures/simplenet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom keras.models import Sequential\nfrom keras.layers import Dense, Flatten\nfrom keras.layers import Conv2D, MaxPooling2D, BatchNormalization\n\nK.set_image_data_format('channels_first')\n\n\nclass SimpleNetPytorch(nn.Module):\n    def __init__(self):\n        super(SimpleNetPytorch, self).__init__()\n        self.conv1 = nn.Conv2d(1, 6, 5)\n        self.bn = nn.BatchNorm2d(6)\n        self.fc1 = nn.Linear(6 * 14 * 14, 10)\n\n    def forward(self, x):\n        out = F.relu(self.bn(self.conv1(x)))\n        out = F.max_pool2d(out, 2)\n        out = out.view(out.size(0), -1)\n        out = self.fc1(out)\n        return out\n\n\ndef simplenet_keras():\n    model = Sequential()\n    model.add(Conv2D(6, kernel_size=(5, 5),\n                     activation='relu',\n                     input_shape=(1, 32, 32),\n                     name='conv1'))\n    model.add(BatchNormalization(axis=1, name='bn'))\n    model.add(MaxPooling2D(pool_size=(2, 2)))\n    model.add(Flatten())\n    model.add(Dense(10, activation=None, name='fc1'))\n\n    model.compile(loss=keras.losses.categorical_crossentropy,\n                  optimizer=keras.optimizers.SGD())\n\n    return model\n"
  },
  {
    "path": "nn_transfer/test/architectures/unet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom keras.models import Input, Model\nfrom keras.layers import Conv2D, MaxPooling2D\nfrom keras.layers import Conv2DTranspose, concatenate\n\nK.set_image_data_format('channels_first')\n\n# From https://github.com/jocicmarko/ultrasound-nerve-segmentation\n\n\ndef unet_keras(input_size=224):\n    inputs = Input((1, input_size, input_size))\n    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',\n                   name='conv_block1_32.conv')(inputs)\n    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',\n                   name='conv_block1_32.conv2')(conv1)\n    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n\n    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',\n                   name='conv_block32_64.conv')(pool1)\n    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',\n                   name='conv_block32_64.conv2')(conv2)\n    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n\n    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',\n                   name='conv_block64_128.conv')(pool2)\n    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',\n                   name='conv_block64_128.conv2')(conv3)\n    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n\n    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',\n                   name='conv_block128_256.conv')(pool3)\n    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',\n                   name='conv_block128_256.conv2')(conv4)\n    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)\n\n    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',\n                   name='conv_block256_512.conv')(pool4)\n    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',\n                   name='conv_block256_512.conv2')(conv5)\n\n    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2),\n                                       padding='valid',\n                                       name='up_block512_256.up')(conv5),\n                       conv4], axis=1)\n    conv6 = Conv2D(256, (3, 3), activation='relu',\n                   padding='same', name='up_block512_256.conv')(up6)\n    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same',\n                   name='up_block512_256.conv2')(conv6)\n\n    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2),\n                                       padding='valid',\n                                       name='up_block256_128.up')(conv6),\n                       conv3], axis=1)\n    conv7 = Conv2D(128, (3, 3), activation='relu',\n                   padding='same', name='up_block256_128.conv')(up7)\n    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',\n                   name='up_block256_128.conv2')(conv7)\n\n    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2),\n                                       padding='valid',\n                                       name='up_block128_64.up')(conv7),\n                       conv2], axis=1)\n    conv8 = Conv2D(64, (3, 3), activation='relu',\n                   padding='same', name='up_block128_64.conv')(up8)\n    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',\n                   name='up_block128_64.conv2')(conv8)\n\n    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2),\n                                       padding='valid',\n                                       name='up_block64_32.up')(conv8),\n                       conv1], axis=1)\n    conv9 = Conv2D(32, (3, 3), activation='relu',\n                   padding='same', name='up_block64_32.conv')(up9)\n    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',\n                   name='up_block64_32.conv2')(conv9)\n\n    conv10 = Conv2D(2, (1, 1), activation=None, name='last')(conv9)\n\n    model = Model(inputs=[inputs], outputs=[conv10])\n    model.compile(optimizer=keras.optimizers.SGD(),\n                  loss=keras.losses.categorical_crossentropy)\n\n    return model\n\n\nclass UNetConvBlock(nn.Module):\n    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):\n        super(UNetConvBlock, self).__init__()\n        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)\n        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)\n        self.activation = activation\n\n    def forward(self, x):\n        out = self.activation(self.conv(x))\n        out = self.activation(self.conv2(out))\n        return out\n\n\nclass UNetUpBlock(nn.Module):\n    def __init__(self, in_size, out_size, kernel_size=3,\n                 activation=F.relu, space_dropout=False):\n        super(UNetUpBlock, self).__init__()\n        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)\n        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)\n        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)\n        self.activation = activation\n\n    def center_crop(self, layer, target_size):\n        batch_size, n_channels, layer_width, layer_height = layer.size()\n        xy1 = (layer_width - target_size) // 2\n        return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]\n\n    def forward(self, x, bridge):\n        up = self.up(x)\n        crop1 = self.center_crop(bridge, up.size()[2])\n        out = torch.cat([up, crop1], 1)\n        out = self.activation(self.conv(out))\n        out = self.activation(self.conv2(out))\n\n        return out\n\n\nclass UNetPytorch(nn.Module):\n    def __init__(self):\n        super(UNetPytorch, self).__init__()\n\n        self.activation = F.relu\n\n        self.pool1 = nn.MaxPool2d(2)\n        self.pool2 = nn.MaxPool2d(2)\n        self.pool3 = nn.MaxPool2d(2)\n        self.pool4 = nn.MaxPool2d(2)\n\n        self.conv_block1_32 = UNetConvBlock(1, 32)\n        self.conv_block32_64 = UNetConvBlock(32, 64)\n        self.conv_block64_128 = UNetConvBlock(64, 128)\n        self.conv_block128_256 = UNetConvBlock(128, 256)\n\n        self.conv_block256_512 = UNetConvBlock(256, 512)\n        self.up_block512_256 = UNetUpBlock(512, 256)\n\n        self.up_block256_128 = UNetUpBlock(256, 128)\n        self.up_block128_64 = UNetUpBlock(128, 64)\n        self.up_block64_32 = UNetUpBlock(64, 32)\n\n        self.last = nn.Conv2d(32, 2, 1)\n\n    def forward(self, x):\n\n        block1 = self.conv_block1_32(x)\n        pool1 = self.pool1(block1)\n\n        block2 = self.conv_block32_64(pool1)\n        pool2 = self.pool2(block2)\n\n        block3 = self.conv_block64_128(pool2)\n        pool3 = self.pool3(block3)\n\n        block4 = self.conv_block128_256(pool3)\n        pool4 = self.pool4(block4)\n\n        block5 = self.conv_block256_512(pool4)\n\n        up1 = self.up_block512_256(block5, block4)\n        up2 = self.up_block256_128(up1, block3)\n        up3 = self.up_block128_64(up2, block2)\n        up4 = self.up_block64_32(up3, block1)\n\n        return self.last(up4)\n\n\nif __name__ == \"__main__\":\n    net = UNetPytorch()\n"
  },
  {
    "path": "nn_transfer/test/architectures/vggnet.py",
    "content": "from torchvision.models import vgg\n\nfrom keras import backend as K\nfrom keras.models import Input, Model\nfrom keras.layers import Dense, Flatten, Dropout\nfrom keras.layers import Conv2D, MaxPooling2D\n\nK.set_image_data_format('channels_first')\n\n\ndef vggnet_pytorch():\n    return vgg.vgg16()\n\n\ndef vggnet_keras():\n\n    # Block 1\n    img_input = Input((3, 224, 224))\n    x = Conv2D(64, (3, 3), activation='relu',\n               padding='same', name='features.0')(img_input)\n    x = Conv2D(64, (3, 3), activation='relu',\n               padding='same', name='features.2')(x)\n    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)\n\n    # Block 2\n    x = Conv2D(128, (3, 3), activation='relu',\n               padding='same', name='features.5')(x)\n    x = Conv2D(128, (3, 3), activation='relu',\n               padding='same', name='features.7')(x)\n    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)\n\n    # Block 3\n    x = Conv2D(256, (3, 3), activation='relu',\n               padding='same', name='features.10')(x)\n    x = Conv2D(256, (3, 3), activation='relu',\n               padding='same', name='features.12')(x)\n    x = Conv2D(256, (3, 3), activation='relu',\n               padding='same', name='features.14')(x)\n    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)\n\n    # Block 4\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.17')(x)\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.19')(x)\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.21')(x)\n    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)\n\n    # Block 5\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.24')(x)\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.26')(x)\n    x = Conv2D(512, (3, 3), activation='relu',\n               padding='same', name='features.28')(x)\n    x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)\n\n    x = Flatten(name='flatten')(x)\n    x = Dense(4096, activation='relu', name='classifier.0')(x)\n    x = Dropout(0.5)(x)\n    x = Dense(4096, activation='relu', name='classifier.3')(x)\n    x = Dropout(0.5)(x)\n    x = Dense(1000, activation=None, name='classifier.6')(x)\n\n    # Create model.\n    model = Model(img_input, x, name='vgg16')\n\n    return model\n"
  },
  {
    "path": "nn_transfer/test/helpers.py",
    "content": "from __future__ import print_function\nimport os\n\nimport numpy as np\nimport torch\nfrom torch.autograd import Variable\n\nfrom .. import transfer\n\nif 'TEST_TRANSFER_DIRECTION' in os.environ:\n    TRANSFER_DIRECTION = os.environ['TEST_TRANSFER_DIRECTION'].lower()\nelse:\n    TRANSFER_DIRECTION = 'keras2pytorch'\n\nprint(TRANSFER_DIRECTION, \"tests\")\n\n\ndef set_seeds():\n    torch.manual_seed(0)\n    np.random.seed(0)\n\n\nclass TransferTestCase(object):\n    def assertEqualPrediction(\n            self, keras_model, pytorch_model, test_data, delta=1e-6):\n\n        # Make sure the pytorch model is in evaluation mode (i.e. no dropout)\n        pytorch_model.eval()\n\n        test_data = test_data.astype(np.float32, copy=False)\n        test_data_tensor = Variable(\n            torch.from_numpy(test_data),\n            requires_grad=False)\n\n        keras_prediction = keras_model.predict(test_data)\n        pytorch_prediction = pytorch_model(test_data_tensor).data.numpy()\n\n        self.assertEqual(keras_prediction.shape, pytorch_prediction.shape)\n        for v1, v2 in zip(keras_prediction.flatten(),\n                          pytorch_prediction.flatten()):\n            self.assertAlmostEqual(v1, v2, delta=delta)\n        return keras_prediction, pytorch_prediction\n\n    def is_keras_to_pytorch(self):\n        return TRANSFER_DIRECTION == 'keras2pytorch'\n\n    def transfer(self, keras_model, pytorch_model, verbose=False):\n\n        if self.is_keras_to_pytorch():\n            transfer.keras_to_pytorch(keras_model,\n                                      pytorch_model,\n                                      verbose=verbose)\n        else:\n            transfer.pytorch_to_keras(pytorch_model,\n                                      keras_model,\n                                      verbose=verbose)\n"
  },
  {
    "path": "nn_transfer/test/test_architectures.py",
    "content": "import unittest\n\nimport numpy as np\n\nfrom .helpers import TransferTestCase, set_seeds\nfrom .architectures.lenet import lenet_keras, LeNetPytorch\nfrom .architectures.simplenet import simplenet_keras, SimpleNetPytorch\nfrom .architectures.vggnet import vggnet_keras, vggnet_pytorch\nfrom .architectures.unet import unet_keras, UNetPytorch\n\n\nclass TestArchitectures(TransferTestCase, unittest.TestCase):\n\n    def setUp(self):\n        self.test_data_small = np.random.rand(4, 1, 32, 32)\n        self.test_data_vgg = np.random.rand(1, 3, 224, 224)\n        self.test_data_unet = np.random.rand(1, 1, 224, 224)\n        set_seeds()\n\n    def test_simplenet(self):\n        keras_model = simplenet_keras()\n        pytorch_model = SimpleNetPytorch()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(\n            keras_model,\n            pytorch_model,\n            self.test_data_small,\n            delta=1e-3)  # These results can vary due to float imprecision\n\n    def test_lenet(self):\n        keras_model = lenet_keras()\n        pytorch_model = LeNetPytorch()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(\n            keras_model,\n            pytorch_model,\n            self.test_data_small)\n\n    def test_unet(self):\n        keras_model = unet_keras()\n        pytorch_model = UNetPytorch()\n        pytorch_model.eval()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(\n            keras_model,\n            pytorch_model,\n            self.test_data_unet)\n\n    def test_vggnet(self):\n        keras_model = vggnet_keras()\n        pytorch_model = vggnet_pytorch()\n        pytorch_model.eval()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(\n            keras_model, pytorch_model, self.test_data_vgg)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "nn_transfer/test/test_layers.py",
    "content": "import unittest\n\nimport numpy as np\nimport torch.nn as nn\n\nimport keras\nfrom keras.models import Sequential\nfrom keras.layers import BatchNormalization, PReLU, ELU\nfrom keras.layers import Conv2DTranspose, Conv2D, Conv3D\n\nfrom .helpers import TransferTestCase\n\nkeras.backend.set_image_data_format('channels_first')\n\n\nclass BatchNet(nn.Module):\n    def __init__(self):\n        super(BatchNet, self).__init__()\n        self.bn = nn.BatchNorm3d(3)\n\n    def forward(self, x):\n        return self.bn(x)\n\n\nclass ELUNet(nn.Module):\n    def __init__(self):\n        super(ELUNet, self).__init__()\n        self.elu = nn.ELU()\n\n    def forward(self, x):\n        return self.elu(x)\n\n\nclass TransposeNet(nn.Module):\n    def __init__(self):\n        super(TransposeNet, self).__init__()\n        self.trans = nn.ConvTranspose2d(3, 32, 2, 2)\n\n    def forward(self, x):\n        return self.trans(x)\n\n\nclass PReLUNet(nn.Module):\n    def __init__(self):\n        super(PReLUNet, self).__init__()\n        self.prelu = nn.PReLU(3)\n\n    def forward(self, x):\n        return self.prelu(x)\n\n\nclass Conv2DNet(nn.Module):\n    def __init__(self):\n        super(Conv2DNet, self).__init__()\n        self.conv = nn.Conv2d(3, 16, 7)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass Conv3DNet(nn.Module):\n    def __init__(self):\n        super(Conv3DNet, self).__init__()\n        self.conv = nn.Conv3d(3, 8, 5)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass TestLayers(TransferTestCase, unittest.TestCase):\n\n    def setUp(self):\n        self.test_data = np.random.rand(2, 3, 32, 32)\n        self.test_data_3d = np.random.rand(2, 3, 8, 8, 8)\n\n    def test_batch_normalization(self):\n        keras_model = Sequential()\n        keras_model.add(\n            BatchNormalization(input_shape=(3, 32, 32), axis=1, name='bn'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = BatchNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(\n            keras_model, pytorch_model, self.test_data, 1e-3)\n\n    def test_transposed_conv(self):\n        keras_model = Sequential()\n        keras_model.add(Conv2DTranspose(32, (2, 2), strides=(\n            2, 2), input_shape=(3, 32, 32), name='trans'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = TransposeNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)\n\n    # Tests special activation function\n    def test_elu(self):\n        keras_model = Sequential()\n        keras_model.add(ELU(input_shape=(3, 32, 32), name='elu'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = ELUNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)\n\n    # Tests activation function with learned parameters\n    def test_prelu(self):\n        keras_model = Sequential()\n        keras_model.add(PReLU(input_shape=(3, 32, 32), shared_axes=(2, 3),\n                              name='prelu'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = PReLUNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)\n\n    def test_conv2d(self):\n        keras_model = Sequential()\n        keras_model.add(Conv2D(16, (7, 7), input_shape=(3, 32, 32),\n                               name='conv'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = Conv2DNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(keras_model,\n                                   pytorch_model,\n                                   self.test_data,\n                                   delta=1e-4)\n\n\n    def test_conv3d(self):\n        keras_model = Sequential()\n        keras_model.add(Conv3D(8, (5, 5, 5), input_shape=(3, 8, 8, 8),\n                               name='conv'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = Conv3DNet()\n\n        self.transfer(keras_model, pytorch_model)\n        self.assertEqualPrediction(keras_model,\n                                   pytorch_model,\n                                   self.test_data_3d,\n                                   delta=1e-4)\n\n    def test_keras_model_changed_as_expected(self):\n        keras_model = Sequential()\n        keras_model.add(Conv2D(16, (7, 7), input_shape=(3, 32, 32),\n                               name='conv'))\n        keras_model.compile(loss=keras.losses.categorical_crossentropy,\n                            optimizer=keras.optimizers.SGD())\n\n        pytorch_model = Conv2DNet()\n\n        weights_before = keras_model.layers[0].get_weights()[0]\n        prediction_before = keras_model.predict(self.test_data)\n\n        self.transfer(keras_model, pytorch_model)\n\n        weights_after = keras_model.layers[0].get_weights()[0]\n        prediction_after = keras_model.predict(self.test_data)\n\n        if self.is_keras_to_pytorch():  # Keras model should be unchanged\n            self.assertTrue((weights_before == weights_after).all())\n            self.assertEqual(\n                prediction_before.tobytes(),\n                prediction_after.tobytes(),\n                msg=\"Predictions not are exactly the same\")\n        else:\n            self.assertFalse((weights_before == weights_after).all())\n            self.assertNotEqual(\n                prediction_before.tobytes(),\n                prediction_after.tobytes(),\n                msg=\"Predictions are exactly the same\")\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "nn_transfer/test/test_util.py",
    "content": "import unittest\n\nfrom .. import util\n\n\nclass TestUtil(unittest.TestCase):\n\n    def test_state_dict_names(self):\n        state_dict = {\n            'conv1.weight': 0,\n            'conv1.bias': 0,\n            'fc1.weight': 0,\n            'fc2.weight': 0,\n            'fc2.bias': 0\n        }\n        layer_names = util.state_dict_layer_names(state_dict)\n\n        self.assertListEqual(sorted(layer_names), ['conv1', 'fc1', 'fc2'])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "nn_transfer/transfer.py",
    "content": "from __future__ import print_function\nfrom collections import OrderedDict\n\nimport numpy as np\nimport h5py\nimport keras\nimport torch\n\nfrom . import util\n\nVAR_AFFIX = ':0' if keras.backend.backend() == 'tensorflow' else ''\n\nKERAS_GAMMA_KEY = 'gamma' + VAR_AFFIX\nKERAS_KERNEL_KEY = 'kernel' + VAR_AFFIX\nKERAS_ALPHA_KEY = 'alpha' + VAR_AFFIX\nKERAS_BIAS_KEY = 'bias' + VAR_AFFIX\nKERAS_BETA_KEY = 'beta' + VAR_AFFIX\nKERAS_MOVING_MEAN_KEY = 'moving_mean' + VAR_AFFIX\nKERAS_MOVING_VARIANCE_KEY = 'moving_variance' + VAR_AFFIX\nKERAS_EPSILON = 1e-3\nPYTORCH_EPSILON = 1e-5\n\n\ndef check_for_missing_layers(keras_names, pytorch_layer_names, verbose):\n\n    if verbose:\n        print(\"Layer names in PyTorch state_dict\", pytorch_layer_names)\n        print(\"Layer names in Keras HDF5\", keras_names)\n\n    if not all(x in keras_names for x in pytorch_layer_names):\n        missing_layers = list(set(pytorch_layer_names) - set(keras_names))\n        raise Exception(\"Missing layer(s) in Keras HDF5 that are present\" +\n                        \" in state_dict: {}\".format(missing_layers))\n\n\ndef keras_to_pytorch(keras_model, pytorch_model,\n                     flip_filters=None, verbose=True):\n\n    # If not specifically set, determine whether to flip filters automatically\n    # for the right backend.\n    if flip_filters is None:\n        flip_filters = not keras.backend.backend() == 'tensorflow'\n\n    keras_model.save('temp.h5')\n    input_state_dict = pytorch_model.state_dict()\n    pytorch_layer_names = util.state_dict_layer_names(input_state_dict)\n\n    with h5py.File('temp.h5', 'r') as f:\n        model_weights = f['model_weights']\n        layer_names = list(map(str, model_weights.keys()))\n        check_for_missing_layers(layer_names, pytorch_layer_names, verbose)\n        state_dict = OrderedDict()\n\n        for layer in pytorch_layer_names:\n\n            params = util.dig_to_params(model_weights[layer])\n\n            weight_key = layer + '.weight'\n            bias_key = layer + '.bias'\n            running_mean_key = layer + '.running_mean'\n            running_var_key = layer + '.running_var'\n\n            # Load weights (or other learned parameters)\n            if weight_key in input_state_dict:\n                if KERAS_GAMMA_KEY in params:\n                    weights = params[KERAS_GAMMA_KEY][:]\n                elif KERAS_KERNEL_KEY in params:\n                    weights = params[KERAS_KERNEL_KEY][:]\n                else:\n                    weights = np.squeeze(params[KERAS_ALPHA_KEY][:])\n\n                weights = convert_weights(weights,\n                                          to_keras=True,\n                                          flip_filters=flip_filters)\n\n                state_dict[weight_key] = torch.from_numpy(weights)\n\n            # Load bias\n            if bias_key in input_state_dict:\n                if running_var_key in input_state_dict:\n                    bias = params[KERAS_BETA_KEY][:]\n                else:\n                    bias = params[KERAS_BIAS_KEY][:]\n                state_dict[bias_key] = torch.from_numpy(\n                    bias.transpose())\n\n            # Load batch normalization running mean\n            if running_mean_key in input_state_dict:\n                running_mean = params[KERAS_MOVING_MEAN_KEY][:]\n                state_dict[running_mean_key] = torch.from_numpy(\n                    running_mean.transpose())\n\n            # Load batch normalization running variance\n            if running_var_key in input_state_dict:\n                running_var = params[KERAS_MOVING_VARIANCE_KEY][:]\n                # account for difference in epsilon used\n                running_var += KERAS_EPSILON - PYTORCH_EPSILON\n                state_dict[running_var_key] = torch.from_numpy(\n                    running_var.transpose())\n\n    pytorch_model.load_state_dict(state_dict)\n\n\ndef pytorch_to_keras(pytorch_model, keras_model,\n                     flip_filters=False, flip_channels=None, verbose=True):\n\n    if flip_channels is None:\n        flip_channels = not keras.backend.backend() == 'tensorflow'\n\n    keras_model.save('temp.h5')\n    input_state_dict = pytorch_model.state_dict()\n    pytorch_layer_names = util.state_dict_layer_names(input_state_dict)\n\n    with h5py.File('temp.h5', 'a') as f:\n        model_weights = f['model_weights']\n        target_layer_names = list(map(str, model_weights.keys()))\n        check_for_missing_layers(\n            target_layer_names,\n            pytorch_layer_names,\n            verbose)\n\n        for layer in pytorch_layer_names:\n\n            params = util.dig_to_params(model_weights[layer])\n\n            weight_key = layer + '.weight'\n            bias_key = layer + '.bias'\n            running_mean_key = layer + '.running_mean'\n            running_var_key = layer + '.running_var'\n\n            # Load weights (or other learned parameters)\n            if weight_key in input_state_dict:\n                weights = input_state_dict[weight_key].numpy()\n                weights = convert_weights(weights,\n                                          to_keras=False,\n                                          flip_filters=flip_filters,\n                                          flip_channels=flip_channels)\n\n                if KERAS_GAMMA_KEY in params:\n                    params[KERAS_GAMMA_KEY][:] = weights\n                elif KERAS_KERNEL_KEY in params:\n                    params[KERAS_KERNEL_KEY][:] = weights\n                else:\n                    weights = weights.reshape(params[KERAS_ALPHA_KEY][:].shape)\n                    params[KERAS_ALPHA_KEY][:] = weights\n\n            # Load bias\n            if bias_key in input_state_dict:\n                bias = input_state_dict[bias_key].numpy()\n                if running_var_key in input_state_dict:\n                    params[KERAS_BETA_KEY][:] = bias\n                else:\n                    params[KERAS_BIAS_KEY][:] = bias\n\n            # Load batch normalization running mean\n            if running_mean_key in input_state_dict:\n                running_mean = input_state_dict[running_mean_key].numpy()\n                params[KERAS_MOVING_MEAN_KEY][:] = running_mean\n\n            # Load batch normalization running variance\n            if running_var_key in input_state_dict:\n                running_var = input_state_dict[running_var_key].numpy()\n                # account for difference in epsilon used\n                running_var += PYTORCH_EPSILON - KERAS_EPSILON\n                params[KERAS_MOVING_VARIANCE_KEY][:] = running_var\n\n    # pytorch_model.load_state_dict(state_dict)\n    keras_model.load_weights('temp.h5')\n\n\ndef convert_weights(weights, to_keras, flip_filters, flip_channels=False):\n\n    if len(weights.shape) == 3:  # 1D conv\n        weights = weights.transpose()\n\n        if flip_channels:\n            weights = weights[::-1]\n\n        if flip_filters:\n            weights = weights[..., ::-1].copy()\n\n    if len(weights.shape) == 4:  # 2D conv\n        if to_keras:  # D1 D2 F F\n            weights = weights.transpose(3, 2, 0, 1)\n        else:\n            weights = weights.transpose(2, 3, 1, 0)\n\n        if flip_channels:\n            weights = weights[::-1, ::-1]\n        if flip_filters:\n            weights = weights[..., ::-1, ::-1].copy()\n\n    elif len(weights.shape) == 5:  # 3D conv\n        if to_keras:  # D1 D2 D3 F F\n            weights = weights.transpose(4, 3, 0, 1, 2)\n        else:\n            weights = weights.transpose(2, 3, 4, 1, 0)\n\n        if flip_channels:\n            weights = weights[::-1, ::-1, ::-1]\n\n        if flip_filters:\n            weights = weights[..., ::-1, ::-1, ::-1].copy()\n    else:\n        weights = weights.transpose()\n\n    return weights\n"
  },
  {
    "path": "nn_transfer/util.py",
    "content": "from collections import OrderedDict\n\n_WEIGHT_KEYS = ['kernel', 'beta', 'alpha']\n_WEIGHT_KEYS += [key+':0' for key in _WEIGHT_KEYS]\n\n\ndef state_dict_layer_names(state_dict):\n    layer_names = [\".\".join(k.split('.')[:-1]) for k in state_dict.keys()]\n    # Order preserving unique set of names\n    return list(OrderedDict.fromkeys(layer_names))\n\n\ndef _contains_weights(keras_h5_layer):\n    for key in _WEIGHT_KEYS:\n        if key in keras_h5_layer:\n            return True\n    return False\n\n\ndef dig_to_params(keras_h5_layer):\n    # Params are hidden many layers deep in keras HDF5 files for\n    # some reason. e.g. h5['model_weights']['conv1']['dense_1'] \\\n    # ['dense_2']['dense_3']['conv2d_7']['dense_4']['conv1']\n    while not _contains_weights(keras_h5_layer):\n        keras_h5_layer = keras_h5_layer[list(keras_h5_layer.keys())[0]]\n\n    return keras_h5_layer\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\n\nsetup(\n    name='nn_transfer',\n    version='0.1.0',\n    description='Transfer weights between Keras and PyTorch.',\n    install_requires=[\n        'numpy',\n        'keras',\n        'h5py',\n    ],\n    test_suite='nose.collector',\n    tests_require=['nose'],\n    packages=['nn_transfer'],\n)\n"
  }
]