[
  {
    "path": "GFPGANReconsitution.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os\nimport sys\nimport argparse\nimport cv2\nimport numpy as np\nimport torch \nimport timeit\n#import onnxruntime\nfrom torch.nn import functional as F\nfrom torchvision.transforms.functional import normalize\nfrom torch import nn\nimport math\nfrom collections import OrderedDict\nfrom noise_main import noise_dict\n\nclass ResBlock(nn.Module):\n    \"\"\"Residual block with upsampling/downsampling.\n\n    Args:\n        in_channels (int): Channel number of the input.\n        out_channels (int): Channel number of the output.\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, mode='down'):\n        super(ResBlock, self).__init__()\n\n        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)\n        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)\n        self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)\n        if mode == 'down':\n            self.scale_factor = 0.5\n        elif mode == 'up':\n            self.scale_factor = 2\n\n    def forward(self, x):\n        out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)\n        # upsample/downsample\n        out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)\n        out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)\n        # skip\n        x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)\n        skip = self.skip(x)\n        out = out + skip\n        return out\n\n\nclass ConstantInput(nn.Module):\n    \"\"\"Constant input.\n\n    Args:\n        num_channel (int): Channel number of constant input.\n        size (int): Spatial size of constant input.\n    \"\"\"\n\n    def __init__(self, num_channel, size):\n        super(ConstantInput, self).__init__()\n        self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) # [1, 512, 4, 4]\n    \n    def forward(self, batch):\n        out = self.weight.repeat(batch, 1, 1, 1)\n        \n        return out\n\nclass GFPGAN(nn.Module):\n    def __init__(self):\n        super(GFPGAN, self).__init__()\n        unet_narrow = 0.5\n        channel_multiplier=2\n        channels = {\n            '4': int(512 * unet_narrow),\n            '8': int(512 * unet_narrow),\n            '16': int(512 * unet_narrow),\n            '32': int(512 * unet_narrow),\n            '64': int(256 * channel_multiplier * unet_narrow),\n            '128': int(128 * channel_multiplier * unet_narrow),\n            '256': int(64 * channel_multiplier * unet_narrow),\n            '512': int(32 * channel_multiplier * unet_narrow),\n            '1024': int(16 * channel_multiplier * unet_narrow)\n        }\n\n        self.conv_body_first = nn.Conv2d(3, 32, 1)\n        self.conv_body_down = nn.ModuleList()\n        \n        in_channels = channels['512']\n        for i in range(9, 2, -1):\n            out_channels = channels[f'{2**(i - 1)}']\n            self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))\n            in_channels = out_channels\n        num_style_feat = 512\n        self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)\n        linear_out_channel = (int(math.log(512, 2)) * 2 - 2) * num_style_feat\n        self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)\n        \n        # upsample\n        in_channels = channels['4']\n        self.conv_body_up = nn.ModuleList()\n        for i in range(3, 9 + 1):\n            out_channels = channels[f'{2**i}']\n            self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))\n            in_channels = out_channels\n\n        # for SFT\n        self.condition_scale = nn.ModuleList()\n        self.condition_shift = nn.ModuleList()\n        for i in range(3, 9 + 1):\n            out_channels = channels[f'{2**i}']\n            sft_out_channels = out_channels\n            self.condition_scale.append(\n                nn.Sequential(\n                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),\n                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))\n            self.condition_shift.append(\n                nn.Sequential(\n                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),\n                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))\n        \n        self.stylegan_decoderdotconstant_input = ConstantInput(512, size=4)\n        \n        \n        # self.style_conv1\n        self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        \n        self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        \n        self.stylegan_decoderdotstyle_conv1dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        \n        \n        self.stylegan_decoderdotstyle_conv1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)\n        \n        \n        # toRGB\n        self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 512, 1, 1) /\n            math.sqrt(512 * 1**2))\n        self.stylegan_decoderdotto_rgb1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        \n        # i = 1\n        self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot0dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot0dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot1dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        #self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 512, 1, 1) /\n            math.sqrt(512 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot0dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n\n        #i = 3\n        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot2dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot2dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot3dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot3dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 512, 1, 1) /\n            math.sqrt(512 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        \n        #i = 5\n        self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot4dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot4dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot5dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot5dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 512, 1, 1) /\n            math.sqrt(512 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot2dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        \n        #i = 7\n        self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot6dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot6dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 512, 512, 3, 3) /\n            math.sqrt(512 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot7dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot7dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))\n        \n        self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 512, 1, 1) /\n            math.sqrt(512 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot3dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n\n        #i = 9\n        self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)\n        self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 256, 512, 3, 3) /\n            math.sqrt(256 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot8dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot8dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))\n        \n        self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)\n        self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 256, 256, 3, 3) /\n            math.sqrt(256 * 3**2))\n        \n        self.stylegan_decoderdotstyle_convsdot9dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot9dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))\n        self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)\n        self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 256, 1, 1) /\n            math.sqrt(256 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot4dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        \n        #i = 11\n        self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)\n        self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 128, 256, 3, 3) /\n            math.sqrt(128 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot10dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot10dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)\n        \n        self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 128, 128, 3, 3) /\n            math.sqrt(128 * 3**2)) \n        self.stylegan_decoderdotstyle_convsdot11dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot11dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))\n        self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)\n        self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 128, 1, 1) /\n            math.sqrt(128 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot5dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        \n        #i = 13\n        self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)\n        self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 64, 128, 3, 3) /\n            math.sqrt(64 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot12dotweight = nn.Parameter(torch.zeros(1))  # for noise injection\n        self.stylegan_decoderdotstyle_convsdot12dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))\n        self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)\n        self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 64, 64, 3, 3) /\n            math.sqrt(64 * 3**2))\n        self.stylegan_decoderdotstyle_convsdot13dotweight = nn.Parameter(torch.zeros(1))  \n        self.stylegan_decoderdotstyle_convsdot13dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))\n        self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)\n        self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight = nn.Parameter(\n            torch.randn(1, 3, 64, 1, 1) /\n            math.sqrt(64 * 1**2))\n        self.stylegan_decoderdotto_rgbsdot6dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))\n        ''' \n        '''\n    def forward(self, x):\n        # encoder\n        feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)\n        conditions = []\n        unet_skips = []\n        out_rgbs = []\n\n        for i in range(7):\n            feat = self.conv_body_down[i](feat)\n            unet_skips.insert(0, feat)\n        \n        feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)\n        \n        # style code\n        style_code = self.final_linear(feat.view(feat.size(0), -1))\n        style_code = style_code.view(style_code.size(0), -1, 512)\n\n        # decode\n        for i in range(7):\n            # add unet skip\n            feat = feat + unet_skips[i]\n            # ResUpLayer\n            feat = self.conv_body_up[i](feat)\n            # generate scale and shift for SFT layer\n            scale = self.condition_scale[i](feat)\n            conditions.append(scale.clone())\n            shift = self.condition_shift[i](feat)\n            conditions.append(shift.clone())\n\n        styles = [style_code]\n\n       \n        #noise = [None] * 15  # for each style conv layer\n        latent = styles[0]    \n        out = self.stylegan_decoderdotconstant_input(latent.shape[0])\n    \n        b, c, h, w = 1, 512, 4, 4\n        # weight modulation\n        style = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation(latent[:, 0]).view(b, 1, c, 1, 1)\n        weight = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight * style  # (b, c_out, c_in, k, k)\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        weight = weight.view(b * 512, c, 3, 3)\n        b, c, h, w = 1, 512, 4, 4\n        out = out.view(1, b * c, h, w)\n        # weight: (b*c_out, c_in, k, k), groups=b\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(b, 512, *out.shape[2:4]) * 2**0.5 \n        b, _, h, w = 1, 512, 4, 4\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_conv1dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_conv1dotbias\n        out = self.activate(out)\n        out0 = out \n        \n       \n        # toRGB    \n        x = out    ########\n        style = latent[:, 1] ###########\n        style = self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation(latent[:, 1]).view(b, 1, c, 1, 1)\n        weight = self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight * style     \n        weight = weight.view(3, 512, 1, 1)\n        b, c, h, w = 1, 512, 4, 4\n        x = x.view(1, 512, 4, 4)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 4, 4)\n        out = out + self.stylegan_decoderdotto_rgb1dotbias\n        skip = out\n        out = out0\n  \n        # i = 1\n        i = 1\n        x = out\n        b, c, h, w = 1, 512, 4, 4\n        \n        #conv1\n        style = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        #\n        weight = weight.view(b * 512, c, 3, 3)\n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 512, 8, 8) * 2 ** 0.5 \n        b, _, h, w = 1,_,8,8\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot0dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot0dotbias\n        out = self.activate(out)\n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        weight = weight.view(b * 512, 512, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 512, 8, 8) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot1dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot1dotbias\n        out = self.activate(out)\n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]  \n        style = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight * style     \n        weight = weight.view(3, 512, 1, 1)\n        #b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 8, 8)\n        out = out + self.stylegan_decoderdotto_rgbsdot0dotbias\n        \n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n        \n        # i = 3\n        out = out0\n        x = out\n        b, c, h, w = 1, 512, 8, 8\n        i += 2\n        style = latent[:, i]\n        #conv1\n        style = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        weight = weight.view(b * 512, c, 3, 3)\n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 512, 16, 16) * 2 ** 0.5 \n        b, _, h, w = 1, _, 16, 16\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot2dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot2dotbias\n        out = self.activate(out)\n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 512, 1, 1, 1)\n        weight = weight.view(b * 512, 512, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 512, 16, 16) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot3dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot3dotbias\n        out = self.activate(out)\n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]  \n        style = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight * style     \n        weight = weight.view(3, 512, 1, 1)\n        #b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 16, 16)\n        out = out + self.stylegan_decoderdotto_rgbsdot1dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n        \n        \n        # i = 5\n        out = out0\n        x = out\n        b, c, h, w = 1, 512, 32, 32\n        i += 2\n        style = latent[:, i] \n        #conv1\n        style = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   \n        weight = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        weight = weight.view(b * 512, c, 3, 3)    \n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 512, 32, 32) * 2 ** 0.5 \n        b, _, h, w = 1, _, 32, 32\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot4dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot4dotbias\n        out = self.activate(out)\n        \n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n       \n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 512, 1, 1, 1)\n        weight = weight.view(b * 512, 512, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 512, 32, 32) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot5dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot5dotbias\n        out = self.activate(out)\n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]\n        style = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight * style     \n        \n        weight = weight.view(3, 512, 1, 1)\n        #b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        \n        out = out.view(1, 3, 32, 32)\n        out = out + self.stylegan_decoderdotto_rgbsdot2dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n\n        # i = 7\n        out = out0\n        x = out\n        b, c, h, w = 1, 512, 32, 32\n        i += 2\n        style = latent[:, i]   # 数值一致\n        #conv1\n        style = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   \n        weight = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 512, 1, 1, 1)\n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        weight = weight.view(b * 512, c, 3, 3)    \n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 512, 64, 64) * 2 ** 0.5 \n        b, _, h, w = 1, _, 64, 64\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot7dotbias\n        out = self.activate(out)\n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 512, 1, 1, 1)\n        weight = weight.view(b * 512, 512, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 512, 64, 64) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot7dotbias\n        out = self.activate(out)\n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]\n        style = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight * style     \n        weight = weight.view(3, 512, 1, 1)\n        #b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 64, 64)\n        out = out + self.stylegan_decoderdotto_rgbsdot3dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n        \n        # i = 9\n        out = out0\n        x = out\n        b, c, h, w = 1, 512, 64, 64\n        i += 2\n        style = latent[:, i]   # 数值一致\n        #conv1\n        style = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   \n        weight = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight * style    \n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 256, 1, 1, 1)\n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        weight = weight.view(b * 256, c, 3, 3)    \n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 256, 128, 128) * 2 ** 0.5 \n        b, _, h, w = 1, _, 128, 128\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot8dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot8dotbias\n        out = self.activate(out)\n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 256, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight * style\n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 256, 1, 1, 1)\n        weight = weight.view(b * 256, 256, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 256, 128, 128) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot9dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot9dotbias\n        out = self.activate(out)\n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]\n        style = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation(style).view(1, 1, 256, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight * style     \n        weight = weight.view(3, 256, 1, 1)\n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 128, 128)\n        out = out + self.stylegan_decoderdotto_rgbsdot4dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n\n        # i = 11\n        out = out0\n        x = out\n        b, c, h, w = 1, 256, 128, 128\n        i += 2\n        style = latent[:, i]   # 数值一致 \n        style = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)     \n        #conv1\n        weight = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight * style    \n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 128, 1, 1, 1)\n        weight = weight.view(b * 128, 256, 3, 3) \n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        b, c, h, w = x.shape\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 128, 256, 256) * 2 ** 0.5 \n        b, _, h, w = 1, _, 256, 256\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot10dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot10dotbias\n        out = self.activate(out)\n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 128, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight * style \n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 128, 1, 1, 1)\n        weight = weight.view(b * 128, 128, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 128, 256, 256) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot11dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot11dotbias\n        out = self.activate(out)  \n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]\n        style = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation(style).view(1, 1, 128, 1, 1)\n        weight = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight * style     \n        weight = weight.view(3, 128, 1, 1)\n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)\n        out = out.view(1, 3, 256, 256) \n        out = out + self.stylegan_decoderdotto_rgbsdot5dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n        \n        # i = 13\n        out = out0\n        x = out\n        b, c, h, w = 1, 128, 256, 256\n        i += 2\n        style = latent[:, i]   # 数值一致 \n        style = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)     \n        #conv1\n        weight = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight * style    \n        \n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(b, 64, 1, 1, 1)\n        weight = weight.view(b * 64, 128, 3, 3) \n        \n        # self.sample_mode == 'upsample'\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        b, c, h, w = x.shape\n        out = F.conv2d(x, weight, padding=1, groups=b)\n        out = out.view(1, 64, 512, 512) * 2 ** 0.5 \n        b, _, h, w = 1, _, 512, 512\n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot12dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot12dotbias\n        out = self.activate(out)\n        \n        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)\n        out_sft = out_sft * conditions[i - 1] + conditions[i]\n        out = torch.cat([out_same, out_sft], dim=1)\n        #conv2\n        style = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 64, 1, 1)\n        weight = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight * style \n        \n        # self.demodulate = True:\n        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)\n        weight = weight * demod.view(1, 64, 1, 1, 1)\n        weight = weight.view(b * 64, 64, 3, 3)\n        out = F.conv2d(out, weight, padding=1, groups=b)\n        out = out.view(1, 64, 512, 512) * 2 ** 0.5 \n        noise = noise_dict[w]\n        out = out + self.stylegan_decoderdotstyle_convsdot13dotweight * noise\n        out = out + self.stylegan_decoderdotstyle_convsdot13dotbias\n        out = self.activate(out)    \n        out0 = out\n        #to_rgb\n        x = out\n        style = latent[:, i + 2]\n        style = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation(style).view(1, 1, 64, 1, 1) \n        weight = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight * style     \n        weight = weight.view(3, 64, 1, 1)\n        b, c, h, w = x.shape\n        x = x.view(1, b * c, h, w)\n        out = F.conv2d(x, weight, padding=0, groups=b)    \n        out = out.view(1, 3, 512, 512) \n        out = out + self.stylegan_decoderdotto_rgbsdot6dotbias\n        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)\n        skip = out + skip\n        return skip        \n\n"
  },
  {
    "path": "README.md",
    "content": "# GFPGAN-onnxruntime-demo\nThis is the onnxruntime inference code for  GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior (CVPR 2021). Official code: https://github.com/TencentARC/GFPGAN\n\n## The following issues are addressed：\n1、noise = out.new_empty(b, 1, h, w).normal_() in stylegan2_clean_arch.py can‘t be supported in ONNX. I move it out the Model class, like noise = Noise[i], the Noise is a list or others which prestores generated random noise.\n\n2、the forward function of Model is very bad, especially stylegan, so many \" if else \" and class be reused. Like the StyleConv \" in \"useself.style_convs.append StyleConv ...\". So I rewrite and make it in single forward.\n\n## convert torch to onnx.\n```\nwget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth\n\npython torch2onnx.py  --src_model_path ./GFPGANv1.3.pth --dst_model_path ./GFPGANv1.3.onnx --img_size 512 \n```\n\n## run onnx demo.\n```\npython demo_onnx.py --model_path GFPGANv1.3.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v3.jpg\n```\n\n| input | output|\n| :-: |:-:|\n|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Justin_Timberlake_crop.png\" height=\"80%\" width=\"80%\">|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Justin_Timberlake_v2.jpg\" height=\"80%\" width=\"80%\">|\n|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Julia_Roberts_crop.png\" height=\"80%\" width=\"80%\">|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Julia_Roberts_v2.jpg\" height=\"80%\" width=\"80%\">|\n|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Paris_Hilton_crop.png\" height=\"80%\" width=\"80%\">|<img src=\"https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Paris_Hilton_v2.jpg\" height=\"80%\" width=\"80%\">|\n\n\n"
  },
  {
    "path": "demo_onnx.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os\nimport sys\nimport argparse\nimport cv2\nimport numpy as np\nimport timeit\nimport onnxruntime\n\nclass GFPGANFaceAugment:\n    def __init__(self, model_path, use_gpu = False):\n        self.ort_session = onnxruntime.InferenceSession(model_path)\n        self.net_input_name = self.ort_session.get_inputs()[0].name\n        _,self.net_input_channels,self.net_input_height,self.net_input_width = self.ort_session.get_inputs()[0].shape\n        self.net_output_count = len(self.ort_session.get_outputs())\n        self.face_size = 512\n        self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) * (self.face_size / 512.0)\n        self.upscale_factor = 2\n        self.affine = False\n        self.affine_matrix = None\n    def pre_process(self, img):\n        img = cv2.resize(img, (int(img.shape[1] / 2), int(img.shape[0] / 2)))\n        img = cv2.resize(img, (self.face_size, self.face_size))\n        img = img / 255.0\n        img = img.astype('float32')\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        img[:,:,0] = (img[:,:,0]-0.5)/0.5\n        img[:,:,1] = (img[:,:,1]-0.5)/0.5\n        img[:,:,2] = (img[:,:,2]-0.5)/0.5\n        img = np.float32(img[np.newaxis,:,:,:])\n        img = img.transpose(0, 3, 1, 2)\n        return img\n    def post_process(self, output, height, width):\n        output = output.clip(-1,1)\n        output = (output + 1) / 2\n        output = output.transpose(1, 2, 0)\n        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)\n        output = (output * 255.0).round()\n        if self.affine:\n            inverse_affine = cv2.invertAffineTransform(self.affine_matrix)\n            inverse_affine *= self.upscale_factor\n            if self.upscale_factor > 1:\n                extra_offset = 0.5 * self.upscale_factor\n            else:\n                extra_offset = 0\n            inverse_affine[:, 2] += extra_offset\n            inv_restored = cv2.warpAffine(output, inverse_affine, (width, height))\n            mask = np.ones((self.face_size, self.face_size), dtype=np.float32)\n            inv_mask = cv2.warpAffine(mask, inverse_affine, (width, height))\n            inv_mask_erosion = cv2.erode(\n                inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))\n            pasted_face = inv_mask_erosion[:, :, None] * inv_restored\n            total_face_area = np.sum(inv_mask_erosion)\n            # compute the fusion edge based on the area of face\n            w_edge = int(total_face_area**0.5) // 20\n            erosion_radius = w_edge * 2\n            inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))\n            blur_size = w_edge * 2\n            inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)\n            inv_soft_mask = inv_soft_mask[:, :, None]\n            output = pasted_face\n        else:\n            inv_soft_mask = np.ones((height, width, 1), dtype=np.float32)\n            output = cv2.resize(output, (width, height))\n        return output, inv_soft_mask\n\n    def forward(self, img):\n        height, width = img.shape[0], img.shape[1]\n        img = self.pre_process(img)\n        t = timeit.default_timer()\n        ort_inputs = {self.ort_session.get_inputs()[0].name: img}\n        ort_outs = self.ort_session.run(None, ort_inputs)\n        output = ort_outs[0][0]\n        output, inv_soft_mask = self.post_process(output, height, width)\n        print('infer time:',timeit.default_timer()-t)  \n        output = output.astype(np.uint8)\n        return output, inv_soft_mask\n        \nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"onnxruntime demo\")\n    parser.add_argument('--model_path', type=str, default=None, help='model path')\n    parser.add_argument('--image_path', type=str, default=None, help='input image path')\n    parser.add_argument('--save_path', type=str, default=\"output.jpg\", help='output image path')\n    args = parser.parse_args()\n\n    faceaugment = GFPGANFaceAugment(model_path=args.model_path)\n    image = cv2.imread(args.image_path, 1)\n    output, _ = faceaugment.forward(image)\n    cv2.imwrite(args.save_path, output)\n\n# python demo_onnx.py --model_path GFPGANv1.4.onnx --image_path ./cropped_faces/Adele_crop.png\n\n\n# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v2.jpg\n# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Julia_Roberts_crop.png --save_path Julia_Roberts_v2.jpg\n# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Justin_Timberlake_crop.png --save_path Justin_Timberlake_v2.jpg\n# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Paris_Hilton_crop.png --save_path Paris_Hilton_v2.jpg"
  },
  {
    "path": "noise_main.py",
    "content": "import torch\nnoise_dict = {}\nsize = [(1, 1, 4, 4),(1, 1, 8, 8),(1, 1, 16, 16),(1, 1, 32, 32),(1, 1, 64, 64),(1, 1, 128, 128),(1, 1, 256, 256),(1, 1, 512, 512)]\nfor s in size: \n    out = torch.rand(s)#.cuda()\n    noise = out.new_empty(s).normal_()\n    #print(s[2])\n    noise_dict[s[2]] = noise\n    #print(noise_dict)\n"
  },
  {
    "path": "torch2onnx.py",
    "content": "# -*- coding: utf-8 -*-\n\n#import cv2\nimport numpy as np\nimport time\nimport torch\nimport pdb\nfrom collections import OrderedDict\n\nimport sys\nsys.path.append('.')\nsys.path.append('./lib')\nimport torch.nn as nn\nfrom torch.autograd import Variable\nimport onnxruntime\nimport timeit\n\nimport argparse\nfrom GFPGANReconsitution import GFPGAN\n\nparser = argparse.ArgumentParser(\"ONNX converter\")\nparser.add_argument('--src_model_path', type=str, default=None, help='src model path')\nparser.add_argument('--dst_model_path', type=str, default=None, help='dst model path')\nparser.add_argument('--img_size', type=int, default=None, help='img size')\nargs = parser.parse_args()\n    \n#device = torch.device('cuda')\nmodel_path = args.src_model_path\nonnx_model_path = args.dst_model_path\nimg_size = args.img_size\n\nmodel = GFPGAN()#.cuda()\n\nx = torch.rand(1, 3, 512, 512)#.cuda()\n\nstate_dict = torch.load(model_path)['params_ema']\nnew_state_dict = OrderedDict()\nfor k, v in state_dict.items():\n    # stylegan_decoderdotto_rgbsdot1dotmodulated_convdotbias\n    if \"stylegan_decoder\" in k:\n        k = k.replace('.', 'dot')\n        new_state_dict[k] = v\n        k = k.replace('dotweight', '.weight')\n        k = k.replace('dotbias', '.bias')\n        new_state_dict[k] = v\n    else:\n        new_state_dict[k] = v\n     \nmodel.load_state_dict(new_state_dict, strict=False)\nmodel.eval()\n\ntorch.onnx.export(model, x, onnx_model_path,\n                    export_params=True, opset_version=11, do_constant_folding=True,\n                    input_names = ['input'],output_names = [])\n\n\n####\ntry:\n    original_model = onnx.load(onnx_model_path)\n    passes = ['fuse_bn_into_conv']\n    optimized_model = optimizer.optimize(original_model, passes)\n    onnx.save(optimized_model, onnx_model_path)\nexcept:\n    print('skip optimize.')\n\n####\nort_session = onnxruntime.InferenceSession(onnx_model_path)\nfor var in ort_session.get_inputs():\n    print(var.name)\nfor var in ort_session.get_outputs():\n    print(var.name)\n_,_,input_h,input_w = ort_session.get_inputs()[0].shape\nt = timeit.default_timer()\n\nimg = np.zeros((input_h,input_w,3))\n\nimg = (np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) )#*self.scale\n\nimg = np.ascontiguousarray(img)\n#    \nort_inputs = {ort_session.get_inputs()[0].name: img}\nort_outs = ort_session.run(None, ort_inputs)\n\nprint('onnxruntime infer time:', timeit.default_timer()-t)\nprint(ort_outs[0].shape)\n\n# python torch2onnx.py  --src_model_path ./experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGAN.onnx --img_size 512 \n\n# 新版本\n\n\n# wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth\n\n# python torch2onnx.py  --src_model_path ./GFPGANv1.4.pth --dst_model_path ./GFPGANv1.4.onnx --img_size 512 \n\n# python torch2onnx.py  --src_model_path ./GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGANv1.2.onnx --img_size 512 \n"
  }
]