main 3728cee8138d cached
5 files
45.6 KB
14.7k tokens
14 symbols
1 requests
Download .txt
Repository: xuanandsix/GFPGAN-onnxruntime-demo
Branch: main
Commit: 3728cee8138d
Files: 5
Total size: 45.6 KB

Directory structure:
gitextract_ll211iz8/

├── GFPGANReconsitution.py
├── README.md
├── demo_onnx.py
├── noise_main.py
└── torch2onnx.py

================================================
FILE CONTENTS
================================================

================================================
FILE: GFPGANReconsitution.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import cv2
import numpy as np
import torch 
import timeit
#import onnxruntime
from torch.nn import functional as F
from torchvision.transforms.functional import normalize
from torch import nn
import math
from collections import OrderedDict
from noise_main import noise_dict

class ResBlock(nn.Module):
    """Residual block with upsampling/downsampling.

    Args:
        in_channels (int): Channel number of the input.
        out_channels (int): Channel number of the output.
    """

    def __init__(self, in_channels, out_channels, mode='down'):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        if mode == 'down':
            self.scale_factor = 0.5
        elif mode == 'up':
            self.scale_factor = 2

    def forward(self, x):
        out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
        # upsample/downsample
        out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
        out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
        # skip
        x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
        skip = self.skip(x)
        out = out + skip
        return out


class ConstantInput(nn.Module):
    """Constant input.

    Args:
        num_channel (int): Channel number of constant input.
        size (int): Spatial size of constant input.
    """

    def __init__(self, num_channel, size):
        super(ConstantInput, self).__init__()
        self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) # [1, 512, 4, 4]
    
    def forward(self, batch):
        out = self.weight.repeat(batch, 1, 1, 1)
        
        return out

class GFPGAN(nn.Module):
    def __init__(self):
        super(GFPGAN, self).__init__()
        unet_narrow = 0.5
        channel_multiplier=2
        channels = {
            '4': int(512 * unet_narrow),
            '8': int(512 * unet_narrow),
            '16': int(512 * unet_narrow),
            '32': int(512 * unet_narrow),
            '64': int(256 * channel_multiplier * unet_narrow),
            '128': int(128 * channel_multiplier * unet_narrow),
            '256': int(64 * channel_multiplier * unet_narrow),
            '512': int(32 * channel_multiplier * unet_narrow),
            '1024': int(16 * channel_multiplier * unet_narrow)
        }

        self.conv_body_first = nn.Conv2d(3, 32, 1)
        self.conv_body_down = nn.ModuleList()
        
        in_channels = channels['512']
        for i in range(9, 2, -1):
            out_channels = channels[f'{2**(i - 1)}']
            self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
            in_channels = out_channels
        num_style_feat = 512
        self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
        linear_out_channel = (int(math.log(512, 2)) * 2 - 2) * num_style_feat
        self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
        
        # upsample
        in_channels = channels['4']
        self.conv_body_up = nn.ModuleList()
        for i in range(3, 9 + 1):
            out_channels = channels[f'{2**i}']
            self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
            in_channels = out_channels

        # for SFT
        self.condition_scale = nn.ModuleList()
        self.condition_shift = nn.ModuleList()
        for i in range(3, 9 + 1):
            out_channels = channels[f'{2**i}']
            sft_out_channels = out_channels
            self.condition_scale.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
            self.condition_shift.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
        
        self.stylegan_decoderdotconstant_input = ConstantInput(512, size=4)
        
        
        # self.style_conv1
        self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        
        self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        
        self.stylegan_decoderdotstyle_conv1dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        
        
        self.stylegan_decoderdotstyle_conv1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        
        # toRGB
        self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 512, 1, 1) /
            math.sqrt(512 * 1**2))
        self.stylegan_decoderdotto_rgb1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        
        # i = 1
        self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot0dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot0dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot1dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot1dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        #self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 512, 1, 1) /
            math.sqrt(512 * 1**2))
        self.stylegan_decoderdotto_rgbsdot0dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))

        #i = 3
        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot2dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot2dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot3dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot3dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 512, 1, 1) /
            math.sqrt(512 * 1**2))
        self.stylegan_decoderdotto_rgbsdot1dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        
        #i = 5
        self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot4dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot4dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot5dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot5dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 512, 1, 1) /
            math.sqrt(512 * 1**2))
        self.stylegan_decoderdotto_rgbsdot2dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        
        #i = 7
        self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot6dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot6dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 512, 512, 3, 3) /
            math.sqrt(512 * 3**2))
        self.stylegan_decoderdotstyle_convsdot7dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot7dotbias = nn.Parameter(torch.zeros(1, 512, 1, 1))
        
        self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 512, 1, 1) /
            math.sqrt(512 * 1**2))
        self.stylegan_decoderdotto_rgbsdot3dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))

        #i = 9
        self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation = nn.Linear(512, 512, bias=True)
        self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 256, 512, 3, 3) /
            math.sqrt(256 * 3**2))
        self.stylegan_decoderdotstyle_convsdot8dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot8dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))
        
        self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
        self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 256, 256, 3, 3) /
            math.sqrt(256 * 3**2))
        
        self.stylegan_decoderdotstyle_convsdot9dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot9dotbias = nn.Parameter(torch.zeros(1, 256, 1, 1))
        self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
        self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 256, 1, 1) /
            math.sqrt(256 * 1**2))
        self.stylegan_decoderdotto_rgbsdot4dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        
        #i = 11
        self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation = nn.Linear(512, 256, bias=True)
        self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 128, 256, 3, 3) /
            math.sqrt(128 * 3**2))
        self.stylegan_decoderdotstyle_convsdot10dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot10dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))
        self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
        
        self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 128, 128, 3, 3) /
            math.sqrt(128 * 3**2)) 
        self.stylegan_decoderdotstyle_convsdot11dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot11dotbias = nn.Parameter(torch.zeros(1, 128, 1, 1))
        self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
        self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 128, 1, 1) /
            math.sqrt(128 * 1**2))
        self.stylegan_decoderdotto_rgbsdot5dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        
        #i = 13
        self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation = nn.Linear(512, 128, bias=True)
        self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 64, 128, 3, 3) /
            math.sqrt(64 * 3**2))
        self.stylegan_decoderdotstyle_convsdot12dotweight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.stylegan_decoderdotstyle_convsdot12dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))
        self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)
        self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 64, 64, 3, 3) /
            math.sqrt(64 * 3**2))
        self.stylegan_decoderdotstyle_convsdot13dotweight = nn.Parameter(torch.zeros(1))  
        self.stylegan_decoderdotstyle_convsdot13dotbias = nn.Parameter(torch.zeros(1, 64, 1, 1))
        self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation = nn.Linear(512, 64, bias=True)
        self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight = nn.Parameter(
            torch.randn(1, 3, 64, 1, 1) /
            math.sqrt(64 * 1**2))
        self.stylegan_decoderdotto_rgbsdot6dotbias = nn.Parameter(torch.zeros(1, 3, 1, 1))
        ''' 
        '''
    def forward(self, x):
        # encoder
        feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
        conditions = []
        unet_skips = []
        out_rgbs = []

        for i in range(7):
            feat = self.conv_body_down[i](feat)
            unet_skips.insert(0, feat)
        
        feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
        
        # style code
        style_code = self.final_linear(feat.view(feat.size(0), -1))
        style_code = style_code.view(style_code.size(0), -1, 512)

        # decode
        for i in range(7):
            # add unet skip
            feat = feat + unet_skips[i]
            # ResUpLayer
            feat = self.conv_body_up[i](feat)
            # generate scale and shift for SFT layer
            scale = self.condition_scale[i](feat)
            conditions.append(scale.clone())
            shift = self.condition_shift[i](feat)
            conditions.append(shift.clone())

        styles = [style_code]

       
        #noise = [None] * 15  # for each style conv layer
        latent = styles[0]    
        out = self.stylegan_decoderdotconstant_input(latent.shape[0])
    
        b, c, h, w = 1, 512, 4, 4
        # weight modulation
        style = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotmodulation(latent[:, 0]).view(b, 1, c, 1, 1)
        weight = self.stylegan_decoderdotstyle_conv1dotmodulated_convdotweight * style  # (b, c_out, c_in, k, k)
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        weight = weight.view(b * 512, c, 3, 3)
        b, c, h, w = 1, 512, 4, 4
        out = out.view(1, b * c, h, w)
        # weight: (b*c_out, c_in, k, k), groups=b
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(b, 512, *out.shape[2:4]) * 2**0.5 
        b, _, h, w = 1, 512, 4, 4
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_conv1dotweight * noise
        out = out + self.stylegan_decoderdotstyle_conv1dotbias
        out = self.activate(out)
        out0 = out 
        
       
        # toRGB    
        x = out    ########
        style = latent[:, 1] ###########
        style = self.stylegan_decoderdotto_rgb1dotmodulated_convdotmodulation(latent[:, 1]).view(b, 1, c, 1, 1)
        weight = self.stylegan_decoderdotto_rgb1dotmodulated_convdotweight * style     
        weight = weight.view(3, 512, 1, 1)
        b, c, h, w = 1, 512, 4, 4
        x = x.view(1, 512, 4, 4)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 4, 4)
        out = out + self.stylegan_decoderdotto_rgb1dotbias
        skip = out
        out = out0
  
        # i = 1
        i = 1
        x = out
        b, c, h, w = 1, 512, 4, 4
        
        #conv1
        style = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot0dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        #
        weight = weight.view(b * 512, c, 3, 3)
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 512, 8, 8) * 2 ** 0.5 
        b, _, h, w = 1,_,8,8
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot0dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot0dotbias
        out = self.activate(out)
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot1dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        weight = weight.view(b * 512, 512, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 512, 8, 8) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot1dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot1dotbias
        out = self.activate(out)
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]  
        style = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot0dotmodulated_convdotweight * style     
        weight = weight.view(3, 512, 1, 1)
        #b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 8, 8)
        out = out + self.stylegan_decoderdotto_rgbsdot0dotbias
        
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip
        
        # i = 3
        out = out0
        x = out
        b, c, h, w = 1, 512, 8, 8
        i += 2
        style = latent[:, i]
        #conv1
        style = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot2dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weight = weight.view(b * 512, c, 3, 3)
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 512, 16, 16) * 2 ** 0.5 
        b, _, h, w = 1, _, 16, 16
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot2dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot2dotbias
        out = self.activate(out)
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot3dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 512, 1, 1, 1)
        weight = weight.view(b * 512, 512, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 512, 16, 16) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot3dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot3dotbias
        out = self.activate(out)
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]  
        style = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot1dotmodulated_convdotweight * style     
        weight = weight.view(3, 512, 1, 1)
        #b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 16, 16)
        out = out + self.stylegan_decoderdotto_rgbsdot1dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip
        
        
        # i = 5
        out = out0
        x = out
        b, c, h, w = 1, 512, 32, 32
        i += 2
        style = latent[:, i] 
        #conv1
        style = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   
        weight = self.stylegan_decoderdotstyle_convsdot4dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weight = weight.view(b * 512, c, 3, 3)    
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 512, 32, 32) * 2 ** 0.5 
        b, _, h, w = 1, _, 32, 32
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot4dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot4dotbias
        out = self.activate(out)
        
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
       
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot5dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 512, 1, 1, 1)
        weight = weight.view(b * 512, 512, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 512, 32, 32) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot5dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot5dotbias
        out = self.activate(out)
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]
        style = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot2dotmodulated_convdotweight * style     
        
        weight = weight.view(3, 512, 1, 1)
        #b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        
        out = out.view(1, 3, 32, 32)
        out = out + self.stylegan_decoderdotto_rgbsdot2dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip

        # i = 7
        out = out0
        x = out
        b, c, h, w = 1, 512, 32, 32
        i += 2
        style = latent[:, i]   # 数值一致
        #conv1
        style = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   
        weight = self.stylegan_decoderdotstyle_convsdot6dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 512, 1, 1, 1)
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weight = weight.view(b * 512, c, 3, 3)    
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 512, 64, 64) * 2 ** 0.5 
        b, _, h, w = 1, _, 64, 64
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot7dotbias
        out = self.activate(out)
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot7dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 512, 1, 1, 1)
        weight = weight.view(b * 512, 512, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 512, 64, 64) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot7dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot7dotbias
        out = self.activate(out)
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]
        style = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotmodulation(style).view(1, 1, 512, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot3dotmodulated_convdotweight * style     
        weight = weight.view(3, 512, 1, 1)
        #b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 64, 64)
        out = out + self.stylegan_decoderdotto_rgbsdot3dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip
        
        # i = 9
        out = out0
        x = out
        b, c, h, w = 1, 512, 64, 64
        i += 2
        style = latent[:, i]   # 数值一致
        #conv1
        style = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)   
        weight = self.stylegan_decoderdotstyle_convsdot8dotmodulated_convdotweight * style    
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 256, 1, 1, 1)
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weight = weight.view(b * 256, c, 3, 3)    
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 256, 128, 128) * 2 ** 0.5 
        b, _, h, w = 1, _, 128, 128
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot8dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot8dotbias
        out = self.activate(out)
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 256, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot9dotmodulated_convdotweight * style
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 256, 1, 1, 1)
        weight = weight.view(b * 256, 256, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 256, 128, 128) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot9dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot9dotbias
        out = self.activate(out)
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]
        style = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotmodulation(style).view(1, 1, 256, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot4dotmodulated_convdotweight * style     
        weight = weight.view(3, 256, 1, 1)
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 128, 128)
        out = out + self.stylegan_decoderdotto_rgbsdot4dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip

        # i = 11
        out = out0
        x = out
        b, c, h, w = 1, 256, 128, 128
        i += 2
        style = latent[:, i]   # 数值一致 
        style = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)     
        #conv1
        weight = self.stylegan_decoderdotstyle_convsdot10dotmodulated_convdotweight * style    
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 128, 1, 1, 1)
        weight = weight.view(b * 128, 256, 3, 3) 
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        b, c, h, w = x.shape
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 128, 256, 256) * 2 ** 0.5 
        b, _, h, w = 1, _, 256, 256
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot10dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot10dotbias
        out = self.activate(out)
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 128, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot11dotmodulated_convdotweight * style 
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 128, 1, 1, 1)
        weight = weight.view(b * 128, 128, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 128, 256, 256) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot11dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot11dotbias
        out = self.activate(out)  
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]
        style = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotmodulation(style).view(1, 1, 128, 1, 1)
        weight = self.stylegan_decoderdotto_rgbsdot5dotmodulated_convdotweight * style     
        weight = weight.view(3, 128, 1, 1)
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)
        out = out.view(1, 3, 256, 256) 
        out = out + self.stylegan_decoderdotto_rgbsdot5dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip
        
        # i = 13
        out = out0
        x = out
        b, c, h, w = 1, 128, 256, 256
        i += 2
        style = latent[:, i]   # 数值一致 
        style = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotmodulation(latent[:, i]).view(b, 1, c, 1, 1)     
        #conv1
        weight = self.stylegan_decoderdotstyle_convsdot12dotmodulated_convdotweight * style    
        
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(b, 64, 1, 1, 1)
        weight = weight.view(b * 64, 128, 3, 3) 
        
        # self.sample_mode == 'upsample'
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        b, c, h, w = x.shape
        out = F.conv2d(x, weight, padding=1, groups=b)
        out = out.view(1, 64, 512, 512) * 2 ** 0.5 
        b, _, h, w = 1, _, 512, 512
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot12dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot12dotbias
        out = self.activate(out)
        
        out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
        out_sft = out_sft * conditions[i - 1] + conditions[i]
        out = torch.cat([out_same, out_sft], dim=1)
        #conv2
        style = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotmodulation(latent[:, i + 1]).view(1, 1, 64, 1, 1)
        weight = self.stylegan_decoderdotstyle_convsdot13dotmodulated_convdotweight * style 
        
        # self.demodulate = True:
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-08)
        weight = weight * demod.view(1, 64, 1, 1, 1)
        weight = weight.view(b * 64, 64, 3, 3)
        out = F.conv2d(out, weight, padding=1, groups=b)
        out = out.view(1, 64, 512, 512) * 2 ** 0.5 
        noise = noise_dict[w]
        out = out + self.stylegan_decoderdotstyle_convsdot13dotweight * noise
        out = out + self.stylegan_decoderdotstyle_convsdot13dotbias
        out = self.activate(out)    
        out0 = out
        #to_rgb
        x = out
        style = latent[:, i + 2]
        style = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotmodulation(style).view(1, 1, 64, 1, 1) 
        weight = self.stylegan_decoderdotto_rgbsdot6dotmodulated_convdotweight * style     
        weight = weight.view(3, 64, 1, 1)
        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        out = F.conv2d(x, weight, padding=0, groups=b)    
        out = out.view(1, 3, 512, 512) 
        out = out + self.stylegan_decoderdotto_rgbsdot6dotbias
        skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
        skip = out + skip
        return skip        



================================================
FILE: README.md
================================================
# GFPGAN-onnxruntime-demo
This is the onnxruntime inference code for  GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior (CVPR 2021). Official code: https://github.com/TencentARC/GFPGAN

## The following issues are addressed:
1、noise = out.new_empty(b, 1, h, w).normal_() in stylegan2_clean_arch.py can‘t be supported in ONNX. I move it out the Model class, like noise = Noise[i], the Noise is a list or others which prestores generated random noise.

2、the forward function of Model is very bad, especially stylegan, so many " if else " and class be reused. Like the StyleConv " in "useself.style_convs.append StyleConv ...". So I rewrite and make it in single forward.

## convert torch to onnx.
```
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth

python torch2onnx.py  --src_model_path ./GFPGANv1.3.pth --dst_model_path ./GFPGANv1.3.onnx --img_size 512 
```

## run onnx demo.
```
python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v3.jpg
```

| input | output|
| :-: |:-:|
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Justin_Timberlake_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Justin_Timberlake_v2.jpg" height="80%" width="80%">|
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Julia_Roberts_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Julia_Roberts_v2.jpg" height="80%" width="80%">|
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Paris_Hilton_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Paris_Hilton_v2.jpg" height="80%" width="80%">|




================================================
FILE: demo_onnx.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import cv2
import numpy as np
import timeit
import onnxruntime

class GFPGANFaceAugment:
    def __init__(self, model_path, use_gpu = False):
        self.ort_session = onnxruntime.InferenceSession(model_path)
        self.net_input_name = self.ort_session.get_inputs()[0].name
        _,self.net_input_channels,self.net_input_height,self.net_input_width = self.ort_session.get_inputs()[0].shape
        self.net_output_count = len(self.ort_session.get_outputs())
        self.face_size = 512
        self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) * (self.face_size / 512.0)
        self.upscale_factor = 2
        self.affine = False
        self.affine_matrix = None
    def pre_process(self, img):
        img = cv2.resize(img, (int(img.shape[1] / 2), int(img.shape[0] / 2)))
        img = cv2.resize(img, (self.face_size, self.face_size))
        img = img / 255.0
        img = img.astype('float32')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img[:,:,0] = (img[:,:,0]-0.5)/0.5
        img[:,:,1] = (img[:,:,1]-0.5)/0.5
        img[:,:,2] = (img[:,:,2]-0.5)/0.5
        img = np.float32(img[np.newaxis,:,:,:])
        img = img.transpose(0, 3, 1, 2)
        return img
    def post_process(self, output, height, width):
        output = output.clip(-1,1)
        output = (output + 1) / 2
        output = output.transpose(1, 2, 0)
        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
        output = (output * 255.0).round()
        if self.affine:
            inverse_affine = cv2.invertAffineTransform(self.affine_matrix)
            inverse_affine *= self.upscale_factor
            if self.upscale_factor > 1:
                extra_offset = 0.5 * self.upscale_factor
            else:
                extra_offset = 0
            inverse_affine[:, 2] += extra_offset
            inv_restored = cv2.warpAffine(output, inverse_affine, (width, height))
            mask = np.ones((self.face_size, self.face_size), dtype=np.float32)
            inv_mask = cv2.warpAffine(mask, inverse_affine, (width, height))
            inv_mask_erosion = cv2.erode(
                inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
            pasted_face = inv_mask_erosion[:, :, None] * inv_restored
            total_face_area = np.sum(inv_mask_erosion)
            # compute the fusion edge based on the area of face
            w_edge = int(total_face_area**0.5) // 20
            erosion_radius = w_edge * 2
            inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
            blur_size = w_edge * 2
            inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
            inv_soft_mask = inv_soft_mask[:, :, None]
            output = pasted_face
        else:
            inv_soft_mask = np.ones((height, width, 1), dtype=np.float32)
            output = cv2.resize(output, (width, height))
        return output, inv_soft_mask

    def forward(self, img):
        height, width = img.shape[0], img.shape[1]
        img = self.pre_process(img)
        t = timeit.default_timer()
        ort_inputs = {self.ort_session.get_inputs()[0].name: img}
        ort_outs = self.ort_session.run(None, ort_inputs)
        output = ort_outs[0][0]
        output, inv_soft_mask = self.post_process(output, height, width)
        print('infer time:',timeit.default_timer()-t)  
        output = output.astype(np.uint8)
        return output, inv_soft_mask
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser("onnxruntime demo")
    parser.add_argument('--model_path', type=str, default=None, help='model path')
    parser.add_argument('--image_path', type=str, default=None, help='input image path')
    parser.add_argument('--save_path', type=str, default="output.jpg", help='output image path')
    args = parser.parse_args()

    faceaugment = GFPGANFaceAugment(model_path=args.model_path)
    image = cv2.imread(args.image_path, 1)
    output, _ = faceaugment.forward(image)
    cv2.imwrite(args.save_path, output)

# python demo_onnx.py --model_path GFPGANv1.4.onnx --image_path ./cropped_faces/Adele_crop.png


# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Julia_Roberts_crop.png --save_path Julia_Roberts_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Justin_Timberlake_crop.png --save_path Justin_Timberlake_v2.jpg
# python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Paris_Hilton_crop.png --save_path Paris_Hilton_v2.jpg

================================================
FILE: noise_main.py
================================================
import torch
noise_dict = {}
size = [(1, 1, 4, 4),(1, 1, 8, 8),(1, 1, 16, 16),(1, 1, 32, 32),(1, 1, 64, 64),(1, 1, 128, 128),(1, 1, 256, 256),(1, 1, 512, 512)]
for s in size: 
    out = torch.rand(s)#.cuda()
    noise = out.new_empty(s).normal_()
    #print(s[2])
    noise_dict[s[2]] = noise
    #print(noise_dict)


================================================
FILE: torch2onnx.py
================================================
# -*- coding: utf-8 -*-

#import cv2
import numpy as np
import time
import torch
import pdb
from collections import OrderedDict

import sys
sys.path.append('.')
sys.path.append('./lib')
import torch.nn as nn
from torch.autograd import Variable
import onnxruntime
import timeit

import argparse
from GFPGANReconsitution import GFPGAN

parser = argparse.ArgumentParser("ONNX converter")
parser.add_argument('--src_model_path', type=str, default=None, help='src model path')
parser.add_argument('--dst_model_path', type=str, default=None, help='dst model path')
parser.add_argument('--img_size', type=int, default=None, help='img size')
args = parser.parse_args()
    
#device = torch.device('cuda')
model_path = args.src_model_path
onnx_model_path = args.dst_model_path
img_size = args.img_size

model = GFPGAN()#.cuda()

x = torch.rand(1, 3, 512, 512)#.cuda()

state_dict = torch.load(model_path)['params_ema']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    # stylegan_decoderdotto_rgbsdot1dotmodulated_convdotbias
    if "stylegan_decoder" in k:
        k = k.replace('.', 'dot')
        new_state_dict[k] = v
        k = k.replace('dotweight', '.weight')
        k = k.replace('dotbias', '.bias')
        new_state_dict[k] = v
    else:
        new_state_dict[k] = v
     
model.load_state_dict(new_state_dict, strict=False)
model.eval()

torch.onnx.export(model, x, onnx_model_path,
                    export_params=True, opset_version=11, do_constant_folding=True,
                    input_names = ['input'],output_names = [])


####
try:
    original_model = onnx.load(onnx_model_path)
    passes = ['fuse_bn_into_conv']
    optimized_model = optimizer.optimize(original_model, passes)
    onnx.save(optimized_model, onnx_model_path)
except:
    print('skip optimize.')

####
ort_session = onnxruntime.InferenceSession(onnx_model_path)
for var in ort_session.get_inputs():
    print(var.name)
for var in ort_session.get_outputs():
    print(var.name)
_,_,input_h,input_w = ort_session.get_inputs()[0].shape
t = timeit.default_timer()

img = np.zeros((input_h,input_w,3))

img = (np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) )#*self.scale

img = np.ascontiguousarray(img)
#    
ort_inputs = {ort_session.get_inputs()[0].name: img}
ort_outs = ort_session.run(None, ort_inputs)

print('onnxruntime infer time:', timeit.default_timer()-t)
print(ort_outs[0].shape)

# python torch2onnx.py  --src_model_path ./experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGAN.onnx --img_size 512 

# 新版本


# wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth

# python torch2onnx.py  --src_model_path ./GFPGANv1.4.pth --dst_model_path ./GFPGANv1.4.onnx --img_size 512 

# python torch2onnx.py  --src_model_path ./GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGANv1.2.onnx --img_size 512 
Download .txt
gitextract_ll211iz8/

├── GFPGANReconsitution.py
├── README.md
├── demo_onnx.py
├── noise_main.py
└── torch2onnx.py
Download .txt
SYMBOL INDEX (14 symbols across 2 files)

FILE: GFPGANReconsitution.py
  class ResBlock (line 18) | class ResBlock(nn.Module):
    method __init__ (line 26) | def __init__(self, in_channels, out_channels, mode='down'):
    method forward (line 37) | def forward(self, x):
  class ConstantInput (line 49) | class ConstantInput(nn.Module):
    method __init__ (line 57) | def __init__(self, num_channel, size):
    method forward (line 61) | def forward(self, batch):
  class GFPGAN (line 66) | class GFPGAN(nn.Module):
    method __init__ (line 67) | def __init__(self):
    method forward (line 283) | def forward(self, x):

FILE: demo_onnx.py
  class GFPGANFaceAugment (line 11) | class GFPGANFaceAugment:
    method __init__ (line 12) | def __init__(self, model_path, use_gpu = False):
    method pre_process (line 22) | def pre_process(self, img):
    method post_process (line 34) | def post_process(self, output, height, width):
    method forward (line 68) | def forward(self, img):
Condensed preview — 5 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (48K chars).
[
  {
    "path": "GFPGANReconsitution.py",
    "chars": 36834,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os\nimport sys\nimport argparse\nimport cv2\nimport numpy as np\nimport "
  },
  {
    "path": "README.md",
    "chars": 1904,
    "preview": "# GFPGAN-onnxruntime-demo\nThis is the onnxruntime inference code for  GFP-GAN: Towards Real-World Blind Face Restoration"
  },
  {
    "path": "demo_onnx.py",
    "chars": 4812,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os\nimport sys\nimport argparse\nimport cv2\nimport numpy as np\nimport "
  },
  {
    "path": "noise_main.py",
    "chars": 316,
    "preview": "import torch\nnoise_dict = {}\nsize = [(1, 1, 4, 4),(1, 1, 8, 8),(1, 1, 16, 16),(1, 1, 32, 32),(1, 1, 64, 64),(1, 1, 128, "
  },
  {
    "path": "torch2onnx.py",
    "chars": 2871,
    "preview": "# -*- coding: utf-8 -*-\n\n#import cv2\nimport numpy as np\nimport time\nimport torch\nimport pdb\nfrom collections import Orde"
  }
]

About this extraction

This page contains the full source code of the xuanandsix/GFPGAN-onnxruntime-demo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 5 files (45.6 KB), approximately 14.7k tokens, and a symbol index with 14 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!