[
  {
    "path": ".gitignore",
    "content": "# OS specific\n*.DS_Store\n\n# Python\n/build\n/dist\n__pycache__\n*.ipynb_checkpoints\n*.egg-info\n\n# Vim\n*.vim\n*.swk\n*.swl\n*.swm\n*.swn\n*.swo\n*.swp\n"
  },
  {
    "path": "LICENSE",
    "content": "Modified MIT License\n\nSoftware Copyright (c) 2021 OpenAI\n\nWe don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to\ndo with as you please. We only ask that you use the model responsibly and clearly indicate that it\nwas used.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and\nassociated documentation files (the \"Software\"), to deal in the Software without restriction,\nincluding without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,\nand/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,\nsubject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included\nin all copies or substantial portions of the Software.\nThe above copyright notice and this permission notice need not be included\nwith content created by the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,\nINCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS\nBE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\nTORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE\nOR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Overview\n\n[[Blog]](https://openai.com/blog/dall-e/) [[Paper]](https://arxiv.org/abs/2102.12092) [[Model Card]](model_card.md) [[Usage]](notebooks/usage.ipynb)\n\nThis is the official PyTorch package for the discrete VAE used for DALL·E. The transformer used to generate the images from the text is not part of this code release.\n\n# Installation\n\nBefore running [the example notebook](notebooks/usage.ipynb), you will need to install the package using\n\n\tpip install DALL-E\n"
  },
  {
    "path": "dall_e/__init__.py",
    "content": "import io, requests\nimport torch\nimport torch.nn as nn\n\nfrom dall_e.encoder import Encoder\nfrom dall_e.decoder import Decoder\nfrom dall_e.utils   import map_pixels, unmap_pixels\n\ndef load_model(path: str, device: torch.device = None) -> nn.Module:\n    if path.startswith('http://') or path.startswith('https://'):\n        resp = requests.get(path)\n        resp.raise_for_status()\n            \n        with io.BytesIO(resp.content) as buf:\n            return torch.load(buf, map_location=device)\n    else:\n        with open(path, 'rb') as f:\n            return torch.load(f, map_location=device)\n"
  },
  {
    "path": "dall_e/decoder.py",
    "content": "import attr\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections  import OrderedDict\nfrom functools    import partial\nfrom dall_e.utils import Conv2d\n\n@attr.s(eq=False, repr=False)\nclass DecoderBlock(nn.Module):\n\tn_in:     int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)\n\tn_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\n\tdevice:        torch.device = attr.ib(default=None)\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\t\tself.n_hid = self.n_out // 4\n\t\tself.post_gain = 1 / (self.n_layers ** 2)\n\n\t\tmake_conv     = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tself.id_path  = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()\n\t\tself.res_path = nn.Sequential(OrderedDict([\n\t\t\t\t('relu_1', nn.ReLU()),\n\t\t\t\t('conv_1', make_conv(self.n_in,  self.n_hid, 1)),\n\t\t\t\t('relu_2', nn.ReLU()),\n\t\t\t\t('conv_2', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_3', nn.ReLU()),\n\t\t\t\t('conv_3', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_4', nn.ReLU()),\n\t\t\t\t('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\treturn self.id_path(x) + self.post_gain * self.res_path(x)\n\n@attr.s(eq=False, repr=False)\nclass Decoder(nn.Module):\n\tgroup_count:     int = 4\n\tn_init:          int = attr.ib(default=128,  validator=lambda i, a, x: x >= 8)\n\tn_hid:           int = attr.ib(default=256,  validator=lambda i, a, x: x >= 64)\n\tn_blk_per_group: int = attr.ib(default=2,    validator=lambda i, a, x: x >= 1)\n\toutput_channels: int = attr.ib(default=3,    validator=lambda i, a, x: x >= 1)\n\tvocab_size:      int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)\n\n\tdevice:              torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad:       bool         = attr.ib(default=False)\n\tuse_mixed_precision: bool         = attr.ib(default=True)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tblk_range  = range(self.n_blk_per_group)\n\t\tn_layers   = self.group_count * self.n_blk_per_group\n\t\tmake_conv  = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tmake_blk   = partial(DecoderBlock, n_layers=n_layers, device=self.device,\n\t\t\t\trequires_grad=self.requires_grad)\n\n\t\tself.blocks = nn.Sequential(OrderedDict([\n\t\t\t('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),\n\t\t\t('group_1', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_2', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_3', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],\n\t\t\t\t('upsample', nn.Upsample(scale_factor=2, mode='nearest')),\n\t\t\t]))),\n\t\t\t('group_4', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],\n\t\t\t]))),\n\t\t\t('output', nn.Sequential(OrderedDict([\n\t\t\t\t('relu', nn.ReLU()),\n\t\t\t\t('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),\n\t\t\t]))),\n\t\t]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif len(x.shape) != 4:\n\t\t\traise ValueError(f'input shape {x.shape} is not 4d')\n\t\tif x.shape[1] != self.vocab_size:\n\t\t\traise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')\n\t\tif x.dtype != torch.float32:\n\t\t\traise ValueError('input must have dtype torch.float32')\n\n\t\treturn self.blocks(x)\n"
  },
  {
    "path": "dall_e/encoder.py",
    "content": "import attr\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections  import OrderedDict\nfrom functools    import partial\nfrom dall_e.utils import Conv2d\n\n@attr.s(eq=False, repr=False)\nclass EncoderBlock(nn.Module):\n\tn_in:     int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)\n\tn_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\n\tdevice:        torch.device = attr.ib(default=None)\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\t\tself.n_hid = self.n_out // 4\n\t\tself.post_gain = 1 / (self.n_layers ** 2)\n\n\t\tmake_conv     = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tself.id_path  = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()\n\t\tself.res_path = nn.Sequential(OrderedDict([\n\t\t\t\t('relu_1', nn.ReLU()),\n\t\t\t\t('conv_1', make_conv(self.n_in,  self.n_hid, 3)),\n\t\t\t\t('relu_2', nn.ReLU()),\n\t\t\t\t('conv_2', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_3', nn.ReLU()),\n\t\t\t\t('conv_3', make_conv(self.n_hid, self.n_hid, 3)),\n\t\t\t\t('relu_4', nn.ReLU()),\n\t\t\t\t('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\treturn self.id_path(x) + self.post_gain * self.res_path(x)\n\n@attr.s(eq=False, repr=False)\nclass Encoder(nn.Module):\n\tgroup_count:     int = 4\n\tn_hid:           int = attr.ib(default=256,  validator=lambda i, a, x: x >= 64)\n\tn_blk_per_group: int = attr.ib(default=2,    validator=lambda i, a, x: x >= 1)\n\tinput_channels:  int = attr.ib(default=3,    validator=lambda i, a, x: x >= 1)\n\tvocab_size:      int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)\n\n\tdevice:              torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad:       bool         = attr.ib(default=False)\n\tuse_mixed_precision: bool         = attr.ib(default=True)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tblk_range  = range(self.n_blk_per_group)\n\t\tn_layers   = self.group_count * self.n_blk_per_group\n\t\tmake_conv  = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)\n\t\tmake_blk   = partial(EncoderBlock, n_layers=n_layers, device=self.device,\n\t\t\t\trequires_grad=self.requires_grad)\n\n\t\tself.blocks = nn.Sequential(OrderedDict([\n\t\t\t('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),\n\t\t\t('group_1', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_2', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_3', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],\n\t\t\t\t('pool', nn.MaxPool2d(kernel_size=2)),\n\t\t\t]))),\n\t\t\t('group_4', nn.Sequential(OrderedDict([\n\t\t\t\t*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],\n\t\t\t]))),\n\t\t\t('output', nn.Sequential(OrderedDict([\n\t\t\t\t('relu', nn.ReLU()),\n\t\t\t\t('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),\n\t\t\t]))),\n\t\t]))\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif len(x.shape) != 4:\n\t\t\traise ValueError(f'input shape {x.shape} is not 4d')\n\t\tif x.shape[1] != self.input_channels:\n\t\t\traise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')\n\t\tif x.dtype != torch.float32:\n\t\t\traise ValueError('input must have dtype torch.float32')\n\n\t\treturn self.blocks(x)\n"
  },
  {
    "path": "dall_e/utils.py",
    "content": "import attr\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nlogit_laplace_eps: float = 0.1\n\n@attr.s(eq=False)\nclass Conv2d(nn.Module):\n\tn_in:  int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tn_out: int = attr.ib(validator=lambda i, a, x: x >= 1)\n\tkw:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)\n\n\tuse_float16:   bool         = attr.ib(default=True)\n\tdevice:        torch.device = attr.ib(default=torch.device('cpu'))\n\trequires_grad: bool         = attr.ib(default=False)\n\n\tdef __attrs_post_init__(self) -> None:\n\t\tsuper().__init__()\n\n\t\tw = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,\n\t\t\tdevice=self.device, requires_grad=self.requires_grad)\n\t\tw.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))\n\n\t\tb = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,\n\t\t\trequires_grad=self.requires_grad)\n\t\tself.w, self.b = nn.Parameter(w), nn.Parameter(b)\n\n\tdef forward(self, x: torch.Tensor) -> torch.Tensor:\n\t\tif self.use_float16 and 'cuda' in self.w.device.type:\n\t\t\tif x.dtype != torch.float16:\n\t\t\t\tx = x.half()\n\n\t\t\tw, b = self.w.half(), self.b.half()\n\t\telse:\n\t\t\tif x.dtype != torch.float32:\n\t\t\t\tx = x.float()\n\n\t\t\tw, b = self.w, self.b\n\n\t\treturn F.conv2d(x, w, b, padding=(self.kw - 1) // 2)\n\ndef map_pixels(x: torch.Tensor) -> torch.Tensor:\n\tif len(x.shape) != 4:\n\t\traise ValueError('expected input to be 4d')\n\tif x.dtype != torch.float:\n\t\traise ValueError('expected input to have type float')\n\n\treturn (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps\n\ndef unmap_pixels(x: torch.Tensor) -> torch.Tensor:\n\tif len(x.shape) != 4:\n\t\traise ValueError('expected input to be 4d')\n\tif x.dtype != torch.float:\n\t\traise ValueError('expected input to have type float')\n\n\treturn torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)\n"
  },
  {
    "path": "model_card.md",
    "content": "# Model Card: DALL·E dVAE\n\nFollowing [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from\nArchives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we're providing some information about about the discrete\nVAE (dVAE) that was used to train DALL·E.\n\n## Model Details\n\nThe dVAE was developed by researchers at OpenAI to reduce the memory footprint of the transformer trained on the\ntext-to-image generation task. The details involved in training the dVAE are described in [the paper][dalle_paper]. This\nmodel card describes the first version of the model, released in February 2021. The model consists of a convolutional\nencoder and decoder whose architectures are described [here](dall_e/encoder.py) and [here](dall_e/decoder.py), respectively.\nFor questions or comments about the models or the code release, please file a Github issue.\n\n## Model Use\n\n### Intended Use\n\nThe model is intended for others to use for training their own generative models.\n\n### Out-of-Scope Use Cases\n\nThis model is inappropriate for high-fidelity image processing applications. We also do not recommend its use as a\ngeneral-purpose image compressor.\n\n## Training Data\n\nThe model was trained on publicly available text-image pairs collected from the internet. This data consists partly of\n[Conceptual Captions][cc] and a filtered subset of [YFCC100M][yfcc100m]. We used a subset of the filters described in\n[Sharma et al.][cc_paper] to construct this dataset; further details are described in [our paper][dalle_paper]. We will\nnot be releasing the dataset.\n\n## Performance and Limitations\n\nThe heavy compression from the encoding process results in a noticeable loss of detail in the reconstructed images. This\nrenders it inappropriate for applications that require fine-grained details of the image to be preserved.\n\n[dalle_paper]: https://arxiv.org/abs/2102.12092\n[cc]: https://ai.google.com/research/ConceptualCaptions\n[cc_paper]: https://www.aclweb.org/anthology/P18-1238/\n[yfcc100m]: http://projects.dfki.uni-kl.de/yfcc100m/\n"
  },
  {
    "path": "notebooks/usage.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import io\\n\",\n    \"import os, sys\\n\",\n    \"import requests\\n\",\n    \"import PIL\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import torchvision.transforms as T\\n\",\n    \"import torchvision.transforms.functional as TF\\n\",\n    \"\\n\",\n    \"from dall_e          import map_pixels, unmap_pixels, load_model\\n\",\n    \"from IPython.display import display, display_markdown\\n\",\n    \"\\n\",\n    \"target_image_size = 256\\n\",\n    \"\\n\",\n    \"def download_image(url):\\n\",\n    \"    resp = requests.get(url)\\n\",\n    \"    resp.raise_for_status()\\n\",\n    \"    return PIL.Image.open(io.BytesIO(resp.content))\\n\",\n    \"\\n\",\n    \"def preprocess(img):\\n\",\n    \"    s = min(img.size)\\n\",\n    \"    \\n\",\n    \"    if s < target_image_size:\\n\",\n    \"        raise ValueError(f'min dim for image {s} < {target_image_size}')\\n\",\n    \"        \\n\",\n    \"    r = target_image_size / s\\n\",\n    \"    s = (round(r * img.size[1]), round(r * img.size[0]))\\n\",\n    \"    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\\n\",\n    \"    img = TF.center_crop(img, output_size=2 * [target_image_size])\\n\",\n    \"    img = torch.unsqueeze(T.ToTensor()(img), 0)\\n\",\n    \"    return map_pixels(img)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This can be changed to a GPU, e.g. 'cuda:0'.\\n\",\n    \"dev = torch.device('cpu')\\n\",\n    \"\\n\",\n    \"# For faster load times, download these files locally and use the local paths instead.\\n\",\n    \"enc = load_model(\\\"https://cdn.openai.com/dall-e/encoder.pkl\\\", dev)\\n\",\n    \"dec = load_model(\\\"https://cdn.openai.com/dall-e/decoder.pkl\\\", dev)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\\n\",\n    \"display_markdown('Original image:')\\n\",\n    \"display(T.ToPILImage(mode='RGB')(x[0]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch.nn.functional as F\\n\",\n    \"\\n\",\n    \"z_logits = enc(x)\\n\",\n    \"z = torch.argmax(z_logits, axis=1)\\n\",\n    \"z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\\n\",\n    \"\\n\",\n    \"x_stats = dec(z).float()\\n\",\n    \"x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\\n\",\n    \"x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\\n\",\n    \"\\n\",\n    \"display_markdown('Reconstructed image:')\\n\",\n    \"display(x_rec)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.1\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "requirements.txt",
    "content": "Pillow\nblobfile\nmypy\nnumpy\npytest\nrequests\ntorch\ntorchvision\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\n\ndef parse_requirements(filename):\n\tlines = (line.strip() for line in open(filename))\n\treturn [line for line in lines if line and not line.startswith(\"#\")]\n\nsetup(name='DALL-E',\n        version='0.1',\n        description='PyTorch package for the discrete VAE used for DALL·E.',\n        url='http://github.com/openai/DALL-E',\n        author='Aditya Ramesh',\n        author_email='aramesh@openai.com',\n        license='BSD',\n        packages=['dall_e'],\n        install_requires=parse_requirements('requirements.txt'),\n        zip_safe=True)\n"
  }
]