Repository: xuanandsix/GFPGAN-onnxruntime-demo
Branch: main
Commit: 3728cee8138d
Files: 5
Total size: 45.6 KB
Directory structure:
gitextract_ll211iz8/
├── GFPGANReconsitution.py
├── README.md
├── demo_onnx.py
├── noise_main.py
└── torch2onnx.py
================================================
FILE CONTENTS
================================================
================================================
FILE: GFPGANReconsitution.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import cv2
import numpy as np
import torch
import timeit
#import onnxruntime
from torch.nn import functional as F
from torchvision.transforms.functional import normalize
from torch import nn
import math
from collections import OrderedDict
from noise_main import noise_dict
class ResBlock(nn.Module):
"""Residual block with upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels, mode='down'):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
if mode == 'down':
self.scale_factor = 0.5
elif mode == 'up':
self.scale_factor = 2
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
skip = self.skip(x)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) # [1, 512, 4, 4]
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
class GFPGAN(nn.Module):
def __init__(self):
super(GFPGAN, self).__init__()
unet_narrow = 0.5
channel_multiplier=2
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.conv_body_first = nn.Conv2d(3, 32, 1)
self.conv_body_down = nn.ModuleList()
in_channels = channels['512']
for i in range(9, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
in_channels = out_channels
num_style_feat = 512
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
linear_out_channel = (int(math.log(512, 2)) * 2 - 2) * num_style_feat
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, 9 + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
in_channels = out_channels
# for SFT
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, 9 + 1):
out_channels = channels[f'{2**i}']
sft_out_channels = out_channels
self.condition_scale.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
self.condition_shift.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
self.stylegan_decoderdotconstant_input = ConstantInput(512, size=4)
# self.style_conv1
self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_conv1dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_conv1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# toRGB
self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 512, 1, 1) /
math.sqrt(512 * 1**2))
self.stylegan_decoderdotto_rgb1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
# i = 1
self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot0dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot0dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot1dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
#self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 512, 1, 1) /
math.sqrt(512 * 1**2))
self.stylegan_decoderdotto_rgbsdot0dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 3
self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot2dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot2dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot3dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot3dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 512, 1, 1) /
math.sqrt(512 * 1**2))
self.stylegan_decoderdotto_rgbsdot1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 5
self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot4dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot4dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot5dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot5dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 512, 1, 1) /
math.sqrt(512 * 1**2))
self.stylegan_decoderdotto_rgbsdot2dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 7
self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot6dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot6dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 512, 512, 3, 3) /
math.sqrt(512 * 3**2))
self.stylegan_decoderdotstyle_convsdot7dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot7dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 512, 1, 1) /
math.sqrt(512 * 1**2))
self.stylegan_decoderdotto_rgbsdot3dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 9
self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 256, 512, 3, 3) /
math.sqrt(256 * 3**2))
self.stylegan_decoderdotstyle_convsdot8dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot8dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))
self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 256, 256, 3, 3) /
math.sqrt(256 * 3**2))
self.stylegan_decoderdotstyle_convsdot9dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot9dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))
self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 256, 1, 1) /
math.sqrt(256 * 1**2))
self.stylegan_decoderdotto_rgbsdot4dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 11
self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 128, 256, 3, 3) /
math.sqrt(128 * 3**2))
self.stylegan_decoderdotstyle_convsdot10dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot10dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))
self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 128, 128, 3, 3) /
math.sqrt(128 * 3**2))
self.stylegan_decoderdotstyle_convsdot11dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot11dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))
self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 128, 1, 1) /
math.sqrt(128 * 1**2))
self.stylegan_decoderdotto_rgbsdot5dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
#i = 13
self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 64, 128, 3, 3) /
math.sqrt(64 * 3**2))
self.stylegan_decoderdotstyle_convsdot12dotweight = nn.Parameter(torch.zeros(1)) # for noise injection
self.stylegan_decoderdotstyle_convsdot12dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))
self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)
self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 64, 64, 3, 3) /
math.sqrt(64 * 3**2))
self.stylegan_decoderdotstyle_convsdot13dotweight = nn.Parameter(torch.zeros(1))
self.stylegan_decoderdotstyle_convsdot13dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))
self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)
self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight = nn.Parameter(
torch.randn(1, 3, 64, 1, 1) /
math.sqrt(64 * 1**2))
self.stylegan_decoderdotto_rgbsdot6dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
'''
'''
def forward(self, x):
# encoder
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
conditions = []
unet_skips = []
out_rgbs = []
for i in range(7):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
style_code = style_code.view(style_code.size(0), -1, 512)
# decode
for i in range(7):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layer
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
styles = [style_code]
#noise = [None] * 15 # for each style conv layer
latent = styles[0]
out = self.stylegan_decoderdotconstant_input(latent.shape[0])
b, c, h, w = 1, 512, 4, 4
# weight modulation
style = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation(latent[:, 0]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight * style # (b, c_out, c_in, k, k)
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(b, 512, 1, 1, 1)
weight = weight.view(b * 512, c, 3, 3)
b, c, h, w = 1, 512, 4, 4
out = out.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(b, 512, *out.shape[2:4]) * 2**0.5
b, _, h, w = 1, 512, 4, 4
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_conv1dotweight * noise
out = out + self.stylegan_decoderdotstyle_conv1dotbias
out = self.activate(out)
out0 = out
# toRGB
x = out ########
style = latent[:, 1] ###########
style = self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation(latent[:, 1]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight * style
weight = weight.view(3, 512, 1, 1)
b, c, h, w = 1, 512, 4, 4
x = x.view(1, 512, 4, 4)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 4, 4)
out = out + self.stylegan_decoderdotto_rgb1dotbias
skip = out
out = out0
# i = 1
i = 1
x = out
b, c, h, w = 1, 512, 4, 4
#conv1
style = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 512, 1, 1, 1)
#
weight = weight.view(b * 512, c, 3, 3)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 512, 8, 8) * 2 ** 0.5
b, _, h, w = 1,_,8,8
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot0dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot0dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 512, 1, 1, 1)
weight = weight.view(b * 512, 512, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 512, 8, 8) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot1dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot1dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight * style
weight = weight.view(3, 512, 1, 1)
#b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 8, 8)
out = out + self.stylegan_decoderdotto_rgbsdot0dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 3
out = out0
x = out
b, c, h, w = 1, 512, 8, 8
i += 2
style = latent[:, i]
#conv1
style = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 512, 1, 1, 1)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weight = weight.view(b * 512, c, 3, 3)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 512, 16, 16) * 2 ** 0.5
b, _, h, w = 1, _, 16, 16
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot2dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot2dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 512, 1, 1, 1)
weight = weight.view(b * 512, 512, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 512, 16, 16) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot3dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot3dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight * style
weight = weight.view(3, 512, 1, 1)
#b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 16, 16)
out = out + self.stylegan_decoderdotto_rgbsdot1dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 5
out = out0
x = out
b, c, h, w = 1, 512, 32, 32
i += 2
style = latent[:, i]
#conv1
style = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 512, 1, 1, 1)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weight = weight.view(b * 512, c, 3, 3)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 512, 32, 32) * 2 ** 0.5
b, _, h, w = 1, _, 32, 32
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot4dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot4dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 512, 1, 1, 1)
weight = weight.view(b * 512, 512, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 512, 32, 32) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot5dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot5dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight * style
weight = weight.view(3, 512, 1, 1)
#b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 32, 32)
out = out + self.stylegan_decoderdotto_rgbsdot2dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 7
out = out0
x = out
b, c, h, w = 1, 512, 32, 32
i += 2
style = latent[:, i] # 数值一致
#conv1
style = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 512, 1, 1, 1)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weight = weight.view(b * 512, c, 3, 3)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 512, 64, 64) * 2 ** 0.5
b, _, h, w = 1, _, 64, 64
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot7dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 512, 1, 1, 1)
weight = weight.view(b * 512, 512, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 512, 64, 64) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot7dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight * style
weight = weight.view(3, 512, 1, 1)
#b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 64, 64)
out = out + self.stylegan_decoderdotto_rgbsdot3dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 9
out = out0
x = out
b, c, h, w = 1, 512, 64, 64
i += 2
style = latent[:, i] # 数值一致
#conv1
style = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 256, 1, 1, 1)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weight = weight.view(b * 256, c, 3, 3)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 256, 128, 128) * 2 ** 0.5
b, _, h, w = 1, _, 128, 128
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot8dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot8dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 256, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 256, 1, 1, 1)
weight = weight.view(b * 256, 256, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 256, 128, 128) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot9dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot9dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation(style).view(1, 1, 256, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight * style
weight = weight.view(3, 256, 1, 1)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 128, 128)
out = out + self.stylegan_decoderdotto_rgbsdot4dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 11
out = out0
x = out
b, c, h, w = 1, 256, 128, 128
i += 2
style = latent[:, i] # 数值一致
style = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
#conv1
weight = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 128, 1, 1, 1)
weight = weight.view(b * 128, 256, 3, 3)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
b, c, h, w = x.shape
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 128, 256, 256) * 2 ** 0.5
b, _, h, w = 1, _, 256, 256
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot10dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot10dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 128, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 128, 1, 1, 1)
weight = weight.view(b * 128, 128, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 128, 256, 256) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot11dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot11dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation(style).view(1, 1, 128, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight * style
weight = weight.view(3, 128, 1, 1)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 256, 256)
out = out + self.stylegan_decoderdotto_rgbsdot5dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
# i = 13
out = out0
x = out
b, c, h, w = 1, 128, 256, 256
i += 2
style = latent[:, i] # 数值一致
style = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
#conv1
weight = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(b, 64, 1, 1, 1)
weight = weight.view(b * 64, 128, 3, 3)
# self.sample_mode == 'upsample'
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
b, c, h, w = x.shape
out = F.conv2d(x, weight, padding=1, groups=b)
out = out.view(1, 64, 512, 512) * 2 ** 0.5
b, _, h, w = 1, _, 512, 512
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot12dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot12dotbias
out = self.activate(out)
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
#conv2
style = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 64, 1, 1)
weight = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight * style
# self.demodulate = True:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
weight = weight * demod.view(1, 64, 1, 1, 1)
weight = weight.view(b * 64, 64, 3, 3)
out = F.conv2d(out, weight, padding=1, groups=b)
out = out.view(1, 64, 512, 512) * 2 ** 0.5
noise = noise_dict[w]
out = out + self.stylegan_decoderdotstyle_convsdot13dotweight * noise
out = out + self.stylegan_decoderdotstyle_convsdot13dotbias
out = self.activate(out)
out0 = out
#to_rgb
x = out
style = latent[:, i + 2]
style = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation(style).view(1, 1, 64, 1, 1)
weight = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight * style
weight = weight.view(3, 64, 1, 1)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
out = F.conv2d(x, weight, padding=0, groups=b)
out = out.view(1, 3, 512, 512)
out = out + self.stylegan_decoderdotto_rgbsdot6dotbias
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
skip = out + skip
return skip
================================================
FILE: README.md
================================================
# GFPGAN-onnxruntime-demo
This is the onnxruntime inference code for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior (CVPR 2021). Official code: https://github.com/TencentARC/GFPGAN
## The following issues are addressed:
1、noise = out.new_empty(b, 1, h, w).normal_() in stylegan2_clean_arch.py can‘t be supported in ONNX. I move it out the Model class, like noise = Noise[i], the Noise is a list or others which prestores generated random noise.
2、the forward function of Model is very bad, especially stylegan, so many " if else " and class be reused. Like the StyleConv " in "useself.style_convs.append StyleConv ...". So I rewrite and make it in single forward.
## convert torch to onnx.
```
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
python torch2onnx.py --src_model_path ./GFPGANv1.3.pth --dst_model_path ./GFPGANv1.3.onnx --img_size 512
```
## run onnx demo.
```
python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v3.jpg
```
| input | output|
| :-: |:-:|
|
|
|
|
|
|
|
|
|
================================================
FILE: demo_onnx.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import cv2
import numpy as np
import timeit
import onnxruntime
class GFPGANFaceAugment:
def __init__(self, model_path, use_gpu = False):
self.ort_session = onnxruntime.InferenceSession(model_path)
self.net_input_name = self.ort_session.get_inputs()[0].name
_,self.net_input_channels,self.net_input_height,self.net_input_width = self.ort_session.get_inputs()[0].shape
self.net_output_count = len(self.ort_session.get_outputs())
self.face_size = 512
self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) * (self.face_size / 512.0)
self.upscale_factor = 2
self.affine = False
self.affine_matrix = None
def pre_process(self, img):
img = cv2.resize(img, (int(img.shape[1] / 2), int(img.shape[0] / 2)))
img = cv2.resize(img, (self.face_size, self.face_size))
img = img / 255.0
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img[:,:,0] = (img[:,:,0]-0.5)/0.5
img[:,:,1] = (img[:,:,1]-0.5)/0.5
img[:,:,2] = (img[:,:,2]-0.5)/0.5
img = np.float32(img[np.newaxis,:,:,:])
img = img.transpose(0, 3, 1, 2)
return img
def post_process(self, output, height, width):
output = output.clip(-1,1)
output = (output + 1) / 2
output = output.transpose(1, 2, 0)
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
output = (output * 255.0).round()
if self.affine:
inverse_affine = cv2.invertAffineTransform(self.affine_matrix)
inverse_affine *= self.upscale_factor
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(output, inverse_affine, (width, height))
mask = np.ones((self.face_size, self.face_size), dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (width, height))
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion)
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
inv_soft_mask = inv_soft_mask[:, :, None]
output = pasted_face
else:
inv_soft_mask = np.ones((height, width, 1), dtype=np.float32)
output = cv2.resize(output, (width, height))
return output, inv_soft_mask
def forward(self, img):
height, width = img.shape[0], img.shape[1]
img = self.pre_process(img)
t = timeit.default_timer()
ort_inputs = {self.ort_session.get_inputs()[0].name: img}
ort_outs = self.ort_session.run(None, ort_inputs)
output = ort_outs[0][0]
output, inv_soft_mask = self.post_process(output, height, width)
print('infer time:',timeit.default_timer()-t)
output = output.astype(np.uint8)
return output, inv_soft_mask
if __name__ == "__main__":
parser = argparse.ArgumentParser("onnxruntime demo")
parser.add_argument('--model_path', type=str, default=None, help='model path')
parser.add_argument('--image_path', type=str, default=None, help='input image path')
parser.add_argument('--save_path', type=str, default="output.jpg", help='output image path')
args = parser.parse_args()
faceaugment = GFPGANFaceAugment(model_path=args.model_path)
image = cv2.imread(args.image_path, 1)
output, _ = faceaugment.forward(image)
cv2.imwrite(args.save_path, output)
# python demo_onnx.py --model_path GFPGANv1.4.onnx --image_path ./cropped_faces/Adele_crop.png
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Julia_Roberts_crop.png --save_path Julia_Roberts_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Justin_Timberlake_crop.png --save_path Justin_Timberlake_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Paris_Hilton_crop.png --save_path Paris_Hilton_v2.jpg
================================================
FILE: noise_main.py
================================================
import torch
noise_dict = {}
size = [(1, 1, 4, 4),(1, 1, 8, 8),(1, 1, 16, 16),(1, 1, 32, 32),(1, 1, 64, 64),(1, 1, 128, 128),(1, 1, 256, 256),(1, 1, 512, 512)]
for s in size:
out = torch.rand(s)#.cuda()
noise = out.new_empty(s).normal_()
#print(s[2])
noise_dict[s[2]] = noise
#print(noise_dict)
================================================
FILE: torch2onnx.py
================================================
# -*- coding: utf-8 -*-
#import cv2
import numpy as np
import time
import torch
import pdb
from collections import OrderedDict
import sys
sys.path.append('.')
sys.path.append('./lib')
import torch.nn as nn
from torch.autograd import Variable
import onnxruntime
import timeit
import argparse
from GFPGANReconsitution import GFPGAN
parser = argparse.ArgumentParser("ONNX converter")
parser.add_argument('--src_model_path', type=str, default=None, help='src model path')
parser.add_argument('--dst_model_path', type=str, default=None, help='dst model path')
parser.add_argument('--img_size', type=int, default=None, help='img size')
args = parser.parse_args()
#device = torch.device('cuda')
model_path = args.src_model_path
onnx_model_path = args.dst_model_path
img_size = args.img_size
model = GFPGAN()#.cuda()
x = torch.rand(1, 3, 512, 512)#.cuda()
state_dict = torch.load(model_path)['params_ema']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# stylegan_decoderdotto_rgbsdot1dotmodulated_convdotbias
if "stylegan_decoder" in k:
k = k.replace('.', 'dot')
new_state_dict[k] = v
k = k.replace('dotweight', '.weight')
k = k.replace('dotbias', '.bias')
new_state_dict[k] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
torch.onnx.export(model, x, onnx_model_path,
export_params=True, opset_version=11, do_constant_folding=True,
input_names = ['input'],output_names = [])
####
try:
original_model = onnx.load(onnx_model_path)
passes = ['fuse_bn_into_conv']
optimized_model = optimizer.optimize(original_model, passes)
onnx.save(optimized_model, onnx_model_path)
except:
print('skip optimize.')
####
ort_session = onnxruntime.InferenceSession(onnx_model_path)
for var in ort_session.get_inputs():
print(var.name)
for var in ort_session.get_outputs():
print(var.name)
_,_,input_h,input_w = ort_session.get_inputs()[0].shape
t = timeit.default_timer()
img = np.zeros((input_h,input_w,3))
img = (np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) )#*self.scale
img = np.ascontiguousarray(img)
#
ort_inputs = {ort_session.get_inputs()[0].name: img}
ort_outs = ort_session.run(None, ort_inputs)
print('onnxruntime infer time:', timeit.default_timer()-t)
print(ort_outs[0].shape)
# python torch2onnx.py --src_model_path ./experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGAN.onnx --img_size 512
# 新版本
# wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
# python torch2onnx.py --src_model_path ./GFPGANv1.4.pth --dst_model_path ./GFPGANv1.4.onnx --img_size 512
# python torch2onnx.py --src_model_path ./GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGANv1.2.onnx --img_size 512