Showing preview only (374K chars total). Download the full file or copy to clipboard to get everything.
Repository: microsoft/Bringing-Old-Photos-Back-to-Life
Branch: master
Commit: 33875eccf4eb
Files: 75
Total size: 353.2 KB
Directory structure:
gitextract_ojvoon3f/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── Dockerfile
├── Face_Detection/
│ ├── align_warp_back_multiple_dlib.py
│ ├── align_warp_back_multiple_dlib_HR.py
│ ├── detect_all_dlib.py
│ └── detect_all_dlib_HR.py
├── Face_Enhancement/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── custom_dataset.py
│ │ ├── face_dataset.py
│ │ ├── image_folder.py
│ │ └── pix2pix_dataset.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── networks/
│ │ │ ├── __init__.py
│ │ │ ├── architecture.py
│ │ │ ├── base_network.py
│ │ │ ├── encoder.py
│ │ │ ├── generator.py
│ │ │ └── normalization.py
│ │ └── pix2pix_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ └── test_options.py
│ ├── requirements.txt
│ ├── test_face.py
│ └── util/
│ ├── __init__.py
│ ├── iter_counter.py
│ ├── util.py
│ └── visualizer.py
├── GUI.py
├── Global/
│ ├── data/
│ │ ├── Create_Bigfile.py
│ │ ├── Load_Bigfile.py
│ │ ├── __init__.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── custom_dataset_data_loader.py
│ │ ├── data_loader.py
│ │ ├── image_folder.py
│ │ └── online_dataset_for_old_photos.py
│ ├── detection.py
│ ├── detection_models/
│ │ ├── __init__.py
│ │ ├── antialiasing.py
│ │ └── networks.py
│ ├── detection_util/
│ │ └── util.py
│ ├── models/
│ │ ├── NonLocal_feature_mapping_model.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── mapping_model.py
│ │ ├── models.py
│ │ ├── networks.py
│ │ ├── pix2pixHD_model.py
│ │ └── pix2pixHD_model_DA.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train_domain_A.py
│ ├── train_domain_B.py
│ ├── train_mapping.py
│ └── util/
│ ├── __init__.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── LICENSE
├── README.md
├── SECURITY.md
├── ansible.yaml
├── cog.yaml
├── download-weights
├── kubernetes-pod.yml
├── predict.py
├── requirements.txt
└── run.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
*.pyc
*~
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:11.1-base-ubuntu20.04
RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq
ADD . /app
WORKDIR /app
RUN cd Face_Enhancement/models/networks/ &&\
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
cd ../../../
RUN cd Global/detection_models &&\
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\
cd ../../
RUN cd Face_Detection/ &&\
wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\
bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\
cd ../
RUN cd Face_Enhancement/ &&\
wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\
unzip checkpoints.zip &&\
cd ../ &&\
cd Global/ &&\
wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\
unzip checkpoints.zip &&\
rm -f checkpoints.zip &&\
cd ../
RUN pip3 install numpy
RUN pip3 install dlib
RUN pip3 install -r requirements.txt
RUN git clone https://github.com/NVlabs/SPADE.git
RUN cd SPADE/ && pip3 install -r requirements.txt
RUN cd ..
CMD ["python3", "run.py"]
================================================
FILE: Face_Detection/align_warp_back_multiple_dlib.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import numpy as np
import skimage.io as io
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image, ImageFilter
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib
def calculate_cdf(histogram):
"""
This method calculates the cumulative distribution function
:param array histogram: The values of the histogram
:return: normalized_cdf: The normalized cumulative distribution function
:rtype: array
"""
# Get the cumulative sum of the elements
cdf = histogram.cumsum()
# Normalize the cdf
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def calculate_lookup(src_cdf, ref_cdf):
"""
This method creates the lookup table
:param array src_cdf: The cdf for the source image
:param array ref_cdf: The cdf for the reference image
:return: lookup_table: The lookup table
:rtype: array
"""
lookup_table = np.zeros(256)
lookup_val = 0
for src_pixel_val in range(len(src_cdf)):
lookup_val
for ref_pixel_val in range(len(ref_cdf)):
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
lookup_val = ref_pixel_val
break
lookup_table[src_pixel_val] = lookup_val
return lookup_table
def match_histograms(src_image, ref_image):
"""
This method matches the source image histogram to the
reference signal
:param image src_image: The original source image
:param image ref_image: The reference image
:return: image_after_matching
:rtype: image (array)
"""
# Split the images into the different color channels
# b means blue, g means green and r means red
src_b, src_g, src_r = cv2.split(src_image)
ref_b, ref_g, ref_r = cv2.split(ref_image)
# Compute the b, g, and r histograms separately
# The flatten() Numpy method returns a copy of the array c
# collapsed into one dimension.
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
# Compute the normalized cdf for the source and reference image
src_cdf_blue = calculate_cdf(src_hist_blue)
src_cdf_green = calculate_cdf(src_hist_green)
src_cdf_red = calculate_cdf(src_hist_red)
ref_cdf_blue = calculate_cdf(ref_hist_blue)
ref_cdf_green = calculate_cdf(ref_hist_green)
ref_cdf_red = calculate_cdf(ref_hist_red)
# Make a separate lookup table for each color
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
# Use the lookup function to transform the colors of the original
# source image
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
green_after_transform = cv2.LUT(src_g, green_lookup_table)
red_after_transform = cv2.LUT(src_r, red_lookup_table)
# Put the image back together
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
image_after_matching = cv2.convertScaleAbs(image_after_matching)
return image_after_matching
def _standard_face_pts():
pts = (
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
return np.reshape(pts, (5, 2))
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine
def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(landmark, target_pts)
return affine
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
return theta
def blur_blending(im1, im2, mask):
mask *= 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask = Image.fromarray(mask.astype("uint8")).convert("L")
im1 = Image.fromarray(im1.astype("uint8"))
im2 = Image.fromarray(im2.astype("uint8"))
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
im = Image.composite(im1, im2, mask)
im = Image.composite(im, im2, mask_blur)
return np.array(im) / 255.0
def blur_blending_cv2(im1, im2, mask):
mask *= 255.0
kernel = np.ones((9, 9), np.uint8)
mask = cv2.erode(mask, kernel, iterations=3)
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
mask_blur /= 255.0
im = im1 * mask_blur + (1 - mask_blur) * im2
im /= 255.0
im = np.clip(im, 0.0, 1.0)
return im
# def Poisson_blending(im1,im2,mask):
# Image.composite(
def Poisson_blending(im1, im2, mask):
# mask=1-mask
mask *= 255
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask /= 255
mask = 1 - mask
mask *= 255
mask = mask[:, :, 0]
width, height, channels = im1.shape
center = (int(height / 2), int(width / 2))
result = cv2.seamlessClone(
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
)
return result / 255.0
def Poisson_B(im1, im2, mask, center):
mask *= 255
result = cv2.seamlessClone(
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
)
return result / 255
def seamless_clone(old_face, new_face, raw_mask):
height, width, _ = old_face.shape
height = height // 2
width = width // 2
y_indices, x_indices, _ = np.nonzero(raw_mask)
y_crop = slice(np.min(y_indices), np.max(y_indices))
x_crop = slice(np.min(x_indices), np.max(x_indices))
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask[insertion_mask != 0] = 255
prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
"uint8"
)
# if np.sum(insertion_mask) == 0:
n_mask = insertion_mask[1:-1, 1:-1, :]
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
print(n_mask.shape)
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
if w < 4 or h < 4:
blended = prior
else:
blended = cv2.seamlessClone(
insertion, # pylint: disable=no-member
prior,
insertion_mask,
(x_center, y_center),
cv2.NORMAL_CLONE,
) # pylint: disable=no-member
blended = blended[height:-height, width:-width]
return blended.astype("float32") / 255.0
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--origin_url", type=str, default="./", help="origin images")
parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
parser.add_argument("--save_url", type=str, default="./save")
opts = parser.parse_args()
origin_url = opts.origin_url
replace_url = opts.replace_url
save_url = opts.save_url
if not os.path.exists(save_url):
os.makedirs(save_url)
face_detector = dlib.get_frontal_face_detector()
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
count = 0
for x in os.listdir(origin_url):
img_url = os.path.join(origin_url, x)
pil_img = Image.open(img_url).convert("RGB")
origin_width, origin_height = pil_img.size
image = np.array(pil_img)
start = time.time()
faces = face_detector(image)
done = time.time()
if len(faces) == 0:
print("Warning: There is no face in %s" % (x))
continue
blended = image
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
forward_mask = np.ones_like(image).astype("uint8")
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True)
forward_mask = warp(
forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True
)
affine_inverse = affine.inverse
cur_face = aligned_face
if replace_url != "":
face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
cur_url = os.path.join(replace_url, face_name)
restored_face = Image.open(cur_url).convert("RGB")
restored_face = np.array(restored_face)
cur_face = restored_face
## Histogram Color matching
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = match_histograms(B, A)
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
warped_back = warp(
cur_face,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=3,
preserve_range=True,
)
backward_mask = warp(
forward_mask,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=0,
preserve_range=True,
) ## Nearest neighbour
blended = blur_blending_cv2(warped_back, blended, backward_mask)
blended *= 255.0
io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
count += 1
if count % 1000 == 0:
print("%d have finished ..." % (count))
================================================
FILE: Face_Detection/align_warp_back_multiple_dlib_HR.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import numpy as np
import skimage.io as io
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image, ImageFilter
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib
def calculate_cdf(histogram):
"""
This method calculates the cumulative distribution function
:param array histogram: The values of the histogram
:return: normalized_cdf: The normalized cumulative distribution function
:rtype: array
"""
# Get the cumulative sum of the elements
cdf = histogram.cumsum()
# Normalize the cdf
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def calculate_lookup(src_cdf, ref_cdf):
"""
This method creates the lookup table
:param array src_cdf: The cdf for the source image
:param array ref_cdf: The cdf for the reference image
:return: lookup_table: The lookup table
:rtype: array
"""
lookup_table = np.zeros(256)
lookup_val = 0
for src_pixel_val in range(len(src_cdf)):
lookup_val
for ref_pixel_val in range(len(ref_cdf)):
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
lookup_val = ref_pixel_val
break
lookup_table[src_pixel_val] = lookup_val
return lookup_table
def match_histograms(src_image, ref_image):
"""
This method matches the source image histogram to the
reference signal
:param image src_image: The original source image
:param image ref_image: The reference image
:return: image_after_matching
:rtype: image (array)
"""
# Split the images into the different color channels
# b means blue, g means green and r means red
src_b, src_g, src_r = cv2.split(src_image)
ref_b, ref_g, ref_r = cv2.split(ref_image)
# Compute the b, g, and r histograms separately
# The flatten() Numpy method returns a copy of the array c
# collapsed into one dimension.
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
# Compute the normalized cdf for the source and reference image
src_cdf_blue = calculate_cdf(src_hist_blue)
src_cdf_green = calculate_cdf(src_hist_green)
src_cdf_red = calculate_cdf(src_hist_red)
ref_cdf_blue = calculate_cdf(ref_hist_blue)
ref_cdf_green = calculate_cdf(ref_hist_green)
ref_cdf_red = calculate_cdf(ref_hist_red)
# Make a separate lookup table for each color
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
# Use the lookup function to transform the colors of the original
# source image
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
green_after_transform = cv2.LUT(src_g, green_lookup_table)
red_after_transform = cv2.LUT(src_r, red_lookup_table)
# Put the image back together
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
image_after_matching = cv2.convertScaleAbs(image_after_matching)
return image_after_matching
def _standard_face_pts():
pts = (
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
return np.reshape(pts, (5, 2))
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine
def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(landmark, target_pts)
return affine
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
return theta
def blur_blending(im1, im2, mask):
mask *= 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask = Image.fromarray(mask.astype("uint8")).convert("L")
im1 = Image.fromarray(im1.astype("uint8"))
im2 = Image.fromarray(im2.astype("uint8"))
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
im = Image.composite(im1, im2, mask)
im = Image.composite(im, im2, mask_blur)
return np.array(im) / 255.0
def blur_blending_cv2(im1, im2, mask):
mask *= 255.0
kernel = np.ones((9, 9), np.uint8)
mask = cv2.erode(mask, kernel, iterations=3)
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
mask_blur /= 255.0
im = im1 * mask_blur + (1 - mask_blur) * im2
im /= 255.0
im = np.clip(im, 0.0, 1.0)
return im
# def Poisson_blending(im1,im2,mask):
# Image.composite(
def Poisson_blending(im1, im2, mask):
# mask=1-mask
mask *= 255
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask /= 255
mask = 1 - mask
mask *= 255
mask = mask[:, :, 0]
width, height, channels = im1.shape
center = (int(height / 2), int(width / 2))
result = cv2.seamlessClone(
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE
)
return result / 255.0
def Poisson_B(im1, im2, mask, center):
mask *= 255
result = cv2.seamlessClone(
im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE
)
return result / 255
def seamless_clone(old_face, new_face, raw_mask):
height, width, _ = old_face.shape
height = height // 2
width = width // 2
y_indices, x_indices, _ = np.nonzero(raw_mask)
y_crop = slice(np.min(y_indices), np.max(y_indices))
x_crop = slice(np.min(x_indices), np.max(x_indices))
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask[insertion_mask != 0] = 255
prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype(
"uint8"
)
# if np.sum(insertion_mask) == 0:
n_mask = insertion_mask[1:-1, 1:-1, :]
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
print(n_mask.shape)
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
if w < 4 or h < 4:
blended = prior
else:
blended = cv2.seamlessClone(
insertion, # pylint: disable=no-member
prior,
insertion_mask,
(x_center, y_center),
cv2.NORMAL_CLONE,
) # pylint: disable=no-member
blended = blended[height:-height, width:-width]
return blended.astype("float32") / 255.0
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--origin_url", type=str, default="./", help="origin images")
parser.add_argument("--replace_url", type=str, default="./", help="restored faces")
parser.add_argument("--save_url", type=str, default="./save")
opts = parser.parse_args()
origin_url = opts.origin_url
replace_url = opts.replace_url
save_url = opts.save_url
if not os.path.exists(save_url):
os.makedirs(save_url)
face_detector = dlib.get_frontal_face_detector()
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
count = 0
for x in os.listdir(origin_url):
img_url = os.path.join(origin_url, x)
pil_img = Image.open(img_url).convert("RGB")
origin_width, origin_height = pil_img.size
image = np.array(pil_img)
start = time.time()
faces = face_detector(image)
done = time.time()
if len(faces) == 0:
print("Warning: There is no face in %s" % (x))
continue
blended = image
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
forward_mask = np.ones_like(image).astype("uint8")
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True)
forward_mask = warp(
forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True
)
affine_inverse = affine.inverse
cur_face = aligned_face
if replace_url != "":
face_name = x[:-4] + "_" + str(face_id + 1) + ".png"
cur_url = os.path.join(replace_url, face_name)
restored_face = Image.open(cur_url).convert("RGB")
restored_face = np.array(restored_face)
cur_face = restored_face
## Histogram Color matching
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = match_histograms(B, A)
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
warped_back = warp(
cur_face,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=3,
preserve_range=True,
)
backward_mask = warp(
forward_mask,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=0,
preserve_range=True,
) ## Nearest neighbour
blended = blur_blending_cv2(warped_back, blended, backward_mask)
blended *= 255.0
io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))
count += 1
if count % 1000 == 0:
print("%d have finished ..." % (count))
================================================
FILE: Face_Detection/detect_all_dlib.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import numpy as np
import skimage.io as io
# from FaceSDK.face_sdk import FaceDetection
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib
def _standard_face_pts():
pts = (
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
return np.reshape(pts, (5, 2))
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine.params
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
return theta
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
parser.add_argument(
"--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
)
opts = parser.parse_args()
url = opts.url
save_url = opts.save_url
### If the origin url is None, then we don't need to reid the origin image
os.makedirs(url, exist_ok=True)
os.makedirs(save_url, exist_ok=True)
face_detector = dlib.get_frontal_face_detector()
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
count = 0
map_id = {}
for x in os.listdir(url):
img_url = os.path.join(url, x)
pil_img = Image.open(img_url).convert("RGB")
image = np.array(pil_img)
start = time.time()
faces = face_detector(image)
done = time.time()
if len(faces) == 0:
print("Warning: There is no face in %s" % (x))
continue
print(len(faces))
if len(faces) > 0:
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
aligned_face = warp(image, affine, output_shape=(256, 256, 3))
img_name = x[:-4] + "_" + str(face_id + 1)
io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
count += 1
if count % 1000 == 0:
print("%d have finished ..." % (count))
================================================
FILE: Face_Detection/detect_all_dlib_HR.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import numpy as np
import skimage.io as io
# from FaceSDK.face_sdk import FaceDetection
# from face_sdk import FaceDetection
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from PIL import Image
import torch.nn.functional as F
import torchvision as tv
import torchvision.utils as vutils
import time
import cv2
import os
from skimage import img_as_ubyte
import json
import argparse
import dlib
def _standard_face_pts():
pts = (
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
return np.reshape(pts, (5, 2))
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine.params
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
return theta
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
parser.add_argument(
"--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
)
opts = parser.parse_args()
url = opts.url
save_url = opts.save_url
### If the origin url is None, then we don't need to reid the origin image
os.makedirs(url, exist_ok=True)
os.makedirs(save_url, exist_ok=True)
face_detector = dlib.get_frontal_face_detector()
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
count = 0
map_id = {}
for x in os.listdir(url):
img_url = os.path.join(url, x)
pil_img = Image.open(img_url).convert("RGB")
image = np.array(pil_img)
start = time.time()
faces = face_detector(image)
done = time.time()
if len(faces) == 0:
print("Warning: There is no face in %s" % (x))
continue
print(len(faces))
if len(faces) > 0:
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
aligned_face = warp(image, affine, output_shape=(512, 512, 3))
img_name = x[:-4] + "_" + str(face_id + 1)
io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
count += 1
if count % 1000 == 0:
print("%d have finished ..." % (count))
================================================
FILE: Face_Enhancement/data/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
from data.face_dataset import FaceTestDataset
def create_dataloader(opt):
instance = FaceTestDataset()
instance.initialize(opt)
print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=opt.isTrain,
)
return dataloader
================================================
FILE: Face_Enhancement/data/base_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess_mode == "resize_and_crop":
new_h = new_w = opt.load_size
elif opt.preprocess_mode == "scale_width_and_crop":
new_w = opt.load_size
new_h = opt.load_size * h // w
elif opt.preprocess_mode == "scale_shortside_and_crop":
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(opt.load_size * ls / ss)
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {"crop_pos": (x, y), "flip": flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
transform_list = []
if "resize" in opt.preprocess_mode:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, interpolation=method))
elif "scale_width" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
elif "scale_shortside" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
if "crop" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
if opt.preprocess_mode == "none":
base = 32
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.preprocess_mode == "fixed":
w = opt.crop_size
h = round(opt.crop_size / opt.aspect_ratio)
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __resize(img, w, h, method=Image.BICUBIC):
return img.resize((w, h), method)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_width:
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __scale_shortside(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
width_is_shorter = ow == ss
if ss == target_width:
return img
ls = int(target_width * ls / ss)
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
return img.resize((nw, nh), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
return img.crop((x1, y1, x1 + tw, y1 + th))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
================================================
FILE: Face_Enhancement/data/custom_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.pix2pix_dataset import Pix2pixDataset
from data.image_folder import make_dataset
class CustomDataset(Pix2pixDataset):
""" Dataset that loads images from directories
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
The images in the directories are sorted in alphabetical order and paired in order.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode="resize_and_crop")
load_size = 286 if is_train else 256
parser.set_defaults(load_size=load_size)
parser.set_defaults(crop_size=256)
parser.set_defaults(display_winsize=256)
parser.set_defaults(label_nc=13)
parser.set_defaults(contain_dontcare_label=False)
parser.add_argument(
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
)
parser.add_argument(
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
)
parser.add_argument(
"--instance_dir",
type=str,
default="",
help="path to the directory that contains instance maps. Leave black if not exists",
)
return parser
def get_paths(self, opt):
label_dir = opt.label_dir
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
image_dir = opt.image_dir
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
if len(opt.instance_dir) > 0:
instance_dir = opt.instance_dir
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
else:
instance_paths = []
assert len(label_paths) == len(
image_paths
), "The #images in %s and %s do not match. Is there something wrong?"
return label_paths, image_paths, instance_paths
================================================
FILE: Face_Enhancement/data/face_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
import torch
class FaceTestDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
# parser.set_defaults(contain_dontcare_label=False)
# parser.set_defaults(no_instance=True)
return parser
def initialize(self, opt):
self.opt = opt
image_path = os.path.join(opt.dataroot, opt.old_face_folder)
label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)
image_list = os.listdir(image_path)
image_list = sorted(image_list)
# image_list=image_list[:opt.max_dataset_size]
self.label_paths = label_path ## Just the root dir
self.image_paths = image_list ## All the image name
self.parts = [
"skin",
"hair",
"l_brow",
"r_brow",
"l_eye",
"r_eye",
"eye_g",
"l_ear",
"r_ear",
"ear_r",
"nose",
"mouth",
"u_lip",
"l_lip",
"neck",
"neck_l",
"cloth",
"hat",
]
size = len(self.image_paths)
self.dataset_size = size
def __getitem__(self, index):
params = get_params(self.opt, (-1, -1))
image_name = self.image_paths[index]
image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)
image = Image.open(image_path)
image = image.convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
img_name = image_name[:-4]
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
full_label = []
cnt = 0
for each_part in self.parts:
part_name = img_name + "_" + each_part + ".png"
part_url = os.path.join(self.label_paths, part_name)
if os.path.exists(part_url):
label = Image.open(part_url).convert("RGB")
label_tensor = transform_label(label) ## 3 channels and pixel [0,1]
full_label.append(label_tensor[0])
else:
current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
full_label.append(current_part)
cnt += 1
full_label_tensor = torch.stack(full_label, 0)
input_dict = {
"label": full_label_tensor,
"image": image_tensor,
"path": image_path,
}
return input_dict
def __len__(self):
return self.dataset_size
================================================
FILE: Face_Enhancement/data/image_folder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tiff",
".webp",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset_rec(dir, images):
assert os.path.isdir(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
images = []
if read_cache:
possible_filelist = os.path.join(dir, "files.list")
if os.path.isfile(possible_filelist):
with open(possible_filelist, "r") as f:
images = f.read().splitlines()
return images
if recursive:
make_dataset_rec(dir, images)
else:
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
if write_cache:
filelist_cache = os.path.join(dir, "files.list")
with open(filelist_cache, "w") as f:
for path in images:
f.write("%s\n" % path)
print("wrote filelist cache at %s" % filelist_cache)
return images
def default_loader(path):
return Image.open(path).convert("RGB")
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise (
RuntimeError(
"Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
)
)
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: Face_Enhancement/data/pix2pix_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths, instance_paths = self.get_paths(opt)
util.natural_sort(label_paths)
util.natural_sort(image_paths)
if not opt.no_instance:
util.natural_sort(instance_paths)
label_paths = label_paths[: opt.max_dataset_size]
image_paths = image_paths[: opt.max_dataset_size]
instance_paths = instance_paths[: opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), (
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
% (path1, path2)
)
self.label_paths = label_paths
self.image_paths = image_paths
self.instance_paths = instance_paths
size = len(self.label_paths)
self.dataset_size = size
def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths, instance_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
# input image (real images)
image_path = self.image_paths[index]
assert self.paths_match(
label_path, image_path
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
image = Image.open(image_path)
image = image.convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
# if using instance maps
if self.opt.no_instance:
instance_tensor = 0
else:
instance_path = self.instance_paths[index]
instance = Image.open(instance_path)
if instance.mode == "L":
instance_tensor = transform_label(instance) * 255
instance_tensor = instance_tensor.long()
else:
instance_tensor = transform_label(instance)
input_dict = {
"label": label_tensor,
"instance": instance_tensor,
"image": image_tensor,
"path": image_path,
}
# Give subclasses a chance to modify the final output
self.postprocess(input_dict)
return input_dict
def postprocess(self, input_dict):
return input_dict
def __len__(self):
return self.dataset_size
================================================
FILE: Face_Enhancement/models/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import torch
def find_model_using_name(model_name):
# Given the option --model [modelname],
# the file "models/modelname_model.py"
# will be imported.
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
# In the file, the class called ModelNameModel() will
# be instantiated. It has to be a subclass of torch.nn.Module,
# and it is case-insensitive.
model = None
target_model_name = model_name.replace("_", "") + "model"
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
model = cls
if model is None:
print(
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
% (model_filename, target_model_name)
)
exit(0)
return model
def get_option_setter(model_name):
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % (type(instance).__name__))
return instance
================================================
FILE: Face_Enhancement/models/networks/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from models.networks.base_network import BaseNetwork
from models.networks.generator import *
from models.networks.encoder import *
import util.util as util
def find_network_using_name(target_network_name, filename):
target_class_name = target_network_name + filename
module_name = "models.networks." + filename
network = util.find_class_in_module(target_class_name, module_name)
assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network
return network
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
netG_cls = find_network_using_name(opt.netG, "generator")
parser = netG_cls.modify_commandline_options(parser, is_train)
if is_train:
netD_cls = find_network_using_name(opt.netD, "discriminator")
parser = netD_cls.modify_commandline_options(parser, is_train)
netE_cls = find_network_using_name("conv", "encoder")
parser = netE_cls.modify_commandline_options(parser, is_train)
return parser
def create_network(cls, opt):
net = cls(opt)
net.print_network()
if len(opt.gpu_ids) > 0:
assert torch.cuda.is_available()
net.cuda()
net.init_weights(opt.init_type, opt.init_variance)
return net
def define_G(opt):
netG_cls = find_network_using_name(opt.netG, "generator")
return create_network(netG_cls, opt)
def define_D(opt):
netD_cls = find_network_using_name(opt.netD, "discriminator")
return create_network(netD_cls, opt)
def define_E(opt):
# there exists only one encoder type
netE_cls = find_network_using_name("conv", "encoder")
return create_network(netE_cls, opt)
================================================
FILE: Face_Enhancement/models/networks/architecture.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from models.networks.normalization import SPADE
# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
# ResNet block used in pix2pixHD
# We keep the same architecture as pix2pixHD.
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
super().__init__()
pw = (kernel_size - 1) // 2
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
activation,
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
)
def forward(self, x):
y = self.conv_block(x)
out = x + y
return out
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class SPADEResnetBlock_non_spade(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(x))
dx = self.conv_1(self.actvn(dx))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
================================================
FILE: Face_Enhancement/models/networks/base_network.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
from torch.nn import init
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
"Network [%s] was created. Total number of parameters: %.1f million. "
"To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000)
)
def init_weights(self, init_type="normal", gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find("BatchNorm2d") != -1:
if hasattr(m, "weight") and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
if init_type == "normal":
init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "xavier_uniform":
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == "kaiming":
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == "none": # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, "init_weights"):
m.init_weights(init_type, gain)
================================================
FILE: Face_Enhancement/models/networks/encoder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
class ConvEncoder(BaseNetwork):
""" Same architecture as the image discriminator """
def __init__(self, opt):
super().__init__()
kw = 3
pw = int(np.ceil((kw - 1.0) / 2))
ndf = opt.ngf
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
if opt.crop_size >= 256:
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
self.so = s0 = 4
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
self.actvn = nn.LeakyReLU(0.2, False)
self.opt = opt
def forward(self, x):
if x.size(2) != 256 or x.size(3) != 256:
x = F.interpolate(x, size=(256, 256), mode="bilinear")
x = self.layer1(x)
x = self.layer2(self.actvn(x))
x = self.layer3(self.actvn(x))
x = self.layer4(self.actvn(x))
x = self.layer5(self.actvn(x))
if self.opt.crop_size >= 256:
x = self.layer6(self.actvn(x))
x = self.actvn(x)
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_var(x)
return mu, logvar
================================================
FILE: Face_Enhancement/models/networks/generator.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
parser.add_argument(
"--num_upsampling_layers",
choices=("normal", "more", "most"),
default="normal",
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh = self.compute_latent_vector_size(opt)
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
if opt.use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
if self.opt.no_parsing_map:
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
else:
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
else:
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
else:
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
else:
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
else:
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
final_nc = nf
if opt.num_upsampling_layers == "most":
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
final_nc = nf // 2
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def compute_latent_vector_size(self, opt):
if opt.num_upsampling_layers == "normal":
num_up_layers = 5
elif opt.num_upsampling_layers == "more":
num_up_layers = 6
elif opt.num_upsampling_layers == "most":
num_up_layers = 7
else:
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
sw = opt.load_size // (2 ** num_up_layers)
sh = round(sw / opt.aspect_ratio)
return sw, sh
def forward(self, input, degraded_image, z=None):
seg = input
if self.opt.use_vae:
# we sample z from unit normal and reshape the tensor
if z is None:
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
x = self.fc(z)
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
else:
# we downsample segmap and run convolution
if self.opt.no_parsing_map:
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
else:
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
x = self.fc(x)
x = self.head_0(x, seg, degraded_image)
x = self.up(x)
x = self.G_middle_0(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.G_middle_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_0(x, seg, degraded_image)
x = self.up(x)
x = self.up_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_2(x, seg, degraded_image)
x = self.up(x)
x = self.up_3(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.up_4(x, seg, degraded_image)
x = self.conv_img(F.leaky_relu(x, 2e-1))
x = F.tanh(x)
return x
class Pix2PixHDGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
)
parser.add_argument(
"--resnet_n_blocks",
type=int,
default=9,
help="number of residual blocks in the global generator network",
)
parser.add_argument(
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
)
parser.add_argument(
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
)
# parser.set_defaults(norm_G='instance')
return parser
def __init__(self, opt):
super().__init__()
input_nc = 3
# print("xxxxx")
# print(opt.norm_G)
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
activation = nn.ReLU(False)
model = []
# initial conv
model += [
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
activation,
]
# downsample
mult = 1
for i in range(opt.resnet_n_downsample):
model += [
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
activation,
]
mult *= 2
# resnet blocks
for i in range(opt.resnet_n_blocks):
model += [
ResnetBlock(
opt.ngf * mult,
norm_layer=norm_layer,
activation=activation,
kernel_size=opt.resnet_kernel_size,
)
]
# upsample
for i in range(opt.resnet_n_downsample):
nc_in = int(opt.ngf * mult)
nc_out = int((opt.ngf * mult) / 2)
model += [
norm_layer(
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
),
activation,
]
mult = mult // 2
# final output conv
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
nn.Tanh(),
]
self.model = nn.Sequential(*model)
def forward(self, input, degraded_image, z=None):
return self.model(degraded_image)
================================================
FILE: Face_Enhancement/models/networks/normalization.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(opt, norm_type="instance"):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, "out_channels"):
return getattr(layer, "out_channels")
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith("spectral"):
layer = spectral_norm(layer)
subnorm_type = norm_type[len("spectral") :]
if subnorm_type == "none" or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, "bias", None) is not None:
delattr(layer, "bias")
layer.register_parameter("bias", None)
if subnorm_type == "batch":
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "sync_batch":
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "instance":
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc, opt):
super().__init__()
assert config_text.startswith("spade")
parsed = re.search("spade(\D+)(\d)x\d", config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
self.opt = opt
if param_free_norm_type == "instance":
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "syncbatch":
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "batch":
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
if self.opt.no_parsing_map:
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
else:
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap, degraded_image):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
if self.opt.no_parsing_map:
actv = self.mlp_shared(degraded_face)
else:
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
================================================
FILE: Face_Enhancement/models/pix2pix_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import models.networks as networks
import util.util as util
class Pix2PixModel(torch.nn.Module):
@staticmethod
def modify_commandline_options(parser, is_train):
networks.modify_commandline_options(parser, is_train)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
self.netG, self.netD, self.netE = self.initialize_networks(opt)
# set loss functions
if opt.isTrain:
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
if opt.use_vae:
self.KLDLoss = networks.KLDLoss()
# Entry point for all calls involving forward pass
# of deep networks. We used this approach since DataParallel module
# can't parallelize custom functions, we branch to different
# routines based on |mode|.
def forward(self, data, mode):
input_semantics, real_image, degraded_image = self.preprocess_input(data)
if mode == "generator":
g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
return g_loss, generated
elif mode == "discriminator":
d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
return d_loss
elif mode == "encode_only":
z, mu, logvar = self.encode_z(real_image)
return mu, logvar
elif mode == "inference":
with torch.no_grad():
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
return fake_image
else:
raise ValueError("|mode| is invalid")
def create_optimizers(self, opt):
G_params = list(self.netG.parameters())
if opt.use_vae:
G_params += list(self.netE.parameters())
if opt.isTrain:
D_params = list(self.netD.parameters())
beta1, beta2 = opt.beta1, opt.beta2
if opt.no_TTUR:
G_lr, D_lr = opt.lr, opt.lr
else:
G_lr, D_lr = opt.lr / 2, opt.lr * 2
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
return optimizer_G, optimizer_D
def save(self, epoch):
util.save_network(self.netG, "G", epoch, self.opt)
util.save_network(self.netD, "D", epoch, self.opt)
if self.opt.use_vae:
util.save_network(self.netE, "E", epoch, self.opt)
############################################################################
# Private helper methods
############################################################################
def initialize_networks(self, opt):
netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None
if not opt.isTrain or opt.continue_train:
netG = util.load_network(netG, "G", opt.which_epoch, opt)
if opt.isTrain:
netD = util.load_network(netD, "D", opt.which_epoch, opt)
if opt.use_vae:
netE = util.load_network(netE, "E", opt.which_epoch, opt)
return netG, netD, netE
# preprocess the input, such as moving the tensors to GPUs and
# transforming the label map to one-hot encoding
# |data|: dictionary of the input data
def preprocess_input(self, data):
# move to GPU and change data types
# data['label'] = data['label'].long()
if not self.opt.isTrain:
if self.use_gpu():
data["label"] = data["label"].cuda()
data["image"] = data["image"].cuda()
return data["label"], data["image"], data["image"]
## While testing, the input image is the degraded face
if self.use_gpu():
data["label"] = data["label"].cuda()
data["degraded_image"] = data["degraded_image"].cuda()
data["image"] = data["image"].cuda()
# # create one-hot label map
# label_map = data['label']
# bs, _, h, w = label_map.size()
# nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
# else self.opt.label_nc
# input_label = self.FloatTensor(bs, nc, h, w).zero_()
# input_semantics = input_label.scatter_(1, label_map, 1.0)
return data["label"], data["image"], data["degraded_image"]
def compute_generator_loss(self, input_semantics, degraded_image, real_image):
G_losses = {}
fake_image, KLD_loss = self.generate_fake(
input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
)
if self.opt.use_vae:
G_losses["KLD"] = KLD_loss
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
if not self.opt.no_ganFeat_loss:
num_D = len(pred_fake)
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
for i in range(num_D): # for each discriminator
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs): # for each layer output
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
G_losses["GAN_Feat"] = GAN_Feat_loss
if not self.opt.no_vgg_loss:
G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
return G_losses, fake_image
def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
D_losses = {}
with torch.no_grad():
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
fake_image = fake_image.detach()
fake_image.requires_grad_()
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
return D_losses
def encode_z(self, real_image):
mu, logvar = self.netE(real_image)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
z = None
KLD_loss = None
if self.opt.use_vae:
z, mu, logvar = self.encode_z(real_image)
if compute_kld_loss:
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
fake_image = self.netG(input_semantics, degraded_image, z=z)
assert (
not compute_kld_loss
) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
return fake_image, KLD_loss
# Given fake and real image, return the prediction of discriminator
# for each fake and real image.
def discriminate(self, input_semantics, fake_image, real_image):
if self.opt.no_parsing_map:
fake_concat = fake_image
real_concat = real_image
else:
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
real_concat = torch.cat([input_semantics, real_image], dim=1)
# In Batch Normalization, the fake and real images are
# recommended to be in the same batch to avoid disparate
# statistics in fake and real images.
# So both fake and real images are fed to D all at once.
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
discriminator_out = self.netD(fake_and_real)
pred_fake, pred_real = self.divide_pred(discriminator_out)
return pred_fake, pred_real
# Take the prediction of fake and real images from the combined batch
def divide_pred(self, pred):
# the prediction contains the intermediate outputs of multiscale GAN,
# so it's usually a list
if type(pred) == list:
fake = []
real = []
for p in pred:
fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
else:
fake = pred[: pred.size(0) // 2]
real = pred[pred.size(0) // 2 :]
return fake, real
def get_edges(self, t):
edge = self.ByteTensor(t.size()).zero_()
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
return edge.float()
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std) + mu
def use_gpu(self):
return len(self.opt.gpu_ids) > 0
================================================
FILE: Face_Enhancement/options/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
================================================
FILE: Face_Enhancement/options/base_options.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import argparse
import os
from util import util
import torch
import models
import data
import pickle
class BaseOptions:
def __init__(self):
self.initialized = False
def initialize(self, parser):
# experiment specifics
parser.add_argument(
"--name",
type=str,
default="label2coco",
help="name of the experiment. It decides where to store samples and models",
)
parser.add_argument(
"--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU"
)
parser.add_argument(
"--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here"
)
parser.add_argument("--model", type=str, default="pix2pix", help="which model to use")
parser.add_argument(
"--norm_G",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument(
"--norm_D",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument(
"--norm_E",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")
# input/output sizes
parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
parser.add_argument(
"--preprocess_mode",
type=str,
default="scale_width_and_crop",
help="scaling and cropping of images at load time.",
choices=(
"resize_and_crop",
"crop",
"scale_width",
"scale_width_and_crop",
"scale_shortside",
"scale_shortside_and_crop",
"fixed",
"none",
"resize",
),
)
parser.add_argument(
"--load_size",
type=int,
default=1024,
help="Scale images to this size. The final image will be cropped to --crop_size.",
)
parser.add_argument(
"--crop_size",
type=int,
default=512,
help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
)
parser.add_argument(
"--aspect_ratio",
type=float,
default=1.0,
help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
)
parser.add_argument(
"--label_nc",
type=int,
default=182,
help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
)
parser.add_argument(
"--contain_dontcare_label",
action="store_true",
help="if the label map contains dontcare label (dontcare=255)",
)
parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels")
# for setting inputs
parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
parser.add_argument("--dataset_mode", type=str, default="coco")
parser.add_argument(
"--serial_batches",
action="store_true",
help="if true, takes images in order to make batches, otherwise takes them randomly",
)
parser.add_argument(
"--no_flip",
action="store_true",
help="if specified, do not flip the images for data argumentation",
)
parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data")
parser.add_argument(
"--max_dataset_size",
type=int,
default=sys.maxsize,
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
)
parser.add_argument(
"--load_from_opt_file",
action="store_true",
help="load the options from checkpoints and use that as default",
)
parser.add_argument(
"--cache_filelist_write",
action="store_true",
help="saves the current filelist into a text file, so that it loads faster",
)
parser.add_argument(
"--cache_filelist_read", action="store_true", help="reads from the file list cache"
)
# for displays
parser.add_argument("--display_winsize", type=int, default=400, help="display window size")
# for generator
parser.add_argument(
"--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)"
)
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer")
parser.add_argument(
"--init_type",
type=str,
default="xavier",
help="network initialization [normal|xavier|kaiming|orthogonal]",
)
parser.add_argument(
"--init_variance", type=float, default=0.02, help="variance of the initialization distribution"
)
parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector")
parser.add_argument(
"--no_parsing_map", action="store_true", help="During training, we do not use the parsing map"
)
# for instance-wise features
parser.add_argument(
"--no_instance", action="store_true", help="if specified, do *not* add instance map as input"
)
parser.add_argument(
"--nef", type=int, default=16, help="# of encoder filters in the first conv layer"
)
parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.")
parser.add_argument(
"--tensorboard_log", action="store_true", help="use tensorboard to record the resutls"
)
# parser.add_argument('--img_dir',)
parser.add_argument(
"--old_face_folder", type=str, default="", help="The folder name of input old face"
)
parser.add_argument(
"--old_face_label_folder", type=str, default="", help="The folder name of input old face label"
)
parser.add_argument("--injection_layer", type=str, default="all", help="")
self.initialized = True
return parser
def gather_options(self):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, unknown = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
# modify dataset-related parser options
# dataset_mode = opt.dataset_mode
# dataset_option_setter = data.get_option_setter(dataset_mode)
# parser = dataset_option_setter(parser, self.isTrain)
opt, unknown = parser.parse_known_args()
# if there is opt_file, load it.
# The previous default options will be overwritten
if opt.load_from_opt_file:
parser = self.update_options_from_file(parser, opt)
opt = parser.parse_args()
self.parser = parser
return opt
def print_options(self, opt):
message = ""
message += "----------------- Options ---------------\n"
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
message += "----------------- End -------------------"
# print(message)
def option_file_path(self, opt, makedir=False):
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
if makedir:
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, "opt")
return file_name
def save_options(self, opt):
file_name = self.option_file_path(opt, makedir=True)
with open(file_name + ".txt", "wt") as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))
with open(file_name + ".pkl", "wb") as opt_file:
pickle.dump(opt, opt_file)
def update_options_from_file(self, parser, opt):
new_opt = self.load_options(opt)
for k, v in sorted(vars(opt).items()):
if hasattr(new_opt, k) and v != getattr(new_opt, k):
new_val = getattr(new_opt, k)
parser.set_defaults(**{k: new_val})
return parser
def load_options(self, opt):
file_name = self.option_file_path(opt, makedir=False)
new_opt = pickle.load(open(file_name + ".pkl", "rb"))
return new_opt
def parse(self, save=False):
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
opt.contain_dontcare_label = False
self.print_options(opt)
if opt.isTrain:
self.save_options(opt)
# Set semantic_nc based on the option.
# This will be convenient in many places
opt.semantic_nc = (
opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
)
# set gpu ids
str_ids = opt.gpu_ids.split(",")
opt.gpu_ids = []
for str_id in str_ids:
int_id = int(str_id)
if int_id >= 0:
opt.gpu_ids.append(int_id)
if len(opt.gpu_ids) > 0:
print("The main GPU is ")
print(opt.gpu_ids[0])
torch.cuda.set_device(opt.gpu_ids[0])
assert (
len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0
), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))
self.opt = opt
return self.opt
================================================
FILE: Face_Enhancement/options/test_options.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self, parser):
BaseOptions.initialize(self, parser)
parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
parser.add_argument(
"--which_epoch",
type=str,
default="latest",
help="which epoch to load? set to latest to use latest cached model",
)
parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")
parser.set_defaults(
preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
)
parser.set_defaults(serial_batches=True)
parser.set_defaults(no_flip=True)
parser.set_defaults(phase="test")
self.isTrain = False
return parser
================================================
FILE: Face_Enhancement/requirements.txt
================================================
torch>=1.0.0
torchvision
dominate>=2.3.1
wandb
dill
scikit-image
tensorboardX
scipy
opencv-python
================================================
FILE: Face_Enhancement/test_face.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from collections import OrderedDict
import data
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
from util.visualizer import Visualizer
import torchvision.utils as vutils
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
opt = TestOptions().parse()
dataloader = data.create_dataloader(opt)
model = Pix2PixModel(opt)
model.eval()
visualizer = Visualizer(opt)
single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img")
if not os.path.exists(single_save_url):
os.makedirs(single_save_url)
for i, data_i in enumerate(dataloader):
if i * opt.batchSize >= opt.how_many:
break
generated = model(data_i, mode="inference")
img_path = data_i["path"]
for b in range(generated.shape[0]):
img_name = os.path.split(img_path[b])[-1]
save_img_url = os.path.join(single_save_url, img_name)
vutils.save_image((generated[b] + 1) / 2, save_img_url)
================================================
FILE: Face_Enhancement/util/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
================================================
FILE: Face_Enhancement/util/iter_counter.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import time
import numpy as np
# Helper class that keeps track of training iterations
class IterationCounter:
def __init__(self, opt, dataset_size):
self.opt = opt
self.dataset_size = dataset_size
self.first_epoch = 1
self.total_epochs = opt.niter + opt.niter_decay
self.epoch_iter = 0 # iter number within each epoch
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
if opt.isTrain and opt.continue_train:
try:
self.first_epoch, self.epoch_iter = np.loadtxt(
self.iter_record_path, delimiter=",", dtype=int
)
print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
except:
print(
"Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
)
self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
# return the iterator of epochs for the training
def training_epochs(self):
return range(self.first_epoch, self.total_epochs + 1)
def record_epoch_start(self, epoch):
self.epoch_start_time = time.time()
self.epoch_iter = 0
self.last_iter_time = time.time()
self.current_epoch = epoch
def record_one_iteration(self):
current_time = time.time()
# the last remaining batch is dropped (see data/__init__.py),
# so we can assume batch size is always opt.batchSize
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
self.last_iter_time = current_time
self.total_steps_so_far += self.opt.batchSize
self.epoch_iter += self.opt.batchSize
def record_epoch_end(self):
current_time = time.time()
self.time_per_epoch = current_time - self.epoch_start_time
print(
"End of epoch %d / %d \t Time Taken: %d sec"
% (self.current_epoch, self.total_epochs, self.time_per_epoch)
)
if self.current_epoch % self.opt.save_epoch_freq == 0:
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
print("Saved current iteration count at %s." % self.iter_record_path)
def record_current_iter(self):
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
print("Saved current iteration count at %s." % self.iter_record_path)
def needs_saving(self):
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
def needs_printing(self):
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
def needs_displaying(self):
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
================================================
FILE: Face_Enhancement/util/util.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
import argparse
import dill as pickle
def save_obj(obj, name):
with open(name, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name):
with open(name, "rb") as f:
return pickle.load(f)
def copyconf(default_opt, **kwargs):
conf = argparse.Namespace(**vars(default_opt))
for key in kwargs:
print(key, kwargs[key])
setattr(conf, key, kwargs[key])
return conf
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
if image_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(image_tensor.size(0)):
one_image = image_tensor[b]
one_image_np = tensor2im(one_image)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
return images_np
if image_tensor.dim() == 2:
image_tensor = image_tensor.unsqueeze(0)
image_numpy = image_tensor.detach().cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
if label_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(label_tensor.size(0)):
one_image = label_tensor[b]
one_image_np = tensor2label(one_image, n_label, imtype)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
# if tile:
# images_tiled = tile_images(images_np)
# return images_tiled
# else:
# images_np = images_np[0]
# return images_np
return images_np
if label_tensor.dim() == 1:
return np.zeros((64, 64, 3), dtype=np.uint8)
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
result = label_numpy.astype(imtype)
return result
def save_image(image_numpy, image_path, create_dir=False):
if create_dir:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
if len(image_numpy.shape) == 2:
image_numpy = np.expand_dims(image_numpy, axis=2)
if image_numpy.shape[2] == 1:
image_numpy = np.repeat(image_numpy, 3, 2)
image_pil = Image.fromarray(image_numpy)
# save to png
image_pil.save(image_path.replace(".jpg", ".png"))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
"""
return [atoi(c) for c in re.split("(\d+)", text)]
def natural_sort(items):
items.sort(key=natural_keys)
def str2bool(v):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace("_", "").lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print(
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
% (module, target_cls_name)
)
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
if os.path.exists(save_path):
weights = torch.load(save_path)
net.load_state_dict(weights)
return net
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image
================================================
FILE: Face_Enhancement/util/visualizer.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import ntpath
import time
from . import util
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np
class Visualizer:
def __init__(self, opt):
self.opt = opt
self.tf_log = opt.isTrain and opt.tf_log
self.tensorboard_log = opt.tensorboard_log
self.win_size = opt.display_winsize
self.name = opt.name
if self.tensorboard_log:
if self.opt.isTrain:
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.writer = SummaryWriter(log_dir=self.log_dir)
else:
print("hi :)")
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if opt.isTrain:
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write("================ Training Loss (%s) ================\n" % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, step):
all_tensor = []
if self.tensorboard_log:
for key, tensor in visuals.items():
all_tensor.append((tensor.data.cpu() + 1) / 2)
output = torch.cat(all_tensor, 0)
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
if self.opt.isTrain:
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
else:
vutils.save_image(
output,
os.path.join(self.log_dir, str(step) + ".png"),
nrow=self.opt.batchSize,
padding=0,
normalize=False,
)
# errors: dictionary of error labels and values
def plot_current_errors(self, errors, step):
if self.tf_log:
for tag, value in errors.items():
value = value.mean().float()
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
if self.tensorboard_log:
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
self.writer.add_scalars(
"Loss/GAN",
{
"G": errors["GAN"].mean().float(),
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
},
step,
)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
for k, v in errors.items():
v = v.mean().float()
message += "%s: %.3f " % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write("%s\n" % message)
def convert_visuals_to_numpy(self, visuals):
for key, t in visuals.items():
tile = self.opt.batchSize > 8
if "input_label" == key:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
return visuals
# save image to the disk
def save_images(self, webpage, visuals, image_path):
visuals = self.convert_visuals_to_numpy(visuals)
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = os.path.join(label, "%s.png" % (name))
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path, create_dir=True)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
================================================
FILE: GUI.py
================================================
import numpy as np
import cv2
import PySimpleGUI as sg
import os.path
import argparse
import os
import sys
import shutil
from subprocess import call
def modify(image_filename=None, cv2_frame=None):
def run_cmd(command):
try:
call(command, shell=True)
except KeyboardInterrupt:
print("Process interrupted")
sys.exit(1)
parser = argparse.ArgumentParser()
parser.add_argument("--input_folder", type=str,
default= image_filename, help="Test images")
parser.add_argument(
"--output_folder",
type=str,
default="./output",
help="Restored images, please use the absolute path",
)
parser.add_argument("--GPU", type=str, default="-1", help="0,1,2")
parser.add_argument(
"--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint"
)
parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true")
opts = parser.parse_args()
gpu1 = opts.GPU
# resolve relative paths before changing directory
opts.input_folder = os.path.abspath(opts.input_folder)
opts.output_folder = os.path.abspath(opts.output_folder)
if not os.path.exists(opts.output_folder):
os.makedirs(opts.output_folder)
main_environment = os.getcwd()
# Stage 1: Overall Quality Improve
print("Running Stage 1: Overall restoration")
os.chdir("./Global")
stage_1_input_dir = opts.input_folder
stage_1_output_dir = os.path.join(
opts.output_folder, "stage_1_restore_output")
if not os.path.exists(stage_1_output_dir):
os.makedirs(stage_1_output_dir)
if not opts.with_scratch:
stage_1_command = (
"python test.py --test_mode Full --Quality_restore --test_input "
+ stage_1_input_dir
+ " --outputs_dir "
+ stage_1_output_dir
+ " --gpu_ids "
+ gpu1
)
run_cmd(stage_1_command)
else:
mask_dir = os.path.join(stage_1_output_dir, "masks")
new_input = os.path.join(mask_dir, "input")
new_mask = os.path.join(mask_dir, "mask")
stage_1_command_1 = (
"python detection.py --test_path "
+ stage_1_input_dir
+ " --output_dir "
+ mask_dir
+ " --input_size full_size"
+ " --GPU "
+ gpu1
)
stage_1_command_2 = (
"python test.py --Scratch_and_Quality_restore --test_input "
+ new_input
+ " --test_mask "
+ new_mask
+ " --outputs_dir "
+ stage_1_output_dir
+ " --gpu_ids "
+ gpu1
)
run_cmd(stage_1_command_1)
run_cmd(stage_1_command_2)
# Solve the case when there is no face in the old photo
stage_1_results = os.path.join(stage_1_output_dir, "restored_image")
stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
if not os.path.exists(stage_4_output_dir):
os.makedirs(stage_4_output_dir)
for x in os.listdir(stage_1_results):
img_dir = os.path.join(stage_1_results, x)
shutil.copy(img_dir, stage_4_output_dir)
print("Finish Stage 1 ...")
print("\n")
# Stage 2: Face Detection
print("Running Stage 2: Face Detection")
os.chdir(".././Face_Detection")
stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image")
stage_2_output_dir = os.path.join(
opts.output_folder, "stage_2_detection_output")
if not os.path.exists(stage_2_output_dir):
os.makedirs(stage_2_output_dir)
stage_2_command = (
"python detect_all_dlib.py --url " + stage_2_input_dir +
" --save_url " + stage_2_output_dir
)
run_cmd(stage_2_command)
print("Finish Stage 2 ...")
print("\n")
# Stage 3: Face Restore
print("Running Stage 3: Face Enhancement")
os.chdir(".././Face_Enhancement")
stage_3_input_mask = "./"
stage_3_input_face = stage_2_output_dir
stage_3_output_dir = os.path.join(
opts.output_folder, "stage_3_face_output")
if not os.path.exists(stage_3_output_dir):
os.makedirs(stage_3_output_dir)
stage_3_command = (
"python test_face.py --old_face_folder "
+ stage_3_input_face
+ " --old_face_label_folder "
+ stage_3_input_mask
+ " --tensorboard_log --name "
+ opts.checkpoint_name
+ " --gpu_ids "
+ gpu1
+ " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir "
+ stage_3_output_dir
+ " --no_parsing_map"
)
run_cmd(stage_3_command)
print("Finish Stage 3 ...")
print("\n")
# Stage 4: Warp back
print("Running Stage 4: Blending")
os.chdir(".././Face_Detection")
stage_4_input_image_dir = os.path.join(
stage_1_output_dir, "restored_image")
stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img")
stage_4_output_dir = os.path.join(opts.output_folder, "final_output")
if not os.path.exists(stage_4_output_dir):
os.makedirs(stage_4_output_dir)
stage_4_command = (
"python align_warp_back_multiple_dlib.py --origin_url "
+ stage_4_input_image_dir
+ " --replace_url "
+ stage_4_input_face_dir
+ " --save_url "
+ stage_4_output_dir
)
run_cmd(stage_4_command)
print("Finish Stage 4 ...")
print("\n")
print("All the processing is done. Please check the results.")
# --------------------------------- The GUI ---------------------------------
# First the window layout...
images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()],
[sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')],
[sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],]
# ----- Full layout -----
layout = [[sg.VSeperator(), sg.Column(images_col)]]
# ----- Make the window -----
window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True)
# ----- Run the Event Loop -----
prev_filename = colorized = cap = None
while True:
event, values = window.read()
if event in (None, 'Exit'):
break
elif event == '-MPHOTO-':
try:
n1 = filename.split("/")[-2]
n2 = filename.split("/")[-3]
n3 = filename.split("/")[-1]
filename= str(f"./{n2}/{n1}")
modify(filename)
global f_image
f_image = f'./output/final_output/{n3}'
image = cv2.imread(f_image)
window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes())
except:
continue
elif event == '-IN FILE-': # A single filename was chosen
filename = values['-IN FILE-']
if filename != prev_filename:
prev_filename = filename
try:
image = cv2.imread(filename)
window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes())
except:
continue
# ----- Exit program -----
window.close()
================================================
FILE: Global/data/Create_Bigfile.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import struct
from PIL import Image
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
#print(fname)
path = os.path.join(root, fname)
images.append(path)
return images
### Modify these 3 lines in your own environment
indir="/home/ziyuwan/workspace/data/temp_old"
target_folders=['VOC','Real_L_old','Real_RGB_old']
out_dir ="/home/ziyuwan/workspace/data/temp_old"
###
if os.path.exists(out_dir) is False:
os.makedirs(out_dir)
#
for target_folder in target_folders:
curr_indir = os.path.join(indir, target_folder)
curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
image_lists = make_dataset(curr_indir)
image_lists.sort()
with open(curr_out_file, 'wb') as wfid:
# write total image number
wfid.write(struct.pack('i', len(image_lists)))
for i, img_path in enumerate(image_lists):
# write file name first
img_name = os.path.basename(img_path)
img_name_bytes = img_name.encode('utf-8')
wfid.write(struct.pack('i', len(img_name_bytes)))
wfid.write(img_name_bytes)
#
# # write image data in
with open(img_path, 'rb') as img_fid:
img_bytes = img_fid.read()
wfid.write(struct.pack('i', len(img_bytes)))
wfid.write(img_bytes)
if i % 1000 == 0:
print('write %d images done' % i)
================================================
FILE: Global/data/Load_Bigfile.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import struct
from PIL import Image
class BigFileMemoryLoader(object):
def __load_bigfile(self):
print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
with open(self.file_path, 'rb') as fid:
self.img_num = struct.unpack('i', fid.read(4))[0]
self.img_names = []
self.img_bytes = []
print('find total %d images' % self.img_num)
for i in range(self.img_num):
img_name_len = struct.unpack('i', fid.read(4))[0]
img_name = fid.read(img_name_len).decode('utf-8')
self.img_names.append(img_name)
img_bytes_len = struct.unpack('i', fid.read(4))[0]
self.img_bytes.append(fid.read(img_bytes_len))
if i % 5000 == 0:
print('load %d images done' % i)
print('load all %d images done' % self.img_num)
def __init__(self, file_path):
super(BigFileMemoryLoader, self).__init__()
self.file_path = file_path
self.__load_bigfile()
def __getitem__(self, index):
try:
img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
return self.img_names[index], img
except Exception:
print('Image read error for index %d: %s' % (index, self.img_names[index]))
return self.__getitem__((index+1)%self.img_num)
def __len__(self):
return self.img_num
================================================
FILE: Global/data/__init__.py
================================================
================================================
FILE: Global/data/base_data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
================================================
FILE: Global/data/base_dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256
if w<h:
new_w = opt.loadSize
new_h = opt.loadSize * h // w
else:
new_h=opt.loadSize
new_w = opt.loadSize * w // h
if opt.resize_or_crop=='crop_only':
pass
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it.
transform_list.append(transforms.Scale(256,method))
if 'crop' in opt.resize_or_crop:
if opt.isTrain:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
else:
if opt.test_random_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
else:
transform_list.append(transforms.CenterCrop(opt.fineSize))
## when testing, for ablation study, choose center_crop directly.
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
================================================
FILE: Global/data/custom_dataset_data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile
def CreateDataset(opt):
dataset = None
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
if opt.training_dataset=='mapping':
if opt.random_hole:
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
else:
dataset = dts_ray_bigfile.PairOldPhotos()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=True)
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
================================================
FILE: Global/data/data_loader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
================================================
FILE: Global/data/image_folder.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: Global/data/online_dataset_for_old_photos.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from data.Load_Bigfile import BigFileMemoryLoader
import random
import cv2
from io import BytesIO
def pil_to_np(img_PIL):
'''Converts image in PIL format to np.array.
From W x H x C [0...255] to C x W x H [0..1]
'''
ar = np.array(img_PIL)
if len(ar.shape) == 3:
ar = ar.transpose(2, 0, 1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
def np_to_pil(img_np):
'''Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
'''
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def synthesize_salt_pepper(image,amount,salt_vs_pepper):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
out = img_pil.copy()
p = amount
q = salt_vs_pepper
flipped = np.random.choice([True, False], size=img_pil.shape,
p=[p, 1 - p])
salted = np.random.choice([True, False], size=img_pil.shape,
p=[q, 1 - q])
peppered = ~salted
out[flipped & salted] = 1
out[flipped & peppered] = 0.
noisy = np.clip(out, 0, 1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_gaussian(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_speckle(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss*img_pil
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_low_resolution(img):
w,h=img.size
new_w=random.randint(int(w/2),w)
new_h=random.randint(int(h/2),h)
img=img.resize((new_w,new_h),Image.BICUBIC)
if random.uniform(0,1)<0.5:
img=img.resize((w,h),Image.NEAREST)
else:
img = img.resize((w, h), Image.BILINEAR)
return img
def convertToJpeg(im,quality):
with BytesIO() as f:
im.save(f, format='JPEG',quality=quality)
f.seek(0)
return Image.open(f).convert('RGB')
def blur_image_v2(img):
x=np.array(img)
kernel_size_candidate=[(3,3),(5,5),(7,7)]
kernel_size=random.sample(kernel_size_candidate,1)[0]
std=random.uniform(1.,5.)
#print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
blur=cv2.GaussianBlur(x,kernel_size,std)
return Image.fromarray(blur.astype(np.uint8))
def online_add_degradation_v2(img):
task_id=np.random.permutation(4)
for x in task_id:
if x==0 and random.uniform(0,1)<0.7:
img = blur_image_v2(img)
if x==1 and random.uniform(0,1)<0.7:
flag = random.choice([1, 2, 3])
if flag == 1:
img = synthesize_gaussian(img, 5, 50)
if flag == 2:
img = synthesize_speckle(img, 5, 50)
if flag == 3:
img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
if x==2 and random.uniform(0,1)<0.7:
img=synthesize_low_resolution(img)
if x==3 and random.uniform(0,1)<0.7:
img=convertToJpeg(img,random.randint(40,100))
return img
def irregular_hole_synthesize(img,mask):
img_np=np.array(img).astype('uint8')
mask_np=np.array(mask).astype('uint8')
mask_np=mask_np/255
img_new=img_np*(1-mask_np)+mask_np*255
hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
return hole_img,mask.convert("L")
def zero_mask(size):
x=np.zeros((size,size,3)).astype('uint8')
mask=Image.fromarray(x).convert("RGB")
return mask
class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old
def initialize(self, opt):
self.opt = opt
self.isImage = 'domainA' in opt.name
self.task = 'old_photo_restoration_training_vae'
self.dir_AB = opt.dataroot
if self.isImage:
self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
else:
# self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
####
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean=[]
for i in range(len(self.loaded_imgs_clean)):
img_name,img=self.loaded_imgs_clean[i]
h,w=img.size
if h<256 or w<256:
continue
self.filtered_imgs_clean.append((img_name,img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
## Filter these images whose size is less than 256
# self.img_list=os.listdir(load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
is_real_old=0
sampled_dataset=None
degradation=None
if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
P=random.uniform(0,2)
if P>=0 and P<1:
if random.uniform(0,1)<0.5:
sampled_dataset=self.loaded_imgs_L_old
self.load_img_dir=self.load_img_dir_L_old
else:
sampled_dataset=self.loaded_imgs_RGB_old
self.load_img_dir=self.load_img_dir_RGB_old
is_real_old=1
if P>=1 and P<2:
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
degradation=1
else:
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
sampled_dataset_len=len(sampled_dataset)
index=random.randint(0,sampled_dataset_len-1)
img_name,img = sampled_dataset[index]
if degradation is not None:
img=online_add_degradation_v2(img)
path=os.path.join(self.load_img_dir,img_name)
# AB = Image.open(path).convert('RGB')
# split AB image into A and B
# apply the same transform to both A and B
if random.uniform(0,1) <0.1:
img=img.convert("L")
img=img.convert("RGB")
## Give a probability P, we convert the RGB image into L
A=img
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number
def name(self):
return 'UnPairOldPhotos_SR'
class PairOldPhotos(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
if self.opt.isTrain:
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
if self.opt.use_v2_degradation:
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
if self.opt.test_on_synthetic:
img_name_B,B=self.loaded_imgs[index]
A=online_add_degradation_v2(B)
img_name_A=img_name_B
path = os.path.join(self.load_img_dir, img_name_A)
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
##test on L
# split AB image into A and B
# w, h = img.size
# w2 = int(w / 2)
# A = img.crop((0, 0, w2, h))
# B = img.crop((w2, 0, w, h))
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
B=transforms.Scale(256, Image.BICUBIC)(B)
# apply the same transform to both A and B
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos'
class PairOldPhotos_with_hole(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)
self.pid = os.getpid()
def __getitem__(self, index):
if self.opt.isTrain:
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
B=transforms.RandomCrop(256)(B)
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
#A=A.resize((256,256))
A=transforms.CenterCrop(256)(A)
B=A
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
if self.opt.isTrain:
mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
else:
mask_name, mask = self.loaded_masks[index%100]
mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)
if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
mask=zero_mask(256)
if self.opt.no_hole:
mask=zero_mask(256)
A,_=irregular_hole_synthesize(A,mask)
if not self.opt.isTrain and self.opt.hole_image_no_mask:
mask=zero_mask(256)
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
if transform_params['flip'] and self.opt.isTrain:
mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
mask_tensor = transforms.ToTensor()(mask)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos_with_hole'
================================================
FILE: Global/detection.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import gc
import json
import os
import time
import warnings
import numpy as np
import torch
import torch.nn.functional as F
import torchvision as tv
from PIL import Image, ImageFile
from detection_models import networks
from detection_util.util import *
warnings.filterwarnings("ignore", category=UserWarning)
ImageFile.LOAD_TRUNCATED_IMAGES = True
def data_transforms(img, full_size, method=Image.BICUBIC):
if full_size == "full_size":
ow, oh = img.size
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
elif full_size == "scale_256":
ow, oh = img.size
pw, ph = ow, oh
if ow < oh:
ow = 256
oh = ph / pw * 256
else:
oh = 256
ow = pw / ph * 256
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == ph) and (w == pw):
return img
return img.resize((w, h), method)
def scale_tensor(img_tensor, default_scale=256):
_, _, w, h = img_tensor.shape
if w < h:
ow = default_scale
oh = h / w * default_scale
else:
oh = default_scale
ow = w / h * default_scale
oh = int(round(oh / 16) * 16)
ow = int(round(ow / 16) * 16)
return F.interpolate(img_tensor, [ow, oh], mode="bilinear")
def blend_mask(img, mask):
np_img = np.array(img).astype("float")
return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB")
def main(config):
print("initializing the dataloader")
model = networks.UNet(
in_channels=1,
out_channels=1,
depth=4,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
)
## load model
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state"])
print("model weights loaded")
if config.GPU >= 0:
model.to(config.GPU)
else:
model.cpu()
model.eval()
## dataloader and transformation
print("directory of testing image: " + config.test_path)
imagelist = os.listdir(config.test_path)
imagelist.sort()
total_iter = 0
P_matrix = {}
save_url = os.path.join(config.output_dir)
mkdir_if_not(save_url)
input_dir = os.path.join(save_url, "input")
output_dir = os.path.join(save_url, "mask")
# blend_output_dir=os.path.join(save_url, 'blend_output')
mkdir_if_not(input_dir)
mkdir_if_not(output_dir)
# mkdir_if_not(blend_output_dir)
idx = 0
results = []
for image_name in imagelist:
idx += 1
print("processing", image_name)
scratch_file = os.path.join(config.test_path, image_name)
if not os.path.isfile(scratch_file):
print("Skipping non-file %s" % image_name)
continue
scratch_image = Image.open(scratch_file).convert("RGB")
w, h = scratch_image.size
transformed_image_PIL = data_transforms(scratch_image, config.input_size)
scratch_image = transformed_image_PIL.convert("L")
scratch_image = tv.transforms.ToTensor()(scratch_image)
scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)
scratch_image = torch.unsqueeze(scratch_image, 0)
_, _, ow, oh = scratch_image.shape
scratch_image_scale = scale_tensor(scratch_image)
if config.GPU >= 0:
scratch_image_scale = scratch_image_scale.to(config.GPU)
else:
scratch_image_scale = scratch_image_scale.cpu()
with torch.no_grad():
P = torch.sigmoid(model(scratch_image_scale))
P = P.data.cpu()
P = F.interpolate(P, [ow, oh], mode="nearest")
tv.utils.save_image(
(P >= 0.4).float(),
os.path.join(
output_dir,
image_name[:-4] + ".png",
),
nrow=1,
padding=0,
normalize=True,
)
transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name')
parser.add_argument("--GPU", type=int, default=0)
parser.add_argument("--test_path", type=str, default=".")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256")
config = parser.parse_args()
main(config)
================================================
FILE: Global/detection_models/__init__.py
================================================
================================================
FILE: Global/detection_models/antialiasing.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.parallel
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class Downsample(nn.Module):
# https://github.com/adobe/antialiased-cnns
def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.0)
self.channels = channels
# print('Filter size [%i]'%filt_size)
if self.filt_size == 1:
a = np.array([1.0,])
elif self.filt_size == 2:
a = np.array([1.0, 1.0])
elif self.filt_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filt_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filt_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filt_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filt_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if self.filt_size == 1:
if self.pad_off == 0:
return inp[:, :, :: self.stride, :: self.stride]
else:
return self.pad(inp)[:, :, :: self.stride, :: self.stride]
else:
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
def get_pad_layer(pad_type):
if pad_type in ["refl", "reflect"]:
PadLayer = nn.ReflectionPad2d
elif pad_type in ["repl", "replicate"]:
PadLayer = nn.ReplicationPad2d
elif pad_type == "zero":
PadLayer = nn.ZeroPad2d
else:
print("Pad type [%s] not recognized" % pad_type)
return PadLayer
================================================
FILE: Global/detection_models/networks.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from detection_models.sync_batchnorm import DataParallelWithCallback
from detection_models.antialiasing import Downsample
class UNet(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
depth=5,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
):
"""
Implementation of
U-Net: Convolutional Networks for Biomedical Image Segmentation
(Ronneberger et al., 2015)
https://arxiv.org/abs/1505.04597
Using the default arguments will yield the exact version used
in the original paper
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
depth (int): depth of the network
wf (int): number of filters in the first layer is 2**wf
padding (bool): if True, apply padding such that the input shape
is the same as the output.
This may introduce artifacts
batch_norm (bool): Use BatchNorm after layers with an
activation function
up_mode (str): one of 'upconv' or 'upsample'.
'upconv' will use transposed convolutions for
learned upsampling.
'upsample' will use bilinear upsampling.
"""
super().__init__()
assert up_mode in ("upconv", "upsample")
self.padding = padding
self.depth = depth - 1
prev_channels = in_channels
self.first = nn.Sequential(
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
)
prev_channels = 2 ** wf
self.down_path = nn.ModuleList()
self.down_sample = nn.ModuleList()
for i in range(depth):
if antialiasing and depth > 0:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
Downsample(channels=prev_channels, stride=2),
]
)
)
else:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
]
)
)
self.down_path.append(
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
)
prev_channels = 2 ** (wf + i + 1)
self.up_path = nn.ModuleList()
for i in reversed(range(depth)):
self.up_path.append(
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
)
prev_channels = 2 ** (wf + i)
if with_tanh:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
)
else:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
)
if sync_bn:
self = DataParallelWithCallback(self)
def forward(self, x):
x = self.first(x)
blocks = []
for i, down_block in enumerate(self.down_path):
blocks.append(x)
x = self.down_sample[i](x)
x = down_block(x)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i - 1])
return self.last(x)
class UNetConvBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
super(UNetConvBlock, self).__init__()
block = []
for _ in range(conv_num):
block.append(nn.ReflectionPad2d(padding=int(padding)))
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
if batch_norm:
block.append(nn.BatchNorm2d(out_size))
block.append(nn.LeakyReLU(0.2, True))
in_size = out_size
self.block = nn.Sequential(*block)
def forward(self, x):
out = self.block(x)
return out
class UNetUpBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
super(UNetUpBlock, self).__init__()
if up_mode == "upconv":
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
elif up_mode == "upsample":
self.up = nn.Sequential(
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
nn.ReflectionPad2d(1),
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
)
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)
def center_crop(self, layer, target_size):
_, _, layer_height, layer_width = layer.size()
diff_y = (layer_height - target_size[0]) // 2
diff_x = (layer_width - target_size[1]) // 2
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.center_crop(bridge, up.shape[2:])
out = torch.cat([up, crop1], 1)
out = self.conv_block(out)
return out
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super().__init__()
if norm_type == "BN":
norm_layer = nn.BatchNorm2d
elif norm_type == "IN":
norm_layer = nn.InstanceNorm2d
else:
raise NameError("Unknown norm layer")
# construct unet structure
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(
ngf * 8,
ngf * 8,
input_nc=None,
submodule=unet_block,
norm_layer=norm_layer,
use_dropout=use_dropout,
)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
self.model = UnetSkipConnectionBlock(
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
) # add the outermost layer
def forward(self, input):
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
-------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(
self,
outer_nc,
inner_nc,
input_nc=None,
submodule=None,
outermost=False,
innermost=False,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super().__init__()
self.outermost = outermost
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.LeakyReLU(0.2, True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
# ============================================
# Network testing
# ============================================
if __name__ == "__main__":
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet_two_decoders(
in_channels=3,
out_channels1=3,
out_channels2=1,
depth=4,
conv_num=1,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
)
model.to(device)
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
model_pix2pix.to(device)
print("customized unet:")
summary(model, (3, 256, 256))
print("cyclegan unet:")
summary(model_pix2pix, (3, 256, 256))
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
g = make_dot(model(x))
g.render("models/Digraph.gv", view=False)
================================================
FILE: Global/detection_util/util.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import time
import shutil
import platform
import numpy as np
from datetime import datetime
import torch
import torchvision as tv
import torch.backends.cudnn as cudnn
# from torch.utils.tensorboard import SummaryWriter
import yaml
import matplotlib.pyplot as plt
from easydict import EasyDict as edict
import torchvision.utils as vutils
##### option parsing ######
def print_options(config_dict):
print("------------ Options -------------")
for k, v in sorted(config_dict.items()):
print("%s: %s" % (str(k), str(v)))
print("-------------- End ----------------")
def save_options(config_dict):
from time import gmtime, strftime
file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"])
mkdir_if_not(file_dir)
file_name = os.path.join(file_dir, "opt.txt")
with open(file_name, "wt") as opt_file:
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
opt_file.write("------------ Options -------------\n")
for k, v in sorted(config_dict.items()):
opt_file.write("%s: %s\n" % (str(k), str(v)))
opt_file.write("-------------- End ----------------\n")
def config_parse(config_file, options, save=True):
with open(config_file, "r") as stream:
config_dict = yaml.safe_load(stream)
config = edict(config_dict)
for option_key, option_value in vars(options).items():
config_dict[option_key] = option_value
config[option_key] = option_value
if config.debug_mode:
config_dict["num_workers"] = 0
config.num_workers = 0
config.batch_size = 2
if isinstance(config.gpu_ids, str):
config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0]
print_options(config_dict)
if save:
save_options(config_dict)
return config
###### utility ######
def to_np(x):
return x.cpu().numpy()
def prepare_device(use_gpu, gpu_ids):
if use_gpu:
cudnn.benchmark = True
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if isinstance(gpu_ids, str):
gpu_ids = [int(x) for x in gpu_ids.split(",")]
torch.cuda.set_device(gpu_ids[0])
device = torch.device("cuda:" + str(gpu_ids[0]))
else:
torch.cuda.set_device(gpu_ids)
device = torch.device("cuda:" + str(gpu_ids))
print("running on GPU {}".format(gpu_ids))
else:
device = torch.device("cpu")
print("running on CPU")
return device
###### file system ######
def get_dir_size(start_path="."):
total_size = 0
for dirpath, dirnames, filenames in os.walk(start_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size
def mkdir_if_not(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
##### System related ######
class Timer:
def __init__(self, msg):
self.msg = msg
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, exc_tb):
elapse = time.time() - self.start_time
print(self.msg % elapse)
###### interactive ######
def get_size(start_path="."):
total_size = 0
for dirpath, dirnames, filenames in os.walk(start_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size
def clean_tensorboard(directory):
tensorboard_list = os.listdir(directory)
SIZE_THRESH = 100000
for tensorboard in tensorboard_list:
tensorboard = os.path.join(directory, tensorboard)
if get_size(tensorboard) < SIZE_THRESH:
print("deleting the empty tensorboard: ", tensorboard)
#
if os.path.isdir(tensorboard):
shutil.rmtree(tensorboard)
else:
os.remove(tensorboard)
def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")):
tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs")
mkdir_if_not(tensorboard_directory)
clean_tensorboard(tensorboard_directory)
tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10)
# try:
# shutil.copy('outputs/opt.txt', tensorboard_directory)
# except:
# print('cannot find file opt.txt')
return tb_writer
def tb_loss_logger(tb_writer, iter_index, loss_logger):
for tag, value in loss_logger.items():
tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index)
def tb_image_logger(tb_writer, iter_index, images_info, config):
### Save and write the output into the tensorboard
tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode)
mkdir_if_not(tb_logger_path)
for tag, image in images_info.items():
if tag == "test_image_prediction" or tag == "image_prediction":
continue
image = tv.utils.make_grid(image.cpu())
image = torch.clamp(image, 0, 1)
tb_writer.add_image(tag, img_tensor=image, global_step=iter_index)
tv.transforms.functional.to_pil_image(image).save(
os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag))
)
def tb_image_logger_test(epoch, iter, images_info, config):
url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch))
if not os.path.exists(url):
os.makedirs(url)
scratch_img = images_info["test_scratch_image"].data.cpu()
if config.norm_input:
scratch_img = (scratch_img + 1.0) / 2.0
scratch_img = torch.clamp(scratch_img, 0, 1)
gt_mask = images_info["test_mask_image"].data.cpu()
predict_mask = images_info["test_scratch_prediction"].data.cpu()
predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float()
imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0)
img_grid = vutils.save_image(
imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True
)
def imshow(input_image, title=None, to_numpy=False):
inp = input_image
if to_numpy or type(input_image) is torch.Tensor:
inp = input_image.numpy()
fig = plt.figure()
if inp.ndim == 2:
fig = plt.imshow(inp, cmap="gray", clim=[0, 255])
else:
fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))
plt.axis("off")
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.title(title)
###### vgg preprocessing ######
def vgg_preprocess(tensor):
# input is RGB tensor which ranges in [0,1]
# output is BGR tensor which ranges in [0,255]
tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
# tensor_bgr = tensor[:, [2, 1, 0], ...]
tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(
1, 3, 1, 1
)
tensor_rst = tensor_bgr_ml * 255
return tensor_rst
def torch_vgg_preprocess(tensor):
# pytorch version normalization
# note that both input and output are RGB tensors;
# input and output ranges in [0,1]
# normalize the tensor with mean and variance
tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)
tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)
return tensor_mc_norm
def network_gradient(net, gradient_on=True):
if gradient_on:
for param in net.parameters():
param.requires_grad = True
else:
for param in net.parameters():
param.requires_grad = False
return net
================================================
FILE: Global/models/NonLocal_feature_mapping_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import functools
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import math
class Mapping_Model_with_mask(nn.Module):
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
super(Mapping_Model_with_mask, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2 ** i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
self.before_NL = nn.Sequential(*model)
if opt.NL_res:
self.NL = networks.NonLocalBlock2D_with_mask_Res(
mc,
mc,
opt.NL_fusion_method,
opt.correlation_renormalize,
opt.softmax_temperature,
opt.use_self,
opt.cosin_similarity,
)
print("You are using NL + Res")
model = []
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL(x1, mask)
del x1, mask
x3 = self.after_NL(x2)
del x2
return x3
class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
super(Mapping_Model_with_mask_2, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2 ** i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
print("Mapping: You are using multi-scale patch attention, conv combine + mask input")
self.before_NL = nn.Sequential(*model)
if opt.mapping_exp==1:
self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_1 = nn.Sequential(*model)
if opt.mapping_exp==1:
self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_2 = nn.Sequential(*model)
if opt.mapping_exp==1:
self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2)
# self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
x2 = self.NL_scale_1(x1,mask)
x3 = self.res_block_1(x2)
x4 = self.NL_scale_2(x3,mask)
x5 = self.res_block_2(x4)
x6 = self.NL_scale_3(x5,mask)
x7 = self.after_NL(x6)
return x7
def inference_forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL_scale_1.inference_forward(x1,mask)
del x1
x3 = self.res_block_1(x2)
del x2
x4 = self.NL_scale_2.inference_forward(x3,mask)
del x3
x5 = self.res_block_2(x4)
del x4
x6 = self.NL_scale_3.inference_forward(x5,mask)
del x5
x7 = self.after_NL(x6)
del x6
return x7
================================================
FILE: Global/models/__init__.py
================================================
================================================
FILE: Global/models/base_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return "BaseModel"
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
def save_optimizer(self, optimizer, optimizer_label, epoch_label):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
else:
optimizer.load_state_dict(torch.load(save_path))
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=""):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
# print(save_dir)
# print(self.save_dir)
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
# if network_label == 'G':
# raise('Generator must exist!')
else:
# network.load_state_dict(torch.load(save_path))
try:
# print(save_path)
network.load_state_dict(torch.load(save_path))
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
# if self.opt.verbose:
print(
"Pretrained network %s has excessive layers; Only loading layers that are used"
% network_label
)
except:
print(
"Pretrained network %s has fewer layers; The following are not initialized:"
% network_label
)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3, 0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split(".")[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass
================================================
FILE: Global/models/mapping_model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import functools
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import math
from .NonLocal_feature_mapping_model import *
class Mapping_Model(nn.Module):
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
super(Mapping_Model, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
print("Mapping: You are using the mapping model without global restoration.")
for i in range(n_up):
ic = min(tmp_nc * (2 ** i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class Pix2PixHDModel_Mapping(BaseModel):
def name(self):
return "Pix2PixHDModel_Mapping"
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2):
flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2)
def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2):
return [
l
for (l, f) in zip(
(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags
)
if f
]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != "none" or not opt.isTrain:
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
self.netG_A = networks.GlobalGenerator_DCDCv2
gitextract_ojvoon3f/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── Face_Detection/ │ ├── align_warp_back_multiple_dlib.py │ ├── align_warp_back_multiple_dlib_HR.py │ ├── detect_all_dlib.py │ └── detect_all_dlib_HR.py ├── Face_Enhancement/ │ ├── data/ │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── custom_dataset.py │ │ ├── face_dataset.py │ │ ├── image_folder.py │ │ └── pix2pix_dataset.py │ ├── models/ │ │ ├── __init__.py │ │ ├── networks/ │ │ │ ├── __init__.py │ │ │ ├── architecture.py │ │ │ ├── base_network.py │ │ │ ├── encoder.py │ │ │ ├── generator.py │ │ │ └── normalization.py │ │ └── pix2pix_model.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ └── test_options.py │ ├── requirements.txt │ ├── test_face.py │ └── util/ │ ├── __init__.py │ ├── iter_counter.py │ ├── util.py │ └── visualizer.py ├── GUI.py ├── Global/ │ ├── data/ │ │ ├── Create_Bigfile.py │ │ ├── Load_Bigfile.py │ │ ├── __init__.py │ │ ├── base_data_loader.py │ │ ├── base_dataset.py │ │ ├── custom_dataset_data_loader.py │ │ ├── data_loader.py │ │ ├── image_folder.py │ │ └── online_dataset_for_old_photos.py │ ├── detection.py │ ├── detection_models/ │ │ ├── __init__.py │ │ ├── antialiasing.py │ │ └── networks.py │ ├── detection_util/ │ │ └── util.py │ ├── models/ │ │ ├── NonLocal_feature_mapping_model.py │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── mapping_model.py │ │ ├── models.py │ │ ├── networks.py │ │ ├── pix2pixHD_model.py │ │ └── pix2pixHD_model_DA.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── test.py │ ├── train_domain_A.py │ ├── train_domain_B.py │ ├── train_mapping.py │ └── util/ │ ├── __init__.py │ ├── image_pool.py │ ├── util.py │ └── visualizer.py ├── LICENSE ├── README.md ├── SECURITY.md ├── ansible.yaml ├── cog.yaml ├── download-weights ├── kubernetes-pod.yml ├── predict.py ├── requirements.txt └── run.py
SYMBOL INDEX (447 symbols across 52 files)
FILE: Face_Detection/align_warp_back_multiple_dlib.py
function calculate_cdf (line 26) | def calculate_cdf(histogram):
function calculate_lookup (line 42) | def calculate_lookup(src_cdf, ref_cdf):
function match_histograms (line 62) | def match_histograms(src_image, ref_image):
function _standard_face_pts (line 112) | def _standard_face_pts():
function _origin_face_pts (line 121) | def _origin_face_pts():
function compute_transformation_matrix (line 127) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
function compute_inverse_transformation_matrix (line 148) | def compute_inverse_transformation_matrix(img, landmark, normalize, targ...
function show_detection (line 169) | def show_detection(image, box, landmark):
function affine2theta (line 185) | def affine2theta(affine, input_w, input_h, target_w, target_h):
function blur_blending (line 198) | def blur_blending(im1, im2, mask):
function blur_blending_cv2 (line 217) | def blur_blending_cv2(im1, im2, mask):
function Poisson_blending (line 239) | def Poisson_blending(im1, im2, mask):
function Poisson_B (line 259) | def Poisson_B(im1, im2, mask, center):
function seamless_clone (line 270) | def seamless_clone(old_face, new_face, raw_mask):
function get_landmark (line 309) | def get_landmark(face_landmarks, id):
function search (line 317) | def search(face_landmarks):
FILE: Face_Detection/align_warp_back_multiple_dlib_HR.py
function calculate_cdf (line 26) | def calculate_cdf(histogram):
function calculate_lookup (line 42) | def calculate_lookup(src_cdf, ref_cdf):
function match_histograms (line 62) | def match_histograms(src_image, ref_image):
function _standard_face_pts (line 112) | def _standard_face_pts():
function _origin_face_pts (line 121) | def _origin_face_pts():
function compute_transformation_matrix (line 127) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
function compute_inverse_transformation_matrix (line 148) | def compute_inverse_transformation_matrix(img, landmark, normalize, targ...
function show_detection (line 169) | def show_detection(image, box, landmark):
function affine2theta (line 185) | def affine2theta(affine, input_w, input_h, target_w, target_h):
function blur_blending (line 198) | def blur_blending(im1, im2, mask):
function blur_blending_cv2 (line 217) | def blur_blending_cv2(im1, im2, mask):
function Poisson_blending (line 239) | def Poisson_blending(im1, im2, mask):
function Poisson_B (line 259) | def Poisson_B(im1, im2, mask, center):
function seamless_clone (line 270) | def seamless_clone(old_face, new_face, raw_mask):
function get_landmark (line 309) | def get_landmark(face_landmarks, id):
function search (line 317) | def search(face_landmarks):
FILE: Face_Detection/detect_all_dlib.py
function _standard_face_pts (line 27) | def _standard_face_pts():
function _origin_face_pts (line 36) | def _origin_face_pts():
function get_landmark (line 42) | def get_landmark(face_landmarks, id):
function search (line 50) | def search(face_landmarks):
function compute_transformation_matrix (line 80) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
function show_detection (line 101) | def show_detection(image, box, landmark):
function affine2theta (line 117) | def affine2theta(affine, input_w, input_h, target_w, target_h):
FILE: Face_Detection/detect_all_dlib_HR.py
function _standard_face_pts (line 27) | def _standard_face_pts():
function _origin_face_pts (line 36) | def _origin_face_pts():
function get_landmark (line 42) | def get_landmark(face_landmarks, id):
function search (line 50) | def search(face_landmarks):
function compute_transformation_matrix (line 80) | def compute_transformation_matrix(img, landmark, normalize, target_face_...
function show_detection (line 101) | def show_detection(image, box, landmark):
function affine2theta (line 117) | def affine2theta(affine, input_w, input_h, target_w, target_h):
FILE: Face_Enhancement/data/__init__.py
function create_dataloader (line 10) | def create_dataloader(opt):
FILE: Face_Enhancement/data/base_dataset.py
class BaseDataset (line 11) | class BaseDataset(data.Dataset):
method __init__ (line 12) | def __init__(self):
method modify_commandline_options (line 16) | def modify_commandline_options(parser, is_train):
method initialize (line 19) | def initialize(self, opt):
function get_params (line 23) | def get_params(opt, size):
function get_transform (line 45) | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toT...
function normalize (line 78) | def normalize():
function __resize (line 82) | def __resize(img, w, h, method=Image.BICUBIC):
function __make_power_2 (line 86) | def __make_power_2(img, base, method=Image.BICUBIC):
function __scale_width (line 95) | def __scale_width(img, target_width, method=Image.BICUBIC):
function __scale_shortside (line 104) | def __scale_shortside(img, target_width, method=Image.BICUBIC):
function __crop (line 115) | def __crop(img, pos, size):
function __flip (line 122) | def __flip(img, flip):
FILE: Face_Enhancement/data/custom_dataset.py
class CustomDataset (line 8) | class CustomDataset(Pix2pixDataset):
method modify_commandline_options (line 15) | def modify_commandline_options(parser, is_train):
method get_paths (line 39) | def get_paths(self, opt):
FILE: Face_Enhancement/data/face_dataset.py
class FaceTestDataset (line 11) | class FaceTestDataset(BaseDataset):
method modify_commandline_options (line 13) | def modify_commandline_options(parser, is_train):
method initialize (line 23) | def initialize(self, opt):
method __getitem__ (line 60) | def __getitem__(self, index):
method __len__ (line 100) | def __len__(self):
FILE: Face_Enhancement/data/image_folder.py
function is_image_file (line 24) | def is_image_file(filename):
function make_dataset_rec (line 28) | def make_dataset_rec(dir, images):
function make_dataset (line 38) | def make_dataset(dir, recursive=False, read_cache=False, write_cache=Fal...
function default_loader (line 69) | def default_loader(path):
class ImageFolder (line 73) | class ImageFolder(data.Dataset):
method __init__ (line 74) | def __init__(self, root, transform=None, return_paths=False, loader=de...
method __getitem__ (line 90) | def __getitem__(self, index):
method __len__ (line 100) | def __len__(self):
FILE: Face_Enhancement/data/pix2pix_dataset.py
class Pix2pixDataset (line 10) | class Pix2pixDataset(BaseDataset):
method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train):
method initialize (line 20) | def initialize(self, opt):
method get_paths (line 48) | def get_paths(self, opt):
method paths_match (line 55) | def paths_match(self, path1, path2):
method __getitem__ (line 60) | def __getitem__(self, index):
method postprocess (line 104) | def postprocess(self, input_dict):
method __len__ (line 107) | def __len__(self):
FILE: Face_Enhancement/models/__init__.py
function find_model_using_name (line 8) | def find_model_using_name(model_name):
function get_option_setter (line 34) | def get_option_setter(model_name):
function create_model (line 39) | def create_model(opt):
FILE: Face_Enhancement/models/networks/__init__.py
function find_network_using_name (line 11) | def find_network_using_name(target_network_name, filename):
function modify_commandline_options (line 21) | def modify_commandline_options(parser, is_train):
function create_network (line 35) | def create_network(cls, opt):
function define_G (line 45) | def define_G(opt):
function define_D (line 50) | def define_D(opt):
function define_E (line 55) | def define_E(opt):
FILE: Face_Enhancement/models/networks/architecture.py
class SPADEResnetBlock (line 19) | class SPADEResnetBlock(nn.Module):
method __init__ (line 20) | def __init__(self, fin, fout, opt):
method forward (line 49) | def forward(self, x, seg, degraded_image):
method shortcut (line 59) | def shortcut(self, x, seg, degraded_image):
method actvn (line 66) | def actvn(self, x):
class ResnetBlock (line 72) | class ResnetBlock(nn.Module):
method __init__ (line 73) | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_...
method forward (line 85) | def forward(self, x):
class VGG19 (line 92) | class VGG19(torch.nn.Module):
method __init__ (line 93) | def __init__(self, requires_grad=False):
method forward (line 115) | def forward(self, X):
class SPADEResnetBlock_non_spade (line 125) | class SPADEResnetBlock_non_spade(nn.Module):
method __init__ (line 126) | def __init__(self, fin, fout, opt):
method forward (line 155) | def forward(self, x, seg, degraded_image):
method shortcut (line 165) | def shortcut(self, x, seg, degraded_image):
method actvn (line 172) | def actvn(self, x):
FILE: Face_Enhancement/models/networks/base_network.py
class BaseNetwork (line 8) | class BaseNetwork(nn.Module):
method __init__ (line 9) | def __init__(self):
method modify_commandline_options (line 13) | def modify_commandline_options(parser, is_train):
method print_network (line 16) | def print_network(self):
method init_weights (line 27) | def init_weights(self, init_type="normal", gain=0.02):
FILE: Face_Enhancement/models/networks/encoder.py
class ConvEncoder (line 11) | class ConvEncoder(BaseNetwork):
method __init__ (line 14) | def __init__(self, opt):
method forward (line 36) | def forward(self, x):
FILE: Face_Enhancement/models/networks/generator.py
class SPADEGenerator (line 14) | class SPADEGenerator(BaseNetwork):
method modify_commandline_options (line 16) | def modify_commandline_options(parser, is_train):
method __init__ (line 27) | def __init__(self, opt):
method compute_latent_vector_size (line 90) | def compute_latent_vector_size(self, opt):
method forward (line 105) | def forward(self, input, degraded_image, z=None):
class Pix2PixHDGenerator (line 151) | class Pix2PixHDGenerator(BaseNetwork):
method modify_commandline_options (line 153) | def modify_commandline_options(parser, is_train):
method __init__ (line 172) | def __init__(self, opt):
method forward (line 231) | def forward(self, input, degraded_image, z=None):
FILE: Face_Enhancement/models/networks/normalization.py
function get_nonspade_norm_layer (line 12) | def get_nonspade_norm_layer(opt, norm_type="instance"):
class SPADE (line 49) | class SPADE(nn.Module):
method __init__ (line 50) | def __init__(self, config_text, norm_nc, label_nc, opt):
method forward (line 81) | def forward(self, x, segmap, degraded_image):
FILE: Face_Enhancement/models/pix2pix_model.py
class Pix2PixModel (line 9) | class Pix2PixModel(torch.nn.Module):
method modify_commandline_options (line 11) | def modify_commandline_options(parser, is_train):
method __init__ (line 15) | def __init__(self, opt):
method forward (line 36) | def forward(self, data, mode):
method create_optimizers (line 55) | def create_optimizers(self, opt):
method save (line 73) | def save(self, epoch):
method initialize_networks (line 83) | def initialize_networks(self, opt):
method preprocess_input (line 101) | def preprocess_input(self, data):
method compute_generator_loss (line 127) | def compute_generator_loss(self, input_semantics, degraded_image, real...
method compute_discriminator_loss (line 157) | def compute_discriminator_loss(self, input_semantics, degraded_image, ...
method encode_z (line 171) | def encode_z(self, real_image):
method generate_fake (line 176) | def generate_fake(self, input_semantics, degraded_image, real_image, c...
method discriminate (line 195) | def discriminate(self, input_semantics, fake_image, real_image):
method divide_pred (line 217) | def divide_pred(self, pred):
method get_edges (line 232) | def get_edges(self, t):
method reparameterize (line 240) | def reparameterize(self, mu, logvar):
method use_gpu (line 245) | def use_gpu(self):
FILE: Face_Enhancement/options/base_options.py
class BaseOptions (line 14) | class BaseOptions:
method __init__ (line 15) | def __init__(self):
method initialize (line 18) | def initialize(self, parser):
method gather_options (line 185) | def gather_options(self):
method print_options (line 215) | def print_options(self, opt):
method option_file_path (line 227) | def option_file_path(self, opt, makedir=False):
method save_options (line 234) | def save_options(self, opt):
method update_options_from_file (line 247) | def update_options_from_file(self, parser, opt):
method load_options (line 255) | def load_options(self, opt):
method parse (line 260) | def parse(self, save=False):
FILE: Face_Enhancement/options/test_options.py
class TestOptions (line 7) | class TestOptions(BaseOptions):
method initialize (line 8) | def initialize(self, parser):
FILE: Face_Enhancement/util/iter_counter.py
class IterationCounter (line 10) | class IterationCounter:
method __init__ (line 11) | def __init__(self, opt, dataset_size):
method training_epochs (line 33) | def training_epochs(self):
method record_epoch_start (line 36) | def record_epoch_start(self, epoch):
method record_one_iteration (line 42) | def record_one_iteration(self):
method record_epoch_end (line 52) | def record_epoch_end(self):
method record_current_iter (line 63) | def record_current_iter(self):
method needs_saving (line 67) | def needs_saving(self):
method needs_printing (line 70) | def needs_printing(self):
method needs_displaying (line 73) | def needs_displaying(self):
FILE: Face_Enhancement/util/util.py
function save_obj (line 15) | def save_obj(obj, name):
function load_obj (line 20) | def load_obj(name):
function copyconf (line 25) | def copyconf(default_opt, **kwargs):
function tensor2im (line 35) | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
function tensor2label (line 67) | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
function save_image (line 97) | def save_image(image_numpy, image_path, create_dir=False):
function mkdirs (line 110) | def mkdirs(paths):
function mkdir (line 118) | def mkdir(path):
function atoi (line 123) | def atoi(text):
function natural_keys (line 127) | def natural_keys(text):
function natural_sort (line 136) | def natural_sort(items):
function str2bool (line 140) | def str2bool(v):
function find_class_in_module (line 149) | def find_class_in_module(target_cls_name, module):
function save_network (line 167) | def save_network(net, label, epoch, opt):
function load_network (line 175) | def load_network(net, label, epoch, opt):
function uint82bin (line 190) | def uint82bin(n, count=8):
class Colorize (line 195) | class Colorize(object):
method __init__ (line 196) | def __init__(self, n=35):
method __call__ (line 200) | def __call__(self, gray_image):
FILE: Face_Enhancement/util/visualizer.py
class Visualizer (line 20) | class Visualizer:
method __init__ (line 21) | def __init__(self, opt):
method display_current_results (line 49) | def display_current_results(self, visuals, epoch, step):
method plot_current_errors (line 72) | def plot_current_errors(self, errors, step):
method print_current_errors (line 93) | def print_current_errors(self, epoch, i, errors, t):
method convert_visuals_to_numpy (line 103) | def convert_visuals_to_numpy(self, visuals):
method save_images (line 114) | def save_images(self, webpage, visuals, image_path):
FILE: GUI.py
function modify (line 11) | def modify(image_filename=None, cv2_frame=None):
FILE: Global/data/Create_Bigfile.py
function is_image_file (line 14) | def is_image_file(filename):
function make_dataset (line 18) | def make_dataset(dir):
FILE: Global/data/Load_Bigfile.py
class BigFileMemoryLoader (line 9) | class BigFileMemoryLoader(object):
method __load_bigfile (line 10) | def __load_bigfile(self):
method __init__ (line 27) | def __init__(self, file_path):
method __getitem__ (line 32) | def __getitem__(self, index):
method __len__ (line 41) | def __len__(self):
FILE: Global/data/base_data_loader.py
class BaseDataLoader (line 4) | class BaseDataLoader():
method __init__ (line 5) | def __init__(self):
method initialize (line 8) | def initialize(self, opt):
method load_data (line 12) | def load_data():
FILE: Global/data/base_dataset.py
class BaseDataset (line 10) | class BaseDataset(data.Dataset):
method __init__ (line 11) | def __init__(self):
method name (line 14) | def name(self):
method initialize (line 17) | def initialize(self, opt):
function get_params (line 20) | def get_params(opt, size):
function get_transform (line 46) | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
function normalize (line 84) | def normalize():
function __make_power_2 (line 87) | def __make_power_2(img, base, method=Image.BICUBIC):
function __scale_width (line 95) | def __scale_width(img, target_width, method=Image.BICUBIC):
function __crop (line 103) | def __crop(img, pos, size):
function __flip (line 111) | def __flip(img, flip):
FILE: Global/data/custom_dataset_data_loader.py
function CreateDataset (line 10) | def CreateDataset(opt):
class CustomDatasetDataLoader (line 23) | class CustomDatasetDataLoader(BaseDataLoader):
method name (line 24) | def name(self):
method initialize (line 27) | def initialize(self, opt):
method load_data (line 37) | def load_data(self):
method __len__ (line 40) | def __len__(self):
FILE: Global/data/data_loader.py
function CreateDataLoader (line 4) | def CreateDataLoader(opt):
FILE: Global/data/image_folder.py
function is_image_file (line 14) | def is_image_file(filename):
function make_dataset (line 18) | def make_dataset(dir):
function default_loader (line 31) | def default_loader(path):
class ImageFolder (line 35) | class ImageFolder(data.Dataset):
method __init__ (line 37) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 51) | def __getitem__(self, index):
method __len__ (line 61) | def __len__(self):
FILE: Global/data/online_dataset_for_old_photos.py
function pil_to_np (line 17) | def pil_to_np(img_PIL):
function np_to_pil (line 32) | def np_to_pil(img_np):
function synthesize_salt_pepper (line 46) | def synthesize_salt_pepper(image,amount,salt_vs_pepper):
function synthesize_gaussian (line 67) | def synthesize_gaussian(image,std_l,std_r):
function synthesize_speckle (line 81) | def synthesize_speckle(image,std_l,std_r):
function synthesize_low_resolution (line 96) | def synthesize_low_resolution(img):
function convertToJpeg (line 112) | def convertToJpeg(im,quality):
function blur_image_v2 (line 119) | def blur_image_v2(img):
function online_add_degradation_v2 (line 132) | def online_add_degradation_v2(img):
function irregular_hole_synthesize (line 156) | def irregular_hole_synthesize(img,mask):
function zero_mask (line 168) | def zero_mask(size):
class UnPairOldPhotos_SR (line 175) | class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old
method initialize (line 176) | def initialize(self, opt):
method __getitem__ (line 213) | def __getitem__(self, index):
method __len__ (line 278) | def __len__(self):
method name (line 281) | def name(self):
class PairOldPhotos (line 285) | class PairOldPhotos(BaseDataset):
method initialize (line 286) | def initialize(self, opt):
method __getitem__ (line 313) | def __getitem__(self, index):
method __len__ (line 370) | def __len__(self):
method name (line 377) | def name(self):
class PairOldPhotos_with_hole (line 381) | class PairOldPhotos_with_hole(BaseDataset):
method initialize (line 382) | def initialize(self, opt):
method __getitem__ (line 411) | def __getitem__(self, index):
method __len__ (line 476) | def __len__(self):
method name (line 484) | def name(self):
FILE: Global/detection.py
function data_transforms (line 25) | def data_transforms(img, full_size, method=Image.BICUBIC):
function scale_tensor (line 51) | def scale_tensor(img_tensor, default_scale=256):
function blend_mask (line 66) | def blend_mask(img, mask):
function main (line 73) | def main(config):
FILE: Global/detection_models/antialiasing.py
class Downsample (line 11) | class Downsample(nn.Module):
method __init__ (line 14) | def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels...
method forward (line 51) | def forward(self, inp):
function get_pad_layer (line 61) | def get_pad_layer(pad_type):
FILE: Global/detection_models/networks.py
class UNet (line 11) | class UNet(nn.Module):
method __init__ (line 12) | def __init__(
method forward (line 109) | def forward(self, x):
class UNetConvBlock (line 124) | class UNetConvBlock(nn.Module):
method __init__ (line 125) | def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
method forward (line 139) | def forward(self, x):
class UNetUpBlock (line 144) | class UNetUpBlock(nn.Module):
method __init__ (line 145) | def __init__(self, conv_num, in_size, out_size, up_mode, padding, batc...
method center_crop (line 158) | def center_crop(self, layer, target_size):
method forward (line 164) | def forward(self, x, bridge):
class UnetGenerator (line 173) | class UnetGenerator(nn.Module):
method __init__ (line 176) | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="...
method forward (line 223) | def forward(self, input):
class UnetSkipConnectionBlock (line 227) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 234) | def __init__(
method forward (line 291) | def forward(self, x):
FILE: Global/detection_util/util.py
function print_options (line 25) | def print_options(config_dict):
function save_options (line 32) | def save_options(config_dict):
function config_parse (line 46) | def config_parse(config_file, options, save=True):
function to_np (line 70) | def to_np(x):
function prepare_device (line 74) | def prepare_device(use_gpu, gpu_ids):
function get_dir_size (line 94) | def get_dir_size(start_path="."):
function mkdir_if_not (line 103) | def mkdir_if_not(dir_path):
class Timer (line 109) | class Timer:
method __init__ (line 110) | def __init__(self, msg):
method __enter__ (line 114) | def __enter__(self):
method __exit__ (line 117) | def __exit__(self, exc_type, exc_value, exc_tb):
function get_size (line 123) | def get_size(start_path="."):
function clean_tensorboard (line 132) | def clean_tensorboard(directory):
function prepare_tensorboard (line 146) | def prepare_tensorboard(config, experiment_name=datetime.now().strftime(...
function tb_loss_logger (line 159) | def tb_loss_logger(tb_writer, iter_index, loss_logger):
function tb_image_logger (line 164) | def tb_image_logger(tb_writer, iter_index, images_info, config):
function tb_image_logger_test (line 179) | def tb_image_logger_test(epoch, iter, images_info, config):
function imshow (line 199) | def imshow(input_image, title=None, to_numpy=False):
function vgg_preprocess (line 216) | def vgg_preprocess(tensor):
function torch_vgg_preprocess (line 228) | def torch_vgg_preprocess(tensor):
function network_gradient (line 238) | def network_gradient(net, gradient_on=True):
FILE: Global/models/NonLocal_feature_mapping_model.py
class Mapping_Model_with_mask (line 17) | class Mapping_Model_with_mask(nn.Module):
method __init__ (line 18) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
method forward (line 71) | def forward(self, input, mask):
class Mapping_Model_with_mask_2 (line 81) | class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
method __init__ (line 82) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
method forward (line 177) | def forward(self, input, mask):
method inference_forward (line 187) | def inference_forward(self, input, mask):
FILE: Global/models/base_model.py
class BaseModel (line 9) | class BaseModel(torch.nn.Module):
method name (line 10) | def name(self):
method initialize (line 13) | def initialize(self, opt):
method set_input (line 20) | def set_input(self, input):
method forward (line 23) | def forward(self):
method test (line 27) | def test(self):
method get_image_paths (line 30) | def get_image_paths(self):
method optimize_parameters (line 33) | def optimize_parameters(self):
method get_current_visuals (line 36) | def get_current_visuals(self):
method get_current_errors (line 39) | def get_current_errors(self):
method save (line 42) | def save(self, label):
method save_network (line 46) | def save_network(self, network, network_label, epoch_label, gpu_ids):
method save_optimizer (line 53) | def save_optimizer(self, optimizer, optimizer_label, epoch_label):
method load_optimizer (line 58) | def load_optimizer(self, optimizer, optimizer_label, epoch_label, save...
method load_network (line 70) | def load_network(self, network, network_label, epoch_label, save_dir=""):
method update_learning_rate (line 121) | def update_learning_rate():
FILE: Global/models/mapping_model.py
class Mapping_Model (line 18) | class Mapping_Model(nn.Module):
method __init__ (line 19) | def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_typ...
method forward (line 56) | def forward(self, input):
class Pix2PixHDModel_Mapping (line 60) | class Pix2PixHDModel_Mapping(BaseModel):
method name (line 61) | def name(self):
method init_loss_filter (line 64) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth...
method initialize (line 78) | def initialize(self, opt):
method encode_input (line 215) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
method discriminate (line 240) | def discriminate(self, input_label, test_image, use_pool=False):
method forward (line 248) | def forward(self, label, inst, image, feat, pair=True, infer=False, la...
method inference (line 325) | def inference(self, label, inst):
class InferenceModel (line 349) | class InferenceModel(Pix2PixHDModel_Mapping):
method forward (line 350) | def forward(self, label, inst):
FILE: Global/models/models.py
function create_model (line 7) | def create_model(opt):
function create_da_model (line 29) | def create_da_model(opt):
FILE: Global/models/networks.py
function weights_init (line 17) | def weights_init(m):
function get_norm_layer (line 26) | def get_norm_layer(norm_type="instance"):
function print_network (line 40) | def print_network(net):
function define_G (line 50) | def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_glob...
function define_D (line 70) | def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoi...
class GlobalGenerator_DCDCv2 (line 82) | class GlobalGenerator_DCDCv2(nn.Module):
method __init__ (line 83) | def __init__(
method forward (line 283) | def forward(self, input, flow="enc_dec"):
class ResnetBlock (line 295) | class ResnetBlock(nn.Module):
method __init__ (line 296) | def __init__(
method build_conv_block (line 304) | def build_conv_block(self, dim, padding_type, norm_layer, activation, ...
method forward (line 337) | def forward(self, x):
class Encoder (line 342) | class Encoder(nn.Module):
method __init__ (line 343) | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm...
method forward (line 376) | def forward(self, input, inst):
function SN (line 394) | def SN(module, mode=True):
class NonLocalBlock2D_with_mask_Res (line 401) | class NonLocalBlock2D_with_mask_Res(nn.Module):
method __init__ (line 402) | def __init__(
method forward (line 460) | def forward(self, x, mask): ## The shape of mask is Batch*1*H*W
class MultiscaleDiscriminator (line 526) | class MultiscaleDiscriminator(nn.Module):
method __init__ (line 527) | def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.Ba...
method singleD_forward (line 544) | def singleD_forward(self, model, input):
method forward (line 553) | def forward(self, input):
class NLayerDiscriminator (line 568) | class NLayerDiscriminator(nn.Module):
method __init__ (line 569) | def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.Ba...
method forward (line 609) | def forward(self, input):
class Patch_Attention_4 (line 621) | class Patch_Attention_4(nn.Module): ## While combine the feature map, u...
method __init__ (line 622) | def __init__(self, in_channels, inter_channels, patch_size):
method Hard_Compose (line 666) | def Hard_Compose(self, input, dim, index):
method forward (line 678) | def forward(self, z, mask): ## The shape of mask is Batch*1*H*W
method inference_forward (line 720) | def inference_forward(self,z,mask): ## Reduce the extra memory cost
class GANLoss (line 781) | class GANLoss(nn.Module):
method __init__ (line 782) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
method get_target_tensor (line 795) | def get_target_tensor(self, input, target_is_real):
method __call__ (line 813) | def __call__(self, input, target_is_real):
class VGG19_torch (line 831) | class VGG19_torch(torch.nn.Module):
method __init__ (line 832) | def __init__(self, requires_grad=False):
method forward (line 854) | def forward(self, X):
class VGGLoss_torch (line 863) | class VGGLoss_torch(nn.Module):
method __init__ (line 864) | def __init__(self, gpu_ids):
method forward (line 870) | def forward(self, x, y):
FILE: Global/models/pix2pixHD_model.py
class Pix2PixHDModel (line 12) | class Pix2PixHDModel(BaseModel):
method name (line 13) | def name(self):
method init_loss_filter (line 16) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_...
method initialize (line 22) | def initialize(self, opt):
method encode_input (line 112) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
method discriminate (line 145) | def discriminate(self, input_label, test_image, use_pool=False):
method forward (line 156) | def forward(self, label, inst, image, feat, infer=False):
method inference (line 221) | def inference(self, label, inst, image=None, feat=None):
method sample_features (line 245) | def sample_features(self, inst):
method encode_features (line 266) | def encode_features(self, image, inst):
method get_edges (line 288) | def get_edges(self, t):
method save (line 299) | def save(self, which_epoch):
method update_fixed_params (line 309) | def update_fixed_params(self):
method update_learning_rate (line 318) | def update_learning_rate(self):
class InferenceModel (line 330) | class InferenceModel(Pix2PixHDModel):
method forward (line 331) | def forward(self, inp):
FILE: Global/models/pix2pixHD_model_DA.py
class Pix2PixHDModel (line 13) | class Pix2PixHDModel(BaseModel):
method name (line 14) | def name(self):
method init_loss_filter (line 17) | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
method initialize (line 25) | def initialize(self, opt):
method encode_input (line 118) | def encode_input(self, label_map, inst_map=None, real_image=None, feat...
method discriminate (line 151) | def discriminate(self, input_label, test_image, use_pool=False):
method feat_discriminate (line 162) | def feat_discriminate(self,input):
method forward (line 167) | def forward(self, label, inst, image, feat, infer=False):
method inference (line 256) | def inference(self, label, inst, image=None, feat=None):
method sample_features (line 280) | def sample_features(self, inst):
method encode_features (line 301) | def encode_features(self, image, inst):
method get_edges (line 323) | def get_edges(self, t):
method save (line 334) | def save(self, which_epoch):
method update_fixed_params (line 346) | def update_fixed_params(self):
method update_learning_rate (line 355) | def update_learning_rate(self):
class InferenceModel (line 369) | class InferenceModel(Pix2PixHDModel):
method forward (line 370) | def forward(self, inp):
FILE: Global/options/base_options.py
class BaseOptions (line 10) | class BaseOptions:
method __init__ (line 11) | def __init__(self):
method initialize (line 15) | def initialize(self):
method parse (line 338) | def parse(self, save=True):
FILE: Global/options/test_options.py
class TestOptions (line 7) | class TestOptions(BaseOptions):
method initialize (line 8) | def initialize(self):
FILE: Global/options/train_options.py
class TrainOptions (line 6) | class TrainOptions(BaseOptions):
method initialize (line 7) | def initialize(self):
FILE: Global/test.py
function data_transforms (line 18) | def data_transforms(img, method=Image.BILINEAR, scale=False):
function data_transforms_rgb_old (line 39) | def data_transforms_rgb_old(img):
function irregular_hole_synthesize (line 47) | def irregular_hole_synthesize(img, mask):
function parameter_set (line 59) | def parameter_set(opt):
FILE: Global/util/image_pool.py
class ImagePool (line 9) | class ImagePool:
method __init__ (line 10) | def __init__(self, pool_size):
method query (line 16) | def query(self, images):
FILE: Global/util/util.py
function tensor2im (line 14) | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
function tensor2label (line 32) | def tensor2label(label_tensor, n_label, imtype=np.uint8):
function save_image (line 43) | def save_image(image_numpy, image_path):
function mkdirs (line 48) | def mkdirs(paths):
function mkdir (line 56) | def mkdir(path):
FILE: Global/util/visualizer.py
class Visualizer (line 16) | class Visualizer():
method __init__ (line 17) | def __init__(self, opt):
method display_current_results (line 40) | def display_current_results(self, visuals, epoch, step):
method plot_current_errors (line 98) | def plot_current_errors(self, errors, step):
method print_current_errors (line 105) | def print_current_errors(self, epoch, i, errors, t, lr):
method print_save (line 116) | def print_save(self,message):
method save_images (line 125) | def save_images(self, webpage, visuals, image_path):
FILE: predict.py
class Predictor (line 12) | class Predictor(cog.Predictor):
method setup (line 13) | def setup(self):
method predict (line 51) | def predict(self, image, HR=False, with_scratch=False):
function clean_folder (line 213) | def clean_folder(folder):
FILE: run.py
function run_cmd (line 10) | def run_cmd(command):
Condensed preview — 75 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (379K chars).
[
{
"path": ".gitignore",
"chars": 23,
"preview": "__pycache__/\n*.pyc\n*~\n\n"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 444,
"preview": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://op"
},
{
"path": "Dockerfile",
"chars": 1358,
"preview": "FROM nvidia/cuda:11.1-base-ubuntu20.04\n\nRUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzi"
},
{
"path": "Face_Detection/align_warp_back_multiple_dlib.py",
"chars": 13487,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
},
{
"path": "Face_Detection/align_warp_back_multiple_dlib_HR.py",
"chars": 13487,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
},
{
"path": "Face_Detection/detect_all_dlib.py",
"chars": 5291,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
},
{
"path": "Face_Detection/detect_all_dlib_HR.py",
"chars": 5291,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage"
},
{
"path": "Face_Enhancement/data/__init__.py",
"chars": 624,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch.utils.data\nfrom "
},
{
"path": "Face_Enhancement/data/base_dataset.py",
"chars": 3892,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
},
{
"path": "Face_Enhancement/data/custom_dataset.py",
"chars": 2149,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.pix2pix_dataset import Pix2pixDatase"
},
{
"path": "Face_Enhancement/data/face_dataset.py",
"chars": 3029,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get"
},
{
"path": "Face_Enhancement/data/image_folder.py",
"chars": 2717,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
},
{
"path": "Face_Enhancement/data/pix2pix_dataset.py",
"chars": 3911,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get"
},
{
"path": "Face_Enhancement/models/__init__.py",
"chars": 1336,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch\n\n\ndef find_model"
},
{
"path": "Face_Enhancement/models/networks/__init__.py",
"chars": 1773,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nfrom models.networks.base_network"
},
{
"path": "Face_Enhancement/models/networks/architecture.py",
"chars": 6291,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
},
{
"path": "Face_Enhancement/models/networks/base_network.py",
"chars": 2370,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nfrom torch.nn import ini"
},
{
"path": "Face_Enhancement/models/networks/encoder.py",
"chars": 1873,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nimport numpy as np\nimpor"
},
{
"path": "Face_Enhancement/models/networks/generator.py",
"chars": 8343,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
},
{
"path": "Face_Enhancement/models/networks/normalization.py",
"chars": 3848,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport torch\nimport torch.nn as nn\ni"
},
{
"path": "Face_Enhancement/models/pix2pix_model.py",
"chars": 9786,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport models.networks as network"
},
{
"path": "Face_Enhancement/options/__init__.py",
"chars": 73,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
},
{
"path": "Face_Enhancement/options/base_options.py",
"chars": 10935,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nimport argparse\nimport os\nfrom util"
},
{
"path": "Face_Enhancement/options/test_options.py",
"chars": 968,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass "
},
{
"path": "Face_Enhancement/requirements.txt",
"chars": 97,
"preview": "torch>=1.0.0\ntorchvision\ndominate>=2.3.1\nwandb\ndill\nscikit-image\ntensorboardX\nscipy\nopencv-python"
},
{
"path": "Face_Enhancement/test_face.py",
"chars": 1077,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\n"
},
{
"path": "Face_Enhancement/util/__init__.py",
"chars": 73,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
},
{
"path": "Face_Enhancement/util/iter_counter.py",
"chars": 3009,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport time\nimport numpy as np\n\n\n# H"
},
{
"path": "Face_Enhancement/util/util.py",
"chars": 6590,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport importlib\nimport torch\nfrom a"
},
{
"path": "Face_Enhancement/util/visualizer.py",
"chars": 4786,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport ntpath\nimport time\nfrom . imp"
},
{
"path": "GUI.py",
"chars": 7241,
"preview": "import numpy as np\nimport cv2\nimport PySimpleGUI as sg\nimport os.path\nimport argparse\nimport os\nimport sys\nimport shutil"
},
{
"path": "Global/data/Create_Bigfile.py",
"chars": 1952,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport struct\nfrom PIL import Image\n"
},
{
"path": "Global/data/Load_Bigfile.py",
"chars": 1590,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport io\nimport os\nimport struct\nfrom PIL imp"
},
{
"path": "Global/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Global/data/base_data_loader.py",
"chars": 268,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nclass BaseDataLoader():\n def __init__(self)"
},
{
"path": "Global/data/base_dataset.py",
"chars": 3596,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
},
{
"path": "Global/data/custom_dataset_data_loader.py",
"chars": 1316,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data\nimport random\nfrom dat"
},
{
"path": "Global/data/data_loader.py",
"chars": 302,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\ndef CreateDataLoader(opt):\n from data.custo"
},
{
"path": "Global/data/image_folder.py",
"chars": 1643,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL impor"
},
{
"path": "Global/data/online_dataset_for_old_photos.py",
"chars": 15313,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os.path\nimport io\nimport zipfile\nfrom d"
},
{
"path": "Global/detection.py",
"chars": 5032,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport gc\nimport json\nimport o"
},
{
"path": "Global/detection_models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Global/detection_models/antialiasing.py",
"chars": 2451,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn.parallel\nimport n"
},
{
"path": "Global/detection_models/networks.py",
"chars": 11753,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torc"
},
{
"path": "Global/detection_util/util.py",
"chars": 7981,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport sys\nimport time\nimport shutil"
},
{
"path": "Global/models/NonLocal_feature_mapping_model.py",
"chars": 6414,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.n"
},
{
"path": "Global/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Global/models/base_model.py",
"chars": 4281,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport torch\nimport sys\n\n\nclass Base"
},
{
"path": "Global/models/mapping_model.py",
"chars": 13758,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.n"
},
{
"path": "Global/models/models.py",
"chars": 1212,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\n\n\ndef create_model(opt):\n if o"
},
{
"path": "Global/models/networks.py",
"chars": 30705,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport func"
},
{
"path": "Global/models/pix2pixHD_model.py",
"chars": 15127,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom"
},
{
"path": "Global/models/pix2pixHD_model_DA.py",
"chars": 16499,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom"
},
{
"path": "Global/options/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Global/options/base_options.py",
"chars": 16231,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport os\nfrom util import uti"
},
{
"path": "Global/options/test_options.py",
"chars": 4605,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass "
},
{
"path": "Global/options/train_options.py",
"chars": 5199,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\nclass T"
},
{
"path": "Global/test.py",
"chars": 6028,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\n"
},
{
"path": "Global/train_domain_A.py",
"chars": 5403,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
},
{
"path": "Global/train_domain_B.py",
"chars": 5193,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
},
{
"path": "Global/train_mapping.py",
"chars": 5954,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDic"
},
{
"path": "Global/util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Global/util/image_pool.py",
"chars": 1166,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport random\nimport torch\nfrom torch.autograd"
},
{
"path": "Global/util/util.py",
"chars": 1845,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import print_function\nimport t"
},
{
"path": "Global/util/visualizer.py",
"chars": 5790,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport os\nimport ntpath\nimp"
},
{
"path": "LICENSE",
"chars": 1141,
"preview": " MIT License\n\n Copyright (c) Microsoft Corporation.\n\n Permission is hereby granted, free of charge, to any pers"
},
{
"path": "README.md",
"chars": 12303,
"preview": "# Old Photo Restoration (Official PyTorch Implementation)\n\n<img src='imgs/0001.jpg'/>\n\n### [Project Page](http://raywzy."
},
{
"path": "SECURITY.md",
"chars": 2780,
"preview": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products an"
},
{
"path": "ansible.yaml",
"chars": 3506,
"preview": "---\r\n- name: Bringing-Old-Photos-Back-to-Life\r\n hosts: all\r\n gather_facts: no\r\n\r\n# Succesfully tested on Ubuntu 18.04\\"
},
{
"path": "cog.yaml",
"chars": 551,
"preview": "build:\n gpu: true\n python_version: \"3.8\"\n system_packages:\n - \"libgl1-mesa-glx\"\n - \"libglib2.0-0\"\n python_pack"
},
{
"path": "download-weights",
"chars": 847,
"preview": "#!/bin/sh\n\ncd Face_Enhancement/models/networks\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -r"
},
{
"path": "kubernetes-pod.yml",
"chars": 727,
"preview": "apiVersion: v1\nkind: Pod\nmetadata:\n name: photo-back2life\nspec:\n containers:\n - name: photos-back2life\n image:"
},
{
"path": "predict.py",
"chars": 8086,
"preview": "import tempfile\nfrom pathlib import Path\nimport argparse\nimport shutil\nimport os\nimport glob\nimport cv2\nimport cog\nfrom "
},
{
"path": "requirements.txt",
"chars": 136,
"preview": "torch\ntorchvision\ndlib\nscikit-image\neasydict\nPyYAML\ndominate>=2.3.1\ndill\ntensorboardX\nscipy\nopencv-python\neinops\nPySimpl"
},
{
"path": "run.py",
"chars": 6775,
"preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport argparse\nimport shutil\nimport"
}
]
About this extraction
This page contains the full source code of the microsoft/Bringing-Old-Photos-Back-to-Life GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 75 files (353.2 KB), approximately 90.5k tokens, and a symbol index with 447 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.