Full Code of gzuidhof/nn-transfer for AI

master c71038af2308 cached
19 files
47.0 KB
14.6k tokens
68 symbols
1 requests
Download .txt
Repository: gzuidhof/nn-transfer
Branch: master
Commit: c71038af2308
Files: 19
Total size: 47.0 KB

Directory structure:
gitextract_5890qzio/

├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── example.ipynb
├── nn_transfer/
│   ├── __init__.py
│   ├── test/
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── __init__.py
│   │   │   ├── lenet.py
│   │   │   ├── simplenet.py
│   │   │   ├── unet.py
│   │   │   └── vggnet.py
│   │   ├── helpers.py
│   │   ├── test_architectures.py
│   │   ├── test_layers.py
│   │   └── test_util.py
│   ├── transfer.py
│   └── util.py
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

*.hdf5
*.h5
*.pth

notebooks/

================================================
FILE: .travis.yml
================================================
language: python
python:
  # We don't actually use the Travis Python, but this keeps it organized.
  - "2.7"
  - "3.6"
install:
  - sudo apt-get update
  # We do this conditionally because it saves us some downloading if the
  # version is the same.
  - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
      wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;
    else
      wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
    fi
  - bash miniconda.sh -b -p $HOME/miniconda
  - export PATH="$HOME/miniconda/bin:$PATH"
  - hash -r
  - conda config --set always_yes yes --set changeps1 no
  - conda update -q conda
  # Useful for debugging any issues with conda
  - conda info -a

  - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
  - source activate test-environment
  - conda install pytorch torchvision cuda80 numpy scipy h5py -c pytorch -q
  - conda install tensorflow -c conda-forge -q
  - conda install theano pygpu -q
  - python setup.py install

script:
  - TEST_TRANSFER_DIRECTION=keras2pytorch KERAS_BACKEND=theano python setup.py test
  - TEST_TRANSFER_DIRECTION=keras2pytorch KERAS_BACKEND=tensorflow python setup.py test
  - TEST_TRANSFER_DIRECTION=pytorch2keras KERAS_BACKEND=theano python setup.py test
  - TEST_TRANSFER_DIRECTION=pytorch2keras KERAS_BACKEND=tensorflow python setup.py test

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2017 Guido Zuidhof

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# nn-transfer

[![Build Status](https://travis-ci.org/gzuidhof/nn-transfer.svg?branch=master)](https://travis-ci.org/gzuidhof/nn-transfer)

**NOTE: This repository does not seem to yield the correct output anymore with the latest versions of Keras and PyTorch. Take care to verify the results or use an alternative method for conversion.**

This repository contains utilities for **converting PyTorch models to Keras and the other way around**. More specifically, it allows you to copy the weights from a PyTorch model to an identical model in Keras and vice-versa.

From Keras you can then run it on the **TensorFlow**, **Theano** and **CNTK** backend. You can also convert it to a pure TensorFlow model (see [[1]](https://github.com/amir-abdi/keras_to_tensorflow) and [[2]](https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html)), which allows you to choose more robust deployment options in the cloud, or even mobile devices. From Keras you can also do inference in browsers with [keras-js](https://github.com/transcranial/keras-js).

## Installation
Clone this repository, and simply run

```
pip install .
```

You need to have PyTorch and torchvision installed beforehand, see the [PyTorch website](https://www.pytorch.org) for how to easily install that.

## Tests

To run the unit and integration tests:

```
python setup.py test
# OR, if you have nose2 installed,
nose2
```

There is also Travis CI which will automatically build every commit, see the button at the top of the readme. You can test the direction of weight transfer individually using the `TEST_TRANSFER_DIRECTION` environment variable, see `.travis.yml`.

## How to use

See [**example.ipynb**](example.ipynb) for a small tutorial on how to use this library.

## Code guidelines

* This repository is fully PEP8 compliant, I recommend `flake8`.
* It works for both Python 2 and 3.


================================================
FILE: example.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using Theano backend.\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "from collections import OrderedDict\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from nn_transfer import transfer, util\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1\n",
    "Simply define your PyTorch model like usual, and create an instance of it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1   = nn.Linear(16*5*5, 120)\n",
    "        self.fc2   = nn.Linear(120, 84)\n",
    "        self.fc3   = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.conv1(x))\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = F.relu(self.conv2(out))\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = F.relu(self.fc1(out))\n",
    "        out = F.relu(self.fc2(out))\n",
    "        out = self.fc3(out)\n",
    "        return out\n",
    "    \n",
    "pytorch_network = LeNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2\n",
    "Determine the names of the layers.\n",
    "\n",
    "For the above model example it is very straightforward, but if you use param groups it may be a little more involved. To determine the names of the layers the next commands are useful:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LeNet (\n",
      "  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear (400 -> 120)\n",
      "  (fc2): Linear (120 -> 84)\n",
      "  (fc3): Linear (84 -> 10)\n",
      ")\n",
      "['conv1', 'conv2', 'fc1', 'fc2', 'fc3']\n"
     ]
    }
   ],
   "source": [
    "# The most useful, just print the network\n",
    "print(pytorch_network)\n",
    "\n",
    "# Also useful: will only print those layers with params\n",
    "state_dict = pytorch_network.state_dict()\n",
    "print(util.state_dict_layer_names(state_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3\n",
    "Define an equivalent Keras network. Use the built-in `name` keyword argument for each layer with params."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras import backend as K\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "K.set_image_data_format('channels_first')\n",
    "\n",
    "\n",
    "def lenet_keras():\n",
    "\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(6, kernel_size=(5, 5),\n",
    "                     activation='relu',\n",
    "                     input_shape=(1,32,32),\n",
    "                     name='conv1'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(Conv2D(16, (5, 5), activation='relu', name='conv2'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(120, activation='relu', name='fc1'))\n",
    "    model.add(Dense(84, activation='relu', name='fc2'))\n",
    "    model.add(Dense(10, activation=None, name='fc3'))\n",
    "\n",
    "    model.compile(loss=keras.losses.categorical_crossentropy,\n",
    "                  optimizer=keras.optimizers.Adadelta())\n",
    "    \n",
    "    return model\n",
    "    \n",
    "keras_network = lenet_keras()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Step 4\n",
    "Now simply convert!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer names in target ['conv1', 'conv2', 'fc1', 'fc2', 'fc3']\n",
      "Layer names in Keras HDF5 ['conv1', 'conv2', 'fc1', 'fc2', 'fc3', 'flatten_1', 'max_pooling2d_1', 'max_pooling2d_2']\n"
     ]
    }
   ],
   "source": [
    "transfer.keras_to_pytorch(keras_network, pytorch_network)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Done!\n",
    "\n",
    "Now let's check whether it was succesful. If it was, both networks should have the same output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Create dummy data\n",
    "data = torch.rand(6,1,32,32)\n",
    "data_keras = data.numpy()\n",
    "data_pytorch = Variable(data, requires_grad=False)\n",
    "\n",
    "# Do a forward pass in both frameworks\n",
    "keras_pred = keras_network.predict(data_keras)\n",
    "pytorch_pred = pytorch_network(data_pytorch).data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAADrCAYAAABJqHxQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABU5JREFUeJzt28GKVwUYxuFXZ1SqGQNLAjEzjMqNLZKgNkGLLiC6mWgT\nUVcQrdq0iFZRVBvbtbdF4qISlIisLMQkZxR0/HcLnsWfzxeeZz18vBzm/OZsZt9qtQoAPfZPDwBg\nGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0CZzXUcPf3Ve+P/jvnhma+nJyRJjm3emJ6Qd6+8\nNT0hSXLljyenJ+TjVz+fnpAk+Xfv0ekJOX/r2ekJSZIvz5+dnpDTH92cnpAkOXfxg30P8nO+uAHK\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\nygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUGZz\nHUdfP3F5HWcX+eKfs9MTkiSfnfx+ekKu/bc1PSFJstqb/064eOf49IQkydvbF6Yn5NOrr01PSJIc\nuLExPSG3j29PT1hk/k0CYBHhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBug\njHADlBFugDLCDVBGuAHKCDdAmc11HP3h76fXcXaRI4/sTk9Ikly+e2t6Qo5u7UxPSJLc3j00PSFv\nPPbT9IQkyS93n5iekKs3H5+ekCS5/8zt6Qk5+N2P0xMW8cUNUEa4AcoIN0AZ4QYoI9wAZYQboIxw\nA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAym+s4unPn4DrOVtpZreURL/LrxWPT\nE5Ikz535fXpC/to7PD0hSXL93tb0hJw6cn16QpLkwqUT0xOy8fyp6QmL+OIGKCPcAGWEG6CMcAOU\nEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHAD\nlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYos7mOo0e3d9Zx\ndpFDG/emJyRJDuT+9ISsHpI/zxv755/F9v470xOSJHc31vLqLXJtd2t6QpJk3+7G9ISsfrs6PWGR\nh+SVBuBBCTdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAy\nwg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6A\nMsINUGZzHUf/vHF4HWcXeeelc9MTkiS7q7U84kVeOXtpekKS5MonL0xPyMn3b01PSJLcX81/M715\n7OfpCUmSb759anpC9l5+cXrCIvO/PQAsItwAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABl\nhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wA\nZYQboIxwA5QRboAywg1QRrgBygg3QBnhBiizb7VaTW8AYAFf3ABlhBugjHADlBFugDLCDVBGuAHK\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\nygg3QJn/ATUPZi0NSPeeAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f588249dcd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAADrCAYAAABJqHxQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABU5JREFUeJzt28GKVwUYxuFXZ1SqGQNLAjEzjMqNLZKgNkGLLiC6mWgT\nUVcQrdq0iFZRVBvbtbdF4qISlIisLMQkZxR0/HcLnsWfzxeeZz18vBzm/OZsZt9qtQoAPfZPDwBg\nGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0CZzXUcPf3Ve+P/jvnhma+nJyRJjm3emJ6Qd6+8\nNT0hSXLljyenJ+TjVz+fnpAk+Xfv0ekJOX/r2ekJSZIvz5+dnpDTH92cnpAkOXfxg30P8nO+uAHK\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\nygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUGZz\nHUdfP3F5HWcX+eKfs9MTkiSfnfx+ekKu/bc1PSFJstqb/064eOf49IQkydvbF6Yn5NOrr01PSJIc\nuLExPSG3j29PT1hk/k0CYBHhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBug\njHADlBFugDLCDVBGuAHKCDdAmc11HP3h76fXcXaRI4/sTk9Ikly+e2t6Qo5u7UxPSJLc3j00PSFv\nPPbT9IQkyS93n5iekKs3H5+ekCS5/8zt6Qk5+N2P0xMW8cUNUEa4AcoIN0AZ4QYoI9wAZYQboIxw\nA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CM\ncAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAym+s4unPn4DrOVtpZreURL/LrxWPT\nE5Ikz535fXpC/to7PD0hSXL93tb0hJw6cn16QpLkwqUT0xOy8fyp6QmL+OIGKCPcAGWEG6CMcAOU\nEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABlhBugjHAD\nlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYos7mOo0e3d9Zx\ndpFDG/emJyRJDuT+9ISsHpI/zxv755/F9v470xOSJHc31vLqLXJtd2t6QpJk3+7G9ISsfrs6PWGR\nh+SVBuBBCTdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAy\nwg1QRrgBygg3QBnhBigj3ABlhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6A\nMsINUGZzHUf/vHF4HWcXeeelc9MTkiS7q7U84kVeOXtpekKS5MonL0xPyMn3b01PSJLcX81/M715\n7OfpCUmSb759anpC9l5+cXrCIvO/PQAsItwAZYQboIxwA5QRboAywg1QRrgBygg3QBnhBigj3ABl\nhBugjHADlBFugDLCDVBGuAHKCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wA\nZYQboIxwA5QRboAywg1QRrgBygg3QBnhBiizb7VaTW8AYAFf3ABlhBugjHADlBFugDLCDVBGuAHK\nCDdAGeEGKCPcAGWEG6CMcAOUEW6AMsINUEa4AcoIN0AZ4QYoI9wAZYQboIxwA5QRboAywg1QRrgB\nygg3QJn/ATUPZi0NSPeeAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f58824a9a90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "assert keras_pred.shape == pytorch_pred.shape\n",
    "\n",
    "plt.axis('Off')\n",
    "plt.imshow(keras_pred)\n",
    "plt.show()\n",
    "plt.axis('Off')\n",
    "plt.imshow(pytorch_pred)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "They are the same, it works :)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}


================================================
FILE: nn_transfer/__init__.py
================================================


================================================
FILE: nn_transfer/test/__init__.py
================================================


================================================
FILE: nn_transfer/test/architectures/__init__.py
================================================


================================================
FILE: nn_transfer/test/architectures/lenet.py
================================================
import torch.nn as nn
import torch.nn.functional as F

import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D

K.set_image_data_format('channels_first')


class LeNetPytorch(nn.Module):
    def __init__(self):
        super(LeNetPytorch, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out


def lenet_keras():

    model = Sequential()
    model.add(Conv2D(6, kernel_size=(5, 5),
                     activation='relu',
                     input_shape=(1, 32, 32),
                     name='conv1'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(16, (5, 5), activation='relu', name='conv2'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(120, activation='relu', name='fc1'))
    model.add(Dense(84, activation='relu', name='fc2'))
    model.add(Dense(10, activation=None, name='fc3'))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.SGD())

    return model


================================================
FILE: nn_transfer/test/architectures/simplenet.py
================================================
import torch.nn as nn
import torch.nn.functional as F

import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization

K.set_image_data_format('channels_first')


class SimpleNetPytorch(nn.Module):
    def __init__(self):
        super(SimpleNetPytorch, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.bn = nn.BatchNorm2d(6)
        self.fc1 = nn.Linear(6 * 14 * 14, 10)

    def forward(self, x):
        out = F.relu(self.bn(self.conv1(x)))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        return out


def simplenet_keras():
    model = Sequential()
    model.add(Conv2D(6, kernel_size=(5, 5),
                     activation='relu',
                     input_shape=(1, 32, 32),
                     name='conv1'))
    model.add(BatchNormalization(axis=1, name='bn'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(10, activation=None, name='fc1'))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.SGD())

    return model


================================================
FILE: nn_transfer/test/architectures/unet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

import keras
from keras import backend as K
from keras.models import Input, Model
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Conv2DTranspose, concatenate

K.set_image_data_format('channels_first')

# From https://github.com/jocicmarko/ultrasound-nerve-segmentation


def unet_keras(input_size=224):
    inputs = Input((1, input_size, input_size))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',
                   name='conv_block1_32.conv')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',
                   name='conv_block1_32.conv2')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',
                   name='conv_block32_64.conv')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',
                   name='conv_block32_64.conv2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',
                   name='conv_block64_128.conv')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',
                   name='conv_block64_128.conv2')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',
                   name='conv_block128_256.conv')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',
                   name='conv_block128_256.conv2')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',
                   name='conv_block256_512.conv')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',
                   name='conv_block256_512.conv2')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2),
                                       padding='valid',
                                       name='up_block512_256.up')(conv5),
                       conv4], axis=1)
    conv6 = Conv2D(256, (3, 3), activation='relu',
                   padding='same', name='up_block512_256.conv')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same',
                   name='up_block512_256.conv2')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2),
                                       padding='valid',
                                       name='up_block256_128.up')(conv6),
                       conv3], axis=1)
    conv7 = Conv2D(128, (3, 3), activation='relu',
                   padding='same', name='up_block256_128.conv')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',
                   name='up_block256_128.conv2')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2),
                                       padding='valid',
                                       name='up_block128_64.up')(conv7),
                       conv2], axis=1)
    conv8 = Conv2D(64, (3, 3), activation='relu',
                   padding='same', name='up_block128_64.conv')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',
                   name='up_block128_64.conv2')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2),
                                       padding='valid',
                                       name='up_block64_32.up')(conv8),
                       conv1], axis=1)
    conv9 = Conv2D(32, (3, 3), activation='relu',
                   padding='same', name='up_block64_32.conv')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',
                   name='up_block64_32.conv2')(conv9)

    conv10 = Conv2D(2, (1, 1), activation=None, name='last')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])
    model.compile(optimizer=keras.optimizers.SGD(),
                  loss=keras.losses.categorical_crossentropy)

    return model


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3,
                 activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation

    def center_crop(self, layer, target_size):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_size) // 2
        return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.size()[2])
        out = torch.cat([up, crop1], 1)
        out = self.activation(self.conv(out))
        out = self.activation(self.conv2(out))

        return out


class UNetPytorch(nn.Module):
    def __init__(self):
        super(UNetPytorch, self).__init__()

        self.activation = F.relu

        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)
        self.pool4 = nn.MaxPool2d(2)

        self.conv_block1_32 = UNetConvBlock(1, 32)
        self.conv_block32_64 = UNetConvBlock(32, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)

        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.up_block512_256 = UNetUpBlock(512, 256)

        self.up_block256_128 = UNetUpBlock(256, 128)
        self.up_block128_64 = UNetUpBlock(128, 64)
        self.up_block64_32 = UNetUpBlock(64, 32)

        self.last = nn.Conv2d(32, 2, 1)

    def forward(self, x):

        block1 = self.conv_block1_32(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block32_64(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block64_128(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block128_256(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block256_512(pool4)

        up1 = self.up_block512_256(block5, block4)
        up2 = self.up_block256_128(up1, block3)
        up3 = self.up_block128_64(up2, block2)
        up4 = self.up_block64_32(up3, block1)

        return self.last(up4)


if __name__ == "__main__":
    net = UNetPytorch()


================================================
FILE: nn_transfer/test/architectures/vggnet.py
================================================
from torchvision.models import vgg

from keras import backend as K
from keras.models import Input, Model
from keras.layers import Dense, Flatten, Dropout
from keras.layers import Conv2D, MaxPooling2D

K.set_image_data_format('channels_first')


def vggnet_pytorch():
    return vgg.vgg16()


def vggnet_keras():

    # Block 1
    img_input = Input((3, 224, 224))
    x = Conv2D(64, (3, 3), activation='relu',
               padding='same', name='features.0')(img_input)
    x = Conv2D(64, (3, 3), activation='relu',
               padding='same', name='features.2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    x = Conv2D(128, (3, 3), activation='relu',
               padding='same', name='features.5')(x)
    x = Conv2D(128, (3, 3), activation='relu',
               padding='same', name='features.7')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu',
               padding='same', name='features.10')(x)
    x = Conv2D(256, (3, 3), activation='relu',
               padding='same', name='features.12')(x)
    x = Conv2D(256, (3, 3), activation='relu',
               padding='same', name='features.14')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.17')(x)
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.19')(x)
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.21')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Block 5
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.24')(x)
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.26')(x)
    x = Conv2D(512, (3, 3), activation='relu',
               padding='same', name='features.28')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)

    x = Flatten(name='flatten')(x)
    x = Dense(4096, activation='relu', name='classifier.0')(x)
    x = Dropout(0.5)(x)
    x = Dense(4096, activation='relu', name='classifier.3')(x)
    x = Dropout(0.5)(x)
    x = Dense(1000, activation=None, name='classifier.6')(x)

    # Create model.
    model = Model(img_input, x, name='vgg16')

    return model


================================================
FILE: nn_transfer/test/helpers.py
================================================
from __future__ import print_function
import os

import numpy as np
import torch
from torch.autograd import Variable

from .. import transfer

if 'TEST_TRANSFER_DIRECTION' in os.environ:
    TRANSFER_DIRECTION = os.environ['TEST_TRANSFER_DIRECTION'].lower()
else:
    TRANSFER_DIRECTION = 'keras2pytorch'

print(TRANSFER_DIRECTION, "tests")


def set_seeds():
    torch.manual_seed(0)
    np.random.seed(0)


class TransferTestCase(object):
    def assertEqualPrediction(
            self, keras_model, pytorch_model, test_data, delta=1e-6):

        # Make sure the pytorch model is in evaluation mode (i.e. no dropout)
        pytorch_model.eval()

        test_data = test_data.astype(np.float32, copy=False)
        test_data_tensor = Variable(
            torch.from_numpy(test_data),
            requires_grad=False)

        keras_prediction = keras_model.predict(test_data)
        pytorch_prediction = pytorch_model(test_data_tensor).data.numpy()

        self.assertEqual(keras_prediction.shape, pytorch_prediction.shape)
        for v1, v2 in zip(keras_prediction.flatten(),
                          pytorch_prediction.flatten()):
            self.assertAlmostEqual(v1, v2, delta=delta)
        return keras_prediction, pytorch_prediction

    def is_keras_to_pytorch(self):
        return TRANSFER_DIRECTION == 'keras2pytorch'

    def transfer(self, keras_model, pytorch_model, verbose=False):

        if self.is_keras_to_pytorch():
            transfer.keras_to_pytorch(keras_model,
                                      pytorch_model,
                                      verbose=verbose)
        else:
            transfer.pytorch_to_keras(pytorch_model,
                                      keras_model,
                                      verbose=verbose)


================================================
FILE: nn_transfer/test/test_architectures.py
================================================
import unittest

import numpy as np

from .helpers import TransferTestCase, set_seeds
from .architectures.lenet import lenet_keras, LeNetPytorch
from .architectures.simplenet import simplenet_keras, SimpleNetPytorch
from .architectures.vggnet import vggnet_keras, vggnet_pytorch
from .architectures.unet import unet_keras, UNetPytorch


class TestArchitectures(TransferTestCase, unittest.TestCase):

    def setUp(self):
        self.test_data_small = np.random.rand(4, 1, 32, 32)
        self.test_data_vgg = np.random.rand(1, 3, 224, 224)
        self.test_data_unet = np.random.rand(1, 1, 224, 224)
        set_seeds()

    def test_simplenet(self):
        keras_model = simplenet_keras()
        pytorch_model = SimpleNetPytorch()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(
            keras_model,
            pytorch_model,
            self.test_data_small,
            delta=1e-3)  # These results can vary due to float imprecision

    def test_lenet(self):
        keras_model = lenet_keras()
        pytorch_model = LeNetPytorch()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(
            keras_model,
            pytorch_model,
            self.test_data_small)

    def test_unet(self):
        keras_model = unet_keras()
        pytorch_model = UNetPytorch()
        pytorch_model.eval()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(
            keras_model,
            pytorch_model,
            self.test_data_unet)

    def test_vggnet(self):
        keras_model = vggnet_keras()
        pytorch_model = vggnet_pytorch()
        pytorch_model.eval()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(
            keras_model, pytorch_model, self.test_data_vgg)


if __name__ == '__main__':
    unittest.main()


================================================
FILE: nn_transfer/test/test_layers.py
================================================
import unittest

import numpy as np
import torch.nn as nn

import keras
from keras.models import Sequential
from keras.layers import BatchNormalization, PReLU, ELU
from keras.layers import Conv2DTranspose, Conv2D, Conv3D

from .helpers import TransferTestCase

keras.backend.set_image_data_format('channels_first')


class BatchNet(nn.Module):
    def __init__(self):
        super(BatchNet, self).__init__()
        self.bn = nn.BatchNorm3d(3)

    def forward(self, x):
        return self.bn(x)


class ELUNet(nn.Module):
    def __init__(self):
        super(ELUNet, self).__init__()
        self.elu = nn.ELU()

    def forward(self, x):
        return self.elu(x)


class TransposeNet(nn.Module):
    def __init__(self):
        super(TransposeNet, self).__init__()
        self.trans = nn.ConvTranspose2d(3, 32, 2, 2)

    def forward(self, x):
        return self.trans(x)


class PReLUNet(nn.Module):
    def __init__(self):
        super(PReLUNet, self).__init__()
        self.prelu = nn.PReLU(3)

    def forward(self, x):
        return self.prelu(x)


class Conv2DNet(nn.Module):
    def __init__(self):
        super(Conv2DNet, self).__init__()
        self.conv = nn.Conv2d(3, 16, 7)

    def forward(self, x):
        return self.conv(x)


class Conv3DNet(nn.Module):
    def __init__(self):
        super(Conv3DNet, self).__init__()
        self.conv = nn.Conv3d(3, 8, 5)

    def forward(self, x):
        return self.conv(x)


class TestLayers(TransferTestCase, unittest.TestCase):

    def setUp(self):
        self.test_data = np.random.rand(2, 3, 32, 32)
        self.test_data_3d = np.random.rand(2, 3, 8, 8, 8)

    def test_batch_normalization(self):
        keras_model = Sequential()
        keras_model.add(
            BatchNormalization(input_shape=(3, 32, 32), axis=1, name='bn'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = BatchNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(
            keras_model, pytorch_model, self.test_data, 1e-3)

    def test_transposed_conv(self):
        keras_model = Sequential()
        keras_model.add(Conv2DTranspose(32, (2, 2), strides=(
            2, 2), input_shape=(3, 32, 32), name='trans'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = TransposeNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)

    # Tests special activation function
    def test_elu(self):
        keras_model = Sequential()
        keras_model.add(ELU(input_shape=(3, 32, 32), name='elu'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = ELUNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)

    # Tests activation function with learned parameters
    def test_prelu(self):
        keras_model = Sequential()
        keras_model.add(PReLU(input_shape=(3, 32, 32), shared_axes=(2, 3),
                              name='prelu'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = PReLUNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(keras_model, pytorch_model, self.test_data)

    def test_conv2d(self):
        keras_model = Sequential()
        keras_model.add(Conv2D(16, (7, 7), input_shape=(3, 32, 32),
                               name='conv'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = Conv2DNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(keras_model,
                                   pytorch_model,
                                   self.test_data,
                                   delta=1e-4)


    def test_conv3d(self):
        keras_model = Sequential()
        keras_model.add(Conv3D(8, (5, 5, 5), input_shape=(3, 8, 8, 8),
                               name='conv'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = Conv3DNet()

        self.transfer(keras_model, pytorch_model)
        self.assertEqualPrediction(keras_model,
                                   pytorch_model,
                                   self.test_data_3d,
                                   delta=1e-4)

    def test_keras_model_changed_as_expected(self):
        keras_model = Sequential()
        keras_model.add(Conv2D(16, (7, 7), input_shape=(3, 32, 32),
                               name='conv'))
        keras_model.compile(loss=keras.losses.categorical_crossentropy,
                            optimizer=keras.optimizers.SGD())

        pytorch_model = Conv2DNet()

        weights_before = keras_model.layers[0].get_weights()[0]
        prediction_before = keras_model.predict(self.test_data)

        self.transfer(keras_model, pytorch_model)

        weights_after = keras_model.layers[0].get_weights()[0]
        prediction_after = keras_model.predict(self.test_data)

        if self.is_keras_to_pytorch():  # Keras model should be unchanged
            self.assertTrue((weights_before == weights_after).all())
            self.assertEqual(
                prediction_before.tobytes(),
                prediction_after.tobytes(),
                msg="Predictions not are exactly the same")
        else:
            self.assertFalse((weights_before == weights_after).all())
            self.assertNotEqual(
                prediction_before.tobytes(),
                prediction_after.tobytes(),
                msg="Predictions are exactly the same")


if __name__ == '__main__':
    unittest.main()


================================================
FILE: nn_transfer/test/test_util.py
================================================
import unittest

from .. import util


class TestUtil(unittest.TestCase):

    def test_state_dict_names(self):
        state_dict = {
            'conv1.weight': 0,
            'conv1.bias': 0,
            'fc1.weight': 0,
            'fc2.weight': 0,
            'fc2.bias': 0
        }
        layer_names = util.state_dict_layer_names(state_dict)

        self.assertListEqual(sorted(layer_names), ['conv1', 'fc1', 'fc2'])


if __name__ == '__main__':
    unittest.main()


================================================
FILE: nn_transfer/transfer.py
================================================
from __future__ import print_function
from collections import OrderedDict

import numpy as np
import h5py
import keras
import torch

from . import util

VAR_AFFIX = ':0' if keras.backend.backend() == 'tensorflow' else ''

KERAS_GAMMA_KEY = 'gamma' + VAR_AFFIX
KERAS_KERNEL_KEY = 'kernel' + VAR_AFFIX
KERAS_ALPHA_KEY = 'alpha' + VAR_AFFIX
KERAS_BIAS_KEY = 'bias' + VAR_AFFIX
KERAS_BETA_KEY = 'beta' + VAR_AFFIX
KERAS_MOVING_MEAN_KEY = 'moving_mean' + VAR_AFFIX
KERAS_MOVING_VARIANCE_KEY = 'moving_variance' + VAR_AFFIX
KERAS_EPSILON = 1e-3
PYTORCH_EPSILON = 1e-5


def check_for_missing_layers(keras_names, pytorch_layer_names, verbose):

    if verbose:
        print("Layer names in PyTorch state_dict", pytorch_layer_names)
        print("Layer names in Keras HDF5", keras_names)

    if not all(x in keras_names for x in pytorch_layer_names):
        missing_layers = list(set(pytorch_layer_names) - set(keras_names))
        raise Exception("Missing layer(s) in Keras HDF5 that are present" +
                        " in state_dict: {}".format(missing_layers))


def keras_to_pytorch(keras_model, pytorch_model,
                     flip_filters=None, verbose=True):

    # If not specifically set, determine whether to flip filters automatically
    # for the right backend.
    if flip_filters is None:
        flip_filters = not keras.backend.backend() == 'tensorflow'

    keras_model.save('temp.h5')
    input_state_dict = pytorch_model.state_dict()
    pytorch_layer_names = util.state_dict_layer_names(input_state_dict)

    with h5py.File('temp.h5', 'r') as f:
        model_weights = f['model_weights']
        layer_names = list(map(str, model_weights.keys()))
        check_for_missing_layers(layer_names, pytorch_layer_names, verbose)
        state_dict = OrderedDict()

        for layer in pytorch_layer_names:

            params = util.dig_to_params(model_weights[layer])

            weight_key = layer + '.weight'
            bias_key = layer + '.bias'
            running_mean_key = layer + '.running_mean'
            running_var_key = layer + '.running_var'

            # Load weights (or other learned parameters)
            if weight_key in input_state_dict:
                if KERAS_GAMMA_KEY in params:
                    weights = params[KERAS_GAMMA_KEY][:]
                elif KERAS_KERNEL_KEY in params:
                    weights = params[KERAS_KERNEL_KEY][:]
                else:
                    weights = np.squeeze(params[KERAS_ALPHA_KEY][:])

                weights = convert_weights(weights,
                                          to_keras=True,
                                          flip_filters=flip_filters)

                state_dict[weight_key] = torch.from_numpy(weights)

            # Load bias
            if bias_key in input_state_dict:
                if running_var_key in input_state_dict:
                    bias = params[KERAS_BETA_KEY][:]
                else:
                    bias = params[KERAS_BIAS_KEY][:]
                state_dict[bias_key] = torch.from_numpy(
                    bias.transpose())

            # Load batch normalization running mean
            if running_mean_key in input_state_dict:
                running_mean = params[KERAS_MOVING_MEAN_KEY][:]
                state_dict[running_mean_key] = torch.from_numpy(
                    running_mean.transpose())

            # Load batch normalization running variance
            if running_var_key in input_state_dict:
                running_var = params[KERAS_MOVING_VARIANCE_KEY][:]
                # account for difference in epsilon used
                running_var += KERAS_EPSILON - PYTORCH_EPSILON
                state_dict[running_var_key] = torch.from_numpy(
                    running_var.transpose())

    pytorch_model.load_state_dict(state_dict)


def pytorch_to_keras(pytorch_model, keras_model,
                     flip_filters=False, flip_channels=None, verbose=True):

    if flip_channels is None:
        flip_channels = not keras.backend.backend() == 'tensorflow'

    keras_model.save('temp.h5')
    input_state_dict = pytorch_model.state_dict()
    pytorch_layer_names = util.state_dict_layer_names(input_state_dict)

    with h5py.File('temp.h5', 'a') as f:
        model_weights = f['model_weights']
        target_layer_names = list(map(str, model_weights.keys()))
        check_for_missing_layers(
            target_layer_names,
            pytorch_layer_names,
            verbose)

        for layer in pytorch_layer_names:

            params = util.dig_to_params(model_weights[layer])

            weight_key = layer + '.weight'
            bias_key = layer + '.bias'
            running_mean_key = layer + '.running_mean'
            running_var_key = layer + '.running_var'

            # Load weights (or other learned parameters)
            if weight_key in input_state_dict:
                weights = input_state_dict[weight_key].numpy()
                weights = convert_weights(weights,
                                          to_keras=False,
                                          flip_filters=flip_filters,
                                          flip_channels=flip_channels)

                if KERAS_GAMMA_KEY in params:
                    params[KERAS_GAMMA_KEY][:] = weights
                elif KERAS_KERNEL_KEY in params:
                    params[KERAS_KERNEL_KEY][:] = weights
                else:
                    weights = weights.reshape(params[KERAS_ALPHA_KEY][:].shape)
                    params[KERAS_ALPHA_KEY][:] = weights

            # Load bias
            if bias_key in input_state_dict:
                bias = input_state_dict[bias_key].numpy()
                if running_var_key in input_state_dict:
                    params[KERAS_BETA_KEY][:] = bias
                else:
                    params[KERAS_BIAS_KEY][:] = bias

            # Load batch normalization running mean
            if running_mean_key in input_state_dict:
                running_mean = input_state_dict[running_mean_key].numpy()
                params[KERAS_MOVING_MEAN_KEY][:] = running_mean

            # Load batch normalization running variance
            if running_var_key in input_state_dict:
                running_var = input_state_dict[running_var_key].numpy()
                # account for difference in epsilon used
                running_var += PYTORCH_EPSILON - KERAS_EPSILON
                params[KERAS_MOVING_VARIANCE_KEY][:] = running_var

    # pytorch_model.load_state_dict(state_dict)
    keras_model.load_weights('temp.h5')


def convert_weights(weights, to_keras, flip_filters, flip_channels=False):

    if len(weights.shape) == 3:  # 1D conv
        weights = weights.transpose()

        if flip_channels:
            weights = weights[::-1]

        if flip_filters:
            weights = weights[..., ::-1].copy()

    if len(weights.shape) == 4:  # 2D conv
        if to_keras:  # D1 D2 F F
            weights = weights.transpose(3, 2, 0, 1)
        else:
            weights = weights.transpose(2, 3, 1, 0)

        if flip_channels:
            weights = weights[::-1, ::-1]
        if flip_filters:
            weights = weights[..., ::-1, ::-1].copy()

    elif len(weights.shape) == 5:  # 3D conv
        if to_keras:  # D1 D2 D3 F F
            weights = weights.transpose(4, 3, 0, 1, 2)
        else:
            weights = weights.transpose(2, 3, 4, 1, 0)

        if flip_channels:
            weights = weights[::-1, ::-1, ::-1]

        if flip_filters:
            weights = weights[..., ::-1, ::-1, ::-1].copy()
    else:
        weights = weights.transpose()

    return weights


================================================
FILE: nn_transfer/util.py
================================================
from collections import OrderedDict

_WEIGHT_KEYS = ['kernel', 'beta', 'alpha']
_WEIGHT_KEYS += [key+':0' for key in _WEIGHT_KEYS]


def state_dict_layer_names(state_dict):
    layer_names = [".".join(k.split('.')[:-1]) for k in state_dict.keys()]
    # Order preserving unique set of names
    return list(OrderedDict.fromkeys(layer_names))


def _contains_weights(keras_h5_layer):
    for key in _WEIGHT_KEYS:
        if key in keras_h5_layer:
            return True
    return False


def dig_to_params(keras_h5_layer):
    # Params are hidden many layers deep in keras HDF5 files for
    # some reason. e.g. h5['model_weights']['conv1']['dense_1'] \
    # ['dense_2']['dense_3']['conv2d_7']['dense_4']['conv1']
    while not _contains_weights(keras_h5_layer):
        keras_h5_layer = keras_h5_layer[list(keras_h5_layer.keys())[0]]

    return keras_h5_layer


================================================
FILE: setup.py
================================================
from setuptools import setup

setup(
    name='nn_transfer',
    version='0.1.0',
    description='Transfer weights between Keras and PyTorch.',
    install_requires=[
        'numpy',
        'keras',
        'h5py',
    ],
    test_suite='nose.collector',
    tests_require=['nose'],
    packages=['nn_transfer'],
)
Download .txt
gitextract_5890qzio/

├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── example.ipynb
├── nn_transfer/
│   ├── __init__.py
│   ├── test/
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── __init__.py
│   │   │   ├── lenet.py
│   │   │   ├── simplenet.py
│   │   │   ├── unet.py
│   │   │   └── vggnet.py
│   │   ├── helpers.py
│   │   ├── test_architectures.py
│   │   ├── test_layers.py
│   │   └── test_util.py
│   ├── transfer.py
│   └── util.py
└── setup.py
Download .txt
SYMBOL INDEX (68 symbols across 10 files)

FILE: nn_transfer/test/architectures/lenet.py
  class LeNetPytorch (line 13) | class LeNetPytorch(nn.Module):
    method __init__ (line 14) | def __init__(self):
    method forward (line 22) | def forward(self, x):
  function lenet_keras (line 34) | def lenet_keras():

FILE: nn_transfer/test/architectures/simplenet.py
  class SimpleNetPytorch (line 13) | class SimpleNetPytorch(nn.Module):
    method __init__ (line 14) | def __init__(self):
    method forward (line 20) | def forward(self, x):
  function simplenet_keras (line 28) | def simplenet_keras():

FILE: nn_transfer/test/architectures/unet.py
  function unet_keras (line 16) | def unet_keras(input_size=224):
  class UNetConvBlock (line 92) | class UNetConvBlock(nn.Module):
    method __init__ (line 93) | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
    method forward (line 99) | def forward(self, x):
  class UNetUpBlock (line 105) | class UNetUpBlock(nn.Module):
    method __init__ (line 106) | def __init__(self, in_size, out_size, kernel_size=3,
    method center_crop (line 114) | def center_crop(self, layer, target_size):
    method forward (line 119) | def forward(self, x, bridge):
  class UNetPytorch (line 129) | class UNetPytorch(nn.Module):
    method __init__ (line 130) | def __init__(self):
    method forward (line 154) | def forward(self, x):

FILE: nn_transfer/test/architectures/vggnet.py
  function vggnet_pytorch (line 11) | def vggnet_pytorch():
  function vggnet_keras (line 15) | def vggnet_keras():

FILE: nn_transfer/test/helpers.py
  function set_seeds (line 18) | def set_seeds():
  class TransferTestCase (line 23) | class TransferTestCase(object):
    method assertEqualPrediction (line 24) | def assertEqualPrediction(
    method is_keras_to_pytorch (line 44) | def is_keras_to_pytorch(self):
    method transfer (line 47) | def transfer(self, keras_model, pytorch_model, verbose=False):

FILE: nn_transfer/test/test_architectures.py
  class TestArchitectures (line 12) | class TestArchitectures(TransferTestCase, unittest.TestCase):
    method setUp (line 14) | def setUp(self):
    method test_simplenet (line 20) | def test_simplenet(self):
    method test_lenet (line 31) | def test_lenet(self):
    method test_unet (line 41) | def test_unet(self):
    method test_vggnet (line 52) | def test_vggnet(self):

FILE: nn_transfer/test/test_layers.py
  class BatchNet (line 16) | class BatchNet(nn.Module):
    method __init__ (line 17) | def __init__(self):
    method forward (line 21) | def forward(self, x):
  class ELUNet (line 25) | class ELUNet(nn.Module):
    method __init__ (line 26) | def __init__(self):
    method forward (line 30) | def forward(self, x):
  class TransposeNet (line 34) | class TransposeNet(nn.Module):
    method __init__ (line 35) | def __init__(self):
    method forward (line 39) | def forward(self, x):
  class PReLUNet (line 43) | class PReLUNet(nn.Module):
    method __init__ (line 44) | def __init__(self):
    method forward (line 48) | def forward(self, x):
  class Conv2DNet (line 52) | class Conv2DNet(nn.Module):
    method __init__ (line 53) | def __init__(self):
    method forward (line 57) | def forward(self, x):
  class Conv3DNet (line 61) | class Conv3DNet(nn.Module):
    method __init__ (line 62) | def __init__(self):
    method forward (line 66) | def forward(self, x):
  class TestLayers (line 70) | class TestLayers(TransferTestCase, unittest.TestCase):
    method setUp (line 72) | def setUp(self):
    method test_batch_normalization (line 76) | def test_batch_normalization(self):
    method test_transposed_conv (line 89) | def test_transposed_conv(self):
    method test_elu (line 102) | def test_elu(self):
    method test_prelu (line 114) | def test_prelu(self):
    method test_conv2d (line 126) | def test_conv2d(self):
    method test_conv3d (line 142) | def test_conv3d(self):
    method test_keras_model_changed_as_expected (line 157) | def test_keras_model_changed_as_expected(self):

FILE: nn_transfer/test/test_util.py
  class TestUtil (line 6) | class TestUtil(unittest.TestCase):
    method test_state_dict_names (line 8) | def test_state_dict_names(self):

FILE: nn_transfer/transfer.py
  function check_for_missing_layers (line 24) | def check_for_missing_layers(keras_names, pytorch_layer_names, verbose):
  function keras_to_pytorch (line 36) | def keras_to_pytorch(keras_model, pytorch_model,
  function pytorch_to_keras (line 104) | def pytorch_to_keras(pytorch_model, keras_model,
  function convert_weights (line 171) | def convert_weights(weights, to_keras, flip_filters, flip_channels=False):

FILE: nn_transfer/util.py
  function state_dict_layer_names (line 7) | def state_dict_layer_names(state_dict):
  function _contains_weights (line 13) | def _contains_weights(keras_h5_layer):
  function dig_to_params (line 20) | def dig_to_params(keras_h5_layer):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (51K chars).
[
  {
    "path": ".gitignore",
    "chars": 1187,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".travis.yml",
    "chars": 1408,
    "preview": "language: python\npython:\n  # We don't actually use the Travis Python, but this keeps it organized.\n  - \"2.7\"\n  - \"3.6\"\ni"
  },
  {
    "path": "LICENSE",
    "chars": 1070,
    "preview": "MIT License\n\nCopyright (c) 2017 Guido Zuidhof\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
  },
  {
    "path": "README.md",
    "chars": 1884,
    "preview": "# nn-transfer\n\n[![Build Status](https://travis-ci.org/gzuidhof/nn-transfer.svg?branch=master)](https://travis-ci.org/gzu"
  },
  {
    "path": "example.ipynb",
    "chars": 11258,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"out"
  },
  {
    "path": "nn_transfer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nn_transfer/test/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nn_transfer/test/architectures/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "nn_transfer/test/architectures/lenet.py",
    "chars": 1595,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom keras.models imp"
  },
  {
    "path": "nn_transfer/test/architectures/simplenet.py",
    "chars": 1237,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom keras.models imp"
  },
  {
    "path": "nn_transfer/test/architectures/unet.py",
    "chars": 6957,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport keras\nfrom keras import backend as K\nfrom ker"
  },
  {
    "path": "nn_transfer/test/architectures/vggnet.py",
    "chars": 2435,
    "preview": "from torchvision.models import vgg\n\nfrom keras import backend as K\nfrom keras.models import Input, Model\nfrom keras.laye"
  },
  {
    "path": "nn_transfer/test/helpers.py",
    "chars": 1780,
    "preview": "from __future__ import print_function\nimport os\n\nimport numpy as np\nimport torch\nfrom torch.autograd import Variable\n\nfr"
  },
  {
    "path": "nn_transfer/test/test_architectures.py",
    "chars": 1890,
    "preview": "import unittest\n\nimport numpy as np\n\nfrom .helpers import TransferTestCase, set_seeds\nfrom .architectures.lenet import l"
  },
  {
    "path": "nn_transfer/test/test_layers.py",
    "chars": 6139,
    "preview": "import unittest\n\nimport numpy as np\nimport torch.nn as nn\n\nimport keras\nfrom keras.models import Sequential\nfrom keras.l"
  },
  {
    "path": "nn_transfer/test/test_util.py",
    "chars": 476,
    "preview": "import unittest\n\nfrom .. import util\n\n\nclass TestUtil(unittest.TestCase):\n\n    def test_state_dict_names(self):\n        "
  },
  {
    "path": "nn_transfer/transfer.py",
    "chars": 7643,
    "preview": "from __future__ import print_function\nfrom collections import OrderedDict\n\nimport numpy as np\nimport h5py\nimport keras\ni"
  },
  {
    "path": "nn_transfer/util.py",
    "chars": 864,
    "preview": "from collections import OrderedDict\n\n_WEIGHT_KEYS = ['kernel', 'beta', 'alpha']\n_WEIGHT_KEYS += [key+':0' for key in _WE"
  },
  {
    "path": "setup.py",
    "chars": 318,
    "preview": "from setuptools import setup\n\nsetup(\n    name='nn_transfer',\n    version='0.1.0',\n    description='Transfer weights betw"
  }
]

About this extraction

This page contains the full source code of the gzuidhof/nn-transfer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (47.0 KB), approximately 14.6k tokens, and a symbol index with 68 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!