Repository: xuanandsix/GFPGAN-onnxruntime-demo Branch: main Commit: 3728cee8138d Files: 5 Total size: 45.6 KB Directory structure: gitextract_ll211iz8/ ├── GFPGANReconsitution.py ├── README.md ├── demo_onnx.py ├── noise_main.py └── torch2onnx.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: GFPGANReconsitution.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import argparse import cv2 import numpy as np import torch import timeit #import onnxruntime from torch.nn import functional as F from torchvision.transforms.functional import normalize from torch import nn import math from collections import OrderedDict from noise_main import noise_dict class ResBlock(nn.Module): """Residual block with upsampling/downsampling. Args: in_channels (int): Channel number of the input. out_channels (int): Channel number of the output. """ def __init__(self, in_channels, out_channels, mode='down'): super(ResBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if mode == 'down': self.scale_factor = 0.5 elif mode == 'up': self.scale_factor = 2 def forward(self, x): out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) # upsample/downsample out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) # skip x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) skip = self.skip(x) out = out + skip return out class ConstantInput(nn.Module): """Constant input. Args: num_channel (int): Channel number of constant input. size (int): Spatial size of constant input. """ def __init__(self, num_channel, size): super(ConstantInput, self).__init__() self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) # [1, 512, 4, 4] def forward(self, batch): out = self.weight.repeat(batch, 1, 1, 1) return out class GFPGAN(nn.Module): def __init__(self): super(GFPGAN, self).__init__() unet_narrow = 0.5 channel_multiplier=2 channels = { '4': int(512 * unet_narrow), '8': int(512 * unet_narrow), '16': int(512 * unet_narrow), '32': int(512 * unet_narrow), '64': int(256 * channel_multiplier * unet_narrow), '128': int(128 * channel_multiplier * unet_narrow), '256': int(64 * channel_multiplier * unet_narrow), '512': int(32 * channel_multiplier * unet_narrow), '1024': int(16 * channel_multiplier * unet_narrow) } self.conv_body_first = nn.Conv2d(3, 32, 1) self.conv_body_down = nn.ModuleList() in_channels = channels['512'] for i in range(9, 2, -1): out_channels = channels[f'{2**(i - 1)}'] self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) in_channels = out_channels num_style_feat = 512 self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) linear_out_channel = (int(math.log(512, 2)) * 2 - 2) * num_style_feat self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) # upsample in_channels = channels['4'] self.conv_body_up = nn.ModuleList() for i in range(3, 9 + 1): out_channels = channels[f'{2**i}'] self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) in_channels = out_channels # for SFT self.condition_scale = nn.ModuleList() self.condition_shift = nn.ModuleList() for i in range(3, 9 + 1): out_channels = channels[f'{2**i}'] sft_out_channels = out_channels self.condition_scale.append( nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) self.condition_shift.append( nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) self.stylegan_decoderdotconstant_input = ConstantInput(512, size=4) # self.style_conv1 self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_conv1dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_conv1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) # toRGB self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 512, 1, 1) / math.sqrt(512 * 1**2)) self.stylegan_decoderdotto_rgb1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) # i = 1 self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot0dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot0dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot1dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) #self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 512, 1, 1) / math.sqrt(512 * 1**2)) self.stylegan_decoderdotto_rgbsdot0dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 3 self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot2dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot2dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot3dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot3dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 512, 1, 1) / math.sqrt(512 * 1**2)) self.stylegan_decoderdotto_rgbsdot1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 5 self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot4dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot4dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot5dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot5dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 512, 1, 1) / math.sqrt(512 * 1**2)) self.stylegan_decoderdotto_rgbsdot2dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 7 self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot6dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot6dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 512, 512, 3, 3) / math.sqrt(512 * 3**2)) self.stylegan_decoderdotstyle_convsdot7dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot7dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1)) self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 512, 1, 1) / math.sqrt(512 * 1**2)) self.stylegan_decoderdotto_rgbsdot3dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 9 self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True) self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 256, 512, 3, 3) / math.sqrt(256 * 3**2)) self.stylegan_decoderdotstyle_convsdot8dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot8dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1)) self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True) self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 256, 256, 3, 3) / math.sqrt(256 * 3**2)) self.stylegan_decoderdotstyle_convsdot9dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot9dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1)) self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True) self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 256, 1, 1) / math.sqrt(256 * 1**2)) self.stylegan_decoderdotto_rgbsdot4dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 11 self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True) self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 128, 256, 3, 3) / math.sqrt(128 * 3**2)) self.stylegan_decoderdotstyle_convsdot10dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot10dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1)) self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True) self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 128, 128, 3, 3) / math.sqrt(128 * 3**2)) self.stylegan_decoderdotstyle_convsdot11dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot11dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1)) self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True) self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 128, 1, 1) / math.sqrt(128 * 1**2)) self.stylegan_decoderdotto_rgbsdot5dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) #i = 13 self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True) self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 64, 128, 3, 3) / math.sqrt(64 * 3**2)) self.stylegan_decoderdotstyle_convsdot12dotweight = nn.Parameter(torch.zeros(1)) # for noise injection self.stylegan_decoderdotstyle_convsdot12dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1)) self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True) self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 64, 64, 3, 3) / math.sqrt(64 * 3**2)) self.stylegan_decoderdotstyle_convsdot13dotweight = nn.Parameter(torch.zeros(1)) self.stylegan_decoderdotstyle_convsdot13dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1)) self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True) self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight = nn.Parameter( torch.randn(1, 3, 64, 1, 1) / math.sqrt(64 * 1**2)) self.stylegan_decoderdotto_rgbsdot6dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1)) ''' ''' def forward(self, x): # encoder feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) conditions = [] unet_skips = [] out_rgbs = [] for i in range(7): feat = self.conv_body_down[i](feat) unet_skips.insert(0, feat) feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) # style code style_code = self.final_linear(feat.view(feat.size(0), -1)) style_code = style_code.view(style_code.size(0), -1, 512) # decode for i in range(7): # add unet skip feat = feat + unet_skips[i] # ResUpLayer feat = self.conv_body_up[i](feat) # generate scale and shift for SFT layer scale = self.condition_scale[i](feat) conditions.append(scale.clone()) shift = self.condition_shift[i](feat) conditions.append(shift.clone()) styles = [style_code] #noise = [None] * 15 # for each style conv layer latent = styles[0] out = self.stylegan_decoderdotconstant_input(latent.shape[0]) b, c, h, w = 1, 512, 4, 4 # weight modulation style = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation(latent[:, 0]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight * style # (b, c_out, c_in, k, k) demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(b, 512, 1, 1, 1) weight = weight.view(b * 512, c, 3, 3) b, c, h, w = 1, 512, 4, 4 out = out.view(1, b * c, h, w) # weight: (b*c_out, c_in, k, k), groups=b out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(b, 512, *out.shape[2:4]) * 2**0.5 b, _, h, w = 1, 512, 4, 4 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_conv1dotweight * noise out = out + self.stylegan_decoderdotstyle_conv1dotbias out = self.activate(out) out0 = out # toRGB x = out ######## style = latent[:, 1] ########### style = self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation(latent[:, 1]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight * style weight = weight.view(3, 512, 1, 1) b, c, h, w = 1, 512, 4, 4 x = x.view(1, 512, 4, 4) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 4, 4) out = out + self.stylegan_decoderdotto_rgb1dotbias skip = out out = out0 # i = 1 i = 1 x = out b, c, h, w = 1, 512, 4, 4 #conv1 style = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 512, 1, 1, 1) # weight = weight.view(b * 512, c, 3, 3) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 512, 8, 8) * 2 ** 0.5 b, _, h, w = 1,_,8,8 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot0dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot0dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 512, 1, 1, 1) weight = weight.view(b * 512, 512, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 512, 8, 8) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot1dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot1dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight * style weight = weight.view(3, 512, 1, 1) #b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 8, 8) out = out + self.stylegan_decoderdotto_rgbsdot0dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 3 out = out0 x = out b, c, h, w = 1, 512, 8, 8 i += 2 style = latent[:, i] #conv1 style = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 512, 1, 1, 1) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weight = weight.view(b * 512, c, 3, 3) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 512, 16, 16) * 2 ** 0.5 b, _, h, w = 1, _, 16, 16 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot2dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot2dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 512, 1, 1, 1) weight = weight.view(b * 512, 512, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 512, 16, 16) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot3dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot3dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight * style weight = weight.view(3, 512, 1, 1) #b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 16, 16) out = out + self.stylegan_decoderdotto_rgbsdot1dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 5 out = out0 x = out b, c, h, w = 1, 512, 32, 32 i += 2 style = latent[:, i] #conv1 style = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 512, 1, 1, 1) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weight = weight.view(b * 512, c, 3, 3) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 512, 32, 32) * 2 ** 0.5 b, _, h, w = 1, _, 32, 32 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot4dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot4dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 512, 1, 1, 1) weight = weight.view(b * 512, 512, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 512, 32, 32) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot5dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot5dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight * style weight = weight.view(3, 512, 1, 1) #b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 32, 32) out = out + self.stylegan_decoderdotto_rgbsdot2dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 7 out = out0 x = out b, c, h, w = 1, 512, 32, 32 i += 2 style = latent[:, i] # 数值一致 #conv1 style = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 512, 1, 1, 1) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weight = weight.view(b * 512, c, 3, 3) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 512, 64, 64) * 2 ** 0.5 b, _, h, w = 1, _, 64, 64 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot7dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 512, 1, 1, 1) weight = weight.view(b * 512, 512, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 512, 64, 64) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot7dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight * style weight = weight.view(3, 512, 1, 1) #b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 64, 64) out = out + self.stylegan_decoderdotto_rgbsdot3dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 9 out = out0 x = out b, c, h, w = 1, 512, 64, 64 i += 2 style = latent[:, i] # 数值一致 #conv1 style = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 256, 1, 1, 1) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weight = weight.view(b * 256, c, 3, 3) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 256, 128, 128) * 2 ** 0.5 b, _, h, w = 1, _, 128, 128 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot8dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot8dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 256, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 256, 1, 1, 1) weight = weight.view(b * 256, 256, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 256, 128, 128) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot9dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot9dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation(style).view(1, 1, 256, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight * style weight = weight.view(3, 256, 1, 1) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 128, 128) out = out + self.stylegan_decoderdotto_rgbsdot4dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 11 out = out0 x = out b, c, h, w = 1, 256, 128, 128 i += 2 style = latent[:, i] # 数值一致 style = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) #conv1 weight = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 128, 1, 1, 1) weight = weight.view(b * 128, 256, 3, 3) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) b, c, h, w = x.shape out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 128, 256, 256) * 2 ** 0.5 b, _, h, w = 1, _, 256, 256 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot10dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot10dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 128, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 128, 1, 1, 1) weight = weight.view(b * 128, 128, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 128, 256, 256) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot11dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot11dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation(style).view(1, 1, 128, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight * style weight = weight.view(3, 128, 1, 1) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 256, 256) out = out + self.stylegan_decoderdotto_rgbsdot5dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip # i = 13 out = out0 x = out b, c, h, w = 1, 128, 256, 256 i += 2 style = latent[:, i] # 数值一致 style = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1) #conv1 weight = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(b, 64, 1, 1, 1) weight = weight.view(b * 64, 128, 3, 3) # self.sample_mode == 'upsample' x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) b, c, h, w = x.shape out = F.conv2d(x, weight, padding=1, groups=b) out = out.view(1, 64, 512, 512) * 2 ** 0.5 b, _, h, w = 1, _, 512, 512 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot12dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot12dotbias out = self.activate(out) out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) out_sft = out_sft * conditions[i - 1] + conditions[i] out = torch.cat([out_same, out_sft], dim=1) #conv2 style = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 64, 1, 1) weight = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight * style # self.demodulate = True: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08) weight = weight * demod.view(1, 64, 1, 1, 1) weight = weight.view(b * 64, 64, 3, 3) out = F.conv2d(out, weight, padding=1, groups=b) out = out.view(1, 64, 512, 512) * 2 ** 0.5 noise = noise_dict[w] out = out + self.stylegan_decoderdotstyle_convsdot13dotweight * noise out = out + self.stylegan_decoderdotstyle_convsdot13dotbias out = self.activate(out) out0 = out #to_rgb x = out style = latent[:, i + 2] style = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation(style).view(1, 1, 64, 1, 1) weight = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight * style weight = weight.view(3, 64, 1, 1) b, c, h, w = x.shape x = x.view(1, b * c, h, w) out = F.conv2d(x, weight, padding=0, groups=b) out = out.view(1, 3, 512, 512) out = out + self.stylegan_decoderdotto_rgbsdot6dotbias skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) skip = out + skip return skip ================================================ FILE: README.md ================================================ # GFPGAN-onnxruntime-demo This is the onnxruntime inference code for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior (CVPR 2021). Official code: https://github.com/TencentARC/GFPGAN ## The following issues are addressed: 1、noise = out.new_empty(b, 1, h, w).normal_() in stylegan2_clean_arch.py can‘t be supported in ONNX. I move it out the Model class, like noise = Noise[i], the Noise is a list or others which prestores generated random noise. 2、the forward function of Model is very bad, especially stylegan, so many " if else " and class be reused. Like the StyleConv " in "useself.style_convs.append StyleConv ...". So I rewrite and make it in single forward. ## convert torch to onnx. ``` wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth python torch2onnx.py --src_model_path ./GFPGANv1.3.pth --dst_model_path ./GFPGANv1.3.onnx --img_size 512 ``` ## run onnx demo. ``` python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v3.jpg ``` | input | output| | :-: |:-:| ||| ||| ||| ================================================ FILE: demo_onnx.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import argparse import cv2 import numpy as np import timeit import onnxruntime class GFPGANFaceAugment: def __init__(self, model_path, use_gpu = False): self.ort_session = onnxruntime.InferenceSession(model_path) self.net_input_name = self.ort_session.get_inputs()[0].name _,self.net_input_channels,self.net_input_height,self.net_input_width = self.ort_session.get_inputs()[0].shape self.net_output_count = len(self.ort_session.get_outputs()) self.face_size = 512 self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) * (self.face_size / 512.0) self.upscale_factor = 2 self.affine = False self.affine_matrix = None def pre_process(self, img): img = cv2.resize(img, (int(img.shape[1] / 2), int(img.shape[0] / 2))) img = cv2.resize(img, (self.face_size, self.face_size)) img = img / 255.0 img = img.astype('float32') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img[:,:,0] = (img[:,:,0]-0.5)/0.5 img[:,:,1] = (img[:,:,1]-0.5)/0.5 img[:,:,2] = (img[:,:,2]-0.5)/0.5 img = np.float32(img[np.newaxis,:,:,:]) img = img.transpose(0, 3, 1, 2) return img def post_process(self, output, height, width): output = output.clip(-1,1) output = (output + 1) / 2 output = output.transpose(1, 2, 0) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = (output * 255.0).round() if self.affine: inverse_affine = cv2.invertAffineTransform(self.affine_matrix) inverse_affine *= self.upscale_factor if self.upscale_factor > 1: extra_offset = 0.5 * self.upscale_factor else: extra_offset = 0 inverse_affine[:, 2] += extra_offset inv_restored = cv2.warpAffine(output, inverse_affine, (width, height)) mask = np.ones((self.face_size, self.face_size), dtype=np.float32) inv_mask = cv2.warpAffine(mask, inverse_affine, (width, height)) inv_mask_erosion = cv2.erode( inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) pasted_face = inv_mask_erosion[:, :, None] * inv_restored total_face_area = np.sum(inv_mask_erosion) # compute the fusion edge based on the area of face w_edge = int(total_face_area**0.5) // 20 erosion_radius = w_edge * 2 inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) blur_size = w_edge * 2 inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) inv_soft_mask = inv_soft_mask[:, :, None] output = pasted_face else: inv_soft_mask = np.ones((height, width, 1), dtype=np.float32) output = cv2.resize(output, (width, height)) return output, inv_soft_mask def forward(self, img): height, width = img.shape[0], img.shape[1] img = self.pre_process(img) t = timeit.default_timer() ort_inputs = {self.ort_session.get_inputs()[0].name: img} ort_outs = self.ort_session.run(None, ort_inputs) output = ort_outs[0][0] output, inv_soft_mask = self.post_process(output, height, width) print('infer time:',timeit.default_timer()-t) output = output.astype(np.uint8) return output, inv_soft_mask if __name__ == "__main__": parser = argparse.ArgumentParser("onnxruntime demo") parser.add_argument('--model_path', type=str, default=None, help='model path') parser.add_argument('--image_path', type=str, default=None, help='input image path') parser.add_argument('--save_path', type=str, default="output.jpg", help='output image path') args = parser.parse_args() faceaugment = GFPGANFaceAugment(model_path=args.model_path) image = cv2.imread(args.image_path, 1) output, _ = faceaugment.forward(image) cv2.imwrite(args.save_path, output) # python demo_onnx.py --model_path GFPGANv1.4.onnx --image_path ./cropped_faces/Adele_crop.png # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v2.jpg # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Julia_Roberts_crop.png --save_path Julia_Roberts_v2.jpg # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Justin_Timberlake_crop.png --save_path Justin_Timberlake_v2.jpg # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Paris_Hilton_crop.png --save_path Paris_Hilton_v2.jpg ================================================ FILE: noise_main.py ================================================ import torch noise_dict = {} size = [(1, 1, 4, 4),(1, 1, 8, 8),(1, 1, 16, 16),(1, 1, 32, 32),(1, 1, 64, 64),(1, 1, 128, 128),(1, 1, 256, 256),(1, 1, 512, 512)] for s in size: out = torch.rand(s)#.cuda() noise = out.new_empty(s).normal_() #print(s[2]) noise_dict[s[2]] = noise #print(noise_dict) ================================================ FILE: torch2onnx.py ================================================ # -*- coding: utf-8 -*- #import cv2 import numpy as np import time import torch import pdb from collections import OrderedDict import sys sys.path.append('.') sys.path.append('./lib') import torch.nn as nn from torch.autograd import Variable import onnxruntime import timeit import argparse from GFPGANReconsitution import GFPGAN parser = argparse.ArgumentParser("ONNX converter") parser.add_argument('--src_model_path', type=str, default=None, help='src model path') parser.add_argument('--dst_model_path', type=str, default=None, help='dst model path') parser.add_argument('--img_size', type=int, default=None, help='img size') args = parser.parse_args() #device = torch.device('cuda') model_path = args.src_model_path onnx_model_path = args.dst_model_path img_size = args.img_size model = GFPGAN()#.cuda() x = torch.rand(1, 3, 512, 512)#.cuda() state_dict = torch.load(model_path)['params_ema'] new_state_dict = OrderedDict() for k, v in state_dict.items(): # stylegan_decoderdotto_rgbsdot1dotmodulated_convdotbias if "stylegan_decoder" in k: k = k.replace('.', 'dot') new_state_dict[k] = v k = k.replace('dotweight', '.weight') k = k.replace('dotbias', '.bias') new_state_dict[k] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) model.eval() torch.onnx.export(model, x, onnx_model_path, export_params=True, opset_version=11, do_constant_folding=True, input_names = ['input'],output_names = []) #### try: original_model = onnx.load(onnx_model_path) passes = ['fuse_bn_into_conv'] optimized_model = optimizer.optimize(original_model, passes) onnx.save(optimized_model, onnx_model_path) except: print('skip optimize.') #### ort_session = onnxruntime.InferenceSession(onnx_model_path) for var in ort_session.get_inputs(): print(var.name) for var in ort_session.get_outputs(): print(var.name) _,_,input_h,input_w = ort_session.get_inputs()[0].shape t = timeit.default_timer() img = np.zeros((input_h,input_w,3)) img = (np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) )#*self.scale img = np.ascontiguousarray(img) # ort_inputs = {ort_session.get_inputs()[0].name: img} ort_outs = ort_session.run(None, ort_inputs) print('onnxruntime infer time:', timeit.default_timer()-t) print(ort_outs[0].shape) # python torch2onnx.py --src_model_path ./experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGAN.onnx --img_size 512 # 新版本 # wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth # python torch2onnx.py --src_model_path ./GFPGANv1.4.pth --dst_model_path ./GFPGANv1.4.onnx --img_size 512 # python torch2onnx.py --src_model_path ./GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGANv1.2.onnx --img_size 512