master 33875eccf4eb cached
75 files
353.2 KB
90.5k tokens
447 symbols
1 requests
Download .txt
Showing preview only (374K chars total). Download the full file or copy to clipboard to get everything.
Repository: microsoft/Bringing-Old-Photos-Back-to-Life
Branch: master
Commit: 33875eccf4eb
Files: 75
Total size: 353.2 KB

Directory structure:
gitextract_ojvoon3f/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── Dockerfile
├── Face_Detection/
│   ├── align_warp_back_multiple_dlib.py
│   ├── align_warp_back_multiple_dlib_HR.py
│   ├── detect_all_dlib.py
│   └── detect_all_dlib_HR.py
├── Face_Enhancement/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_dataset.py
│   │   ├── custom_dataset.py
│   │   ├── face_dataset.py
│   │   ├── image_folder.py
│   │   └── pix2pix_dataset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── networks/
│   │   │   ├── __init__.py
│   │   │   ├── architecture.py
│   │   │   ├── base_network.py
│   │   │   ├── encoder.py
│   │   │   ├── generator.py
│   │   │   └── normalization.py
│   │   └── pix2pix_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   └── test_options.py
│   ├── requirements.txt
│   ├── test_face.py
│   └── util/
│       ├── __init__.py
│       ├── iter_counter.py
│       ├── util.py
│       └── visualizer.py
├── GUI.py
├── Global/
│   ├── data/
│   │   ├── Create_Bigfile.py
│   │   ├── Load_Bigfile.py
│   │   ├── __init__.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── custom_dataset_data_loader.py
│   │   ├── data_loader.py
│   │   ├── image_folder.py
│   │   └── online_dataset_for_old_photos.py
│   ├── detection.py
│   ├── detection_models/
│   │   ├── __init__.py
│   │   ├── antialiasing.py
│   │   └── networks.py
│   ├── detection_util/
│   │   └── util.py
│   ├── models/
│   │   ├── NonLocal_feature_mapping_model.py
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── mapping_model.py
│   │   ├── models.py
│   │   ├── networks.py
│   │   ├── pix2pixHD_model.py
│   │   └── pix2pixHD_model_DA.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── test.py
│   ├── train_domain_A.py
│   ├── train_domain_B.py
│   ├── train_mapping.py
│   └── util/
│       ├── __init__.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
├── LICENSE
├── README.md
├── SECURITY.md
├── ansible.yaml
├── cog.yaml
├── download-weights
├── kubernetes-pod.yml
├── predict.py
├── requirements.txt
└── run.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__/
*.pyc
*~



================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Resources:

- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns


================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:11.1-base-ubuntu20.04

RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq
ADD . /app
WORKDIR /app
RUN cd Face_Enhancement/models/networks/ &&\
  git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
  cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
  cd ../../../

RUN cd Global/detection_models &&\
  git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
  cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
  cd ../../

RUN cd Face_Detection/ &&\
  wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\
  bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\
  cd ../ 

RUN cd Face_Enhancement/ &&\
  wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\
  unzip checkpoints.zip &&\
  cd ../ &&\
  cd Global/ &&\
  wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\
  unzip checkpoints.zip &&\
  rm -f checkpoints.zip &&\
  cd ../

RUN pip3 install numpy

RUN pip3 install dlib

RUN pip3 install -r requirements.txt

RUN git clone https://github.com/NVlabs/SPADE.git

RUN cd SPADE/ && pip3 install -r requirements.txt

RUN cd ..

CMD ["python3", "run.py"]


================================================
FILE: Face_Detection/align_warp_back_multiple_dlib.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import numpy as np
import skimage.io as io

# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image, ImageFilter
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib


def calculate_cdf(histogram):
    """
    This method calculates the cumulative distribution function
    :param array histogram: The values of the histogram
    :return: normalized_cdf: The normalized cumulative distribution function
    :rtype: array
    """
    # Get the cumulative sum of the elements
    cdf = histogram.cumsum()

    # Normalize the cdf
    normalized_cdf = cdf / float(cdf.max())

    return normalized_cdf


def calculate_lookup(src_cdf, ref_cdf):
    """
    This method creates the lookup table
    :param array src_cdf: The cdf for the source image
    :param array ref_cdf: The cdf for the reference image
    :return: lookup_table: The lookup table
    :rtype: array
    """
    lookup_table = np.zeros(256)
    lookup_val = 0
    for src_pixel_val in range(len(src_cdf)):
        lookup_val
        for ref_pixel_val in range(len(ref_cdf)):
            if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
                lookup_val = ref_pixel_val
                break
        lookup_table[src_pixel_val] = lookup_val
    return lookup_table


def match_histograms(src_image, ref_image):
    """
    This method matches the source image histogram to the
    reference signal
    :param image src_image: The original source image
    :param image  ref_image: The reference image
    :return: image_after_matching
    :rtype: image (array)
    """
    # Split the images into the different color channels
    # b means blue, g means green and r means red
    src_b, src_g, src_r = cv2.split(src_image)
    ref_b, ref_g, ref_r = cv2.split(ref_image)

    # Compute the b, g, and r histograms separately
    # The flatten() Numpy method returns a copy of the array c
    # collapsed into one dimension.
    src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
    src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
    src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
    ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
    ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
    ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])

    # Compute the normalized cdf for the source and reference image
    src_cdf_blue = calculate_cdf(src_hist_blue)
    src_cdf_green = calculate_cdf(src_hist_green)
    src_cdf_red = calculate_cdf(src_hist_red)
    ref_cdf_blue = calculate_cdf(ref_hist_blue)
    ref_cdf_green = calculate_cdf(ref_hist_green)
    ref_cdf_red = calculate_cdf(ref_hist_red)

    # Make a separate lookup table for each color
    blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
    green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
    red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)

    # Use the lookup function to transform the colors of the original
    # source image
    blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
    green_after_transform = cv2.LUT(src_g, green_lookup_table)
    red_after_transform = cv2.LUT(src_r, red_lookup_table)

    # Put the image back together
    image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
    image_after_matching = cv2.convertScaleAbs(image_after_matching)

    return image_after_matching


def _standard_face_pts():
    pts = (
        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
        - 1.0
    )

    return np.reshape(pts, (5, 2))


def _origin_face_pts():
    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)

    return np.reshape(pts, (5, 2))


def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(target_pts, landmark)

    return affine


def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(landmark, target_pts)

    return affine


def show_detection(image, box, landmark):
    plt.imshow(image)
    print(box[2] - box[0])
    plt.gca().add_patch(
        Rectangle(
            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
        )
    )
    plt.scatter(landmark[0][0], landmark[0][1])
    plt.scatter(landmark[1][0], landmark[1][1])
    plt.scatter(landmark[2][0], landmark[2][1])
    plt.scatter(landmark[3][0], landmark[3][1])
    plt.scatter(landmark[4][0], landmark[4][1])
    plt.show()


def affine2theta(affine, input_w, input_h, target_w, target_h):
    # param = np.linalg.inv(affine)
    param = affine
    theta = np.zeros([2, 3])
    theta[0, 0] = param[0, 0] * input_h / target_h
    theta[0, 1] = param[0, 1] * input_w / target_h
    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
    theta[1, 0] = param[1, 0] * input_h / target_w
    theta[1, 1] = param[1, 1] * input_w / target_w
    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
    return theta


def blur_blending(im1, im2, mask):

    mask *= 255.0

    kernel = np.ones((10, 10), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=1)

    mask = Image.fromarray(mask.astype("uint8")).convert("L")
    im1 = Image.fromarray(im1.astype("uint8"))
    im2 = Image.fromarray(im2.astype("uint8"))

    mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
    im = Image.composite(im1, im2, mask)

    im = Image.composite(im, im2, mask_blur)

    return np.array(im) / 255.0


def blur_blending_cv2(im1, im2, mask):

    mask *= 255.0

    kernel = np.ones((9, 9), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=3)

    mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
    mask_blur /= 255.0

    im = im1 * mask_blur + (1 - mask_blur) * im2

    im /= 255.0
    im = np.clip(im, 0.0, 1.0)

    return im


# def Poisson_blending(im1,im2,mask):


#     Image.composite(
def Poisson_blending(im1, im2, mask):

    # mask=1-mask
    mask *= 255
    kernel = np.ones((10, 10), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=1)
    mask /= 255
    mask = 1 - mask
    mask *= 255

    mask = mask[:, :, 0]
    width, height, channels = im1.shape
    center = (int(height / 2), int(width / 2))
    result = cv2.seamlessClone(
        im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
    )

    return result / 255.0


def Poisson_B(im1, im2, mask, center):

    mask *= 255

    result = cv2.seamlessClone(
        im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
    )

    return result / 255


def seamless_clone(old_face, new_face, raw_mask):

    height, width, _ = old_face.shape
    height = height // 2
    width = width // 2

    y_indices, x_indices, _ = np.nonzero(raw_mask)
    y_crop = slice(np.min(y_indices), np.max(y_indices))
    x_crop = slice(np.min(x_indices), np.max(x_indices))
    y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
    x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))

    insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
    insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
    insertion_mask[insertion_mask != 0] = 255
    prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
        "uint8"
    )
    # if np.sum(insertion_mask) == 0:
    n_mask = insertion_mask[1:-1, 1:-1, :]
    n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
    print(n_mask.shape)
    x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
    if w < 4 or h < 4:
        blended = prior
    else:
        blended = cv2.seamlessClone(
            insertion,  # pylint: disable=no-member
            prior,
            insertion_mask,
            (x_center, y_center),
            cv2.NORMAL_CLONE,
        )  # pylint: disable=no-member

    blended = blended[height:-height, width:-width]

    return blended.astype("float32") / 255.0


def get_landmark(face_landmarks, id):
    part = face_landmarks.part(id)
    x = part.x
    y = part.y

    return (x, y)


def search(face_landmarks):

    x1, y1 = get_landmark(face_landmarks, 36)
    x2, y2 = get_landmark(face_landmarks, 39)
    x3, y3 = get_landmark(face_landmarks, 42)
    x4, y4 = get_landmark(face_landmarks, 45)

    x_nose, y_nose = get_landmark(face_landmarks, 30)

    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)

    x_left_eye = int((x1 + x2) / 2)
    y_left_eye = int((y1 + y2) / 2)
    x_right_eye = int((x3 + x4) / 2)
    y_right_eye = int((y3 + y4) / 2)

    results = np.array(
        [
            [x_left_eye, y_left_eye],
            [x_right_eye, y_right_eye],
            [x_nose, y_nose],
            [x_left_mouth, y_left_mouth],
            [x_right_mouth, y_right_mouth],
        ]
    )

    return results


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--origin_url", type=str, default="./", help="origin images")
    parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
    parser.add_argument("--save_url", type=str, default="./save")
    opts = parser.parse_args()

    origin_url = opts.origin_url
    replace_url = opts.replace_url
    save_url = opts.save_url

    if not os.path.exists(save_url):
        os.makedirs(save_url)

    face_detector = dlib.get_frontal_face_detector()
    landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")

    count = 0

    for x in os.listdir(origin_url):
        img_url = os.path.join(origin_url, x)
        pil_img = Image.open(img_url).convert("RGB")

        origin_width, origin_height = pil_img.size
        image = np.array(pil_img)

        start = time.time()
        faces = face_detector(image)
        done = time.time()

        if len(faces) == 0:
            print("Warning: There is no face in %s" % (x))
            continue

        blended = image
        for face_id in range(len(faces)):

            current_face = faces[face_id]
            face_landmarks = landmark_locator(image, current_face)
            current_fl = search(face_landmarks)

            forward_mask = np.ones_like(image).astype("uint8")
            affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
            aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True)
            forward_mask = warp(
                forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True
            )

            affine_inverse = affine.inverse
            cur_face = aligned_face
            if replace_url != "":

                face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
                cur_url = os.path.join(replace_url, face_name)
                restored_face = Image.open(cur_url).convert("RGB")
                restored_face = np.array(restored_face)
                cur_face = restored_face

            ## Histogram Color matching
            A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
            B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
            B = match_histograms(B, A)
            cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)

            warped_back = warp(
                cur_face,
                affine_inverse,
                output_shape=(origin_height, origin_width, 3),
                order=3,
                preserve_range=True,
            )

            backward_mask = warp(
                forward_mask,
                affine_inverse,
                output_shape=(origin_height, origin_width, 3),
                order=0,
                preserve_range=True,
            )  ## Nearest neighbour

            blended = blur_blending_cv2(warped_back, blended, backward_mask)
            blended *= 255.0

        io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))

        count += 1

        if count % 1000 == 0:
            print("%d have finished ..." % (count))



================================================
FILE: Face_Detection/align_warp_back_multiple_dlib_HR.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import numpy as np
import skimage.io as io

# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image, ImageFilter
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib


def calculate_cdf(histogram):
    """
    This method calculates the cumulative distribution function
    :param array histogram: The values of the histogram
    :return: normalized_cdf: The normalized cumulative distribution function
    :rtype: array
    """
    # Get the cumulative sum of the elements
    cdf = histogram.cumsum()

    # Normalize the cdf
    normalized_cdf = cdf / float(cdf.max())

    return normalized_cdf


def calculate_lookup(src_cdf, ref_cdf):
    """
    This method creates the lookup table
    :param array src_cdf: The cdf for the source image
    :param array ref_cdf: The cdf for the reference image
    :return: lookup_table: The lookup table
    :rtype: array
    """
    lookup_table = np.zeros(256)
    lookup_val = 0
    for src_pixel_val in range(len(src_cdf)):
        lookup_val
        for ref_pixel_val in range(len(ref_cdf)):
            if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
                lookup_val = ref_pixel_val
                break
        lookup_table[src_pixel_val] = lookup_val
    return lookup_table


def match_histograms(src_image, ref_image):
    """
    This method matches the source image histogram to the
    reference signal
    :param image src_image: The original source image
    :param image  ref_image: The reference image
    :return: image_after_matching
    :rtype: image (array)
    """
    # Split the images into the different color channels
    # b means blue, g means green and r means red
    src_b, src_g, src_r = cv2.split(src_image)
    ref_b, ref_g, ref_r = cv2.split(ref_image)

    # Compute the b, g, and r histograms separately
    # The flatten() Numpy method returns a copy of the array c
    # collapsed into one dimension.
    src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
    src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
    src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
    ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
    ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
    ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])

    # Compute the normalized cdf for the source and reference image
    src_cdf_blue = calculate_cdf(src_hist_blue)
    src_cdf_green = calculate_cdf(src_hist_green)
    src_cdf_red = calculate_cdf(src_hist_red)
    ref_cdf_blue = calculate_cdf(ref_hist_blue)
    ref_cdf_green = calculate_cdf(ref_hist_green)
    ref_cdf_red = calculate_cdf(ref_hist_red)

    # Make a separate lookup table for each color
    blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
    green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
    red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)

    # Use the lookup function to transform the colors of the original
    # source image
    blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
    green_after_transform = cv2.LUT(src_g, green_lookup_table)
    red_after_transform = cv2.LUT(src_r, red_lookup_table)

    # Put the image back together
    image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
    image_after_matching = cv2.convertScaleAbs(image_after_matching)

    return image_after_matching


def _standard_face_pts():
    pts = (
        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
        - 1.0
    )

    return np.reshape(pts, (5, 2))


def _origin_face_pts():
    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)

    return np.reshape(pts, (5, 2))


def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(target_pts, landmark)

    return affine


def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(landmark, target_pts)

    return affine


def show_detection(image, box, landmark):
    plt.imshow(image)
    print(box[2] - box[0])
    plt.gca().add_patch(
        Rectangle(
            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
        )
    )
    plt.scatter(landmark[0][0], landmark[0][1])
    plt.scatter(landmark[1][0], landmark[1][1])
    plt.scatter(landmark[2][0], landmark[2][1])
    plt.scatter(landmark[3][0], landmark[3][1])
    plt.scatter(landmark[4][0], landmark[4][1])
    plt.show()


def affine2theta(affine, input_w, input_h, target_w, target_h):
    # param = np.linalg.inv(affine)
    param = affine
    theta = np.zeros([2, 3])
    theta[0, 0] = param[0, 0] * input_h / target_h
    theta[0, 1] = param[0, 1] * input_w / target_h
    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
    theta[1, 0] = param[1, 0] * input_h / target_w
    theta[1, 1] = param[1, 1] * input_w / target_w
    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
    return theta


def blur_blending(im1, im2, mask):

    mask *= 255.0

    kernel = np.ones((10, 10), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=1)

    mask = Image.fromarray(mask.astype("uint8")).convert("L")
    im1 = Image.fromarray(im1.astype("uint8"))
    im2 = Image.fromarray(im2.astype("uint8"))

    mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
    im = Image.composite(im1, im2, mask)

    im = Image.composite(im, im2, mask_blur)

    return np.array(im) / 255.0


def blur_blending_cv2(im1, im2, mask):

    mask *= 255.0

    kernel = np.ones((9, 9), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=3)

    mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
    mask_blur /= 255.0

    im = im1 * mask_blur + (1 - mask_blur) * im2

    im /= 255.0
    im = np.clip(im, 0.0, 1.0)

    return im


# def Poisson_blending(im1,im2,mask):


#     Image.composite(
def Poisson_blending(im1, im2, mask):

    # mask=1-mask
    mask *= 255
    kernel = np.ones((10, 10), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=1)
    mask /= 255
    mask = 1 - mask
    mask *= 255

    mask = mask[:, :, 0]
    width, height, channels = im1.shape
    center = (int(height / 2), int(width / 2))
    result = cv2.seamlessClone(
        im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
    )

    return result / 255.0


def Poisson_B(im1, im2, mask, center):

    mask *= 255

    result = cv2.seamlessClone(
        im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
    )

    return result / 255


def seamless_clone(old_face, new_face, raw_mask):

    height, width, _ = old_face.shape
    height = height // 2
    width = width // 2

    y_indices, x_indices, _ = np.nonzero(raw_mask)
    y_crop = slice(np.min(y_indices), np.max(y_indices))
    x_crop = slice(np.min(x_indices), np.max(x_indices))
    y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
    x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))

    insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
    insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
    insertion_mask[insertion_mask != 0] = 255
    prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
        "uint8"
    )
    # if np.sum(insertion_mask) == 0:
    n_mask = insertion_mask[1:-1, 1:-1, :]
    n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
    print(n_mask.shape)
    x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
    if w < 4 or h < 4:
        blended = prior
    else:
        blended = cv2.seamlessClone(
            insertion,  # pylint: disable=no-member
            prior,
            insertion_mask,
            (x_center, y_center),
            cv2.NORMAL_CLONE,
        )  # pylint: disable=no-member

    blended = blended[height:-height, width:-width]

    return blended.astype("float32") / 255.0


def get_landmark(face_landmarks, id):
    part = face_landmarks.part(id)
    x = part.x
    y = part.y

    return (x, y)


def search(face_landmarks):

    x1, y1 = get_landmark(face_landmarks, 36)
    x2, y2 = get_landmark(face_landmarks, 39)
    x3, y3 = get_landmark(face_landmarks, 42)
    x4, y4 = get_landmark(face_landmarks, 45)

    x_nose, y_nose = get_landmark(face_landmarks, 30)

    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)

    x_left_eye = int((x1 + x2) / 2)
    y_left_eye = int((y1 + y2) / 2)
    x_right_eye = int((x3 + x4) / 2)
    y_right_eye = int((y3 + y4) / 2)

    results = np.array(
        [
            [x_left_eye, y_left_eye],
            [x_right_eye, y_right_eye],
            [x_nose, y_nose],
            [x_left_mouth, y_left_mouth],
            [x_right_mouth, y_right_mouth],
        ]
    )

    return results


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--origin_url", type=str, default="./", help="origin images")
    parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
    parser.add_argument("--save_url", type=str, default="./save")
    opts = parser.parse_args()

    origin_url = opts.origin_url
    replace_url = opts.replace_url
    save_url = opts.save_url

    if not os.path.exists(save_url):
        os.makedirs(save_url)

    face_detector = dlib.get_frontal_face_detector()
    landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")

    count = 0

    for x in os.listdir(origin_url):
        img_url = os.path.join(origin_url, x)
        pil_img = Image.open(img_url).convert("RGB")

        origin_width, origin_height = pil_img.size
        image = np.array(pil_img)

        start = time.time()
        faces = face_detector(image)
        done = time.time()

        if len(faces) == 0:
            print("Warning: There is no face in %s" % (x))
            continue

        blended = image
        for face_id in range(len(faces)):

            current_face = faces[face_id]
            face_landmarks = landmark_locator(image, current_face)
            current_fl = search(face_landmarks)

            forward_mask = np.ones_like(image).astype("uint8")
            affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
            aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True)
            forward_mask = warp(
                forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True
            )

            affine_inverse = affine.inverse
            cur_face = aligned_face
            if replace_url != "":

                face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
                cur_url = os.path.join(replace_url, face_name)
                restored_face = Image.open(cur_url).convert("RGB")
                restored_face = np.array(restored_face)
                cur_face = restored_face

            ## Histogram Color matching
            A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
            B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
            B = match_histograms(B, A)
            cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)

            warped_back = warp(
                cur_face,
                affine_inverse,
                output_shape=(origin_height, origin_width, 3),
                order=3,
                preserve_range=True,
            )

            backward_mask = warp(
                forward_mask,
                affine_inverse,
                output_shape=(origin_height, origin_width, 3),
                order=0,
                preserve_range=True,
            )  ## Nearest neighbour

            blended = blur_blending_cv2(warped_back, blended, backward_mask)
            blended *= 255.0

        io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))

        count += 1

        if count % 1000 == 0:
            print("%d have finished ..." % (count))



================================================
FILE: Face_Detection/detect_all_dlib.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import numpy as np
import skimage.io as io

# from FaceSDK.face_sdk import FaceDetection
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib


def _standard_face_pts():
    pts = (
        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
        - 1.0
    )

    return np.reshape(pts, (5, 2))


def _origin_face_pts():
    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)

    return np.reshape(pts, (5, 2))


def get_landmark(face_landmarks, id):
    part = face_landmarks.part(id)
    x = part.x
    y = part.y

    return (x, y)


def search(face_landmarks):

    x1, y1 = get_landmark(face_landmarks, 36)
    x2, y2 = get_landmark(face_landmarks, 39)
    x3, y3 = get_landmark(face_landmarks, 42)
    x4, y4 = get_landmark(face_landmarks, 45)

    x_nose, y_nose = get_landmark(face_landmarks, 30)

    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)

    x_left_eye = int((x1 + x2) / 2)
    y_left_eye = int((y1 + y2) / 2)
    x_right_eye = int((x3 + x4) / 2)
    y_right_eye = int((y3 + y4) / 2)

    results = np.array(
        [
            [x_left_eye, y_left_eye],
            [x_right_eye, y_right_eye],
            [x_nose, y_nose],
            [x_left_mouth, y_left_mouth],
            [x_right_mouth, y_right_mouth],
        ]
    )

    return results


def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(target_pts, landmark)

    return affine.params


def show_detection(image, box, landmark):
    plt.imshow(image)
    print(box[2] - box[0])
    plt.gca().add_patch(
        Rectangle(
            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
        )
    )
    plt.scatter(landmark[0][0], landmark[0][1])
    plt.scatter(landmark[1][0], landmark[1][1])
    plt.scatter(landmark[2][0], landmark[2][1])
    plt.scatter(landmark[3][0], landmark[3][1])
    plt.scatter(landmark[4][0], landmark[4][1])
    plt.show()


def affine2theta(affine, input_w, input_h, target_w, target_h):
    # param = np.linalg.inv(affine)
    param = affine
    theta = np.zeros([2, 3])
    theta[0, 0] = param[0, 0] * input_h / target_h
    theta[0, 1] = param[0, 1] * input_w / target_h
    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
    theta[1, 0] = param[1, 0] * input_h / target_w
    theta[1, 1] = param[1, 1] * input_w / target_w
    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
    return theta


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
    parser.add_argument(
        "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
    )
    opts = parser.parse_args()

    url = opts.url
    save_url = opts.save_url

    ### If the origin url is None, then we don't need to reid the origin image

    os.makedirs(url, exist_ok=True)
    os.makedirs(save_url, exist_ok=True)

    face_detector = dlib.get_frontal_face_detector()
    landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")

    count = 0

    map_id = {}
    for x in os.listdir(url):
        img_url = os.path.join(url, x)
        pil_img = Image.open(img_url).convert("RGB")

        image = np.array(pil_img)

        start = time.time()
        faces = face_detector(image)
        done = time.time()

        if len(faces) == 0:
            print("Warning: There is no face in %s" % (x))
            continue

        print(len(faces))

        if len(faces) > 0:
            for face_id in range(len(faces)):
                current_face = faces[face_id]
                face_landmarks = landmark_locator(image, current_face)
                current_fl = search(face_landmarks)

                affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
                aligned_face = warp(image, affine, output_shape=(256, 256, 3))
                img_name = x[:-4] + "_" + str(face_id + 1)
                io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))

        count += 1

        if count % 1000 == 0:
            print("%d have finished ..." % (count))



================================================
FILE: Face_Detection/detect_all_dlib_HR.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import numpy as np
import skimage.io as io

# from FaceSDK.face_sdk import FaceDetection
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib


def _standard_face_pts():
    pts = (
        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
        - 1.0
    )

    return np.reshape(pts, (5, 2))


def _origin_face_pts():
    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)

    return np.reshape(pts, (5, 2))


def get_landmark(face_landmarks, id):
    part = face_landmarks.part(id)
    x = part.x
    y = part.y

    return (x, y)


def search(face_landmarks):

    x1, y1 = get_landmark(face_landmarks, 36)
    x2, y2 = get_landmark(face_landmarks, 39)
    x3, y3 = get_landmark(face_landmarks, 42)
    x4, y4 = get_landmark(face_landmarks, 45)

    x_nose, y_nose = get_landmark(face_landmarks, 30)

    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)

    x_left_eye = int((x1 + x2) / 2)
    y_left_eye = int((y1 + y2) / 2)
    x_right_eye = int((x3 + x4) / 2)
    y_right_eye = int((y3 + y4) / 2)

    results = np.array(
        [
            [x_left_eye, y_left_eye],
            [x_right_eye, y_right_eye],
            [x_nose, y_nose],
            [x_left_mouth, y_left_mouth],
            [x_right_mouth, y_right_mouth],
        ]
    )

    return results


def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):

    std_pts = _standard_face_pts()  # [-1,1]
    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0

    # print(target_pts)

    h, w, c = img.shape
    if normalize == True:
        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0

    # print(landmark)

    affine = SimilarityTransform()

    affine.estimate(target_pts, landmark)

    return affine.params


def show_detection(image, box, landmark):
    plt.imshow(image)
    print(box[2] - box[0])
    plt.gca().add_patch(
        Rectangle(
            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
        )
    )
    plt.scatter(landmark[0][0], landmark[0][1])
    plt.scatter(landmark[1][0], landmark[1][1])
    plt.scatter(landmark[2][0], landmark[2][1])
    plt.scatter(landmark[3][0], landmark[3][1])
    plt.scatter(landmark[4][0], landmark[4][1])
    plt.show()


def affine2theta(affine, input_w, input_h, target_w, target_h):
    # param = np.linalg.inv(affine)
    param = affine
    theta = np.zeros([2, 3])
    theta[0, 0] = param[0, 0] * input_h / target_h
    theta[0, 1] = param[0, 1] * input_w / target_h
    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
    theta[1, 0] = param[1, 0] * input_h / target_w
    theta[1, 1] = param[1, 1] * input_w / target_w
    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
    return theta


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
    parser.add_argument(
        "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
    )
    opts = parser.parse_args()

    url = opts.url
    save_url = opts.save_url

    ### If the origin url is None, then we don't need to reid the origin image

    os.makedirs(url, exist_ok=True)
    os.makedirs(save_url, exist_ok=True)

    face_detector = dlib.get_frontal_face_detector()
    landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")

    count = 0

    map_id = {}
    for x in os.listdir(url):
        img_url = os.path.join(url, x)
        pil_img = Image.open(img_url).convert("RGB")

        image = np.array(pil_img)

        start = time.time()
        faces = face_detector(image)
        done = time.time()

        if len(faces) == 0:
            print("Warning: There is no face in %s" % (x))
            continue

        print(len(faces))

        if len(faces) > 0:
            for face_id in range(len(faces)):
                current_face = faces[face_id]
                face_landmarks = landmark_locator(image, current_face)
                current_fl = search(face_landmarks)

                affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
                aligned_face = warp(image, affine, output_shape=(512, 512, 3))
                img_name = x[:-4] + "_" + str(face_id + 1)
                io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))

        count += 1

        if count % 1000 == 0:
            print("%d have finished ..." % (count))



================================================
FILE: Face_Enhancement/data/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
from data.face_dataset import FaceTestDataset


def create_dataloader(opt):

    instance = FaceTestDataset()
    instance.initialize(opt)
    print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
    dataloader = torch.utils.data.DataLoader(
        instance,
        batch_size=opt.batchSize,
        shuffle=not opt.serial_batches,
        num_workers=int(opt.nThreads),
        drop_last=opt.isTrain,
    )
    return dataloader


================================================
FILE: Face_Enhancement/data/base_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random


class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def initialize(self, opt):
        pass


def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess_mode == "resize_and_crop":
        new_h = new_w = opt.load_size
    elif opt.preprocess_mode == "scale_width_and_crop":
        new_w = opt.load_size
        new_h = opt.load_size * h // w
    elif opt.preprocess_mode == "scale_shortside_and_crop":
        ss, ls = min(w, h), max(w, h)  # shortside and longside
        width_is_shorter = w == ss
        ls = int(opt.load_size * ls / ss)
        new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)

    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

    flip = random.random() > 0.5
    return {"crop_pos": (x, y), "flip": flip}


def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
    transform_list = []
    if "resize" in opt.preprocess_mode:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, interpolation=method))
    elif "scale_width" in opt.preprocess_mode:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
    elif "scale_shortside" in opt.preprocess_mode:
        transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))

    if "crop" in opt.preprocess_mode:
        transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))

    if opt.preprocess_mode == "none":
        base = 32
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.preprocess_mode == "fixed":
        w = opt.crop_size
        h = round(opt.crop_size / opt.aspect_ratio)
        transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))

    if toTensor:
        transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def normalize():
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))


def __resize(img, w, h, method=Image.BICUBIC):
    return img.resize((w, h), method)


def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)


def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if ow == target_width:
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), method)


def __scale_shortside(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    ss, ls = min(ow, oh), max(ow, oh)  # shortside and longside
    width_is_shorter = ow == ss
    if ss == target_width:
        return img
    ls = int(target_width * ls / ss)
    nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
    return img.resize((nw, nh), method)


def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    return img.crop((x1, y1, x1 + tw, y1 + th))


def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


================================================
FILE: Face_Enhancement/data/custom_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from data.pix2pix_dataset import Pix2pixDataset
from data.image_folder import make_dataset


class CustomDataset(Pix2pixDataset):
    """ Dataset that loads images from directories
        Use option --label_dir, --image_dir, --instance_dir to specify the directories.
        The images in the directories are sorted in alphabetical order and paired in order.
    """

    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
        parser.set_defaults(preprocess_mode="resize_and_crop")
        load_size = 286 if is_train else 256
        parser.set_defaults(load_size=load_size)
        parser.set_defaults(crop_size=256)
        parser.set_defaults(display_winsize=256)
        parser.set_defaults(label_nc=13)
        parser.set_defaults(contain_dontcare_label=False)

        parser.add_argument(
            "--label_dir", type=str, required=True, help="path to the directory that contains label images"
        )
        parser.add_argument(
            "--image_dir", type=str, required=True, help="path to the directory that contains photo images"
        )
        parser.add_argument(
            "--instance_dir",
            type=str,
            default="",
            help="path to the directory that contains instance maps. Leave black if not exists",
        )
        return parser

    def get_paths(self, opt):
        label_dir = opt.label_dir
        label_paths = make_dataset(label_dir, recursive=False, read_cache=True)

        image_dir = opt.image_dir
        image_paths = make_dataset(image_dir, recursive=False, read_cache=True)

        if len(opt.instance_dir) > 0:
            instance_dir = opt.instance_dir
            instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
        else:
            instance_paths = []

        assert len(label_paths) == len(
            image_paths
        ), "The #images in %s and %s do not match. Is there something wrong?"

        return label_paths, image_paths, instance_paths


================================================
FILE: Face_Enhancement/data/face_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
import torch


class FaceTestDataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument(
            "--no_pairing_check",
            action="store_true",
            help="If specified, skip sanity check of correct label-image file pairing",
        )
        #    parser.set_defaults(contain_dontcare_label=False)
        #    parser.set_defaults(no_instance=True)
        return parser

    def initialize(self, opt):
        self.opt = opt

        image_path = os.path.join(opt.dataroot, opt.old_face_folder)
        label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)

        image_list = os.listdir(image_path)
        image_list = sorted(image_list)
        # image_list=image_list[:opt.max_dataset_size]

        self.label_paths = label_path  ## Just the root dir
        self.image_paths = image_list  ## All the image name

        self.parts = [
            "skin",
            "hair",
            "l_brow",
            "r_brow",
            "l_eye",
            "r_eye",
            "eye_g",
            "l_ear",
            "r_ear",
            "ear_r",
            "nose",
            "mouth",
            "u_lip",
            "l_lip",
            "neck",
            "neck_l",
            "cloth",
            "hat",
        ]

        size = len(self.image_paths)
        self.dataset_size = size

    def __getitem__(self, index):

        params = get_params(self.opt, (-1, -1))
        image_name = self.image_paths[index]
        image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)
        image = Image.open(image_path)
        image = image.convert("RGB")

        transform_image = get_transform(self.opt, params)
        image_tensor = transform_image(image)

        img_name = image_name[:-4]
        transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
        full_label = []

        cnt = 0

        for each_part in self.parts:
            part_name = img_name + "_" + each_part + ".png"
            part_url = os.path.join(self.label_paths, part_name)

            if os.path.exists(part_url):
                label = Image.open(part_url).convert("RGB")
                label_tensor = transform_label(label)  ## 3 channels and pixel [0,1]
                full_label.append(label_tensor[0])
            else:
                current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
                full_label.append(current_part)
                cnt += 1

        full_label_tensor = torch.stack(full_label, 0)

        input_dict = {
            "label": full_label_tensor,
            "image": image_tensor,
            "path": image_path,
        }

        return input_dict

    def __len__(self):
        return self.dataset_size



================================================
FILE: Face_Enhancement/data/image_folder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
    ".jpg",
    ".JPG",
    ".jpeg",
    ".JPEG",
    ".png",
    ".PNG",
    ".ppm",
    ".PPM",
    ".bmp",
    ".BMP",
    ".tiff",
    ".webp",
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset_rec(dir, images):
    assert os.path.isdir(dir), "%s is not a valid directory" % dir

    for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)


def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
    images = []

    if read_cache:
        possible_filelist = os.path.join(dir, "files.list")
        if os.path.isfile(possible_filelist):
            with open(possible_filelist, "r") as f:
                images = f.read().splitlines()
                return images

    if recursive:
        make_dataset_rec(dir, images)
    else:
        assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir

        for root, dnames, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    images.append(path)

    if write_cache:
        filelist_cache = os.path.join(dir, "files.list")
        with open(filelist_cache, "w") as f:
            for path in images:
                f.write("%s\n" % path)
            print("wrote filelist cache at %s" % filelist_cache)

    return images


def default_loader(path):
    return Image.open(path).convert("RGB")


class ImageFolder(data.Dataset):
    def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise (
                RuntimeError(
                    "Found 0 images in: " + root + "\n"
                    "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
                )
            )

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: Face_Enhancement/data/pix2pix_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os


class Pix2pixDataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument(
            "--no_pairing_check",
            action="store_true",
            help="If specified, skip sanity check of correct label-image file pairing",
        )
        return parser

    def initialize(self, opt):
        self.opt = opt

        label_paths, image_paths, instance_paths = self.get_paths(opt)

        util.natural_sort(label_paths)
        util.natural_sort(image_paths)
        if not opt.no_instance:
            util.natural_sort(instance_paths)

        label_paths = label_paths[: opt.max_dataset_size]
        image_paths = image_paths[: opt.max_dataset_size]
        instance_paths = instance_paths[: opt.max_dataset_size]

        if not opt.no_pairing_check:
            for path1, path2 in zip(label_paths, image_paths):
                assert self.paths_match(path1, path2), (
                    "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
                    % (path1, path2)
                )

        self.label_paths = label_paths
        self.image_paths = image_paths
        self.instance_paths = instance_paths

        size = len(self.label_paths)
        self.dataset_size = size

    def get_paths(self, opt):
        label_paths = []
        image_paths = []
        instance_paths = []
        assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
        return label_paths, image_paths, instance_paths

    def paths_match(self, path1, path2):
        filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
        filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
        return filename1_without_ext == filename2_without_ext

    def __getitem__(self, index):
        # Label Image
        label_path = self.label_paths[index]
        label = Image.open(label_path)
        params = get_params(self.opt, label.size)
        transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
        label_tensor = transform_label(label) * 255.0
        label_tensor[label_tensor == 255] = self.opt.label_nc  # 'unknown' is opt.label_nc

        # input image (real images)
        image_path = self.image_paths[index]
        assert self.paths_match(
            label_path, image_path
        ), "The label_path %s and image_path %s don't match." % (label_path, image_path)
        image = Image.open(image_path)
        image = image.convert("RGB")

        transform_image = get_transform(self.opt, params)
        image_tensor = transform_image(image)

        # if using instance maps
        if self.opt.no_instance:
            instance_tensor = 0
        else:
            instance_path = self.instance_paths[index]
            instance = Image.open(instance_path)
            if instance.mode == "L":
                instance_tensor = transform_label(instance) * 255
                instance_tensor = instance_tensor.long()
            else:
                instance_tensor = transform_label(instance)

        input_dict = {
            "label": label_tensor,
            "instance": instance_tensor,
            "image": image_tensor,
            "path": image_path,
        }

        # Give subclasses a chance to modify the final output
        self.postprocess(input_dict)

        return input_dict

    def postprocess(self, input_dict):
        return input_dict

    def __len__(self):
        return self.dataset_size


================================================
FILE: Face_Enhancement/models/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import importlib
import torch


def find_model_using_name(model_name):
    # Given the option --model [modelname],
    # the file "models/modelname_model.py"
    # will be imported.
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)

    # In the file, the class called ModelNameModel() will
    # be instantiated. It has to be a subclass of torch.nn.Module,
    # and it is case-insensitive.
    model = None
    target_model_name = model_name.replace("_", "") + "model"
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
            model = cls

    if model is None:
        print(
            "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
            % (model_filename, target_model_name)
        )
        exit(0)

    return model


def get_option_setter(model_name):
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % (type(instance).__name__))

    return instance


================================================
FILE: Face_Enhancement/models/networks/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
from models.networks.base_network import BaseNetwork
from models.networks.generator import *
from models.networks.encoder import *
import util.util as util


def find_network_using_name(target_network_name, filename):
    target_class_name = target_network_name + filename
    module_name = "models.networks." + filename
    network = util.find_class_in_module(target_class_name, module_name)

    assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network

    return network


def modify_commandline_options(parser, is_train):
    opt, _ = parser.parse_known_args()

    netG_cls = find_network_using_name(opt.netG, "generator")
    parser = netG_cls.modify_commandline_options(parser, is_train)
    if is_train:
        netD_cls = find_network_using_name(opt.netD, "discriminator")
        parser = netD_cls.modify_commandline_options(parser, is_train)
    netE_cls = find_network_using_name("conv", "encoder")
    parser = netE_cls.modify_commandline_options(parser, is_train)

    return parser


def create_network(cls, opt):
    net = cls(opt)
    net.print_network()
    if len(opt.gpu_ids) > 0:
        assert torch.cuda.is_available()
        net.cuda()
    net.init_weights(opt.init_type, opt.init_variance)
    return net


def define_G(opt):
    netG_cls = find_network_using_name(opt.netG, "generator")
    return create_network(netG_cls, opt)


def define_D(opt):
    netD_cls = find_network_using_name(opt.netD, "discriminator")
    return create_network(netD_cls, opt)


def define_E(opt):
    # there exists only one encoder type
    netE_cls = find_network_using_name("conv", "encoder")
    return create_network(netE_cls, opt)


================================================
FILE: Face_Enhancement/models/networks/architecture.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from models.networks.normalization import SPADE


# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout, opt):
        super().__init__()
        # Attributes
        self.learned_shortcut = fin != fout
        fmiddle = min(fin, fout)

        self.opt = opt
        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        if "spectral" in opt.norm_G:
            self.conv_0 = spectral_norm(self.conv_0)
            self.conv_1 = spectral_norm(self.conv_1)
            if self.learned_shortcut:
                self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
        spade_config_str = opt.norm_G.replace("spectral", "")
        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
        if self.learned_shortcut:
            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg, degraded_image):
        x_s = self.shortcut(x, seg, degraded_image)

        dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))

        out = x_s + dx

        return out

    def shortcut(self, x, seg, degraded_image):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)


# ResNet block used in pix2pixHD
# We keep the same architecture as pix2pixHD.
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
        super().__init__()

        pw = (kernel_size - 1) // 2
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(pw),
            norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
            activation,
            nn.ReflectionPad2d(pw),
            norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
        )

    def forward(self, x):
        y = self.conv_block(x)
        out = x + y
        return out


# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class SPADEResnetBlock_non_spade(nn.Module):
    def __init__(self, fin, fout, opt):
        super().__init__()
        # Attributes
        self.learned_shortcut = fin != fout
        fmiddle = min(fin, fout)

        self.opt = opt
        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        if "spectral" in opt.norm_G:
            self.conv_0 = spectral_norm(self.conv_0)
            self.conv_1 = spectral_norm(self.conv_1)
            if self.learned_shortcut:
                self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
        spade_config_str = opt.norm_G.replace("spectral", "")
        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
        if self.learned_shortcut:
            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg, degraded_image):
        x_s = self.shortcut(x, seg, degraded_image)

        dx = self.conv_0(self.actvn(x))
        dx = self.conv_1(self.actvn(dx))

        out = x_s + dx

        return out

    def shortcut(self, x, seg, degraded_image):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)


================================================
FILE: Face_Enhancement/models/networks/base_network.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.nn as nn
from torch.nn import init


class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def print_network(self):
        if isinstance(self, list):
            self = self[0]
        num_params = 0
        for param in self.parameters():
            num_params += param.numel()
        print(
            "Network [%s] was created. Total number of parameters: %.1f million. "
            "To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000)
        )

    def init_weights(self, init_type="normal", gain=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if classname.find("BatchNorm2d") != -1:
                if hasattr(m, "weight") and m.weight is not None:
                    init.normal_(m.weight.data, 1.0, gain)
                if hasattr(m, "bias") and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
            elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
                if init_type == "normal":
                    init.normal_(m.weight.data, 0.0, gain)
                elif init_type == "xavier":
                    init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == "xavier_uniform":
                    init.xavier_uniform_(m.weight.data, gain=1.0)
                elif init_type == "kaiming":
                    init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
                elif init_type == "orthogonal":
                    init.orthogonal_(m.weight.data, gain=gain)
                elif init_type == "none":  # uses pytorch's default init method
                    m.reset_parameters()
                else:
                    raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
                if hasattr(m, "bias") and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

        # propagate to children
        for m in self.children():
            if hasattr(m, "init_weights"):
                m.init_weights(init_type, gain)


================================================
FILE: Face_Enhancement/models/networks/encoder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer


class ConvEncoder(BaseNetwork):
    """ Same architecture as the image discriminator """

    def __init__(self, opt):
        super().__init__()

        kw = 3
        pw = int(np.ceil((kw - 1.0) / 2))
        ndf = opt.ngf
        norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
        self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
        self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
        self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
        self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
        self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
        if opt.crop_size >= 256:
            self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))

        self.so = s0 = 4
        self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
        self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)

        self.actvn = nn.LeakyReLU(0.2, False)
        self.opt = opt

    def forward(self, x):
        if x.size(2) != 256 or x.size(3) != 256:
            x = F.interpolate(x, size=(256, 256), mode="bilinear")

        x = self.layer1(x)
        x = self.layer2(self.actvn(x))
        x = self.layer3(self.actvn(x))
        x = self.layer4(self.actvn(x))
        x = self.layer5(self.actvn(x))
        if self.opt.crop_size >= 256:
            x = self.layer6(self.actvn(x))
        x = self.actvn(x)

        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_var(x)

        return mu, logvar


================================================
FILE: Face_Enhancement/models/networks/generator.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade


class SPADEGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
        parser.add_argument(
            "--num_upsampling_layers",
            choices=("normal", "more", "most"),
            default="normal",
            help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
        )

        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            if self.opt.no_parsing_map:
                self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
            else:
                self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
            self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        else:
            self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
            self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
            self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        else:
            self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
            self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
            self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        else:
            self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
            self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        else:
            self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
            self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        else:
            self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
            self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
        else:
            self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == "most":
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

    def compute_latent_vector_size(self, opt):
        if opt.num_upsampling_layers == "normal":
            num_up_layers = 5
        elif opt.num_upsampling_layers == "more":
            num_up_layers = 6
        elif opt.num_upsampling_layers == "most":
            num_up_layers = 7
        else:
            raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)

        sw = opt.load_size // (2 ** num_up_layers)
        sh = round(sw / opt.aspect_ratio)

        return sw, sh

    def forward(self, input, degraded_image, z=None):
        seg = input

        if self.opt.use_vae:
            # we sample z from unit normal and reshape the tensor
            if z is None:
                z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
            x = self.fc(z)
            x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
        else:
            # we downsample segmap and run convolution
            if self.opt.no_parsing_map:
                x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
            else:
                x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
            x = self.fc(x)

        x = self.head_0(x, seg, degraded_image)

        x = self.up(x)
        x = self.G_middle_0(x, seg, degraded_image)

        if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
            x = self.up(x)

        x = self.G_middle_1(x, seg, degraded_image)

        x = self.up(x)
        x = self.up_0(x, seg, degraded_image)
        x = self.up(x)
        x = self.up_1(x, seg, degraded_image)
        x = self.up(x)
        x = self.up_2(x, seg, degraded_image)
        x = self.up(x)
        x = self.up_3(x, seg, degraded_image)

        if self.opt.num_upsampling_layers == "most":
            x = self.up(x)
            x = self.up_4(x, seg, degraded_image)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x


class Pix2PixHDGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument(
            "--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
        )
        parser.add_argument(
            "--resnet_n_blocks",
            type=int,
            default=9,
            help="number of residual blocks in the global generator network",
        )
        parser.add_argument(
            "--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
        )
        parser.add_argument(
            "--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
        )
        # parser.set_defaults(norm_G='instance')
        return parser

    def __init__(self, opt):
        super().__init__()
        input_nc = 3

        # print("xxxxx")
        # print(opt.norm_G)
        norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
        activation = nn.ReLU(False)

        model = []

        # initial conv
        model += [
            nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
            norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
            activation,
        ]

        # downsample
        mult = 1
        for i in range(opt.resnet_n_downsample):
            model += [
                norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
                activation,
            ]
            mult *= 2

        # resnet blocks
        for i in range(opt.resnet_n_blocks):
            model += [
                ResnetBlock(
                    opt.ngf * mult,
                    norm_layer=norm_layer,
                    activation=activation,
                    kernel_size=opt.resnet_kernel_size,
                )
            ]

        # upsample
        for i in range(opt.resnet_n_downsample):
            nc_in = int(opt.ngf * mult)
            nc_out = int((opt.ngf * mult) / 2)
            model += [
                norm_layer(
                    nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
                ),
                activation,
            ]
            mult = mult // 2

        # final output conv
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, input, degraded_image, z=None):
        return self.model(degraded_image)



================================================
FILE: Face_Enhancement/models/networks/normalization.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm


def get_nonspade_norm_layer(opt, norm_type="instance"):
    # helper function to get # output channels of the previous layer
    def get_out_channel(layer):
        if hasattr(layer, "out_channels"):
            return getattr(layer, "out_channels")
        return layer.weight.size(0)

    # this function will be returned
    def add_norm_layer(layer):
        nonlocal norm_type
        if norm_type.startswith("spectral"):
            layer = spectral_norm(layer)
            subnorm_type = norm_type[len("spectral") :]

        if subnorm_type == "none" or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, "bias", None) is not None:
            delattr(layer, "bias")
            layer.register_parameter("bias", None)

        if subnorm_type == "batch":
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == "sync_batch":
            norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == "instance":
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
        else:
            raise ValueError("normalization layer %s is not recognized" % subnorm_type)

        return nn.Sequential(layer, norm_layer)

    return add_norm_layer


class SPADE(nn.Module):
    def __init__(self, config_text, norm_nc, label_nc, opt):
        super().__init__()

        assert config_text.startswith("spade")
        parsed = re.search("spade(\D+)(\d)x\d", config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))
        self.opt = opt
        if param_free_norm_type == "instance":
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == "syncbatch":
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == "batch":
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2

        if self.opt.no_parsing_map:
            self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
        else:
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
            )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap, degraded_image):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on semantic map
        segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
        degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")

        if self.opt.no_parsing_map:
            actv = self.mlp_shared(degraded_face)
        else:
            actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)

        # apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out


================================================
FILE: Face_Enhancement/models/pix2pix_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import models.networks as networks
import util.util as util


class Pix2PixModel(torch.nn.Module):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        networks.modify_commandline_options(parser, is_train)
        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()

    # Entry point for all calls involving forward pass
    # of deep networks. We used this approach since DataParallel module
    # can't parallelize custom functions, we branch to different
    # routines based on |mode|.
    def forward(self, data, mode):
        input_semantics, real_image, degraded_image = self.preprocess_input(data)

        if mode == "generator":
            g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
            return g_loss, generated
        elif mode == "discriminator":
            d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
            return d_loss
        elif mode == "encode_only":
            z, mu, logvar = self.encode_z(real_image)
            return mu, logvar
        elif mode == "inference":
            with torch.no_grad():
                fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
            return fake_image
        else:
            raise ValueError("|mode| is invalid")

    def create_optimizers(self, opt):
        G_params = list(self.netG.parameters())
        if opt.use_vae:
            G_params += list(self.netE.parameters())
        if opt.isTrain:
            D_params = list(self.netD.parameters())

        beta1, beta2 = opt.beta1, opt.beta2
        if opt.no_TTUR:
            G_lr, D_lr = opt.lr, opt.lr
        else:
            G_lr, D_lr = opt.lr / 2, opt.lr * 2

        optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
        optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))

        return optimizer_G, optimizer_D

    def save(self, epoch):
        util.save_network(self.netG, "G", epoch, self.opt)
        util.save_network(self.netD, "D", epoch, self.opt)
        if self.opt.use_vae:
            util.save_network(self.netE, "E", epoch, self.opt)

    ############################################################################
    # Private helper methods
    ############################################################################

    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, "G", opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, "D", opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, "E", opt.which_epoch, opt)

        return netG, netD, netE

    # preprocess the input, such as moving the tensors to GPUs and
    # transforming the label map to one-hot encoding
    # |data|: dictionary of the input data

    def preprocess_input(self, data):
        # move to GPU and change data types
        # data['label'] = data['label'].long()

        if not self.opt.isTrain:
            if self.use_gpu():
                data["label"] = data["label"].cuda()
                data["image"] = data["image"].cuda()
            return data["label"], data["image"], data["image"]

        ## While testing, the input image is the degraded face
        if self.use_gpu():
            data["label"] = data["label"].cuda()
            data["degraded_image"] = data["degraded_image"].cuda()
            data["image"] = data["image"].cuda()

        # # create one-hot label map
        # label_map = data['label']
        # bs, _, h, w = label_map.size()
        # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
        #     else self.opt.label_nc
        # input_label = self.FloatTensor(bs, nc, h, w).zero_()
        # input_semantics = input_label.scatter_(1, label_map, 1.0)

        return data["label"], data["image"], data["degraded_image"]

    def compute_generator_loss(self, input_semantics, degraded_image, real_image):
        G_losses = {}

        fake_image, KLD_loss = self.generate_fake(
            input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
        )

        if self.opt.use_vae:
            G_losses["KLD"] = KLD_loss

        pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)

        G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)

        if not self.opt.no_ganFeat_loss:
            num_D = len(pred_fake)
            GAN_Feat_loss = self.FloatTensor(1).fill_(0)
            for i in range(num_D):  # for each discriminator
                # last output is the final prediction, so we exclude it
                num_intermediate_outputs = len(pred_fake[i]) - 1
                for j in range(num_intermediate_outputs):  # for each layer output
                    unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
                    GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
            G_losses["GAN_Feat"] = GAN_Feat_loss

        if not self.opt.no_vgg_loss:
            G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg

        return G_losses, fake_image

    def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
        D_losses = {}
        with torch.no_grad():
            fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
            fake_image = fake_image.detach()
            fake_image.requires_grad_()

        pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)

        D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
        D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)

        return D_losses

    def encode_z(self, real_image):
        mu, logvar = self.netE(real_image)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
        z = None
        KLD_loss = None
        if self.opt.use_vae:
            z, mu, logvar = self.encode_z(real_image)
            if compute_kld_loss:
                KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld

        fake_image = self.netG(input_semantics, degraded_image, z=z)

        assert (
            not compute_kld_loss
        ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"

        return fake_image, KLD_loss

    # Given fake and real image, return the prediction of discriminator
    # for each fake and real image.

    def discriminate(self, input_semantics, fake_image, real_image):

        if self.opt.no_parsing_map:
            fake_concat = fake_image
            real_concat = real_image
        else:
            fake_concat = torch.cat([input_semantics, fake_image], dim=1)
            real_concat = torch.cat([input_semantics, real_image], dim=1)

        # In Batch Normalization, the fake and real images are
        # recommended to be in the same batch to avoid disparate
        # statistics in fake and real images.
        # So both fake and real images are fed to D all at once.
        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)

        discriminator_out = self.netD(fake_and_real)

        pred_fake, pred_real = self.divide_pred(discriminator_out)

        return pred_fake, pred_real

    # Take the prediction of fake and real images from the combined batch
    def divide_pred(self, pred):
        # the prediction contains the intermediate outputs of multiscale GAN,
        # so it's usually a list
        if type(pred) == list:
            fake = []
            real = []
            for p in pred:
                fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
                real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
        else:
            fake = pred[: pred.size(0) // 2]
            real = pred[pred.size(0) // 2 :]

        return fake, real

    def get_edges(self, t):
        edge = self.ByteTensor(t.size()).zero_()
        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        return edge.float()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std) + mu

    def use_gpu(self):
        return len(self.opt.gpu_ids) > 0


================================================
FILE: Face_Enhancement/options/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


================================================
FILE: Face_Enhancement/options/base_options.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import sys
import argparse
import os
from util import util
import torch
import models
import data
import pickle


class BaseOptions:
    def __init__(self):
        self.initialized = False

    def initialize(self, parser):
        # experiment specifics
        parser.add_argument(
            "--name",
            type=str,
            default="label2coco",
            help="name of the experiment. It decides where to store samples and models",
        )

        parser.add_argument(
            "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU"
        )
        parser.add_argument(
            "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here"
        )
        parser.add_argument("--model", type=str, default="pix2pix", help="which model to use")
        parser.add_argument(
            "--norm_G",
            type=str,
            default="spectralinstance",
            help="instance normalization or batch normalization",
        )
        parser.add_argument(
            "--norm_D",
            type=str,
            default="spectralinstance",
            help="instance normalization or batch normalization",
        )
        parser.add_argument(
            "--norm_E",
            type=str,
            default="spectralinstance",
            help="instance normalization or batch normalization",
        )
        parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")

        # input/output sizes
        parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
        parser.add_argument(
            "--preprocess_mode",
            type=str,
            default="scale_width_and_crop",
            help="scaling and cropping of images at load time.",
            choices=(
                "resize_and_crop",
                "crop",
                "scale_width",
                "scale_width_and_crop",
                "scale_shortside",
                "scale_shortside_and_crop",
                "fixed",
                "none",
                "resize",
            ),
        )
        parser.add_argument(
            "--load_size",
            type=int,
            default=1024,
            help="Scale images to this size. The final image will be cropped to --crop_size.",
        )
        parser.add_argument(
            "--crop_size",
            type=int,
            default=512,
            help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
        )
        parser.add_argument(
            "--aspect_ratio",
            type=float,
            default=1.0,
            help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
        )
        parser.add_argument(
            "--label_nc",
            type=int,
            default=182,
            help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
        )
        parser.add_argument(
            "--contain_dontcare_label",
            action="store_true",
            help="if the label map contains dontcare label (dontcare=255)",
        )
        parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels")

        # for setting inputs
        parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
        parser.add_argument("--dataset_mode", type=str, default="coco")
        parser.add_argument(
            "--serial_batches",
            action="store_true",
            help="if true, takes images in order to make batches, otherwise takes them randomly",
        )
        parser.add_argument(
            "--no_flip",
            action="store_true",
            help="if specified, do not flip the images for data argumentation",
        )
        parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data")
        parser.add_argument(
            "--max_dataset_size",
            type=int,
            default=sys.maxsize,
            help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
        )
        parser.add_argument(
            "--load_from_opt_file",
            action="store_true",
            help="load the options from checkpoints and use that as default",
        )
        parser.add_argument(
            "--cache_filelist_write",
            action="store_true",
            help="saves the current filelist into a text file, so that it loads faster",
        )
        parser.add_argument(
            "--cache_filelist_read", action="store_true", help="reads from the file list cache"
        )

        # for displays
        parser.add_argument("--display_winsize", type=int, default=400, help="display window size")

        # for generator
        parser.add_argument(
            "--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)"
        )
        parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer")
        parser.add_argument(
            "--init_type",
            type=str,
            default="xavier",
            help="network initialization [normal|xavier|kaiming|orthogonal]",
        )
        parser.add_argument(
            "--init_variance", type=float, default=0.02, help="variance of the initialization distribution"
        )
        parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector")
        parser.add_argument(
            "--no_parsing_map", action="store_true", help="During training, we do not use the parsing map"
        )

        # for instance-wise features
        parser.add_argument(
            "--no_instance", action="store_true", help="if specified, do *not* add instance map as input"
        )
        parser.add_argument(
            "--nef", type=int, default=16, help="# of encoder filters in the first conv layer"
        )
        parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.")
        parser.add_argument(
            "--tensorboard_log", action="store_true", help="use tensorboard to record the resutls"
        )

        # parser.add_argument('--img_dir',)
        parser.add_argument(
            "--old_face_folder", type=str, default="", help="The folder name of input old face"
        )
        parser.add_argument(
            "--old_face_label_folder", type=str, default="", help="The folder name of input old face label"
        )

        parser.add_argument("--injection_layer", type=str, default="all", help="")

        self.initialized = True
        return parser

    def gather_options(self):
        # initialize parser with basic options
        if not self.initialized:
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        opt, unknown = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)

        # modify dataset-related parser options
        # dataset_mode = opt.dataset_mode
        # dataset_option_setter = data.get_option_setter(dataset_mode)
        # parser = dataset_option_setter(parser, self.isTrain)

        opt, unknown = parser.parse_known_args()

        # if there is opt_file, load it.
        # The previous default options will be overwritten
        if opt.load_from_opt_file:
            parser = self.update_options_from_file(parser, opt)

        opt = parser.parse_args()
        self.parser = parser
        return opt

    def print_options(self, opt):
        message = ""
        message += "----------------- Options ---------------\n"
        for k, v in sorted(vars(opt).items()):
            comment = ""
            default = self.parser.get_default(k)
            if v != default:
                comment = "\t[default: %s]" % str(default)
            message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
        message += "----------------- End -------------------"
        # print(message)

    def option_file_path(self, opt, makedir=False):
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        if makedir:
            util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, "opt")
        return file_name

    def save_options(self, opt):
        file_name = self.option_file_path(opt, makedir=True)
        with open(file_name + ".txt", "wt") as opt_file:
            for k, v in sorted(vars(opt).items()):
                comment = ""
                default = self.parser.get_default(k)
                if v != default:
                    comment = "\t[default: %s]" % str(default)
                opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))

        with open(file_name + ".pkl", "wb") as opt_file:
            pickle.dump(opt, opt_file)

    def update_options_from_file(self, parser, opt):
        new_opt = self.load_options(opt)
        for k, v in sorted(vars(opt).items()):
            if hasattr(new_opt, k) and v != getattr(new_opt, k):
                new_val = getattr(new_opt, k)
                parser.set_defaults(**{k: new_val})
        return parser

    def load_options(self, opt):
        file_name = self.option_file_path(opt, makedir=False)
        new_opt = pickle.load(open(file_name + ".pkl", "rb"))
        return new_opt

    def parse(self, save=False):

        opt = self.gather_options()
        opt.isTrain = self.isTrain  # train or test
        opt.contain_dontcare_label = False

        self.print_options(opt)
        if opt.isTrain:
            self.save_options(opt)

        # Set semantic_nc based on the option.
        # This will be convenient in many places
        opt.semantic_nc = (
            opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
        )

        # set gpu ids
        str_ids = opt.gpu_ids.split(",")
        opt.gpu_ids = []
        for str_id in str_ids:
            int_id = int(str_id)
            if int_id >= 0:
                opt.gpu_ids.append(int_id)

        if len(opt.gpu_ids) > 0:
            print("The main GPU is ")
            print(opt.gpu_ids[0])
            torch.cuda.set_device(opt.gpu_ids[0])

        assert (
            len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0
        ), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))

        self.opt = opt
        return self.opt


================================================
FILE: Face_Enhancement/options/test_options.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .base_options import BaseOptions


class TestOptions(BaseOptions):
    def initialize(self, parser):
        BaseOptions.initialize(self, parser)
        parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
        parser.add_argument(
            "--which_epoch",
            type=str,
            default="latest",
            help="which epoch to load? set to latest to use latest cached model",
        )
        parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")

        parser.set_defaults(
            preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
        )
        parser.set_defaults(serial_batches=True)
        parser.set_defaults(no_flip=True)
        parser.set_defaults(phase="test")
        self.isTrain = False
        return parser


================================================
FILE: Face_Enhancement/requirements.txt
================================================
torch>=1.0.0
torchvision
dominate>=2.3.1
wandb
dill
scikit-image
tensorboardX
scipy
opencv-python

================================================
FILE: Face_Enhancement/test_face.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
from collections import OrderedDict

import data
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
from util.visualizer import Visualizer
import torchvision.utils as vutils
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

opt = TestOptions().parse()

dataloader = data.create_dataloader(opt)

model = Pix2PixModel(opt)
model.eval()

visualizer = Visualizer(opt)


single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img")


if not os.path.exists(single_save_url):
    os.makedirs(single_save_url)


for i, data_i in enumerate(dataloader):
    if i * opt.batchSize >= opt.how_many:
        break

    generated = model(data_i, mode="inference")

    img_path = data_i["path"]

    for b in range(generated.shape[0]):
        img_name = os.path.split(img_path[b])[-1]
        save_img_url = os.path.join(single_save_url, img_name)

        vutils.save_image((generated[b] + 1) / 2, save_img_url)



================================================
FILE: Face_Enhancement/util/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


================================================
FILE: Face_Enhancement/util/iter_counter.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import time
import numpy as np


# Helper class that keeps track of training iterations
class IterationCounter:
    def __init__(self, opt, dataset_size):
        self.opt = opt
        self.dataset_size = dataset_size

        self.first_epoch = 1
        self.total_epochs = opt.niter + opt.niter_decay
        self.epoch_iter = 0  # iter number within each epoch
        self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
        if opt.isTrain and opt.continue_train:
            try:
                self.first_epoch, self.epoch_iter = np.loadtxt(
                    self.iter_record_path, delimiter=",", dtype=int
                )
                print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
            except:
                print(
                    "Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
                )

        self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter

    # return the iterator of epochs for the training
    def training_epochs(self):
        return range(self.first_epoch, self.total_epochs + 1)

    def record_epoch_start(self, epoch):
        self.epoch_start_time = time.time()
        self.epoch_iter = 0
        self.last_iter_time = time.time()
        self.current_epoch = epoch

    def record_one_iteration(self):
        current_time = time.time()

        # the last remaining batch is dropped (see data/__init__.py),
        # so we can assume batch size is always opt.batchSize
        self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
        self.last_iter_time = current_time
        self.total_steps_so_far += self.opt.batchSize
        self.epoch_iter += self.opt.batchSize

    def record_epoch_end(self):
        current_time = time.time()
        self.time_per_epoch = current_time - self.epoch_start_time
        print(
            "End of epoch %d / %d \t Time Taken: %d sec"
            % (self.current_epoch, self.total_epochs, self.time_per_epoch)
        )
        if self.current_epoch % self.opt.save_epoch_freq == 0:
            np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
            print("Saved current iteration count at %s." % self.iter_record_path)

    def record_current_iter(self):
        np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
        print("Saved current iteration count at %s." % self.iter_record_path)

    def needs_saving(self):
        return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize

    def needs_printing(self):
        return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize

    def needs_displaying(self):
        return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize


================================================
FILE: Face_Enhancement/util/util.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
import argparse
import dill as pickle


def save_obj(obj, name):
    with open(name, "wb") as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    with open(name, "rb") as f:
        return pickle.load(f)


def copyconf(default_opt, **kwargs):
    conf = argparse.Namespace(**vars(default_opt))
    for key in kwargs:
        print(key, kwargs[key])
        setattr(conf, key, kwargs[key])
    return conf


# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
    if isinstance(image_tensor, list):
        image_numpy = []
        for i in range(len(image_tensor)):
            image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
        return image_numpy

    if image_tensor.dim() == 4:
        # transform each image in the batch
        images_np = []
        for b in range(image_tensor.size(0)):
            one_image = image_tensor[b]
            one_image_np = tensor2im(one_image)
            images_np.append(one_image_np.reshape(1, *one_image_np.shape))
        images_np = np.concatenate(images_np, axis=0)

        return images_np

    if image_tensor.dim() == 2:
        image_tensor = image_tensor.unsqueeze(0)
    image_numpy = image_tensor.detach().cpu().float().numpy()
    if normalize:
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    else:
        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
    image_numpy = np.clip(image_numpy, 0, 255)
    if image_numpy.shape[2] == 1:
        image_numpy = image_numpy[:, :, 0]
    return image_numpy.astype(imtype)


# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
    if label_tensor.dim() == 4:
        # transform each image in the batch
        images_np = []
        for b in range(label_tensor.size(0)):
            one_image = label_tensor[b]
            one_image_np = tensor2label(one_image, n_label, imtype)
            images_np.append(one_image_np.reshape(1, *one_image_np.shape))
        images_np = np.concatenate(images_np, axis=0)
        # if tile:
        #     images_tiled = tile_images(images_np)
        #     return images_tiled
        # else:
        #     images_np = images_np[0]
        #     return images_np
        return images_np

    if label_tensor.dim() == 1:
        return np.zeros((64, 64, 3), dtype=np.uint8)
    if n_label == 0:
        return tensor2im(label_tensor, imtype)
    label_tensor = label_tensor.cpu().float()
    if label_tensor.size()[0] > 1:
        label_tensor = label_tensor.max(0, keepdim=True)[1]
    label_tensor = Colorize(n_label)(label_tensor)
    label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
    result = label_numpy.astype(imtype)
    return result


def save_image(image_numpy, image_path, create_dir=False):
    if create_dir:
        os.makedirs(os.path.dirname(image_path), exist_ok=True)
    if len(image_numpy.shape) == 2:
        image_numpy = np.expand_dims(image_numpy, axis=2)
    if image_numpy.shape[2] == 1:
        image_numpy = np.repeat(image_numpy, 3, 2)
    image_pil = Image.fromarray(image_numpy)

    # save to png
    image_pil.save(image_path.replace(".jpg", ".png"))


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    """
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    """
    return [atoi(c) for c in re.split("(\d+)", text)]


def natural_sort(items):
    items.sort(key=natural_keys)


def str2bool(v):
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def find_class_in_module(target_cls_name, module):
    target_cls_name = target_cls_name.replace("_", "").lower()
    clslib = importlib.import_module(module)
    cls = None
    for name, clsobj in clslib.__dict__.items():
        if name.lower() == target_cls_name:
            cls = clsobj

    if cls is None:
        print(
            "In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
            % (module, target_cls_name)
        )
        exit(0)

    return cls


def save_network(net, label, epoch, opt):
    save_filename = "%s_net_%s.pth" % (epoch, label)
    save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
    torch.save(net.cpu().state_dict(), save_path)
    if len(opt.gpu_ids) and torch.cuda.is_available():
        net.cuda()


def load_network(net, label, epoch, opt):
    save_filename = "%s_net_%s.pth" % (epoch, label)
    save_dir = os.path.join(opt.checkpoints_dir, opt.name)
    save_path = os.path.join(save_dir, save_filename)
    if os.path.exists(save_path):
        weights = torch.load(save_path)
        net.load_state_dict(weights)
    return net


###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
    """returns the binary of integer n, count refers to amount of bits"""
    return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])


class Colorize(object):
    def __init__(self, n=35):
        self.cmap = labelcolormap(n)
        self.cmap = torch.from_numpy(self.cmap[:n])

    def __call__(self, gray_image):
        size = gray_image.size()
        color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)

        for label in range(0, len(self.cmap)):
            mask = (label == gray_image[0]).cpu()
            color_image[0][mask] = self.cmap[label][0]
            color_image[1][mask] = self.cmap[label][1]
            color_image[2][mask] = self.cmap[label][2]

        return color_image


================================================
FILE: Face_Enhancement/util/visualizer.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import ntpath
import time
from . import util
import scipy.misc

try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np


class Visualizer:
    def __init__(self, opt):
        self.opt = opt
        self.tf_log = opt.isTrain and opt.tf_log

        self.tensorboard_log = opt.tensorboard_log

        self.win_size = opt.display_winsize
        self.name = opt.name
        if self.tensorboard_log:

            if self.opt.isTrain:
                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)
                self.writer = SummaryWriter(log_dir=self.log_dir)
            else:
                print("hi :)")
                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)

        if opt.isTrain:
            self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
            with open(self.log_name, "a") as log_file:
                now = time.strftime("%c")
                log_file.write("================ Training Loss (%s) ================\n" % now)

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, epoch, step):

        all_tensor = []
        if self.tensorboard_log:

            for key, tensor in visuals.items():
                all_tensor.append((tensor.data.cpu() + 1) / 2)

            output = torch.cat(all_tensor, 0)
            img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)

            if self.opt.isTrain:
                self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
            else:
                vutils.save_image(
                    output,
                    os.path.join(self.log_dir, str(step) + ".png"),
                    nrow=self.opt.batchSize,
                    padding=0,
                    normalize=False,
                )

    # errors: dictionary of error labels and values
    def plot_current_errors(self, errors, step):
        if self.tf_log:
            for tag, value in errors.items():
                value = value.mean().float()
                summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
                self.writer.add_summary(summary, step)

        if self.tensorboard_log:

            self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
            self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
            self.writer.add_scalars(
                "Loss/GAN",
                {
                    "G": errors["GAN"].mean().float(),
                    "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
                },
                step,
            )

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t):
        message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
        for k, v in errors.items():
            v = v.mean().float()
            message += "%s: %.3f " % (k, v)

        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write("%s\n" % message)

    def convert_visuals_to_numpy(self, visuals):
        for key, t in visuals.items():
            tile = self.opt.batchSize > 8
            if "input_label" == key:
                t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)  ## B*H*W*C 0-255 numpy
            else:
                t = util.tensor2im(t, tile=tile)
            visuals[key] = t
        return visuals

    # save image to the disk
    def save_images(self, webpage, visuals, image_path):
        visuals = self.convert_visuals_to_numpy(visuals)

        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, image_numpy in visuals.items():
            image_name = os.path.join(label, "%s.png" % (name))
            save_path = os.path.join(image_dir, image_name)
            util.save_image(image_numpy, save_path, create_dir=True)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)


================================================
FILE: GUI.py
================================================
import numpy as np
import cv2
import PySimpleGUI as sg
import os.path
import argparse
import os
import sys
import shutil
from subprocess import call

def modify(image_filename=None, cv2_frame=None):

    def run_cmd(command):
        try:
            call(command, shell=True)
        except KeyboardInterrupt:
            print("Process interrupted")
            sys.exit(1)

    parser = argparse.ArgumentParser()
    parser.add_argument("--input_folder", type=str,
                        default= image_filename, help="Test images")
    parser.add_argument(
        "--output_folder",
        type=str,
        default="./output",
        help="Restored images, please use the absolute path",
    )
    parser.add_argument("--GPU", type=str, default="-1", help="0,1,2")
    parser.add_argument(
        "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint"
    )
    parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true")
    opts = parser.parse_args()

    gpu1 = opts.GPU

    # resolve relative paths before changing directory
    opts.input_folder = os.path.abspath(opts.input_folder)
    opts.output_folder = os.path.abspath(opts.output_folder)
    if not os.path.exists(opts.output_folder):
        os.makedirs(opts.output_folder)

    main_environment = os.getcwd()

    # Stage 1: Overall Quality Improve
    print("Running Stage 1: Overall restoration")
    os.chdir("./Global")
    stage_1_input_dir = opts.input_folder
    stage_1_output_dir = os.path.join(
        opts.output_folder, "stage_1_restore_output")
    if not os.path.exists(stage_1_output_dir):
        os.makedirs(stage_1_output_dir)

    if not opts.with_scratch:
        stage_1_command = (
            "python test.py --test_mode Full --Quality_restore --test_input "
            + stage_1_input_dir
            + " --outputs_dir "
            + stage_1_output_dir
            + " --gpu_ids "
            + gpu1
        )
        run_cmd(stage_1_command)
    else:

        mask_dir = os.path.join(stage_1_output_dir, "masks")
        new_input = os.path.join(mask_dir, "input")
        new_mask = os.path.join(mask_dir, "mask")
        stage_1_command_1 = (
            "python detection.py --test_path "
            + stage_1_input_dir
            + " --output_dir "
            + mask_dir
            + " --input_size full_size"
            + " --GPU "
            + gpu1
        )
        stage_1_command_2 = (
            "python test.py --Scratch_and_Quality_restore --test_input "
            + new_input
            + " --test_mask "
            + new_mask
            + " --outputs_dir "
            + stage_1_output_dir
            + " --gpu_ids "
            + gpu1
        )
        run_cmd(stage_1_command_1)
        run_cmd(stage_1_command_2)

    # Solve the case when there is no face in the old photo
    stage_1_results = os.path.join(stage_1_output_dir, "restored_image")
    stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
    if not os.path.exists(stage_4_output_dir):
        os.makedirs(stage_4_output_dir)
    for x in os.listdir(stage_1_results):
        img_dir = os.path.join(stage_1_results, x)
        shutil.copy(img_dir, stage_4_output_dir)

    print("Finish Stage 1 ...")
    print("\n")

    # Stage 2: Face Detection

    print("Running Stage 2: Face Detection")
    os.chdir(".././Face_Detection")
    stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image")
    stage_2_output_dir = os.path.join(
        opts.output_folder, "stage_2_detection_output")
    if not os.path.exists(stage_2_output_dir):
        os.makedirs(stage_2_output_dir)
    stage_2_command = (
        "python detect_all_dlib.py --url " + stage_2_input_dir +
        " --save_url " + stage_2_output_dir
    )
    run_cmd(stage_2_command)
    print("Finish Stage 2 ...")
    print("\n")

    # Stage 3: Face Restore
    print("Running Stage 3: Face Enhancement")
    os.chdir(".././Face_Enhancement")
    stage_3_input_mask = "./"
    stage_3_input_face = stage_2_output_dir
    stage_3_output_dir = os.path.join(
        opts.output_folder, "stage_3_face_output")
    if not os.path.exists(stage_3_output_dir):
        os.makedirs(stage_3_output_dir)
    stage_3_command = (
        "python test_face.py --old_face_folder "
        + stage_3_input_face
        + " --old_face_label_folder "
        + stage_3_input_mask
        + " --tensorboard_log --name "
        + opts.checkpoint_name
        + " --gpu_ids "
        + gpu1
        + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir "
        + stage_3_output_dir
        + " --no_parsing_map"
    )
    run_cmd(stage_3_command)
    print("Finish Stage 3 ...")
    print("\n")

    # Stage 4: Warp back
    print("Running Stage 4: Blending")
    os.chdir(".././Face_Detection")
    stage_4_input_image_dir = os.path.join(
        stage_1_output_dir, "restored_image")
    stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img")
    stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
    if not os.path.exists(stage_4_output_dir):
        os.makedirs(stage_4_output_dir)
    stage_4_command = (
        "python align_warp_back_multiple_dlib.py --origin_url "
        + stage_4_input_image_dir
        + " --replace_url "
        + stage_4_input_face_dir
        + " --save_url "
        + stage_4_output_dir
    )
    run_cmd(stage_4_command)
    print("Finish Stage 4 ...")
    print("\n")

    print("All the processing is done. Please check the results.")

# --------------------------------- The GUI ---------------------------------

# First the window layout...

images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()],
              [sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')],
              [sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],]
# ----- Full layout -----
layout = [[sg.VSeperator(), sg.Column(images_col)]]

# ----- Make the window -----
window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True)

# ----- Run the Event Loop -----
prev_filename = colorized = cap = None
while True:
    event, values = window.read()
    if event in (None, 'Exit'):
        break

    elif event == '-MPHOTO-':
        try:
            n1 = filename.split("/")[-2]
            n2 = filename.split("/")[-3]
            n3 = filename.split("/")[-1]
            filename= str(f"./{n2}/{n1}")
            modify(filename)
           
            global f_image
            f_image = f'./output/final_output/{n3}'
            image = cv2.imread(f_image)
            window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes())
            
        except:
            continue

    elif event == '-IN FILE-':      # A single filename was chosen
        filename = values['-IN FILE-']
        if filename != prev_filename:
            prev_filename = filename
            try:
                image = cv2.imread(filename)
                window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes())
            except:
                continue

# ----- Exit program -----
window.close()

================================================
FILE: Global/data/Create_Bigfile.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import struct
from PIL import Image

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                #print(fname)
                path = os.path.join(root, fname)
                images.append(path)

    return images

### Modify these 3 lines in your own environment
indir="/home/ziyuwan/workspace/data/temp_old"
target_folders=['VOC','Real_L_old','Real_RGB_old']
out_dir ="/home/ziyuwan/workspace/data/temp_old"
###

if os.path.exists(out_dir) is False:
    os.makedirs(out_dir)

#
for target_folder in target_folders:
    curr_indir = os.path.join(indir, target_folder)
    curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
    image_lists = make_dataset(curr_indir)
    image_lists.sort()
    with open(curr_out_file, 'wb') as wfid:
        # write total image number
        wfid.write(struct.pack('i', len(image_lists)))
        for i, img_path in enumerate(image_lists):
             # write file name first
             img_name = os.path.basename(img_path)
             img_name_bytes = img_name.encode('utf-8')
             wfid.write(struct.pack('i', len(img_name_bytes)))
             wfid.write(img_name_bytes)
    #
    #             # write image data in
             with open(img_path, 'rb') as img_fid:
                 img_bytes = img_fid.read()
             wfid.write(struct.pack('i', len(img_bytes)))
             wfid.write(img_bytes)

             if i % 1000 == 0:
                 print('write %d images done' % i)

================================================
FILE: Global/data/Load_Bigfile.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import io
import os
import struct
from PIL import Image

class BigFileMemoryLoader(object):
    def __load_bigfile(self):
        print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
        with open(self.file_path, 'rb') as fid:
            self.img_num = struct.unpack('i', fid.read(4))[0]
            self.img_names = []
            self.img_bytes = []
            print('find total %d images' % self.img_num)
            for i in range(self.img_num):
                img_name_len = struct.unpack('i', fid.read(4))[0]
                img_name = fid.read(img_name_len).decode('utf-8')
                self.img_names.append(img_name)
                img_bytes_len = struct.unpack('i', fid.read(4))[0]
                self.img_bytes.append(fid.read(img_bytes_len))
                if i % 5000 == 0:
                    print('load %d images done' % i)
            print('load all %d images done' % self.img_num)

    def __init__(self, file_path):
        super(BigFileMemoryLoader, self).__init__()
        self.file_path = file_path
        self.__load_bigfile()

    def __getitem__(self, index):
        try:
            img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
            return self.img_names[index], img
        except Exception:
            print('Image read error for index %d: %s' % (index, self.img_names[index]))
            return self.__getitem__((index+1)%self.img_num)


    def __len__(self):
        return self.img_num


================================================
FILE: Global/data/__init__.py
================================================


================================================
FILE: Global/data/base_data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

class BaseDataLoader():
    def __init__(self):
        pass
    
    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data():
        return None

        
        


================================================
FILE: Global/data/base_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass

def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.resize_or_crop == 'resize_and_crop':
        new_h = new_w = opt.loadSize

    if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256

        if w<h:
            new_w = opt.loadSize
            new_h = opt.loadSize * h // w
        else:
            new_h=opt.loadSize
            new_w = opt.loadSize * w // h

    if opt.resize_or_crop=='crop_only':
        pass


    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
    
    flip = random.random() > 0.5
    return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
    transform_list = []
    if 'resize' in opt.resize_or_crop:
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, method))   
    elif 'scale_width' in opt.resize_or_crop:
    #    transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))  ## Here , We want the shorter side to match 256, and Scale will finish it.
        transform_list.append(transforms.Scale(256,method))

    if 'crop' in opt.resize_or_crop:
        if opt.isTrain:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
        else:
            if opt.test_random_crop:
                transform_list.append(transforms.RandomCrop(opt.fineSize))
            else:
                transform_list.append(transforms.CenterCrop(opt.fineSize))

    ## when testing, for ablation study, choose center_crop directly.



    if opt.resize_or_crop == 'none':
        base = float(2 ** opt.n_downsample_global)
        if opt.netG == 'local':
            base *= (2 ** opt.n_local_enhancers)
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def normalize():    
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size        
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img    
    w = target_width
    h = int(target_width * oh / ow)    
    return img.resize((w, h), method)

def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):        
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


================================================
FILE: Global/data/custom_dataset_data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile


def CreateDataset(opt):
    dataset = None
    if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
        dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
    if opt.training_dataset=='mapping':
        if opt.random_hole:
            dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
        else:
            dataset = dts_ray_bigfile.PairOldPhotos()
    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads),
            drop_last=True)

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)


================================================
FILE: Global/data/data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt)
    return data_loader


================================================
FILE: Global/data/image_folder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: Global/data/online_dataset_for_old_photos.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from data.Load_Bigfile import BigFileMemoryLoader
import random
import cv2
from io import BytesIO

def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.

    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)

    if len(ar.shape) == 3:
        ar = ar.transpose(2, 0, 1)
    else:
        ar = ar[None, ...]

    return ar.astype(np.float32) / 255.


def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.

    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)

    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)

    return Image.fromarray(ar)

def synthesize_salt_pepper(image,amount,salt_vs_pepper):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    out = img_pil.copy()
    p = amount
    q = salt_vs_pepper
    flipped = np.random.choice([True, False], size=img_pil.shape,
                               p=[p, 1 - p])
    salted = np.random.choice([True, False], size=img_pil.shape,
                              p=[q, 1 - q])
    peppered = ~salted
    out[flipped & salted] = 1
    out[flipped & peppered] = 0.
    noisy = np.clip(out, 0, 1).astype(np.float32)


    return np_to_pil(noisy)

def synthesize_gaussian(image,std_l,std_r):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss
    noisy=np.clip(noisy,0,1).astype(np.float32)

    return np_to_pil(noisy)

def synthesize_speckle(image,std_l,std_r):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss*img_pil
    noisy=np.clip(noisy,0,1).astype(np.float32)

    return np_to_pil(noisy)


def synthesize_low_resolution(img):
    w,h=img.size

    new_w=random.randint(int(w/2),w)
    new_h=random.randint(int(h/2),h)

    img=img.resize((new_w,new_h),Image.BICUBIC)

    if random.uniform(0,1)<0.5:
        img=img.resize((w,h),Image.NEAREST)
    else:
        img = img.resize((w, h), Image.BILINEAR)

    return img


def convertToJpeg(im,quality):
    with BytesIO() as f:
        im.save(f, format='JPEG',quality=quality)
        f.seek(0)
        return Image.open(f).convert('RGB')


def blur_image_v2(img):


    x=np.array(img)
    kernel_size_candidate=[(3,3),(5,5),(7,7)]
    kernel_size=random.sample(kernel_size_candidate,1)[0]
    std=random.uniform(1.,5.)

    #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
    blur=cv2.GaussianBlur(x,kernel_size,std)

    return Image.fromarray(blur.astype(np.uint8))

def online_add_degradation_v2(img):

    task_id=np.random.permutation(4)

    for x in task_id:
        if x==0 and random.uniform(0,1)<0.7:
            img = blur_image_v2(img)
        if x==1 and random.uniform(0,1)<0.7:
            flag = random.choice([1, 2, 3])
            if flag == 1:
                img = synthesize_gaussian(img, 5, 50)
            if flag == 2:
                img = synthesize_speckle(img, 5, 50)
            if flag == 3:
                img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
        if x==2 and random.uniform(0,1)<0.7:
            img=synthesize_low_resolution(img)

        if x==3 and random.uniform(0,1)<0.7:
            img=convertToJpeg(img,random.randint(40,100))

    return img


def irregular_hole_synthesize(img,mask):

    img_np=np.array(img).astype('uint8')
    mask_np=np.array(mask).astype('uint8')
    mask_np=mask_np/255
    img_new=img_np*(1-mask_np)+mask_np*255


    hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")

    return hole_img,mask.convert("L")

def zero_mask(size):
    x=np.zeros((size,size,3)).astype('uint8')
    mask=Image.fromarray(x).convert("RGB")
    return mask



class UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'domainA' in opt.name
        self.task = 'old_photo_restoration_training_vae'
        self.dir_AB = opt.dataroot
        if self.isImage:

            self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
            self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")

            self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
            self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)

        else:
            # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)

        ####
        print("-------------Filter the imgs whose size <256 in VOC-------------")
        self.filtered_imgs_clean=[]
        for i in range(len(self.loaded_imgs_clean)):
            img_name,img=self.loaded_imgs_clean[i]
            h,w=img.size
            if h<256 or w<256:
                continue
            self.filtered_imgs_clean.append((img_name,img))

        print("--------Origin image num is [%d], filtered result is [%d]--------" % (
        len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
        ## Filter these images whose size is less than 256

        # self.img_list=os.listdir(load_img_dir)
        self.pid = os.getpid()

    def __getitem__(self, index):


        is_real_old=0

        sampled_dataset=None
        degradation=None
        if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
            P=random.uniform(0,2)
            if P>=0 and P<1:
                if random.uniform(0,1)<0.5:
                    sampled_dataset=self.loaded_imgs_L_old
                    self.load_img_dir=self.load_img_dir_L_old
                else:
                    sampled_dataset=self.loaded_imgs_RGB_old
                    self.load_img_dir=self.load_img_dir_RGB_old
                is_real_old=1
            if P>=1 and P<2:
                sampled_dataset=self.filtered_imgs_clean
                self.load_img_dir=self.load_img_dir_clean
                degradation=1
        else:

            sampled_dataset=self.filtered_imgs_clean
            self.load_img_dir=self.load_img_dir_clean

        sampled_dataset_len=len(sampled_dataset)

        index=random.randint(0,sampled_dataset_len-1)

        img_name,img = sampled_dataset[index]

        if degradation is not None:
            img=online_add_degradation_v2(img)

        path=os.path.join(self.load_img_dir,img_name)

        # AB = Image.open(path).convert('RGB')
        # split AB image into A and B

        # apply the same transform to both A and B

        if random.uniform(0,1) <0.1:
            img=img.convert("L")
            img=img.convert("RGB")
            ## Give a probability P, we convert the RGB image into L


        A=img
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
        ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.

        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)

        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)


        input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
                        'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):
        return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number

    def name(self):
        return 'UnPairOldPhotos_SR'


class PairOldPhotos(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)

            print("-------------Filter the imgs whose size <256 in VOC-------------")
            self.filtered_imgs_clean = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))

            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))

        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)

        self.pid = os.getpid()

    def __getitem__(self, index):



        if self.opt.isTrain:
            img_name_clean,B = self.filtered_imgs_clean[index]
            path = os.path.join(self.load_img_dir_clean, img_name_clean)
            if self.opt.use_v2_degradation:
                A=online_add_degradation_v2(B)
            ### Remind: A is the input and B is corresponding GT
        else:

            if self.opt.test_on_synthetic:

                img_name_B,B=self.loaded_imgs[index]
                A=online_add_degradation_v2(B)
                img_name_A=img_name_B
                path = os.path.join(self.load_img_dir, img_name_A)
            else:
                img_name_A,A=self.loaded_imgs[index]
                img_name_B,B=self.loaded_imgs[index]
                path = os.path.join(self.load_img_dir, img_name_A)


        if random.uniform(0,1)<0.1 and self.opt.isTrain:
            A=A.convert("L")
            B=B.convert("L")
            A=A.convert("RGB")
            B=B.convert("RGB")
        ## In P, we convert the RGB into L


        ##test on L

        # split AB image into A and B
        # w, h = img.size
        # w2 = int(w / 2)
        # A = img.crop((0, 0, w2, h))
        # B = img.crop((w2, 0, w, h))
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
            B=transforms.Scale(256, Image.BICUBIC)(B)

        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)

        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)

        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):

        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)
        else:
            return len(self.loaded_imgs)

    def name(self):
        return 'PairOldPhotos'


class PairOldPhotos_with_hole(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)

            print("-------------Filter the imgs whose size <256 in VOC-------------")
            self.filtered_imgs_clean = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))

            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))

        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)

        self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)

        self.pid = os.getpid()

    def __getitem__(self, index):



        if self.opt.isTrain:
            img_name_clean,B = self.filtered_imgs_clean[index]
            path = os.path.join(self.load_img_dir_clean, img_name_clean)


            B=transforms.RandomCrop(256)(B)
            A=online_add_degradation_v2(B)
            ### Remind: A is the input and B is corresponding GT

        else:
            img_name_A,A=self.loaded_imgs[index]
            img_name_B,B=self.loaded_imgs[index]
            path = os.path.join(self.load_img_dir, img_name_A)

            #A=A.resize((256,256))
            A=transforms.CenterCrop(256)(A)
            B=A

        if random.uniform(0,1)<0.1 and self.opt.isTrain:
            A=A.convert("L")
            B=B.convert("L")
            A=A.convert("RGB")
            B=B.convert("RGB")
        ## In P, we convert the RGB into L

        if self.opt.isTrain:
            mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
        else:
            mask_name, mask = self.loaded_masks[index%100]
        mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)

        if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
            mask=zero_mask(256)

        if self.opt.no_hole:
            mask=zero_mask(256)


        A,_=irregular_hole_synthesize(A,mask)

        if not self.opt.isTrain and self.opt.hole_image_no_mask:
            mask=zero_mask(256)

        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)

        if transform_params['flip'] and self.opt.isTrain:
            mask=mask.transpose(Image.FLIP_LEFT_RIGHT)

        mask_tensor = transforms.ToTensor()(mask)


        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)

        input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):

        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)

        else:
            return len(self.loaded_imgs)

    def name(self):
        return 'PairOldPhotos_with_hole'

================================================
FILE: Global/detection.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import gc
import json
import os
import time
import warnings

import numpy as np
import torch
import torch.nn.functional as F
import torchvision as tv
from PIL import Image, ImageFile

from detection_models import networks
from detection_util.util import *

warnings.filterwarnings("ignore", category=UserWarning)

ImageFile.LOAD_TRUNCATED_IMAGES = True


def data_transforms(img, full_size, method=Image.BICUBIC):
    if full_size == "full_size":
        ow, oh = img.size
        h = int(round(oh / 16) * 16)
        w = int(round(ow / 16) * 16)
        if (h == oh) and (w == ow):
            return img
        return img.resize((w, h), method)

    elif full_size == "scale_256":
        ow, oh = img.size
        pw, ph = ow, oh
        if ow < oh:
            ow = 256
            oh = ph / pw * 256
        else:
            oh = 256
            ow = pw / ph * 256

        h = int(round(oh / 16) * 16)
        w = int(round(ow / 16) * 16)
        if (h == ph) and (w == pw):
            return img
        return img.resize((w, h), method)


def scale_tensor(img_tensor, default_scale=256):
    _, _, w, h = img_tensor.shape
    if w < h:
        ow = default_scale
        oh = h / w * default_scale
    else:
        oh = default_scale
        ow = w / h * default_scale

    oh = int(round(oh / 16) * 16)
    ow = int(round(ow / 16) * 16)

    return F.interpolate(img_tensor, [ow, oh], mode="bilinear")


def blend_mask(img, mask):

    np_img = np.array(img).astype("float")

    return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB")


def main(config):
    print("initializing the dataloader")

    model = networks.UNet(
        in_channels=1,
        out_channels=1,
        depth=4,
        conv_num=2,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode="upsample",
        with_tanh=False,
        sync_bn=True,
        antialiasing=True,
    )

    ## load model
    checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt")
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state"])
    print("model weights loaded")

    if config.GPU >= 0:
        model.to(config.GPU)
    else: 
        model.cpu()
    model.eval()

    ## dataloader and transformation
    print("directory of testing image: " + config.test_path)
    imagelist = os.listdir(config.test_path)
    imagelist.sort()
    total_iter = 0

    P_matrix = {}
    save_url = os.path.join(config.output_dir)
    mkdir_if_not(save_url)

    input_dir = os.path.join(save_url, "input")
    output_dir = os.path.join(save_url, "mask")
    # blend_output_dir=os.path.join(save_url, 'blend_output')
    mkdir_if_not(input_dir)
    mkdir_if_not(output_dir)
    # mkdir_if_not(blend_output_dir)

    idx = 0

    results = []
    for image_name in imagelist:

        idx += 1

        print("processing", image_name)

        scratch_file = os.path.join(config.test_path, image_name)
        if not os.path.isfile(scratch_file):
            print("Skipping non-file %s" % image_name)
            continue
        scratch_image = Image.open(scratch_file).convert("RGB")
        w, h = scratch_image.size

        transformed_image_PIL = data_transforms(scratch_image, config.input_size)
        scratch_image = transformed_image_PIL.convert("L")
        scratch_image = tv.transforms.ToTensor()(scratch_image)
        scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)
        scratch_image = torch.unsqueeze(scratch_image, 0)
        _, _, ow, oh = scratch_image.shape
        scratch_image_scale = scale_tensor(scratch_image)

        if config.GPU >= 0:
            scratch_image_scale = scratch_image_scale.to(config.GPU)
        else:
            scratch_image_scale = scratch_image_scale.cpu()
        with torch.no_grad():
            P = torch.sigmoid(model(scratch_image_scale))

        P = P.data.cpu()
        P = F.interpolate(P, [ow, oh], mode="nearest")

        tv.utils.save_image(
            (P >= 0.4).float(),
            os.path.join(
                output_dir,
                image_name[:-4] + ".png",
            ),
            nrow=1,
            padding=0,
            normalize=True,
        )
        transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
        gc.collect()
        torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name')

    parser.add_argument("--GPU", type=int, default=0)
    parser.add_argument("--test_path", type=str, default=".")
    parser.add_argument("--output_dir", type=str, default=".")
    parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256")
    config = parser.parse_args()

    main(config)


================================================
FILE: Global/detection_models/__init__.py
================================================


================================================
FILE: Global/detection_models/antialiasing.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn.parallel
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


class Downsample(nn.Module):
    # https://github.com/adobe/antialiased-cnns

    def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0):
        super(Downsample, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [
            int(1.0 * (filt_size - 1) / 2),
            int(np.ceil(1.0 * (filt_size - 1) / 2)),
            int(1.0 * (filt_size - 1) / 2),
            int(np.ceil(1.0 * (filt_size - 1) / 2)),
        ]
        self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride - 1) / 2.0)
        self.channels = channels

        # print('Filter size [%i]'%filt_size)
        if self.filt_size == 1:
            a = np.array([1.0,])
        elif self.filt_size == 2:
            a = np.array([1.0, 1.0])
        elif self.filt_size == 3:
            a = np.array([1.0, 2.0, 1.0])
        elif self.filt_size == 4:
            a = np.array([1.0, 3.0, 3.0, 1.0])
        elif self.filt_size == 5:
            a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
        elif self.filt_size == 6:
            a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
        elif self.filt_size == 7:
            a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])

        filt = torch.Tensor(a[:, None] * a[None, :])
        filt = filt / torch.sum(filt)
        self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))

        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if self.filt_size == 1:
            if self.pad_off == 0:
                return inp[:, :, :: self.stride, :: self.stride]
            else:
                return self.pad(inp)[:, :, :: self.stride, :: self.stride]
        else:
            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])


def get_pad_layer(pad_type):
    if pad_type in ["refl", "reflect"]:
        PadLayer = nn.ReflectionPad2d
    elif pad_type in ["repl", "replicate"]:
        PadLayer = nn.ReplicationPad2d
    elif pad_type == "zero":
        PadLayer = nn.ZeroPad2d
    else:
        print("Pad type [%s] not recognized" % pad_type)
    return PadLayer


================================================
FILE: Global/detection_models/networks.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from detection_models.sync_batchnorm import DataParallelWithCallback
from detection_models.antialiasing import Downsample


class UNet(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        depth=5,
        conv_num=2,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode="upsample",
        with_tanh=False,
        sync_bn=True,
        antialiasing=True,
    ):
        """
		Implementation of
		U-Net: Convolutional Networks for Biomedical Image Segmentation
		(Ronneberger et al., 2015)
		https://arxiv.org/abs/1505.04597
		Using the default arguments will yield the exact version used
		in the original paper
		Args:
			in_channels (int): number of input channels
			out_channels (int): number of output channels
			depth (int): depth of the network
			wf (int): number of filters in the first layer is 2**wf
			padding (bool): if True, apply padding such that the input shape
							is the same as the output.
							This may introduce artifacts
			batch_norm (bool): Use BatchNorm after layers with an
							   activation function
			up_mode (str): one of 'upconv' or 'upsample'.
						   'upconv' will use transposed convolutions for
						   learned upsampling.
						   'upsample' will use bilinear upsampling.
		"""
        super().__init__()
        assert up_mode in ("upconv", "upsample")
        self.padding = padding
        self.depth = depth - 1
        prev_channels = in_channels

        self.first = nn.Sequential(
            *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
        )
        prev_channels = 2 ** wf

        self.down_path = nn.ModuleList()
        self.down_sample = nn.ModuleList()
        for i in range(depth):
            if antialiasing and depth > 0:
                self.down_sample.append(
                    nn.Sequential(
                        *[
                            nn.ReflectionPad2d(1),
                            nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
                            nn.BatchNorm2d(prev_channels),
                            nn.LeakyReLU(0.2, True),
                            Downsample(channels=prev_channels, stride=2),
                        ]
                    )
                )
            else:
                self.down_sample.append(
                    nn.Sequential(
                        *[
                            nn.ReflectionPad2d(1),
                            nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
                            nn.BatchNorm2d(prev_channels),
                            nn.LeakyReLU(0.2, True),
                        ]
                    )
                )
            self.down_path.append(
                UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i + 1)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth)):
            self.up_path.append(
                UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        if with_tanh:
            self.last = nn.Sequential(
                *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
            )
        else:
            self.last = nn.Sequential(
                *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
            )

        if sync_bn:
            self = DataParallelWithCallback(self)

    def forward(self, x):
        x = self.first(x)

        blocks = []
        for i, down_block in enumerate(self.down_path):
            blocks.append(x)
            x = self.down_sample[i](x)
            x = down_block(x)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        for _ in range(conv_num):
            block.append(nn.ReflectionPad2d(padding=int(padding)))
            block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
            if batch_norm:
                block.append(nn.BatchNorm2d(out_size))
            block.append(nn.LeakyReLU(0.2, True))
            in_size = out_size

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == "upconv":
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == "upsample":
            self.up = nn.Sequential(
                nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
            )

        self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out


class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
        """Construct a Unet generator
		Parameters:
			input_nc (int)  -- the number of channels in input images
			output_nc (int) -- the number of channels in output images
			num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
								image of size 128x128 will become of size 1x1 # at the bottleneck
			ngf (int)       -- the number of filters in the last conv layer
			norm_layer      -- normalization layer
		We construct the U-Net from the innermost layer to the outermost layer.
		It is a recursive process.
		"""
        super().__init__()
        if norm_type == "BN":
            norm_layer = nn.BatchNorm2d
        elif norm_type == "IN":
            norm_layer = nn.InstanceNorm2d
        else:
            raise NameError("Unknown norm layer")

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(
            ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
        )  # add the innermost layer
        for i in range(num_downs - 5):  # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(
                ngf * 8,
                ngf * 8,
                input_nc=None,
                submodule=unet_block,
                norm_layer=norm_layer,
                use_dropout=use_dropout,
            )
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(
            ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
        )
        unet_block = UnetSkipConnectionBlock(
            ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
        )
        unet_block = UnetSkipConnectionBlock(
            ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
        )
        self.model = UnetSkipConnectionBlock(
            output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
        )  # add the outermost layer

    def forward(self, input):
        return self.model(input)


class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.

		-------------------identity----------------------
		|-- downsampling -- |submodule| -- upsampling --|
	"""

    def __init__(
        self,
        outer_nc,
        inner_nc,
        input_nc=None,
        submodule=None,
        outermost=False,
        innermost=False,
        norm_layer=nn.BatchNorm2d,
        use_dropout=False,
    ):
        """Construct a Unet submodule with skip connections.
		Parameters:
			outer_nc (int) -- the number of filters in the outer conv layer
			inner_nc (int) -- the number of filters in the inner conv layer
			input_nc (int) -- the number of channels in input images/features
			submodule (UnetSkipConnectionBlock) -- previously defined submodules
			outermost (bool)    -- if this module is the outermost module
			innermost (bool)    -- if this module is the innermost module
			norm_layer          -- normalization layer
			user_dropout (bool) -- if use dropout layers.
		"""
        super().__init__()
        self.outermost = outermost
        use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.LeakyReLU(0.2, True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(
                inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
            )
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:  # add skip connections
            return torch.cat([x, self.model(x)], 1)


# ============================================
# Network testing
# ============================================
if __name__ == "__main__":
    from torchsummary import summary

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNet_two_decoders(
        in_channels=3,
        out_channels1=3,
        out_channels2=1,
        depth=4,
        conv_num=1,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode="upsample",
        with_tanh=False,
    )
    model.to(device)

    model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
    model_pix2pix.to(device)

    print("customized unet:")
    summary(model, (3, 256, 256))

    print("cyclegan unet:")
    summary(model_pix2pix, (3, 256, 256))

    x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
    g = make_dot(model(x))
    g.render("models/Digraph.gv", view=False)



================================================
FILE: Global/detection_util/util.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import sys
import time
import shutil
import platform
import numpy as np
from datetime import datetime

import torch
import torchvision as tv
import torch.backends.cudnn as cudnn

# from torch.utils.tensorboard import SummaryWriter

import yaml
import matplotlib.pyplot as plt
from easydict import EasyDict as edict
import torchvision.utils as vutils


##### option parsing ######
def print_options(config_dict):
    print("------------ Options -------------")
    for k, v in sorted(config_dict.items()):
        print("%s: %s" % (str(k), str(v)))
    print("-------------- End ----------------")


def save_options(config_dict):
    from time import gmtime, strftime

    file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"])
    mkdir_if_not(file_dir)
    file_name = os.path.join(file_dir, "opt.txt")
    with open(file_name, "wt") as opt_file:
        opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
        opt_file.write("------------ Options -------------\n")
        for k, v in sorted(config_dict.items()):
            opt_file.write("%s: %s\n" % (str(k), str(v)))
        opt_file.write("-------------- End ----------------\n")


def config_parse(config_file, options, save=True):
    with open(config_file, "r") as stream:
        config_dict = yaml.safe_load(stream)
        config = edict(config_dict)

    for option_key, option_value in vars(options).items():
        config_dict[option_key] = option_value
        config[option_key] = option_value

    if config.debug_mode:
        config_dict["num_workers"] = 0
        config.num_workers = 0
        config.batch_size = 2
        if isinstance(config.gpu_ids, str):
            config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0]

    print_options(config_dict)
    if save:
        save_options(config_dict)

    return config


###### utility ######
def to_np(x):
    return x.cpu().numpy()


def prepare_device(use_gpu, gpu_ids):
    if use_gpu:
        cudnn.benchmark = True
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if isinstance(gpu_ids, str):
            gpu_ids = [int(x) for x in gpu_ids.split(",")]
            torch.cuda.set_device(gpu_ids[0])
            device = torch.device("cuda:" + str(gpu_ids[0]))
        else:
            torch.cuda.set_device(gpu_ids)
            device = torch.device("cuda:" + str(gpu_ids))
        print("running on GPU {}".format(gpu_ids))
    else:
        device = torch.device("cpu")
        print("running on CPU")

    return device


###### file system ######
def get_dir_size(start_path="."):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size


def mkdir_if_not(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


##### System related ######
class Timer:
    def __init__(self, msg):
        self.msg = msg
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()

    def __exit__(self, exc_type, exc_value, exc_tb):
        elapse = time.time() - self.start_time
        print(self.msg % elapse)


###### interactive ######
def get_size(start_path="."):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size


def clean_tensorboard(directory):
    tensorboard_list = os.listdir(directory)
    SIZE_THRESH = 100000
    for tensorboard in tensorboard_list:
        tensorboard = os.path.join(directory, tensorboard)
        if get_size(tensorboard) < SIZE_THRESH:
            print("deleting the empty tensorboard: ", tensorboard)
            #
            if os.path.isdir(tensorboard):
                shutil.rmtree(tensorboard)
            else:
                os.remove(tensorboard)


def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")):
    tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs")
    mkdir_if_not(tensorboard_directory)
    clean_tensorboard(tensorboard_directory)
    tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10)

    # try:
    #     shutil.copy('outputs/opt.txt', tensorboard_directory)
    # except:
    #     print('cannot find file opt.txt')
    return tb_writer


def tb_loss_logger(tb_writer, iter_index, loss_logger):
    for tag, value in loss_logger.items():
        tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index)


def tb_image_logger(tb_writer, iter_index, images_info, config):
    ### Save and write the output into the tensorboard
    tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode)
    mkdir_if_not(tb_logger_path)
    for tag, image in images_info.items():
        if tag == "test_image_prediction" or tag == "image_prediction":
            continue
        image = tv.utils.make_grid(image.cpu())
        image = torch.clamp(image, 0, 1)
        tb_writer.add_image(tag, img_tensor=image, global_step=iter_index)
        tv.transforms.functional.to_pil_image(image).save(
            os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag))
        )


def tb_image_logger_test(epoch, iter, images_info, config):

    url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch))
    if not os.path.exists(url):
        os.makedirs(url)
    scratch_img = images_info["test_scratch_image"].data.cpu()
    if config.norm_input:
        scratch_img = (scratch_img + 1.0) / 2.0
    scratch_img = torch.clamp(scratch_img, 0, 1)
    gt_mask = images_info["test_mask_image"].data.cpu()
    predict_mask = images_info["test_scratch_prediction"].data.cpu()

    predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float()

    imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0)
    img_grid = vutils.save_image(
        imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True
    )


def imshow(input_image, title=None, to_numpy=False):
    inp = input_image
    if to_numpy or type(input_image) is torch.Tensor:
        inp = input_image.numpy()

    fig = plt.figure()
    if inp.ndim == 2:
        fig = plt.imshow(inp, cmap="gray", clim=[0, 255])
    else:
        fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))
    plt.axis("off")
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)


###### vgg preprocessing ######
def vgg_preprocess(tensor):
    # input is RGB tensor which ranges in [0,1]
    # output is BGR tensor which ranges in [0,255]
    tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
    # tensor_bgr = tensor[:, [2, 1, 0], ...]
    tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(
        1, 3, 1, 1
    )
    tensor_rst = tensor_bgr_ml * 255
    return tensor_rst


def torch_vgg_preprocess(tensor):
    # pytorch version normalization
    # note that both input and output are RGB tensors;
    # input and output ranges in [0,1]
    # normalize the tensor with mean and variance
    tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)
    tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)
    return tensor_mc_norm


def network_gradient(net, gradient_on=True):
    if gradient_on:
        for param in net.parameters():
            param.requires_grad = True
    else:
        for param in net.parameters():
            param.requires_grad = False
    return net


================================================
FILE: Global/models/NonLocal_feature_mapping_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import functools
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import math


class Mapping_Model_with_mask(nn.Module):
    def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
        super(Mapping_Model_with_mask, self).__init__()

        norm_layer = networks.get_norm_layer(norm_type=norm)
        activation = nn.ReLU(True)
        model = []

        tmp_nc = 64
        n_up = 4

        for i in range(n_up):
            ic = min(tmp_nc * (2 ** i), mc)
            oc = min(tmp_nc * (2 ** (i + 1)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]

        self.before_NL = nn.Sequential(*model)

        if opt.NL_res:
            self.NL = networks.NonLocalBlock2D_with_mask_Res(
                mc,
                mc,
                opt.NL_fusion_method,
                opt.correlation_renormalize,
                opt.softmax_temperature,
                opt.use_self,
                opt.cosin_similarity,
            )
            print("You are using NL + Res")

        model = []
        for i in range(n_blocks):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        for i in range(n_up - 1):
            ic = min(64 * (2 ** (4 - i)), mc)
            oc = min(64 * (2 ** (3 - i)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
        if opt.feat_dim > 0 and opt.feat_dim < 64:
            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
        # model += [nn.Conv2d(64, 1, 1, 1, 0)]
        self.after_NL = nn.Sequential(*model)
        
    
    def forward(self, input, mask):
        x1 = self.before_NL(input)
        del input
        x2 = self.NL(x1, mask)
        del x1, mask
        x3 = self.after_NL(x2)
        del x2

        return x3

class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
    def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
        super(Mapping_Model_with_mask_2, self).__init__()

        norm_layer = networks.get_norm_layer(norm_type=norm)
        activation = nn.ReLU(True)
        model = []

        tmp_nc = 64
        n_up = 4

        for i in range(n_up):
            ic = min(tmp_nc * (2 ** i), mc)
            oc = min(tmp_nc * (2 ** (i + 1)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]

        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        print("Mapping: You are using multi-scale patch attention, conv combine + mask input")

        self.before_NL = nn.Sequential(*model)

        if opt.mapping_exp==1:
            self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        self.res_block_1 = nn.Sequential(*model)

        if opt.mapping_exp==1:
            self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]
        
        self.res_block_2 = nn.Sequential(*model)
        
        if opt.mapping_exp==1:
            self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2)
        # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        for i in range(n_up - 1):
            ic = min(64 * (2 ** (4 - i)), mc)
            oc = min(64 * (2 ** (3 - i)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
        if opt.feat_dim > 0 and opt.feat_dim < 64:
            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
        # model += [nn.Conv2d(64, 1, 1, 1, 0)]
        self.after_NL = nn.Sequential(*model)
        
    
    def forward(self, input, mask):
        x1 = self.before_NL(input)
        x2 = self.NL_scale_1(x1,mask)
        x3 = self.res_block_1(x2)
        x4 = self.NL_scale_2(x3,mask)
        x5 = self.res_block_2(x4)
        x6 = self.NL_scale_3(x5,mask)
        x7 = self.after_NL(x6)
        return x7

    def inference_forward(self, input, mask):
        x1 = self.before_NL(input)
        del input
        x2 = self.NL_scale_1.inference_forward(x1,mask)
        del x1
        x3 = self.res_block_1(x2)
        del x2
        x4 = self.NL_scale_2.inference_forward(x3,mask)
        del x3
        x5 = self.res_block_2(x4)
        del x4
        x6 = self.NL_scale_3.inference_forward(x5,mask)
        del x5
        x7 = self.after_NL(x6)
        del x6
        return x7   

================================================
FILE: Global/models/__init__.py
================================================


================================================
FILE: Global/models/base_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import torch
import sys


class BaseModel(torch.nn.Module):
    def name(self):
        return "BaseModel"

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda()

    def save_optimizer(self, optimizer, optimizer_label, epoch_label):
        save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(optimizer.state_dict(), save_path)

    def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""):
        save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
        if not save_dir:
            save_dir = self.save_dir
        save_path = os.path.join(save_dir, save_filename)

        if not os.path.isfile(save_path):
            print("%s not exists yet!" % save_path)
        else:
            optimizer.load_state_dict(torch.load(save_path))

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label, save_dir=""):
        save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
        if not save_dir:
            save_dir = self.save_dir

        # print(save_dir)
        # print(self.save_dir)
        save_path = os.path.join(save_dir, save_filename)
        if not os.path.isfile(save_path):
            print("%s not exists yet!" % save_path)
            # if network_label == 'G':
            #     raise('Generator must exist!')
        else:
            # network.load_state_dict(torch.load(save_path))
            try:
                # print(save_path)
                network.load_state_dict(torch.load(save_path))
            except:
                pretrained_dict = torch.load(save_path)
                model_dict = network.state_dict()
                try:
                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                    network.load_state_dict(pretrained_dict)
                    # if self.opt.verbose:
                    print(
                        "Pretrained network %s has excessive layers; Only loading layers that are used"
                        % network_label
                    )
                except:
                    print(
                        "Pretrained network %s has fewer layers; The following are not initialized:"
                        % network_label
                    )
                    for k, v in pretrained_dict.items():
                        if v.size() == model_dict[k].size():
                            model_dict[k] = v

                    if sys.version_info >= (3, 0):
                        not_initialized = set()
                    else:
                        from sets import Set

                        not_initialized = Set()

                    for k, v in model_dict.items():
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                            not_initialized.add(k.split(".")[0])

                    print(sorted(not_initialized))
                    network.load_state_dict(model_dict)

    def update_learning_rate():
        pass


================================================
FILE: Global/models/mapping_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import functools
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import math
from .NonLocal_feature_mapping_model import *


class Mapping_Model(nn.Module):
    def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
        super(Mapping_Model, self).__init__()

        norm_layer = networks.get_norm_layer(norm_type=norm)
        activation = nn.ReLU(True)
        model = []
        tmp_nc = 64
        n_up = 4

        print("Mapping: You are using the mapping model without global restoration.")

        for i in range(n_up):
            ic = min(tmp_nc * (2 ** i), mc)
            oc = min(tmp_nc * (2 ** (i + 1)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        for i in range(n_blocks):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        for i in range(n_up - 1):
            ic = min(64 * (2 ** (4 - i)), mc)
            oc = min(64 * (2 ** (3 - i)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
        if opt.feat_dim > 0 and opt.feat_dim < 64:
            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
        # model += [nn.Conv2d(64, 1, 1, 1, 0)]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)


class Pix2PixHDModel_Mapping(BaseModel):
    def name(self):
        return "Pix2PixHDModel_Mapping"

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2):
        flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2)

        def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2):
            return [
                l
                for (l, f) in zip(
                    (g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags
                )
                if f
            ]

        return loss_filter

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != "none" or not opt.isTrain:
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks
        # Generator network
        netG_input_nc = input_nc
        self.netG_A = networks.GlobalGenerator_DCDCv2
Download .txt
gitextract_ojvoon3f/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── Dockerfile
├── Face_Detection/
│   ├── align_warp_back_multiple_dlib.py
│   ├── align_warp_back_multiple_dlib_HR.py
│   ├── detect_all_dlib.py
│   └── detect_all_dlib_HR.py
├── Face_Enhancement/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_dataset.py
│   │   ├── custom_dataset.py
│   │   ├── face_dataset.py
│   │   ├── image_folder.py
│   │   └── pix2pix_dataset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── networks/
│   │   │   ├── __init__.py
│   │   │   ├── architecture.py
│   │   │   ├── base_network.py
│   │   │   ├── encoder.py
│   │   │   ├── generator.py
│   │   │   └── normalization.py
│   │   └── pix2pix_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   └── test_options.py
│   ├── requirements.txt
│   ├── test_face.py
│   └── util/
│       ├── __init__.py
│       ├── iter_counter.py
│       ├── util.py
│       └── visualizer.py
├── GUI.py
├── Global/
│   ├── data/
│   │   ├── Create_Bigfile.py
│   │   ├── Load_Bigfile.py
│   │   ├── __init__.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── custom_dataset_data_loader.py
│   │   ├── data_loader.py
│   │   ├── image_folder.py
│   │   └── online_dataset_for_old_photos.py
│   ├── detection.py
│   ├── detection_models/
│   │   ├── __init__.py
│   │   ├── antialiasing.py
│   │   └── networks.py
│   ├── detection_util/
│   │   └── util.py
│   ├── models/
│   │   ├── NonLocal_feature_mapping_model.py
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── mapping_model.py
│   │   ├── models.py
│   │   ├── networks.py
│   │   ├── pix2pixHD_model.py
│   │   └── pix2pixHD_model_DA.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── test.py
│   ├── train_domain_A.py
│   ├── train_domain_B.py
│   ├── train_mapping.py
│   └── util/
│       ├── __init__.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
├── LICENSE
├── README.md
├── SECURITY.md
├── ansible.yaml
├── cog.yaml
├── download-weights
├── kubernetes-pod.yml
├── predict.py
├── requirements.txt
└── run.py
Download .txt
SYMBOL INDEX (447 symbols across 52 files)

FILE: Face_Detection/align_warp_back_multiple_dlib.py
  function calculate_cdf (line 26) | def calculate_cdf(histogram):
  function calculate_lookup (line 42) | def calculate_lookup(src_cdf, ref_cdf):
  function match_histograms (line 62) | def match_histograms(src_image, ref_image):
  function _standard_face_pts (line 112) | def _standard_face_pts():
  function _origin_face_pts (line 121) | def _origin_face_pts():
  function compute_transformation_matrix (line 127) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
  function compute_inverse_transformation_matrix (line 148) | def compute_inverse_transformation_matrix(img, landmark, normalize, targ...
  function show_detection (line 169) | def show_detection(image, box, landmark):
  function affine2theta (line 185) | def affine2theta(affine, input_w, input_h, target_w, target_h):
  function blur_blending (line 198) | def blur_blending(im1, im2, mask):
  function blur_blending_cv2 (line 217) | def blur_blending_cv2(im1, im2, mask):
  function Poisson_blending (line 239) | def Poisson_blending(im1, im2, mask):
  function Poisson_B (line 259) | def Poisson_B(im1, im2, mask, center):
  function seamless_clone (line 270) | def seamless_clone(old_face, new_face, raw_mask):
  function get_landmark (line 309) | def get_landmark(face_landmarks, id):
  function search (line 317) | def search(face_landmarks):

FILE: Face_Detection/align_warp_back_multiple_dlib_HR.py
  function calculate_cdf (line 26) | def calculate_cdf(histogram):
  function calculate_lookup (line 42) | def calculate_lookup(src_cdf, ref_cdf):
  function match_histograms (line 62) | def match_histograms(src_image, ref_image):
  function _standard_face_pts (line 112) | def _standard_face_pts():
  function _origin_face_pts (line 121) | def _origin_face_pts():
  function compute_transformation_matrix (line 127) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
  function compute_inverse_transformation_matrix (line 148) | def compute_inverse_transformation_matrix(img, landmark, normalize, targ...
  function show_detection (line 169) | def show_detection(image, box, landmark):
  function affine2theta (line 185) | def affine2theta(affine, input_w, input_h, target_w, target_h):
  function blur_blending (line 198) | def blur_blending(im1, im2, mask):
  function blur_blending_cv2 (line 217) | def blur_blending_cv2(im1, im2, mask):
  function Poisson_blending (line 239) | def Poisson_blending(im1, im2, mask):
  function Poisson_B (line 259) | def Poisson_B(im1, im2, mask, center):
  function seamless_clone (line 270) | def seamless_clone(old_face, new_face, raw_mask):
  function get_landmark (line 309) | def get_landmark(face_landmarks, id):
  function search (line 317) | def search(face_landmarks):

FILE: Face_Detection/detect_all_dlib.py
  function _standard_face_pts (line 27) | def _standard_face_pts():
  function _origin_face_pts (line 36) | def _origin_face_pts():
  function get_landmark (line 42) | def get_landmark(face_landmarks, id):
  function search (line 50) | def search(face_landmarks):
  function compute_transformation_matrix (line 80) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
  function show_detection (line 101) | def show_detection(image, box, landmark):
  function affine2theta (line 117) | def affine2theta(affine, input_w, input_h, target_w, target_h):

FILE: Face_Detection/detect_all_dlib_HR.py
  function _standard_face_pts (line 27) | def _standard_face_pts():
  function _origin_face_pts (line 36) | def _origin_face_pts():
  function get_landmark (line 42) | def get_landmark(face_landmarks, id):
  function search (line 50) | def search(face_landmarks):
  function compute_transformation_matrix (line 80) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
  function show_detection (line 101) | def show_detection(image, box, landmark):
  function affine2theta (line 117) | def affine2theta(affine, input_w, input_h, target_w, target_h):

FILE: Face_Enhancement/data/__init__.py
  function create_dataloader (line 10) | def create_dataloader(opt):

FILE: Face_Enhancement/data/base_dataset.py
  class BaseDataset (line 11) | class BaseDataset(data.Dataset):
    method __init__ (line 12) | def __init__(self):
    method modify_commandline_options (line 16) | def modify_commandline_options(parser, is_train):
    method initialize (line 19) | def initialize(self, opt):
  function get_params (line 23) | def get_params(opt, size):
  function get_transform (line 45) | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toT...
  function normalize (line 78) | def normalize():
  function __resize (line 82) | def __resize(img, w, h, method=Image.BICUBIC):
  function __make_power_2 (line 86) | def __make_power_2(img, base, method=Image.BICUBIC):
  function __scale_width (line 95) | def __scale_width(img, target_width, method=Image.BICUBIC):
  function __scale_shortside (line 104) | def __scale_shortside(img, target_width, method=Image.BICUBIC):
  function __crop (line 115) | def __crop(img, pos, size):
  function __flip (line 122) | def __flip(img, flip):

FILE: Face_Enhancement/data/custom_dataset.py
  class CustomDataset (line 8) | class CustomDataset(Pix2pixDataset):
    method modify_commandline_options (line 15) | def modify_commandline_options(parser, is_train):
    method get_paths (line 39) | def get_paths(self, opt):

FILE: Face_Enhancement/data/face_dataset.py
  class FaceTestDataset (line 11) | class FaceTestDataset(BaseDataset):
    method modify_commandline_options (line 13) | def modify_commandline_options(parser, is_train):
    method initialize (line 23) | def initialize(self, opt):
    method __getitem__ (line 60) | def __getitem__(self, index):
    method __len__ (line 100) | def __len__(self):

FILE: Face_Enhancement/data/image_folder.py
  function is_image_file (line 24) | def is_image_file(filename):
  function make_dataset_rec (line 28) | def make_dataset_rec(dir, images):
  function make_dataset (line 38) | def make_dataset(dir, recursive=False, read_cache=False, write_cache=Fal...
  function default_loader (line 69) | def default_loader(path):
  class ImageFolder (line 73) | class ImageFolder(data.Dataset):
    method __init__ (line 74) | def __init__(self, root, transform=None, return_paths=False, loader=de...
    method __getitem__ (line 90) | def __getitem__(self, index):
    method __len__ (line 100) | def __len__(self):

FILE: Face_Enhancement/data/pix2pix_dataset.py
  class Pix2pixDataset (line 10) | class Pix2pixDataset(BaseDataset):
    method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train):
    method initialize (line 20) | def initialize(self, opt):
    method get_paths (line 48) | def get_paths(self, opt):
    method paths_match (line 55) | def paths_match(self, path1, path2):
    method __getitem__ (line 60) | def __getitem__(self, index):
    method postprocess (line 104) | def postprocess(self, input_dict):
    method __len__ (line 107) | def __len__(self):

FILE: Face_Enhancement/models/__init__.py
  function find_model_using_name (line 8) | def find_model_using_name(model_name):
  function get_option_setter (line 34) | def get_option_setter(model_name):
  function create_model (line 39) | def create_model(opt):

FILE: Face_Enhancement/models/networks/__init__.py
  function find_network_using_name (line 11) | def find_network_using_name(target_network_name, filename):
  function modify_commandline_options (line 21) | def modify_commandline_options(parser, is_train):
  function create_network (line 35) | def create_network(cls, opt):
  function define_G (line 45) | def define_G(opt):
  function define_D (line 50) | def define_D(opt):
  function define_E (line 55) | def define_E(opt):

FILE: Face_Enhancement/models/networks/architecture.py
  class SPADEResnetBlock (line 19) | class SPADEResnetBlock(nn.Module):
    method __init__ (line 20) | def __init__(self, fin, fout, opt):
    method forward (line 49) | def forward(self, x, seg, degraded_image):
    method shortcut (line 59) | def shortcut(self, x, seg, degraded_image):
    method actvn (line 66) | def actvn(self, x):
  class ResnetBlock (line 72) | class ResnetBlock(nn.Module):
    method __init__ (line 73) | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_...
    method forward (line 85) | def forward(self, x):
  class VGG19 (line 92) | class VGG19(torch.nn.Module):
    method __init__ (line 93) | def __init__(self, requires_grad=False):
    method forward (line 115) | def forward(self, X):
  class SPADEResnetBlock_non_spade (line 125) | class SPADEResnetBlock_non_spade(nn.Module):
    method __init__ (line 126) | def __init__(self, fin, fout, opt):
    method forward (line 155) | def forward(self, x, seg, degraded_image):
    method shortcut (line 165) | def shortcut(self, x, seg, degraded_image):
    method actvn (line 172) | def actvn(self, x):

FILE: Face_Enhancement/models/networks/base_network.py
  class BaseNetwork (line 8) | class BaseNetwork(nn.Module):
    method __init__ (line 9) | def __init__(self):
    method modify_commandline_options (line 13) | def modify_commandline_options(parser, is_train):
    method print_network (line 16) | def print_network(self):
    method init_weights (line 27) | def init_weights(self, init_type="normal", gain=0.02):

FILE: Face_Enhancement/models/networks/encoder.py
  class ConvEncoder (line 11) | class ConvEncoder(BaseNetwork):
    method __init__ (line 14) | def __init__(self, opt):
    method forward (line 36) | def forward(self, x):

FILE: Face_Enhancement/models/networks/generator.py
  class SPADEGenerator (line 14) | class SPADEGenerator(BaseNetwork):
    method modify_commandline_options (line 16) | def modify_commandline_options(parser, is_train):
    method __init__ (line 27) | def __init__(self, opt):
    method compute_latent_vector_size (line 90) | def compute_latent_vector_size(self, opt):
    method forward (line 105) | def forward(self, input, degraded_image, z=None):
  class Pix2PixHDGenerator (line 151) | class Pix2PixHDGenerator(BaseNetwork):
    method modify_commandline_options (line 153) | def modify_commandline_options(parser, is_train):
    method __init__ (line 172) | def __init__(self, opt):
    method forward (line 231) | def forward(self, input, degraded_image, z=None):

FILE: Face_Enhancement/models/networks/normalization.py
  function get_nonspade_norm_layer (line 12) | def get_nonspade_norm_layer(opt, norm_type="instance"):
  class SPADE (line 49) | class SPADE(nn.Module):
    method __init__ (line 50) | def __init__(self, config_text, norm_nc, label_nc, opt):
    method forward (line 81) | def forward(self, x, segmap, degraded_image):

FILE: Face_Enhancement/models/pix2pix_model.py
  class Pix2PixModel (line 9) | class Pix2PixModel(torch.nn.Module):
    method modify_commandline_options (line 11) | def modify_commandline_options(parser, is_train):
    method __init__ (line 15) | def __init__(self, opt):
    method forward (line 36) | def forward(self, data, mode):
    method create_optimizers (line 55) | def create_optimizers(self, opt):
    method save (line 73) | def save(self, epoch):
    method initialize_networks (line 83) | def initialize_networks(self, opt):
    method preprocess_input (line 101) | def preprocess_input(self, data):
    method compute_generator_loss (line 127) | def compute_generator_loss(self, input_semantics, degraded_image, real...
    method compute_discriminator_loss (line 157) | def compute_discriminator_loss(self, input_semantics, degraded_image, ...
    method encode_z (line 171) | def encode_z(self, real_image):
    method generate_fake (line 176) | def generate_fake(self, input_semantics, degraded_image, real_image, c...
    method discriminate (line 195) | def discriminate(self, input_semantics, fake_image, real_image):
    method divide_pred (line 217) | def divide_pred(self, pred):
    method get_edges (line 232) | def get_edges(self, t):
    method reparameterize (line 240) | def reparameterize(self, mu, logvar):
    method use_gpu (line 245) | def use_gpu(self):

FILE: Face_Enhancement/options/base_options.py
  class BaseOptions (line 14) | class BaseOptions:
    method __init__ (line 15) | def __init__(self):
    method initialize (line 18) | def initialize(self, parser):
    method gather_options (line 185) | def gather_options(self):
    method print_options (line 215) | def print_options(self, opt):
    method option_file_path (line 227) | def option_file_path(self, opt, makedir=False):
    method save_options (line 234) | def save_options(self, opt):
    method update_options_from_file (line 247) | def update_options_from_file(self, parser, opt):
    method load_options (line 255) | def load_options(self, opt):
    method parse (line 260) | def parse(self, save=False):

FILE: Face_Enhancement/options/test_options.py
  class TestOptions (line 7) | class TestOptions(BaseOptions):
    method initialize (line 8) | def initialize(self, parser):

FILE: Face_Enhancement/util/iter_counter.py
  class IterationCounter (line 10) | class IterationCounter:
    method __init__ (line 11) | def __init__(self, opt, dataset_size):
    method training_epochs (line 33) | def training_epochs(self):
    method record_epoch_start (line 36) | def record_epoch_start(self, epoch):
    method record_one_iteration (line 42) | def record_one_iteration(self):
    method record_epoch_end (line 52) | def record_epoch_end(self):
    method record_current_iter (line 63) | def record_current_iter(self):
    method needs_saving (line 67) | def needs_saving(self):
    method needs_printing (line 70) | def needs_printing(self):
    method needs_displaying (line 73) | def needs_displaying(self):

FILE: Face_Enhancement/util/util.py
  function save_obj (line 15) | def save_obj(obj, name):
  function load_obj (line 20) | def load_obj(name):
  function copyconf (line 25) | def copyconf(default_opt, **kwargs):
  function tensor2im (line 35) | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
  function tensor2label (line 67) | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
  function save_image (line 97) | def save_image(image_numpy, image_path, create_dir=False):
  function mkdirs (line 110) | def mkdirs(paths):
  function mkdir (line 118) | def mkdir(path):
  function atoi (line 123) | def atoi(text):
  function natural_keys (line 127) | def natural_keys(text):
  function natural_sort (line 136) | def natural_sort(items):
  function str2bool (line 140) | def str2bool(v):
  function find_class_in_module (line 149) | def find_class_in_module(target_cls_name, module):
  function save_network (line 167) | def save_network(net, label, epoch, opt):
  function load_network (line 175) | def load_network(net, label, epoch, opt):
  function uint82bin (line 190) | def uint82bin(n, count=8):
  class Colorize (line 195) | class Colorize(object):
    method __init__ (line 196) | def __init__(self, n=35):
    method __call__ (line 200) | def __call__(self, gray_image):

FILE: Face_Enhancement/util/visualizer.py
  class Visualizer (line 20) | class Visualizer:
    method __init__ (line 21) | def __init__(self, opt):
    method display_current_results (line 49) | def display_current_results(self, visuals, epoch, step):
    method plot_current_errors (line 72) | def plot_current_errors(self, errors, step):
    method print_current_errors (line 93) | def print_current_errors(self, epoch, i, errors, t):
    method convert_visuals_to_numpy (line 103) | def convert_visuals_to_numpy(self, visuals):
    method save_images (line 114) | def save_images(self, webpage, visuals, image_path):

FILE: GUI.py
  function modify (line 11) | def modify(image_filename=None, cv2_frame=None):

FILE: Global/data/Create_Bigfile.py
  function is_image_file (line 14) | def is_image_file(filename):
  function make_dataset (line 18) | def make_dataset(dir):

FILE: Global/data/Load_Bigfile.py
  class BigFileMemoryLoader (line 9) | class BigFileMemoryLoader(object):
    method __load_bigfile (line 10) | def __load_bigfile(self):
    method __init__ (line 27) | def __init__(self, file_path):
    method __getitem__ (line 32) | def __getitem__(self, index):
    method __len__ (line 41) | def __len__(self):

FILE: Global/data/base_data_loader.py
  class BaseDataLoader (line 4) | class BaseDataLoader():
    method __init__ (line 5) | def __init__(self):
    method initialize (line 8) | def initialize(self, opt):
    method load_data (line 12) | def load_data():

FILE: Global/data/base_dataset.py
  class BaseDataset (line 10) | class BaseDataset(data.Dataset):
    method __init__ (line 11) | def __init__(self):
    method name (line 14) | def name(self):
    method initialize (line 17) | def initialize(self, opt):
  function get_params (line 20) | def get_params(opt, size):
  function get_transform (line 46) | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
  function normalize (line 84) | def normalize():
  function __make_power_2 (line 87) | def __make_power_2(img, base, method=Image.BICUBIC):
  function __scale_width (line 95) | def __scale_width(img, target_width, method=Image.BICUBIC):
  function __crop (line 103) | def __crop(img, pos, size):
  function __flip (line 111) | def __flip(img, flip):

FILE: Global/data/custom_dataset_data_loader.py
  function CreateDataset (line 10) | def CreateDataset(opt):
  class CustomDatasetDataLoader (line 23) | class CustomDatasetDataLoader(BaseDataLoader):
    method name (line 24) | def name(self):
    method initialize (line 27) | def initialize(self, opt):
    method load_data (line 37) | def load_data(self):
    method __len__ (line 40) | def __len__(self):

FILE: Global/data/data_loader.py
  function CreateDataLoader (line 4) | def CreateDataLoader(opt):

FILE: Global/data/image_folder.py
  function is_image_file (line 14) | def is_image_file(filename):
  function make_dataset (line 18) | def make_dataset(dir):
  function default_loader (line 31) | def default_loader(path):
  class ImageFolder (line 35) | class ImageFolder(data.Dataset):
    method __init__ (line 37) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 51) | def __getitem__(self, index):
    method __len__ (line 61) | def __len__(self):

FILE: Global/data/online_dataset_for_old_photos.py
  function pil_to_np (line 17) | def pil_to_np(img_PIL):
  function np_to_pil (line 32) | def np_to_pil(img_np):
  function synthesize_salt_pepper (line 46) | def synthesize_salt_pepper(image,amount,salt_vs_pepper):
  function synthesize_gaussian (line 67) | def synthesize_gaussian(image,std_l,std_r):
  function synthesize_speckle (line 81) | def synthesize_speckle(image,std_l,std_r):
  function synthesize_low_resolution (line 96) | def synthesize_low_resolution(img):
  function convertToJpeg (line 112) | def convertToJpeg(im,quality):
  function blur_image_v2 (line 119) | def blur_image_v2(img):
  function online_add_degradation_v2 (line 132) | def online_add_degradation_v2(img):
  function irregular_hole_synthesize (line 156) | def irregular_hole_synthesize(img,mask):
  function zero_mask (line 168) | def zero_mask(size):
  class UnPairOldPhotos_SR (line 175) | class UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old
    method initialize (line 176) | def initialize(self, opt):
    method __getitem__ (line 213) | def __getitem__(self, index):
    method __len__ (line 278) | def __len__(self):
    method name (line 281) | def name(self):
  class PairOldPhotos (line 285) | class PairOldPhotos(BaseDataset):
    method initialize (line 286) | def initialize(self, opt):
    method __getitem__ (line 313) | def __getitem__(self, index):
    method __len__ (line 370) | def __len__(self):
    method name (line 377) | def name(self):
  class PairOldPhotos_with_hole (line 381) | class PairOldPhotos_with_hole(BaseDataset):
    method initialize (line 382) | def initialize(self, opt):
    method __getitem__ (line 411) | def __getitem__(self, index):
    method __len__ (line 476) | def __len__(self):
    method name (line 484) | def name(self):

FILE: Global/detection.py
  function data_transforms (line 25) | def data_transforms(img, full_size, method=Image.BICUBIC):
  function scale_tensor (line 51) | def scale_tensor(img_tensor, default_scale=256):
  function blend_mask (line 66) | def blend_mask(img, mask):
  function main (line 73) | def main(config):

FILE: Global/detection_models/antialiasing.py
  class Downsample (line 11) | class Downsample(nn.Module):
    method __init__ (line 14) | def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels...
    method forward (line 51) | def forward(self, inp):
  function get_pad_layer (line 61) | def get_pad_layer(pad_type):

FILE: Global/detection_models/networks.py
  class UNet (line 11) | class UNet(nn.Module):
    method __init__ (line 12) | def __init__(
    method forward (line 109) | def forward(self, x):
  class UNetConvBlock (line 124) | class UNetConvBlock(nn.Module):
    method __init__ (line 125) | def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
    method forward (line 139) | def forward(self, x):
  class UNetUpBlock (line 144) | class UNetUpBlock(nn.Module):
    method __init__ (line 145) | def __init__(self, conv_num, in_size, out_size, up_mode, padding, batc...
    method center_crop (line 158) | def center_crop(self, layer, target_size):
    method forward (line 164) | def forward(self, x, bridge):
  class UnetGenerator (line 173) | class UnetGenerator(nn.Module):
    method __init__ (line 176) | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="...
    method forward (line 223) | def forward(self, input):
  class UnetSkipConnectionBlock (line 227) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 234) | def __init__(
    method forward (line 291) | def forward(self, x):

FILE: Global/detection_util/util.py
  function print_options (line 25) | def print_options(config_dict):
  function save_options (line 32) | def save_options(config_dict):
  function config_parse (line 46) | def config_parse(config_file, options, save=True):
  function to_np (line 70) | def to_np(x):
  function prepare_device (line 74) | def prepare_device(use_gpu, gpu_ids):
  function get_dir_size (line 94) | def get_dir_size(start_path="."):
  function mkdir_if_not (line 103) | def mkdir_if_not(dir_path):
  class Timer (line 109) | class Timer:
    method __init__ (line 110) | def __init__(self, msg):
    method __enter__ (line 114) | def __enter__(self):
    method __exit__ (line 117) | def __exit__(self, exc_type, exc_value, exc_tb):
  function get_size (line 123) | def get_size(start_path="."):
  function clean_tensorboard (line 132) | def clean_tensorboard(directory):
  function prepare_tensorboard (line 146) | def prepare_tensorboard(config, experiment_name=datetime.now().strftime(...
  function tb_loss_logger (line 159) | def tb_loss_logger(tb_writer, iter_index, loss_logger):
  function tb_image_logger (line 164) | def tb_image_logger(tb_writer, iter_index, images_info, config):
  function tb_image_logger_test (line 179) | def tb_image_logger_test(epoch, iter, images_info, config):
  function imshow (line 199) | def imshow(input_image, title=None, to_numpy=False):
  function vgg_preprocess (line 216) | def vgg_preprocess(tensor):
  function torch_vgg_preprocess (line 228) | def torch_vgg_preprocess(tensor):
  function network_gradient (line 238) | def network_gradient(net, gradient_on=True):

FILE: Global/models/NonLocal_feature_mapping_model.py
  class Mapping_Model_with_mask (line 17) | class Mapping_Model_with_mask(nn.Module):
    method __init__ (line 18) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
    method forward (line 71) | def forward(self, input, mask):
  class Mapping_Model_with_mask_2 (line 81) | class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
    method __init__ (line 82) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
    method forward (line 177) | def forward(self, input, mask):
    method inference_forward (line 187) | def inference_forward(self, input, mask):

FILE: Global/models/base_model.py
  class BaseModel (line 9) | class BaseModel(torch.nn.Module):
    method name (line 10) | def name(self):
    method initialize (line 13) | def initialize(self, opt):
    method set_input (line 20) | def set_input(self, input):
    method forward (line 23) | def forward(self):
    method test (line 27) | def test(self):
    method get_image_paths (line 30) | def get_image_paths(self):
    method optimize_parameters (line 33) | def optimize_parameters(self):
    method get_current_visuals (line 36) | def get_current_visuals(self):
    method get_current_errors (line 39) | def get_current_errors(self):
    method save (line 42) | def save(self, label):
    method save_network (line 46) | def save_network(self, network, network_label, epoch_label, gpu_ids):
    method save_optimizer (line 53) | def save_optimizer(self, optimizer, optimizer_label, epoch_label):
    method load_optimizer (line 58) | def load_optimizer(self, optimizer, optimizer_label, epoch_label, save...
    method load_network (line 70) | def load_network(self, network, network_label, epoch_label, save_dir=""):
    method update_learning_rate (line 121) | def update_learning_rate():

FILE: Global/models/mapping_model.py
  class Mapping_Model (line 18) | class Mapping_Model(nn.Module):
    method __init__ (line 19) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
    method forward (line 56) | def forward(self, input):
  class Pix2PixHDModel_Mapping (line 60) | class Pix2PixHDModel_Mapping(BaseModel):
    method name (line 61) | def name(self):
    method init_loss_filter (line 64) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth...
    method initialize (line 78) | def initialize(self, opt):
    method encode_input (line 215) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
    method discriminate (line 240) | def discriminate(self, input_label, test_image, use_pool=False):
    method forward (line 248) | def forward(self, label, inst, image, feat, pair=True, infer=False, la...
    method inference (line 325) | def inference(self, label, inst):
  class InferenceModel (line 349) | class InferenceModel(Pix2PixHDModel_Mapping):
    method forward (line 350) | def forward(self, label, inst):

FILE: Global/models/models.py
  function create_model (line 7) | def create_model(opt):
  function create_da_model (line 29) | def create_da_model(opt):

FILE: Global/models/networks.py
  function weights_init (line 17) | def weights_init(m):
  function get_norm_layer (line 26) | def get_norm_layer(norm_type="instance"):
  function print_network (line 40) | def print_network(net):
  function define_G (line 50) | def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_glob...
  function define_D (line 70) | def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoi...
  class GlobalGenerator_DCDCv2 (line 82) | class GlobalGenerator_DCDCv2(nn.Module):
    method __init__ (line 83) | def __init__(
    method forward (line 283) | def forward(self, input, flow="enc_dec"):
  class ResnetBlock (line 295) | class ResnetBlock(nn.Module):
    method __init__ (line 296) | def __init__(
    method build_conv_block (line 304) | def build_conv_block(self, dim, padding_type, norm_layer, activation, ...
    method forward (line 337) | def forward(self, x):
  class Encoder (line 342) | class Encoder(nn.Module):
    method __init__ (line 343) | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm...
    method forward (line 376) | def forward(self, input, inst):
  function SN (line 394) | def SN(module, mode=True):
  class NonLocalBlock2D_with_mask_Res (line 401) | class NonLocalBlock2D_with_mask_Res(nn.Module):
    method __init__ (line 402) | def __init__(
    method forward (line 460) | def forward(self, x, mask):  ## The shape of mask is Batch*1*H*W
  class MultiscaleDiscriminator (line 526) | class MultiscaleDiscriminator(nn.Module):
    method __init__ (line 527) | def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.Ba...
    method singleD_forward (line 544) | def singleD_forward(self, model, input):
    method forward (line 553) | def forward(self, input):
  class NLayerDiscriminator (line 568) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 569) | def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.Ba...
    method forward (line 609) | def forward(self, input):
  class Patch_Attention_4 (line 621) | class Patch_Attention_4(nn.Module):  ## While combine the feature map, u...
    method __init__ (line 622) | def __init__(self, in_channels, inter_channels, patch_size):
    method Hard_Compose (line 666) | def Hard_Compose(self, input, dim, index):
    method forward (line 678) | def forward(self, z, mask):  ## The shape of mask is Batch*1*H*W
    method inference_forward (line 720) | def inference_forward(self,z,mask): ## Reduce the extra memory cost
  class GANLoss (line 781) | class GANLoss(nn.Module):
    method __init__ (line 782) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
    method get_target_tensor (line 795) | def get_target_tensor(self, input, target_is_real):
    method __call__ (line 813) | def __call__(self, input, target_is_real):
  class VGG19_torch (line 831) | class VGG19_torch(torch.nn.Module):
    method __init__ (line 832) | def __init__(self, requires_grad=False):
    method forward (line 854) | def forward(self, X):
  class VGGLoss_torch (line 863) | class VGGLoss_torch(nn.Module):
    method __init__ (line 864) | def __init__(self, gpu_ids):
    method forward (line 870) | def forward(self, x, y):

FILE: Global/models/pix2pixHD_model.py
  class Pix2PixHDModel (line 12) | class Pix2PixHDModel(BaseModel):
    method name (line 13) | def name(self):
    method init_loss_filter (line 16) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_...
    method initialize (line 22) | def initialize(self, opt):
    method encode_input (line 112) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
    method discriminate (line 145) | def discriminate(self, input_label, test_image, use_pool=False):
    method forward (line 156) | def forward(self, label, inst, image, feat, infer=False):
    method inference (line 221) | def inference(self, label, inst, image=None, feat=None):
    method sample_features (line 245) | def sample_features(self, inst):
    method encode_features (line 266) | def encode_features(self, image, inst):
    method get_edges (line 288) | def get_edges(self, t):
    method save (line 299) | def save(self, which_epoch):
    method update_fixed_params (line 309) | def update_fixed_params(self):
    method update_learning_rate (line 318) | def update_learning_rate(self):
  class InferenceModel (line 330) | class InferenceModel(Pix2PixHDModel):
    method forward (line 331) | def forward(self, inp):

FILE: Global/models/pix2pixHD_model_DA.py
  class Pix2PixHDModel (line 13) | class Pix2PixHDModel(BaseModel):
    method name (line 14) | def name(self):
    method init_loss_filter (line 17) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
    method initialize (line 25) | def initialize(self, opt):
    method encode_input (line 118) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
    method discriminate (line 151) | def discriminate(self, input_label, test_image, use_pool=False):
    method feat_discriminate (line 162) | def feat_discriminate(self,input):
    method forward (line 167) | def forward(self, label, inst, image, feat, infer=False):
    method inference (line 256) | def inference(self, label, inst, image=None, feat=None):
    method sample_features (line 280) | def sample_features(self, inst):
    method encode_features (line 301) | def encode_features(self, image, inst):
    method get_edges (line 323) | def get_edges(self, t):
    method save (line 334) | def save(self, which_epoch):
    method update_fixed_params (line 346) | def update_fixed_params(self):
    method update_learning_rate (line 355) | def update_learning_rate(self):
  class InferenceModel (line 369) | class InferenceModel(Pix2PixHDModel):
    method forward (line 370) | def forward(self, inp):

FILE: Global/options/base_options.py
  class BaseOptions (line 10) | class BaseOptions:
    method __init__ (line 11) | def __init__(self):
    method initialize (line 15) | def initialize(self):
    method parse (line 338) | def parse(self, save=True):

FILE: Global/options/test_options.py
  class TestOptions (line 7) | class TestOptions(BaseOptions):
    method initialize (line 8) | def initialize(self):

FILE: Global/options/train_options.py
  class TrainOptions (line 6) | class TrainOptions(BaseOptions):
    method initialize (line 7) | def initialize(self):

FILE: Global/test.py
  function data_transforms (line 18) | def data_transforms(img, method=Image.BILINEAR, scale=False):
  function data_transforms_rgb_old (line 39) | def data_transforms_rgb_old(img):
  function irregular_hole_synthesize (line 47) | def irregular_hole_synthesize(img, mask):
  function parameter_set (line 59) | def parameter_set(opt):

FILE: Global/util/image_pool.py
  class ImagePool (line 9) | class ImagePool:
    method __init__ (line 10) | def __init__(self, pool_size):
    method query (line 16) | def query(self, images):

FILE: Global/util/util.py
  function tensor2im (line 14) | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
  function tensor2label (line 32) | def tensor2label(label_tensor, n_label, imtype=np.uint8):
  function save_image (line 43) | def save_image(image_numpy, image_path):
  function mkdirs (line 48) | def mkdirs(paths):
  function mkdir (line 56) | def mkdir(path):

FILE: Global/util/visualizer.py
  class Visualizer (line 16) | class Visualizer():
    method __init__ (line 17) | def __init__(self, opt):
    method display_current_results (line 40) | def display_current_results(self, visuals, epoch, step):
    method plot_current_errors (line 98) | def plot_current_errors(self, errors, step):
    method print_current_errors (line 105) | def print_current_errors(self, epoch, i, errors, t, lr):
    method print_save (line 116) | def print_save(self,message):
    method save_images (line 125) | def save_images(self, webpage, visuals, image_path):

FILE: predict.py
  class Predictor (line 12) | class Predictor(cog.Predictor):
    method setup (line 13) | def setup(self):
    method predict (line 51) | def predict(self, image, HR=False, with_scratch=False):
  function clean_folder (line 213) | def clean_folder(folder):

FILE: run.py
  function run_cmd (line 10) | def run_cmd(command):
Condensed preview — 75 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (379K chars).
[
  {
    "path": ".gitignore",
    "chars": 23,
    "preview": "__pycache__/\n*.pyc\n*~\n\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 444,
    "preview": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://op"
  },
  {
    "path": "Dockerfile",
    "chars": 1358,
    "preview": "FROM nvidia/cuda:11.1-base-ubuntu20.04\n\nRUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzi"
  },
  {
    "path": "Face_Detection/align_warp_back_multiple_dlib.py",
    "chars": 13487,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
  },
  {
    "path": "Face_Detection/align_warp_back_multiple_dlib_HR.py",
    "chars": 13487,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
  },
  {
    "path": "Face_Detection/detect_all_dlib.py",
    "chars": 5291,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
  },
  {
    "path": "Face_Detection/detect_all_dlib_HR.py",
    "chars": 5291,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
  },
  {
    "path": "Face_Enhancement/data/__init__.py",
    "chars": 624,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch.utils.data\nfrom "
  },
  {
    "path": "Face_Enhancement/data/base_dataset.py",
    "chars": 3892,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
  },
  {
    "path": "Face_Enhancement/data/custom_dataset.py",
    "chars": 2149,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.pix2pix_dataset import Pix2pixDatase"
  },
  {
    "path": "Face_Enhancement/data/face_dataset.py",
    "chars": 3029,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get"
  },
  {
    "path": "Face_Enhancement/data/image_folder.py",
    "chars": 2717,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
  },
  {
    "path": "Face_Enhancement/data/pix2pix_dataset.py",
    "chars": 3911,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get"
  },
  {
    "path": "Face_Enhancement/models/__init__.py",
    "chars": 1336,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch\n\n\ndef find_model"
  },
  {
    "path": "Face_Enhancement/models/networks/__init__.py",
    "chars": 1773,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nfrom models.networks.base_network"
  },
  {
    "path": "Face_Enhancement/models/networks/architecture.py",
    "chars": 6291,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
  },
  {
    "path": "Face_Enhancement/models/networks/base_network.py",
    "chars": 2370,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nfrom torch.nn import ini"
  },
  {
    "path": "Face_Enhancement/models/networks/encoder.py",
    "chars": 1873,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nimport numpy as np\nimpor"
  },
  {
    "path": "Face_Enhancement/models/networks/generator.py",
    "chars": 8343,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
  },
  {
    "path": "Face_Enhancement/models/networks/normalization.py",
    "chars": 3848,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport torch\nimport torch.nn as nn\ni"
  },
  {
    "path": "Face_Enhancement/models/pix2pix_model.py",
    "chars": 9786,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport models.networks as network"
  },
  {
    "path": "Face_Enhancement/options/__init__.py",
    "chars": 73,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "Face_Enhancement/options/base_options.py",
    "chars": 10935,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nimport argparse\nimport os\nfrom util"
  },
  {
    "path": "Face_Enhancement/options/test_options.py",
    "chars": 968,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass "
  },
  {
    "path": "Face_Enhancement/requirements.txt",
    "chars": 97,
    "preview": "torch>=1.0.0\ntorchvision\ndominate>=2.3.1\nwandb\ndill\nscikit-image\ntensorboardX\nscipy\nopencv-python"
  },
  {
    "path": "Face_Enhancement/test_face.py",
    "chars": 1077,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\n"
  },
  {
    "path": "Face_Enhancement/util/__init__.py",
    "chars": 73,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "Face_Enhancement/util/iter_counter.py",
    "chars": 3009,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport time\nimport numpy as np\n\n\n# H"
  },
  {
    "path": "Face_Enhancement/util/util.py",
    "chars": 6590,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport importlib\nimport torch\nfrom a"
  },
  {
    "path": "Face_Enhancement/util/visualizer.py",
    "chars": 4786,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport ntpath\nimport time\nfrom . imp"
  },
  {
    "path": "GUI.py",
    "chars": 7241,
    "preview": "import numpy as np\nimport cv2\nimport PySimpleGUI as sg\nimport os.path\nimport argparse\nimport os\nimport sys\nimport shutil"
  },
  {
    "path": "Global/data/Create_Bigfile.py",
    "chars": 1952,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport struct\nfrom PIL import Image\n"
  },
  {
    "path": "Global/data/Load_Bigfile.py",
    "chars": 1590,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport io\nimport os\nimport struct\nfrom PIL imp"
  },
  {
    "path": "Global/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Global/data/base_data_loader.py",
    "chars": 268,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nclass BaseDataLoader():\n    def __init__(self)"
  },
  {
    "path": "Global/data/base_dataset.py",
    "chars": 3596,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
  },
  {
    "path": "Global/data/custom_dataset_data_loader.py",
    "chars": 1316,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data\nimport random\nfrom dat"
  },
  {
    "path": "Global/data/data_loader.py",
    "chars": 302,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\ndef CreateDataLoader(opt):\n    from data.custo"
  },
  {
    "path": "Global/data/image_folder.py",
    "chars": 1643,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
  },
  {
    "path": "Global/data/online_dataset_for_old_photos.py",
    "chars": 15313,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os.path\nimport io\nimport zipfile\nfrom d"
  },
  {
    "path": "Global/detection.py",
    "chars": 5032,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport gc\nimport json\nimport o"
  },
  {
    "path": "Global/detection_models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Global/detection_models/antialiasing.py",
    "chars": 2451,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn.parallel\nimport n"
  },
  {
    "path": "Global/detection_models/networks.py",
    "chars": 11753,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
  },
  {
    "path": "Global/detection_util/util.py",
    "chars": 7981,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport sys\nimport time\nimport shutil"
  },
  {
    "path": "Global/models/NonLocal_feature_mapping_model.py",
    "chars": 6414,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.n"
  },
  {
    "path": "Global/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Global/models/base_model.py",
    "chars": 4281,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport torch\nimport sys\n\n\nclass Base"
  },
  {
    "path": "Global/models/mapping_model.py",
    "chars": 13758,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.n"
  },
  {
    "path": "Global/models/models.py",
    "chars": 1212,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\n\n\ndef create_model(opt):\n    if o"
  },
  {
    "path": "Global/models/networks.py",
    "chars": 30705,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport func"
  },
  {
    "path": "Global/models/pix2pixHD_model.py",
    "chars": 15127,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom"
  },
  {
    "path": "Global/models/pix2pixHD_model_DA.py",
    "chars": 16499,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom"
  },
  {
    "path": "Global/options/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Global/options/base_options.py",
    "chars": 16231,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport os\nfrom util import uti"
  },
  {
    "path": "Global/options/test_options.py",
    "chars": 4605,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass "
  },
  {
    "path": "Global/options/train_options.py",
    "chars": 5199,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\nclass T"
  },
  {
    "path": "Global/test.py",
    "chars": 6028,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\n"
  },
  {
    "path": "Global/train_domain_A.py",
    "chars": 5403,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
  },
  {
    "path": "Global/train_domain_B.py",
    "chars": 5193,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
  },
  {
    "path": "Global/train_mapping.py",
    "chars": 5954,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
  },
  {
    "path": "Global/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "Global/util/image_pool.py",
    "chars": 1166,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport random\nimport torch\nfrom torch.autograd"
  },
  {
    "path": "Global/util/util.py",
    "chars": 1845,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import print_function\nimport t"
  },
  {
    "path": "Global/util/visualizer.py",
    "chars": 5790,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport os\nimport ntpath\nimp"
  },
  {
    "path": "LICENSE",
    "chars": 1141,
    "preview": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any pers"
  },
  {
    "path": "README.md",
    "chars": 12303,
    "preview": "# Old Photo Restoration (Official PyTorch Implementation)\n\n<img src='imgs/0001.jpg'/>\n\n### [Project Page](http://raywzy."
  },
  {
    "path": "SECURITY.md",
    "chars": 2780,
    "preview": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products an"
  },
  {
    "path": "ansible.yaml",
    "chars": 3506,
    "preview": "---\r\n- name: Bringing-Old-Photos-Back-to-Life\r\n  hosts: all\r\n  gather_facts: no\r\n\r\n# Succesfully tested on Ubuntu 18.04\\"
  },
  {
    "path": "cog.yaml",
    "chars": 551,
    "preview": "build:\n  gpu: true\n  python_version: \"3.8\"\n  system_packages:\n    - \"libgl1-mesa-glx\"\n    - \"libglib2.0-0\"\n  python_pack"
  },
  {
    "path": "download-weights",
    "chars": 847,
    "preview": "#!/bin/sh\n\ncd Face_Enhancement/models/networks\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -r"
  },
  {
    "path": "kubernetes-pod.yml",
    "chars": 727,
    "preview": "apiVersion: v1\nkind: Pod\nmetadata:\n  name: photo-back2life\nspec:\n  containers:\n    - name: photos-back2life\n      image:"
  },
  {
    "path": "predict.py",
    "chars": 8086,
    "preview": "import tempfile\nfrom pathlib import Path\nimport argparse\nimport shutil\nimport os\nimport glob\nimport cv2\nimport cog\nfrom "
  },
  {
    "path": "requirements.txt",
    "chars": 136,
    "preview": "torch\ntorchvision\ndlib\nscikit-image\neasydict\nPyYAML\ndominate>=2.3.1\ndill\ntensorboardX\nscipy\nopencv-python\neinops\nPySimpl"
  },
  {
    "path": "run.py",
    "chars": 6775,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport argparse\nimport shutil\nimport"
  }
]

About this extraction

This page contains the full source code of the microsoft/Bringing-Old-Photos-Back-to-Life GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 75 files (353.2 KB), approximately 90.5k tokens, and a symbol index with 447 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!