Repository: microsoft/Bringing-Old-Photos-Back-to-Life Branch: master Commit: 33875eccf4eb Files: 75 Total size: 353.2 KB Directory structure: gitextract_ojvoon3f/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── Face_Detection/ │ ├── align_warp_back_multiple_dlib.py │ ├── align_warp_back_multiple_dlib_HR.py │ ├── detect_all_dlib.py │ └── detect_all_dlib_HR.py ├── Face_Enhancement/ │ ├── data/ │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── custom_dataset.py │ │ ├── face_dataset.py │ │ ├── image_folder.py │ │ └── pix2pix_dataset.py │ ├── models/ │ │ ├── __init__.py │ │ ├── networks/ │ │ │ ├── __init__.py │ │ │ ├── architecture.py │ │ │ ├── base_network.py │ │ │ ├── encoder.py │ │ │ ├── generator.py │ │ │ └── normalization.py │ │ └── pix2pix_model.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ └── test_options.py │ ├── requirements.txt │ ├── test_face.py │ └── util/ │ ├── __init__.py │ ├── iter_counter.py │ ├── util.py │ └── visualizer.py ├── GUI.py ├── Global/ │ ├── data/ │ │ ├── Create_Bigfile.py │ │ ├── Load_Bigfile.py │ │ ├── __init__.py │ │ ├── base_data_loader.py │ │ ├── base_dataset.py │ │ ├── custom_dataset_data_loader.py │ │ ├── data_loader.py │ │ ├── image_folder.py │ │ └── online_dataset_for_old_photos.py │ ├── detection.py │ ├── detection_models/ │ │ ├── __init__.py │ │ ├── antialiasing.py │ │ └── networks.py │ ├── detection_util/ │ │ └── util.py │ ├── models/ │ │ ├── NonLocal_feature_mapping_model.py │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── mapping_model.py │ │ ├── models.py │ │ ├── networks.py │ │ ├── pix2pixHD_model.py │ │ └── pix2pixHD_model_DA.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── test.py │ ├── train_domain_A.py │ ├── train_domain_B.py │ ├── train_mapping.py │ └── util/ │ ├── __init__.py │ ├── image_pool.py │ ├── util.py │ └── visualizer.py ├── LICENSE ├── README.md ├── SECURITY.md ├── ansible.yaml ├── cog.yaml ├── download-weights ├── kubernetes-pod.yml ├── predict.py ├── requirements.txt └── run.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ *.pyc *~ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Microsoft Open Source Code of Conduct This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). Resources: - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns ================================================ FILE: Dockerfile ================================================ FROM nvidia/cuda:11.1-base-ubuntu20.04 RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq ADD . /app WORKDIR /app RUN cd Face_Enhancement/models/networks/ &&\ git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\ cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\ cd ../../../ RUN cd Global/detection_models &&\ git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\ cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\ cd ../../ RUN cd Face_Detection/ &&\ wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\ bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\ cd ../ RUN cd Face_Enhancement/ &&\ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\ unzip checkpoints.zip &&\ cd ../ &&\ cd Global/ &&\ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\ unzip checkpoints.zip &&\ rm -f checkpoints.zip &&\ cd ../ RUN pip3 install numpy RUN pip3 install dlib RUN pip3 install -r requirements.txt RUN git clone https://github.com/NVlabs/SPADE.git RUN cd SPADE/ && pip3 install -r requirements.txt RUN cd .. CMD ["python3", "run.py"] ================================================ FILE: Face_Detection/align_warp_back_multiple_dlib.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import numpy as np import skimage.io as io # from face_sdk import FaceDetection import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from skimage.transform import SimilarityTransform from skimage.transform import warp from PIL import Image, ImageFilter import torch.nn.functional as F import torchvision as tv import torchvision.utils as vutils import time import cv2 import os from skimage import img_as_ubyte import json import argparse import dlib def calculate_cdf(histogram): """ This method calculates the cumulative distribution function :param array histogram: The values of the histogram :return: normalized_cdf: The normalized cumulative distribution function :rtype: array """ # Get the cumulative sum of the elements cdf = histogram.cumsum() # Normalize the cdf normalized_cdf = cdf / float(cdf.max()) return normalized_cdf def calculate_lookup(src_cdf, ref_cdf): """ This method creates the lookup table :param array src_cdf: The cdf for the source image :param array ref_cdf: The cdf for the reference image :return: lookup_table: The lookup table :rtype: array """ lookup_table = np.zeros(256) lookup_val = 0 for src_pixel_val in range(len(src_cdf)): lookup_val for ref_pixel_val in range(len(ref_cdf)): if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: lookup_val = ref_pixel_val break lookup_table[src_pixel_val] = lookup_val return lookup_table def match_histograms(src_image, ref_image): """ This method matches the source image histogram to the reference signal :param image src_image: The original source image :param image ref_image: The reference image :return: image_after_matching :rtype: image (array) """ # Split the images into the different color channels # b means blue, g means green and r means red src_b, src_g, src_r = cv2.split(src_image) ref_b, ref_g, ref_r = cv2.split(ref_image) # Compute the b, g, and r histograms separately # The flatten() Numpy method returns a copy of the array c # collapsed into one dimension. src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) # Compute the normalized cdf for the source and reference image src_cdf_blue = calculate_cdf(src_hist_blue) src_cdf_green = calculate_cdf(src_hist_green) src_cdf_red = calculate_cdf(src_hist_red) ref_cdf_blue = calculate_cdf(ref_hist_blue) ref_cdf_green = calculate_cdf(ref_hist_green) ref_cdf_red = calculate_cdf(ref_hist_red) # Make a separate lookup table for each color blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) # Use the lookup function to transform the colors of the original # source image blue_after_transform = cv2.LUT(src_b, blue_lookup_table) green_after_transform = cv2.LUT(src_g, green_lookup_table) red_after_transform = cv2.LUT(src_r, red_lookup_table) # Put the image back together image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) image_after_matching = cv2.convertScaleAbs(image_after_matching) return image_after_matching def _standard_face_pts(): pts = ( np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 - 1.0 ) return np.reshape(pts, (5, 2)) def _origin_face_pts(): pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) return np.reshape(pts, (5, 2)) def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(target_pts, landmark) return affine def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(landmark, target_pts) return affine def show_detection(image, box, landmark): plt.imshow(image) print(box[2] - box[0]) plt.gca().add_patch( Rectangle( (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" ) ) plt.scatter(landmark[0][0], landmark[0][1]) plt.scatter(landmark[1][0], landmark[1][1]) plt.scatter(landmark[2][0], landmark[2][1]) plt.scatter(landmark[3][0], landmark[3][1]) plt.scatter(landmark[4][0], landmark[4][1]) plt.show() def affine2theta(affine, input_w, input_h, target_w, target_h): # param = np.linalg.inv(affine) param = affine theta = np.zeros([2, 3]) theta[0, 0] = param[0, 0] * input_h / target_h theta[0, 1] = param[0, 1] * input_w / target_h theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 theta[1, 0] = param[1, 0] * input_h / target_w theta[1, 1] = param[1, 1] * input_w / target_w theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 return theta def blur_blending(im1, im2, mask): mask *= 255.0 kernel = np.ones((10, 10), np.uint8) mask = cv2.erode(mask, kernel, iterations=1) mask = Image.fromarray(mask.astype("uint8")).convert("L") im1 = Image.fromarray(im1.astype("uint8")) im2 = Image.fromarray(im2.astype("uint8")) mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) im = Image.composite(im1, im2, mask) im = Image.composite(im, im2, mask_blur) return np.array(im) / 255.0 def blur_blending_cv2(im1, im2, mask): mask *= 255.0 kernel = np.ones((9, 9), np.uint8) mask = cv2.erode(mask, kernel, iterations=3) mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) mask_blur /= 255.0 im = im1 * mask_blur + (1 - mask_blur) * im2 im /= 255.0 im = np.clip(im, 0.0, 1.0) return im # def Poisson_blending(im1,im2,mask): # Image.composite( def Poisson_blending(im1, im2, mask): # mask=1-mask mask *= 255 kernel = np.ones((10, 10), np.uint8) mask = cv2.erode(mask, kernel, iterations=1) mask /= 255 mask = 1 - mask mask *= 255 mask = mask[:, :, 0] width, height, channels = im1.shape center = (int(height / 2), int(width / 2)) result = cv2.seamlessClone( im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE ) return result / 255.0 def Poisson_B(im1, im2, mask, center): mask *= 255 result = cv2.seamlessClone( im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE ) return result / 255 def seamless_clone(old_face, new_face, raw_mask): height, width, _ = old_face.shape height = height // 2 width = width // 2 y_indices, x_indices, _ = np.nonzero(raw_mask) y_crop = slice(np.min(y_indices), np.max(y_indices)) x_crop = slice(np.min(x_indices), np.max(x_indices)) y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") insertion_mask[insertion_mask != 0] = 255 prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( "uint8" ) # if np.sum(insertion_mask) == 0: n_mask = insertion_mask[1:-1, 1:-1, :] n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) print(n_mask.shape) x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) if w < 4 or h < 4: blended = prior else: blended = cv2.seamlessClone( insertion, # pylint: disable=no-member prior, insertion_mask, (x_center, y_center), cv2.NORMAL_CLONE, ) # pylint: disable=no-member blended = blended[height:-height, width:-width] return blended.astype("float32") / 255.0 def get_landmark(face_landmarks, id): part = face_landmarks.part(id) x = part.x y = part.y return (x, y) def search(face_landmarks): x1, y1 = get_landmark(face_landmarks, 36) x2, y2 = get_landmark(face_landmarks, 39) x3, y3 = get_landmark(face_landmarks, 42) x4, y4 = get_landmark(face_landmarks, 45) x_nose, y_nose = get_landmark(face_landmarks, 30) x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) x_left_eye = int((x1 + x2) / 2) y_left_eye = int((y1 + y2) / 2) x_right_eye = int((x3 + x4) / 2) y_right_eye = int((y3 + y4) / 2) results = np.array( [ [x_left_eye, y_left_eye], [x_right_eye, y_right_eye], [x_nose, y_nose], [x_left_mouth, y_left_mouth], [x_right_mouth, y_right_mouth], ] ) return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--origin_url", type=str, default="./", help="origin images") parser.add_argument("--replace_url", type=str, default="./", help="restored faces") parser.add_argument("--save_url", type=str, default="./save") opts = parser.parse_args() origin_url = opts.origin_url replace_url = opts.replace_url save_url = opts.save_url if not os.path.exists(save_url): os.makedirs(save_url) face_detector = dlib.get_frontal_face_detector() landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") count = 0 for x in os.listdir(origin_url): img_url = os.path.join(origin_url, x) pil_img = Image.open(img_url).convert("RGB") origin_width, origin_height = pil_img.size image = np.array(pil_img) start = time.time() faces = face_detector(image) done = time.time() if len(faces) == 0: print("Warning: There is no face in %s" % (x)) continue blended = image for face_id in range(len(faces)): current_face = faces[face_id] face_landmarks = landmark_locator(image, current_face) current_fl = search(face_landmarks) forward_mask = np.ones_like(image).astype("uint8") affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True) forward_mask = warp( forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True ) affine_inverse = affine.inverse cur_face = aligned_face if replace_url != "": face_name = x[:-4] + "_" + str(face_id + 1) + ".png" cur_url = os.path.join(replace_url, face_name) restored_face = Image.open(cur_url).convert("RGB") restored_face = np.array(restored_face) cur_face = restored_face ## Histogram Color matching A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) B = match_histograms(B, A) cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) warped_back = warp( cur_face, affine_inverse, output_shape=(origin_height, origin_width, 3), order=3, preserve_range=True, ) backward_mask = warp( forward_mask, affine_inverse, output_shape=(origin_height, origin_width, 3), order=0, preserve_range=True, ) ## Nearest neighbour blended = blur_blending_cv2(warped_back, blended, backward_mask) blended *= 255.0 io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) count += 1 if count % 1000 == 0: print("%d have finished ..." % (count)) ================================================ FILE: Face_Detection/align_warp_back_multiple_dlib_HR.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import numpy as np import skimage.io as io # from face_sdk import FaceDetection import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from skimage.transform import SimilarityTransform from skimage.transform import warp from PIL import Image, ImageFilter import torch.nn.functional as F import torchvision as tv import torchvision.utils as vutils import time import cv2 import os from skimage import img_as_ubyte import json import argparse import dlib def calculate_cdf(histogram): """ This method calculates the cumulative distribution function :param array histogram: The values of the histogram :return: normalized_cdf: The normalized cumulative distribution function :rtype: array """ # Get the cumulative sum of the elements cdf = histogram.cumsum() # Normalize the cdf normalized_cdf = cdf / float(cdf.max()) return normalized_cdf def calculate_lookup(src_cdf, ref_cdf): """ This method creates the lookup table :param array src_cdf: The cdf for the source image :param array ref_cdf: The cdf for the reference image :return: lookup_table: The lookup table :rtype: array """ lookup_table = np.zeros(256) lookup_val = 0 for src_pixel_val in range(len(src_cdf)): lookup_val for ref_pixel_val in range(len(ref_cdf)): if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: lookup_val = ref_pixel_val break lookup_table[src_pixel_val] = lookup_val return lookup_table def match_histograms(src_image, ref_image): """ This method matches the source image histogram to the reference signal :param image src_image: The original source image :param image ref_image: The reference image :return: image_after_matching :rtype: image (array) """ # Split the images into the different color channels # b means blue, g means green and r means red src_b, src_g, src_r = cv2.split(src_image) ref_b, ref_g, ref_r = cv2.split(ref_image) # Compute the b, g, and r histograms separately # The flatten() Numpy method returns a copy of the array c # collapsed into one dimension. src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) # Compute the normalized cdf for the source and reference image src_cdf_blue = calculate_cdf(src_hist_blue) src_cdf_green = calculate_cdf(src_hist_green) src_cdf_red = calculate_cdf(src_hist_red) ref_cdf_blue = calculate_cdf(ref_hist_blue) ref_cdf_green = calculate_cdf(ref_hist_green) ref_cdf_red = calculate_cdf(ref_hist_red) # Make a separate lookup table for each color blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) # Use the lookup function to transform the colors of the original # source image blue_after_transform = cv2.LUT(src_b, blue_lookup_table) green_after_transform = cv2.LUT(src_g, green_lookup_table) red_after_transform = cv2.LUT(src_r, red_lookup_table) # Put the image back together image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) image_after_matching = cv2.convertScaleAbs(image_after_matching) return image_after_matching def _standard_face_pts(): pts = ( np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 - 1.0 ) return np.reshape(pts, (5, 2)) def _origin_face_pts(): pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) return np.reshape(pts, (5, 2)) def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(target_pts, landmark) return affine def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(landmark, target_pts) return affine def show_detection(image, box, landmark): plt.imshow(image) print(box[2] - box[0]) plt.gca().add_patch( Rectangle( (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" ) ) plt.scatter(landmark[0][0], landmark[0][1]) plt.scatter(landmark[1][0], landmark[1][1]) plt.scatter(landmark[2][0], landmark[2][1]) plt.scatter(landmark[3][0], landmark[3][1]) plt.scatter(landmark[4][0], landmark[4][1]) plt.show() def affine2theta(affine, input_w, input_h, target_w, target_h): # param = np.linalg.inv(affine) param = affine theta = np.zeros([2, 3]) theta[0, 0] = param[0, 0] * input_h / target_h theta[0, 1] = param[0, 1] * input_w / target_h theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 theta[1, 0] = param[1, 0] * input_h / target_w theta[1, 1] = param[1, 1] * input_w / target_w theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 return theta def blur_blending(im1, im2, mask): mask *= 255.0 kernel = np.ones((10, 10), np.uint8) mask = cv2.erode(mask, kernel, iterations=1) mask = Image.fromarray(mask.astype("uint8")).convert("L") im1 = Image.fromarray(im1.astype("uint8")) im2 = Image.fromarray(im2.astype("uint8")) mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) im = Image.composite(im1, im2, mask) im = Image.composite(im, im2, mask_blur) return np.array(im) / 255.0 def blur_blending_cv2(im1, im2, mask): mask *= 255.0 kernel = np.ones((9, 9), np.uint8) mask = cv2.erode(mask, kernel, iterations=3) mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) mask_blur /= 255.0 im = im1 * mask_blur + (1 - mask_blur) * im2 im /= 255.0 im = np.clip(im, 0.0, 1.0) return im # def Poisson_blending(im1,im2,mask): # Image.composite( def Poisson_blending(im1, im2, mask): # mask=1-mask mask *= 255 kernel = np.ones((10, 10), np.uint8) mask = cv2.erode(mask, kernel, iterations=1) mask /= 255 mask = 1 - mask mask *= 255 mask = mask[:, :, 0] width, height, channels = im1.shape center = (int(height / 2), int(width / 2)) result = cv2.seamlessClone( im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE ) return result / 255.0 def Poisson_B(im1, im2, mask, center): mask *= 255 result = cv2.seamlessClone( im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE ) return result / 255 def seamless_clone(old_face, new_face, raw_mask): height, width, _ = old_face.shape height = height // 2 width = width // 2 y_indices, x_indices, _ = np.nonzero(raw_mask) y_crop = slice(np.min(y_indices), np.max(y_indices)) x_crop = slice(np.min(x_indices), np.max(x_indices)) y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") insertion_mask[insertion_mask != 0] = 255 prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( "uint8" ) # if np.sum(insertion_mask) == 0: n_mask = insertion_mask[1:-1, 1:-1, :] n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) print(n_mask.shape) x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) if w < 4 or h < 4: blended = prior else: blended = cv2.seamlessClone( insertion, # pylint: disable=no-member prior, insertion_mask, (x_center, y_center), cv2.NORMAL_CLONE, ) # pylint: disable=no-member blended = blended[height:-height, width:-width] return blended.astype("float32") / 255.0 def get_landmark(face_landmarks, id): part = face_landmarks.part(id) x = part.x y = part.y return (x, y) def search(face_landmarks): x1, y1 = get_landmark(face_landmarks, 36) x2, y2 = get_landmark(face_landmarks, 39) x3, y3 = get_landmark(face_landmarks, 42) x4, y4 = get_landmark(face_landmarks, 45) x_nose, y_nose = get_landmark(face_landmarks, 30) x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) x_left_eye = int((x1 + x2) / 2) y_left_eye = int((y1 + y2) / 2) x_right_eye = int((x3 + x4) / 2) y_right_eye = int((y3 + y4) / 2) results = np.array( [ [x_left_eye, y_left_eye], [x_right_eye, y_right_eye], [x_nose, y_nose], [x_left_mouth, y_left_mouth], [x_right_mouth, y_right_mouth], ] ) return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--origin_url", type=str, default="./", help="origin images") parser.add_argument("--replace_url", type=str, default="./", help="restored faces") parser.add_argument("--save_url", type=str, default="./save") opts = parser.parse_args() origin_url = opts.origin_url replace_url = opts.replace_url save_url = opts.save_url if not os.path.exists(save_url): os.makedirs(save_url) face_detector = dlib.get_frontal_face_detector() landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") count = 0 for x in os.listdir(origin_url): img_url = os.path.join(origin_url, x) pil_img = Image.open(img_url).convert("RGB") origin_width, origin_height = pil_img.size image = np.array(pil_img) start = time.time() faces = face_detector(image) done = time.time() if len(faces) == 0: print("Warning: There is no face in %s" % (x)) continue blended = image for face_id in range(len(faces)): current_face = faces[face_id] face_landmarks = landmark_locator(image, current_face) current_fl = search(face_landmarks) forward_mask = np.ones_like(image).astype("uint8") affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True) forward_mask = warp( forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True ) affine_inverse = affine.inverse cur_face = aligned_face if replace_url != "": face_name = x[:-4] + "_" + str(face_id + 1) + ".png" cur_url = os.path.join(replace_url, face_name) restored_face = Image.open(cur_url).convert("RGB") restored_face = np.array(restored_face) cur_face = restored_face ## Histogram Color matching A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) B = match_histograms(B, A) cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) warped_back = warp( cur_face, affine_inverse, output_shape=(origin_height, origin_width, 3), order=3, preserve_range=True, ) backward_mask = warp( forward_mask, affine_inverse, output_shape=(origin_height, origin_width, 3), order=0, preserve_range=True, ) ## Nearest neighbour blended = blur_blending_cv2(warped_back, blended, backward_mask) blended *= 255.0 io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) count += 1 if count % 1000 == 0: print("%d have finished ..." % (count)) ================================================ FILE: Face_Detection/detect_all_dlib.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import numpy as np import skimage.io as io # from FaceSDK.face_sdk import FaceDetection # from face_sdk import FaceDetection import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from skimage.transform import SimilarityTransform from skimage.transform import warp from PIL import Image import torch.nn.functional as F import torchvision as tv import torchvision.utils as vutils import time import cv2 import os from skimage import img_as_ubyte import json import argparse import dlib def _standard_face_pts(): pts = ( np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 - 1.0 ) return np.reshape(pts, (5, 2)) def _origin_face_pts(): pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) return np.reshape(pts, (5, 2)) def get_landmark(face_landmarks, id): part = face_landmarks.part(id) x = part.x y = part.y return (x, y) def search(face_landmarks): x1, y1 = get_landmark(face_landmarks, 36) x2, y2 = get_landmark(face_landmarks, 39) x3, y3 = get_landmark(face_landmarks, 42) x4, y4 = get_landmark(face_landmarks, 45) x_nose, y_nose = get_landmark(face_landmarks, 30) x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) x_left_eye = int((x1 + x2) / 2) y_left_eye = int((y1 + y2) / 2) x_right_eye = int((x3 + x4) / 2) y_right_eye = int((y3 + y4) / 2) results = np.array( [ [x_left_eye, y_left_eye], [x_right_eye, y_right_eye], [x_nose, y_nose], [x_left_mouth, y_left_mouth], [x_right_mouth, y_right_mouth], ] ) return results def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(target_pts, landmark) return affine.params def show_detection(image, box, landmark): plt.imshow(image) print(box[2] - box[0]) plt.gca().add_patch( Rectangle( (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" ) ) plt.scatter(landmark[0][0], landmark[0][1]) plt.scatter(landmark[1][0], landmark[1][1]) plt.scatter(landmark[2][0], landmark[2][1]) plt.scatter(landmark[3][0], landmark[3][1]) plt.scatter(landmark[4][0], landmark[4][1]) plt.show() def affine2theta(affine, input_w, input_h, target_w, target_h): # param = np.linalg.inv(affine) param = affine theta = np.zeros([2, 3]) theta[0, 0] = param[0, 0] * input_h / target_h theta[0, 1] = param[0, 1] * input_w / target_h theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 theta[1, 0] = param[1, 0] * input_h / target_w theta[1, 1] = param[1, 1] * input_w / target_w theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 return theta if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input") parser.add_argument( "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output" ) opts = parser.parse_args() url = opts.url save_url = opts.save_url ### If the origin url is None, then we don't need to reid the origin image os.makedirs(url, exist_ok=True) os.makedirs(save_url, exist_ok=True) face_detector = dlib.get_frontal_face_detector() landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") count = 0 map_id = {} for x in os.listdir(url): img_url = os.path.join(url, x) pil_img = Image.open(img_url).convert("RGB") image = np.array(pil_img) start = time.time() faces = face_detector(image) done = time.time() if len(faces) == 0: print("Warning: There is no face in %s" % (x)) continue print(len(faces)) if len(faces) > 0: for face_id in range(len(faces)): current_face = faces[face_id] face_landmarks = landmark_locator(image, current_face) current_fl = search(face_landmarks) affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) aligned_face = warp(image, affine, output_shape=(256, 256, 3)) img_name = x[:-4] + "_" + str(face_id + 1) io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face)) count += 1 if count % 1000 == 0: print("%d have finished ..." % (count)) ================================================ FILE: Face_Detection/detect_all_dlib_HR.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import numpy as np import skimage.io as io # from FaceSDK.face_sdk import FaceDetection # from face_sdk import FaceDetection import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from skimage.transform import SimilarityTransform from skimage.transform import warp from PIL import Image import torch.nn.functional as F import torchvision as tv import torchvision.utils as vutils import time import cv2 import os from skimage import img_as_ubyte import json import argparse import dlib def _standard_face_pts(): pts = ( np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 - 1.0 ) return np.reshape(pts, (5, 2)) def _origin_face_pts(): pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) return np.reshape(pts, (5, 2)) def get_landmark(face_landmarks, id): part = face_landmarks.part(id) x = part.x y = part.y return (x, y) def search(face_landmarks): x1, y1 = get_landmark(face_landmarks, 36) x2, y2 = get_landmark(face_landmarks, 39) x3, y3 = get_landmark(face_landmarks, 42) x4, y4 = get_landmark(face_landmarks, 45) x_nose, y_nose = get_landmark(face_landmarks, 30) x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) x_left_eye = int((x1 + x2) / 2) y_left_eye = int((y1 + y2) / 2) x_right_eye = int((x3 + x4) / 2) y_right_eye = int((y3 + y4) / 2) results = np.array( [ [x_left_eye, y_left_eye], [x_right_eye, y_right_eye], [x_nose, y_nose], [x_left_mouth, y_left_mouth], [x_right_mouth, y_right_mouth], ] ) return results def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): std_pts = _standard_face_pts() # [-1,1] target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 # print(target_pts) h, w, c = img.shape if normalize == True: landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 # print(landmark) affine = SimilarityTransform() affine.estimate(target_pts, landmark) return affine.params def show_detection(image, box, landmark): plt.imshow(image) print(box[2] - box[0]) plt.gca().add_patch( Rectangle( (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" ) ) plt.scatter(landmark[0][0], landmark[0][1]) plt.scatter(landmark[1][0], landmark[1][1]) plt.scatter(landmark[2][0], landmark[2][1]) plt.scatter(landmark[3][0], landmark[3][1]) plt.scatter(landmark[4][0], landmark[4][1]) plt.show() def affine2theta(affine, input_w, input_h, target_w, target_h): # param = np.linalg.inv(affine) param = affine theta = np.zeros([2, 3]) theta[0, 0] = param[0, 0] * input_h / target_h theta[0, 1] = param[0, 1] * input_w / target_h theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 theta[1, 0] = param[1, 0] * input_h / target_w theta[1, 1] = param[1, 1] * input_w / target_w theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 return theta if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input") parser.add_argument( "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output" ) opts = parser.parse_args() url = opts.url save_url = opts.save_url ### If the origin url is None, then we don't need to reid the origin image os.makedirs(url, exist_ok=True) os.makedirs(save_url, exist_ok=True) face_detector = dlib.get_frontal_face_detector() landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") count = 0 map_id = {} for x in os.listdir(url): img_url = os.path.join(url, x) pil_img = Image.open(img_url).convert("RGB") image = np.array(pil_img) start = time.time() faces = face_detector(image) done = time.time() if len(faces) == 0: print("Warning: There is no face in %s" % (x)) continue print(len(faces)) if len(faces) > 0: for face_id in range(len(faces)): current_face = faces[face_id] face_landmarks = landmark_locator(image, current_face) current_fl = search(face_landmarks) affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) aligned_face = warp(image, affine, output_shape=(512, 512, 3)) img_name = x[:-4] + "_" + str(face_id + 1) io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face)) count += 1 if count % 1000 == 0: print("%d have finished ..." % (count)) ================================================ FILE: Face_Enhancement/data/__init__.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import importlib import torch.utils.data from data.base_dataset import BaseDataset from data.face_dataset import FaceTestDataset def create_dataloader(opt): instance = FaceTestDataset() instance.initialize(opt) print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance))) dataloader = torch.utils.data.DataLoader( instance, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads), drop_last=opt.isTrain, ) return dataloader ================================================ FILE: Face_Enhancement/data/base_dataset.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms import numpy as np import random class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() @staticmethod def modify_commandline_options(parser, is_train): return parser def initialize(self, opt): pass def get_params(opt, size): w, h = size new_h = h new_w = w if opt.preprocess_mode == "resize_and_crop": new_h = new_w = opt.load_size elif opt.preprocess_mode == "scale_width_and_crop": new_w = opt.load_size new_h = opt.load_size * h // w elif opt.preprocess_mode == "scale_shortside_and_crop": ss, ls = min(w, h), max(w, h) # shortside and longside width_is_shorter = w == ss ls = int(opt.load_size * ls / ss) new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) flip = random.random() > 0.5 return {"crop_pos": (x, y), "flip": flip} def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): transform_list = [] if "resize" in opt.preprocess_mode: osize = [opt.load_size, opt.load_size] transform_list.append(transforms.Resize(osize, interpolation=method)) elif "scale_width" in opt.preprocess_mode: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) elif "scale_shortside" in opt.preprocess_mode: transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) if "crop" in opt.preprocess_mode: transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size))) if opt.preprocess_mode == "none": base = 32 transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) if opt.preprocess_mode == "fixed": w = opt.crop_size h = round(opt.crop_size / opt.aspect_ratio) transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"]))) if toTensor: transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __resize(img, w, h, method=Image.BICUBIC): return img.resize((w, h), method) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) def __scale_width(img, target_width, method=Image.BICUBIC): ow, oh = img.size if ow == target_width: return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), method) def __scale_shortside(img, target_width, method=Image.BICUBIC): ow, oh = img.size ss, ls = min(ow, oh), max(ow, oh) # shortside and longside width_is_shorter = ow == ss if ss == target_width: return img ls = int(target_width * ls / ss) nw, nh = (ss, ls) if width_is_shorter else (ls, ss) return img.resize((nw, nh), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size return img.crop((x1, y1, x1 + tw, y1 + th)) def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img ================================================ FILE: Face_Enhancement/data/custom_dataset.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from data.pix2pix_dataset import Pix2pixDataset from data.image_folder import make_dataset class CustomDataset(Pix2pixDataset): """ Dataset that loads images from directories Use option --label_dir, --image_dir, --instance_dir to specify the directories. The images in the directories are sorted in alphabetical order and paired in order. """ @staticmethod def modify_commandline_options(parser, is_train): parser = Pix2pixDataset.modify_commandline_options(parser, is_train) parser.set_defaults(preprocess_mode="resize_and_crop") load_size = 286 if is_train else 256 parser.set_defaults(load_size=load_size) parser.set_defaults(crop_size=256) parser.set_defaults(display_winsize=256) parser.set_defaults(label_nc=13) parser.set_defaults(contain_dontcare_label=False) parser.add_argument( "--label_dir", type=str, required=True, help="path to the directory that contains label images" ) parser.add_argument( "--image_dir", type=str, required=True, help="path to the directory that contains photo images" ) parser.add_argument( "--instance_dir", type=str, default="", help="path to the directory that contains instance maps. Leave black if not exists", ) return parser def get_paths(self, opt): label_dir = opt.label_dir label_paths = make_dataset(label_dir, recursive=False, read_cache=True) image_dir = opt.image_dir image_paths = make_dataset(image_dir, recursive=False, read_cache=True) if len(opt.instance_dir) > 0: instance_dir = opt.instance_dir instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) else: instance_paths = [] assert len(label_paths) == len( image_paths ), "The #images in %s and %s do not match. Is there something wrong?" return label_paths, image_paths, instance_paths ================================================ FILE: Face_Enhancement/data/face_dataset.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from data.base_dataset import BaseDataset, get_params, get_transform from PIL import Image import util.util as util import os import torch class FaceTestDataset(BaseDataset): @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument( "--no_pairing_check", action="store_true", help="If specified, skip sanity check of correct label-image file pairing", ) # parser.set_defaults(contain_dontcare_label=False) # parser.set_defaults(no_instance=True) return parser def initialize(self, opt): self.opt = opt image_path = os.path.join(opt.dataroot, opt.old_face_folder) label_path = os.path.join(opt.dataroot, opt.old_face_label_folder) image_list = os.listdir(image_path) image_list = sorted(image_list) # image_list=image_list[:opt.max_dataset_size] self.label_paths = label_path ## Just the root dir self.image_paths = image_list ## All the image name self.parts = [ "skin", "hair", "l_brow", "r_brow", "l_eye", "r_eye", "eye_g", "l_ear", "r_ear", "ear_r", "nose", "mouth", "u_lip", "l_lip", "neck", "neck_l", "cloth", "hat", ] size = len(self.image_paths) self.dataset_size = size def __getitem__(self, index): params = get_params(self.opt, (-1, -1)) image_name = self.image_paths[index] image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name) image = Image.open(image_path) image = image.convert("RGB") transform_image = get_transform(self.opt, params) image_tensor = transform_image(image) img_name = image_name[:-4] transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) full_label = [] cnt = 0 for each_part in self.parts: part_name = img_name + "_" + each_part + ".png" part_url = os.path.join(self.label_paths, part_name) if os.path.exists(part_url): label = Image.open(part_url).convert("RGB") label_tensor = transform_label(label) ## 3 channels and pixel [0,1] full_label.append(label_tensor[0]) else: current_part = torch.zeros((self.opt.load_size, self.opt.load_size)) full_label.append(current_part) cnt += 1 full_label_tensor = torch.stack(full_label, 0) input_dict = { "label": full_label_tensor, "image": image_tensor, "path": image_path, } return input_dict def __len__(self): return self.dataset_size ================================================ FILE: Face_Enhancement/data/image_folder.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import os IMG_EXTENSIONS = [ ".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tiff", ".webp", ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset_rec(dir, images): assert os.path.isdir(dir), "%s is not a valid directory" % dir for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): images = [] if read_cache: possible_filelist = os.path.join(dir, "files.list") if os.path.isfile(possible_filelist): with open(possible_filelist, "r") as f: images = f.read().splitlines() return images if recursive: make_dataset_rec(dir, images) else: assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir for root, dnames, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) if write_cache: filelist_cache = os.path.join(dir, "files.list") with open(filelist_cache, "w") as f: for path in images: f.write("%s\n" % path) print("wrote filelist cache at %s" % filelist_cache) return images def default_loader(path): return Image.open(path).convert("RGB") class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise ( RuntimeError( "Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) ) ) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: Face_Enhancement/data/pix2pix_dataset.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from data.base_dataset import BaseDataset, get_params, get_transform from PIL import Image import util.util as util import os class Pix2pixDataset(BaseDataset): @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument( "--no_pairing_check", action="store_true", help="If specified, skip sanity check of correct label-image file pairing", ) return parser def initialize(self, opt): self.opt = opt label_paths, image_paths, instance_paths = self.get_paths(opt) util.natural_sort(label_paths) util.natural_sort(image_paths) if not opt.no_instance: util.natural_sort(instance_paths) label_paths = label_paths[: opt.max_dataset_size] image_paths = image_paths[: opt.max_dataset_size] instance_paths = instance_paths[: opt.max_dataset_size] if not opt.no_pairing_check: for path1, path2 in zip(label_paths, image_paths): assert self.paths_match(path1, path2), ( "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) ) self.label_paths = label_paths self.image_paths = image_paths self.instance_paths = instance_paths size = len(self.label_paths) self.dataset_size = size def get_paths(self, opt): label_paths = [] image_paths = [] instance_paths = [] assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" return label_paths, image_paths, instance_paths def paths_match(self, path1, path2): filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] return filename1_without_ext == filename2_without_ext def __getitem__(self, index): # Label Image label_path = self.label_paths[index] label = Image.open(label_path) params = get_params(self.opt, label.size) transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) label_tensor = transform_label(label) * 255.0 label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc # input image (real images) image_path = self.image_paths[index] assert self.paths_match( label_path, image_path ), "The label_path %s and image_path %s don't match." % (label_path, image_path) image = Image.open(image_path) image = image.convert("RGB") transform_image = get_transform(self.opt, params) image_tensor = transform_image(image) # if using instance maps if self.opt.no_instance: instance_tensor = 0 else: instance_path = self.instance_paths[index] instance = Image.open(instance_path) if instance.mode == "L": instance_tensor = transform_label(instance) * 255 instance_tensor = instance_tensor.long() else: instance_tensor = transform_label(instance) input_dict = { "label": label_tensor, "instance": instance_tensor, "image": image_tensor, "path": image_path, } # Give subclasses a chance to modify the final output self.postprocess(input_dict) return input_dict def postprocess(self, input_dict): return input_dict def __len__(self): return self.dataset_size ================================================ FILE: Face_Enhancement/models/__init__.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import importlib import torch def find_model_using_name(model_name): # Given the option --model [modelname], # the file "models/modelname_model.py" # will be imported. model_filename = "models." + model_name + "_model" modellib = importlib.import_module(model_filename) # In the file, the class called ModelNameModel() will # be instantiated. It has to be a subclass of torch.nn.Module, # and it is case-insensitive. model = None target_model_name = model_name.replace("_", "") + "model" for name, cls in modellib.__dict__.items(): if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module): model = cls if model is None: print( "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name) ) exit(0) return model def get_option_setter(model_name): model_class = find_model_using_name(model_name) return model_class.modify_commandline_options def create_model(opt): model = find_model_using_name(opt.model) instance = model(opt) print("model [%s] was created" % (type(instance).__name__)) return instance ================================================ FILE: Face_Enhancement/models/networks/__init__.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch from models.networks.base_network import BaseNetwork from models.networks.generator import * from models.networks.encoder import * import util.util as util def find_network_using_name(target_network_name, filename): target_class_name = target_network_name + filename module_name = "models.networks." + filename network = util.find_class_in_module(target_class_name, module_name) assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network return network def modify_commandline_options(parser, is_train): opt, _ = parser.parse_known_args() netG_cls = find_network_using_name(opt.netG, "generator") parser = netG_cls.modify_commandline_options(parser, is_train) if is_train: netD_cls = find_network_using_name(opt.netD, "discriminator") parser = netD_cls.modify_commandline_options(parser, is_train) netE_cls = find_network_using_name("conv", "encoder") parser = netE_cls.modify_commandline_options(parser, is_train) return parser def create_network(cls, opt): net = cls(opt) net.print_network() if len(opt.gpu_ids) > 0: assert torch.cuda.is_available() net.cuda() net.init_weights(opt.init_type, opt.init_variance) return net def define_G(opt): netG_cls = find_network_using_name(opt.netG, "generator") return create_network(netG_cls, opt) def define_D(opt): netD_cls = find_network_using_name(opt.netD, "discriminator") return create_network(netD_cls, opt) def define_E(opt): # there exists only one encoder type netE_cls = find_network_using_name("conv", "encoder") return create_network(netE_cls, opt) ================================================ FILE: Face_Enhancement/models/networks/architecture.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torch.nn.utils.spectral_norm as spectral_norm from models.networks.normalization import SPADE # ResNet block that uses SPADE. # It differs from the ResNet block of pix2pixHD in that # it takes in the segmentation map as input, learns the skip connection if necessary, # and applies normalization first and then convolution. # This architecture seemed like a standard architecture for unconditional or # class-conditional GAN architecture using residual block. # The code was inspired from https://github.com/LMescheder/GAN_stability. class SPADEResnetBlock(nn.Module): def __init__(self, fin, fout, opt): super().__init__() # Attributes self.learned_shortcut = fin != fout fmiddle = min(fin, fout) self.opt = opt # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # apply spectral norm if specified if "spectral" in opt.norm_G: self.conv_0 = spectral_norm(self.conv_0) self.conv_1 = spectral_norm(self.conv_1) if self.learned_shortcut: self.conv_s = spectral_norm(self.conv_s) # define normalization layers spade_config_str = opt.norm_G.replace("spectral", "") self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) if self.learned_shortcut: self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) # note the resnet block with SPADE also takes in |seg|, # the semantic segmentation map as input def forward(self, x, seg, degraded_image): x_s = self.shortcut(x, seg, degraded_image) dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image))) dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image))) out = x_s + dx return out def shortcut(self, x, seg, degraded_image): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg, degraded_image)) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) # ResNet block used in pix2pixHD # We keep the same architecture as pix2pixHD. class ResnetBlock(nn.Module): def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): super().__init__() pw = (kernel_size - 1) // 2 self.conv_block = nn.Sequential( nn.ReflectionPad2d(pw), norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), activation, nn.ReflectionPad2d(pw), norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), ) def forward(self, x): y = self.conv_block(x) out = x + y return out # VGG architecter, used for the perceptual loss using a pretrained VGG network class VGG19(torch.nn.Module): def __init__(self, requires_grad=False): super().__init__() vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class SPADEResnetBlock_non_spade(nn.Module): def __init__(self, fin, fout, opt): super().__init__() # Attributes self.learned_shortcut = fin != fout fmiddle = min(fin, fout) self.opt = opt # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # apply spectral norm if specified if "spectral" in opt.norm_G: self.conv_0 = spectral_norm(self.conv_0) self.conv_1 = spectral_norm(self.conv_1) if self.learned_shortcut: self.conv_s = spectral_norm(self.conv_s) # define normalization layers spade_config_str = opt.norm_G.replace("spectral", "") self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) if self.learned_shortcut: self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) # note the resnet block with SPADE also takes in |seg|, # the semantic segmentation map as input def forward(self, x, seg, degraded_image): x_s = self.shortcut(x, seg, degraded_image) dx = self.conv_0(self.actvn(x)) dx = self.conv_1(self.actvn(dx)) out = x_s + dx return out def shortcut(self, x, seg, degraded_image): if self.learned_shortcut: x_s = self.conv_s(x) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) ================================================ FILE: Face_Enhancement/models/networks/base_network.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.nn as nn from torch.nn import init class BaseNetwork(nn.Module): def __init__(self): super(BaseNetwork, self).__init__() @staticmethod def modify_commandline_options(parser, is_train): return parser def print_network(self): if isinstance(self, list): self = self[0] num_params = 0 for param in self.parameters(): num_params += param.numel() print( "Network [%s] was created. Total number of parameters: %.1f million. " "To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000) ) def init_weights(self, init_type="normal", gain=0.02): def init_func(m): classname = m.__class__.__name__ if classname.find("BatchNorm2d") != -1: if hasattr(m, "weight") and m.weight is not None: init.normal_(m.weight.data, 1.0, gain) if hasattr(m, "bias") and m.bias is not None: init.constant_(m.bias.data, 0.0) elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): if init_type == "normal": init.normal_(m.weight.data, 0.0, gain) elif init_type == "xavier": init.xavier_normal_(m.weight.data, gain=gain) elif init_type == "xavier_uniform": init.xavier_uniform_(m.weight.data, gain=1.0) elif init_type == "kaiming": init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": init.orthogonal_(m.weight.data, gain=gain) elif init_type == "none": # uses pytorch's default init method m.reset_parameters() else: raise NotImplementedError("initialization method [%s] is not implemented" % init_type) if hasattr(m, "bias") and m.bias is not None: init.constant_(m.bias.data, 0.0) self.apply(init_func) # propagate to children for m in self.children(): if hasattr(m, "init_weights"): m.init_weights(init_type, gain) ================================================ FILE: Face_Enhancement/models/networks/encoder.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.nn as nn import numpy as np import torch.nn.functional as F from models.networks.base_network import BaseNetwork from models.networks.normalization import get_nonspade_norm_layer class ConvEncoder(BaseNetwork): """ Same architecture as the image discriminator """ def __init__(self, opt): super().__init__() kw = 3 pw = int(np.ceil((kw - 1.0) / 2)) ndf = opt.ngf norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) if opt.crop_size >= 256: self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) self.so = s0 = 4 self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) self.actvn = nn.LeakyReLU(0.2, False) self.opt = opt def forward(self, x): if x.size(2) != 256 or x.size(3) != 256: x = F.interpolate(x, size=(256, 256), mode="bilinear") x = self.layer1(x) x = self.layer2(self.actvn(x)) x = self.layer3(self.actvn(x)) x = self.layer4(self.actvn(x)) x = self.layer5(self.actvn(x)) if self.opt.crop_size >= 256: x = self.layer6(self.actvn(x)) x = self.actvn(x) x = x.view(x.size(0), -1) mu = self.fc_mu(x) logvar = self.fc_var(x) return mu, logvar ================================================ FILE: Face_Enhancement/models/networks/generator.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import torch.nn.functional as F from models.networks.base_network import BaseNetwork from models.networks.normalization import get_nonspade_norm_layer from models.networks.architecture import ResnetBlock as ResnetBlock from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade class SPADEGenerator(BaseNetwork): @staticmethod def modify_commandline_options(parser, is_train): parser.set_defaults(norm_G="spectralspadesyncbatch3x3") parser.add_argument( "--num_upsampling_layers", choices=("normal", "more", "most"), default="normal", help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator", ) return parser def __init__(self, opt): super().__init__() self.opt = opt nf = opt.ngf self.sw, self.sh = self.compute_latent_vector_size(opt) print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh)) if opt.use_vae: # In case of VAE, we will sample from random z vector self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) else: # Otherwise, we make the network deterministic by starting with # downsampled segmentation map instead of random z if self.opt.no_parsing_map: self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1) else: self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) if self.opt.injection_layer == "all" or self.opt.injection_layer == "1": self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) else: self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) if self.opt.injection_layer == "all" or self.opt.injection_layer == "2": self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt) else: self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) if self.opt.injection_layer == "all" or self.opt.injection_layer == "3": self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt) else: self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt) if self.opt.injection_layer == "all" or self.opt.injection_layer == "4": self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt) else: self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt) if self.opt.injection_layer == "all" or self.opt.injection_layer == "5": self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt) else: self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt) if self.opt.injection_layer == "all" or self.opt.injection_layer == "6": self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt) else: self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt) final_nc = nf if opt.num_upsampling_layers == "most": self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt) final_nc = nf // 2 self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) self.up = nn.Upsample(scale_factor=2) def compute_latent_vector_size(self, opt): if opt.num_upsampling_layers == "normal": num_up_layers = 5 elif opt.num_upsampling_layers == "more": num_up_layers = 6 elif opt.num_upsampling_layers == "most": num_up_layers = 7 else: raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers) sw = opt.load_size // (2 ** num_up_layers) sh = round(sw / opt.aspect_ratio) return sw, sh def forward(self, input, degraded_image, z=None): seg = input if self.opt.use_vae: # we sample z from unit normal and reshape the tensor if z is None: z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device()) x = self.fc(z) x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw) else: # we downsample segmap and run convolution if self.opt.no_parsing_map: x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear") else: x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest") x = self.fc(x) x = self.head_0(x, seg, degraded_image) x = self.up(x) x = self.G_middle_0(x, seg, degraded_image) if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most": x = self.up(x) x = self.G_middle_1(x, seg, degraded_image) x = self.up(x) x = self.up_0(x, seg, degraded_image) x = self.up(x) x = self.up_1(x, seg, degraded_image) x = self.up(x) x = self.up_2(x, seg, degraded_image) x = self.up(x) x = self.up_3(x, seg, degraded_image) if self.opt.num_upsampling_layers == "most": x = self.up(x) x = self.up_4(x, seg, degraded_image) x = self.conv_img(F.leaky_relu(x, 2e-1)) x = F.tanh(x) return x class Pix2PixHDGenerator(BaseNetwork): @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument( "--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG" ) parser.add_argument( "--resnet_n_blocks", type=int, default=9, help="number of residual blocks in the global generator network", ) parser.add_argument( "--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block" ) parser.add_argument( "--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution" ) # parser.set_defaults(norm_G='instance') return parser def __init__(self, opt): super().__init__() input_nc = 3 # print("xxxxx") # print(opt.norm_G) norm_layer = get_nonspade_norm_layer(opt, opt.norm_G) activation = nn.ReLU(False) model = [] # initial conv model += [ nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2), norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)), activation, ] # downsample mult = 1 for i in range(opt.resnet_n_downsample): model += [ norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)), activation, ] mult *= 2 # resnet blocks for i in range(opt.resnet_n_blocks): model += [ ResnetBlock( opt.ngf * mult, norm_layer=norm_layer, activation=activation, kernel_size=opt.resnet_kernel_size, ) ] # upsample for i in range(opt.resnet_n_downsample): nc_in = int(opt.ngf * mult) nc_out = int((opt.ngf * mult) / 2) model += [ norm_layer( nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1) ), activation, ] mult = mult // 2 # final output conv model += [ nn.ReflectionPad2d(3), nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0), nn.Tanh(), ] self.model = nn.Sequential(*model) def forward(self, input, degraded_image, z=None): return self.model(degraded_image) ================================================ FILE: Face_Enhancement/models/networks/normalization.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import re import torch import torch.nn as nn import torch.nn.functional as F from models.networks.sync_batchnorm import SynchronizedBatchNorm2d import torch.nn.utils.spectral_norm as spectral_norm def get_nonspade_norm_layer(opt, norm_type="instance"): # helper function to get # output channels of the previous layer def get_out_channel(layer): if hasattr(layer, "out_channels"): return getattr(layer, "out_channels") return layer.weight.size(0) # this function will be returned def add_norm_layer(layer): nonlocal norm_type if norm_type.startswith("spectral"): layer = spectral_norm(layer) subnorm_type = norm_type[len("spectral") :] if subnorm_type == "none" or len(subnorm_type) == 0: return layer # remove bias in the previous layer, which is meaningless # since it has no effect after normalization if getattr(layer, "bias", None) is not None: delattr(layer, "bias") layer.register_parameter("bias", None) if subnorm_type == "batch": norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == "sync_batch": norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == "instance": norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) else: raise ValueError("normalization layer %s is not recognized" % subnorm_type) return nn.Sequential(layer, norm_layer) return add_norm_layer class SPADE(nn.Module): def __init__(self, config_text, norm_nc, label_nc, opt): super().__init__() assert config_text.startswith("spade") parsed = re.search("spade(\D+)(\d)x\d", config_text) param_free_norm_type = str(parsed.group(1)) ks = int(parsed.group(2)) self.opt = opt if param_free_norm_type == "instance": self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) elif param_free_norm_type == "syncbatch": self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) elif param_free_norm_type == "batch": self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) else: raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type) # The dimension of the intermediate embedding space. Yes, hardcoded. nhidden = 128 pw = ks // 2 if self.opt.no_parsing_map: self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) else: self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU() ) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) def forward(self, x, segmap, degraded_image): # Part 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on semantic map segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear") if self.opt.no_parsing_map: actv = self.mlp_shared(degraded_face) else: actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1)) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) # apply scale and bias out = normalized * (1 + gamma) + beta return out ================================================ FILE: Face_Enhancement/models/pix2pix_model.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import models.networks as networks import util.util as util class Pix2PixModel(torch.nn.Module): @staticmethod def modify_commandline_options(parser, is_train): networks.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super().__init__() self.opt = opt self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor self.netG, self.netD, self.netE = self.initialize_networks(opt) # set loss functions if opt.isTrain: self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) if opt.use_vae: self.KLDLoss = networks.KLDLoss() # Entry point for all calls involving forward pass # of deep networks. We used this approach since DataParallel module # can't parallelize custom functions, we branch to different # routines based on |mode|. def forward(self, data, mode): input_semantics, real_image, degraded_image = self.preprocess_input(data) if mode == "generator": g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image) return g_loss, generated elif mode == "discriminator": d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image) return d_loss elif mode == "encode_only": z, mu, logvar = self.encode_z(real_image) return mu, logvar elif mode == "inference": with torch.no_grad(): fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) return fake_image else: raise ValueError("|mode| is invalid") def create_optimizers(self, opt): G_params = list(self.netG.parameters()) if opt.use_vae: G_params += list(self.netE.parameters()) if opt.isTrain: D_params = list(self.netD.parameters()) beta1, beta2 = opt.beta1, opt.beta2 if opt.no_TTUR: G_lr, D_lr = opt.lr, opt.lr else: G_lr, D_lr = opt.lr / 2, opt.lr * 2 optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) return optimizer_G, optimizer_D def save(self, epoch): util.save_network(self.netG, "G", epoch, self.opt) util.save_network(self.netD, "D", epoch, self.opt) if self.opt.use_vae: util.save_network(self.netE, "E", epoch, self.opt) ############################################################################ # Private helper methods ############################################################################ def initialize_networks(self, opt): netG = networks.define_G(opt) netD = networks.define_D(opt) if opt.isTrain else None netE = networks.define_E(opt) if opt.use_vae else None if not opt.isTrain or opt.continue_train: netG = util.load_network(netG, "G", opt.which_epoch, opt) if opt.isTrain: netD = util.load_network(netD, "D", opt.which_epoch, opt) if opt.use_vae: netE = util.load_network(netE, "E", opt.which_epoch, opt) return netG, netD, netE # preprocess the input, such as moving the tensors to GPUs and # transforming the label map to one-hot encoding # |data|: dictionary of the input data def preprocess_input(self, data): # move to GPU and change data types # data['label'] = data['label'].long() if not self.opt.isTrain: if self.use_gpu(): data["label"] = data["label"].cuda() data["image"] = data["image"].cuda() return data["label"], data["image"], data["image"] ## While testing, the input image is the degraded face if self.use_gpu(): data["label"] = data["label"].cuda() data["degraded_image"] = data["degraded_image"].cuda() data["image"] = data["image"].cuda() # # create one-hot label map # label_map = data['label'] # bs, _, h, w = label_map.size() # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \ # else self.opt.label_nc # input_label = self.FloatTensor(bs, nc, h, w).zero_() # input_semantics = input_label.scatter_(1, label_map, 1.0) return data["label"], data["image"], data["degraded_image"] def compute_generator_loss(self, input_semantics, degraded_image, real_image): G_losses = {} fake_image, KLD_loss = self.generate_fake( input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae ) if self.opt.use_vae: G_losses["KLD"] = KLD_loss pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False) if not self.opt.no_ganFeat_loss: num_D = len(pred_fake) GAN_Feat_loss = self.FloatTensor(1).fill_(0) for i in range(num_D): # for each discriminator # last output is the final prediction, so we exclude it num_intermediate_outputs = len(pred_fake[i]) - 1 for j in range(num_intermediate_outputs): # for each layer output unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D G_losses["GAN_Feat"] = GAN_Feat_loss if not self.opt.no_vgg_loss: G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg return G_losses, fake_image def compute_discriminator_loss(self, input_semantics, degraded_image, real_image): D_losses = {} with torch.no_grad(): fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) fake_image = fake_image.detach() fake_image.requires_grad_() pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True) D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True) return D_losses def encode_z(self, real_image): mu, logvar = self.netE(real_image) z = self.reparameterize(mu, logvar) return z, mu, logvar def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False): z = None KLD_loss = None if self.opt.use_vae: z, mu, logvar = self.encode_z(real_image) if compute_kld_loss: KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld fake_image = self.netG(input_semantics, degraded_image, z=z) assert ( not compute_kld_loss ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False" return fake_image, KLD_loss # Given fake and real image, return the prediction of discriminator # for each fake and real image. def discriminate(self, input_semantics, fake_image, real_image): if self.opt.no_parsing_map: fake_concat = fake_image real_concat = real_image else: fake_concat = torch.cat([input_semantics, fake_image], dim=1) real_concat = torch.cat([input_semantics, real_image], dim=1) # In Batch Normalization, the fake and real images are # recommended to be in the same batch to avoid disparate # statistics in fake and real images. # So both fake and real images are fed to D all at once. fake_and_real = torch.cat([fake_concat, real_concat], dim=0) discriminator_out = self.netD(fake_and_real) pred_fake, pred_real = self.divide_pred(discriminator_out) return pred_fake, pred_real # Take the prediction of fake and real images from the combined batch def divide_pred(self, pred): # the prediction contains the intermediate outputs of multiscale GAN, # so it's usually a list if type(pred) == list: fake = [] real = [] for p in pred: fake.append([tensor[: tensor.size(0) // 2] for tensor in p]) real.append([tensor[tensor.size(0) // 2 :] for tensor in p]) else: fake = pred[: pred.size(0) // 2] real = pred[pred.size(0) // 2 :] return fake, real def get_edges(self, t): edge = self.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) return edge.float() def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std) + mu def use_gpu(self): return len(self.opt.gpu_ids) > 0 ================================================ FILE: Face_Enhancement/options/__init__.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. ================================================ FILE: Face_Enhancement/options/base_options.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import sys import argparse import os from util import util import torch import models import data import pickle class BaseOptions: def __init__(self): self.initialized = False def initialize(self, parser): # experiment specifics parser.add_argument( "--name", type=str, default="label2coco", help="name of the experiment. It decides where to store samples and models", ) parser.add_argument( "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU" ) parser.add_argument( "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here" ) parser.add_argument("--model", type=str, default="pix2pix", help="which model to use") parser.add_argument( "--norm_G", type=str, default="spectralinstance", help="instance normalization or batch normalization", ) parser.add_argument( "--norm_D", type=str, default="spectralinstance", help="instance normalization or batch normalization", ) parser.add_argument( "--norm_E", type=str, default="spectralinstance", help="instance normalization or batch normalization", ) parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc") # input/output sizes parser.add_argument("--batchSize", type=int, default=1, help="input batch size") parser.add_argument( "--preprocess_mode", type=str, default="scale_width_and_crop", help="scaling and cropping of images at load time.", choices=( "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none", "resize", ), ) parser.add_argument( "--load_size", type=int, default=1024, help="Scale images to this size. The final image will be cropped to --crop_size.", ) parser.add_argument( "--crop_size", type=int, default=512, help="Crop to the width of crop_size (after initially scaling the images to load_size.)", ) parser.add_argument( "--aspect_ratio", type=float, default=1.0, help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio", ) parser.add_argument( "--label_nc", type=int, default=182, help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.", ) parser.add_argument( "--contain_dontcare_label", action="store_true", help="if the label map contains dontcare label (dontcare=255)", ) parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels") # for setting inputs parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/") parser.add_argument("--dataset_mode", type=str, default="coco") parser.add_argument( "--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly", ) parser.add_argument( "--no_flip", action="store_true", help="if specified, do not flip the images for data argumentation", ) parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data") parser.add_argument( "--max_dataset_size", type=int, default=sys.maxsize, help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", ) parser.add_argument( "--load_from_opt_file", action="store_true", help="load the options from checkpoints and use that as default", ) parser.add_argument( "--cache_filelist_write", action="store_true", help="saves the current filelist into a text file, so that it loads faster", ) parser.add_argument( "--cache_filelist_read", action="store_true", help="reads from the file list cache" ) # for displays parser.add_argument("--display_winsize", type=int, default=400, help="display window size") # for generator parser.add_argument( "--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)" ) parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer") parser.add_argument( "--init_type", type=str, default="xavier", help="network initialization [normal|xavier|kaiming|orthogonal]", ) parser.add_argument( "--init_variance", type=float, default=0.02, help="variance of the initialization distribution" ) parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector") parser.add_argument( "--no_parsing_map", action="store_true", help="During training, we do not use the parsing map" ) # for instance-wise features parser.add_argument( "--no_instance", action="store_true", help="if specified, do *not* add instance map as input" ) parser.add_argument( "--nef", type=int, default=16, help="# of encoder filters in the first conv layer" ) parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.") parser.add_argument( "--tensorboard_log", action="store_true", help="use tensorboard to record the resutls" ) # parser.add_argument('--img_dir',) parser.add_argument( "--old_face_folder", type=str, default="", help="The folder name of input old face" ) parser.add_argument( "--old_face_label_folder", type=str, default="", help="The folder name of input old face label" ) parser.add_argument("--injection_layer", type=str, default="all", help="") self.initialized = True return parser def gather_options(self): # initialize parser with basic options if not self.initialized: parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options opt, unknown = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) # modify dataset-related parser options # dataset_mode = opt.dataset_mode # dataset_option_setter = data.get_option_setter(dataset_mode) # parser = dataset_option_setter(parser, self.isTrain) opt, unknown = parser.parse_known_args() # if there is opt_file, load it. # The previous default options will be overwritten if opt.load_from_opt_file: parser = self.update_options_from_file(parser, opt) opt = parser.parse_args() self.parser = parser return opt def print_options(self, opt): message = "" message += "----------------- Options ---------------\n" for k, v in sorted(vars(opt).items()): comment = "" default = self.parser.get_default(k) if v != default: comment = "\t[default: %s]" % str(default) message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) message += "----------------- End -------------------" # print(message) def option_file_path(self, opt, makedir=False): expr_dir = os.path.join(opt.checkpoints_dir, opt.name) if makedir: util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, "opt") return file_name def save_options(self, opt): file_name = self.option_file_path(opt, makedir=True) with open(file_name + ".txt", "wt") as opt_file: for k, v in sorted(vars(opt).items()): comment = "" default = self.parser.get_default(k) if v != default: comment = "\t[default: %s]" % str(default) opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment)) with open(file_name + ".pkl", "wb") as opt_file: pickle.dump(opt, opt_file) def update_options_from_file(self, parser, opt): new_opt = self.load_options(opt) for k, v in sorted(vars(opt).items()): if hasattr(new_opt, k) and v != getattr(new_opt, k): new_val = getattr(new_opt, k) parser.set_defaults(**{k: new_val}) return parser def load_options(self, opt): file_name = self.option_file_path(opt, makedir=False) new_opt = pickle.load(open(file_name + ".pkl", "rb")) return new_opt def parse(self, save=False): opt = self.gather_options() opt.isTrain = self.isTrain # train or test opt.contain_dontcare_label = False self.print_options(opt) if opt.isTrain: self.save_options(opt) # Set semantic_nc based on the option. # This will be convenient in many places opt.semantic_nc = ( opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1) ) # set gpu ids str_ids = opt.gpu_ids.split(",") opt.gpu_ids = [] for str_id in str_ids: int_id = int(str_id) if int_id >= 0: opt.gpu_ids.append(int_id) if len(opt.gpu_ids) > 0: print("The main GPU is ") print(opt.gpu_ids[0]) torch.cuda.set_device(opt.gpu_ids[0]) assert ( len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0 ), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids)) self.opt = opt return self.opt ================================================ FILE: Face_Enhancement/options/test_options.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .base_options import BaseOptions class TestOptions(BaseOptions): def initialize(self, parser): BaseOptions.initialize(self, parser) parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") parser.add_argument( "--which_epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model", ) parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run") parser.set_defaults( preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256 ) parser.set_defaults(serial_batches=True) parser.set_defaults(no_flip=True) parser.set_defaults(phase="test") self.isTrain = False return parser ================================================ FILE: Face_Enhancement/requirements.txt ================================================ torch>=1.0.0 torchvision dominate>=2.3.1 wandb dill scikit-image tensorboardX scipy opencv-python ================================================ FILE: Face_Enhancement/test_face.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os from collections import OrderedDict import data from options.test_options import TestOptions from models.pix2pix_model import Pix2PixModel from util.visualizer import Visualizer import torchvision.utils as vutils import warnings warnings.filterwarnings("ignore", category=UserWarning) opt = TestOptions().parse() dataloader = data.create_dataloader(opt) model = Pix2PixModel(opt) model.eval() visualizer = Visualizer(opt) single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img") if not os.path.exists(single_save_url): os.makedirs(single_save_url) for i, data_i in enumerate(dataloader): if i * opt.batchSize >= opt.how_many: break generated = model(data_i, mode="inference") img_path = data_i["path"] for b in range(generated.shape[0]): img_name = os.path.split(img_path[b])[-1] save_img_url = os.path.join(single_save_url, img_name) vutils.save_image((generated[b] + 1) / 2, save_img_url) ================================================ FILE: Face_Enhancement/util/__init__.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. ================================================ FILE: Face_Enhancement/util/iter_counter.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import time import numpy as np # Helper class that keeps track of training iterations class IterationCounter: def __init__(self, opt, dataset_size): self.opt = opt self.dataset_size = dataset_size self.first_epoch = 1 self.total_epochs = opt.niter + opt.niter_decay self.epoch_iter = 0 # iter number within each epoch self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt") if opt.isTrain and opt.continue_train: try: self.first_epoch, self.epoch_iter = np.loadtxt( self.iter_record_path, delimiter=",", dtype=int ) print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter)) except: print( "Could not load iteration record at %s. Starting from beginning." % self.iter_record_path ) self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter # return the iterator of epochs for the training def training_epochs(self): return range(self.first_epoch, self.total_epochs + 1) def record_epoch_start(self, epoch): self.epoch_start_time = time.time() self.epoch_iter = 0 self.last_iter_time = time.time() self.current_epoch = epoch def record_one_iteration(self): current_time = time.time() # the last remaining batch is dropped (see data/__init__.py), # so we can assume batch size is always opt.batchSize self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize self.last_iter_time = current_time self.total_steps_so_far += self.opt.batchSize self.epoch_iter += self.opt.batchSize def record_epoch_end(self): current_time = time.time() self.time_per_epoch = current_time - self.epoch_start_time print( "End of epoch %d / %d \t Time Taken: %d sec" % (self.current_epoch, self.total_epochs, self.time_per_epoch) ) if self.current_epoch % self.opt.save_epoch_freq == 0: np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d") print("Saved current iteration count at %s." % self.iter_record_path) def record_current_iter(self): np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d") print("Saved current iteration count at %s." % self.iter_record_path) def needs_saving(self): return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize def needs_printing(self): return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize def needs_displaying(self): return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize ================================================ FILE: Face_Enhancement/util/util.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import re import importlib import torch from argparse import Namespace import numpy as np from PIL import Image import os import argparse import dill as pickle def save_obj(obj, name): with open(name, "wb") as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(name): with open(name, "rb") as f: return pickle.load(f) def copyconf(default_opt, **kwargs): conf = argparse.Namespace(**vars(default_opt)) for key in kwargs: print(key, kwargs[key]) setattr(conf, key, kwargs[key]) return conf # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): if isinstance(image_tensor, list): image_numpy = [] for i in range(len(image_tensor)): image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) return image_numpy if image_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(image_tensor.size(0)): one_image = image_tensor[b] one_image_np = tensor2im(one_image) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) return images_np if image_tensor.dim() == 2: image_tensor = image_tensor.unsqueeze(0) image_numpy = image_tensor.detach().cpu().float().numpy() if normalize: image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 else: image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 image_numpy = np.clip(image_numpy, 0, 255) if image_numpy.shape[2] == 1: image_numpy = image_numpy[:, :, 0] return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): if label_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(label_tensor.size(0)): one_image = label_tensor[b] one_image_np = tensor2label(one_image, n_label, imtype) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) # if tile: # images_tiled = tile_images(images_np) # return images_tiled # else: # images_np = images_np[0] # return images_np return images_np if label_tensor.dim() == 1: return np.zeros((64, 64, 3), dtype=np.uint8) if n_label == 0: return tensor2im(label_tensor, imtype) label_tensor = label_tensor.cpu().float() if label_tensor.size()[0] > 1: label_tensor = label_tensor.max(0, keepdim=True)[1] label_tensor = Colorize(n_label)(label_tensor) label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) result = label_numpy.astype(imtype) return result def save_image(image_numpy, image_path, create_dir=False): if create_dir: os.makedirs(os.path.dirname(image_path), exist_ok=True) if len(image_numpy.shape) == 2: image_numpy = np.expand_dims(image_numpy, axis=2) if image_numpy.shape[2] == 1: image_numpy = np.repeat(image_numpy, 3, 2) image_pil = Image.fromarray(image_numpy) # save to png image_pil.save(image_path.replace(".jpg", ".png")) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): """ alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) """ return [atoi(c) for c in re.split("(\d+)", text)] def natural_sort(items): items.sort(key=natural_keys) def str2bool(v): if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def find_class_in_module(target_cls_name, module): target_cls_name = target_cls_name.replace("_", "").lower() clslib = importlib.import_module(module) cls = None for name, clsobj in clslib.__dict__.items(): if name.lower() == target_cls_name: cls = clsobj if cls is None: print( "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) ) exit(0) return cls def save_network(net, label, epoch, opt): save_filename = "%s_net_%s.pth" % (epoch, label) save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) torch.save(net.cpu().state_dict(), save_path) if len(opt.gpu_ids) and torch.cuda.is_available(): net.cuda() def load_network(net, label, epoch, opt): save_filename = "%s_net_%s.pth" % (epoch, label) save_dir = os.path.join(opt.checkpoints_dir, opt.name) save_path = os.path.join(save_dir, save_filename) if os.path.exists(save_path): weights = torch.load(save_path) net.load_state_dict(weights) return net ############################################################################### # Code from # https://github.com/ycszen/pytorch-seg/blob/master/transform.py # Modified so it complies with the Citscape label map colors ############################################################################### def uint82bin(n, count=8): """returns the binary of integer n, count refers to amount of bits""" return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) class Colorize(object): def __init__(self, n=35): self.cmap = labelcolormap(n) self.cmap = torch.from_numpy(self.cmap[:n]) def __call__(self, gray_image): size = gray_image.size() color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) for label in range(0, len(self.cmap)): mask = (label == gray_image[0]).cpu() color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] return color_image ================================================ FILE: Face_Enhancement/util/visualizer.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import ntpath import time from . import util import scipy.misc try: from StringIO import StringIO # Python 2.7 except ImportError: from io import BytesIO # Python 3.x import torchvision.utils as vutils from tensorboardX import SummaryWriter import torch import numpy as np class Visualizer: def __init__(self, opt): self.opt = opt self.tf_log = opt.isTrain and opt.tf_log self.tensorboard_log = opt.tensorboard_log self.win_size = opt.display_winsize self.name = opt.name if self.tensorboard_log: if self.opt.isTrain: self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs") if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) self.writer = SummaryWriter(log_dir=self.log_dir) else: print("hi :)") self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) if opt.isTrain: self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt") with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write("================ Training Loss (%s) ================\n" % now) # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, step): all_tensor = [] if self.tensorboard_log: for key, tensor in visuals.items(): all_tensor.append((tensor.data.cpu() + 1) / 2) output = torch.cat(all_tensor, 0) img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False) if self.opt.isTrain: self.writer.add_image("Face_SPADE/training_samples", img_grid, step) else: vutils.save_image( output, os.path.join(self.log_dir, str(step) + ".png"), nrow=self.opt.batchSize, padding=0, normalize=False, ) # errors: dictionary of error labels and values def plot_current_errors(self, errors, step): if self.tf_log: for tag, value in errors.items(): value = value.mean().float() summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step) if self.tensorboard_log: self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step) self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step) self.writer.add_scalars( "Loss/GAN", { "G": errors["GAN"].mean().float(), "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2, }, step, ) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, t): message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t) for k, v in errors.items(): v = v.mean().float() message += "%s: %.3f " % (k, v) print(message) with open(self.log_name, "a") as log_file: log_file.write("%s\n" % message) def convert_visuals_to_numpy(self, visuals): for key, t in visuals.items(): tile = self.opt.batchSize > 8 if "input_label" == key: t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy else: t = util.tensor2im(t, tile=tile) visuals[key] = t return visuals # save image to the disk def save_images(self, webpage, visuals, image_path): visuals = self.convert_visuals_to_numpy(visuals) image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): image_name = os.path.join(label, "%s.png" % (name)) save_path = os.path.join(image_dir, image_name) util.save_image(image_numpy, save_path, create_dir=True) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=self.win_size) ================================================ FILE: GUI.py ================================================ import numpy as np import cv2 import PySimpleGUI as sg import os.path import argparse import os import sys import shutil from subprocess import call def modify(image_filename=None, cv2_frame=None): def run_cmd(command): try: call(command, shell=True) except KeyboardInterrupt: print("Process interrupted") sys.exit(1) parser = argparse.ArgumentParser() parser.add_argument("--input_folder", type=str, default= image_filename, help="Test images") parser.add_argument( "--output_folder", type=str, default="./output", help="Restored images, please use the absolute path", ) parser.add_argument("--GPU", type=str, default="-1", help="0,1,2") parser.add_argument( "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" ) parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true") opts = parser.parse_args() gpu1 = opts.GPU # resolve relative paths before changing directory opts.input_folder = os.path.abspath(opts.input_folder) opts.output_folder = os.path.abspath(opts.output_folder) if not os.path.exists(opts.output_folder): os.makedirs(opts.output_folder) main_environment = os.getcwd() # Stage 1: Overall Quality Improve print("Running Stage 1: Overall restoration") os.chdir("./Global") stage_1_input_dir = opts.input_folder stage_1_output_dir = os.path.join( opts.output_folder, "stage_1_restore_output") if not os.path.exists(stage_1_output_dir): os.makedirs(stage_1_output_dir) if not opts.with_scratch: stage_1_command = ( "python test.py --test_mode Full --Quality_restore --test_input " + stage_1_input_dir + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 ) run_cmd(stage_1_command) else: mask_dir = os.path.join(stage_1_output_dir, "masks") new_input = os.path.join(mask_dir, "input") new_mask = os.path.join(mask_dir, "mask") stage_1_command_1 = ( "python detection.py --test_path " + stage_1_input_dir + " --output_dir " + mask_dir + " --input_size full_size" + " --GPU " + gpu1 ) stage_1_command_2 = ( "python test.py --Scratch_and_Quality_restore --test_input " + new_input + " --test_mask " + new_mask + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 ) run_cmd(stage_1_command_1) run_cmd(stage_1_command_2) # Solve the case when there is no face in the old photo stage_1_results = os.path.join(stage_1_output_dir, "restored_image") stage_4_output_dir = os.path.join(opts.output_folder, "final_output") if not os.path.exists(stage_4_output_dir): os.makedirs(stage_4_output_dir) for x in os.listdir(stage_1_results): img_dir = os.path.join(stage_1_results, x) shutil.copy(img_dir, stage_4_output_dir) print("Finish Stage 1 ...") print("\n") # Stage 2: Face Detection print("Running Stage 2: Face Detection") os.chdir(".././Face_Detection") stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") stage_2_output_dir = os.path.join( opts.output_folder, "stage_2_detection_output") if not os.path.exists(stage_2_output_dir): os.makedirs(stage_2_output_dir) stage_2_command = ( "python detect_all_dlib.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir ) run_cmd(stage_2_command) print("Finish Stage 2 ...") print("\n") # Stage 3: Face Restore print("Running Stage 3: Face Enhancement") os.chdir(".././Face_Enhancement") stage_3_input_mask = "./" stage_3_input_face = stage_2_output_dir stage_3_output_dir = os.path.join( opts.output_folder, "stage_3_face_output") if not os.path.exists(stage_3_output_dir): os.makedirs(stage_3_output_dir) stage_3_command = ( "python test_face.py --old_face_folder " + stage_3_input_face + " --old_face_label_folder " + stage_3_input_mask + " --tensorboard_log --name " + opts.checkpoint_name + " --gpu_ids " + gpu1 + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " + stage_3_output_dir + " --no_parsing_map" ) run_cmd(stage_3_command) print("Finish Stage 3 ...") print("\n") # Stage 4: Warp back print("Running Stage 4: Blending") os.chdir(".././Face_Detection") stage_4_input_image_dir = os.path.join( stage_1_output_dir, "restored_image") stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") stage_4_output_dir = os.path.join(opts.output_folder, "final_output") if not os.path.exists(stage_4_output_dir): os.makedirs(stage_4_output_dir) stage_4_command = ( "python align_warp_back_multiple_dlib.py --origin_url " + stage_4_input_image_dir + " --replace_url " + stage_4_input_face_dir + " --save_url " + stage_4_output_dir ) run_cmd(stage_4_command) print("Finish Stage 4 ...") print("\n") print("All the processing is done. Please check the results.") # --------------------------------- The GUI --------------------------------- # First the window layout... images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()], [sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')], [sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],] # ----- Full layout ----- layout = [[sg.VSeperator(), sg.Column(images_col)]] # ----- Make the window ----- window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True) # ----- Run the Event Loop ----- prev_filename = colorized = cap = None while True: event, values = window.read() if event in (None, 'Exit'): break elif event == '-MPHOTO-': try: n1 = filename.split("/")[-2] n2 = filename.split("/")[-3] n3 = filename.split("/")[-1] filename= str(f"./{n2}/{n1}") modify(filename) global f_image f_image = f'./output/final_output/{n3}' image = cv2.imread(f_image) window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes()) except: continue elif event == '-IN FILE-': # A single filename was chosen filename = values['-IN FILE-'] if filename != prev_filename: prev_filename = filename try: image = cv2.imread(filename) window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes()) except: continue # ----- Exit program ----- window.close() ================================================ FILE: Global/data/Create_Bigfile.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import struct from PIL import Image IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): #print(fname) path = os.path.join(root, fname) images.append(path) return images ### Modify these 3 lines in your own environment indir="/home/ziyuwan/workspace/data/temp_old" target_folders=['VOC','Real_L_old','Real_RGB_old'] out_dir ="/home/ziyuwan/workspace/data/temp_old" ### if os.path.exists(out_dir) is False: os.makedirs(out_dir) # for target_folder in target_folders: curr_indir = os.path.join(indir, target_folder) curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder))) image_lists = make_dataset(curr_indir) image_lists.sort() with open(curr_out_file, 'wb') as wfid: # write total image number wfid.write(struct.pack('i', len(image_lists))) for i, img_path in enumerate(image_lists): # write file name first img_name = os.path.basename(img_path) img_name_bytes = img_name.encode('utf-8') wfid.write(struct.pack('i', len(img_name_bytes))) wfid.write(img_name_bytes) # # # write image data in with open(img_path, 'rb') as img_fid: img_bytes = img_fid.read() wfid.write(struct.pack('i', len(img_bytes))) wfid.write(img_bytes) if i % 1000 == 0: print('write %d images done' % i) ================================================ FILE: Global/data/Load_Bigfile.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import io import os import struct from PIL import Image class BigFileMemoryLoader(object): def __load_bigfile(self): print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024)) with open(self.file_path, 'rb') as fid: self.img_num = struct.unpack('i', fid.read(4))[0] self.img_names = [] self.img_bytes = [] print('find total %d images' % self.img_num) for i in range(self.img_num): img_name_len = struct.unpack('i', fid.read(4))[0] img_name = fid.read(img_name_len).decode('utf-8') self.img_names.append(img_name) img_bytes_len = struct.unpack('i', fid.read(4))[0] self.img_bytes.append(fid.read(img_bytes_len)) if i % 5000 == 0: print('load %d images done' % i) print('load all %d images done' % self.img_num) def __init__(self, file_path): super(BigFileMemoryLoader, self).__init__() self.file_path = file_path self.__load_bigfile() def __getitem__(self, index): try: img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB') return self.img_names[index], img except Exception: print('Image read error for index %d: %s' % (index, self.img_names[index])) return self.__getitem__((index+1)%self.img_num) def __len__(self): return self.img_num ================================================ FILE: Global/data/__init__.py ================================================ ================================================ FILE: Global/data/base_data_loader.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. class BaseDataLoader(): def __init__(self): pass def initialize(self, opt): self.opt = opt pass def load_data(): return None ================================================ FILE: Global/data/base_dataset.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms import numpy as np import random class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass def get_params(opt, size): w, h = size new_h = h new_w = w if opt.resize_or_crop == 'resize_and_crop': new_h = new_w = opt.loadSize if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 if w 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(opt, params, method=Image.BICUBIC, normalize=True): transform_list = [] if 'resize' in opt.resize_or_crop: osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Scale(osize, method)) elif 'scale_width' in opt.resize_or_crop: # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. transform_list.append(transforms.Scale(256,method)) if 'crop' in opt.resize_or_crop: if opt.isTrain: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) else: if opt.test_random_crop: transform_list.append(transforms.RandomCrop(opt.fineSize)) else: transform_list.append(transforms.CenterCrop(opt.fineSize)) ## when testing, for ablation study, choose center_crop directly. if opt.resize_or_crop == 'none': base = float(2 ** opt.n_downsample_global) if opt.netG == 'local': base *= (2 ** opt.n_local_enhancers) transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) def __scale_width(img, target_width, method=Image.BICUBIC): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size if (ow > tw or oh > th): return img.crop((x1, y1, x1 + tw, y1 + th)) return img def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img ================================================ FILE: Global/data/custom_dataset_data_loader.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data import random from data.base_data_loader import BaseDataLoader from data import online_dataset_for_old_photos as dts_ray_bigfile def CreateDataset(opt): dataset = None if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': dataset = dts_ray_bigfile.UnPairOldPhotos_SR() if opt.training_dataset=='mapping': if opt.random_hole: dataset = dts_ray_bigfile.PairOldPhotos_with_hole() else: dataset = dts_ray_bigfile.PairOldPhotos() print("dataset [%s] was created" % (dataset.name())) dataset.initialize(opt) return dataset class CustomDatasetDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) self.dataset = CreateDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads), drop_last=True) def load_data(self): return self.dataloader def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) ================================================ FILE: Global/data/data_loader.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. def CreateDataLoader(opt): from data.custom_dataset_data_loader import CustomDatasetDataLoader data_loader = CustomDatasetDataLoader() print(data_loader.name()) data_loader.initialize(opt) return data_loader ================================================ FILE: Global/data/image_folder.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import os IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return images def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: Global/data/online_dataset_for_old_photos.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os.path import io import zipfile from data.base_dataset import BaseDataset, get_params, get_transform, normalize from data.image_folder import make_dataset from PIL import Image import torchvision.transforms as transforms import numpy as np from data.Load_Bigfile import BigFileMemoryLoader import random import cv2 from io import BytesIO def pil_to_np(img_PIL): '''Converts image in PIL format to np.array. From W x H x C [0...255] to C x W x H [0..1] ''' ar = np.array(img_PIL) if len(ar.shape) == 3: ar = ar.transpose(2, 0, 1) else: ar = ar[None, ...] return ar.astype(np.float32) / 255. def np_to_pil(img_np): '''Converts image in np.array format to PIL image. From C x W x H [0..1] to W x H x C [0...255] ''' ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) if img_np.shape[0] == 1: ar = ar[0] else: ar = ar.transpose(1, 2, 0) return Image.fromarray(ar) def synthesize_salt_pepper(image,amount,salt_vs_pepper): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) out = img_pil.copy() p = amount q = salt_vs_pepper flipped = np.random.choice([True, False], size=img_pil.shape, p=[p, 1 - p]) salted = np.random.choice([True, False], size=img_pil.shape, p=[q, 1 - q]) peppered = ~salted out[flipped & salted] = 1 out[flipped & peppered] = 0. noisy = np.clip(out, 0, 1).astype(np.float32) return np_to_pil(noisy) def synthesize_gaussian(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) def synthesize_speckle(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss*img_pil noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) def synthesize_low_resolution(img): w,h=img.size new_w=random.randint(int(w/2),w) new_h=random.randint(int(h/2),h) img=img.resize((new_w,new_h),Image.BICUBIC) if random.uniform(0,1)<0.5: img=img.resize((w,h),Image.NEAREST) else: img = img.resize((w, h), Image.BILINEAR) return img def convertToJpeg(im,quality): with BytesIO() as f: im.save(f, format='JPEG',quality=quality) f.seek(0) return Image.open(f).convert('RGB') def blur_image_v2(img): x=np.array(img) kernel_size_candidate=[(3,3),(5,5),(7,7)] kernel_size=random.sample(kernel_size_candidate,1)[0] std=random.uniform(1.,5.) #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std)) blur=cv2.GaussianBlur(x,kernel_size,std) return Image.fromarray(blur.astype(np.uint8)) def online_add_degradation_v2(img): task_id=np.random.permutation(4) for x in task_id: if x==0 and random.uniform(0,1)<0.7: img = blur_image_v2(img) if x==1 and random.uniform(0,1)<0.7: flag = random.choice([1, 2, 3]) if flag == 1: img = synthesize_gaussian(img, 5, 50) if flag == 2: img = synthesize_speckle(img, 5, 50) if flag == 3: img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8)) if x==2 and random.uniform(0,1)<0.7: img=synthesize_low_resolution(img) if x==3 and random.uniform(0,1)<0.7: img=convertToJpeg(img,random.randint(40,100)) return img def irregular_hole_synthesize(img,mask): img_np=np.array(img).astype('uint8') mask_np=np.array(mask).astype('uint8') mask_np=mask_np/255 img_new=img_np*(1-mask_np)+mask_np*255 hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB") return hole_img,mask.convert("L") def zero_mask(size): x=np.zeros((size,size,3)).astype('uint8') mask=Image.fromarray(x).convert("RGB") return mask class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old def initialize(self, opt): self.opt = opt self.isImage = 'domainA' in opt.name self.task = 'old_photo_restoration_training_vae' self.dir_AB = opt.dataroot if self.isImage: self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile") self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile") self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old) self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old) self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) else: # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset) self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) #### print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean=[] for i in range(len(self.loaded_imgs_clean)): img_name,img=self.loaded_imgs_clean[i] h,w=img.size if h<256 or w<256: continue self.filtered_imgs_clean.append((img_name,img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) ## Filter these images whose size is less than 256 # self.img_list=os.listdir(load_img_dir) self.pid = os.getpid() def __getitem__(self, index): is_real_old=0 sampled_dataset=None degradation=None if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old P=random.uniform(0,2) if P>=0 and P<1: if random.uniform(0,1)<0.5: sampled_dataset=self.loaded_imgs_L_old self.load_img_dir=self.load_img_dir_L_old else: sampled_dataset=self.loaded_imgs_RGB_old self.load_img_dir=self.load_img_dir_RGB_old is_real_old=1 if P>=1 and P<2: sampled_dataset=self.filtered_imgs_clean self.load_img_dir=self.load_img_dir_clean degradation=1 else: sampled_dataset=self.filtered_imgs_clean self.load_img_dir=self.load_img_dir_clean sampled_dataset_len=len(sampled_dataset) index=random.randint(0,sampled_dataset_len-1) img_name,img = sampled_dataset[index] if degradation is not None: img=online_add_degradation_v2(img) path=os.path.join(self.load_img_dir,img_name) # AB = Image.open(path).convert('RGB') # split AB image into A and B # apply the same transform to both A and B if random.uniform(0,1) <0.1: img=img.convert("L") img=img.convert("RGB") ## Give a probability P, we convert the RGB image into L A=img w,h=A.size if w<256 or h<256: A=transforms.Scale(256,Image.BICUBIC)(A) ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them. transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number def name(self): return 'UnPairOldPhotos_SR' class PairOldPhotos(BaseDataset): def initialize(self, opt): self.opt = opt self.isImage = 'imagegan' in opt.name self.task = 'old_photo_restoration_training_mapping' self.dir_AB = opt.dataroot if opt.isTrain: self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean = [] for i in range(len(self.loaded_imgs_clean)): img_name, img = self.loaded_imgs_clean[i] h, w = img.size if h < 256 or w < 256: continue self.filtered_imgs_clean.append((img_name, img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) else: self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) self.pid = os.getpid() def __getitem__(self, index): if self.opt.isTrain: img_name_clean,B = self.filtered_imgs_clean[index] path = os.path.join(self.load_img_dir_clean, img_name_clean) if self.opt.use_v2_degradation: A=online_add_degradation_v2(B) ### Remind: A is the input and B is corresponding GT else: if self.opt.test_on_synthetic: img_name_B,B=self.loaded_imgs[index] A=online_add_degradation_v2(B) img_name_A=img_name_B path = os.path.join(self.load_img_dir, img_name_A) else: img_name_A,A=self.loaded_imgs[index] img_name_B,B=self.loaded_imgs[index] path = os.path.join(self.load_img_dir, img_name_A) if random.uniform(0,1)<0.1 and self.opt.isTrain: A=A.convert("L") B=B.convert("L") A=A.convert("RGB") B=B.convert("RGB") ## In P, we convert the RGB into L ##test on L # split AB image into A and B # w, h = img.size # w2 = int(w / 2) # A = img.crop((0, 0, w2, h)) # B = img.crop((w2, 0, w, h)) w,h=A.size if w<256 or h<256: A=transforms.Scale(256,Image.BICUBIC)(A) B=transforms.Scale(256, Image.BICUBIC)(B) # apply the same transform to both A and B transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_transform = get_transform(self.opt, transform_params) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) B_tensor = B_transform(B) input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): if self.opt.isTrain: return len(self.filtered_imgs_clean) else: return len(self.loaded_imgs) def name(self): return 'PairOldPhotos' class PairOldPhotos_with_hole(BaseDataset): def initialize(self, opt): self.opt = opt self.isImage = 'imagegan' in opt.name self.task = 'old_photo_restoration_training_mapping' self.dir_AB = opt.dataroot if opt.isTrain: self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean = [] for i in range(len(self.loaded_imgs_clean)): img_name, img = self.loaded_imgs_clean[i] h, w = img.size if h < 256 or w < 256: continue self.filtered_imgs_clean.append((img_name, img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) else: self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask) self.pid = os.getpid() def __getitem__(self, index): if self.opt.isTrain: img_name_clean,B = self.filtered_imgs_clean[index] path = os.path.join(self.load_img_dir_clean, img_name_clean) B=transforms.RandomCrop(256)(B) A=online_add_degradation_v2(B) ### Remind: A is the input and B is corresponding GT else: img_name_A,A=self.loaded_imgs[index] img_name_B,B=self.loaded_imgs[index] path = os.path.join(self.load_img_dir, img_name_A) #A=A.resize((256,256)) A=transforms.CenterCrop(256)(A) B=A if random.uniform(0,1)<0.1 and self.opt.isTrain: A=A.convert("L") B=B.convert("L") A=A.convert("RGB") B=B.convert("RGB") ## In P, we convert the RGB into L if self.opt.isTrain: mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)] else: mask_name, mask = self.loaded_masks[index%100] mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST) if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain: mask=zero_mask(256) if self.opt.no_hole: mask=zero_mask(256) A,_=irregular_hole_synthesize(A,mask) if not self.opt.isTrain and self.opt.hole_image_no_mask: mask=zero_mask(256) transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_transform = get_transform(self.opt, transform_params) if transform_params['flip'] and self.opt.isTrain: mask=mask.transpose(Image.FLIP_LEFT_RIGHT) mask_tensor = transforms.ToTensor()(mask) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) B_tensor = B_transform(B) input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): if self.opt.isTrain: return len(self.filtered_imgs_clean) else: return len(self.loaded_imgs) def name(self): return 'PairOldPhotos_with_hole' ================================================ FILE: Global/detection.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import argparse import gc import json import os import time import warnings import numpy as np import torch import torch.nn.functional as F import torchvision as tv from PIL import Image, ImageFile from detection_models import networks from detection_util.util import * warnings.filterwarnings("ignore", category=UserWarning) ImageFile.LOAD_TRUNCATED_IMAGES = True def data_transforms(img, full_size, method=Image.BICUBIC): if full_size == "full_size": ow, oh = img.size h = int(round(oh / 16) * 16) w = int(round(ow / 16) * 16) if (h == oh) and (w == ow): return img return img.resize((w, h), method) elif full_size == "scale_256": ow, oh = img.size pw, ph = ow, oh if ow < oh: ow = 256 oh = ph / pw * 256 else: oh = 256 ow = pw / ph * 256 h = int(round(oh / 16) * 16) w = int(round(ow / 16) * 16) if (h == ph) and (w == pw): return img return img.resize((w, h), method) def scale_tensor(img_tensor, default_scale=256): _, _, w, h = img_tensor.shape if w < h: ow = default_scale oh = h / w * default_scale else: oh = default_scale ow = w / h * default_scale oh = int(round(oh / 16) * 16) ow = int(round(ow / 16) * 16) return F.interpolate(img_tensor, [ow, oh], mode="bilinear") def blend_mask(img, mask): np_img = np.array(img).astype("float") return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB") def main(config): print("initializing the dataloader") model = networks.UNet( in_channels=1, out_channels=1, depth=4, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample", with_tanh=False, sync_bn=True, antialiasing=True, ) ## load model checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model_state"]) print("model weights loaded") if config.GPU >= 0: model.to(config.GPU) else: model.cpu() model.eval() ## dataloader and transformation print("directory of testing image: " + config.test_path) imagelist = os.listdir(config.test_path) imagelist.sort() total_iter = 0 P_matrix = {} save_url = os.path.join(config.output_dir) mkdir_if_not(save_url) input_dir = os.path.join(save_url, "input") output_dir = os.path.join(save_url, "mask") # blend_output_dir=os.path.join(save_url, 'blend_output') mkdir_if_not(input_dir) mkdir_if_not(output_dir) # mkdir_if_not(blend_output_dir) idx = 0 results = [] for image_name in imagelist: idx += 1 print("processing", image_name) scratch_file = os.path.join(config.test_path, image_name) if not os.path.isfile(scratch_file): print("Skipping non-file %s" % image_name) continue scratch_image = Image.open(scratch_file).convert("RGB") w, h = scratch_image.size transformed_image_PIL = data_transforms(scratch_image, config.input_size) scratch_image = transformed_image_PIL.convert("L") scratch_image = tv.transforms.ToTensor()(scratch_image) scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image) scratch_image = torch.unsqueeze(scratch_image, 0) _, _, ow, oh = scratch_image.shape scratch_image_scale = scale_tensor(scratch_image) if config.GPU >= 0: scratch_image_scale = scratch_image_scale.to(config.GPU) else: scratch_image_scale = scratch_image_scale.cpu() with torch.no_grad(): P = torch.sigmoid(model(scratch_image_scale)) P = P.data.cpu() P = F.interpolate(P, [ow, oh], mode="nearest") tv.utils.save_image( (P >= 0.4).float(), os.path.join( output_dir, image_name[:-4] + ".png", ), nrow=1, padding=0, normalize=True, ) transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png")) gc.collect() torch.cuda.empty_cache() if __name__ == "__main__": parser = argparse.ArgumentParser() # parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name') parser.add_argument("--GPU", type=int, default=0) parser.add_argument("--test_path", type=str, default=".") parser.add_argument("--output_dir", type=str, default=".") parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256") config = parser.parse_args() main(config) ================================================ FILE: Global/detection_models/__init__.py ================================================ ================================================ FILE: Global/detection_models/antialiasing.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn.parallel import numpy as np import torch.nn as nn import torch.nn.functional as F class Downsample(nn.Module): # https://github.com/adobe/antialiased-cnns def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): super(Downsample, self).__init__() self.filt_size = filt_size self.pad_off = pad_off self.pad_sizes = [ int(1.0 * (filt_size - 1) / 2), int(np.ceil(1.0 * (filt_size - 1) / 2)), int(1.0 * (filt_size - 1) / 2), int(np.ceil(1.0 * (filt_size - 1) / 2)), ] self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] self.stride = stride self.off = int((self.stride - 1) / 2.0) self.channels = channels # print('Filter size [%i]'%filt_size) if self.filt_size == 1: a = np.array([1.0,]) elif self.filt_size == 2: a = np.array([1.0, 1.0]) elif self.filt_size == 3: a = np.array([1.0, 2.0, 1.0]) elif self.filt_size == 4: a = np.array([1.0, 3.0, 3.0, 1.0]) elif self.filt_size == 5: a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) elif self.filt_size == 6: a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) elif self.filt_size == 7: a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) filt = torch.Tensor(a[:, None] * a[None, :]) filt = filt / torch.sum(filt) self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) self.pad = get_pad_layer(pad_type)(self.pad_sizes) def forward(self, inp): if self.filt_size == 1: if self.pad_off == 0: return inp[:, :, :: self.stride, :: self.stride] else: return self.pad(inp)[:, :, :: self.stride, :: self.stride] else: return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) def get_pad_layer(pad_type): if pad_type in ["refl", "reflect"]: PadLayer = nn.ReflectionPad2d elif pad_type in ["repl", "replicate"]: PadLayer = nn.ReplicationPad2d elif pad_type == "zero": PadLayer = nn.ZeroPad2d else: print("Pad type [%s] not recognized" % pad_type) return PadLayer ================================================ FILE: Global/detection_models/networks.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import torch.nn.functional as F from detection_models.sync_batchnorm import DataParallelWithCallback from detection_models.antialiasing import Downsample class UNet(nn.Module): def __init__( self, in_channels=3, out_channels=3, depth=5, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample", with_tanh=False, sync_bn=True, antialiasing=True, ): """ Implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) https://arxiv.org/abs/1505.04597 Using the default arguments will yield the exact version used in the original paper Args: in_channels (int): number of input channels out_channels (int): number of output channels depth (int): depth of the network wf (int): number of filters in the first layer is 2**wf padding (bool): if True, apply padding such that the input shape is the same as the output. This may introduce artifacts batch_norm (bool): Use BatchNorm after layers with an activation function up_mode (str): one of 'upconv' or 'upsample'. 'upconv' will use transposed convolutions for learned upsampling. 'upsample' will use bilinear upsampling. """ super().__init__() assert up_mode in ("upconv", "upsample") self.padding = padding self.depth = depth - 1 prev_channels = in_channels self.first = nn.Sequential( *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] ) prev_channels = 2 ** wf self.down_path = nn.ModuleList() self.down_sample = nn.ModuleList() for i in range(depth): if antialiasing and depth > 0: self.down_sample.append( nn.Sequential( *[ nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(prev_channels), nn.LeakyReLU(0.2, True), Downsample(channels=prev_channels, stride=2), ] ) ) else: self.down_sample.append( nn.Sequential( *[ nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), nn.BatchNorm2d(prev_channels), nn.LeakyReLU(0.2, True), ] ) ) self.down_path.append( UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) ) prev_channels = 2 ** (wf + i + 1) self.up_path = nn.ModuleList() for i in reversed(range(depth)): self.up_path.append( UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) ) prev_channels = 2 ** (wf + i) if with_tanh: self.last = nn.Sequential( *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] ) else: self.last = nn.Sequential( *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] ) if sync_bn: self = DataParallelWithCallback(self) def forward(self, x): x = self.first(x) blocks = [] for i, down_block in enumerate(self.down_path): blocks.append(x) x = self.down_sample[i](x) x = down_block(x) for i, up in enumerate(self.up_path): x = up(x, blocks[-i - 1]) return self.last(x) class UNetConvBlock(nn.Module): def __init__(self, conv_num, in_size, out_size, padding, batch_norm): super(UNetConvBlock, self).__init__() block = [] for _ in range(conv_num): block.append(nn.ReflectionPad2d(padding=int(padding))) block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) if batch_norm: block.append(nn.BatchNorm2d(out_size)) block.append(nn.LeakyReLU(0.2, True)) in_size = out_size self.block = nn.Sequential(*block) def forward(self, x): out = self.block(x) return out class UNetUpBlock(nn.Module): def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): super(UNetUpBlock, self).__init__() if up_mode == "upconv": self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) elif up_mode == "upsample": self.up = nn.Sequential( nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), nn.ReflectionPad2d(1), nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), ) self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) def center_crop(self, layer, target_size): _, _, layer_height, layer_width = layer.size() diff_y = (layer_height - target_size[0]) // 2 diff_x = (layer_width - target_size[1]) // 2 return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] def forward(self, x, bridge): up = self.up(x) crop1 = self.center_crop(bridge, up.shape[2:]) out = torch.cat([up, crop1], 1) out = self.conv_block(out) return out class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super().__init__() if norm_type == "BN": norm_layer = nn.BatchNorm2d elif norm_type == "IN": norm_layer = nn.InstanceNorm2d else: raise NameError("Unknown norm layer") # construct unet structure unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) self.model = UnetSkipConnectionBlock( output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer ) # add the outermost layer def forward(self, input): return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__( self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, ): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer user_dropout (bool) -- if use dropout layers. """ super().__init__() self.outermost = outermost use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.LeakyReLU(0.2, True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) # ============================================ # Network testing # ============================================ if __name__ == "__main__": from torchsummary import summary device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet_two_decoders( in_channels=3, out_channels1=3, out_channels2=1, depth=4, conv_num=1, wf=6, padding=True, batch_norm=True, up_mode="upsample", with_tanh=False, ) model.to(device) model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) model_pix2pix.to(device) print("customized unet:") summary(model, (3, 256, 256)) print("cyclegan unet:") summary(model_pix2pix, (3, 256, 256)) x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() g = make_dot(model(x)) g.render("models/Digraph.gv", view=False) ================================================ FILE: Global/detection_util/util.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import sys import time import shutil import platform import numpy as np from datetime import datetime import torch import torchvision as tv import torch.backends.cudnn as cudnn # from torch.utils.tensorboard import SummaryWriter import yaml import matplotlib.pyplot as plt from easydict import EasyDict as edict import torchvision.utils as vutils ##### option parsing ###### def print_options(config_dict): print("------------ Options -------------") for k, v in sorted(config_dict.items()): print("%s: %s" % (str(k), str(v))) print("-------------- End ----------------") def save_options(config_dict): from time import gmtime, strftime file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"]) mkdir_if_not(file_dir) file_name = os.path.join(file_dir, "opt.txt") with open(file_name, "wt") as opt_file: opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") opt_file.write("------------ Options -------------\n") for k, v in sorted(config_dict.items()): opt_file.write("%s: %s\n" % (str(k), str(v))) opt_file.write("-------------- End ----------------\n") def config_parse(config_file, options, save=True): with open(config_file, "r") as stream: config_dict = yaml.safe_load(stream) config = edict(config_dict) for option_key, option_value in vars(options).items(): config_dict[option_key] = option_value config[option_key] = option_value if config.debug_mode: config_dict["num_workers"] = 0 config.num_workers = 0 config.batch_size = 2 if isinstance(config.gpu_ids, str): config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0] print_options(config_dict) if save: save_options(config_dict) return config ###### utility ###### def to_np(x): return x.cpu().numpy() def prepare_device(use_gpu, gpu_ids): if use_gpu: cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if isinstance(gpu_ids, str): gpu_ids = [int(x) for x in gpu_ids.split(",")] torch.cuda.set_device(gpu_ids[0]) device = torch.device("cuda:" + str(gpu_ids[0])) else: torch.cuda.set_device(gpu_ids) device = torch.device("cuda:" + str(gpu_ids)) print("running on GPU {}".format(gpu_ids)) else: device = torch.device("cpu") print("running on CPU") return device ###### file system ###### def get_dir_size(start_path="."): total_size = 0 for dirpath, dirnames, filenames in os.walk(start_path): for f in filenames: fp = os.path.join(dirpath, f) total_size += os.path.getsize(fp) return total_size def mkdir_if_not(dir_path): if not os.path.exists(dir_path): os.makedirs(dir_path) ##### System related ###### class Timer: def __init__(self, msg): self.msg = msg self.start_time = None def __enter__(self): self.start_time = time.time() def __exit__(self, exc_type, exc_value, exc_tb): elapse = time.time() - self.start_time print(self.msg % elapse) ###### interactive ###### def get_size(start_path="."): total_size = 0 for dirpath, dirnames, filenames in os.walk(start_path): for f in filenames: fp = os.path.join(dirpath, f) total_size += os.path.getsize(fp) return total_size def clean_tensorboard(directory): tensorboard_list = os.listdir(directory) SIZE_THRESH = 100000 for tensorboard in tensorboard_list: tensorboard = os.path.join(directory, tensorboard) if get_size(tensorboard) < SIZE_THRESH: print("deleting the empty tensorboard: ", tensorboard) # if os.path.isdir(tensorboard): shutil.rmtree(tensorboard) else: os.remove(tensorboard) def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")): tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs") mkdir_if_not(tensorboard_directory) clean_tensorboard(tensorboard_directory) tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10) # try: # shutil.copy('outputs/opt.txt', tensorboard_directory) # except: # print('cannot find file opt.txt') return tb_writer def tb_loss_logger(tb_writer, iter_index, loss_logger): for tag, value in loss_logger.items(): tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index) def tb_image_logger(tb_writer, iter_index, images_info, config): ### Save and write the output into the tensorboard tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode) mkdir_if_not(tb_logger_path) for tag, image in images_info.items(): if tag == "test_image_prediction" or tag == "image_prediction": continue image = tv.utils.make_grid(image.cpu()) image = torch.clamp(image, 0, 1) tb_writer.add_image(tag, img_tensor=image, global_step=iter_index) tv.transforms.functional.to_pil_image(image).save( os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag)) ) def tb_image_logger_test(epoch, iter, images_info, config): url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch)) if not os.path.exists(url): os.makedirs(url) scratch_img = images_info["test_scratch_image"].data.cpu() if config.norm_input: scratch_img = (scratch_img + 1.0) / 2.0 scratch_img = torch.clamp(scratch_img, 0, 1) gt_mask = images_info["test_mask_image"].data.cpu() predict_mask = images_info["test_scratch_prediction"].data.cpu() predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float() imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0) img_grid = vutils.save_image( imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True ) def imshow(input_image, title=None, to_numpy=False): inp = input_image if to_numpy or type(input_image) is torch.Tensor: inp = input_image.numpy() fig = plt.figure() if inp.ndim == 2: fig = plt.imshow(inp, cmap="gray", clim=[0, 255]) else: fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8)) plt.axis("off") fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title(title) ###### vgg preprocessing ###### def vgg_preprocess(tensor): # input is RGB tensor which ranges in [0,1] # output is BGR tensor which ranges in [0,255] tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) # tensor_bgr = tensor[:, [2, 1, 0], ...] tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view( 1, 3, 1, 1 ) tensor_rst = tensor_bgr_ml * 255 return tensor_rst def torch_vgg_preprocess(tensor): # pytorch version normalization # note that both input and output are RGB tensors; # input and output ranges in [0,1] # normalize the tensor with mean and variance tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1) tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1) return tensor_mc_norm def network_gradient(net, gradient_on=True): if gradient_on: for param in net.parameters(): param.requires_grad = True else: for param in net.parameters(): param.requires_grad = False return net ================================================ FILE: Global/models/NonLocal_feature_mapping_model.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import os import functools from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks import math class Mapping_Model_with_mask(nn.Module): def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): super(Mapping_Model_with_mask, self).__init__() norm_layer = networks.get_norm_layer(norm_type=norm) activation = nn.ReLU(True) model = [] tmp_nc = 64 n_up = 4 for i in range(n_up): ic = min(tmp_nc * (2 ** i), mc) oc = min(tmp_nc * (2 ** (i + 1)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] self.before_NL = nn.Sequential(*model) if opt.NL_res: self.NL = networks.NonLocalBlock2D_with_mask_Res( mc, mc, opt.NL_fusion_method, opt.correlation_renormalize, opt.softmax_temperature, opt.use_self, opt.cosin_similarity, ) print("You are using NL + Res") model = [] for i in range(n_blocks): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] for i in range(n_up - 1): ic = min(64 * (2 ** (4 - i)), mc) oc = min(64 * (2 ** (3 - i)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] if opt.feat_dim > 0 and opt.feat_dim < 64: model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] # model += [nn.Conv2d(64, 1, 1, 1, 0)] self.after_NL = nn.Sequential(*model) def forward(self, input, mask): x1 = self.before_NL(input) del input x2 = self.NL(x1, mask) del x1, mask x3 = self.after_NL(x2) del x2 return x3 class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): super(Mapping_Model_with_mask_2, self).__init__() norm_layer = networks.get_norm_layer(norm_type=norm) activation = nn.ReLU(True) model = [] tmp_nc = 64 n_up = 4 for i in range(n_up): ic = min(tmp_nc * (2 ** i), mc) oc = min(tmp_nc * (2 ** (i + 1)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] print("Mapping: You are using multi-scale patch attention, conv combine + mask input") self.before_NL = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] self.res_block_1 = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] self.res_block_2 = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2) # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] for i in range(n_up - 1): ic = min(64 * (2 ** (4 - i)), mc) oc = min(64 * (2 ** (3 - i)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] if opt.feat_dim > 0 and opt.feat_dim < 64: model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] # model += [nn.Conv2d(64, 1, 1, 1, 0)] self.after_NL = nn.Sequential(*model) def forward(self, input, mask): x1 = self.before_NL(input) x2 = self.NL_scale_1(x1,mask) x3 = self.res_block_1(x2) x4 = self.NL_scale_2(x3,mask) x5 = self.res_block_2(x4) x6 = self.NL_scale_3(x5,mask) x7 = self.after_NL(x6) return x7 def inference_forward(self, input, mask): x1 = self.before_NL(input) del input x2 = self.NL_scale_1.inference_forward(x1,mask) del x1 x3 = self.res_block_1(x2) del x2 x4 = self.NL_scale_2.inference_forward(x3,mask) del x3 x5 = self.res_block_2(x4) del x4 x6 = self.NL_scale_3.inference_forward(x5,mask) del x5 x7 = self.after_NL(x6) del x6 return x7 ================================================ FILE: Global/models/__init__.py ================================================ ================================================ FILE: Global/models/base_model.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import torch import sys class BaseModel(torch.nn.Module): def name(self): return "BaseModel" def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) def set_input(self, input): self.input = input def forward(self): pass # used in test time, no backprop def test(self): pass def get_image_paths(self): pass def optimize_parameters(self): pass def get_current_visuals(self): return self.input def get_current_errors(self): return {} def save(self, label): pass # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch_label, gpu_ids): save_filename = "%s_net_%s.pth" % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): network.cuda() def save_optimizer(self, optimizer, optimizer_label, epoch_label): save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""): save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) if not save_dir: save_dir = self.save_dir save_path = os.path.join(save_dir, save_filename) if not os.path.isfile(save_path): print("%s not exists yet!" % save_path) else: optimizer.load_state_dict(torch.load(save_path)) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label, save_dir=""): save_filename = "%s_net_%s.pth" % (epoch_label, network_label) if not save_dir: save_dir = self.save_dir # print(save_dir) # print(self.save_dir) save_path = os.path.join(save_dir, save_filename) if not os.path.isfile(save_path): print("%s not exists yet!" % save_path) # if network_label == 'G': # raise('Generator must exist!') else: # network.load_state_dict(torch.load(save_path)) try: # print(save_path) network.load_state_dict(torch.load(save_path)) except: pretrained_dict = torch.load(save_path) model_dict = network.state_dict() try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) # if self.opt.verbose: print( "Pretrained network %s has excessive layers; Only loading layers that are used" % network_label ) except: print( "Pretrained network %s has fewer layers; The following are not initialized:" % network_label ) for k, v in pretrained_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v if sys.version_info >= (3, 0): not_initialized = set() else: from sets import Set not_initialized = Set() for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): not_initialized.add(k.split(".")[0]) print(sorted(not_initialized)) network.load_state_dict(model_dict) def update_learning_rate(): pass ================================================ FILE: Global/models/mapping_model.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import os import functools from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks import math from .NonLocal_feature_mapping_model import * class Mapping_Model(nn.Module): def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): super(Mapping_Model, self).__init__() norm_layer = networks.get_norm_layer(norm_type=norm) activation = nn.ReLU(True) model = [] tmp_nc = 64 n_up = 4 print("Mapping: You are using the mapping model without global restoration.") for i in range(n_up): ic = min(tmp_nc * (2 ** i), mc) oc = min(tmp_nc * (2 ** (i + 1)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] for i in range(n_blocks): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] for i in range(n_up - 1): ic = min(64 * (2 ** (4 - i)), mc) oc = min(64 * (2 ** (3 - i)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] if opt.feat_dim > 0 and opt.feat_dim < 64: model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] # model += [nn.Conv2d(64, 1, 1, 1, 0)] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) class Pix2PixHDModel_Mapping(BaseModel): def name(self): return "Pix2PixHDModel_Mapping" def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2): flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2) def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2): return [ l for (l, f) in zip( (g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags ) if f ] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != "none" or not opt.isTrain: torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ##### define networks # Generator network netG_input_nc = input_nc self.netG_A = networks.GlobalGenerator_DCDCv2( netG_input_nc, opt.output_nc, opt.ngf, opt.k_size, opt.n_downsample_global, networks.get_norm_layer(norm_type=opt.norm), opt=opt, ) self.netG_B = networks.GlobalGenerator_DCDCv2( netG_input_nc, opt.output_nc, opt.ngf, opt.k_size, opt.n_downsample_global, networks.get_norm_layer(norm_type=opt.norm), opt=opt, ) if opt.non_local == "Setting_42" or opt.NL_use_mask: if opt.mapping_exp==1: self.mapping_net = Mapping_Model_with_mask_2( min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), opt.map_mc, n_blocks=opt.mapping_n_block, opt=opt, ) else: self.mapping_net = Mapping_Model_with_mask( min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), opt.map_mc, n_blocks=opt.mapping_n_block, opt=opt, ) else: self.mapping_net = Mapping_Model( min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), opt.map_mc, n_blocks=opt.mapping_n_block, opt=opt, ) self.mapping_net.apply(networks.weights_init) if opt.load_pretrain != "": self.load_network(self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain) if not opt.no_load_VAE: self.load_network(self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA) self.load_network(self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB) for param in self.netG_A.parameters(): param.requires_grad = False for param in self.netG_B.parameters(): param.requires_grad = False self.netG_A.eval() self.netG_B.eval() if opt.gpu_ids: self.netG_A.cuda(opt.gpu_ids[0]) self.netG_B.cuda(opt.gpu_ids[0]) self.mapping_net.cuda(opt.gpu_ids[0]) if not self.isTrain: self.load_network(self.mapping_net, "mapping_net", opt.which_epoch) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, opt.use_two_stage_mapping) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() self.criterionFeat_feat = torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss() if self.opt.image_L1: self.criterionImage=torch.nn.L1Loss() else: self.criterionImage = torch.nn.SmoothL1Loss() print(self.criterionFeat_feat) if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG','D_real', 'D_fake', 'Smooth_L1', 'G_Feat_L2_Stage_1') # initialize optimizers # optimizer G if opt.no_TTUR: beta1,beta2=opt.beta1,0.999 G_lr,D_lr=opt.lr,opt.lr else: beta1,beta2=0,0.9 G_lr,D_lr=opt.lr/2,opt.lr*2 if not opt.no_load_VAE: params = list(self.mapping_net.parameters()) self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2)) print("---------- Optimizers initialized -------------") def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation input_concat = input_label label_feat = self.netG_A.forward(input_concat, flow='enc') # print('label:') # print(label_feat.min(), label_feat.max(), label_feat.mean()) #label_feat = label_feat / 16.0 if self.opt.NL_use_mask: label_feat_map=self.mapping_net(label_feat.detach(),inst) else: label_feat_map = self.mapping_net(label_feat.detach()) fake_image = self.netG_B.forward(label_feat_map, flow='dec') image_feat = self.netG_B.forward(real_image, flow='enc') loss_feat_l2_stage_1=0 loss_feat_l2 = self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat if self.opt.feat_gan: # Fake Detection and Loss pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(label_feat.detach(), image_feat) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((label_feat.detach(), label_feat_map), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) else: # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss if pair: pred_real = self.discriminate(input_label, real_image) else: pred_real = self.discriminate(last_label, last_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss and pair: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): tmp = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat loss_G_GAN_Feat += D_weights * feat_weights * tmp else: loss_G_GAN_Feat = torch.zeros(1).to(label.device) # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros(1).to(label.device) smooth_l1_loss=0 if self.opt.Smooth_L1: smooth_l1_loss=self.criterionImage(fake_image,real_image)*self.opt.L1_weight return [ self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake,smooth_l1_loss,loss_feat_l2_stage_1), None if not infer else fake_image ] def inference(self, label, inst): use_gpu = len(self.opt.gpu_ids) > 0 if use_gpu: input_concat = label.data.cuda() inst_data = inst.cuda() else: input_concat = label.data inst_data = inst label_feat = self.netG_A.forward(input_concat, flow="enc") if self.opt.NL_use_mask: if self.opt.inference_optimize: label_feat_map=self.mapping_net.inference_forward(label_feat.detach(),inst_data) else: label_feat_map = self.mapping_net(label_feat.detach(), inst_data) else: label_feat_map = self.mapping_net(label_feat.detach()) fake_image = self.netG_B.forward(label_feat_map, flow="dec") return fake_image class InferenceModel(Pix2PixHDModel_Mapping): def forward(self, label, inst): return self.inference(label, inst) ================================================ FILE: Global/models/models.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch def create_model(opt): if opt.model == "pix2pixHD": from .pix2pixHD_model import Pix2PixHDModel, InferenceModel if opt.isTrain: model = Pix2PixHDModel() else: model = InferenceModel() else: from .ui_model import UIModel model = UIModel() model.initialize(opt) if opt.verbose: print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids) > 1: # pass model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) return model def create_da_model(opt): if opt.model == 'pix2pixHD': from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel if opt.isTrain: model = Pix2PixHDModel() else: model = InferenceModel() else: from .ui_model import UIModel model = UIModel() model.initialize(opt) if opt.verbose: print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids) > 1: #pass model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) return model ================================================ FILE: Global/models/networks.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import functools from torch.autograd import Variable import numpy as np from torch.nn.utils import spectral_norm # from util.util import SwitchNorm2d import torch.nn.functional as F ############################################################################### # Functions ############################################################################### def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find("BatchNorm2d") != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type="instance"): if norm_type == "batch": norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == "instance": norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == "spectral": norm_layer = spectral_norm() elif norm_type == "SwitchNorm": norm_layer = SwitchNorm2d else: raise NotImplementedError("normalization layer [%s] is not found" % norm_type) return norm_layer def print_network(net): if isinstance(net, list): net = net[0] num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print("Total number of parameters: %d" % num_params) def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): norm_layer = get_norm_layer(norm_type=norm) if netG == 'global': # if opt.self_gen: if opt.use_v2: netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) else: netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt) else: raise('generator not implemented!') print(netG) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) print(netD) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) netD.cuda(gpu_ids[0]) netD.apply(weights_init) return netD class GlobalGenerator_DCDCv2(nn.Module): def __init__( self, input_nc, output_nc, ngf=64, k_size=3, n_downsampling=8, norm_layer=nn.BatchNorm2d, padding_type="reflect", opt=None, ): super(GlobalGenerator_DCDCv2, self).__init__() activation = nn.ReLU(True) model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), norm_layer(ngf), activation, ] ### downsample for i in range(opt.start_r): mult = 2 ** i model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] for i in range(opt.start_r, n_downsampling - 1): mult = 2 ** i model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] mult = 2 ** (n_downsampling - 1) if opt.spatio_size == 32: model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] if opt.spatio_size == 64: model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] if opt.feat_dim > 0: model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] self.encoder = nn.Sequential(*model) # decode model = [] if opt.feat_dim > 0: model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] o_pad = 0 if k_size == 4 else 1 mult = 2 ** n_downsampling model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] if opt.spatio_size == 32: model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] if opt.spatio_size == 64: model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] for i in range(1, n_downsampling - opt.start_r): mult = 2 ** (n_downsampling - i) model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] for i in range(n_downsampling - opt.start_r, n_downsampling): mult = 2 ** (n_downsampling - i) model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] if opt.use_segmentation_model: model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] else: model += [ nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), nn.Tanh(), ] self.decoder = nn.Sequential(*model) def forward(self, input, flow="enc_dec"): if flow == "enc": return self.encoder(input) elif flow == "dec": return self.decoder(input) elif flow == "enc_dec": x = self.encoder(input) x = self.decoder(x) return x # Define a resnet block class ResnetBlock(nn.Module): def __init__( self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 ): super(ResnetBlock, self).__init__() self.opt = opt self.dilation = dilation self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(self.dilation)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(self.dilation)] elif padding_type == "zero": p = self.dilation else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), norm_layer(dim), activation, ] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class Encoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): super(Encoder, self).__init__() self.output_nc = output_nc model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), nn.ReLU(True), ] ### downsample for i in range(n_downsampling): mult = 2 ** i model += [ nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), nn.ReLU(True), ] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [ nn.ConvTranspose2d( ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 ), norm_layer(int(ngf * mult / 2)), nn.ReLU(True), ] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input, inst): outputs = self.model(input) # instance-wise average pooling outputs_mean = outputs.clone() inst_list = np.unique(inst.cpu().numpy().astype(int)) for i in inst_list: for b in range(input.size()[0]): indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 for j in range(self.output_nc): output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] mean_feat = torch.mean(output_ins).expand_as(output_ins) outputs_mean[ indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] ] = mean_feat return outputs_mean def SN(module, mode=True): if mode: return torch.nn.utils.spectral_norm(module) return module class NonLocalBlock2D_with_mask_Res(nn.Module): def __init__( self, in_channels, inter_channels, mode="add", re_norm=False, temperature=1.0, use_self=False, cosin=False, ): super(NonLocalBlock2D_with_mask_Res, self).__init__() self.cosin = cosin self.renorm = re_norm self.in_channels = in_channels self.inter_channels = inter_channels self.g = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.W = nn.Conv2d( in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 ) # for pytorch 0.3.1 # nn.init.constant(self.W.weight, 0) # nn.init.constant(self.W.bias, 0) # for pytorch 0.4.0 nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.phi = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.mode = mode self.temperature = temperature self.use_self = use_self norm_layer = get_norm_layer(norm_type="instance") activation = nn.ReLU(True) model = [] for i in range(3): model += [ ResnetBlock( inter_channels, padding_type="reflect", activation=activation, norm_layer=norm_layer, opt=None, ) ] self.res_block = nn.Sequential(*model) def forward(self, x, mask): ## The shape of mask is Batch*1*H*W batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) if self.cosin: theta_x = F.normalize(theta_x, dim=2) phi_x = F.normalize(phi_x, dim=1) f = torch.matmul(theta_x, phi_x) f /= self.temperature f_div_C = F.softmax(f, dim=2) tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 mask = 1 - mask tmp = F.interpolate(tmp, (x.size(2), x.size(3))) mask *= tmp mask_expand = mask.view(batch_size, 1, -1) mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) # mask = 1 - mask # mask=F.interpolate(mask,(x.size(2),x.size(3))) # mask_expand=mask.view(batch_size,1,-1) # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) if self.use_self: mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 # print(mask_expand.shape) # print(f_div_C.shape) f_div_C = mask_expand * f_div_C if self.renorm: f_div_C = F.normalize(f_div_C, p=1, dim=2) ########################### y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) W_y = self.res_block(W_y) if self.mode == "combine": full_mask = mask.repeat(1, self.inter_channels, 1, 1) z = full_mask * x + (1 - full_mask) * W_y return z class MultiscaleDiscriminator(nn.Module): def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers self.getIntermFeat = getIntermFeat for i in range(num_D): netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) if getIntermFeat: for j in range(n_layers+2): setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) else: setattr(self, 'layer'+str(i), netD.model) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input): if self.getIntermFeat: result = [input] for i in range(len(model)): result.append(model[i](result[-1])) return result[1:] else: return [model(input)] def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): if self.getIntermFeat: model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] else: model = getattr(self, 'layer'+str(num_D-1-i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D-1): input_downsampled = self.downsample(input_downsampled) return result # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): super(NLayerDiscriminator, self).__init__() self.getIntermFeat = getIntermFeat self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw-1.0)/2)) sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]] if use_sigmoid: sequence += [[nn.Sigmoid()]] if getIntermFeat: for n in range(len(sequence)): setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) else: sequence_stream = [] for n in range(len(sequence)): sequence_stream += sequence[n] self.model = nn.Sequential(*sequence_stream) def forward(self, input): if self.getIntermFeat: res = [input] for n in range(self.n_layers+2): model = getattr(self, 'model'+str(n)) res.append(model(res[-1])) return res[1:] else: return self.model(input) class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask def __init__(self, in_channels, inter_channels, patch_size): super(Patch_Attention_4, self).__init__() self.patch_size=patch_size # self.g = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) # self.W = nn.Conv2d( # in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 # ) # # for pytorch 0.3.1 # # nn.init.constant(self.W.weight, 0) # # nn.init.constant(self.W.bias, 0) # # for pytorch 0.4.0 # nn.init.constant_(self.W.weight, 0) # nn.init.constant_(self.W.bias, 0) # self.theta = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) # self.phi = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True) norm_layer = get_norm_layer(norm_type="instance") activation = nn.ReLU(True) model = [] for i in range(1): model += [ ResnetBlock( inter_channels, padding_type="reflect", activation=activation, norm_layer=norm_layer, opt=None, ) ] self.res_block = nn.Sequential(*model) def Hard_Compose(self, input, dim, index): # batch index select # input: [B,C,HW] # dim: scalar > 0 # index: [B, HW] views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))] expanse = list(input.size()) expanse[0] = -1 expanse[dim] = -1 index = index.view(views).expand(expanse) return torch.gather(input, dim, index) def forward(self, z, mask): ## The shape of mask is Batch*1*H*W x=self.res_block(z) b,c,h,w=x.shape ## mask resize + dilation # tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 # mask = 1 - mask # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) # mask *= tmp # mask=1-mask ## 1: mask position 0: non-mask mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float() all_patch_num=h*w/self.patch_size/self.patch_size non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1) x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) y_unfold=x_unfold.permute(0,2,1) x_unfold_normalized=F.normalize(x_unfold,dim=1) y_unfold_normalized=F.normalize(y_unfold,dim=2) correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized) correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9) correlation_matrix=F.softmax(correlation_matrix,dim=2) # print(correlation_matrix) R, max_arg=torch.max(correlation_matrix,dim=2) composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg) composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) concat_1=torch.cat((z,composed_fold,mask),dim=1) concat_1=self.F_Combine(concat_1) return concat_1 def inference_forward(self,z,mask): ## Reduce the extra memory cost x=self.res_block(z) b,c,h,w=x.shape ## mask resize + dilation # tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 # mask = 1 - mask # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) # mask *= tmp # mask=1-mask ## 1: mask position 0: non-mask mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num all_patch_num=h*w/self.patch_size/self.patch_size mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0] if len(mask_index)==0: ## No mask patch is selected, no attention is needed composed_fold=x else: unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0] x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) Query_Patch=torch.index_select(x_unfold,2,mask_index) Key_Patch=torch.index_select(x_unfold,2,unmask_index) Query_Patch=Query_Patch.permute(0,2,1) Query_Patch_normalized=F.normalize(Query_Patch,dim=2) Key_Patch_normalized=F.normalize(Key_Patch,dim=1) correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized) correlation_matrix=F.softmax(correlation_matrix,dim=2) R, max_arg=torch.max(correlation_matrix,dim=2) composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg) x_unfold[:,:,mask_index]=composed_unfold composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) concat_1=torch.cat((z,composed_fold,mask),dim=1) concat_1=self.F_Combine(concat_1) return concat_1 ############################################################################## # Losses ############################################################################## class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) self.real_label_var = Variable(real_tensor, requires_grad=False) target_tensor = self.real_label_var else: create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) self.fake_label_var = Variable(fake_tensor, requires_grad=False) target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): if isinstance(input[0], list): loss = 0 for input_i in input: pred = input_i[-1] target_tensor = self.get_target_tensor(pred, target_is_real) loss += self.loss(pred, target_tensor) return loss else: target_tensor = self.get_target_tensor(input[-1], target_is_real) return self.loss(input[-1], target_tensor) ####################################### VGG Loss from torchvision import models class VGG19_torch(torch.nn.Module): def __init__(self, requires_grad=False): super(VGG19_torch, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class VGGLoss_torch(nn.Module): def __init__(self, gpu_ids): super(VGGLoss_torch, self).__init__() self.vgg = VGG19_torch().cuda() self.criterion = nn.L1Loss() self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss ================================================ FILE: Global/models/pix2pixHD_model.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import torch import os from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1) def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1): return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false self.gen_features = self.use_features and not self.opt.load_features ## it is also false input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) print("---------- G Networks reloaded -------------") if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) print("---------- D Networks reloaded -------------") if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() # self.criterionImage = torch.nn.SmoothL1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1') # initialize optimizers # optimizer G params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) print("---------- Optimizers initialized -------------") if opt.continue_train: self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) for param_groups in self.optimizer_D.param_groups: self.old_lr=param_groups['lr'] print("---------- Optimizers reloaded -------------") print("---------- Current LR is %.8f -------------"%(self.old_lr)) ## We also want to re-load the parameters of optimizer. def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) if self.opt.label_feat: inst_map = label_map.cuda() return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): if input_label is None: input_concat = test_image.detach() else: input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label hiddens = self.netG.forward(input_concat, 'enc') noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) fake_image = self.netG.forward(hiddens + noise, 'dec') if self.opt.no_cgan: # Fake Detection and Loss pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(None, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(fake_image) loss_G_GAN = self.criterionGAN(pred_fake, True) else: # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat smooth_l1_loss=0 return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ] def inference(self, label, inst, image=None, feat=None): # Encode Inputs image = Variable(image) if image is not None else None input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) # Fake Generation if self.use_features: if self.opt.use_encoded_image: # encode the real image to get feature map feat_map = self.netE.forward(real_image, inst_map) else: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(input_concat) else: fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == int(i)).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] if self.opt.data_type==16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num+1)) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 idx = (inst == int(i)).nonzero() num = idx.size()[0] idx = idx[num//2,:] val = np.zeros((1, feat_num+1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) if self.opt.data_type==16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) self.save_optimizer(self.optimizer_G,"G",which_epoch) self.save_optimizer(self.optimizer_D,"D",which_epoch) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr class InferenceModel(Pix2PixHDModel): def forward(self, inp): label, inst = inp return self.inference(label, inst) ================================================ FILE: Global/models/pix2pixHD_model_DA.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import torch import os from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake): return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false self.gen_features = self.use_features and not self.opt.load_features ## it is also false input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, 1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) print("---------- G Networks reloaded -------------") if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path) print("---------- D Networks reloaded -------------") # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake') # initialize optimizers # optimizer G params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) params = list(self.feat_D.parameters()) self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) print("---------- Optimizers initialized -------------") if opt.continue_train: self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch) for param_groups in self.optimizer_D.param_groups: self.old_lr = param_groups['lr'] print("---------- Optimizers reloaded -------------") print("---------- Current LR is %.8f -------------" % (self.old_lr)) ## We also want to re-load the parameters of optimizer. def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) if self.opt.label_feat: inst_map = label_map.cuda() return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): if input_label is None: input_concat = test_image.detach() else: input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def feat_discriminate(self,input): return self.feat_D.forward(input.detach()) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label hiddens = self.netG.forward(input_concat, 'enc') noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) fake_image = self.netG.forward(hiddens + noise, 'dec') #################### ##### GAN for the intermediate feature real_old_feat =[] syn_feat = [] for index,x in enumerate(inst): if x==1: real_old_feat.append(hiddens[index].unsqueeze(0)) else: syn_feat.append(hiddens[index].unsqueeze(0)) L=min(len(real_old_feat),len(syn_feat)) real_old_feat=real_old_feat[:L] syn_feat=syn_feat[:L] real_old_feat=torch.cat(real_old_feat,0) syn_feat=torch.cat(syn_feat,0) pred_fake_feat=self.feat_discriminate(real_old_feat) loss_featD_fake = self.criterionGAN(pred_fake_feat, False) pred_real_feat=self.feat_discriminate(syn_feat) loss_featD_real = self.criterionGAN(pred_real_feat, True) pred_fake_feat_G=self.feat_D.forward(real_old_feat) loss_G_featD=self.criterionGAN(pred_fake_feat_G,True) ##################################### if self.opt.no_cgan: # Fake Detection and Loss pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(None, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(fake_image) loss_G_GAN = self.criterionGAN(pred_fake, True) else: # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i]) - 1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake), None if not infer else fake_image] def inference(self, label, inst, image=None, feat=None): # Encode Inputs image = Variable(image) if image is not None else None input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) # Fake Generation if self.use_features: if self.opt.use_encoded_image: # encode the real image to get feature map feat_map = self.netE.forward(real_image, inst_map) else: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(input_concat) else: fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == int(i)).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] if self.opt.data_type == 16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num + 1)) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 idx = (inst == int(i)).nonzero() num = idx.size()[0] idx = idx[num // 2, :] val = np.zeros((1, feat_num + 1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) if self.opt.data_type == 16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids) self.save_optimizer(self.optimizer_G, "G", which_epoch) self.save_optimizer(self.optimizer_D, "D", which_epoch) self.save_optimizer(self.optimizer_featD,'featD',which_epoch) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr for param_group in self.optimizer_featD.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr class InferenceModel(Pix2PixHDModel): def forward(self, inp): label, inst = inp return self.inference(label, inst) ================================================ FILE: Global/options/__init__.py ================================================ ================================================ FILE: Global/options/base_options.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import argparse import os from util import util import torch class BaseOptions: def __init__(self): self.parser = argparse.ArgumentParser() self.initialized = False def initialize(self): # experiment specifics self.parser.add_argument( "--name", type=str, default="label2city", help="name of the experiment. It decides where to store samples and models", ) self.parser.add_argument( "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU" ) self.parser.add_argument( "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here" ) ## note: to add this param when using philly # self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here') ################### This is necessary for philly self.parser.add_argument( "--outputs_dir", type=str, default="./outputs", help="models are saved here" ) ## note: to add this param when using philly Please end with '/' self.parser.add_argument("--model", type=str, default="pix2pixHD", help="which model to use") self.parser.add_argument( "--norm", type=str, default="instance", help="instance normalization or batch normalization" ) self.parser.add_argument("--use_dropout", action="store_true", help="use dropout for the generator") self.parser.add_argument( "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit", ) self.parser.add_argument("--verbose", action="store_true", default=False, help="toggles verbose") # input/output sizes self.parser.add_argument("--batchSize", type=int, default=1, help="input batch size") self.parser.add_argument("--loadSize", type=int, default=1024, help="scale images to this size") self.parser.add_argument("--fineSize", type=int, default=512, help="then crop to this size") self.parser.add_argument("--label_nc", type=int, default=35, help="# of input label channels") self.parser.add_argument("--input_nc", type=int, default=3, help="# of input image channels") self.parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels") # for setting inputs self.parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/") self.parser.add_argument( "--resize_or_crop", type=str, default="scale_width", help="scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]", ) self.parser.add_argument( "--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly", ) self.parser.add_argument( "--no_flip", action="store_true", help="if specified, do not flip the images for data argumentation", ) self.parser.add_argument("--nThreads", default=2, type=int, help="# threads for loading data") self.parser.add_argument( "--max_dataset_size", type=int, default=float("inf"), help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", ) # for displays self.parser.add_argument("--display_winsize", type=int, default=512, help="display window size") self.parser.add_argument( "--tf_log", action="store_true", help="if specified, use tensorboard logging. Requires tensorflow installed", ) # for generator self.parser.add_argument("--netG", type=str, default="global", help="selects model to use for netG") self.parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer") self.parser.add_argument("--k_size", type=int, default=3, help="# kernel size conv layer") self.parser.add_argument("--use_v2", action="store_true", help="use DCDCv2") self.parser.add_argument("--mc", type=int, default=1024, help="# max channel") self.parser.add_argument("--start_r", type=int, default=3, help="start layer to use resblock") self.parser.add_argument( "--n_downsample_global", type=int, default=4, help="number of downsampling layers in netG" ) self.parser.add_argument( "--n_blocks_global", type=int, default=9, help="number of residual blocks in the global generator network", ) self.parser.add_argument( "--n_blocks_local", type=int, default=3, help="number of residual blocks in the local enhancer network", ) self.parser.add_argument( "--n_local_enhancers", type=int, default=1, help="number of local enhancers to use" ) self.parser.add_argument( "--niter_fix_global", type=int, default=0, help="number of epochs that we only train the outmost local enhancer", ) self.parser.add_argument( "--load_pretrain", type=str, default="", help="load the pretrained model from the specified location", ) # for instance-wise features self.parser.add_argument( "--no_instance", action="store_true", help="if specified, do *not* add instance map as input" ) self.parser.add_argument( "--instance_feat", action="store_true", help="if specified, add encoded instance features as input", ) self.parser.add_argument( "--label_feat", action="store_true", help="if specified, add encoded label features as input" ) self.parser.add_argument("--feat_num", type=int, default=3, help="vector length for encoded features") self.parser.add_argument( "--load_features", action="store_true", help="if specified, load precomputed feature maps" ) self.parser.add_argument( "--n_downsample_E", type=int, default=4, help="# of downsampling layers in encoder" ) self.parser.add_argument( "--nef", type=int, default=16, help="# of encoder filters in the first conv layer" ) self.parser.add_argument("--n_clusters", type=int, default=10, help="number of clusters for features") # diy self.parser.add_argument("--self_gen", action="store_true", help="self generate") self.parser.add_argument( "--mapping_n_block", type=int, default=3, help="number of resblock in mapping" ) self.parser.add_argument("--map_mc", type=int, default=64, help="max channel of mapping") self.parser.add_argument("--kl", type=float, default=0, help="KL Loss") self.parser.add_argument( "--load_pretrainA", type=str, default="", help="load the pretrained model from the specified location", ) self.parser.add_argument( "--load_pretrainB", type=str, default="", help="load the pretrained model from the specified location", ) self.parser.add_argument("--feat_gan", action="store_true") self.parser.add_argument("--no_cgan", action="store_true") self.parser.add_argument("--map_unet", action="store_true") self.parser.add_argument("--map_densenet", action="store_true") self.parser.add_argument("--fcn", action="store_true") self.parser.add_argument("--is_image", action="store_true", help="train image recon only pair data") self.parser.add_argument("--label_unpair", action="store_true") self.parser.add_argument("--mapping_unpair", action="store_true") self.parser.add_argument("--unpair_w", type=float, default=1.0) self.parser.add_argument("--pair_num", type=int, default=-1) self.parser.add_argument("--Gan_w", type=float, default=1) self.parser.add_argument("--feat_dim", type=int, default=-1) self.parser.add_argument("--abalation_vae_len", type=int, default=-1) ######################### useless, just to cooperate with docker self.parser.add_argument("--gpu", type=str) self.parser.add_argument("--dataDir", type=str) self.parser.add_argument("--modelDir", type=str) self.parser.add_argument("--logDir", type=str) self.parser.add_argument("--data_dir", type=str) self.parser.add_argument("--use_skip_model", action="store_true") self.parser.add_argument("--use_segmentation_model", action="store_true") self.parser.add_argument("--spatio_size", type=int, default=64) self.parser.add_argument("--test_random_crop", action="store_true") ########################## self.parser.add_argument("--contain_scratch_L", action="store_true") self.parser.add_argument( "--mask_dilation", type=int, default=0 ) ## Don't change the input, only dilation the mask self.parser.add_argument( "--irregular_mask", type=str, default="", help="This is the root of the mask" ) self.parser.add_argument( "--mapping_net_dilation", type=int, default=1, help="This parameter is the dilation size of the translation net", ) self.parser.add_argument( "--VOC", type=str, default="VOC_RGB_JPEGImages.bigfile", help="The root of VOC dataset" ) self.parser.add_argument("--non_local", type=str, default="", help="which non_local setting") self.parser.add_argument( "--NL_fusion_method", type=str, default="add", help="how to fuse the origin feature and nl feature", ) self.parser.add_argument( "--NL_use_mask", action="store_true", help="If use mask while using Non-local mapping model" ) self.parser.add_argument( "--correlation_renormalize", action="store_true", help="Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight", ) self.parser.add_argument("--Smooth_L1", action="store_true", help="Use L1 Loss in image level") self.parser.add_argument( "--face_restore_setting", type=int, default=1, help="This is for the aligned face restoration" ) self.parser.add_argument("--face_clean_url", type=str, default="") self.parser.add_argument("--syn_input_url", type=str, default="") self.parser.add_argument("--syn_gt_url", type=str, default="") self.parser.add_argument( "--test_on_synthetic", action="store_true", help="If you want to test on the synthetic data, enable this parameter", ) self.parser.add_argument("--use_SN", action="store_true", help="Add SN to every parametric layer") self.parser.add_argument( "--use_two_stage_mapping", action="store_true", help="choose the model which uses two stage" ) self.parser.add_argument("--L1_weight", type=float, default=10.0) self.parser.add_argument("--softmax_temperature", type=float, default=1.0) self.parser.add_argument( "--patch_similarity", action="store_true", help="Enable this denotes using 3*3 patch to calculate similarity", ) self.parser.add_argument( "--use_self", action="store_true", help="Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)", ) self.parser.add_argument("--use_own_dataset", action="store_true") self.parser.add_argument( "--test_hole_two_folders", action="store_true", help="Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively", ) self.parser.add_argument( "--no_hole", action="store_true", help="While test the full_model on non_scratch data, do not add random mask into the real old photos", ) ## Only for testing self.parser.add_argument( "--random_hole", action="store_true", help="While training the full model, 50% probability add hole", ) self.parser.add_argument("--NL_res", action="store_true", help="NL+Resdual Block") self.parser.add_argument("--image_L1", action="store_true", help="Image level loss: L1") self.parser.add_argument( "--hole_image_no_mask", action="store_true", help="while testing, give hole image but not give the mask", ) self.parser.add_argument( "--down_sample_degradation", action="store_true", help="down_sample the image only, corresponds to [down_sample_face]", ) self.parser.add_argument( "--norm_G", type=str, default="spectralinstance", help="The norm type of Generator" ) self.parser.add_argument( "--init_G", type=str, default="xavier", help="normal|xavier|xavier_uniform|kaiming|orthogonal|none", ) self.parser.add_argument("--use_new_G", action="store_true") self.parser.add_argument("--use_new_D", action="store_true") self.parser.add_argument( "--only_voc", action="store_true", help="test the trianed celebA face model using VOC face" ) self.parser.add_argument( "--cosin_similarity", action="store_true", help="For non-local, using cosin to calculate the similarity", ) self.parser.add_argument( "--downsample_mode", type=str, default="nearest", help="For partial non-local, choose how to downsample the mask", ) self.parser.add_argument("--mapping_exp",type=int,default=0,help='Default 0: original PNL|1: Multi-Scale Patch Attention') self.parser.add_argument("--inference_optimize",action='store_true',help='optimize the memory cost') self.initialized = True def parse(self, save=True): if not self.initialized: self.initialize() self.opt = self.parser.parse_args() self.opt.isTrain = self.isTrain # train or test str_ids = self.opt.gpu_ids.split(",") self.opt.gpu_ids = [] for str_id in str_ids: int_id = int(str_id) if int_id >= 0: self.opt.gpu_ids.append(int_id) # set gpu ids if len(self.opt.gpu_ids) > 0: # pass torch.cuda.set_device(self.opt.gpu_ids[0]) args = vars(self.opt) # print('------------ Options -------------') # for k, v in sorted(args.items()): # print('%s: %s' % (str(k), str(v))) # print('-------------- End ----------------') # save to the disk expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) util.mkdirs(expr_dir) if save and not self.opt.continue_train: file_name = os.path.join(expr_dir, "opt.txt") with open(file_name, "wt") as opt_file: opt_file.write("------------ Options -------------\n") for k, v in sorted(args.items()): opt_file.write("%s: %s\n" % (str(k), str(v))) opt_file.write("-------------- End ----------------\n") return self.opt ================================================ FILE: Global/options/test_options.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .base_options import BaseOptions class TestOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument("--ntest", type=int, default=float("inf"), help="# of test examples.") self.parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") self.parser.add_argument( "--aspect_ratio", type=float, default=1.0, help="aspect ratio of result images" ) self.parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc") self.parser.add_argument( "--which_epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model", ) self.parser.add_argument("--how_many", type=int, default=50, help="how many test images to run") self.parser.add_argument( "--cluster_path", type=str, default="features_clustered_010.npy", help="the path for clustered results of encoded features", ) self.parser.add_argument( "--use_encoded_image", action="store_true", help="if specified, encode the real image to get the feature map", ) self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") self.parser.add_argument( "--start_epoch", type=int, default=-1, help="write the start_epoch of iter.txt into this parameter", ) self.parser.add_argument("--test_dataset", type=str, default="Real_RGB_old.bigfile") self.parser.add_argument( "--no_degradation", action="store_true", help="when train the mapping, enable this parameter --> no degradation will be added into clean image", ) self.parser.add_argument( "--no_load_VAE", action="store_true", help="when train the mapping, enable this parameter --> random initialize the encoder an decoder", ) self.parser.add_argument( "--use_v2_degradation", action="store_true", help="enable this parameter --> 4 kinds of degradations will be used to synthesize corruption", ) self.parser.add_argument("--use_vae_which_epoch", type=str, default="latest") self.isTrain = False self.parser.add_argument("--generate_pair", action="store_true") self.parser.add_argument("--multi_scale_test", type=float, default=0.5) self.parser.add_argument("--multi_scale_threshold", type=float, default=0.5) self.parser.add_argument( "--mask_need_scale", action="store_true", help="enable this param meas that the pixel range of mask is 0-255", ) self.parser.add_argument("--scale_num", type=int, default=1) self.parser.add_argument( "--save_feature_url", type=str, default="", help="While extracting the features, where to put" ) self.parser.add_argument( "--test_input", type=str, default="", help="A directory or a root of bigfile" ) self.parser.add_argument("--test_mask", type=str, default="", help="A directory or a root of bigfile") self.parser.add_argument("--test_gt", type=str, default="", help="A directory or a root of bigfile") self.parser.add_argument( "--scale_input", action="store_true", help="While testing, choose to scale the input firstly" ) self.parser.add_argument( "--save_feature_name", type=str, default="features.json", help="The name of saved features" ) self.parser.add_argument( "--test_rgb_old_wo_scratch", action="store_true", help="Same setting with origin test" ) self.parser.add_argument("--test_mode", type=str, default="Crop", help="Scale|Full|Crop") self.parser.add_argument("--Quality_restore", action="store_true", help="For RGB images") self.parser.add_argument( "--Scratch_and_Quality_restore", action="store_true", help="For scratched images" ) self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches') ================================================ FILE: Global/options/train_options.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .base_options import BaseOptions class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) # for displays self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') # for training self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') # self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') self.parser.add_argument('--training_dataset',type=str,default='',help='training use which dataset') # for discriminators self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') self.parser.add_argument('--l2_feat', type=float, help='weight for feature mapping loss') self.parser.add_argument('--use_l1_feat', action='store_true', help='use l1 for feat mapping') self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') self.parser.add_argument('--gan_type', type=str, default='lsgan', help='Choose the loss type of GAN') self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--norm_D',type=str, default='spectralinstance', help='instance normalization or batch normalization') self.parser.add_argument('--init_D',type=str,default='xavier',help='normal|xavier|xavier_uniform|kaiming|orthogonal|none') self.parser.add_argument('--no_TTUR',action='store_true',help='No TTUR') self.parser.add_argument('--start_epoch',type=int,default=-1,help='write the start_epoch of iter.txt into this parameter') self.parser.add_argument('--no_degradation',action='store_true',help='when train the mapping, enable this parameter --> no degradation will be added into clean image') self.parser.add_argument('--no_load_VAE',action='store_true',help='when train the mapping, enable this parameter --> random initialize the encoder an decoder') self.parser.add_argument('--use_v2_degradation',action='store_true',help='enable this parameter --> 4 kinds of degradations will be used to synthesize corruption') self.parser.add_argument('--use_vae_which_epoch',type=str,default='200') self.parser.add_argument('--use_focal_loss',action='store_true') self.parser.add_argument('--mask_need_scale',action='store_true',help='enable this param means that the pixel range of mask is 0-255') self.parser.add_argument('--positive_weight',type=float,default=1.0,help='(For scratch detection) Since the scratch number is less, and we use a weight strategy. This parameter means that we want to decrease the weight.') self.parser.add_argument('--no_update_lr',action='store_true',help='use this means we do not update the LR while training') self.isTrain = True ================================================ FILE: Global/test.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os from collections import OrderedDict from torch.autograd import Variable from options.test_options import TestOptions from models.models import create_model from models.mapping_model import Pix2PixHDModel_Mapping import util.util as util from PIL import Image import torch import torchvision.utils as vutils import torchvision.transforms as transforms import numpy as np import cv2 def data_transforms(img, method=Image.BILINEAR, scale=False): ow, oh = img.size pw, ph = ow, oh if scale == True: if ow < oh: ow = 256 oh = ph / pw * 256 else: oh = 256 ow = pw / ph * 256 h = int(round(oh / 4) * 4) w = int(round(ow / 4) * 4) if (h == ph) and (w == pw): return img return img.resize((w, h), method) def data_transforms_rgb_old(img): w, h = img.size A = img if w < 256 or h < 256: A = transforms.Scale(256, Image.BILINEAR)(img) return transforms.CenterCrop(256)(A) def irregular_hole_synthesize(img, mask): img_np = np.array(img).astype("uint8") mask_np = np.array(mask).astype("uint8") mask_np = mask_np / 255 img_new = img_np * (1 - mask_np) + mask_np * 255 hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") return hole_img def parameter_set(opt): ## Default parameters opt.serial_batches = True # no shuffle opt.no_flip = True # no flip opt.label_nc = 0 opt.n_downsample_global = 3 opt.mc = 64 opt.k_size = 4 opt.start_r = 1 opt.mapping_n_block = 6 opt.map_mc = 512 opt.no_instance = True opt.checkpoints_dir = "./checkpoints/restoration" ## if opt.Quality_restore: opt.name = "mapping_quality" opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") if opt.Scratch_and_Quality_restore: opt.NL_res = True opt.use_SN = True opt.correlation_renormalize = True opt.NL_use_mask = True opt.NL_fusion_method = "combine" opt.non_local = "Setting_42" opt.name = "mapping_scratch" opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") if opt.HR: opt.mapping_exp = 1 opt.inference_optimize = True opt.mask_dilation = 3 opt.name = "mapping_Patch_Attention" if __name__ == "__main__": opt = TestOptions().parse(save=False) parameter_set(opt) model = Pix2PixHDModel_Mapping() model.initialize(opt) model.eval() if not os.path.exists(opt.outputs_dir + "/" + "input_image"): os.makedirs(opt.outputs_dir + "/" + "input_image") if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): os.makedirs(opt.outputs_dir + "/" + "restored_image") if not os.path.exists(opt.outputs_dir + "/" + "origin"): os.makedirs(opt.outputs_dir + "/" + "origin") dataset_size = 0 input_loader = os.listdir(opt.test_input) dataset_size = len(input_loader) input_loader.sort() if opt.test_mask != "": mask_loader = os.listdir(opt.test_mask) dataset_size = len(os.listdir(opt.test_mask)) mask_loader.sort() img_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) mask_transform = transforms.ToTensor() for i in range(dataset_size): input_name = input_loader[i] input_file = os.path.join(opt.test_input, input_name) if not os.path.isfile(input_file): print("Skipping non-file %s" % input_name) continue input = Image.open(input_file).convert("RGB") print("Now you are processing %s" % (input_name)) if opt.NL_use_mask: mask_name = mask_loader[i] mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") if opt.mask_dilation != 0: kernel = np.ones((3,3),np.uint8) mask = np.array(mask) mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) mask = Image.fromarray(mask.astype('uint8')) origin = input input = irregular_hole_synthesize(input, mask) mask = mask_transform(mask) mask = mask[:1, :, :] ## Convert to single channel mask = mask.unsqueeze(0) input = img_transform(input) input = input.unsqueeze(0) else: if opt.test_mode == "Scale": input = data_transforms(input, scale=True) if opt.test_mode == "Full": input = data_transforms(input, scale=False) if opt.test_mode == "Crop": input = data_transforms_rgb_old(input) origin = input input = img_transform(input) input = input.unsqueeze(0) mask = torch.zeros_like(input) ### Necessary input try: with torch.no_grad(): generated = model.inference(input, mask) except Exception as ex: print("Skip %s due to an error:\n%s" % (input_name, str(ex))) continue if input_name.endswith(".jpg"): input_name = input_name[:-4] + ".png" image_grid = vutils.save_image( (input + 1.0) / 2.0, opt.outputs_dir + "/input_image/" + input_name, nrow=1, padding=0, normalize=True, ) image_grid = vutils.save_image( (generated.data.cpu() + 1.0) / 2.0, opt.outputs_dir + "/restored_image/" + input_name, nrow=1, padding=0, normalize=True, ) origin.save(opt.outputs_dir + "/origin/" + input_name) ================================================ FILE: Global/train_domain_A.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import time from collections import OrderedDict from options.train_options import TrainOptions from data.data_loader import CreateDataLoader from models.models import create_da_model import util.util as util from util.visualizer import Visualizer import os import numpy as np import torch import torchvision.utils as vutils from torch.autograd import Variable opt = TrainOptions().parse() if opt.debug: opt.display_freq = 1 opt.print_freq = 1 opt.niter = 1 opt.niter_decay = 0 opt.max_dataset_size = 10 data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(dataset) * opt.batchSize print('#training images = %d' % dataset_size) path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') visualizer = Visualizer(opt) iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if opt.continue_train: try: start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 # opt.which_epoch=start_epoch-1 model = create_da_model(opt) fd = open(path, 'w') fd.write(str(model.module.netG)) fd.write(str(model.module.netD)) fd.close() total_steps = (start_epoch - 1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size for i, data in enumerate(dataset, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### losses, generated = model(Variable(data['label']), Variable(data['inst']), Variable(data['image']), Variable(data['feat']), infer=save_fake) # sum per device losses losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses] loss_dict = dict(zip(model.module.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_featD=(loss_dict['featD_fake'] + loss_dict['featD_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0) + loss_dict['G_KL'] + loss_dict['G_featD'] ############### Backward Pass #################### # update generator weights model.module.optimizer_G.zero_grad() loss_G.backward() model.module.optimizer_G.step() # update discriminator weights model.module.optimizer_D.zero_grad() loss_D.backward() model.module.optimizer_D.step() model.module.optimizer_featD.zero_grad() loss_featD.backward() model.module.optimizer_featD.step() # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr) visualizer.plot_current_errors(errors, total_steps) ### display output images if save_fake: if not os.path.exists(opt.outputs_dir + opt.name): os.makedirs(opt.outputs_dir + opt.name) imgs_num = data['label'].shape[0] imgs = torch.cat((data['label'], generated.data.cpu(), data['image']), 0) imgs = (imgs + 1.) / 2.0 try: image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str( total_steps) + '.png', nrow=imgs_num, padding=0, normalize=True) except OSError as err: print(err) if epoch_iter >= dataset_size: break # end of epoch iter_end_time = time.time() print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.module.save('latest') model.module.save(epoch) np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.module.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.module.update_learning_rate() ================================================ FILE: Global/train_domain_B.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import time from collections import OrderedDict from options.train_options import TrainOptions from data.data_loader import CreateDataLoader from models.models import create_model import util.util as util from util.visualizer import Visualizer import os import numpy as np import torch import torchvision.utils as vutils from torch.autograd import Variable import random opt = TrainOptions().parse() if opt.debug: opt.display_freq = 1 opt.print_freq = 1 opt.niter = 1 opt.niter_decay = 0 opt.max_dataset_size = 10 data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(dataset) * opt.batchSize print('#training images = %d' % dataset_size) path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') visualizer = Visualizer(opt) iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if opt.continue_train: try: start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 # opt.which_epoch=start_epoch-1 model = create_model(opt) fd = open(path, 'w') fd.write(str(model.module.netG)) fd.write(str(model.module.netD)) fd.close() total_steps = (start_epoch-1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size for i, data in enumerate(dataset, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### losses, generated = model(Variable(data['label']), Variable(data['inst']), Variable(data['image']), Variable(data['feat']), infer=save_fake) # sum per device losses losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict = dict(zip(model.module.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict['G_KL'] + loss_dict.get('Smooth_L1',0)*opt.L1_weight ############### Backward Pass #################### # update generator weights model.module.optimizer_G.zero_grad() loss_G.backward() model.module.optimizer_G.step() # update discriminator weights model.module.optimizer_D.zero_grad() loss_D.backward() model.module.optimizer_D.step() #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr) visualizer.plot_current_errors(errors, total_steps) ### display output images if save_fake: if not os.path.exists(opt.outputs_dir + opt.name): os.makedirs(opt.outputs_dir + opt.name) imgs_num = 5 imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) imgs = (imgs + 1.) / 2.0 try: image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png', nrow=imgs_num, padding=0, normalize=True) except OSError as err: print(err) if epoch_iter >= dataset_size: break # end of epoch iter_end_time = time.time() print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.module.save('latest') model.module.save(epoch) np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.module.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.module.update_learning_rate() ================================================ FILE: Global/train_mapping.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import time from collections import OrderedDict from options.train_options import TrainOptions from data.data_loader import CreateDataLoader from models.mapping_model import Pix2PixHDModel_Mapping import util.util as util from util.visualizer import Visualizer import os import numpy as np import torch import torchvision.utils as vutils from torch.autograd import Variable import datetime import random opt = TrainOptions().parse() visualizer = Visualizer(opt) iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if opt.continue_train: try: start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 if opt.which_epoch != "latest": start_epoch=int(opt.which_epoch) visualizer.print_save('Notice : Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter)) opt.start_epoch=start_epoch ### temp for continue train unfixed decoder data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(dataset) * opt.batchSize print('#training images = %d' % dataset_size) model = Pix2PixHDModel_Mapping() model.initialize(opt) path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') fd = open(path, 'w') if opt.use_skip_model: fd.write(str(model.mapping_net)) fd.close() else: fd.write(str(model.netG_A)) fd.write(str(model.mapping_net)) fd.close() if opt.isTrain and len(opt.gpu_ids) > 1: model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) total_steps = (start_epoch-1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq ### used for recovering training for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_s_t=datetime.datetime.now() epoch_start_time = time.time() if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size for i, data in enumerate(dataset, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### #print(pair) losses, generated = model(Variable(data['label']), Variable(data['inst']), Variable(data['image']), Variable(data['feat']), infer=save_fake) # sum per device losses losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict = dict(zip(model.module.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict.get('G_Feat_L2', 0) +loss_dict.get('Smooth_L1', 0)+loss_dict.get('G_Feat_L2_Stage_1',0) #loss_G = loss_dict['G_Feat_L2'] ############### Backward Pass #################### # update generator weights model.module.optimizer_mapping.zero_grad() loss_G.backward() model.module.optimizer_mapping.step() # update discriminator weights model.module.optimizer_D.zero_grad() loss_D.backward() model.module.optimizer_D.step() ############## Display results and errors ########## ### print out errors if i == 0 or total_steps % opt.print_freq == print_delta: errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t,model.module.old_lr) visualizer.plot_current_errors(errors, total_steps) ### display output images if save_fake: if not os.path.exists(opt.outputs_dir + opt.name): os.makedirs(opt.outputs_dir + opt.name) imgs_num = 5 if opt.NL_use_mask: mask=data['inst'][:imgs_num] mask=mask.repeat(1,3,1,1) imgs = torch.cat((data['label'][:imgs_num], mask,generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) else: imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) imgs=(imgs+1.)/2.0 ## de-normalize try: image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png', nrow=imgs_num, padding=0, normalize=True) except OSError as err: print(err) if epoch_iter >= dataset_size: break # end of epoch epoch_e_t=datetime.datetime.now() iter_end_time = time.time() print('End of epoch %d / %d \t Time Taken: %s' % (epoch, opt.niter + opt.niter_decay, str(epoch_e_t-epoch_s_t))) ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.module.save('latest') model.module.save(epoch) np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.module.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.module.update_learning_rate() ================================================ FILE: Global/util/__init__.py ================================================ ================================================ FILE: Global/util/image_pool.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import random import torch from torch.autograd import Variable class ImagePool: def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images return_images = [] for image in images.data: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: random_id = random.randint(0, self.pool_size - 1) tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) return_images = Variable(torch.cat(return_images, 0)) return return_images ================================================ FILE: Global/util/util.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import print_function import torch import numpy as np from PIL import Image import numpy as np import os import torch.nn as nn # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8, normalize=True): if isinstance(image_tensor, list): image_numpy = [] for i in range(len(image_tensor)): image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) return image_numpy image_numpy = image_tensor.cpu().float().numpy() if normalize: image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 else: image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 image_numpy = np.clip(image_numpy, 0, 255) if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: image_numpy = image_numpy[:, :, 0] return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map def tensor2label(label_tensor, n_label, imtype=np.uint8): if n_label == 0: return tensor2im(label_tensor, imtype) label_tensor = label_tensor.cpu().float() if label_tensor.size()[0] > 1: label_tensor = label_tensor.max(0, keepdim=True)[1] label_tensor = Colorize(n_label)(label_tensor) label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) return label_numpy.astype(imtype) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) ================================================ FILE: Global/util/visualizer.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import os import ntpath import time from . import util #from . import html import scipy.misc try: from StringIO import StringIO # Python 2.7 except ImportError: from io import BytesIO # Python 3.x class Visualizer(): def __init__(self, opt): # self.opt = opt self.tf_log = opt.tf_log self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name if self.tf_log: import tensorflow as tf self.tf = tf self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') self.writer = tf.summary.FileWriter(self.log_dir) if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, step): if self.tf_log: # show images in tensorboard output img_summaries = [] for label, image_numpy in visuals.items(): # Write the image to a string try: s = StringIO() except: s = BytesIO() scipy.misc.toimage(image_numpy).save(s, format="jpeg") # Create an Image object img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) # Create a Summary value img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) # Create and write Summary summary = self.tf.Summary(value=img_summaries) self.writer.add_summary(summary, step) if self.use_html: # save images to a html file for label, image_numpy in visuals.items(): if isinstance(image_numpy, list): for i in range(len(image_numpy)): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) util.save_image(image_numpy[i], img_path) else: img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): if isinstance(image_numpy, list): for i in range(len(image_numpy)): img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) ims.append(img_path) txts.append(label+str(i)) links.append(img_path) else: img_path = 'epoch%.3d_%s.jpg' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) if len(ims) < 10: webpage.add_images(ims, txts, links, width=self.win_size) else: num = int(round(len(ims)/2.0)) webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) webpage.save() # errors: dictionary of error labels and values def plot_current_errors(self, errors, step): if self.tf_log: for tag, value in errors.items(): summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, t, lr): message = '(epoch: %d, iters: %d, time: %.3f lr: %.5f) ' % (epoch, i, t, lr) for k, v in errors.items(): if v != 0: message += '%s: %.3f ' % (k, v) print(message) with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) def print_save(self,message): print(message) with open(self.log_name,"a") as log_file: log_file.write('%s\n'%message) # save image to the disk def save_images(self, webpage, visuals, image_path): image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): image_name = '%s_%s.jpg' % (name, label) save_path = os.path.join(image_dir, image_name) util.save_image(image_numpy, save_path) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=self.win_size) ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Microsoft Corporation. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE ================================================ FILE: README.md ================================================ # Old Photo Restoration (Official PyTorch Implementation) ### [Project Page](http://raywzy.com/Old_Photo/) | [Paper (CVPR version)](https://arxiv.org/abs/2004.09484) | [Paper (Journal version)](https://arxiv.org/pdf/2009.07047v1.pdf) | [Pretrained Model](https://hkustconnect-my.sharepoint.com/:f:/g/personal/bzhangai_connect_ust_hk/Em0KnYOeSSxFtp4g_dhWdf0BdeT3tY12jIYJ6qvSf300cA?e=nXkJH2) | [Colab Demo](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing) | [Replicate Demo & Docker Image](https://replicate.ai/zhangmozhe/bringing-old-photos-back-to-life) :fire: **Bringing Old Photos Back to Life, CVPR2020 (Oral)** **Old Photo Restoration via Deep Latent Space Translation, TPAMI 2022** [Ziyu Wan](http://raywzy.com/)1, [Bo Zhang](https://www.microsoft.com/en-us/research/people/zhanbo/)2, [Dongdong Chen](http://www.dongdongchen.bid/)3, [Pan Zhang](https://panzhang0212.github.io/)4, [Dong Chen](https://www.microsoft.com/en-us/research/people/doch/)2, [Jing Liao](https://liaojing.github.io/html/)1, [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)2
1City University of Hong Kong, 2Microsoft Research Asia, 3Microsoft Cloud AI, 4USTC ## :sparkles: News **2022.3.31**: Our new work regarding old film restoration will be published in CVPR 2022. For more details, please refer to the [project website](http://raywzy.com/Old_Film/) and [github repo](https://github.com/raywzy/Bringing-Old-Films-Back-to-Life). The framework now supports the restoration of high-resolution input. Training code is available and welcome to have a try and learn the training details. You can now play with our [Colab](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing) and try it on your photos. ## Requirement The code is tested on Ubuntu with Nvidia GPUs and CUDA installed. Python>=3.6 is required to run the code. ## Installation Clone the Synchronized-BatchNorm-PyTorch repository for ``` cd Face_Enhancement/models/networks/ git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . cd ../../../ ``` ``` cd Global/detection_models git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . cd ../../ ``` Download the landmark detection pretrained model ``` cd Face_Detection/ wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 cd ../ ``` Download the pretrained model, put the file `Face_Enhancement/checkpoints.zip` under `./Face_Enhancement`, and put the file `Global/checkpoints.zip` under `./Global`. Then unzip them respectively. ``` cd Face_Enhancement/ wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip unzip face_checkpoints.zip cd ../ cd Global/ wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip unzip global_checkpoints.zip cd ../ ``` Install dependencies: ```bash pip install -r requirements.txt ``` ## :rocket: How to use? **Note**: GPU can be set 0 or 0,1,2 or 0,2; use -1 for CPU ### 1) Full Pipeline You could easily restore the old photos with one simple command after installation and downloading the pretrained model. For images without scratches: ``` python run.py --input_folder [test_image_folder_path] \ --output_folder [output_path] \ --GPU 0 ``` For scratched images: ``` python run.py --input_folder [test_image_folder_path] \ --output_folder [output_path] \ --GPU 0 \ --with_scratch ``` **For high-resolution images with scratches**: ``` python run.py --input_folder [test_image_folder_path] \ --output_folder [output_path] \ --GPU 0 \ --with_scratch \ --HR ``` Note: Please try to use the absolute path. The final results will be saved in `./output_path/final_output/`. You could also check the produced results of different steps in `output_path`. ### 2) Scratch Detection Currently we don't plan to release the scratched old photos dataset with labels directly. If you want to get the paired data, you could use our pretrained model to test the collected images to obtain the labels. ``` cd Global/ python detection.py --test_path [test_image_folder_path] \ --output_dir [output_path] \ --input_size [resize_256|full_size|scale_256] ``` ### 3) Global Restoration A triplet domain translation network is proposed to solve both structured degradation and unstructured degradation of old photos.

``` cd Global/ python test.py --Scratch_and_Quality_restore \ --test_input [test_image_folder_path] \ --test_mask [corresponding mask] \ --outputs_dir [output_path] python test.py --Quality_restore \ --test_input [test_image_folder_path] \ --outputs_dir [output_path] ``` ### 4) Face Enhancement We use a progressive generator to refine the face regions of old photos. More details could be found in our journal submission and `./Face_Enhancement` folder.

> *NOTE*: > This repo is mainly for research purpose and we have not yet optimized the running performance. > > Since the model is pretrained with 256*256 images, the model may not work ideally for arbitrary resolution. ### 5) GUI A user-friendly GUI which takes input of image by user and shows result in respective window. #### How it works: 1. Run GUI.py file. 2. Click browse and select your image from test_images/old_w_scratch folder to remove scratches. 3. Click Modify Photo button. 4. Wait for a while and see results on GUI window. 5. Exit window by clicking Exit Window and get your result image in output folder. ## How to train? ### 1) Create Training File Put the folders of VOC dataset, collected old photos (e.g., Real_L_old and Real_RGB_old) into one shared folder. Then ``` cd Global/data/ python Create_Bigfile.py ``` Note: Remember to modify the code based on your own environment. ### 2) Train the VAEs of domain A and domain B respectively ``` cd .. python train_domain_A.py --use_v2_degradation --continue_train --training_dataset domain_A --name domainA_SR_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 100 --no_html --gpu_ids 0,1,2,3 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] python train_domain_B.py --continue_train --training_dataset domain_B --name domainB_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 120 --no_html --gpu_ids 0,1,2,3 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] ``` Note: For the --name option, please ensure your experiment name contains "domainA" or "domainB", which will be used to select different dataset. ### 3) Train the mapping network between domains Train the mapping without scratches: ``` python train_mapping.py --use_v2_degradation --training_dataset mapping --use_vae_which_epoch 200 --continue_train --name mapping_quality --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 80 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] ``` Traing the mapping with scraches: ``` python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_scratch --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] ``` Traing the mapping with scraches (Multi-Scale Patch Attention for HR input): ``` python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_Patch_Attention --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] --mapping_exp 1 ``` ## Citation If you find our work useful for your research, please consider citing the following papers :) ```bibtex @inproceedings{wan2020bringing, title={Bringing Old Photos Back to Life}, author={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={2747--2757}, year={2020} } ``` ```bibtex @article{wan2020old, title={Old Photo Restoration via Deep Latent Space Translation}, author={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang}, journal={arXiv preprint arXiv:2009.07047}, year={2020} } ``` If you are also interested in the legacy photo/video colorization, please refer to [this work](https://github.com/zhangmozhe/video-colorization). ## Maintenance This project is currently maintained by Ziyu Wan and is for academic research use only. If you have any questions, feel free to contact raywzy@gmail.com. ## License The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file. We use our labeled dataset to train the scratch detection model. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. ================================================ FILE: SECURITY.md ================================================ ## Security Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) * Full paths of source file(s) related to the manifestation of the issue * The location of the affected source code (tag/branch/commit or direct URL) * Any special configuration required to reproduce the issue * Step-by-step instructions to reproduce the issue * Proof-of-concept or exploit code (if possible) * Impact of the issue, including how an attacker might exploit the issue This information will help us triage your report more quickly. If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. ## Preferred Languages We prefer all communications to be in English. ## Policy Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). ================================================ FILE: ansible.yaml ================================================ --- - name: Bringing-Old-Photos-Back-to-Life hosts: all gather_facts: no # Succesfully tested on Ubuntu 18.04\20.04 and Debian 10 pre_tasks: - name: install packages package: name: - python3 - python3-pip - python3-venv - git - unzip - tar - lbzip2 - build-essential - cmake - ffmpeg - libsm6 - libxext6 - libgl1-mesa-glx state: latest become: yes tasks: - name: git clone repo git: repo: 'https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life.git' dest: Bringing-Old-Photos-Back-to-Life clone: yes - name: requirements setup pip: requirements: "~/Bringing-Old-Photos-Back-to-Life/requirements.txt" virtualenv: "~/Bringing-Old-Photos-Back-to-Life/.venv" virtualenv_command: /usr/bin/python3 -m venv .venv - name: additional pip packages #requirements lack some packs pip: name: - setuptools - wheel - scikit-build virtualenv: "~/Bringing-Old-Photos-Back-to-Life/.venv" virtualenv_command: /usr/bin/python3 -m venv .venv - name: git clone batchnorm-pytorch git: repo: 'https://github.com/vacancy/Synchronized-BatchNorm-PyTorch' dest: Synchronized-BatchNorm-PyTorch clone: yes - name: copy sync_batchnorm to face_enhancement copy: src: Synchronized-BatchNorm-PyTorch/sync_batchnorm dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/models/networks/ remote_src: yes - name: copy sync_batchnorm to global copy: src: Synchronized-BatchNorm-PyTorch/sync_batchnorm dest: Bringing-Old-Photos-Back-to-Life/Global/detection_models remote_src: yes - name: check if shape_predictor_68_face_landmarks.dat stat: path: Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat register: p - name: get shape_predictor_68_face_landmarks.dat.bz2 get_url: url: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 dest: Bringing-Old-Photos-Back-to-Life/Face_Detection/ when: p.stat.exists == False - name: unarchive shape_predictor_68_face_landmarks.dat.bz2 shell: 'bzip2 -d Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat.bz2' when: p.stat.exists == False - name: check if face_enhancement stat: path: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/checkpoints/Setting_9_epoch_100/latest_net_G.pth register: fc - name: unarchive Face_Enhancement/checkpoints.zip unarchive: src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/ remote_src: yes when: fc.stat.exists == False - name: check if global stat: path: Bringing-Old-Photos-Back-to-Life/Global/checkpoints/detection/FT_Epoch_latest.pt register: gc - name: unarchive Global/checkpoints.zip unarchive: src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip dest: Bringing-Old-Photos-Back-to-Life/Global/ remote_src: yes when: gc.stat.exists == False # Do not forget to execute 'source .venv/bin/activate' inside Bringing-Old-Photos-Back-to-Life before starting run.py ================================================ FILE: cog.yaml ================================================ build: gpu: true python_version: "3.8" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - "cmake==3.21.2" - "torchvision==0.9.0" - "torch==1.8.0" - "numpy==1.19.4" - "opencv-python==4.4.0.46" - "scipy==1.5.3" - "tensorboardX==2.4" - "dominate==2.6.0" - "easydict==1.9" - "PyYAML==5.3.1" - "scikit-image==0.18.3" - "dill==0.3.4" - "einops==0.3.0" - "PySimpleGUI==4.46.0" - "ipython==7.19.0" run: - pip install dlib predict: "predict.py:Predictor" ================================================ FILE: download-weights ================================================ #!/bin/sh cd Face_Enhancement/models/networks git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . cd ../../../ cd Global/detection_models git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . cd ../../ # download the landmark detection model cd Face_Detection/ wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 cd ../ # download the pretrained model cd Face_Enhancement/ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip unzip checkpoints.zip cd ../ cd Global/ wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip unzip checkpoints.zip cd ../ ================================================ FILE: kubernetes-pod.yml ================================================ apiVersion: v1 kind: Pod metadata: name: photo-back2life spec: containers: - name: photos-back2life image: volumeMounts: - mountPath: /in name: in-folder - mountPath: /out name: out-folder command: - python - /app/run.py args: - --input_folder - /in - --output_folder - /out - --GPU - '0' - --with_scratch resources: limits: memory: 4Gi cpu: 0 nvidia.com/gpu: 1 volumes: - name: in-folder hostPath: path: /srv/in type: Directory - name: out-folder hostPath: path: /srv/out type: Directory ================================================ FILE: predict.py ================================================ import tempfile from pathlib import Path import argparse import shutil import os import glob import cv2 import cog from run import run_cmd class Predictor(cog.Predictor): def setup(self): parser = argparse.ArgumentParser() parser.add_argument( "--input_folder", type=str, default="input/cog_temp", help="Test images" ) parser.add_argument( "--output_folder", type=str, default="output", help="Restored images, please use the absolute path", ) parser.add_argument("--GPU", type=str, default="0", help="0,1,2") parser.add_argument( "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint", ) self.opts = parser.parse_args("") self.basepath = os.getcwd() self.opts.input_folder = os.path.join(self.basepath, self.opts.input_folder) self.opts.output_folder = os.path.join(self.basepath, self.opts.output_folder) os.makedirs(self.opts.input_folder, exist_ok=True) os.makedirs(self.opts.output_folder, exist_ok=True) @cog.input("image", type=Path, help="input image") @cog.input( "HR", type=bool, default=False, help="whether the input image is high-resolution", ) @cog.input( "with_scratch", type=bool, default=False, help="whether the input image is scratched", ) def predict(self, image, HR=False, with_scratch=False): try: os.chdir(self.basepath) input_path = os.path.join(self.opts.input_folder, os.path.basename(image)) shutil.copy(str(image), input_path) gpu1 = self.opts.GPU ## Stage 1: Overall Quality Improve print("Running Stage 1: Overall restoration") os.chdir("./Global") stage_1_input_dir = self.opts.input_folder stage_1_output_dir = os.path.join( self.opts.output_folder, "stage_1_restore_output" ) os.makedirs(stage_1_output_dir, exist_ok=True) if not with_scratch: stage_1_command = ( "python test.py --test_mode Full --Quality_restore --test_input " + stage_1_input_dir + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 ) run_cmd(stage_1_command) else: mask_dir = os.path.join(stage_1_output_dir, "masks") new_input = os.path.join(mask_dir, "input") new_mask = os.path.join(mask_dir, "mask") stage_1_command_1 = ( "python detection.py --test_path " + stage_1_input_dir + " --output_dir " + mask_dir + " --input_size full_size" + " --GPU " + gpu1 ) if HR: HR_suffix = " --HR" else: HR_suffix = "" stage_1_command_2 = ( "python test.py --Scratch_and_Quality_restore --test_input " + new_input + " --test_mask " + new_mask + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 + HR_suffix ) run_cmd(stage_1_command_1) run_cmd(stage_1_command_2) ## Solve the case when there is no face in the old photo stage_1_results = os.path.join(stage_1_output_dir, "restored_image") stage_4_output_dir = os.path.join(self.opts.output_folder, "final_output") os.makedirs(stage_4_output_dir, exist_ok=True) for x in os.listdir(stage_1_results): img_dir = os.path.join(stage_1_results, x) shutil.copy(img_dir, stage_4_output_dir) print("Finish Stage 1 ...") print("\n") ## Stage 2: Face Detection print("Running Stage 2: Face Detection") os.chdir(".././Face_Detection") stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") stage_2_output_dir = os.path.join( self.opts.output_folder, "stage_2_detection_output" ) os.makedirs(stage_2_output_dir, exist_ok=True) stage_2_command = ( "python detect_all_dlib_HR.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir ) run_cmd(stage_2_command) print("Finish Stage 2 ...") print("\n") ## Stage 3: Face Restore print("Running Stage 3: Face Enhancement") os.chdir(".././Face_Enhancement") stage_3_input_mask = "./" stage_3_input_face = stage_2_output_dir stage_3_output_dir = os.path.join( self.opts.output_folder, "stage_3_face_output" ) os.makedirs(stage_3_output_dir, exist_ok=True) self.opts.checkpoint_name = "FaceSR_512" stage_3_command = ( "python test_face.py --old_face_folder " + stage_3_input_face + " --old_face_label_folder " + stage_3_input_mask + " --tensorboard_log --name " + self.opts.checkpoint_name + " --gpu_ids " + gpu1 + " --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir " + stage_3_output_dir + " --no_parsing_map" ) run_cmd(stage_3_command) print("Finish Stage 3 ...") print("\n") ## Stage 4: Warp back print("Running Stage 4: Blending") os.chdir(".././Face_Detection") stage_4_input_image_dir = os.path.join(stage_1_output_dir, "restored_image") stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") stage_4_output_dir = os.path.join(self.opts.output_folder, "final_output") os.makedirs(stage_4_output_dir, exist_ok=True) stage_4_command = ( "python align_warp_back_multiple_dlib_HR.py --origin_url " + stage_4_input_image_dir + " --replace_url " + stage_4_input_face_dir + " --save_url " + stage_4_output_dir ) run_cmd(stage_4_command) print("Finish Stage 4 ...") print("\n") print("All the processing is done. Please check the results.") final_output = os.listdir(os.path.join(self.opts.output_folder, "final_output"))[0] image_restore = cv2.imread(os.path.join(self.opts.output_folder, "final_output", final_output)) out_path = Path(tempfile.mkdtemp()) / "out.png" cv2.imwrite(str(out_path), image_restore) finally: clean_folder(self.opts.input_folder) clean_folder(self.opts.output_folder) return out_path def clean_folder(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f"Failed to delete {file_path}. Reason:{e}") ================================================ FILE: requirements.txt ================================================ torch torchvision dlib scikit-image easydict PyYAML dominate>=2.3.1 dill tensorboardX scipy opencv-python einops PySimpleGUI matplotlib ================================================ FILE: run.py ================================================ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import argparse import shutil import sys from subprocess import call def run_cmd(command): try: call(command, shell=True) except KeyboardInterrupt: print("Process interrupted") sys.exit(1) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_folder", type=str, default="./test_images/old", help="Test images") parser.add_argument( "--output_folder", type=str, default="./output", help="Restored images, please use the absolute path", ) parser.add_argument("--GPU", type=str, default="6,7", help="0,1,2") parser.add_argument( "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" ) parser.add_argument("--with_scratch", action="store_true") parser.add_argument("--HR", action='store_true') opts = parser.parse_args() gpu1 = opts.GPU # resolve relative paths before changing directory opts.input_folder = os.path.abspath(opts.input_folder) opts.output_folder = os.path.abspath(opts.output_folder) if not os.path.exists(opts.output_folder): os.makedirs(opts.output_folder) main_environment = os.getcwd() ## Stage 1: Overall Quality Improve print("Running Stage 1: Overall restoration") os.chdir("./Global") stage_1_input_dir = opts.input_folder stage_1_output_dir = os.path.join(opts.output_folder, "stage_1_restore_output") if not os.path.exists(stage_1_output_dir): os.makedirs(stage_1_output_dir) if not opts.with_scratch: stage_1_command = ( "python test.py --test_mode Full --Quality_restore --test_input " + stage_1_input_dir + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 ) run_cmd(stage_1_command) else: mask_dir = os.path.join(stage_1_output_dir, "masks") new_input = os.path.join(mask_dir, "input") new_mask = os.path.join(mask_dir, "mask") stage_1_command_1 = ( "python detection.py --test_path " + stage_1_input_dir + " --output_dir " + mask_dir + " --input_size full_size" + " --GPU " + gpu1 ) if opts.HR: HR_suffix=" --HR" else: HR_suffix="" stage_1_command_2 = ( "python test.py --Scratch_and_Quality_restore --test_input " + new_input + " --test_mask " + new_mask + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " + gpu1 + HR_suffix ) run_cmd(stage_1_command_1) run_cmd(stage_1_command_2) ## Solve the case when there is no face in the old photo stage_1_results = os.path.join(stage_1_output_dir, "restored_image") stage_4_output_dir = os.path.join(opts.output_folder, "final_output") if not os.path.exists(stage_4_output_dir): os.makedirs(stage_4_output_dir) for x in os.listdir(stage_1_results): img_dir = os.path.join(stage_1_results, x) shutil.copy(img_dir, stage_4_output_dir) print("Finish Stage 1 ...") print("\n") ## Stage 2: Face Detection print("Running Stage 2: Face Detection") os.chdir(".././Face_Detection") stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") stage_2_output_dir = os.path.join(opts.output_folder, "stage_2_detection_output") if not os.path.exists(stage_2_output_dir): os.makedirs(stage_2_output_dir) if opts.HR: stage_2_command = ( "python detect_all_dlib_HR.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir ) else: stage_2_command = ( "python detect_all_dlib.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir ) run_cmd(stage_2_command) print("Finish Stage 2 ...") print("\n") ## Stage 3: Face Restore print("Running Stage 3: Face Enhancement") os.chdir(".././Face_Enhancement") stage_3_input_mask = "./" stage_3_input_face = stage_2_output_dir stage_3_output_dir = os.path.join(opts.output_folder, "stage_3_face_output") if not os.path.exists(stage_3_output_dir): os.makedirs(stage_3_output_dir) if opts.HR: opts.checkpoint_name='FaceSR_512' stage_3_command = ( "python test_face.py --old_face_folder " + stage_3_input_face + " --old_face_label_folder " + stage_3_input_mask + " --tensorboard_log --name " + opts.checkpoint_name + " --gpu_ids " + gpu1 + " --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir " + stage_3_output_dir + " --no_parsing_map" ) else: stage_3_command = ( "python test_face.py --old_face_folder " + stage_3_input_face + " --old_face_label_folder " + stage_3_input_mask + " --tensorboard_log --name " + opts.checkpoint_name + " --gpu_ids " + gpu1 + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " + stage_3_output_dir + " --no_parsing_map" ) run_cmd(stage_3_command) print("Finish Stage 3 ...") print("\n") ## Stage 4: Warp back print("Running Stage 4: Blending") os.chdir(".././Face_Detection") stage_4_input_image_dir = os.path.join(stage_1_output_dir, "restored_image") stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") stage_4_output_dir = os.path.join(opts.output_folder, "final_output") if not os.path.exists(stage_4_output_dir): os.makedirs(stage_4_output_dir) if opts.HR: stage_4_command = ( "python align_warp_back_multiple_dlib_HR.py --origin_url " + stage_4_input_image_dir + " --replace_url " + stage_4_input_face_dir + " --save_url " + stage_4_output_dir ) else: stage_4_command = ( "python align_warp_back_multiple_dlib.py --origin_url " + stage_4_input_image_dir + " --replace_url " + stage_4_input_face_dir + " --save_url " + stage_4_output_dir ) run_cmd(stage_4_command) print("Finish Stage 4 ...") print("\n") print("All the processing is done. Please check the results.")