[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.pyc\n*~\n\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n\nResources:\n\n- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)\n- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)\n- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvidia/cuda:11.1-base-ubuntu20.04\n\nRUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq\nADD . /app\nWORKDIR /app\nRUN cd Face_Enhancement/models/networks/ &&\\\n  git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\\\n  cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\\\n  cd ../../../\n\nRUN cd Global/detection_models &&\\\n  git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\\\n  cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\\\n  cd ../../\n\nRUN cd Face_Detection/ &&\\\n  wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\\\n  bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\\\n  cd ../ \n\nRUN cd Face_Enhancement/ &&\\\n  wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\\\n  unzip checkpoints.zip &&\\\n  cd ../ &&\\\n  cd Global/ &&\\\n  wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\\\n  unzip checkpoints.zip &&\\\n  rm -f checkpoints.zip &&\\\n  cd ../\n\nRUN pip3 install numpy\n\nRUN pip3 install dlib\n\nRUN pip3 install -r requirements.txt\n\nRUN git clone https://github.com/NVlabs/SPADE.git\n\nRUN cd SPADE/ && pip3 install -r requirements.txt\n\nRUN cd ..\n\nCMD [\"python3\", \"run.py\"]\n"
  },
  {
    "path": "Face_Detection/align_warp_back_multiple_dlib.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage.io as io\n\n# from face_sdk import FaceDetection\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Rectangle\nfrom skimage.transform import SimilarityTransform\nfrom skimage.transform import warp\nfrom PIL import Image, ImageFilter\nimport torch.nn.functional as F\nimport torchvision as tv\nimport torchvision.utils as vutils\nimport time\nimport cv2\nimport os\nfrom skimage import img_as_ubyte\nimport json\nimport argparse\nimport dlib\n\n\ndef calculate_cdf(histogram):\n    \"\"\"\n    This method calculates the cumulative distribution function\n    :param array histogram: The values of the histogram\n    :return: normalized_cdf: The normalized cumulative distribution function\n    :rtype: array\n    \"\"\"\n    # Get the cumulative sum of the elements\n    cdf = histogram.cumsum()\n\n    # Normalize the cdf\n    normalized_cdf = cdf / float(cdf.max())\n\n    return normalized_cdf\n\n\ndef calculate_lookup(src_cdf, ref_cdf):\n    \"\"\"\n    This method creates the lookup table\n    :param array src_cdf: The cdf for the source image\n    :param array ref_cdf: The cdf for the reference image\n    :return: lookup_table: The lookup table\n    :rtype: array\n    \"\"\"\n    lookup_table = np.zeros(256)\n    lookup_val = 0\n    for src_pixel_val in range(len(src_cdf)):\n        lookup_val\n        for ref_pixel_val in range(len(ref_cdf)):\n            if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:\n                lookup_val = ref_pixel_val\n                break\n        lookup_table[src_pixel_val] = lookup_val\n    return lookup_table\n\n\ndef match_histograms(src_image, ref_image):\n    \"\"\"\n    This method matches the source image histogram to the\n    reference signal\n    :param image src_image: The original source image\n    :param image  ref_image: The reference image\n    :return: image_after_matching\n    :rtype: image (array)\n    \"\"\"\n    # Split the images into the different color channels\n    # b means blue, g means green and r means red\n    src_b, src_g, src_r = cv2.split(src_image)\n    ref_b, ref_g, ref_r = cv2.split(ref_image)\n\n    # Compute the b, g, and r histograms separately\n    # The flatten() Numpy method returns a copy of the array c\n    # collapsed into one dimension.\n    src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])\n    src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])\n    src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])\n    ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])\n    ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])\n    ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])\n\n    # Compute the normalized cdf for the source and reference image\n    src_cdf_blue = calculate_cdf(src_hist_blue)\n    src_cdf_green = calculate_cdf(src_hist_green)\n    src_cdf_red = calculate_cdf(src_hist_red)\n    ref_cdf_blue = calculate_cdf(ref_hist_blue)\n    ref_cdf_green = calculate_cdf(ref_hist_green)\n    ref_cdf_red = calculate_cdf(ref_hist_red)\n\n    # Make a separate lookup table for each color\n    blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)\n    green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)\n    red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)\n\n    # Use the lookup function to transform the colors of the original\n    # source image\n    blue_after_transform = cv2.LUT(src_b, blue_lookup_table)\n    green_after_transform = cv2.LUT(src_g, green_lookup_table)\n    red_after_transform = cv2.LUT(src_r, red_lookup_table)\n\n    # Put the image back together\n    image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])\n    image_after_matching = cv2.convertScaleAbs(image_after_matching)\n\n    return image_after_matching\n\n\ndef _standard_face_pts():\n    pts = (\n        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0\n        - 1.0\n    )\n\n    return np.reshape(pts, (5, 2))\n\n\ndef _origin_face_pts():\n    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)\n\n    return np.reshape(pts, (5, 2))\n\n\ndef compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(target_pts, landmark)\n\n    return affine\n\n\ndef compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(landmark, target_pts)\n\n    return affine\n\n\ndef show_detection(image, box, landmark):\n    plt.imshow(image)\n    print(box[2] - box[0])\n    plt.gca().add_patch(\n        Rectangle(\n            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n        )\n    )\n    plt.scatter(landmark[0][0], landmark[0][1])\n    plt.scatter(landmark[1][0], landmark[1][1])\n    plt.scatter(landmark[2][0], landmark[2][1])\n    plt.scatter(landmark[3][0], landmark[3][1])\n    plt.scatter(landmark[4][0], landmark[4][1])\n    plt.show()\n\n\ndef affine2theta(affine, input_w, input_h, target_w, target_h):\n    # param = np.linalg.inv(affine)\n    param = affine\n    theta = np.zeros([2, 3])\n    theta[0, 0] = param[0, 0] * input_h / target_h\n    theta[0, 1] = param[0, 1] * input_w / target_h\n    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1\n    theta[1, 0] = param[1, 0] * input_h / target_w\n    theta[1, 1] = param[1, 1] * input_w / target_w\n    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1\n    return theta\n\n\ndef blur_blending(im1, im2, mask):\n\n    mask *= 255.0\n\n    kernel = np.ones((10, 10), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=1)\n\n    mask = Image.fromarray(mask.astype(\"uint8\")).convert(\"L\")\n    im1 = Image.fromarray(im1.astype(\"uint8\"))\n    im2 = Image.fromarray(im2.astype(\"uint8\"))\n\n    mask_blur = mask.filter(ImageFilter.GaussianBlur(20))\n    im = Image.composite(im1, im2, mask)\n\n    im = Image.composite(im, im2, mask_blur)\n\n    return np.array(im) / 255.0\n\n\ndef blur_blending_cv2(im1, im2, mask):\n\n    mask *= 255.0\n\n    kernel = np.ones((9, 9), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=3)\n\n    mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)\n    mask_blur /= 255.0\n\n    im = im1 * mask_blur + (1 - mask_blur) * im2\n\n    im /= 255.0\n    im = np.clip(im, 0.0, 1.0)\n\n    return im\n\n\n# def Poisson_blending(im1,im2,mask):\n\n\n#     Image.composite(\ndef Poisson_blending(im1, im2, mask):\n\n    # mask=1-mask\n    mask *= 255\n    kernel = np.ones((10, 10), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=1)\n    mask /= 255\n    mask = 1 - mask\n    mask *= 255\n\n    mask = mask[:, :, 0]\n    width, height, channels = im1.shape\n    center = (int(height / 2), int(width / 2))\n    result = cv2.seamlessClone(\n        im2.astype(\"uint8\"), im1.astype(\"uint8\"), mask.astype(\"uint8\"), center, cv2.MIXED_CLONE\n    )\n\n    return result / 255.0\n\n\ndef Poisson_B(im1, im2, mask, center):\n\n    mask *= 255\n\n    result = cv2.seamlessClone(\n        im2.astype(\"uint8\"), im1.astype(\"uint8\"), mask.astype(\"uint8\"), center, cv2.NORMAL_CLONE\n    )\n\n    return result / 255\n\n\ndef seamless_clone(old_face, new_face, raw_mask):\n\n    height, width, _ = old_face.shape\n    height = height // 2\n    width = width // 2\n\n    y_indices, x_indices, _ = np.nonzero(raw_mask)\n    y_crop = slice(np.min(y_indices), np.max(y_indices))\n    x_crop = slice(np.min(x_indices), np.max(x_indices))\n    y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))\n    x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))\n\n    insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype(\"uint8\")\n    insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype(\"uint8\")\n    insertion_mask[insertion_mask != 0] = 255\n    prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), \"constant\")).astype(\n        \"uint8\"\n    )\n    # if np.sum(insertion_mask) == 0:\n    n_mask = insertion_mask[1:-1, 1:-1, :]\n    n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)\n    print(n_mask.shape)\n    x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])\n    if w < 4 or h < 4:\n        blended = prior\n    else:\n        blended = cv2.seamlessClone(\n            insertion,  # pylint: disable=no-member\n            prior,\n            insertion_mask,\n            (x_center, y_center),\n            cv2.NORMAL_CLONE,\n        )  # pylint: disable=no-member\n\n    blended = blended[height:-height, width:-width]\n\n    return blended.astype(\"float32\") / 255.0\n\n\ndef get_landmark(face_landmarks, id):\n    part = face_landmarks.part(id)\n    x = part.x\n    y = part.y\n\n    return (x, y)\n\n\ndef search(face_landmarks):\n\n    x1, y1 = get_landmark(face_landmarks, 36)\n    x2, y2 = get_landmark(face_landmarks, 39)\n    x3, y3 = get_landmark(face_landmarks, 42)\n    x4, y4 = get_landmark(face_landmarks, 45)\n\n    x_nose, y_nose = get_landmark(face_landmarks, 30)\n\n    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)\n    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)\n\n    x_left_eye = int((x1 + x2) / 2)\n    y_left_eye = int((y1 + y2) / 2)\n    x_right_eye = int((x3 + x4) / 2)\n    y_right_eye = int((y3 + y4) / 2)\n\n    results = np.array(\n        [\n            [x_left_eye, y_left_eye],\n            [x_right_eye, y_right_eye],\n            [x_nose, y_nose],\n            [x_left_mouth, y_left_mouth],\n            [x_right_mouth, y_right_mouth],\n        ]\n    )\n\n    return results\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--origin_url\", type=str, default=\"./\", help=\"origin images\")\n    parser.add_argument(\"--replace_url\", type=str, default=\"./\", help=\"restored faces\")\n    parser.add_argument(\"--save_url\", type=str, default=\"./save\")\n    opts = parser.parse_args()\n\n    origin_url = opts.origin_url\n    replace_url = opts.replace_url\n    save_url = opts.save_url\n\n    if not os.path.exists(save_url):\n        os.makedirs(save_url)\n\n    face_detector = dlib.get_frontal_face_detector()\n    landmark_locator = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n\n    count = 0\n\n    for x in os.listdir(origin_url):\n        img_url = os.path.join(origin_url, x)\n        pil_img = Image.open(img_url).convert(\"RGB\")\n\n        origin_width, origin_height = pil_img.size\n        image = np.array(pil_img)\n\n        start = time.time()\n        faces = face_detector(image)\n        done = time.time()\n\n        if len(faces) == 0:\n            print(\"Warning: There is no face in %s\" % (x))\n            continue\n\n        blended = image\n        for face_id in range(len(faces)):\n\n            current_face = faces[face_id]\n            face_landmarks = landmark_locator(image, current_face)\n            current_fl = search(face_landmarks)\n\n            forward_mask = np.ones_like(image).astype(\"uint8\")\n            affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)\n            aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True)\n            forward_mask = warp(\n                forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True\n            )\n\n            affine_inverse = affine.inverse\n            cur_face = aligned_face\n            if replace_url != \"\":\n\n                face_name = x[:-4] + \"_\" + str(face_id + 1) + \".png\"\n                cur_url = os.path.join(replace_url, face_name)\n                restored_face = Image.open(cur_url).convert(\"RGB\")\n                restored_face = np.array(restored_face)\n                cur_face = restored_face\n\n            ## Histogram Color matching\n            A = cv2.cvtColor(aligned_face.astype(\"uint8\"), cv2.COLOR_RGB2BGR)\n            B = cv2.cvtColor(cur_face.astype(\"uint8\"), cv2.COLOR_RGB2BGR)\n            B = match_histograms(B, A)\n            cur_face = cv2.cvtColor(B.astype(\"uint8\"), cv2.COLOR_BGR2RGB)\n\n            warped_back = warp(\n                cur_face,\n                affine_inverse,\n                output_shape=(origin_height, origin_width, 3),\n                order=3,\n                preserve_range=True,\n            )\n\n            backward_mask = warp(\n                forward_mask,\n                affine_inverse,\n                output_shape=(origin_height, origin_width, 3),\n                order=0,\n                preserve_range=True,\n            )  ## Nearest neighbour\n\n            blended = blur_blending_cv2(warped_back, blended, backward_mask)\n            blended *= 255.0\n\n        io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))\n\n        count += 1\n\n        if count % 1000 == 0:\n            print(\"%d have finished ...\" % (count))\n\n"
  },
  {
    "path": "Face_Detection/align_warp_back_multiple_dlib_HR.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage.io as io\n\n# from face_sdk import FaceDetection\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Rectangle\nfrom skimage.transform import SimilarityTransform\nfrom skimage.transform import warp\nfrom PIL import Image, ImageFilter\nimport torch.nn.functional as F\nimport torchvision as tv\nimport torchvision.utils as vutils\nimport time\nimport cv2\nimport os\nfrom skimage import img_as_ubyte\nimport json\nimport argparse\nimport dlib\n\n\ndef calculate_cdf(histogram):\n    \"\"\"\n    This method calculates the cumulative distribution function\n    :param array histogram: The values of the histogram\n    :return: normalized_cdf: The normalized cumulative distribution function\n    :rtype: array\n    \"\"\"\n    # Get the cumulative sum of the elements\n    cdf = histogram.cumsum()\n\n    # Normalize the cdf\n    normalized_cdf = cdf / float(cdf.max())\n\n    return normalized_cdf\n\n\ndef calculate_lookup(src_cdf, ref_cdf):\n    \"\"\"\n    This method creates the lookup table\n    :param array src_cdf: The cdf for the source image\n    :param array ref_cdf: The cdf for the reference image\n    :return: lookup_table: The lookup table\n    :rtype: array\n    \"\"\"\n    lookup_table = np.zeros(256)\n    lookup_val = 0\n    for src_pixel_val in range(len(src_cdf)):\n        lookup_val\n        for ref_pixel_val in range(len(ref_cdf)):\n            if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:\n                lookup_val = ref_pixel_val\n                break\n        lookup_table[src_pixel_val] = lookup_val\n    return lookup_table\n\n\ndef match_histograms(src_image, ref_image):\n    \"\"\"\n    This method matches the source image histogram to the\n    reference signal\n    :param image src_image: The original source image\n    :param image  ref_image: The reference image\n    :return: image_after_matching\n    :rtype: image (array)\n    \"\"\"\n    # Split the images into the different color channels\n    # b means blue, g means green and r means red\n    src_b, src_g, src_r = cv2.split(src_image)\n    ref_b, ref_g, ref_r = cv2.split(ref_image)\n\n    # Compute the b, g, and r histograms separately\n    # The flatten() Numpy method returns a copy of the array c\n    # collapsed into one dimension.\n    src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])\n    src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])\n    src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])\n    ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])\n    ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])\n    ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])\n\n    # Compute the normalized cdf for the source and reference image\n    src_cdf_blue = calculate_cdf(src_hist_blue)\n    src_cdf_green = calculate_cdf(src_hist_green)\n    src_cdf_red = calculate_cdf(src_hist_red)\n    ref_cdf_blue = calculate_cdf(ref_hist_blue)\n    ref_cdf_green = calculate_cdf(ref_hist_green)\n    ref_cdf_red = calculate_cdf(ref_hist_red)\n\n    # Make a separate lookup table for each color\n    blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)\n    green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)\n    red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)\n\n    # Use the lookup function to transform the colors of the original\n    # source image\n    blue_after_transform = cv2.LUT(src_b, blue_lookup_table)\n    green_after_transform = cv2.LUT(src_g, green_lookup_table)\n    red_after_transform = cv2.LUT(src_r, red_lookup_table)\n\n    # Put the image back together\n    image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])\n    image_after_matching = cv2.convertScaleAbs(image_after_matching)\n\n    return image_after_matching\n\n\ndef _standard_face_pts():\n    pts = (\n        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0\n        - 1.0\n    )\n\n    return np.reshape(pts, (5, 2))\n\n\ndef _origin_face_pts():\n    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)\n\n    return np.reshape(pts, (5, 2))\n\n\ndef compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(target_pts, landmark)\n\n    return affine\n\n\ndef compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(landmark, target_pts)\n\n    return affine\n\n\ndef show_detection(image, box, landmark):\n    plt.imshow(image)\n    print(box[2] - box[0])\n    plt.gca().add_patch(\n        Rectangle(\n            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n        )\n    )\n    plt.scatter(landmark[0][0], landmark[0][1])\n    plt.scatter(landmark[1][0], landmark[1][1])\n    plt.scatter(landmark[2][0], landmark[2][1])\n    plt.scatter(landmark[3][0], landmark[3][1])\n    plt.scatter(landmark[4][0], landmark[4][1])\n    plt.show()\n\n\ndef affine2theta(affine, input_w, input_h, target_w, target_h):\n    # param = np.linalg.inv(affine)\n    param = affine\n    theta = np.zeros([2, 3])\n    theta[0, 0] = param[0, 0] * input_h / target_h\n    theta[0, 1] = param[0, 1] * input_w / target_h\n    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1\n    theta[1, 0] = param[1, 0] * input_h / target_w\n    theta[1, 1] = param[1, 1] * input_w / target_w\n    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1\n    return theta\n\n\ndef blur_blending(im1, im2, mask):\n\n    mask *= 255.0\n\n    kernel = np.ones((10, 10), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=1)\n\n    mask = Image.fromarray(mask.astype(\"uint8\")).convert(\"L\")\n    im1 = Image.fromarray(im1.astype(\"uint8\"))\n    im2 = Image.fromarray(im2.astype(\"uint8\"))\n\n    mask_blur = mask.filter(ImageFilter.GaussianBlur(20))\n    im = Image.composite(im1, im2, mask)\n\n    im = Image.composite(im, im2, mask_blur)\n\n    return np.array(im) / 255.0\n\n\ndef blur_blending_cv2(im1, im2, mask):\n\n    mask *= 255.0\n\n    kernel = np.ones((9, 9), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=3)\n\n    mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)\n    mask_blur /= 255.0\n\n    im = im1 * mask_blur + (1 - mask_blur) * im2\n\n    im /= 255.0\n    im = np.clip(im, 0.0, 1.0)\n\n    return im\n\n\n# def Poisson_blending(im1,im2,mask):\n\n\n#     Image.composite(\ndef Poisson_blending(im1, im2, mask):\n\n    # mask=1-mask\n    mask *= 255\n    kernel = np.ones((10, 10), np.uint8)\n    mask = cv2.erode(mask, kernel, iterations=1)\n    mask /= 255\n    mask = 1 - mask\n    mask *= 255\n\n    mask = mask[:, :, 0]\n    width, height, channels = im1.shape\n    center = (int(height / 2), int(width / 2))\n    result = cv2.seamlessClone(\n        im2.astype(\"uint8\"), im1.astype(\"uint8\"), mask.astype(\"uint8\"), center, cv2.MIXED_CLONE\n    )\n\n    return result / 255.0\n\n\ndef Poisson_B(im1, im2, mask, center):\n\n    mask *= 255\n\n    result = cv2.seamlessClone(\n        im2.astype(\"uint8\"), im1.astype(\"uint8\"), mask.astype(\"uint8\"), center, cv2.NORMAL_CLONE\n    )\n\n    return result / 255\n\n\ndef seamless_clone(old_face, new_face, raw_mask):\n\n    height, width, _ = old_face.shape\n    height = height // 2\n    width = width // 2\n\n    y_indices, x_indices, _ = np.nonzero(raw_mask)\n    y_crop = slice(np.min(y_indices), np.max(y_indices))\n    x_crop = slice(np.min(x_indices), np.max(x_indices))\n    y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))\n    x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))\n\n    insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype(\"uint8\")\n    insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype(\"uint8\")\n    insertion_mask[insertion_mask != 0] = 255\n    prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), \"constant\")).astype(\n        \"uint8\"\n    )\n    # if np.sum(insertion_mask) == 0:\n    n_mask = insertion_mask[1:-1, 1:-1, :]\n    n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)\n    print(n_mask.shape)\n    x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])\n    if w < 4 or h < 4:\n        blended = prior\n    else:\n        blended = cv2.seamlessClone(\n            insertion,  # pylint: disable=no-member\n            prior,\n            insertion_mask,\n            (x_center, y_center),\n            cv2.NORMAL_CLONE,\n        )  # pylint: disable=no-member\n\n    blended = blended[height:-height, width:-width]\n\n    return blended.astype(\"float32\") / 255.0\n\n\ndef get_landmark(face_landmarks, id):\n    part = face_landmarks.part(id)\n    x = part.x\n    y = part.y\n\n    return (x, y)\n\n\ndef search(face_landmarks):\n\n    x1, y1 = get_landmark(face_landmarks, 36)\n    x2, y2 = get_landmark(face_landmarks, 39)\n    x3, y3 = get_landmark(face_landmarks, 42)\n    x4, y4 = get_landmark(face_landmarks, 45)\n\n    x_nose, y_nose = get_landmark(face_landmarks, 30)\n\n    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)\n    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)\n\n    x_left_eye = int((x1 + x2) / 2)\n    y_left_eye = int((y1 + y2) / 2)\n    x_right_eye = int((x3 + x4) / 2)\n    y_right_eye = int((y3 + y4) / 2)\n\n    results = np.array(\n        [\n            [x_left_eye, y_left_eye],\n            [x_right_eye, y_right_eye],\n            [x_nose, y_nose],\n            [x_left_mouth, y_left_mouth],\n            [x_right_mouth, y_right_mouth],\n        ]\n    )\n\n    return results\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--origin_url\", type=str, default=\"./\", help=\"origin images\")\n    parser.add_argument(\"--replace_url\", type=str, default=\"./\", help=\"restored faces\")\n    parser.add_argument(\"--save_url\", type=str, default=\"./save\")\n    opts = parser.parse_args()\n\n    origin_url = opts.origin_url\n    replace_url = opts.replace_url\n    save_url = opts.save_url\n\n    if not os.path.exists(save_url):\n        os.makedirs(save_url)\n\n    face_detector = dlib.get_frontal_face_detector()\n    landmark_locator = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n\n    count = 0\n\n    for x in os.listdir(origin_url):\n        img_url = os.path.join(origin_url, x)\n        pil_img = Image.open(img_url).convert(\"RGB\")\n\n        origin_width, origin_height = pil_img.size\n        image = np.array(pil_img)\n\n        start = time.time()\n        faces = face_detector(image)\n        done = time.time()\n\n        if len(faces) == 0:\n            print(\"Warning: There is no face in %s\" % (x))\n            continue\n\n        blended = image\n        for face_id in range(len(faces)):\n\n            current_face = faces[face_id]\n            face_landmarks = landmark_locator(image, current_face)\n            current_fl = search(face_landmarks)\n\n            forward_mask = np.ones_like(image).astype(\"uint8\")\n            affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)\n            aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True)\n            forward_mask = warp(\n                forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True\n            )\n\n            affine_inverse = affine.inverse\n            cur_face = aligned_face\n            if replace_url != \"\":\n\n                face_name = x[:-4] + \"_\" + str(face_id + 1) + \".png\"\n                cur_url = os.path.join(replace_url, face_name)\n                restored_face = Image.open(cur_url).convert(\"RGB\")\n                restored_face = np.array(restored_face)\n                cur_face = restored_face\n\n            ## Histogram Color matching\n            A = cv2.cvtColor(aligned_face.astype(\"uint8\"), cv2.COLOR_RGB2BGR)\n            B = cv2.cvtColor(cur_face.astype(\"uint8\"), cv2.COLOR_RGB2BGR)\n            B = match_histograms(B, A)\n            cur_face = cv2.cvtColor(B.astype(\"uint8\"), cv2.COLOR_BGR2RGB)\n\n            warped_back = warp(\n                cur_face,\n                affine_inverse,\n                output_shape=(origin_height, origin_width, 3),\n                order=3,\n                preserve_range=True,\n            )\n\n            backward_mask = warp(\n                forward_mask,\n                affine_inverse,\n                output_shape=(origin_height, origin_width, 3),\n                order=0,\n                preserve_range=True,\n            )  ## Nearest neighbour\n\n            blended = blur_blending_cv2(warped_back, blended, backward_mask)\n            blended *= 255.0\n\n        io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0))\n\n        count += 1\n\n        if count % 1000 == 0:\n            print(\"%d have finished ...\" % (count))\n\n"
  },
  {
    "path": "Face_Detection/detect_all_dlib.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage.io as io\n\n# from FaceSDK.face_sdk import FaceDetection\n# from face_sdk import FaceDetection\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Rectangle\nfrom skimage.transform import SimilarityTransform\nfrom skimage.transform import warp\nfrom PIL import Image\nimport torch.nn.functional as F\nimport torchvision as tv\nimport torchvision.utils as vutils\nimport time\nimport cv2\nimport os\nfrom skimage import img_as_ubyte\nimport json\nimport argparse\nimport dlib\n\n\ndef _standard_face_pts():\n    pts = (\n        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0\n        - 1.0\n    )\n\n    return np.reshape(pts, (5, 2))\n\n\ndef _origin_face_pts():\n    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)\n\n    return np.reshape(pts, (5, 2))\n\n\ndef get_landmark(face_landmarks, id):\n    part = face_landmarks.part(id)\n    x = part.x\n    y = part.y\n\n    return (x, y)\n\n\ndef search(face_landmarks):\n\n    x1, y1 = get_landmark(face_landmarks, 36)\n    x2, y2 = get_landmark(face_landmarks, 39)\n    x3, y3 = get_landmark(face_landmarks, 42)\n    x4, y4 = get_landmark(face_landmarks, 45)\n\n    x_nose, y_nose = get_landmark(face_landmarks, 30)\n\n    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)\n    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)\n\n    x_left_eye = int((x1 + x2) / 2)\n    y_left_eye = int((y1 + y2) / 2)\n    x_right_eye = int((x3 + x4) / 2)\n    y_right_eye = int((y3 + y4) / 2)\n\n    results = np.array(\n        [\n            [x_left_eye, y_left_eye],\n            [x_right_eye, y_right_eye],\n            [x_nose, y_nose],\n            [x_left_mouth, y_left_mouth],\n            [x_right_mouth, y_right_mouth],\n        ]\n    )\n\n    return results\n\n\ndef compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(target_pts, landmark)\n\n    return affine.params\n\n\ndef show_detection(image, box, landmark):\n    plt.imshow(image)\n    print(box[2] - box[0])\n    plt.gca().add_patch(\n        Rectangle(\n            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n        )\n    )\n    plt.scatter(landmark[0][0], landmark[0][1])\n    plt.scatter(landmark[1][0], landmark[1][1])\n    plt.scatter(landmark[2][0], landmark[2][1])\n    plt.scatter(landmark[3][0], landmark[3][1])\n    plt.scatter(landmark[4][0], landmark[4][1])\n    plt.show()\n\n\ndef affine2theta(affine, input_w, input_h, target_w, target_h):\n    # param = np.linalg.inv(affine)\n    param = affine\n    theta = np.zeros([2, 3])\n    theta[0, 0] = param[0, 0] * input_h / target_h\n    theta[0, 1] = param[0, 1] * input_w / target_h\n    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1\n    theta[1, 0] = param[1, 0] * input_h / target_w\n    theta[1, 1] = param[1, 1] * input_w / target_w\n    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1\n    return theta\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--url\", type=str, default=\"/home/jingliao/ziyuwan/celebrities\", help=\"input\")\n    parser.add_argument(\n        \"--save_url\", type=str, default=\"/home/jingliao/ziyuwan/celebrities_detected_face_reid\", help=\"output\"\n    )\n    opts = parser.parse_args()\n\n    url = opts.url\n    save_url = opts.save_url\n\n    ### If the origin url is None, then we don't need to reid the origin image\n\n    os.makedirs(url, exist_ok=True)\n    os.makedirs(save_url, exist_ok=True)\n\n    face_detector = dlib.get_frontal_face_detector()\n    landmark_locator = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n\n    count = 0\n\n    map_id = {}\n    for x in os.listdir(url):\n        img_url = os.path.join(url, x)\n        pil_img = Image.open(img_url).convert(\"RGB\")\n\n        image = np.array(pil_img)\n\n        start = time.time()\n        faces = face_detector(image)\n        done = time.time()\n\n        if len(faces) == 0:\n            print(\"Warning: There is no face in %s\" % (x))\n            continue\n\n        print(len(faces))\n\n        if len(faces) > 0:\n            for face_id in range(len(faces)):\n                current_face = faces[face_id]\n                face_landmarks = landmark_locator(image, current_face)\n                current_fl = search(face_landmarks)\n\n                affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)\n                aligned_face = warp(image, affine, output_shape=(256, 256, 3))\n                img_name = x[:-4] + \"_\" + str(face_id + 1)\n                io.imsave(os.path.join(save_url, img_name + \".png\"), img_as_ubyte(aligned_face))\n\n        count += 1\n\n        if count % 1000 == 0:\n            print(\"%d have finished ...\" % (count))\n\n"
  },
  {
    "path": "Face_Detection/detect_all_dlib_HR.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport numpy as np\nimport skimage.io as io\n\n# from FaceSDK.face_sdk import FaceDetection\n# from face_sdk import FaceDetection\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Rectangle\nfrom skimage.transform import SimilarityTransform\nfrom skimage.transform import warp\nfrom PIL import Image\nimport torch.nn.functional as F\nimport torchvision as tv\nimport torchvision.utils as vutils\nimport time\nimport cv2\nimport os\nfrom skimage import img_as_ubyte\nimport json\nimport argparse\nimport dlib\n\n\ndef _standard_face_pts():\n    pts = (\n        np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0\n        - 1.0\n    )\n\n    return np.reshape(pts, (5, 2))\n\n\ndef _origin_face_pts():\n    pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)\n\n    return np.reshape(pts, (5, 2))\n\n\ndef get_landmark(face_landmarks, id):\n    part = face_landmarks.part(id)\n    x = part.x\n    y = part.y\n\n    return (x, y)\n\n\ndef search(face_landmarks):\n\n    x1, y1 = get_landmark(face_landmarks, 36)\n    x2, y2 = get_landmark(face_landmarks, 39)\n    x3, y3 = get_landmark(face_landmarks, 42)\n    x4, y4 = get_landmark(face_landmarks, 45)\n\n    x_nose, y_nose = get_landmark(face_landmarks, 30)\n\n    x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)\n    x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)\n\n    x_left_eye = int((x1 + x2) / 2)\n    y_left_eye = int((y1 + y2) / 2)\n    x_right_eye = int((x3 + x4) / 2)\n    y_right_eye = int((y3 + y4) / 2)\n\n    results = np.array(\n        [\n            [x_left_eye, y_left_eye],\n            [x_right_eye, y_right_eye],\n            [x_nose, y_nose],\n            [x_left_mouth, y_left_mouth],\n            [x_right_mouth, y_right_mouth],\n        ]\n    )\n\n    return results\n\n\ndef compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):\n\n    std_pts = _standard_face_pts()  # [-1,1]\n    target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0\n\n    # print(target_pts)\n\n    h, w, c = img.shape\n    if normalize == True:\n        landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0\n        landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0\n\n    # print(landmark)\n\n    affine = SimilarityTransform()\n\n    affine.estimate(target_pts, landmark)\n\n    return affine.params\n\n\ndef show_detection(image, box, landmark):\n    plt.imshow(image)\n    print(box[2] - box[0])\n    plt.gca().add_patch(\n        Rectangle(\n            (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n        )\n    )\n    plt.scatter(landmark[0][0], landmark[0][1])\n    plt.scatter(landmark[1][0], landmark[1][1])\n    plt.scatter(landmark[2][0], landmark[2][1])\n    plt.scatter(landmark[3][0], landmark[3][1])\n    plt.scatter(landmark[4][0], landmark[4][1])\n    plt.show()\n\n\ndef affine2theta(affine, input_w, input_h, target_w, target_h):\n    # param = np.linalg.inv(affine)\n    param = affine\n    theta = np.zeros([2, 3])\n    theta[0, 0] = param[0, 0] * input_h / target_h\n    theta[0, 1] = param[0, 1] * input_w / target_h\n    theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1\n    theta[1, 0] = param[1, 0] * input_h / target_w\n    theta[1, 1] = param[1, 1] * input_w / target_w\n    theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1\n    return theta\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--url\", type=str, default=\"/home/jingliao/ziyuwan/celebrities\", help=\"input\")\n    parser.add_argument(\n        \"--save_url\", type=str, default=\"/home/jingliao/ziyuwan/celebrities_detected_face_reid\", help=\"output\"\n    )\n    opts = parser.parse_args()\n\n    url = opts.url\n    save_url = opts.save_url\n\n    ### If the origin url is None, then we don't need to reid the origin image\n\n    os.makedirs(url, exist_ok=True)\n    os.makedirs(save_url, exist_ok=True)\n\n    face_detector = dlib.get_frontal_face_detector()\n    landmark_locator = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n\n    count = 0\n\n    map_id = {}\n    for x in os.listdir(url):\n        img_url = os.path.join(url, x)\n        pil_img = Image.open(img_url).convert(\"RGB\")\n\n        image = np.array(pil_img)\n\n        start = time.time()\n        faces = face_detector(image)\n        done = time.time()\n\n        if len(faces) == 0:\n            print(\"Warning: There is no face in %s\" % (x))\n            continue\n\n        print(len(faces))\n\n        if len(faces) > 0:\n            for face_id in range(len(faces)):\n                current_face = faces[face_id]\n                face_landmarks = landmark_locator(image, current_face)\n                current_fl = search(face_landmarks)\n\n                affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)\n                aligned_face = warp(image, affine, output_shape=(512, 512, 3))\n                img_name = x[:-4] + \"_\" + str(face_id + 1)\n                io.imsave(os.path.join(save_url, img_name + \".png\"), img_as_ubyte(aligned_face))\n\n        count += 1\n\n        if count % 1000 == 0:\n            print(\"%d have finished ...\" % (count))\n\n"
  },
  {
    "path": "Face_Enhancement/data/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch.utils.data\nfrom data.base_dataset import BaseDataset\nfrom data.face_dataset import FaceTestDataset\n\n\ndef create_dataloader(opt):\n\n    instance = FaceTestDataset()\n    instance.initialize(opt)\n    print(\"dataset [%s] of size %d was created\" % (type(instance).__name__, len(instance)))\n    dataloader = torch.utils.data.DataLoader(\n        instance,\n        batch_size=opt.batchSize,\n        shuffle=not opt.serial_batches,\n        num_workers=int(opt.nThreads),\n        drop_last=opt.isTrain,\n    )\n    return dataloader\n"
  },
  {
    "path": "Face_Enhancement/data/base_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\nimport numpy as np\nimport random\n\n\nclass BaseDataset(data.Dataset):\n    def __init__(self):\n        super(BaseDataset, self).__init__()\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        return parser\n\n    def initialize(self, opt):\n        pass\n\n\ndef get_params(opt, size):\n    w, h = size\n    new_h = h\n    new_w = w\n    if opt.preprocess_mode == \"resize_and_crop\":\n        new_h = new_w = opt.load_size\n    elif opt.preprocess_mode == \"scale_width_and_crop\":\n        new_w = opt.load_size\n        new_h = opt.load_size * h // w\n    elif opt.preprocess_mode == \"scale_shortside_and_crop\":\n        ss, ls = min(w, h), max(w, h)  # shortside and longside\n        width_is_shorter = w == ss\n        ls = int(opt.load_size * ls / ss)\n        new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)\n\n    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))\n    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))\n\n    flip = random.random() > 0.5\n    return {\"crop_pos\": (x, y), \"flip\": flip}\n\n\ndef get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):\n    transform_list = []\n    if \"resize\" in opt.preprocess_mode:\n        osize = [opt.load_size, opt.load_size]\n        transform_list.append(transforms.Resize(osize, interpolation=method))\n    elif \"scale_width\" in opt.preprocess_mode:\n        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))\n    elif \"scale_shortside\" in opt.preprocess_mode:\n        transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))\n\n    if \"crop\" in opt.preprocess_mode:\n        transform_list.append(transforms.Lambda(lambda img: __crop(img, params[\"crop_pos\"], opt.crop_size)))\n\n    if opt.preprocess_mode == \"none\":\n        base = 32\n        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))\n\n    if opt.preprocess_mode == \"fixed\":\n        w = opt.crop_size\n        h = round(opt.crop_size / opt.aspect_ratio)\n        transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))\n\n    if opt.isTrain and not opt.no_flip:\n        transform_list.append(transforms.Lambda(lambda img: __flip(img, params[\"flip\"])))\n\n    if toTensor:\n        transform_list += [transforms.ToTensor()]\n\n    if normalize:\n        transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n\n\ndef normalize():\n    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n\n\ndef __resize(img, w, h, method=Image.BICUBIC):\n    return img.resize((w, h), method)\n\n\ndef __make_power_2(img, base, method=Image.BICUBIC):\n    ow, oh = img.size\n    h = int(round(oh / base) * base)\n    w = int(round(ow / base) * base)\n    if (h == oh) and (w == ow):\n        return img\n    return img.resize((w, h), method)\n\n\ndef __scale_width(img, target_width, method=Image.BICUBIC):\n    ow, oh = img.size\n    if ow == target_width:\n        return img\n    w = target_width\n    h = int(target_width * oh / ow)\n    return img.resize((w, h), method)\n\n\ndef __scale_shortside(img, target_width, method=Image.BICUBIC):\n    ow, oh = img.size\n    ss, ls = min(ow, oh), max(ow, oh)  # shortside and longside\n    width_is_shorter = ow == ss\n    if ss == target_width:\n        return img\n    ls = int(target_width * ls / ss)\n    nw, nh = (ss, ls) if width_is_shorter else (ls, ss)\n    return img.resize((nw, nh), method)\n\n\ndef __crop(img, pos, size):\n    ow, oh = img.size\n    x1, y1 = pos\n    tw = th = size\n    return img.crop((x1, y1, x1 + tw, y1 + th))\n\n\ndef __flip(img, flip):\n    if flip:\n        return img.transpose(Image.FLIP_LEFT_RIGHT)\n    return img\n"
  },
  {
    "path": "Face_Enhancement/data/custom_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.pix2pix_dataset import Pix2pixDataset\nfrom data.image_folder import make_dataset\n\n\nclass CustomDataset(Pix2pixDataset):\n    \"\"\" Dataset that loads images from directories\n        Use option --label_dir, --image_dir, --instance_dir to specify the directories.\n        The images in the directories are sorted in alphabetical order and paired in order.\n    \"\"\"\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser = Pix2pixDataset.modify_commandline_options(parser, is_train)\n        parser.set_defaults(preprocess_mode=\"resize_and_crop\")\n        load_size = 286 if is_train else 256\n        parser.set_defaults(load_size=load_size)\n        parser.set_defaults(crop_size=256)\n        parser.set_defaults(display_winsize=256)\n        parser.set_defaults(label_nc=13)\n        parser.set_defaults(contain_dontcare_label=False)\n\n        parser.add_argument(\n            \"--label_dir\", type=str, required=True, help=\"path to the directory that contains label images\"\n        )\n        parser.add_argument(\n            \"--image_dir\", type=str, required=True, help=\"path to the directory that contains photo images\"\n        )\n        parser.add_argument(\n            \"--instance_dir\",\n            type=str,\n            default=\"\",\n            help=\"path to the directory that contains instance maps. Leave black if not exists\",\n        )\n        return parser\n\n    def get_paths(self, opt):\n        label_dir = opt.label_dir\n        label_paths = make_dataset(label_dir, recursive=False, read_cache=True)\n\n        image_dir = opt.image_dir\n        image_paths = make_dataset(image_dir, recursive=False, read_cache=True)\n\n        if len(opt.instance_dir) > 0:\n            instance_dir = opt.instance_dir\n            instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)\n        else:\n            instance_paths = []\n\n        assert len(label_paths) == len(\n            image_paths\n        ), \"The #images in %s and %s do not match. Is there something wrong?\"\n\n        return label_paths, image_paths, instance_paths\n"
  },
  {
    "path": "Face_Enhancement/data/face_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get_params, get_transform\nfrom PIL import Image\nimport util.util as util\nimport os\nimport torch\n\n\nclass FaceTestDataset(BaseDataset):\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.add_argument(\n            \"--no_pairing_check\",\n            action=\"store_true\",\n            help=\"If specified, skip sanity check of correct label-image file pairing\",\n        )\n        #    parser.set_defaults(contain_dontcare_label=False)\n        #    parser.set_defaults(no_instance=True)\n        return parser\n\n    def initialize(self, opt):\n        self.opt = opt\n\n        image_path = os.path.join(opt.dataroot, opt.old_face_folder)\n        label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)\n\n        image_list = os.listdir(image_path)\n        image_list = sorted(image_list)\n        # image_list=image_list[:opt.max_dataset_size]\n\n        self.label_paths = label_path  ## Just the root dir\n        self.image_paths = image_list  ## All the image name\n\n        self.parts = [\n            \"skin\",\n            \"hair\",\n            \"l_brow\",\n            \"r_brow\",\n            \"l_eye\",\n            \"r_eye\",\n            \"eye_g\",\n            \"l_ear\",\n            \"r_ear\",\n            \"ear_r\",\n            \"nose\",\n            \"mouth\",\n            \"u_lip\",\n            \"l_lip\",\n            \"neck\",\n            \"neck_l\",\n            \"cloth\",\n            \"hat\",\n        ]\n\n        size = len(self.image_paths)\n        self.dataset_size = size\n\n    def __getitem__(self, index):\n\n        params = get_params(self.opt, (-1, -1))\n        image_name = self.image_paths[index]\n        image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)\n        image = Image.open(image_path)\n        image = image.convert(\"RGB\")\n\n        transform_image = get_transform(self.opt, params)\n        image_tensor = transform_image(image)\n\n        img_name = image_name[:-4]\n        transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)\n        full_label = []\n\n        cnt = 0\n\n        for each_part in self.parts:\n            part_name = img_name + \"_\" + each_part + \".png\"\n            part_url = os.path.join(self.label_paths, part_name)\n\n            if os.path.exists(part_url):\n                label = Image.open(part_url).convert(\"RGB\")\n                label_tensor = transform_label(label)  ## 3 channels and pixel [0,1]\n                full_label.append(label_tensor[0])\n            else:\n                current_part = torch.zeros((self.opt.load_size, self.opt.load_size))\n                full_label.append(current_part)\n                cnt += 1\n\n        full_label_tensor = torch.stack(full_label, 0)\n\n        input_dict = {\n            \"label\": full_label_tensor,\n            \"image\": image_tensor,\n            \"path\": image_path,\n        }\n\n        return input_dict\n\n    def __len__(self):\n        return self.dataset_size\n\n"
  },
  {
    "path": "Face_Enhancement/data/image_folder.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL import Image\nimport os\n\nIMG_EXTENSIONS = [\n    \".jpg\",\n    \".JPG\",\n    \".jpeg\",\n    \".JPEG\",\n    \".png\",\n    \".PNG\",\n    \".ppm\",\n    \".PPM\",\n    \".bmp\",\n    \".BMP\",\n    \".tiff\",\n    \".webp\",\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset_rec(dir, images):\n    assert os.path.isdir(dir), \"%s is not a valid directory\" % dir\n\n    for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n\n\ndef make_dataset(dir, recursive=False, read_cache=False, write_cache=False):\n    images = []\n\n    if read_cache:\n        possible_filelist = os.path.join(dir, \"files.list\")\n        if os.path.isfile(possible_filelist):\n            with open(possible_filelist, \"r\") as f:\n                images = f.read().splitlines()\n                return images\n\n    if recursive:\n        make_dataset_rec(dir, images)\n    else:\n        assert os.path.isdir(dir) or os.path.islink(dir), \"%s is not a valid directory\" % dir\n\n        for root, dnames, fnames in sorted(os.walk(dir)):\n            for fname in fnames:\n                if is_image_file(fname):\n                    path = os.path.join(root, fname)\n                    images.append(path)\n\n    if write_cache:\n        filelist_cache = os.path.join(dir, \"files.list\")\n        with open(filelist_cache, \"w\") as f:\n            for path in images:\n                f.write(\"%s\\n\" % path)\n            print(\"wrote filelist cache at %s\" % filelist_cache)\n\n    return images\n\n\ndef default_loader(path):\n    return Image.open(path).convert(\"RGB\")\n\n\nclass ImageFolder(data.Dataset):\n    def __init__(self, root, transform=None, return_paths=False, loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise (\n                RuntimeError(\n                    \"Found 0 images in: \" + root + \"\\n\"\n                    \"Supported image extensions are: \" + \",\".join(IMG_EXTENSIONS)\n                )\n            )\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "Face_Enhancement/data/pix2pix_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom data.base_dataset import BaseDataset, get_params, get_transform\nfrom PIL import Image\nimport util.util as util\nimport os\n\n\nclass Pix2pixDataset(BaseDataset):\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.add_argument(\n            \"--no_pairing_check\",\n            action=\"store_true\",\n            help=\"If specified, skip sanity check of correct label-image file pairing\",\n        )\n        return parser\n\n    def initialize(self, opt):\n        self.opt = opt\n\n        label_paths, image_paths, instance_paths = self.get_paths(opt)\n\n        util.natural_sort(label_paths)\n        util.natural_sort(image_paths)\n        if not opt.no_instance:\n            util.natural_sort(instance_paths)\n\n        label_paths = label_paths[: opt.max_dataset_size]\n        image_paths = image_paths[: opt.max_dataset_size]\n        instance_paths = instance_paths[: opt.max_dataset_size]\n\n        if not opt.no_pairing_check:\n            for path1, path2 in zip(label_paths, image_paths):\n                assert self.paths_match(path1, path2), (\n                    \"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this.\"\n                    % (path1, path2)\n                )\n\n        self.label_paths = label_paths\n        self.image_paths = image_paths\n        self.instance_paths = instance_paths\n\n        size = len(self.label_paths)\n        self.dataset_size = size\n\n    def get_paths(self, opt):\n        label_paths = []\n        image_paths = []\n        instance_paths = []\n        assert False, \"A subclass of Pix2pixDataset must override self.get_paths(self, opt)\"\n        return label_paths, image_paths, instance_paths\n\n    def paths_match(self, path1, path2):\n        filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]\n        filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]\n        return filename1_without_ext == filename2_without_ext\n\n    def __getitem__(self, index):\n        # Label Image\n        label_path = self.label_paths[index]\n        label = Image.open(label_path)\n        params = get_params(self.opt, label.size)\n        transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)\n        label_tensor = transform_label(label) * 255.0\n        label_tensor[label_tensor == 255] = self.opt.label_nc  # 'unknown' is opt.label_nc\n\n        # input image (real images)\n        image_path = self.image_paths[index]\n        assert self.paths_match(\n            label_path, image_path\n        ), \"The label_path %s and image_path %s don't match.\" % (label_path, image_path)\n        image = Image.open(image_path)\n        image = image.convert(\"RGB\")\n\n        transform_image = get_transform(self.opt, params)\n        image_tensor = transform_image(image)\n\n        # if using instance maps\n        if self.opt.no_instance:\n            instance_tensor = 0\n        else:\n            instance_path = self.instance_paths[index]\n            instance = Image.open(instance_path)\n            if instance.mode == \"L\":\n                instance_tensor = transform_label(instance) * 255\n                instance_tensor = instance_tensor.long()\n            else:\n                instance_tensor = transform_label(instance)\n\n        input_dict = {\n            \"label\": label_tensor,\n            \"instance\": instance_tensor,\n            \"image\": image_tensor,\n            \"path\": image_path,\n        }\n\n        # Give subclasses a chance to modify the final output\n        self.postprocess(input_dict)\n\n        return input_dict\n\n    def postprocess(self, input_dict):\n        return input_dict\n\n    def __len__(self):\n        return self.dataset_size\n"
  },
  {
    "path": "Face_Enhancement/models/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport importlib\nimport torch\n\n\ndef find_model_using_name(model_name):\n    # Given the option --model [modelname],\n    # the file \"models/modelname_model.py\"\n    # will be imported.\n    model_filename = \"models.\" + model_name + \"_model\"\n    modellib = importlib.import_module(model_filename)\n\n    # In the file, the class called ModelNameModel() will\n    # be instantiated. It has to be a subclass of torch.nn.Module,\n    # and it is case-insensitive.\n    model = None\n    target_model_name = model_name.replace(\"_\", \"\") + \"model\"\n    for name, cls in modellib.__dict__.items():\n        if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):\n            model = cls\n\n    if model is None:\n        print(\n            \"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase.\"\n            % (model_filename, target_model_name)\n        )\n        exit(0)\n\n    return model\n\n\ndef get_option_setter(model_name):\n    model_class = find_model_using_name(model_name)\n    return model_class.modify_commandline_options\n\n\ndef create_model(opt):\n    model = find_model_using_name(opt.model)\n    instance = model(opt)\n    print(\"model [%s] was created\" % (type(instance).__name__))\n\n    return instance\n"
  },
  {
    "path": "Face_Enhancement/models/networks/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nfrom models.networks.base_network import BaseNetwork\nfrom models.networks.generator import *\nfrom models.networks.encoder import *\nimport util.util as util\n\n\ndef find_network_using_name(target_network_name, filename):\n    target_class_name = target_network_name + filename\n    module_name = \"models.networks.\" + filename\n    network = util.find_class_in_module(target_class_name, module_name)\n\n    assert issubclass(network, BaseNetwork), \"Class %s should be a subclass of BaseNetwork\" % network\n\n    return network\n\n\ndef modify_commandline_options(parser, is_train):\n    opt, _ = parser.parse_known_args()\n\n    netG_cls = find_network_using_name(opt.netG, \"generator\")\n    parser = netG_cls.modify_commandline_options(parser, is_train)\n    if is_train:\n        netD_cls = find_network_using_name(opt.netD, \"discriminator\")\n        parser = netD_cls.modify_commandline_options(parser, is_train)\n    netE_cls = find_network_using_name(\"conv\", \"encoder\")\n    parser = netE_cls.modify_commandline_options(parser, is_train)\n\n    return parser\n\n\ndef create_network(cls, opt):\n    net = cls(opt)\n    net.print_network()\n    if len(opt.gpu_ids) > 0:\n        assert torch.cuda.is_available()\n        net.cuda()\n    net.init_weights(opt.init_type, opt.init_variance)\n    return net\n\n\ndef define_G(opt):\n    netG_cls = find_network_using_name(opt.netG, \"generator\")\n    return create_network(netG_cls, opt)\n\n\ndef define_D(opt):\n    netD_cls = find_network_using_name(opt.netD, \"discriminator\")\n    return create_network(netD_cls, opt)\n\n\ndef define_E(opt):\n    # there exists only one encoder type\n    netE_cls = find_network_using_name(\"conv\", \"encoder\")\n    return create_network(netE_cls, opt)\n"
  },
  {
    "path": "Face_Enhancement/models/networks/architecture.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nimport torch.nn.utils.spectral_norm as spectral_norm\nfrom models.networks.normalization import SPADE\n\n\n# ResNet block that uses SPADE.\n# It differs from the ResNet block of pix2pixHD in that\n# it takes in the segmentation map as input, learns the skip connection if necessary,\n# and applies normalization first and then convolution.\n# This architecture seemed like a standard architecture for unconditional or\n# class-conditional GAN architecture using residual block.\n# The code was inspired from https://github.com/LMescheder/GAN_stability.\nclass SPADEResnetBlock(nn.Module):\n    def __init__(self, fin, fout, opt):\n        super().__init__()\n        # Attributes\n        self.learned_shortcut = fin != fout\n        fmiddle = min(fin, fout)\n\n        self.opt = opt\n        # create conv layers\n        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)\n        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)\n\n        # apply spectral norm if specified\n        if \"spectral\" in opt.norm_G:\n            self.conv_0 = spectral_norm(self.conv_0)\n            self.conv_1 = spectral_norm(self.conv_1)\n            if self.learned_shortcut:\n                self.conv_s = spectral_norm(self.conv_s)\n\n        # define normalization layers\n        spade_config_str = opt.norm_G.replace(\"spectral\", \"\")\n        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)\n        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)\n        if self.learned_shortcut:\n            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)\n\n    # note the resnet block with SPADE also takes in |seg|,\n    # the semantic segmentation map as input\n    def forward(self, x, seg, degraded_image):\n        x_s = self.shortcut(x, seg, degraded_image)\n\n        dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))\n        dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))\n\n        out = x_s + dx\n\n        return out\n\n    def shortcut(self, x, seg, degraded_image):\n        if self.learned_shortcut:\n            x_s = self.conv_s(self.norm_s(x, seg, degraded_image))\n        else:\n            x_s = x\n        return x_s\n\n    def actvn(self, x):\n        return F.leaky_relu(x, 2e-1)\n\n\n# ResNet block used in pix2pixHD\n# We keep the same architecture as pix2pixHD.\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):\n        super().__init__()\n\n        pw = (kernel_size - 1) // 2\n        self.conv_block = nn.Sequential(\n            nn.ReflectionPad2d(pw),\n            norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),\n            activation,\n            nn.ReflectionPad2d(pw),\n            norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),\n        )\n\n    def forward(self, x):\n        y = self.conv_block(x)\n        out = x + y\n        return out\n\n\n# VGG architecter, used for the perceptual loss using a pretrained VGG network\nclass VGG19(torch.nn.Module):\n    def __init__(self, requires_grad=False):\n        super().__init__()\n        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h_relu1 = self.slice1(X)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass SPADEResnetBlock_non_spade(nn.Module):\n    def __init__(self, fin, fout, opt):\n        super().__init__()\n        # Attributes\n        self.learned_shortcut = fin != fout\n        fmiddle = min(fin, fout)\n\n        self.opt = opt\n        # create conv layers\n        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)\n        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)\n\n        # apply spectral norm if specified\n        if \"spectral\" in opt.norm_G:\n            self.conv_0 = spectral_norm(self.conv_0)\n            self.conv_1 = spectral_norm(self.conv_1)\n            if self.learned_shortcut:\n                self.conv_s = spectral_norm(self.conv_s)\n\n        # define normalization layers\n        spade_config_str = opt.norm_G.replace(\"spectral\", \"\")\n        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)\n        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)\n        if self.learned_shortcut:\n            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)\n\n    # note the resnet block with SPADE also takes in |seg|,\n    # the semantic segmentation map as input\n    def forward(self, x, seg, degraded_image):\n        x_s = self.shortcut(x, seg, degraded_image)\n\n        dx = self.conv_0(self.actvn(x))\n        dx = self.conv_1(self.actvn(dx))\n\n        out = x_s + dx\n\n        return out\n\n    def shortcut(self, x, seg, degraded_image):\n        if self.learned_shortcut:\n            x_s = self.conv_s(x)\n        else:\n            x_s = x\n        return x_s\n\n    def actvn(self, x):\n        return F.leaky_relu(x, 2e-1)\n"
  },
  {
    "path": "Face_Enhancement/models/networks/base_network.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nfrom torch.nn import init\n\n\nclass BaseNetwork(nn.Module):\n    def __init__(self):\n        super(BaseNetwork, self).__init__()\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        return parser\n\n    def print_network(self):\n        if isinstance(self, list):\n            self = self[0]\n        num_params = 0\n        for param in self.parameters():\n            num_params += param.numel()\n        print(\n            \"Network [%s] was created. Total number of parameters: %.1f million. \"\n            \"To see the architecture, do print(network).\" % (type(self).__name__, num_params / 1000000)\n        )\n\n    def init_weights(self, init_type=\"normal\", gain=0.02):\n        def init_func(m):\n            classname = m.__class__.__name__\n            if classname.find(\"BatchNorm2d\") != -1:\n                if hasattr(m, \"weight\") and m.weight is not None:\n                    init.normal_(m.weight.data, 1.0, gain)\n                if hasattr(m, \"bias\") and m.bias is not None:\n                    init.constant_(m.bias.data, 0.0)\n            elif hasattr(m, \"weight\") and (classname.find(\"Conv\") != -1 or classname.find(\"Linear\") != -1):\n                if init_type == \"normal\":\n                    init.normal_(m.weight.data, 0.0, gain)\n                elif init_type == \"xavier\":\n                    init.xavier_normal_(m.weight.data, gain=gain)\n                elif init_type == \"xavier_uniform\":\n                    init.xavier_uniform_(m.weight.data, gain=1.0)\n                elif init_type == \"kaiming\":\n                    init.kaiming_normal_(m.weight.data, a=0, mode=\"fan_in\")\n                elif init_type == \"orthogonal\":\n                    init.orthogonal_(m.weight.data, gain=gain)\n                elif init_type == \"none\":  # uses pytorch's default init method\n                    m.reset_parameters()\n                else:\n                    raise NotImplementedError(\"initialization method [%s] is not implemented\" % init_type)\n                if hasattr(m, \"bias\") and m.bias is not None:\n                    init.constant_(m.bias.data, 0.0)\n\n        self.apply(init_func)\n\n        # propagate to children\n        for m in self.children():\n            if hasattr(m, \"init_weights\"):\n                m.init_weights(init_type, gain)\n"
  },
  {
    "path": "Face_Enhancement/models/networks/encoder.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nfrom models.networks.base_network import BaseNetwork\nfrom models.networks.normalization import get_nonspade_norm_layer\n\n\nclass ConvEncoder(BaseNetwork):\n    \"\"\" Same architecture as the image discriminator \"\"\"\n\n    def __init__(self, opt):\n        super().__init__()\n\n        kw = 3\n        pw = int(np.ceil((kw - 1.0) / 2))\n        ndf = opt.ngf\n        norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)\n        self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))\n        self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))\n        self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))\n        self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))\n        self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))\n        if opt.crop_size >= 256:\n            self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))\n\n        self.so = s0 = 4\n        self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)\n        self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)\n\n        self.actvn = nn.LeakyReLU(0.2, False)\n        self.opt = opt\n\n    def forward(self, x):\n        if x.size(2) != 256 or x.size(3) != 256:\n            x = F.interpolate(x, size=(256, 256), mode=\"bilinear\")\n\n        x = self.layer1(x)\n        x = self.layer2(self.actvn(x))\n        x = self.layer3(self.actvn(x))\n        x = self.layer4(self.actvn(x))\n        x = self.layer5(self.actvn(x))\n        if self.opt.crop_size >= 256:\n            x = self.layer6(self.actvn(x))\n        x = self.actvn(x)\n\n        x = x.view(x.size(0), -1)\n        mu = self.fc_mu(x)\n        logvar = self.fc_var(x)\n\n        return mu, logvar\n"
  },
  {
    "path": "Face_Enhancement/models/networks/generator.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.networks.base_network import BaseNetwork\nfrom models.networks.normalization import get_nonspade_norm_layer\nfrom models.networks.architecture import ResnetBlock as ResnetBlock\nfrom models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock\nfrom models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade\n\n\nclass SPADEGenerator(BaseNetwork):\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.set_defaults(norm_G=\"spectralspadesyncbatch3x3\")\n        parser.add_argument(\n            \"--num_upsampling_layers\",\n            choices=(\"normal\", \"more\", \"most\"),\n            default=\"normal\",\n            help=\"If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator\",\n        )\n\n        return parser\n\n    def __init__(self, opt):\n        super().__init__()\n        self.opt = opt\n        nf = opt.ngf\n\n        self.sw, self.sh = self.compute_latent_vector_size(opt)\n\n        print(\"The size of the latent vector size is [%d,%d]\" % (self.sw, self.sh))\n\n        if opt.use_vae:\n            # In case of VAE, we will sample from random z vector\n            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)\n        else:\n            # Otherwise, we make the network deterministic by starting with\n            # downsampled segmentation map instead of random z\n            if self.opt.no_parsing_map:\n                self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)\n            else:\n                self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"1\":\n            self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)\n        else:\n            self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"2\":\n            self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)\n            self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)\n\n        else:\n            self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)\n            self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"3\":\n            self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)\n        else:\n            self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"4\":\n            self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)\n        else:\n            self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"5\":\n            self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)\n        else:\n            self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)\n\n        if self.opt.injection_layer == \"all\" or self.opt.injection_layer == \"6\":\n            self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)\n        else:\n            self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)\n\n        final_nc = nf\n\n        if opt.num_upsampling_layers == \"most\":\n            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)\n            final_nc = nf // 2\n\n        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)\n\n        self.up = nn.Upsample(scale_factor=2)\n\n    def compute_latent_vector_size(self, opt):\n        if opt.num_upsampling_layers == \"normal\":\n            num_up_layers = 5\n        elif opt.num_upsampling_layers == \"more\":\n            num_up_layers = 6\n        elif opt.num_upsampling_layers == \"most\":\n            num_up_layers = 7\n        else:\n            raise ValueError(\"opt.num_upsampling_layers [%s] not recognized\" % opt.num_upsampling_layers)\n\n        sw = opt.load_size // (2 ** num_up_layers)\n        sh = round(sw / opt.aspect_ratio)\n\n        return sw, sh\n\n    def forward(self, input, degraded_image, z=None):\n        seg = input\n\n        if self.opt.use_vae:\n            # we sample z from unit normal and reshape the tensor\n            if z is None:\n                z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())\n            x = self.fc(z)\n            x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)\n        else:\n            # we downsample segmap and run convolution\n            if self.opt.no_parsing_map:\n                x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode=\"bilinear\")\n            else:\n                x = F.interpolate(seg, size=(self.sh, self.sw), mode=\"nearest\")\n            x = self.fc(x)\n\n        x = self.head_0(x, seg, degraded_image)\n\n        x = self.up(x)\n        x = self.G_middle_0(x, seg, degraded_image)\n\n        if self.opt.num_upsampling_layers == \"more\" or self.opt.num_upsampling_layers == \"most\":\n            x = self.up(x)\n\n        x = self.G_middle_1(x, seg, degraded_image)\n\n        x = self.up(x)\n        x = self.up_0(x, seg, degraded_image)\n        x = self.up(x)\n        x = self.up_1(x, seg, degraded_image)\n        x = self.up(x)\n        x = self.up_2(x, seg, degraded_image)\n        x = self.up(x)\n        x = self.up_3(x, seg, degraded_image)\n\n        if self.opt.num_upsampling_layers == \"most\":\n            x = self.up(x)\n            x = self.up_4(x, seg, degraded_image)\n\n        x = self.conv_img(F.leaky_relu(x, 2e-1))\n        x = F.tanh(x)\n\n        return x\n\n\nclass Pix2PixHDGenerator(BaseNetwork):\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.add_argument(\n            \"--resnet_n_downsample\", type=int, default=4, help=\"number of downsampling layers in netG\"\n        )\n        parser.add_argument(\n            \"--resnet_n_blocks\",\n            type=int,\n            default=9,\n            help=\"number of residual blocks in the global generator network\",\n        )\n        parser.add_argument(\n            \"--resnet_kernel_size\", type=int, default=3, help=\"kernel size of the resnet block\"\n        )\n        parser.add_argument(\n            \"--resnet_initial_kernel_size\", type=int, default=7, help=\"kernel size of the first convolution\"\n        )\n        # parser.set_defaults(norm_G='instance')\n        return parser\n\n    def __init__(self, opt):\n        super().__init__()\n        input_nc = 3\n\n        # print(\"xxxxx\")\n        # print(opt.norm_G)\n        norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)\n        activation = nn.ReLU(False)\n\n        model = []\n\n        # initial conv\n        model += [\n            nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),\n            norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),\n            activation,\n        ]\n\n        # downsample\n        mult = 1\n        for i in range(opt.resnet_n_downsample):\n            model += [\n                norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),\n                activation,\n            ]\n            mult *= 2\n\n        # resnet blocks\n        for i in range(opt.resnet_n_blocks):\n            model += [\n                ResnetBlock(\n                    opt.ngf * mult,\n                    norm_layer=norm_layer,\n                    activation=activation,\n                    kernel_size=opt.resnet_kernel_size,\n                )\n            ]\n\n        # upsample\n        for i in range(opt.resnet_n_downsample):\n            nc_in = int(opt.ngf * mult)\n            nc_out = int((opt.ngf * mult) / 2)\n            model += [\n                norm_layer(\n                    nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)\n                ),\n                activation,\n            ]\n            mult = mult // 2\n\n        # final output conv\n        model += [\n            nn.ReflectionPad2d(3),\n            nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),\n            nn.Tanh(),\n        ]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input, degraded_image, z=None):\n        return self.model(degraded_image)\n\n"
  },
  {
    "path": "Face_Enhancement/models/networks/normalization.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.networks.sync_batchnorm import SynchronizedBatchNorm2d\nimport torch.nn.utils.spectral_norm as spectral_norm\n\n\ndef get_nonspade_norm_layer(opt, norm_type=\"instance\"):\n    # helper function to get # output channels of the previous layer\n    def get_out_channel(layer):\n        if hasattr(layer, \"out_channels\"):\n            return getattr(layer, \"out_channels\")\n        return layer.weight.size(0)\n\n    # this function will be returned\n    def add_norm_layer(layer):\n        nonlocal norm_type\n        if norm_type.startswith(\"spectral\"):\n            layer = spectral_norm(layer)\n            subnorm_type = norm_type[len(\"spectral\") :]\n\n        if subnorm_type == \"none\" or len(subnorm_type) == 0:\n            return layer\n\n        # remove bias in the previous layer, which is meaningless\n        # since it has no effect after normalization\n        if getattr(layer, \"bias\", None) is not None:\n            delattr(layer, \"bias\")\n            layer.register_parameter(\"bias\", None)\n\n        if subnorm_type == \"batch\":\n            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)\n        elif subnorm_type == \"sync_batch\":\n            norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)\n        elif subnorm_type == \"instance\":\n            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)\n        else:\n            raise ValueError(\"normalization layer %s is not recognized\" % subnorm_type)\n\n        return nn.Sequential(layer, norm_layer)\n\n    return add_norm_layer\n\n\nclass SPADE(nn.Module):\n    def __init__(self, config_text, norm_nc, label_nc, opt):\n        super().__init__()\n\n        assert config_text.startswith(\"spade\")\n        parsed = re.search(\"spade(\\D+)(\\d)x\\d\", config_text)\n        param_free_norm_type = str(parsed.group(1))\n        ks = int(parsed.group(2))\n        self.opt = opt\n        if param_free_norm_type == \"instance\":\n            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)\n        elif param_free_norm_type == \"syncbatch\":\n            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)\n        elif param_free_norm_type == \"batch\":\n            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)\n        else:\n            raise ValueError(\"%s is not a recognized param-free norm type in SPADE\" % param_free_norm_type)\n\n        # The dimension of the intermediate embedding space. Yes, hardcoded.\n        nhidden = 128\n\n        pw = ks // 2\n\n        if self.opt.no_parsing_map:\n            self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())\n        else:\n            self.mlp_shared = nn.Sequential(\n                nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()\n            )\n        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)\n        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)\n\n    def forward(self, x, segmap, degraded_image):\n\n        # Part 1. generate parameter-free normalized activations\n        normalized = self.param_free_norm(x)\n\n        # Part 2. produce scaling and bias conditioned on semantic map\n        segmap = F.interpolate(segmap, size=x.size()[2:], mode=\"nearest\")\n        degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode=\"bilinear\")\n\n        if self.opt.no_parsing_map:\n            actv = self.mlp_shared(degraded_face)\n        else:\n            actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))\n        gamma = self.mlp_gamma(actv)\n        beta = self.mlp_beta(actv)\n\n        # apply scale and bias\n        out = normalized * (1 + gamma) + beta\n\n        return out\n"
  },
  {
    "path": "Face_Enhancement/models/pix2pix_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport models.networks as networks\nimport util.util as util\n\n\nclass Pix2PixModel(torch.nn.Module):\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        networks.modify_commandline_options(parser, is_train)\n        return parser\n\n    def __init__(self, opt):\n        super().__init__()\n        self.opt = opt\n        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor\n        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor\n\n        self.netG, self.netD, self.netE = self.initialize_networks(opt)\n\n        # set loss functions\n        if opt.isTrain:\n            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)\n            self.criterionFeat = torch.nn.L1Loss()\n            if not opt.no_vgg_loss:\n                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)\n            if opt.use_vae:\n                self.KLDLoss = networks.KLDLoss()\n\n    # Entry point for all calls involving forward pass\n    # of deep networks. We used this approach since DataParallel module\n    # can't parallelize custom functions, we branch to different\n    # routines based on |mode|.\n    def forward(self, data, mode):\n        input_semantics, real_image, degraded_image = self.preprocess_input(data)\n\n        if mode == \"generator\":\n            g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)\n            return g_loss, generated\n        elif mode == \"discriminator\":\n            d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)\n            return d_loss\n        elif mode == \"encode_only\":\n            z, mu, logvar = self.encode_z(real_image)\n            return mu, logvar\n        elif mode == \"inference\":\n            with torch.no_grad():\n                fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)\n            return fake_image\n        else:\n            raise ValueError(\"|mode| is invalid\")\n\n    def create_optimizers(self, opt):\n        G_params = list(self.netG.parameters())\n        if opt.use_vae:\n            G_params += list(self.netE.parameters())\n        if opt.isTrain:\n            D_params = list(self.netD.parameters())\n\n        beta1, beta2 = opt.beta1, opt.beta2\n        if opt.no_TTUR:\n            G_lr, D_lr = opt.lr, opt.lr\n        else:\n            G_lr, D_lr = opt.lr / 2, opt.lr * 2\n\n        optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))\n        optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))\n\n        return optimizer_G, optimizer_D\n\n    def save(self, epoch):\n        util.save_network(self.netG, \"G\", epoch, self.opt)\n        util.save_network(self.netD, \"D\", epoch, self.opt)\n        if self.opt.use_vae:\n            util.save_network(self.netE, \"E\", epoch, self.opt)\n\n    ############################################################################\n    # Private helper methods\n    ############################################################################\n\n    def initialize_networks(self, opt):\n        netG = networks.define_G(opt)\n        netD = networks.define_D(opt) if opt.isTrain else None\n        netE = networks.define_E(opt) if opt.use_vae else None\n\n        if not opt.isTrain or opt.continue_train:\n            netG = util.load_network(netG, \"G\", opt.which_epoch, opt)\n            if opt.isTrain:\n                netD = util.load_network(netD, \"D\", opt.which_epoch, opt)\n            if opt.use_vae:\n                netE = util.load_network(netE, \"E\", opt.which_epoch, opt)\n\n        return netG, netD, netE\n\n    # preprocess the input, such as moving the tensors to GPUs and\n    # transforming the label map to one-hot encoding\n    # |data|: dictionary of the input data\n\n    def preprocess_input(self, data):\n        # move to GPU and change data types\n        # data['label'] = data['label'].long()\n\n        if not self.opt.isTrain:\n            if self.use_gpu():\n                data[\"label\"] = data[\"label\"].cuda()\n                data[\"image\"] = data[\"image\"].cuda()\n            return data[\"label\"], data[\"image\"], data[\"image\"]\n\n        ## While testing, the input image is the degraded face\n        if self.use_gpu():\n            data[\"label\"] = data[\"label\"].cuda()\n            data[\"degraded_image\"] = data[\"degraded_image\"].cuda()\n            data[\"image\"] = data[\"image\"].cuda()\n\n        # # create one-hot label map\n        # label_map = data['label']\n        # bs, _, h, w = label_map.size()\n        # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \\\n        #     else self.opt.label_nc\n        # input_label = self.FloatTensor(bs, nc, h, w).zero_()\n        # input_semantics = input_label.scatter_(1, label_map, 1.0)\n\n        return data[\"label\"], data[\"image\"], data[\"degraded_image\"]\n\n    def compute_generator_loss(self, input_semantics, degraded_image, real_image):\n        G_losses = {}\n\n        fake_image, KLD_loss = self.generate_fake(\n            input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae\n        )\n\n        if self.opt.use_vae:\n            G_losses[\"KLD\"] = KLD_loss\n\n        pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)\n\n        G_losses[\"GAN\"] = self.criterionGAN(pred_fake, True, for_discriminator=False)\n\n        if not self.opt.no_ganFeat_loss:\n            num_D = len(pred_fake)\n            GAN_Feat_loss = self.FloatTensor(1).fill_(0)\n            for i in range(num_D):  # for each discriminator\n                # last output is the final prediction, so we exclude it\n                num_intermediate_outputs = len(pred_fake[i]) - 1\n                for j in range(num_intermediate_outputs):  # for each layer output\n                    unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())\n                    GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D\n            G_losses[\"GAN_Feat\"] = GAN_Feat_loss\n\n        if not self.opt.no_vgg_loss:\n            G_losses[\"VGG\"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg\n\n        return G_losses, fake_image\n\n    def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):\n        D_losses = {}\n        with torch.no_grad():\n            fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)\n            fake_image = fake_image.detach()\n            fake_image.requires_grad_()\n\n        pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)\n\n        D_losses[\"D_Fake\"] = self.criterionGAN(pred_fake, False, for_discriminator=True)\n        D_losses[\"D_real\"] = self.criterionGAN(pred_real, True, for_discriminator=True)\n\n        return D_losses\n\n    def encode_z(self, real_image):\n        mu, logvar = self.netE(real_image)\n        z = self.reparameterize(mu, logvar)\n        return z, mu, logvar\n\n    def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):\n        z = None\n        KLD_loss = None\n        if self.opt.use_vae:\n            z, mu, logvar = self.encode_z(real_image)\n            if compute_kld_loss:\n                KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld\n\n        fake_image = self.netG(input_semantics, degraded_image, z=z)\n\n        assert (\n            not compute_kld_loss\n        ) or self.opt.use_vae, \"You cannot compute KLD loss if opt.use_vae == False\"\n\n        return fake_image, KLD_loss\n\n    # Given fake and real image, return the prediction of discriminator\n    # for each fake and real image.\n\n    def discriminate(self, input_semantics, fake_image, real_image):\n\n        if self.opt.no_parsing_map:\n            fake_concat = fake_image\n            real_concat = real_image\n        else:\n            fake_concat = torch.cat([input_semantics, fake_image], dim=1)\n            real_concat = torch.cat([input_semantics, real_image], dim=1)\n\n        # In Batch Normalization, the fake and real images are\n        # recommended to be in the same batch to avoid disparate\n        # statistics in fake and real images.\n        # So both fake and real images are fed to D all at once.\n        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)\n\n        discriminator_out = self.netD(fake_and_real)\n\n        pred_fake, pred_real = self.divide_pred(discriminator_out)\n\n        return pred_fake, pred_real\n\n    # Take the prediction of fake and real images from the combined batch\n    def divide_pred(self, pred):\n        # the prediction contains the intermediate outputs of multiscale GAN,\n        # so it's usually a list\n        if type(pred) == list:\n            fake = []\n            real = []\n            for p in pred:\n                fake.append([tensor[: tensor.size(0) // 2] for tensor in p])\n                real.append([tensor[tensor.size(0) // 2 :] for tensor in p])\n        else:\n            fake = pred[: pred.size(0) // 2]\n            real = pred[pred.size(0) // 2 :]\n\n        return fake, real\n\n    def get_edges(self, t):\n        edge = self.ByteTensor(t.size()).zero_()\n        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])\n        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])\n        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])\n        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])\n        return edge.float()\n\n    def reparameterize(self, mu, logvar):\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps.mul(std) + mu\n\n    def use_gpu(self):\n        return len(self.opt.gpu_ids) > 0\n"
  },
  {
    "path": "Face_Enhancement/options/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "Face_Enhancement/options/base_options.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport sys\nimport argparse\nimport os\nfrom util import util\nimport torch\nimport models\nimport data\nimport pickle\n\n\nclass BaseOptions:\n    def __init__(self):\n        self.initialized = False\n\n    def initialize(self, parser):\n        # experiment specifics\n        parser.add_argument(\n            \"--name\",\n            type=str,\n            default=\"label2coco\",\n            help=\"name of the experiment. It decides where to store samples and models\",\n        )\n\n        parser.add_argument(\n            \"--gpu_ids\", type=str, default=\"0\", help=\"gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU\"\n        )\n        parser.add_argument(\n            \"--checkpoints_dir\", type=str, default=\"./checkpoints\", help=\"models are saved here\"\n        )\n        parser.add_argument(\"--model\", type=str, default=\"pix2pix\", help=\"which model to use\")\n        parser.add_argument(\n            \"--norm_G\",\n            type=str,\n            default=\"spectralinstance\",\n            help=\"instance normalization or batch normalization\",\n        )\n        parser.add_argument(\n            \"--norm_D\",\n            type=str,\n            default=\"spectralinstance\",\n            help=\"instance normalization or batch normalization\",\n        )\n        parser.add_argument(\n            \"--norm_E\",\n            type=str,\n            default=\"spectralinstance\",\n            help=\"instance normalization or batch normalization\",\n        )\n        parser.add_argument(\"--phase\", type=str, default=\"train\", help=\"train, val, test, etc\")\n\n        # input/output sizes\n        parser.add_argument(\"--batchSize\", type=int, default=1, help=\"input batch size\")\n        parser.add_argument(\n            \"--preprocess_mode\",\n            type=str,\n            default=\"scale_width_and_crop\",\n            help=\"scaling and cropping of images at load time.\",\n            choices=(\n                \"resize_and_crop\",\n                \"crop\",\n                \"scale_width\",\n                \"scale_width_and_crop\",\n                \"scale_shortside\",\n                \"scale_shortside_and_crop\",\n                \"fixed\",\n                \"none\",\n                \"resize\",\n            ),\n        )\n        parser.add_argument(\n            \"--load_size\",\n            type=int,\n            default=1024,\n            help=\"Scale images to this size. The final image will be cropped to --crop_size.\",\n        )\n        parser.add_argument(\n            \"--crop_size\",\n            type=int,\n            default=512,\n            help=\"Crop to the width of crop_size (after initially scaling the images to load_size.)\",\n        )\n        parser.add_argument(\n            \"--aspect_ratio\",\n            type=float,\n            default=1.0,\n            help=\"The ratio width/height. The final height of the load image will be crop_size/aspect_ratio\",\n        )\n        parser.add_argument(\n            \"--label_nc\",\n            type=int,\n            default=182,\n            help=\"# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.\",\n        )\n        parser.add_argument(\n            \"--contain_dontcare_label\",\n            action=\"store_true\",\n            help=\"if the label map contains dontcare label (dontcare=255)\",\n        )\n        parser.add_argument(\"--output_nc\", type=int, default=3, help=\"# of output image channels\")\n\n        # for setting inputs\n        parser.add_argument(\"--dataroot\", type=str, default=\"./datasets/cityscapes/\")\n        parser.add_argument(\"--dataset_mode\", type=str, default=\"coco\")\n        parser.add_argument(\n            \"--serial_batches\",\n            action=\"store_true\",\n            help=\"if true, takes images in order to make batches, otherwise takes them randomly\",\n        )\n        parser.add_argument(\n            \"--no_flip\",\n            action=\"store_true\",\n            help=\"if specified, do not flip the images for data argumentation\",\n        )\n        parser.add_argument(\"--nThreads\", default=0, type=int, help=\"# threads for loading data\")\n        parser.add_argument(\n            \"--max_dataset_size\",\n            type=int,\n            default=sys.maxsize,\n            help=\"Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.\",\n        )\n        parser.add_argument(\n            \"--load_from_opt_file\",\n            action=\"store_true\",\n            help=\"load the options from checkpoints and use that as default\",\n        )\n        parser.add_argument(\n            \"--cache_filelist_write\",\n            action=\"store_true\",\n            help=\"saves the current filelist into a text file, so that it loads faster\",\n        )\n        parser.add_argument(\n            \"--cache_filelist_read\", action=\"store_true\", help=\"reads from the file list cache\"\n        )\n\n        # for displays\n        parser.add_argument(\"--display_winsize\", type=int, default=400, help=\"display window size\")\n\n        # for generator\n        parser.add_argument(\n            \"--netG\", type=str, default=\"spade\", help=\"selects model to use for netG (pix2pixhd | spade)\"\n        )\n        parser.add_argument(\"--ngf\", type=int, default=64, help=\"# of gen filters in first conv layer\")\n        parser.add_argument(\n            \"--init_type\",\n            type=str,\n            default=\"xavier\",\n            help=\"network initialization [normal|xavier|kaiming|orthogonal]\",\n        )\n        parser.add_argument(\n            \"--init_variance\", type=float, default=0.02, help=\"variance of the initialization distribution\"\n        )\n        parser.add_argument(\"--z_dim\", type=int, default=256, help=\"dimension of the latent z vector\")\n        parser.add_argument(\n            \"--no_parsing_map\", action=\"store_true\", help=\"During training, we do not use the parsing map\"\n        )\n\n        # for instance-wise features\n        parser.add_argument(\n            \"--no_instance\", action=\"store_true\", help=\"if specified, do *not* add instance map as input\"\n        )\n        parser.add_argument(\n            \"--nef\", type=int, default=16, help=\"# of encoder filters in the first conv layer\"\n        )\n        parser.add_argument(\"--use_vae\", action=\"store_true\", help=\"enable training with an image encoder.\")\n        parser.add_argument(\n            \"--tensorboard_log\", action=\"store_true\", help=\"use tensorboard to record the resutls\"\n        )\n\n        # parser.add_argument('--img_dir',)\n        parser.add_argument(\n            \"--old_face_folder\", type=str, default=\"\", help=\"The folder name of input old face\"\n        )\n        parser.add_argument(\n            \"--old_face_label_folder\", type=str, default=\"\", help=\"The folder name of input old face label\"\n        )\n\n        parser.add_argument(\"--injection_layer\", type=str, default=\"all\", help=\"\")\n\n        self.initialized = True\n        return parser\n\n    def gather_options(self):\n        # initialize parser with basic options\n        if not self.initialized:\n            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n            parser = self.initialize(parser)\n\n        # get the basic options\n        opt, unknown = parser.parse_known_args()\n\n        # modify model-related parser options\n        model_name = opt.model\n        model_option_setter = models.get_option_setter(model_name)\n        parser = model_option_setter(parser, self.isTrain)\n\n        # modify dataset-related parser options\n        # dataset_mode = opt.dataset_mode\n        # dataset_option_setter = data.get_option_setter(dataset_mode)\n        # parser = dataset_option_setter(parser, self.isTrain)\n\n        opt, unknown = parser.parse_known_args()\n\n        # if there is opt_file, load it.\n        # The previous default options will be overwritten\n        if opt.load_from_opt_file:\n            parser = self.update_options_from_file(parser, opt)\n\n        opt = parser.parse_args()\n        self.parser = parser\n        return opt\n\n    def print_options(self, opt):\n        message = \"\"\n        message += \"----------------- Options ---------------\\n\"\n        for k, v in sorted(vars(opt).items()):\n            comment = \"\"\n            default = self.parser.get_default(k)\n            if v != default:\n                comment = \"\\t[default: %s]\" % str(default)\n            message += \"{:>25}: {:<30}{}\\n\".format(str(k), str(v), comment)\n        message += \"----------------- End -------------------\"\n        # print(message)\n\n    def option_file_path(self, opt, makedir=False):\n        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        if makedir:\n            util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, \"opt\")\n        return file_name\n\n    def save_options(self, opt):\n        file_name = self.option_file_path(opt, makedir=True)\n        with open(file_name + \".txt\", \"wt\") as opt_file:\n            for k, v in sorted(vars(opt).items()):\n                comment = \"\"\n                default = self.parser.get_default(k)\n                if v != default:\n                    comment = \"\\t[default: %s]\" % str(default)\n                opt_file.write(\"{:>25}: {:<30}{}\\n\".format(str(k), str(v), comment))\n\n        with open(file_name + \".pkl\", \"wb\") as opt_file:\n            pickle.dump(opt, opt_file)\n\n    def update_options_from_file(self, parser, opt):\n        new_opt = self.load_options(opt)\n        for k, v in sorted(vars(opt).items()):\n            if hasattr(new_opt, k) and v != getattr(new_opt, k):\n                new_val = getattr(new_opt, k)\n                parser.set_defaults(**{k: new_val})\n        return parser\n\n    def load_options(self, opt):\n        file_name = self.option_file_path(opt, makedir=False)\n        new_opt = pickle.load(open(file_name + \".pkl\", \"rb\"))\n        return new_opt\n\n    def parse(self, save=False):\n\n        opt = self.gather_options()\n        opt.isTrain = self.isTrain  # train or test\n        opt.contain_dontcare_label = False\n\n        self.print_options(opt)\n        if opt.isTrain:\n            self.save_options(opt)\n\n        # Set semantic_nc based on the option.\n        # This will be convenient in many places\n        opt.semantic_nc = (\n            opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)\n        )\n\n        # set gpu ids\n        str_ids = opt.gpu_ids.split(\",\")\n        opt.gpu_ids = []\n        for str_id in str_ids:\n            int_id = int(str_id)\n            if int_id >= 0:\n                opt.gpu_ids.append(int_id)\n\n        if len(opt.gpu_ids) > 0:\n            print(\"The main GPU is \")\n            print(opt.gpu_ids[0])\n            torch.cuda.set_device(opt.gpu_ids[0])\n\n        assert (\n            len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0\n        ), \"Batch size %d is wrong. It must be a multiple of # GPUs %d.\" % (opt.batchSize, len(opt.gpu_ids))\n\n        self.opt = opt\n        return self.opt\n"
  },
  {
    "path": "Face_Enhancement/options/test_options.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self, parser):\n        BaseOptions.initialize(self, parser)\n        parser.add_argument(\"--results_dir\", type=str, default=\"./results/\", help=\"saves results here.\")\n        parser.add_argument(\n            \"--which_epoch\",\n            type=str,\n            default=\"latest\",\n            help=\"which epoch to load? set to latest to use latest cached model\",\n        )\n        parser.add_argument(\"--how_many\", type=int, default=float(\"inf\"), help=\"how many test images to run\")\n\n        parser.set_defaults(\n            preprocess_mode=\"scale_width_and_crop\", crop_size=256, load_size=256, display_winsize=256\n        )\n        parser.set_defaults(serial_batches=True)\n        parser.set_defaults(no_flip=True)\n        parser.set_defaults(phase=\"test\")\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "Face_Enhancement/requirements.txt",
    "content": "torch>=1.0.0\ntorchvision\ndominate>=2.3.1\nwandb\ndill\nscikit-image\ntensorboardX\nscipy\nopencv-python"
  },
  {
    "path": "Face_Enhancement/test_face.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\n\nimport data\nfrom options.test_options import TestOptions\nfrom models.pix2pix_model import Pix2PixModel\nfrom util.visualizer import Visualizer\nimport torchvision.utils as vutils\nimport warnings\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\n\nopt = TestOptions().parse()\n\ndataloader = data.create_dataloader(opt)\n\nmodel = Pix2PixModel(opt)\nmodel.eval()\n\nvisualizer = Visualizer(opt)\n\n\nsingle_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, \"each_img\")\n\n\nif not os.path.exists(single_save_url):\n    os.makedirs(single_save_url)\n\n\nfor i, data_i in enumerate(dataloader):\n    if i * opt.batchSize >= opt.how_many:\n        break\n\n    generated = model(data_i, mode=\"inference\")\n\n    img_path = data_i[\"path\"]\n\n    for b in range(generated.shape[0]):\n        img_name = os.path.split(img_path[b])[-1]\n        save_img_url = os.path.join(single_save_url, img_name)\n\n        vutils.save_image((generated[b] + 1) / 2, save_img_url)\n\n"
  },
  {
    "path": "Face_Enhancement/util/__init__.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n"
  },
  {
    "path": "Face_Enhancement/util/iter_counter.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport time\nimport numpy as np\n\n\n# Helper class that keeps track of training iterations\nclass IterationCounter:\n    def __init__(self, opt, dataset_size):\n        self.opt = opt\n        self.dataset_size = dataset_size\n\n        self.first_epoch = 1\n        self.total_epochs = opt.niter + opt.niter_decay\n        self.epoch_iter = 0  # iter number within each epoch\n        self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, \"iter.txt\")\n        if opt.isTrain and opt.continue_train:\n            try:\n                self.first_epoch, self.epoch_iter = np.loadtxt(\n                    self.iter_record_path, delimiter=\",\", dtype=int\n                )\n                print(\"Resuming from epoch %d at iteration %d\" % (self.first_epoch, self.epoch_iter))\n            except:\n                print(\n                    \"Could not load iteration record at %s. Starting from beginning.\" % self.iter_record_path\n                )\n\n        self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter\n\n    # return the iterator of epochs for the training\n    def training_epochs(self):\n        return range(self.first_epoch, self.total_epochs + 1)\n\n    def record_epoch_start(self, epoch):\n        self.epoch_start_time = time.time()\n        self.epoch_iter = 0\n        self.last_iter_time = time.time()\n        self.current_epoch = epoch\n\n    def record_one_iteration(self):\n        current_time = time.time()\n\n        # the last remaining batch is dropped (see data/__init__.py),\n        # so we can assume batch size is always opt.batchSize\n        self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize\n        self.last_iter_time = current_time\n        self.total_steps_so_far += self.opt.batchSize\n        self.epoch_iter += self.opt.batchSize\n\n    def record_epoch_end(self):\n        current_time = time.time()\n        self.time_per_epoch = current_time - self.epoch_start_time\n        print(\n            \"End of epoch %d / %d \\t Time Taken: %d sec\"\n            % (self.current_epoch, self.total_epochs, self.time_per_epoch)\n        )\n        if self.current_epoch % self.opt.save_epoch_freq == 0:\n            np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=\",\", fmt=\"%d\")\n            print(\"Saved current iteration count at %s.\" % self.iter_record_path)\n\n    def record_current_iter(self):\n        np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=\",\", fmt=\"%d\")\n        print(\"Saved current iteration count at %s.\" % self.iter_record_path)\n\n    def needs_saving(self):\n        return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize\n\n    def needs_printing(self):\n        return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize\n\n    def needs_displaying(self):\n        return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize\n"
  },
  {
    "path": "Face_Enhancement/util/util.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport re\nimport importlib\nimport torch\nfrom argparse import Namespace\nimport numpy as np\nfrom PIL import Image\nimport os\nimport argparse\nimport dill as pickle\n\n\ndef save_obj(obj, name):\n    with open(name, \"wb\") as f:\n        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)\n\n\ndef load_obj(name):\n    with open(name, \"rb\") as f:\n        return pickle.load(f)\n\n\ndef copyconf(default_opt, **kwargs):\n    conf = argparse.Namespace(**vars(default_opt))\n    for key in kwargs:\n        print(key, kwargs[key])\n        setattr(conf, key, kwargs[key])\n    return conf\n\n\n# Converts a Tensor into a Numpy array\n# |imtype|: the desired type of the converted numpy array\ndef tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):\n    if isinstance(image_tensor, list):\n        image_numpy = []\n        for i in range(len(image_tensor)):\n            image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))\n        return image_numpy\n\n    if image_tensor.dim() == 4:\n        # transform each image in the batch\n        images_np = []\n        for b in range(image_tensor.size(0)):\n            one_image = image_tensor[b]\n            one_image_np = tensor2im(one_image)\n            images_np.append(one_image_np.reshape(1, *one_image_np.shape))\n        images_np = np.concatenate(images_np, axis=0)\n\n        return images_np\n\n    if image_tensor.dim() == 2:\n        image_tensor = image_tensor.unsqueeze(0)\n    image_numpy = image_tensor.detach().cpu().float().numpy()\n    if normalize:\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n    else:\n        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0\n    image_numpy = np.clip(image_numpy, 0, 255)\n    if image_numpy.shape[2] == 1:\n        image_numpy = image_numpy[:, :, 0]\n    return image_numpy.astype(imtype)\n\n\n# Converts a one-hot tensor into a colorful label map\ndef tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):\n    if label_tensor.dim() == 4:\n        # transform each image in the batch\n        images_np = []\n        for b in range(label_tensor.size(0)):\n            one_image = label_tensor[b]\n            one_image_np = tensor2label(one_image, n_label, imtype)\n            images_np.append(one_image_np.reshape(1, *one_image_np.shape))\n        images_np = np.concatenate(images_np, axis=0)\n        # if tile:\n        #     images_tiled = tile_images(images_np)\n        #     return images_tiled\n        # else:\n        #     images_np = images_np[0]\n        #     return images_np\n        return images_np\n\n    if label_tensor.dim() == 1:\n        return np.zeros((64, 64, 3), dtype=np.uint8)\n    if n_label == 0:\n        return tensor2im(label_tensor, imtype)\n    label_tensor = label_tensor.cpu().float()\n    if label_tensor.size()[0] > 1:\n        label_tensor = label_tensor.max(0, keepdim=True)[1]\n    label_tensor = Colorize(n_label)(label_tensor)\n    label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))\n    result = label_numpy.astype(imtype)\n    return result\n\n\ndef save_image(image_numpy, image_path, create_dir=False):\n    if create_dir:\n        os.makedirs(os.path.dirname(image_path), exist_ok=True)\n    if len(image_numpy.shape) == 2:\n        image_numpy = np.expand_dims(image_numpy, axis=2)\n    if image_numpy.shape[2] == 1:\n        image_numpy = np.repeat(image_numpy, 3, 2)\n    image_pil = Image.fromarray(image_numpy)\n\n    # save to png\n    image_pil.save(image_path.replace(\".jpg\", \".png\"))\n\n\ndef mkdirs(paths):\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef atoi(text):\n    return int(text) if text.isdigit() else text\n\n\ndef natural_keys(text):\n    \"\"\"\n    alist.sort(key=natural_keys) sorts in human order\n    http://nedbatchelder.com/blog/200712/human_sorting.html\n    (See Toothy's implementation in the comments)\n    \"\"\"\n    return [atoi(c) for c in re.split(\"(\\d+)\", text)]\n\n\ndef natural_sort(items):\n    items.sort(key=natural_keys)\n\n\ndef str2bool(v):\n    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n        return True\n    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n        return False\n    else:\n        raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n\ndef find_class_in_module(target_cls_name, module):\n    target_cls_name = target_cls_name.replace(\"_\", \"\").lower()\n    clslib = importlib.import_module(module)\n    cls = None\n    for name, clsobj in clslib.__dict__.items():\n        if name.lower() == target_cls_name:\n            cls = clsobj\n\n    if cls is None:\n        print(\n            \"In %s, there should be a class whose name matches %s in lowercase without underscore(_)\"\n            % (module, target_cls_name)\n        )\n        exit(0)\n\n    return cls\n\n\ndef save_network(net, label, epoch, opt):\n    save_filename = \"%s_net_%s.pth\" % (epoch, label)\n    save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)\n    torch.save(net.cpu().state_dict(), save_path)\n    if len(opt.gpu_ids) and torch.cuda.is_available():\n        net.cuda()\n\n\ndef load_network(net, label, epoch, opt):\n    save_filename = \"%s_net_%s.pth\" % (epoch, label)\n    save_dir = os.path.join(opt.checkpoints_dir, opt.name)\n    save_path = os.path.join(save_dir, save_filename)\n    if os.path.exists(save_path):\n        weights = torch.load(save_path)\n        net.load_state_dict(weights)\n    return net\n\n\n###############################################################################\n# Code from\n# https://github.com/ycszen/pytorch-seg/blob/master/transform.py\n# Modified so it complies with the Citscape label map colors\n###############################################################################\ndef uint82bin(n, count=8):\n    \"\"\"returns the binary of integer n, count refers to amount of bits\"\"\"\n    return \"\".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])\n\n\nclass Colorize(object):\n    def __init__(self, n=35):\n        self.cmap = labelcolormap(n)\n        self.cmap = torch.from_numpy(self.cmap[:n])\n\n    def __call__(self, gray_image):\n        size = gray_image.size()\n        color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)\n\n        for label in range(0, len(self.cmap)):\n            mask = (label == gray_image[0]).cpu()\n            color_image[0][mask] = self.cmap[label][0]\n            color_image[1][mask] = self.cmap[label][1]\n            color_image[2][mask] = self.cmap[label][2]\n\n        return color_image\n"
  },
  {
    "path": "Face_Enhancement/util/visualizer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport ntpath\nimport time\nfrom . import util\nimport scipy.misc\n\ntry:\n    from StringIO import StringIO  # Python 2.7\nexcept ImportError:\n    from io import BytesIO  # Python 3.x\nimport torchvision.utils as vutils\nfrom tensorboardX import SummaryWriter\nimport torch\nimport numpy as np\n\n\nclass Visualizer:\n    def __init__(self, opt):\n        self.opt = opt\n        self.tf_log = opt.isTrain and opt.tf_log\n\n        self.tensorboard_log = opt.tensorboard_log\n\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        if self.tensorboard_log:\n\n            if self.opt.isTrain:\n                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, \"logs\")\n                if not os.path.exists(self.log_dir):\n                    os.makedirs(self.log_dir)\n                self.writer = SummaryWriter(log_dir=self.log_dir)\n            else:\n                print(\"hi :)\")\n                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)\n                if not os.path.exists(self.log_dir):\n                    os.makedirs(self.log_dir)\n\n        if opt.isTrain:\n            self.log_name = os.path.join(opt.checkpoints_dir, opt.name, \"loss_log.txt\")\n            with open(self.log_name, \"a\") as log_file:\n                now = time.strftime(\"%c\")\n                log_file.write(\"================ Training Loss (%s) ================\\n\" % now)\n\n    # |visuals|: dictionary of images to display or save\n    def display_current_results(self, visuals, epoch, step):\n\n        all_tensor = []\n        if self.tensorboard_log:\n\n            for key, tensor in visuals.items():\n                all_tensor.append((tensor.data.cpu() + 1) / 2)\n\n            output = torch.cat(all_tensor, 0)\n            img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)\n\n            if self.opt.isTrain:\n                self.writer.add_image(\"Face_SPADE/training_samples\", img_grid, step)\n            else:\n                vutils.save_image(\n                    output,\n                    os.path.join(self.log_dir, str(step) + \".png\"),\n                    nrow=self.opt.batchSize,\n                    padding=0,\n                    normalize=False,\n                )\n\n    # errors: dictionary of error labels and values\n    def plot_current_errors(self, errors, step):\n        if self.tf_log:\n            for tag, value in errors.items():\n                value = value.mean().float()\n                summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])\n                self.writer.add_summary(summary, step)\n\n        if self.tensorboard_log:\n\n            self.writer.add_scalar(\"Loss/GAN_Feat\", errors[\"GAN_Feat\"].mean().float(), step)\n            self.writer.add_scalar(\"Loss/VGG\", errors[\"VGG\"].mean().float(), step)\n            self.writer.add_scalars(\n                \"Loss/GAN\",\n                {\n                    \"G\": errors[\"GAN\"].mean().float(),\n                    \"D\": (errors[\"D_Fake\"].mean().float() + errors[\"D_real\"].mean().float()) / 2,\n                },\n                step,\n            )\n\n    # errors: same format as |errors| of plotCurrentErrors\n    def print_current_errors(self, epoch, i, errors, t):\n        message = \"(epoch: %d, iters: %d, time: %.3f) \" % (epoch, i, t)\n        for k, v in errors.items():\n            v = v.mean().float()\n            message += \"%s: %.3f \" % (k, v)\n\n        print(message)\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write(\"%s\\n\" % message)\n\n    def convert_visuals_to_numpy(self, visuals):\n        for key, t in visuals.items():\n            tile = self.opt.batchSize > 8\n            if \"input_label\" == key:\n                t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)  ## B*H*W*C 0-255 numpy\n            else:\n                t = util.tensor2im(t, tile=tile)\n            visuals[key] = t\n        return visuals\n\n    # save image to the disk\n    def save_images(self, webpage, visuals, image_path):\n        visuals = self.convert_visuals_to_numpy(visuals)\n\n        image_dir = webpage.get_image_dir()\n        short_path = ntpath.basename(image_path[0])\n        name = os.path.splitext(short_path)[0]\n\n        webpage.add_header(name)\n        ims = []\n        txts = []\n        links = []\n\n        for label, image_numpy in visuals.items():\n            image_name = os.path.join(label, \"%s.png\" % (name))\n            save_path = os.path.join(image_dir, image_name)\n            util.save_image(image_numpy, save_path, create_dir=True)\n\n            ims.append(image_name)\n            txts.append(label)\n            links.append(image_name)\n        webpage.add_images(ims, txts, links, width=self.win_size)\n"
  },
  {
    "path": "GUI.py",
    "content": "import numpy as np\nimport cv2\nimport PySimpleGUI as sg\nimport os.path\nimport argparse\nimport os\nimport sys\nimport shutil\nfrom subprocess import call\n\ndef modify(image_filename=None, cv2_frame=None):\n\n    def run_cmd(command):\n        try:\n            call(command, shell=True)\n        except KeyboardInterrupt:\n            print(\"Process interrupted\")\n            sys.exit(1)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input_folder\", type=str,\n                        default= image_filename, help=\"Test images\")\n    parser.add_argument(\n        \"--output_folder\",\n        type=str,\n        default=\"./output\",\n        help=\"Restored images, please use the absolute path\",\n    )\n    parser.add_argument(\"--GPU\", type=str, default=\"-1\", help=\"0,1,2\")\n    parser.add_argument(\n        \"--checkpoint_name\", type=str, default=\"Setting_9_epoch_100\", help=\"choose which checkpoint\"\n    )\n    parser.add_argument(\"--with_scratch\",default=\"--with_scratch\" ,action=\"store_true\")\n    opts = parser.parse_args()\n\n    gpu1 = opts.GPU\n\n    # resolve relative paths before changing directory\n    opts.input_folder = os.path.abspath(opts.input_folder)\n    opts.output_folder = os.path.abspath(opts.output_folder)\n    if not os.path.exists(opts.output_folder):\n        os.makedirs(opts.output_folder)\n\n    main_environment = os.getcwd()\n\n    # Stage 1: Overall Quality Improve\n    print(\"Running Stage 1: Overall restoration\")\n    os.chdir(\"./Global\")\n    stage_1_input_dir = opts.input_folder\n    stage_1_output_dir = os.path.join(\n        opts.output_folder, \"stage_1_restore_output\")\n    if not os.path.exists(stage_1_output_dir):\n        os.makedirs(stage_1_output_dir)\n\n    if not opts.with_scratch:\n        stage_1_command = (\n            \"python test.py --test_mode Full --Quality_restore --test_input \"\n            + stage_1_input_dir\n            + \" --outputs_dir \"\n            + stage_1_output_dir\n            + \" --gpu_ids \"\n            + gpu1\n        )\n        run_cmd(stage_1_command)\n    else:\n\n        mask_dir = os.path.join(stage_1_output_dir, \"masks\")\n        new_input = os.path.join(mask_dir, \"input\")\n        new_mask = os.path.join(mask_dir, \"mask\")\n        stage_1_command_1 = (\n            \"python detection.py --test_path \"\n            + stage_1_input_dir\n            + \" --output_dir \"\n            + mask_dir\n            + \" --input_size full_size\"\n            + \" --GPU \"\n            + gpu1\n        )\n        stage_1_command_2 = (\n            \"python test.py --Scratch_and_Quality_restore --test_input \"\n            + new_input\n            + \" --test_mask \"\n            + new_mask\n            + \" --outputs_dir \"\n            + stage_1_output_dir\n            + \" --gpu_ids \"\n            + gpu1\n        )\n        run_cmd(stage_1_command_1)\n        run_cmd(stage_1_command_2)\n\n    # Solve the case when there is no face in the old photo\n    stage_1_results = os.path.join(stage_1_output_dir, \"restored_image\")\n    stage_4_output_dir = os.path.join(opts.output_folder, \"final_output\")\n    if not os.path.exists(stage_4_output_dir):\n        os.makedirs(stage_4_output_dir)\n    for x in os.listdir(stage_1_results):\n        img_dir = os.path.join(stage_1_results, x)\n        shutil.copy(img_dir, stage_4_output_dir)\n\n    print(\"Finish Stage 1 ...\")\n    print(\"\\n\")\n\n    # Stage 2: Face Detection\n\n    print(\"Running Stage 2: Face Detection\")\n    os.chdir(\".././Face_Detection\")\n    stage_2_input_dir = os.path.join(stage_1_output_dir, \"restored_image\")\n    stage_2_output_dir = os.path.join(\n        opts.output_folder, \"stage_2_detection_output\")\n    if not os.path.exists(stage_2_output_dir):\n        os.makedirs(stage_2_output_dir)\n    stage_2_command = (\n        \"python detect_all_dlib.py --url \" + stage_2_input_dir +\n        \" --save_url \" + stage_2_output_dir\n    )\n    run_cmd(stage_2_command)\n    print(\"Finish Stage 2 ...\")\n    print(\"\\n\")\n\n    # Stage 3: Face Restore\n    print(\"Running Stage 3: Face Enhancement\")\n    os.chdir(\".././Face_Enhancement\")\n    stage_3_input_mask = \"./\"\n    stage_3_input_face = stage_2_output_dir\n    stage_3_output_dir = os.path.join(\n        opts.output_folder, \"stage_3_face_output\")\n    if not os.path.exists(stage_3_output_dir):\n        os.makedirs(stage_3_output_dir)\n    stage_3_command = (\n        \"python test_face.py --old_face_folder \"\n        + stage_3_input_face\n        + \" --old_face_label_folder \"\n        + stage_3_input_mask\n        + \" --tensorboard_log --name \"\n        + opts.checkpoint_name\n        + \" --gpu_ids \"\n        + gpu1\n        + \" --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir \"\n        + stage_3_output_dir\n        + \" --no_parsing_map\"\n    )\n    run_cmd(stage_3_command)\n    print(\"Finish Stage 3 ...\")\n    print(\"\\n\")\n\n    # Stage 4: Warp back\n    print(\"Running Stage 4: Blending\")\n    os.chdir(\".././Face_Detection\")\n    stage_4_input_image_dir = os.path.join(\n        stage_1_output_dir, \"restored_image\")\n    stage_4_input_face_dir = os.path.join(stage_3_output_dir, \"each_img\")\n    stage_4_output_dir = os.path.join(opts.output_folder, \"final_output\")\n    if not os.path.exists(stage_4_output_dir):\n        os.makedirs(stage_4_output_dir)\n    stage_4_command = (\n        \"python align_warp_back_multiple_dlib.py --origin_url \"\n        + stage_4_input_image_dir\n        + \" --replace_url \"\n        + stage_4_input_face_dir\n        + \" --save_url \"\n        + stage_4_output_dir\n    )\n    run_cmd(stage_4_command)\n    print(\"Finish Stage 4 ...\")\n    print(\"\\n\")\n\n    print(\"All the processing is done. Please check the results.\")\n\n# --------------------------------- The GUI ---------------------------------\n\n# First the window layout...\n\nimages_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()],\n              [sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')],\n              [sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],]\n# ----- Full layout -----\nlayout = [[sg.VSeperator(), sg.Column(images_col)]]\n\n# ----- Make the window -----\nwindow = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True)\n\n# ----- Run the Event Loop -----\nprev_filename = colorized = cap = None\nwhile True:\n    event, values = window.read()\n    if event in (None, 'Exit'):\n        break\n\n    elif event == '-MPHOTO-':\n        try:\n            n1 = filename.split(\"/\")[-2]\n            n2 = filename.split(\"/\")[-3]\n            n3 = filename.split(\"/\")[-1]\n            filename= str(f\"./{n2}/{n1}\")\n            modify(filename)\n           \n            global f_image\n            f_image = f'./output/final_output/{n3}'\n            image = cv2.imread(f_image)\n            window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes())\n            \n        except:\n            continue\n\n    elif event == '-IN FILE-':      # A single filename was chosen\n        filename = values['-IN FILE-']\n        if filename != prev_filename:\n            prev_filename = filename\n            try:\n                image = cv2.imread(filename)\n                window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes())\n            except:\n                continue\n\n# ----- Exit program -----\nwindow.close()"
  },
  {
    "path": "Global/data/Create_Bigfile.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport struct\nfrom PIL import Image\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir):\n    images = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                #print(fname)\n                path = os.path.join(root, fname)\n                images.append(path)\n\n    return images\n\n### Modify these 3 lines in your own environment\nindir=\"/home/ziyuwan/workspace/data/temp_old\"\ntarget_folders=['VOC','Real_L_old','Real_RGB_old']\nout_dir =\"/home/ziyuwan/workspace/data/temp_old\"\n###\n\nif os.path.exists(out_dir) is False:\n    os.makedirs(out_dir)\n\n#\nfor target_folder in target_folders:\n    curr_indir = os.path.join(indir, target_folder)\n    curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))\n    image_lists = make_dataset(curr_indir)\n    image_lists.sort()\n    with open(curr_out_file, 'wb') as wfid:\n        # write total image number\n        wfid.write(struct.pack('i', len(image_lists)))\n        for i, img_path in enumerate(image_lists):\n             # write file name first\n             img_name = os.path.basename(img_path)\n             img_name_bytes = img_name.encode('utf-8')\n             wfid.write(struct.pack('i', len(img_name_bytes)))\n             wfid.write(img_name_bytes)\n    #\n    #             # write image data in\n             with open(img_path, 'rb') as img_fid:\n                 img_bytes = img_fid.read()\n             wfid.write(struct.pack('i', len(img_bytes)))\n             wfid.write(img_bytes)\n\n             if i % 1000 == 0:\n                 print('write %d images done' % i)"
  },
  {
    "path": "Global/data/Load_Bigfile.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport io\nimport os\nimport struct\nfrom PIL import Image\n\nclass BigFileMemoryLoader(object):\n    def __load_bigfile(self):\n        print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))\n        with open(self.file_path, 'rb') as fid:\n            self.img_num = struct.unpack('i', fid.read(4))[0]\n            self.img_names = []\n            self.img_bytes = []\n            print('find total %d images' % self.img_num)\n            for i in range(self.img_num):\n                img_name_len = struct.unpack('i', fid.read(4))[0]\n                img_name = fid.read(img_name_len).decode('utf-8')\n                self.img_names.append(img_name)\n                img_bytes_len = struct.unpack('i', fid.read(4))[0]\n                self.img_bytes.append(fid.read(img_bytes_len))\n                if i % 5000 == 0:\n                    print('load %d images done' % i)\n            print('load all %d images done' % self.img_num)\n\n    def __init__(self, file_path):\n        super(BigFileMemoryLoader, self).__init__()\n        self.file_path = file_path\n        self.__load_bigfile()\n\n    def __getitem__(self, index):\n        try:\n            img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')\n            return self.img_names[index], img\n        except Exception:\n            print('Image read error for index %d: %s' % (index, self.img_names[index]))\n            return self.__getitem__((index+1)%self.img_num)\n\n\n    def __len__(self):\n        return self.img_num\n"
  },
  {
    "path": "Global/data/__init__.py",
    "content": ""
  },
  {
    "path": "Global/data/base_data_loader.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nclass BaseDataLoader():\n    def __init__(self):\n        pass\n    \n    def initialize(self, opt):\n        self.opt = opt\n        pass\n\n    def load_data():\n        return None\n\n        \n        \n"
  },
  {
    "path": "Global/data/base_dataset.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\nimport numpy as np\nimport random\n\nclass BaseDataset(data.Dataset):\n    def __init__(self):\n        super(BaseDataset, self).__init__()\n\n    def name(self):\n        return 'BaseDataset'\n\n    def initialize(self, opt):\n        pass\n\ndef get_params(opt, size):\n    w, h = size\n    new_h = h\n    new_w = w\n    if opt.resize_or_crop == 'resize_and_crop':\n        new_h = new_w = opt.loadSize\n\n    if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256\n\n        if w<h:\n            new_w = opt.loadSize\n            new_h = opt.loadSize * h // w\n        else:\n            new_h=opt.loadSize\n            new_w = opt.loadSize * w // h\n\n    if opt.resize_or_crop=='crop_only':\n        pass\n\n\n    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))\n    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))\n    \n    flip = random.random() > 0.5\n    return {'crop_pos': (x, y), 'flip': flip}\n\ndef get_transform(opt, params, method=Image.BICUBIC, normalize=True):\n    transform_list = []\n    if 'resize' in opt.resize_or_crop:\n        osize = [opt.loadSize, opt.loadSize]\n        transform_list.append(transforms.Scale(osize, method))   \n    elif 'scale_width' in opt.resize_or_crop:\n    #    transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))  ## Here , We want the shorter side to match 256, and Scale will finish it.\n        transform_list.append(transforms.Scale(256,method))\n\n    if 'crop' in opt.resize_or_crop:\n        if opt.isTrain:\n            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))\n        else:\n            if opt.test_random_crop:\n                transform_list.append(transforms.RandomCrop(opt.fineSize))\n            else:\n                transform_list.append(transforms.CenterCrop(opt.fineSize))\n\n    ## when testing, for ablation study, choose center_crop directly.\n\n\n\n    if opt.resize_or_crop == 'none':\n        base = float(2 ** opt.n_downsample_global)\n        if opt.netG == 'local':\n            base *= (2 ** opt.n_local_enhancers)\n        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))\n\n    if opt.isTrain and not opt.no_flip:\n        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n\n    transform_list += [transforms.ToTensor()]\n\n    if normalize:\n        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),\n                                                (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n\ndef normalize():    \n    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n\ndef __make_power_2(img, base, method=Image.BICUBIC):\n    ow, oh = img.size        \n    h = int(round(oh / base) * base)\n    w = int(round(ow / base) * base)\n    if (h == oh) and (w == ow):\n        return img\n    return img.resize((w, h), method)\n\ndef __scale_width(img, target_width, method=Image.BICUBIC):\n    ow, oh = img.size\n    if (ow == target_width):\n        return img    \n    w = target_width\n    h = int(target_width * oh / ow)    \n    return img.resize((w, h), method)\n\ndef __crop(img, pos, size):\n    ow, oh = img.size\n    x1, y1 = pos\n    tw = th = size\n    if (ow > tw or oh > th):        \n        return img.crop((x1, y1, x1 + tw, y1 + th))\n    return img\n\ndef __flip(img, flip):\n    if flip:\n        return img.transpose(Image.FLIP_LEFT_RIGHT)\n    return img\n"
  },
  {
    "path": "Global/data/custom_dataset_data_loader.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data\nimport random\nfrom data.base_data_loader import BaseDataLoader\nfrom data import online_dataset_for_old_photos as dts_ray_bigfile\n\n\ndef CreateDataset(opt):\n    dataset = None\n    if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':\n        dataset = dts_ray_bigfile.UnPairOldPhotos_SR()\n    if opt.training_dataset=='mapping':\n        if opt.random_hole:\n            dataset = dts_ray_bigfile.PairOldPhotos_with_hole()\n        else:\n            dataset = dts_ray_bigfile.PairOldPhotos()\n    print(\"dataset [%s] was created\" % (dataset.name()))\n    dataset.initialize(opt)\n    return dataset\n\nclass CustomDatasetDataLoader(BaseDataLoader):\n    def name(self):\n        return 'CustomDatasetDataLoader'\n\n    def initialize(self, opt):\n        BaseDataLoader.initialize(self, opt)\n        self.dataset = CreateDataset(opt)\n        self.dataloader = torch.utils.data.DataLoader(\n            self.dataset,\n            batch_size=opt.batchSize,\n            shuffle=not opt.serial_batches,\n            num_workers=int(opt.nThreads),\n            drop_last=True)\n\n    def load_data(self):\n        return self.dataloader\n\n    def __len__(self):\n        return min(len(self.dataset), self.opt.max_dataset_size)\n"
  },
  {
    "path": "Global/data/data_loader.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\ndef CreateDataLoader(opt):\n    from data.custom_dataset_data_loader import CustomDatasetDataLoader\n    data_loader = CustomDatasetDataLoader()\n    print(data_loader.name())\n    data_loader.initialize(opt)\n    return data_loader\n"
  },
  {
    "path": "Global/data/image_folder.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch.utils.data as data\nfrom PIL import Image\nimport os\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir):\n    images = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n\n    return images\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" +\n                               \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "Global/data/online_dataset_for_old_photos.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os.path\nimport io\nimport zipfile\nfrom data.base_dataset import BaseDataset, get_params, get_transform, normalize\nfrom data.image_folder import make_dataset\nfrom PIL import Image\nimport torchvision.transforms as transforms\nimport numpy as np\nfrom data.Load_Bigfile import BigFileMemoryLoader\nimport random\nimport cv2\nfrom io import BytesIO\n\ndef pil_to_np(img_PIL):\n    '''Converts image in PIL format to np.array.\n\n    From W x H x C [0...255] to C x W x H [0..1]\n    '''\n    ar = np.array(img_PIL)\n\n    if len(ar.shape) == 3:\n        ar = ar.transpose(2, 0, 1)\n    else:\n        ar = ar[None, ...]\n\n    return ar.astype(np.float32) / 255.\n\n\ndef np_to_pil(img_np):\n    '''Converts image in np.array format to PIL image.\n\n    From C x W x H [0..1] to  W x H x C [0...255]\n    '''\n    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)\n\n    if img_np.shape[0] == 1:\n        ar = ar[0]\n    else:\n        ar = ar.transpose(1, 2, 0)\n\n    return Image.fromarray(ar)\n\ndef synthesize_salt_pepper(image,amount,salt_vs_pepper):\n\n    ## Give PIL, return the noisy PIL\n\n    img_pil=pil_to_np(image)\n\n    out = img_pil.copy()\n    p = amount\n    q = salt_vs_pepper\n    flipped = np.random.choice([True, False], size=img_pil.shape,\n                               p=[p, 1 - p])\n    salted = np.random.choice([True, False], size=img_pil.shape,\n                              p=[q, 1 - q])\n    peppered = ~salted\n    out[flipped & salted] = 1\n    out[flipped & peppered] = 0.\n    noisy = np.clip(out, 0, 1).astype(np.float32)\n\n\n    return np_to_pil(noisy)\n\ndef synthesize_gaussian(image,std_l,std_r):\n\n    ## Give PIL, return the noisy PIL\n\n    img_pil=pil_to_np(image)\n\n    mean=0\n    std=random.uniform(std_l/255.,std_r/255.)\n    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)\n    noisy=img_pil+gauss\n    noisy=np.clip(noisy,0,1).astype(np.float32)\n\n    return np_to_pil(noisy)\n\ndef synthesize_speckle(image,std_l,std_r):\n\n    ## Give PIL, return the noisy PIL\n\n    img_pil=pil_to_np(image)\n\n    mean=0\n    std=random.uniform(std_l/255.,std_r/255.)\n    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)\n    noisy=img_pil+gauss*img_pil\n    noisy=np.clip(noisy,0,1).astype(np.float32)\n\n    return np_to_pil(noisy)\n\n\ndef synthesize_low_resolution(img):\n    w,h=img.size\n\n    new_w=random.randint(int(w/2),w)\n    new_h=random.randint(int(h/2),h)\n\n    img=img.resize((new_w,new_h),Image.BICUBIC)\n\n    if random.uniform(0,1)<0.5:\n        img=img.resize((w,h),Image.NEAREST)\n    else:\n        img = img.resize((w, h), Image.BILINEAR)\n\n    return img\n\n\ndef convertToJpeg(im,quality):\n    with BytesIO() as f:\n        im.save(f, format='JPEG',quality=quality)\n        f.seek(0)\n        return Image.open(f).convert('RGB')\n\n\ndef blur_image_v2(img):\n\n\n    x=np.array(img)\n    kernel_size_candidate=[(3,3),(5,5),(7,7)]\n    kernel_size=random.sample(kernel_size_candidate,1)[0]\n    std=random.uniform(1.,5.)\n\n    #print(\"The gaussian kernel size: (%d,%d) std: %.2f\"%(kernel_size[0],kernel_size[1],std))\n    blur=cv2.GaussianBlur(x,kernel_size,std)\n\n    return Image.fromarray(blur.astype(np.uint8))\n\ndef online_add_degradation_v2(img):\n\n    task_id=np.random.permutation(4)\n\n    for x in task_id:\n        if x==0 and random.uniform(0,1)<0.7:\n            img = blur_image_v2(img)\n        if x==1 and random.uniform(0,1)<0.7:\n            flag = random.choice([1, 2, 3])\n            if flag == 1:\n                img = synthesize_gaussian(img, 5, 50)\n            if flag == 2:\n                img = synthesize_speckle(img, 5, 50)\n            if flag == 3:\n                img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))\n        if x==2 and random.uniform(0,1)<0.7:\n            img=synthesize_low_resolution(img)\n\n        if x==3 and random.uniform(0,1)<0.7:\n            img=convertToJpeg(img,random.randint(40,100))\n\n    return img\n\n\ndef irregular_hole_synthesize(img,mask):\n\n    img_np=np.array(img).astype('uint8')\n    mask_np=np.array(mask).astype('uint8')\n    mask_np=mask_np/255\n    img_new=img_np*(1-mask_np)+mask_np*255\n\n\n    hole_img=Image.fromarray(img_new.astype('uint8')).convert(\"RGB\")\n\n    return hole_img,mask.convert(\"L\")\n\ndef zero_mask(size):\n    x=np.zeros((size,size,3)).astype('uint8')\n    mask=Image.fromarray(x).convert(\"RGB\")\n    return mask\n\n\n\nclass UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old\n    def initialize(self, opt):\n        self.opt = opt\n        self.isImage = 'domainA' in opt.name\n        self.task = 'old_photo_restoration_training_vae'\n        self.dir_AB = opt.dataroot\n        if self.isImage:\n\n            self.load_img_dir_L_old=os.path.join(self.dir_AB,\"Real_L_old.bigfile\")\n            self.load_img_dir_RGB_old=os.path.join(self.dir_AB,\"Real_RGB_old.bigfile\")\n            self.load_img_dir_clean=os.path.join(self.dir_AB,\"VOC_RGB_JPEGImages.bigfile\")\n\n            self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)\n            self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)\n            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)\n\n        else:\n            # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)\n            self.load_img_dir_clean=os.path.join(self.dir_AB,\"VOC_RGB_JPEGImages.bigfile\")\n            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)\n\n        ####\n        print(\"-------------Filter the imgs whose size <256 in VOC-------------\")\n        self.filtered_imgs_clean=[]\n        for i in range(len(self.loaded_imgs_clean)):\n            img_name,img=self.loaded_imgs_clean[i]\n            h,w=img.size\n            if h<256 or w<256:\n                continue\n            self.filtered_imgs_clean.append((img_name,img))\n\n        print(\"--------Origin image num is [%d], filtered result is [%d]--------\" % (\n        len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))\n        ## Filter these images whose size is less than 256\n\n        # self.img_list=os.listdir(load_img_dir)\n        self.pid = os.getpid()\n\n    def __getitem__(self, index):\n\n\n        is_real_old=0\n\n        sampled_dataset=None\n        degradation=None\n        if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old\n            P=random.uniform(0,2)\n            if P>=0 and P<1:\n                if random.uniform(0,1)<0.5:\n                    sampled_dataset=self.loaded_imgs_L_old\n                    self.load_img_dir=self.load_img_dir_L_old\n                else:\n                    sampled_dataset=self.loaded_imgs_RGB_old\n                    self.load_img_dir=self.load_img_dir_RGB_old\n                is_real_old=1\n            if P>=1 and P<2:\n                sampled_dataset=self.filtered_imgs_clean\n                self.load_img_dir=self.load_img_dir_clean\n                degradation=1\n        else:\n\n            sampled_dataset=self.filtered_imgs_clean\n            self.load_img_dir=self.load_img_dir_clean\n\n        sampled_dataset_len=len(sampled_dataset)\n\n        index=random.randint(0,sampled_dataset_len-1)\n\n        img_name,img = sampled_dataset[index]\n\n        if degradation is not None:\n            img=online_add_degradation_v2(img)\n\n        path=os.path.join(self.load_img_dir,img_name)\n\n        # AB = Image.open(path).convert('RGB')\n        # split AB image into A and B\n\n        # apply the same transform to both A and B\n\n        if random.uniform(0,1) <0.1:\n            img=img.convert(\"L\")\n            img=img.convert(\"RGB\")\n            ## Give a probability P, we convert the RGB image into L\n\n\n        A=img\n        w,h=A.size\n        if w<256 or h<256:\n            A=transforms.Scale(256,Image.BICUBIC)(A)\n        ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.\n\n        transform_params = get_params(self.opt, A.size)\n        A_transform = get_transform(self.opt, transform_params)\n\n        B_tensor = inst_tensor = feat_tensor = 0\n        A_tensor = A_transform(A)\n\n\n        input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,\n                        'feat': feat_tensor, 'path': path}\n        return input_dict\n\n    def __len__(self):\n        return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number\n\n    def name(self):\n        return 'UnPairOldPhotos_SR'\n\n\nclass PairOldPhotos(BaseDataset):\n    def initialize(self, opt):\n        self.opt = opt\n        self.isImage = 'imagegan' in opt.name\n        self.task = 'old_photo_restoration_training_mapping'\n        self.dir_AB = opt.dataroot\n        if opt.isTrain:\n            self.load_img_dir_clean= os.path.join(self.dir_AB, \"VOC_RGB_JPEGImages.bigfile\")\n            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)\n\n            print(\"-------------Filter the imgs whose size <256 in VOC-------------\")\n            self.filtered_imgs_clean = []\n            for i in range(len(self.loaded_imgs_clean)):\n                img_name, img = self.loaded_imgs_clean[i]\n                h, w = img.size\n                if h < 256 or w < 256:\n                    continue\n                self.filtered_imgs_clean.append((img_name, img))\n\n            print(\"--------Origin image num is [%d], filtered result is [%d]--------\" % (\n            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))\n\n        else:\n            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)\n            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)\n\n        self.pid = os.getpid()\n\n    def __getitem__(self, index):\n\n\n\n        if self.opt.isTrain:\n            img_name_clean,B = self.filtered_imgs_clean[index]\n            path = os.path.join(self.load_img_dir_clean, img_name_clean)\n            if self.opt.use_v2_degradation:\n                A=online_add_degradation_v2(B)\n            ### Remind: A is the input and B is corresponding GT\n        else:\n\n            if self.opt.test_on_synthetic:\n\n                img_name_B,B=self.loaded_imgs[index]\n                A=online_add_degradation_v2(B)\n                img_name_A=img_name_B\n                path = os.path.join(self.load_img_dir, img_name_A)\n            else:\n                img_name_A,A=self.loaded_imgs[index]\n                img_name_B,B=self.loaded_imgs[index]\n                path = os.path.join(self.load_img_dir, img_name_A)\n\n\n        if random.uniform(0,1)<0.1 and self.opt.isTrain:\n            A=A.convert(\"L\")\n            B=B.convert(\"L\")\n            A=A.convert(\"RGB\")\n            B=B.convert(\"RGB\")\n        ## In P, we convert the RGB into L\n\n\n        ##test on L\n\n        # split AB image into A and B\n        # w, h = img.size\n        # w2 = int(w / 2)\n        # A = img.crop((0, 0, w2, h))\n        # B = img.crop((w2, 0, w, h))\n        w,h=A.size\n        if w<256 or h<256:\n            A=transforms.Scale(256,Image.BICUBIC)(A)\n            B=transforms.Scale(256, Image.BICUBIC)(B)\n\n        # apply the same transform to both A and B\n        transform_params = get_params(self.opt, A.size)\n        A_transform = get_transform(self.opt, transform_params)\n        B_transform = get_transform(self.opt, transform_params)\n\n        B_tensor = inst_tensor = feat_tensor = 0\n        A_tensor = A_transform(A)\n        B_tensor = B_transform(B)\n\n        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,\n                    'feat': feat_tensor, 'path': path}\n        return input_dict\n\n    def __len__(self):\n\n        if self.opt.isTrain:\n            return len(self.filtered_imgs_clean)\n        else:\n            return len(self.loaded_imgs)\n\n    def name(self):\n        return 'PairOldPhotos'\n\n\nclass PairOldPhotos_with_hole(BaseDataset):\n    def initialize(self, opt):\n        self.opt = opt\n        self.isImage = 'imagegan' in opt.name\n        self.task = 'old_photo_restoration_training_mapping'\n        self.dir_AB = opt.dataroot\n        if opt.isTrain:\n            self.load_img_dir_clean= os.path.join(self.dir_AB, \"VOC_RGB_JPEGImages.bigfile\")\n            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)\n\n            print(\"-------------Filter the imgs whose size <256 in VOC-------------\")\n            self.filtered_imgs_clean = []\n            for i in range(len(self.loaded_imgs_clean)):\n                img_name, img = self.loaded_imgs_clean[i]\n                h, w = img.size\n                if h < 256 or w < 256:\n                    continue\n                self.filtered_imgs_clean.append((img_name, img))\n\n            print(\"--------Origin image num is [%d], filtered result is [%d]--------\" % (\n            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))\n\n        else:\n            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)\n            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)\n\n        self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)\n\n        self.pid = os.getpid()\n\n    def __getitem__(self, index):\n\n\n\n        if self.opt.isTrain:\n            img_name_clean,B = self.filtered_imgs_clean[index]\n            path = os.path.join(self.load_img_dir_clean, img_name_clean)\n\n\n            B=transforms.RandomCrop(256)(B)\n            A=online_add_degradation_v2(B)\n            ### Remind: A is the input and B is corresponding GT\n\n        else:\n            img_name_A,A=self.loaded_imgs[index]\n            img_name_B,B=self.loaded_imgs[index]\n            path = os.path.join(self.load_img_dir, img_name_A)\n\n            #A=A.resize((256,256))\n            A=transforms.CenterCrop(256)(A)\n            B=A\n\n        if random.uniform(0,1)<0.1 and self.opt.isTrain:\n            A=A.convert(\"L\")\n            B=B.convert(\"L\")\n            A=A.convert(\"RGB\")\n            B=B.convert(\"RGB\")\n        ## In P, we convert the RGB into L\n\n        if self.opt.isTrain:\n            mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]\n        else:\n            mask_name, mask = self.loaded_masks[index%100]\n        mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)\n\n        if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:\n            mask=zero_mask(256)\n\n        if self.opt.no_hole:\n            mask=zero_mask(256)\n\n\n        A,_=irregular_hole_synthesize(A,mask)\n\n        if not self.opt.isTrain and self.opt.hole_image_no_mask:\n            mask=zero_mask(256)\n\n        transform_params = get_params(self.opt, A.size)\n        A_transform = get_transform(self.opt, transform_params)\n        B_transform = get_transform(self.opt, transform_params)\n\n        if transform_params['flip'] and self.opt.isTrain:\n            mask=mask.transpose(Image.FLIP_LEFT_RIGHT)\n\n        mask_tensor = transforms.ToTensor()(mask)\n\n\n        B_tensor = inst_tensor = feat_tensor = 0\n        A_tensor = A_transform(A)\n        B_tensor = B_transform(B)\n\n        input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,\n                    'feat': feat_tensor, 'path': path}\n        return input_dict\n\n    def __len__(self):\n\n        if self.opt.isTrain:\n            return len(self.filtered_imgs_clean)\n\n        else:\n            return len(self.loaded_imgs)\n\n    def name(self):\n        return 'PairOldPhotos_with_hole'"
  },
  {
    "path": "Global/detection.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport gc\nimport json\nimport os\nimport time\nimport warnings\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision as tv\nfrom PIL import Image, ImageFile\n\nfrom detection_models import networks\nfrom detection_util.util import *\n\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\ndef data_transforms(img, full_size, method=Image.BICUBIC):\n    if full_size == \"full_size\":\n        ow, oh = img.size\n        h = int(round(oh / 16) * 16)\n        w = int(round(ow / 16) * 16)\n        if (h == oh) and (w == ow):\n            return img\n        return img.resize((w, h), method)\n\n    elif full_size == \"scale_256\":\n        ow, oh = img.size\n        pw, ph = ow, oh\n        if ow < oh:\n            ow = 256\n            oh = ph / pw * 256\n        else:\n            oh = 256\n            ow = pw / ph * 256\n\n        h = int(round(oh / 16) * 16)\n        w = int(round(ow / 16) * 16)\n        if (h == ph) and (w == pw):\n            return img\n        return img.resize((w, h), method)\n\n\ndef scale_tensor(img_tensor, default_scale=256):\n    _, _, w, h = img_tensor.shape\n    if w < h:\n        ow = default_scale\n        oh = h / w * default_scale\n    else:\n        oh = default_scale\n        ow = w / h * default_scale\n\n    oh = int(round(oh / 16) * 16)\n    ow = int(round(ow / 16) * 16)\n\n    return F.interpolate(img_tensor, [ow, oh], mode=\"bilinear\")\n\n\ndef blend_mask(img, mask):\n\n    np_img = np.array(img).astype(\"float\")\n\n    return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype(\"uint8\")).convert(\"RGB\")\n\n\ndef main(config):\n    print(\"initializing the dataloader\")\n\n    model = networks.UNet(\n        in_channels=1,\n        out_channels=1,\n        depth=4,\n        conv_num=2,\n        wf=6,\n        padding=True,\n        batch_norm=True,\n        up_mode=\"upsample\",\n        with_tanh=False,\n        sync_bn=True,\n        antialiasing=True,\n    )\n\n    ## load model\n    checkpoint_path = os.path.join(os.path.dirname(__file__), \"checkpoints/detection/FT_Epoch_latest.pt\")\n    checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n    model.load_state_dict(checkpoint[\"model_state\"])\n    print(\"model weights loaded\")\n\n    if config.GPU >= 0:\n        model.to(config.GPU)\n    else: \n        model.cpu()\n    model.eval()\n\n    ## dataloader and transformation\n    print(\"directory of testing image: \" + config.test_path)\n    imagelist = os.listdir(config.test_path)\n    imagelist.sort()\n    total_iter = 0\n\n    P_matrix = {}\n    save_url = os.path.join(config.output_dir)\n    mkdir_if_not(save_url)\n\n    input_dir = os.path.join(save_url, \"input\")\n    output_dir = os.path.join(save_url, \"mask\")\n    # blend_output_dir=os.path.join(save_url, 'blend_output')\n    mkdir_if_not(input_dir)\n    mkdir_if_not(output_dir)\n    # mkdir_if_not(blend_output_dir)\n\n    idx = 0\n\n    results = []\n    for image_name in imagelist:\n\n        idx += 1\n\n        print(\"processing\", image_name)\n\n        scratch_file = os.path.join(config.test_path, image_name)\n        if not os.path.isfile(scratch_file):\n            print(\"Skipping non-file %s\" % image_name)\n            continue\n        scratch_image = Image.open(scratch_file).convert(\"RGB\")\n        w, h = scratch_image.size\n\n        transformed_image_PIL = data_transforms(scratch_image, config.input_size)\n        scratch_image = transformed_image_PIL.convert(\"L\")\n        scratch_image = tv.transforms.ToTensor()(scratch_image)\n        scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)\n        scratch_image = torch.unsqueeze(scratch_image, 0)\n        _, _, ow, oh = scratch_image.shape\n        scratch_image_scale = scale_tensor(scratch_image)\n\n        if config.GPU >= 0:\n            scratch_image_scale = scratch_image_scale.to(config.GPU)\n        else:\n            scratch_image_scale = scratch_image_scale.cpu()\n        with torch.no_grad():\n            P = torch.sigmoid(model(scratch_image_scale))\n\n        P = P.data.cpu()\n        P = F.interpolate(P, [ow, oh], mode=\"nearest\")\n\n        tv.utils.save_image(\n            (P >= 0.4).float(),\n            os.path.join(\n                output_dir,\n                image_name[:-4] + \".png\",\n            ),\n            nrow=1,\n            padding=0,\n            normalize=True,\n        )\n        transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + \".png\"))\n        gc.collect()\n        torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # parser.add_argument('--checkpoint_name', type=str, default=\"FT_Epoch_latest.pt\", help='Checkpoint Name')\n\n    parser.add_argument(\"--GPU\", type=int, default=0)\n    parser.add_argument(\"--test_path\", type=str, default=\".\")\n    parser.add_argument(\"--output_dir\", type=str, default=\".\")\n    parser.add_argument(\"--input_size\", type=str, default=\"scale_256\", help=\"resize_256|full_size|scale_256\")\n    config = parser.parse_args()\n\n    main(config)\n"
  },
  {
    "path": "Global/detection_models/__init__.py",
    "content": ""
  },
  {
    "path": "Global/detection_models/antialiasing.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn.parallel\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Downsample(nn.Module):\n    # https://github.com/adobe/antialiased-cnns\n\n    def __init__(self, pad_type=\"reflect\", filt_size=3, stride=2, channels=None, pad_off=0):\n        super(Downsample, self).__init__()\n        self.filt_size = filt_size\n        self.pad_off = pad_off\n        self.pad_sizes = [\n            int(1.0 * (filt_size - 1) / 2),\n            int(np.ceil(1.0 * (filt_size - 1) / 2)),\n            int(1.0 * (filt_size - 1) / 2),\n            int(np.ceil(1.0 * (filt_size - 1) / 2)),\n        ]\n        self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]\n        self.stride = stride\n        self.off = int((self.stride - 1) / 2.0)\n        self.channels = channels\n\n        # print('Filter size [%i]'%filt_size)\n        if self.filt_size == 1:\n            a = np.array([1.0,])\n        elif self.filt_size == 2:\n            a = np.array([1.0, 1.0])\n        elif self.filt_size == 3:\n            a = np.array([1.0, 2.0, 1.0])\n        elif self.filt_size == 4:\n            a = np.array([1.0, 3.0, 3.0, 1.0])\n        elif self.filt_size == 5:\n            a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])\n        elif self.filt_size == 6:\n            a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])\n        elif self.filt_size == 7:\n            a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])\n\n        filt = torch.Tensor(a[:, None] * a[None, :])\n        filt = filt / torch.sum(filt)\n        self.register_buffer(\"filt\", filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))\n\n        self.pad = get_pad_layer(pad_type)(self.pad_sizes)\n\n    def forward(self, inp):\n        if self.filt_size == 1:\n            if self.pad_off == 0:\n                return inp[:, :, :: self.stride, :: self.stride]\n            else:\n                return self.pad(inp)[:, :, :: self.stride, :: self.stride]\n        else:\n            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])\n\n\ndef get_pad_layer(pad_type):\n    if pad_type in [\"refl\", \"reflect\"]:\n        PadLayer = nn.ReflectionPad2d\n    elif pad_type in [\"repl\", \"replicate\"]:\n        PadLayer = nn.ReplicationPad2d\n    elif pad_type == \"zero\":\n        PadLayer = nn.ZeroPad2d\n    else:\n        print(\"Pad type [%s] not recognized\" % pad_type)\n    return PadLayer\n"
  },
  {
    "path": "Global/detection_models/networks.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom detection_models.sync_batchnorm import DataParallelWithCallback\nfrom detection_models.antialiasing import Downsample\n\n\nclass UNet(nn.Module):\n    def __init__(\n        self,\n        in_channels=3,\n        out_channels=3,\n        depth=5,\n        conv_num=2,\n        wf=6,\n        padding=True,\n        batch_norm=True,\n        up_mode=\"upsample\",\n        with_tanh=False,\n        sync_bn=True,\n        antialiasing=True,\n    ):\n        \"\"\"\n\t\tImplementation of\n\t\tU-Net: Convolutional Networks for Biomedical Image Segmentation\n\t\t(Ronneberger et al., 2015)\n\t\thttps://arxiv.org/abs/1505.04597\n\t\tUsing the default arguments will yield the exact version used\n\t\tin the original paper\n\t\tArgs:\n\t\t\tin_channels (int): number of input channels\n\t\t\tout_channels (int): number of output channels\n\t\t\tdepth (int): depth of the network\n\t\t\twf (int): number of filters in the first layer is 2**wf\n\t\t\tpadding (bool): if True, apply padding such that the input shape\n\t\t\t\t\t\t\tis the same as the output.\n\t\t\t\t\t\t\tThis may introduce artifacts\n\t\t\tbatch_norm (bool): Use BatchNorm after layers with an\n\t\t\t\t\t\t\t   activation function\n\t\t\tup_mode (str): one of 'upconv' or 'upsample'.\n\t\t\t\t\t\t   'upconv' will use transposed convolutions for\n\t\t\t\t\t\t   learned upsampling.\n\t\t\t\t\t\t   'upsample' will use bilinear upsampling.\n\t\t\"\"\"\n        super().__init__()\n        assert up_mode in (\"upconv\", \"upsample\")\n        self.padding = padding\n        self.depth = depth - 1\n        prev_channels = in_channels\n\n        self.first = nn.Sequential(\n            *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]\n        )\n        prev_channels = 2 ** wf\n\n        self.down_path = nn.ModuleList()\n        self.down_sample = nn.ModuleList()\n        for i in range(depth):\n            if antialiasing and depth > 0:\n                self.down_sample.append(\n                    nn.Sequential(\n                        *[\n                            nn.ReflectionPad2d(1),\n                            nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),\n                            nn.BatchNorm2d(prev_channels),\n                            nn.LeakyReLU(0.2, True),\n                            Downsample(channels=prev_channels, stride=2),\n                        ]\n                    )\n                )\n            else:\n                self.down_sample.append(\n                    nn.Sequential(\n                        *[\n                            nn.ReflectionPad2d(1),\n                            nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),\n                            nn.BatchNorm2d(prev_channels),\n                            nn.LeakyReLU(0.2, True),\n                        ]\n                    )\n                )\n            self.down_path.append(\n                UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)\n            )\n            prev_channels = 2 ** (wf + i + 1)\n\n        self.up_path = nn.ModuleList()\n        for i in reversed(range(depth)):\n            self.up_path.append(\n                UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)\n            )\n            prev_channels = 2 ** (wf + i)\n\n        if with_tanh:\n            self.last = nn.Sequential(\n                *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]\n            )\n        else:\n            self.last = nn.Sequential(\n                *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]\n            )\n\n        if sync_bn:\n            self = DataParallelWithCallback(self)\n\n    def forward(self, x):\n        x = self.first(x)\n\n        blocks = []\n        for i, down_block in enumerate(self.down_path):\n            blocks.append(x)\n            x = self.down_sample[i](x)\n            x = down_block(x)\n\n        for i, up in enumerate(self.up_path):\n            x = up(x, blocks[-i - 1])\n\n        return self.last(x)\n\n\nclass UNetConvBlock(nn.Module):\n    def __init__(self, conv_num, in_size, out_size, padding, batch_norm):\n        super(UNetConvBlock, self).__init__()\n        block = []\n\n        for _ in range(conv_num):\n            block.append(nn.ReflectionPad2d(padding=int(padding)))\n            block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))\n            if batch_norm:\n                block.append(nn.BatchNorm2d(out_size))\n            block.append(nn.LeakyReLU(0.2, True))\n            in_size = out_size\n\n        self.block = nn.Sequential(*block)\n\n    def forward(self, x):\n        out = self.block(x)\n        return out\n\n\nclass UNetUpBlock(nn.Module):\n    def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):\n        super(UNetUpBlock, self).__init__()\n        if up_mode == \"upconv\":\n            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)\n        elif up_mode == \"upsample\":\n            self.up = nn.Sequential(\n                nn.Upsample(mode=\"bilinear\", scale_factor=2, align_corners=False),\n                nn.ReflectionPad2d(1),\n                nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),\n            )\n\n        self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)\n\n    def center_crop(self, layer, target_size):\n        _, _, layer_height, layer_width = layer.size()\n        diff_y = (layer_height - target_size[0]) // 2\n        diff_x = (layer_width - target_size[1]) // 2\n        return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]\n\n    def forward(self, x, bridge):\n        up = self.up(x)\n        crop1 = self.center_crop(bridge, up.shape[2:])\n        out = torch.cat([up, crop1], 1)\n        out = self.conv_block(out)\n\n        return out\n\n\nclass UnetGenerator(nn.Module):\n    \"\"\"Create a Unet-based generator\"\"\"\n\n    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type=\"BN\", use_dropout=False):\n        \"\"\"Construct a Unet generator\n\t\tParameters:\n\t\t\tinput_nc (int)  -- the number of channels in input images\n\t\t\toutput_nc (int) -- the number of channels in output images\n\t\t\tnum_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,\n\t\t\t\t\t\t\t\timage of size 128x128 will become of size 1x1 # at the bottleneck\n\t\t\tngf (int)       -- the number of filters in the last conv layer\n\t\t\tnorm_layer      -- normalization layer\n\t\tWe construct the U-Net from the innermost layer to the outermost layer.\n\t\tIt is a recursive process.\n\t\t\"\"\"\n        super().__init__()\n        if norm_type == \"BN\":\n            norm_layer = nn.BatchNorm2d\n        elif norm_type == \"IN\":\n            norm_layer = nn.InstanceNorm2d\n        else:\n            raise NameError(\"Unknown norm layer\")\n\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(\n            ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True\n        )  # add the innermost layer\n        for i in range(num_downs - 5):  # add intermediate layers with ngf * 8 filters\n            unet_block = UnetSkipConnectionBlock(\n                ngf * 8,\n                ngf * 8,\n                input_nc=None,\n                submodule=unet_block,\n                norm_layer=norm_layer,\n                use_dropout=use_dropout,\n            )\n        # gradually reduce the number of filters from ngf * 8 to ngf\n        unet_block = UnetSkipConnectionBlock(\n            ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer\n        )\n        unet_block = UnetSkipConnectionBlock(\n            ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer\n        )\n        unet_block = UnetSkipConnectionBlock(\n            ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer\n        )\n        self.model = UnetSkipConnectionBlock(\n            output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer\n        )  # add the outermost layer\n\n    def forward(self, input):\n        return self.model(input)\n\n\nclass UnetSkipConnectionBlock(nn.Module):\n    \"\"\"Defines the Unet submodule with skip connection.\n\n\t\t-------------------identity----------------------\n\t\t|-- downsampling -- |submodule| -- upsampling --|\n\t\"\"\"\n\n    def __init__(\n        self,\n        outer_nc,\n        inner_nc,\n        input_nc=None,\n        submodule=None,\n        outermost=False,\n        innermost=False,\n        norm_layer=nn.BatchNorm2d,\n        use_dropout=False,\n    ):\n        \"\"\"Construct a Unet submodule with skip connections.\n\t\tParameters:\n\t\t\touter_nc (int) -- the number of filters in the outer conv layer\n\t\t\tinner_nc (int) -- the number of filters in the inner conv layer\n\t\t\tinput_nc (int) -- the number of channels in input images/features\n\t\t\tsubmodule (UnetSkipConnectionBlock) -- previously defined submodules\n\t\t\toutermost (bool)    -- if this module is the outermost module\n\t\t\tinnermost (bool)    -- if this module is the innermost module\n\t\t\tnorm_layer          -- normalization layer\n\t\t\tuser_dropout (bool) -- if use dropout layers.\n\t\t\"\"\"\n        super().__init__()\n        self.outermost = outermost\n        use_bias = norm_layer == nn.InstanceNorm2d\n        if input_nc is None:\n            input_nc = outer_nc\n        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)\n        downrelu = nn.LeakyReLU(0.2, True)\n        downnorm = norm_layer(inner_nc)\n        uprelu = nn.LeakyReLU(0.2, True)\n        upnorm = norm_layer(outer_nc)\n\n        if outermost:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)\n            down = [downconv]\n            up = [uprelu, upconv, nn.Tanh()]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)\n            down = [downrelu, downconv]\n            up = [uprelu, upconv, upnorm]\n            model = down + up\n        else:\n            upconv = nn.ConvTranspose2d(\n                inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias\n            )\n            down = [downrelu, downconv, downnorm]\n            up = [uprelu, upconv, upnorm]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:  # add skip connections\n            return torch.cat([x, self.model(x)], 1)\n\n\n# ============================================\n# Network testing\n# ============================================\nif __name__ == \"__main__\":\n    from torchsummary import summary\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    model = UNet_two_decoders(\n        in_channels=3,\n        out_channels1=3,\n        out_channels2=1,\n        depth=4,\n        conv_num=1,\n        wf=6,\n        padding=True,\n        batch_norm=True,\n        up_mode=\"upsample\",\n        with_tanh=False,\n    )\n    model.to(device)\n\n    model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type=\"BN\", use_dropout=False)\n    model_pix2pix.to(device)\n\n    print(\"customized unet:\")\n    summary(model, (3, 256, 256))\n\n    print(\"cyclegan unet:\")\n    summary(model_pix2pix, (3, 256, 256))\n\n    x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()\n    g = make_dot(model(x))\n    g.render(\"models/Digraph.gv\", view=False)\n\n"
  },
  {
    "path": "Global/detection_util/util.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport sys\nimport time\nimport shutil\nimport platform\nimport numpy as np\nfrom datetime import datetime\n\nimport torch\nimport torchvision as tv\nimport torch.backends.cudnn as cudnn\n\n# from torch.utils.tensorboard import SummaryWriter\n\nimport yaml\nimport matplotlib.pyplot as plt\nfrom easydict import EasyDict as edict\nimport torchvision.utils as vutils\n\n\n##### option parsing ######\ndef print_options(config_dict):\n    print(\"------------ Options -------------\")\n    for k, v in sorted(config_dict.items()):\n        print(\"%s: %s\" % (str(k), str(v)))\n    print(\"-------------- End ----------------\")\n\n\ndef save_options(config_dict):\n    from time import gmtime, strftime\n\n    file_dir = os.path.join(config_dict[\"checkpoint_dir\"], config_dict[\"name\"])\n    mkdir_if_not(file_dir)\n    file_name = os.path.join(file_dir, \"opt.txt\")\n    with open(file_name, \"wt\") as opt_file:\n        opt_file.write(os.path.basename(sys.argv[0]) + \" \" + strftime(\"%Y-%m-%d %H:%M:%S\", gmtime()) + \"\\n\")\n        opt_file.write(\"------------ Options -------------\\n\")\n        for k, v in sorted(config_dict.items()):\n            opt_file.write(\"%s: %s\\n\" % (str(k), str(v)))\n        opt_file.write(\"-------------- End ----------------\\n\")\n\n\ndef config_parse(config_file, options, save=True):\n    with open(config_file, \"r\") as stream:\n        config_dict = yaml.safe_load(stream)\n        config = edict(config_dict)\n\n    for option_key, option_value in vars(options).items():\n        config_dict[option_key] = option_value\n        config[option_key] = option_value\n\n    if config.debug_mode:\n        config_dict[\"num_workers\"] = 0\n        config.num_workers = 0\n        config.batch_size = 2\n        if isinstance(config.gpu_ids, str):\n            config.gpu_ids = [int(x) for x in config.gpu_ids.split(\",\")][0]\n\n    print_options(config_dict)\n    if save:\n        save_options(config_dict)\n\n    return config\n\n\n###### utility ######\ndef to_np(x):\n    return x.cpu().numpy()\n\n\ndef prepare_device(use_gpu, gpu_ids):\n    if use_gpu:\n        cudnn.benchmark = True\n        os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n        if isinstance(gpu_ids, str):\n            gpu_ids = [int(x) for x in gpu_ids.split(\",\")]\n            torch.cuda.set_device(gpu_ids[0])\n            device = torch.device(\"cuda:\" + str(gpu_ids[0]))\n        else:\n            torch.cuda.set_device(gpu_ids)\n            device = torch.device(\"cuda:\" + str(gpu_ids))\n        print(\"running on GPU {}\".format(gpu_ids))\n    else:\n        device = torch.device(\"cpu\")\n        print(\"running on CPU\")\n\n    return device\n\n\n###### file system ######\ndef get_dir_size(start_path=\".\"):\n    total_size = 0\n    for dirpath, dirnames, filenames in os.walk(start_path):\n        for f in filenames:\n            fp = os.path.join(dirpath, f)\n            total_size += os.path.getsize(fp)\n    return total_size\n\n\ndef mkdir_if_not(dir_path):\n    if not os.path.exists(dir_path):\n        os.makedirs(dir_path)\n\n\n##### System related ######\nclass Timer:\n    def __init__(self, msg):\n        self.msg = msg\n        self.start_time = None\n\n    def __enter__(self):\n        self.start_time = time.time()\n\n    def __exit__(self, exc_type, exc_value, exc_tb):\n        elapse = time.time() - self.start_time\n        print(self.msg % elapse)\n\n\n###### interactive ######\ndef get_size(start_path=\".\"):\n    total_size = 0\n    for dirpath, dirnames, filenames in os.walk(start_path):\n        for f in filenames:\n            fp = os.path.join(dirpath, f)\n            total_size += os.path.getsize(fp)\n    return total_size\n\n\ndef clean_tensorboard(directory):\n    tensorboard_list = os.listdir(directory)\n    SIZE_THRESH = 100000\n    for tensorboard in tensorboard_list:\n        tensorboard = os.path.join(directory, tensorboard)\n        if get_size(tensorboard) < SIZE_THRESH:\n            print(\"deleting the empty tensorboard: \", tensorboard)\n            #\n            if os.path.isdir(tensorboard):\n                shutil.rmtree(tensorboard)\n            else:\n                os.remove(tensorboard)\n\n\ndef prepare_tensorboard(config, experiment_name=datetime.now().strftime(\"%Y-%m-%d %H-%M-%S\")):\n    tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, \"tensorboard_logs\")\n    mkdir_if_not(tensorboard_directory)\n    clean_tensorboard(tensorboard_directory)\n    tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10)\n\n    # try:\n    #     shutil.copy('outputs/opt.txt', tensorboard_directory)\n    # except:\n    #     print('cannot find file opt.txt')\n    return tb_writer\n\n\ndef tb_loss_logger(tb_writer, iter_index, loss_logger):\n    for tag, value in loss_logger.items():\n        tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index)\n\n\ndef tb_image_logger(tb_writer, iter_index, images_info, config):\n    ### Save and write the output into the tensorboard\n    tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode)\n    mkdir_if_not(tb_logger_path)\n    for tag, image in images_info.items():\n        if tag == \"test_image_prediction\" or tag == \"image_prediction\":\n            continue\n        image = tv.utils.make_grid(image.cpu())\n        image = torch.clamp(image, 0, 1)\n        tb_writer.add_image(tag, img_tensor=image, global_step=iter_index)\n        tv.transforms.functional.to_pil_image(image).save(\n            os.path.join(tb_logger_path, \"{:06d}_{}.jpg\".format(iter_index, tag))\n        )\n\n\ndef tb_image_logger_test(epoch, iter, images_info, config):\n\n    url = os.path.join(config.output_dir, config.name, config.train_mode, \"val_\" + str(epoch))\n    if not os.path.exists(url):\n        os.makedirs(url)\n    scratch_img = images_info[\"test_scratch_image\"].data.cpu()\n    if config.norm_input:\n        scratch_img = (scratch_img + 1.0) / 2.0\n    scratch_img = torch.clamp(scratch_img, 0, 1)\n    gt_mask = images_info[\"test_mask_image\"].data.cpu()\n    predict_mask = images_info[\"test_scratch_prediction\"].data.cpu()\n\n    predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float()\n\n    imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0)\n    img_grid = vutils.save_image(\n        imgs, os.path.join(url, str(iter) + \".jpg\"), nrow=len(scratch_img), padding=0, normalize=True\n    )\n\n\ndef imshow(input_image, title=None, to_numpy=False):\n    inp = input_image\n    if to_numpy or type(input_image) is torch.Tensor:\n        inp = input_image.numpy()\n\n    fig = plt.figure()\n    if inp.ndim == 2:\n        fig = plt.imshow(inp, cmap=\"gray\", clim=[0, 255])\n    else:\n        fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))\n    plt.axis(\"off\")\n    fig.axes.get_xaxis().set_visible(False)\n    fig.axes.get_yaxis().set_visible(False)\n    plt.title(title)\n\n\n###### vgg preprocessing ######\ndef vgg_preprocess(tensor):\n    # input is RGB tensor which ranges in [0,1]\n    # output is BGR tensor which ranges in [0,255]\n    tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)\n    # tensor_bgr = tensor[:, [2, 1, 0], ...]\n    tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(\n        1, 3, 1, 1\n    )\n    tensor_rst = tensor_bgr_ml * 255\n    return tensor_rst\n\n\ndef torch_vgg_preprocess(tensor):\n    # pytorch version normalization\n    # note that both input and output are RGB tensors;\n    # input and output ranges in [0,1]\n    # normalize the tensor with mean and variance\n    tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)\n    tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)\n    return tensor_mc_norm\n\n\ndef network_gradient(net, gradient_on=True):\n    if gradient_on:\n        for param in net.parameters():\n            param.requires_grad = True\n    else:\n        for param in net.parameters():\n            param.requires_grad = False\n    return net\n"
  },
  {
    "path": "Global/models/NonLocal_feature_mapping_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport os\nimport functools\nfrom torch.autograd import Variable\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport math\n\n\nclass Mapping_Model_with_mask(nn.Module):\n    def __init__(self, nc, mc=64, n_blocks=3, norm=\"instance\", padding_type=\"reflect\", opt=None):\n        super(Mapping_Model_with_mask, self).__init__()\n\n        norm_layer = networks.get_norm_layer(norm_type=norm)\n        activation = nn.ReLU(True)\n        model = []\n\n        tmp_nc = 64\n        n_up = 4\n\n        for i in range(n_up):\n            ic = min(tmp_nc * (2 ** i), mc)\n            oc = min(tmp_nc * (2 ** (i + 1)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n\n        self.before_NL = nn.Sequential(*model)\n\n        if opt.NL_res:\n            self.NL = networks.NonLocalBlock2D_with_mask_Res(\n                mc,\n                mc,\n                opt.NL_fusion_method,\n                opt.correlation_renormalize,\n                opt.softmax_temperature,\n                opt.use_self,\n                opt.cosin_similarity,\n            )\n            print(\"You are using NL + Res\")\n\n        model = []\n        for i in range(n_blocks):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n\n        for i in range(n_up - 1):\n            ic = min(64 * (2 ** (4 - i)), mc)\n            oc = min(64 * (2 ** (3 - i)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]\n        if opt.feat_dim > 0 and opt.feat_dim < 64:\n            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]\n        # model += [nn.Conv2d(64, 1, 1, 1, 0)]\n        self.after_NL = nn.Sequential(*model)\n        \n    \n    def forward(self, input, mask):\n        x1 = self.before_NL(input)\n        del input\n        x2 = self.NL(x1, mask)\n        del x1, mask\n        x3 = self.after_NL(x2)\n        del x2\n\n        return x3\n\nclass Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention\n    def __init__(self, nc, mc=64, n_blocks=3, norm=\"instance\", padding_type=\"reflect\", opt=None):\n        super(Mapping_Model_with_mask_2, self).__init__()\n\n        norm_layer = networks.get_norm_layer(norm_type=norm)\n        activation = nn.ReLU(True)\n        model = []\n\n        tmp_nc = 64\n        n_up = 4\n\n        for i in range(n_up):\n            ic = min(tmp_nc * (2 ** i), mc)\n            oc = min(tmp_nc * (2 ** (i + 1)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n\n        for i in range(2):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n\n        print(\"Mapping: You are using multi-scale patch attention, conv combine + mask input\")\n\n        self.before_NL = nn.Sequential(*model)\n\n        if opt.mapping_exp==1:\n            self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8)\n\n        model = []\n        for i in range(2):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n\n        self.res_block_1 = nn.Sequential(*model)\n\n        if opt.mapping_exp==1:\n            self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4)\n\n        model = []\n        for i in range(2):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n        \n        self.res_block_2 = nn.Sequential(*model)\n        \n        if opt.mapping_exp==1:\n            self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2)\n        # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)\n\n        model = []\n        for i in range(2):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n\n        for i in range(n_up - 1):\n            ic = min(64 * (2 ** (4 - i)), mc)\n            oc = min(64 * (2 ** (3 - i)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]\n        if opt.feat_dim > 0 and opt.feat_dim < 64:\n            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]\n        # model += [nn.Conv2d(64, 1, 1, 1, 0)]\n        self.after_NL = nn.Sequential(*model)\n        \n    \n    def forward(self, input, mask):\n        x1 = self.before_NL(input)\n        x2 = self.NL_scale_1(x1,mask)\n        x3 = self.res_block_1(x2)\n        x4 = self.NL_scale_2(x3,mask)\n        x5 = self.res_block_2(x4)\n        x6 = self.NL_scale_3(x5,mask)\n        x7 = self.after_NL(x6)\n        return x7\n\n    def inference_forward(self, input, mask):\n        x1 = self.before_NL(input)\n        del input\n        x2 = self.NL_scale_1.inference_forward(x1,mask)\n        del x1\n        x3 = self.res_block_1(x2)\n        del x2\n        x4 = self.NL_scale_2.inference_forward(x3,mask)\n        del x3\n        x5 = self.res_block_2(x4)\n        del x4\n        x6 = self.NL_scale_3.inference_forward(x5,mask)\n        del x5\n        x7 = self.after_NL(x6)\n        del x6\n        return x7   "
  },
  {
    "path": "Global/models/__init__.py",
    "content": ""
  },
  {
    "path": "Global/models/base_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport torch\nimport sys\n\n\nclass BaseModel(torch.nn.Module):\n    def name(self):\n        return \"BaseModel\"\n\n    def initialize(self, opt):\n        self.opt = opt\n        self.gpu_ids = opt.gpu_ids\n        self.isTrain = opt.isTrain\n        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor\n        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)\n\n    def set_input(self, input):\n        self.input = input\n\n    def forward(self):\n        pass\n\n    # used in test time, no backprop\n    def test(self):\n        pass\n\n    def get_image_paths(self):\n        pass\n\n    def optimize_parameters(self):\n        pass\n\n    def get_current_visuals(self):\n        return self.input\n\n    def get_current_errors(self):\n        return {}\n\n    def save(self, label):\n        pass\n\n    # helper saving function that can be used by subclasses\n    def save_network(self, network, network_label, epoch_label, gpu_ids):\n        save_filename = \"%s_net_%s.pth\" % (epoch_label, network_label)\n        save_path = os.path.join(self.save_dir, save_filename)\n        torch.save(network.cpu().state_dict(), save_path)\n        if len(gpu_ids) and torch.cuda.is_available():\n            network.cuda()\n\n    def save_optimizer(self, optimizer, optimizer_label, epoch_label):\n        save_filename = \"%s_optimizer_%s.pth\" % (epoch_label, optimizer_label)\n        save_path = os.path.join(self.save_dir, save_filename)\n        torch.save(optimizer.state_dict(), save_path)\n\n    def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=\"\"):\n        save_filename = \"%s_optimizer_%s.pth\" % (epoch_label, optimizer_label)\n        if not save_dir:\n            save_dir = self.save_dir\n        save_path = os.path.join(save_dir, save_filename)\n\n        if not os.path.isfile(save_path):\n            print(\"%s not exists yet!\" % save_path)\n        else:\n            optimizer.load_state_dict(torch.load(save_path))\n\n    # helper loading function that can be used by subclasses\n    def load_network(self, network, network_label, epoch_label, save_dir=\"\"):\n        save_filename = \"%s_net_%s.pth\" % (epoch_label, network_label)\n        if not save_dir:\n            save_dir = self.save_dir\n\n        # print(save_dir)\n        # print(self.save_dir)\n        save_path = os.path.join(save_dir, save_filename)\n        if not os.path.isfile(save_path):\n            print(\"%s not exists yet!\" % save_path)\n            # if network_label == 'G':\n            #     raise('Generator must exist!')\n        else:\n            # network.load_state_dict(torch.load(save_path))\n            try:\n                # print(save_path)\n                network.load_state_dict(torch.load(save_path))\n            except:\n                pretrained_dict = torch.load(save_path)\n                model_dict = network.state_dict()\n                try:\n                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n                    network.load_state_dict(pretrained_dict)\n                    # if self.opt.verbose:\n                    print(\n                        \"Pretrained network %s has excessive layers; Only loading layers that are used\"\n                        % network_label\n                    )\n                except:\n                    print(\n                        \"Pretrained network %s has fewer layers; The following are not initialized:\"\n                        % network_label\n                    )\n                    for k, v in pretrained_dict.items():\n                        if v.size() == model_dict[k].size():\n                            model_dict[k] = v\n\n                    if sys.version_info >= (3, 0):\n                        not_initialized = set()\n                    else:\n                        from sets import Set\n\n                        not_initialized = Set()\n\n                    for k, v in model_dict.items():\n                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():\n                            not_initialized.add(k.split(\".\")[0])\n\n                    print(sorted(not_initialized))\n                    network.load_state_dict(model_dict)\n\n    def update_learning_rate():\n        pass\n"
  },
  {
    "path": "Global/models/mapping_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport os\nimport functools\nfrom torch.autograd import Variable\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport math\nfrom .NonLocal_feature_mapping_model import *\n\n\nclass Mapping_Model(nn.Module):\n    def __init__(self, nc, mc=64, n_blocks=3, norm=\"instance\", padding_type=\"reflect\", opt=None):\n        super(Mapping_Model, self).__init__()\n\n        norm_layer = networks.get_norm_layer(norm_type=norm)\n        activation = nn.ReLU(True)\n        model = []\n        tmp_nc = 64\n        n_up = 4\n\n        print(\"Mapping: You are using the mapping model without global restoration.\")\n\n        for i in range(n_up):\n            ic = min(tmp_nc * (2 ** i), mc)\n            oc = min(tmp_nc * (2 ** (i + 1)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n        for i in range(n_blocks):\n            model += [\n                networks.ResnetBlock(\n                    mc,\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                    dilation=opt.mapping_net_dilation,\n                )\n            ]\n\n        for i in range(n_up - 1):\n            ic = min(64 * (2 ** (4 - i)), mc)\n            oc = min(64 * (2 ** (3 - i)), mc)\n            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]\n        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]\n        if opt.feat_dim > 0 and opt.feat_dim < 64:\n            model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]\n        # model += [nn.Conv2d(64, 1, 1, 1, 0)]\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        return self.model(input)\n\n\nclass Pix2PixHDModel_Mapping(BaseModel):\n    def name(self):\n        return \"Pix2PixHDModel_Mapping\"\n\n    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2):\n        flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2)\n\n        def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2):\n            return [\n                l\n                for (l, f) in zip(\n                    (g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags\n                )\n                if f\n            ]\n\n        return loss_filter\n\n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        if opt.resize_or_crop != \"none\" or not opt.isTrain:\n            torch.backends.cudnn.benchmark = True\n        self.isTrain = opt.isTrain\n        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc\n\n        ##### define networks\n        # Generator network\n        netG_input_nc = input_nc\n        self.netG_A = networks.GlobalGenerator_DCDCv2(\n            netG_input_nc,\n            opt.output_nc,\n            opt.ngf,\n            opt.k_size,\n            opt.n_downsample_global,\n            networks.get_norm_layer(norm_type=opt.norm),\n            opt=opt,\n        )\n        self.netG_B = networks.GlobalGenerator_DCDCv2(\n            netG_input_nc,\n            opt.output_nc,\n            opt.ngf,\n            opt.k_size,\n            opt.n_downsample_global,\n            networks.get_norm_layer(norm_type=opt.norm),\n            opt=opt,\n        )\n\n        if opt.non_local == \"Setting_42\" or opt.NL_use_mask:\n            if opt.mapping_exp==1:\n                self.mapping_net = Mapping_Model_with_mask_2(\n                    min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),\n                    opt.map_mc,\n                    n_blocks=opt.mapping_n_block,\n                    opt=opt,\n                )\n            else:\n                self.mapping_net = Mapping_Model_with_mask(\n                    min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),\n                    opt.map_mc,\n                    n_blocks=opt.mapping_n_block,\n                    opt=opt,\n                )\n        else:\n            self.mapping_net = Mapping_Model(\n                min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),\n                opt.map_mc,\n                n_blocks=opt.mapping_n_block,\n                opt=opt,\n            )\n\n        self.mapping_net.apply(networks.weights_init)\n\n        if opt.load_pretrain != \"\":\n            self.load_network(self.mapping_net, \"mapping_net\", opt.which_epoch, opt.load_pretrain)\n\n        if not opt.no_load_VAE:\n\n            self.load_network(self.netG_A, \"G\", opt.use_vae_which_epoch, opt.load_pretrainA)\n            self.load_network(self.netG_B, \"G\", opt.use_vae_which_epoch, opt.load_pretrainB)\n            for param in self.netG_A.parameters():\n                param.requires_grad = False\n            for param in self.netG_B.parameters():\n                param.requires_grad = False\n            self.netG_A.eval()\n            self.netG_B.eval()\n\n        if opt.gpu_ids:\n            self.netG_A.cuda(opt.gpu_ids[0])\n            self.netG_B.cuda(opt.gpu_ids[0])\n            self.mapping_net.cuda(opt.gpu_ids[0])\n        \n        if not self.isTrain:\n            self.load_network(self.mapping_net, \"mapping_net\", opt.which_epoch)\n\n        # Discriminator network\n        if self.isTrain:\n            use_sigmoid = opt.no_lsgan\n            netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc\n            if not opt.no_instance:\n                netD_input_nc += 1\n\n            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,\n                                              opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)\n\n        # set loss functions and optimizers\n        if self.isTrain:\n            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:\n                raise NotImplementedError(\"Fake Pool Not Implemented for MultiGPU\")\n            self.fake_pool = ImagePool(opt.pool_size)\n            self.old_lr = opt.lr\n\n            # define loss functions\n            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, opt.use_two_stage_mapping)\n\n            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)\n\n\n            self.criterionFeat = torch.nn.L1Loss()\n            self.criterionFeat_feat = torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss()\n\n            if self.opt.image_L1:\n                self.criterionImage=torch.nn.L1Loss()\n            else:\n                self.criterionImage = torch.nn.SmoothL1Loss()\n\n\n            print(self.criterionFeat_feat)\n            if not opt.no_vgg_loss:\n                self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)\n                \n        \n            # Names so we can breakout loss\n            self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG','D_real', 'D_fake', 'Smooth_L1', 'G_Feat_L2_Stage_1')\n\n            # initialize optimizers\n            # optimizer G\n\n            if opt.no_TTUR:\n                beta1,beta2=opt.beta1,0.999\n                G_lr,D_lr=opt.lr,opt.lr\n            else:\n                beta1,beta2=0,0.9\n                G_lr,D_lr=opt.lr/2,opt.lr*2\n\n\n            if not opt.no_load_VAE:\n                params = list(self.mapping_net.parameters())\n                self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2))\n\n            # optimizer D                        \n            params = list(self.netD.parameters())    \n            self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2))\n\n            print(\"---------- Optimizers initialized -------------\")\n\n    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             \n        if self.opt.label_nc == 0:\n            input_label = label_map.data.cuda()\n        else:\n            # create one-hot vector for label map \n            size = label_map.size()\n            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])\n            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()\n            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)\n            if self.opt.data_type == 16:\n                input_label = input_label.half()\n\n        # get edges from instance map\n        if not self.opt.no_instance:\n            inst_map = inst_map.data.cuda()\n            edge_map = self.get_edges(inst_map)\n            input_label = torch.cat((input_label, edge_map), dim=1)         \n        input_label = Variable(input_label, volatile=infer)\n\n        # real images for training\n        if real_image is not None:\n            real_image = Variable(real_image.data.cuda())\n\n        return input_label, inst_map, real_image, feat_map\n\n    def discriminate(self, input_label, test_image, use_pool=False):\n        input_concat = torch.cat((input_label, test_image.detach()), dim=1)\n        if use_pool:            \n            fake_query = self.fake_pool.query(input_concat)\n            return self.netD.forward(fake_query)\n        else:\n            return self.netD.forward(input_concat)\n\n    def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None):\n        # Encode Inputs\n        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  \n\n        # Fake Generation\n        input_concat = input_label\n        \n        label_feat = self.netG_A.forward(input_concat, flow='enc')\n        # print('label:')\n        # print(label_feat.min(), label_feat.max(), label_feat.mean())\n        #label_feat = label_feat / 16.0\n\n        if self.opt.NL_use_mask: \n            label_feat_map=self.mapping_net(label_feat.detach(),inst)\n        else:\n            label_feat_map = self.mapping_net(label_feat.detach())\n        \n        fake_image = self.netG_B.forward(label_feat_map, flow='dec')\n        image_feat = self.netG_B.forward(real_image, flow='enc')\n\n        loss_feat_l2_stage_1=0\n        loss_feat_l2 = self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat\n            \n\n        if self.opt.feat_gan:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)        \n\n            # Real Detection and Loss        \n            pred_real = self.discriminate(label_feat.detach(), image_feat)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)        \n            pred_fake = self.netD.forward(torch.cat((label_feat.detach(), label_feat_map), dim=1))        \n            loss_G_GAN = self.criterionGAN(pred_fake, True)  \n        else:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)        \n\n            # Real Detection and Loss  \n            if pair:      \n                pred_real = self.discriminate(input_label, real_image)\n            else:\n                pred_real = self.discriminate(last_label, last_image)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)        \n            pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        \n            loss_G_GAN = self.criterionGAN(pred_fake, True)               \n        \n        # GAN feature matching loss\n        loss_G_GAN_Feat = 0\n        if not self.opt.no_ganFeat_loss and pair:\n            feat_weights = 4.0 / (self.opt.n_layers_D + 1)\n            D_weights = 1.0 / self.opt.num_D\n            for i in range(self.opt.num_D):\n                for j in range(len(pred_fake[i])-1):\n                    tmp = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat\n                    loss_G_GAN_Feat += D_weights * feat_weights * tmp\n        else:\n            loss_G_GAN_Feat = torch.zeros(1).to(label.device)\n                   \n        # VGG feature matching loss\n        loss_G_VGG = 0\n        if not self.opt.no_vgg_loss:\n            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros(1).to(label.device)\n\n        smooth_l1_loss=0\n        if self.opt.Smooth_L1:\n            smooth_l1_loss=self.criterionImage(fake_image,real_image)*self.opt.L1_weight\n\n\n        return [ self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake,smooth_l1_loss,loss_feat_l2_stage_1), None if not infer else fake_image ]\n\n\n    def inference(self, label, inst):\n\n        use_gpu = len(self.opt.gpu_ids) > 0\n        if use_gpu:\n            input_concat = label.data.cuda()\n            inst_data = inst.cuda()\n        else:\n            input_concat = label.data\n            inst_data = inst\n\n        label_feat = self.netG_A.forward(input_concat, flow=\"enc\")\n\n        if self.opt.NL_use_mask:\n            if self.opt.inference_optimize:\n                label_feat_map=self.mapping_net.inference_forward(label_feat.detach(),inst_data)\n            else:   \n                label_feat_map = self.mapping_net(label_feat.detach(), inst_data)\n        else:\n            label_feat_map = self.mapping_net(label_feat.detach())\n\n        fake_image = self.netG_B.forward(label_feat_map, flow=\"dec\")\n        return fake_image\n\n\nclass InferenceModel(Pix2PixHDModel_Mapping):\n    def forward(self, label, inst):\n        return self.inference(label, inst)\n\n"
  },
  {
    "path": "Global/models/models.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\n\n\ndef create_model(opt):\n    if opt.model == \"pix2pixHD\":\n        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel\n\n        if opt.isTrain:\n            model = Pix2PixHDModel()\n        else:\n            model = InferenceModel()\n    else:\n        from .ui_model import UIModel\n\n        model = UIModel()\n    model.initialize(opt)\n    if opt.verbose:\n        print(\"model [%s] was created\" % (model.name()))\n\n    if opt.isTrain and len(opt.gpu_ids) > 1:\n        # pass\n        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)\n\n    return model\n\ndef create_da_model(opt):\n    if opt.model == 'pix2pixHD':\n        from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel\n        if opt.isTrain:\n            model = Pix2PixHDModel()\n        else:\n            model = InferenceModel()\n    else:\n    \tfrom .ui_model import UIModel\n    \tmodel = UIModel()\n    model.initialize(opt)\n    if opt.verbose:\n        print(\"model [%s] was created\" % (model.name()))\n\n    if opt.isTrain and len(opt.gpu_ids) > 1:\n        #pass\n        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)\n\n    return model"
  },
  {
    "path": "Global/models/networks.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport torch\nimport torch.nn as nn\nimport functools\nfrom torch.autograd import Variable\nimport numpy as np\nfrom torch.nn.utils import spectral_norm\n\n# from util.util import SwitchNorm2d\nimport torch.nn.functional as F\n\n###############################################################################\n# Functions\n###############################################################################\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(0.0, 0.02)\n    elif classname.find(\"BatchNorm2d\") != -1:\n        m.weight.data.normal_(1.0, 0.02)\n        m.bias.data.fill_(0)\n\n\ndef get_norm_layer(norm_type=\"instance\"):\n    if norm_type == \"batch\":\n        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)\n    elif norm_type == \"instance\":\n        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)\n    elif norm_type == \"spectral\":\n        norm_layer = spectral_norm()\n    elif norm_type == \"SwitchNorm\":\n        norm_layer = SwitchNorm2d\n    else:\n        raise NotImplementedError(\"normalization layer [%s] is not found\" % norm_type)\n    return norm_layer\n\n\ndef print_network(net):\n    if isinstance(net, list):\n        net = net[0]\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print(net)\n    print(\"Total number of parameters: %d\" % num_params)\n\n\ndef define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,\n             n_blocks_local=3, norm='instance', gpu_ids=[], opt=None):\n    \n    norm_layer = get_norm_layer(norm_type=norm)\n    if netG == 'global':\n        # if opt.self_gen:\n        if opt.use_v2:\n            netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt)\n        else:\n            netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt)\n    else:\n        raise('generator not implemented!')\n    print(netG)\n    if len(gpu_ids) > 0:\n        assert(torch.cuda.is_available())\n        netG.cuda(gpu_ids[0])\n    netG.apply(weights_init)\n    return netG\n\n\ndef define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):\n    norm_layer = get_norm_layer(norm_type=norm)\n    netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)\n    print(netD)\n    if len(gpu_ids) > 0:\n        assert(torch.cuda.is_available())\n        netD.cuda(gpu_ids[0])\n    netD.apply(weights_init)\n    return netD\n\n\n\nclass GlobalGenerator_DCDCv2(nn.Module):\n    def __init__(\n        self,\n        input_nc,\n        output_nc,\n        ngf=64,\n        k_size=3,\n        n_downsampling=8,\n        norm_layer=nn.BatchNorm2d,\n        padding_type=\"reflect\",\n        opt=None,\n    ):\n        super(GlobalGenerator_DCDCv2, self).__init__()\n        activation = nn.ReLU(True)\n\n        model = [\n            nn.ReflectionPad2d(3),\n            nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0),\n            norm_layer(ngf),\n            activation,\n        ]\n        ### downsample\n        for i in range(opt.start_r):\n            mult = 2 ** i\n            model += [\n                nn.Conv2d(\n                    min(ngf * mult, opt.mc),\n                    min(ngf * mult * 2, opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                ),\n                norm_layer(min(ngf * mult * 2, opt.mc)),\n                activation,\n            ]\n        for i in range(opt.start_r, n_downsampling - 1):\n            mult = 2 ** i\n            model += [\n                nn.Conv2d(\n                    min(ngf * mult, opt.mc),\n                    min(ngf * mult * 2, opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                ),\n                norm_layer(min(ngf * mult * 2, opt.mc)),\n                activation,\n            ]\n            model += [\n                ResnetBlock(\n                    min(ngf * mult * 2, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n            model += [\n                ResnetBlock(\n                    min(ngf * mult * 2, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n        mult = 2 ** (n_downsampling - 1)\n\n        if opt.spatio_size == 32:\n            model += [\n                nn.Conv2d(\n                    min(ngf * mult, opt.mc),\n                    min(ngf * mult * 2, opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                ),\n                norm_layer(min(ngf * mult * 2, opt.mc)),\n                activation,\n            ]\n        if opt.spatio_size == 64:\n            model += [\n                ResnetBlock(\n                    min(ngf * mult * 2, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n        model += [\n            ResnetBlock(\n                min(ngf * mult * 2, opt.mc),\n                padding_type=padding_type,\n                activation=activation,\n                norm_layer=norm_layer,\n                opt=opt,\n            )\n        ]\n        # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)]\n        if opt.feat_dim > 0:\n            model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)]\n        self.encoder = nn.Sequential(*model)\n\n        # decode\n        model = []\n        if opt.feat_dim > 0:\n            model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)]\n        # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)]\n        o_pad = 0 if k_size == 4 else 1\n        mult = 2 ** n_downsampling\n        model += [\n            ResnetBlock(\n                min(ngf * mult, opt.mc),\n                padding_type=padding_type,\n                activation=activation,\n                norm_layer=norm_layer,\n                opt=opt,\n            )\n        ]\n\n        if opt.spatio_size == 32:\n            model += [\n                nn.ConvTranspose2d(\n                    min(ngf * mult, opt.mc),\n                    min(int(ngf * mult / 2), opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                    output_padding=o_pad,\n                ),\n                norm_layer(min(int(ngf * mult / 2), opt.mc)),\n                activation,\n            ]\n        if opt.spatio_size == 64:\n            model += [\n                ResnetBlock(\n                    min(ngf * mult, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n\n        for i in range(1, n_downsampling - opt.start_r):\n            mult = 2 ** (n_downsampling - i)\n            model += [\n                ResnetBlock(\n                    min(ngf * mult, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n            model += [\n                ResnetBlock(\n                    min(ngf * mult, opt.mc),\n                    padding_type=padding_type,\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=opt,\n                )\n            ]\n            model += [\n                nn.ConvTranspose2d(\n                    min(ngf * mult, opt.mc),\n                    min(int(ngf * mult / 2), opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                    output_padding=o_pad,\n                ),\n                norm_layer(min(int(ngf * mult / 2), opt.mc)),\n                activation,\n            ]\n        for i in range(n_downsampling - opt.start_r, n_downsampling):\n            mult = 2 ** (n_downsampling - i)\n            model += [\n                nn.ConvTranspose2d(\n                    min(ngf * mult, opt.mc),\n                    min(int(ngf * mult / 2), opt.mc),\n                    kernel_size=k_size,\n                    stride=2,\n                    padding=1,\n                    output_padding=o_pad,\n                ),\n                norm_layer(min(int(ngf * mult / 2), opt.mc)),\n                activation,\n            ]\n        if opt.use_segmentation_model:\n            model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)]\n        else:\n            model += [\n                nn.ReflectionPad2d(3),\n                nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0),\n                nn.Tanh(),\n            ]\n        self.decoder = nn.Sequential(*model)\n\n    def forward(self, input, flow=\"enc_dec\"):\n        if flow == \"enc\":\n            return self.encoder(input)\n        elif flow == \"dec\":\n            return self.decoder(input)\n        elif flow == \"enc_dec\":\n            x = self.encoder(input)\n            x = self.decoder(x)\n            return x\n\n\n# Define a resnet block\nclass ResnetBlock(nn.Module):\n    def __init__(\n        self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1\n    ):\n        super(ResnetBlock, self).__init__()\n        self.opt = opt\n        self.dilation = dilation\n        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)\n\n    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):\n        conv_block = []\n        p = 0\n        if padding_type == \"reflect\":\n            conv_block += [nn.ReflectionPad2d(self.dilation)]\n        elif padding_type == \"replicate\":\n            conv_block += [nn.ReplicationPad2d(self.dilation)]\n        elif padding_type == \"zero\":\n            p = self.dilation\n        else:\n            raise NotImplementedError(\"padding [%s] is not implemented\" % padding_type)\n\n        conv_block += [\n            nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation),\n            norm_layer(dim),\n            activation,\n        ]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == \"reflect\":\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == \"replicate\":\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == \"zero\":\n            p = 1\n        else:\n            raise NotImplementedError(\"padding [%s] is not implemented\" % padding_type)\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)]\n\n        return nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        out = x + self.conv_block(x)\n        return out\n\n\nclass Encoder(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):\n        super(Encoder, self).__init__()\n        self.output_nc = output_nc\n\n        model = [\n            nn.ReflectionPad2d(3),\n            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),\n            norm_layer(ngf),\n            nn.ReLU(True),\n        ]\n        ### downsample\n        for i in range(n_downsampling):\n            mult = 2 ** i\n            model += [\n                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),\n                norm_layer(ngf * mult * 2),\n                nn.ReLU(True),\n            ]\n\n        ### upsample\n        for i in range(n_downsampling):\n            mult = 2 ** (n_downsampling - i)\n            model += [\n                nn.ConvTranspose2d(\n                    ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1\n                ),\n                norm_layer(int(ngf * mult / 2)),\n                nn.ReLU(True),\n            ]\n\n        model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input, inst):\n        outputs = self.model(input)\n\n        # instance-wise average pooling\n        outputs_mean = outputs.clone()\n        inst_list = np.unique(inst.cpu().numpy().astype(int))\n        for i in inst_list:\n            for b in range(input.size()[0]):\n                indices = (inst[b : b + 1] == int(i)).nonzero()  # n x 4\n                for j in range(self.output_nc):\n                    output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]\n                    mean_feat = torch.mean(output_ins).expand_as(output_ins)\n                    outputs_mean[\n                        indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]\n                    ] = mean_feat\n        return outputs_mean\n\n\ndef SN(module, mode=True):\n    if mode:\n        return torch.nn.utils.spectral_norm(module)\n\n    return module\n\n\nclass NonLocalBlock2D_with_mask_Res(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        inter_channels,\n        mode=\"add\",\n        re_norm=False,\n        temperature=1.0,\n        use_self=False,\n        cosin=False,\n    ):\n        super(NonLocalBlock2D_with_mask_Res, self).__init__()\n\n        self.cosin = cosin\n        self.renorm = re_norm\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        self.g = nn.Conv2d(\n            in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        )\n\n        self.W = nn.Conv2d(\n            in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0\n        )\n        # for pytorch 0.3.1\n        # nn.init.constant(self.W.weight, 0)\n        # nn.init.constant(self.W.bias, 0)\n        # for pytorch 0.4.0\n        nn.init.constant_(self.W.weight, 0)\n        nn.init.constant_(self.W.bias, 0)\n        self.theta = nn.Conv2d(\n            in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        )\n\n        self.phi = nn.Conv2d(\n            in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        )\n\n        self.mode = mode\n        self.temperature = temperature\n        self.use_self = use_self\n\n        norm_layer = get_norm_layer(norm_type=\"instance\")\n        activation = nn.ReLU(True)\n\n        model = []\n        for i in range(3):\n            model += [\n                ResnetBlock(\n                    inter_channels,\n                    padding_type=\"reflect\",\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=None,\n                )\n            ]\n        self.res_block = nn.Sequential(*model)\n\n    def forward(self, x, mask):  ## The shape of mask is Batch*1*H*W\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n\n        theta_x = theta_x.permute(0, 2, 1)\n\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n\n        if self.cosin:\n            theta_x = F.normalize(theta_x, dim=2)\n            phi_x = F.normalize(phi_x, dim=1)\n\n        f = torch.matmul(theta_x, phi_x)\n\n        f /= self.temperature\n\n        f_div_C = F.softmax(f, dim=2)\n\n        tmp = 1 - mask\n        mask = F.interpolate(mask, (x.size(2), x.size(3)), mode=\"bilinear\")\n        mask[mask > 0] = 1.0\n        mask = 1 - mask\n\n        tmp = F.interpolate(tmp, (x.size(2), x.size(3)))\n        mask *= tmp\n\n        mask_expand = mask.view(batch_size, 1, -1)\n        mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1)\n\n        # mask = 1 - mask\n        # mask=F.interpolate(mask,(x.size(2),x.size(3)))\n        # mask_expand=mask.view(batch_size,1,-1)\n        # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1)\n\n        if self.use_self:\n            mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0\n\n        #    print(mask_expand.shape)\n        #    print(f_div_C.shape)\n\n        f_div_C = mask_expand * f_div_C\n        if self.renorm:\n            f_div_C = F.normalize(f_div_C, p=1, dim=2)\n\n        ###########################\n\n        y = torch.matmul(f_div_C, g_x)\n\n        y = y.permute(0, 2, 1).contiguous()\n\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n\n        W_y = self.res_block(W_y)\n\n        if self.mode == \"combine\":\n            full_mask = mask.repeat(1, self.inter_channels, 1, 1)\n            z = full_mask * x + (1 - full_mask) * W_y\n        return z\n\n\nclass MultiscaleDiscriminator(nn.Module):\n    def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,\n                 use_sigmoid=False, num_D=3, getIntermFeat=False):\n        super(MultiscaleDiscriminator, self).__init__()\n        self.num_D = num_D\n        self.n_layers = n_layers\n        self.getIntermFeat = getIntermFeat\n\n        for i in range(num_D):\n            netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)\n            if getIntermFeat:\n                for j in range(n_layers+2):\n                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))\n            else:\n                setattr(self, 'layer'+str(i), netD.model)\n\n        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)\n\n    def singleD_forward(self, model, input):\n        if self.getIntermFeat:\n            result = [input]\n            for i in range(len(model)):\n                result.append(model[i](result[-1]))\n            return result[1:]\n        else:\n            return [model(input)]\n\n    def forward(self, input):\n        num_D = self.num_D\n        result = []\n        input_downsampled = input\n        for i in range(num_D):\n            if self.getIntermFeat:\n                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]\n            else:\n                model = getattr(self, 'layer'+str(num_D-1-i))\n            result.append(self.singleD_forward(model, input_downsampled))\n            if i != (num_D-1):\n                input_downsampled = self.downsample(input_downsampled)\n        return result\n\n# Defines the PatchGAN discriminator with the specified arguments.\nclass NLayerDiscriminator(nn.Module):\n    def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):\n        super(NLayerDiscriminator, self).__init__()\n        self.getIntermFeat = getIntermFeat\n        self.n_layers = n_layers\n\n        kw = 4\n        padw = int(np.ceil((kw-1.0)/2))\n        sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]]\n\n        nf = ndf\n        for n in range(1, n_layers):\n            nf_prev = nf\n            nf = min(nf * 2, 512)\n            sequence += [[\n                SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN),\n                norm_layer(nf), nn.LeakyReLU(0.2, True)\n            ]]\n\n        nf_prev = nf\n        nf = min(nf * 2, 512)\n        sequence += [[\n            SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN),\n            norm_layer(nf),\n            nn.LeakyReLU(0.2, True)\n        ]]\n\n        sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]]\n\n        if use_sigmoid:\n            sequence += [[nn.Sigmoid()]]\n\n        if getIntermFeat:\n            for n in range(len(sequence)):\n                setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))\n        else:\n            sequence_stream = []\n            for n in range(len(sequence)):\n                sequence_stream += sequence[n]\n            self.model = nn.Sequential(*sequence_stream)\n\n    def forward(self, input):\n        if self.getIntermFeat:\n            res = [input]\n            for n in range(self.n_layers+2):\n                model = getattr(self, 'model'+str(n))\n                res.append(model(res[-1]))\n            return res[1:]\n        else:\n            return self.model(input)\n\n\n\nclass Patch_Attention_4(nn.Module):  ## While combine the feature map, use conv and mask\n    def __init__(self, in_channels, inter_channels, patch_size):\n        super(Patch_Attention_4, self).__init__()\n\n        self.patch_size=patch_size\n\n\n        # self.g = nn.Conv2d(\n        #     in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        # )\n\n        # self.W = nn.Conv2d(\n        #     in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0\n        # )\n        # # for pytorch 0.3.1\n        # # nn.init.constant(self.W.weight, 0)\n        # # nn.init.constant(self.W.bias, 0)\n        # # for pytorch 0.4.0\n        # nn.init.constant_(self.W.weight, 0)\n        # nn.init.constant_(self.W.bias, 0)\n        # self.theta = nn.Conv2d(\n        #     in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        # )\n\n        # self.phi = nn.Conv2d(\n        #     in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0\n        # )\n\n        self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True)\n        norm_layer = get_norm_layer(norm_type=\"instance\")\n        activation = nn.ReLU(True)\n\n        model = []\n        for i in range(1):\n            model += [\n                ResnetBlock(\n                    inter_channels,\n                    padding_type=\"reflect\",\n                    activation=activation,\n                    norm_layer=norm_layer,\n                    opt=None,\n                )\n            ]\n        self.res_block = nn.Sequential(*model)\n\n    def Hard_Compose(self, input, dim, index):\n        # batch index select\n        # input: [B,C,HW]\n        # dim: scalar > 0\n        # index: [B, HW]\n        views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))]\n        expanse = list(input.size())\n        expanse[0] = -1\n        expanse[dim] = -1\n        index = index.view(views).expand(expanse)\n        return torch.gather(input, dim, index)\n\n    def forward(self, z, mask):  ## The shape of mask is Batch*1*H*W\n\n        x=self.res_block(z)\n\n        b,c,h,w=x.shape\n\n        ## mask resize + dilation\n        # tmp = 1 - mask\n        mask = F.interpolate(mask, (x.size(2), x.size(3)), mode=\"bilinear\")\n        mask[mask > 0] = 1.0\n\n        # mask = 1 - mask\n        # tmp = F.interpolate(tmp, (x.size(2), x.size(3)))\n        # mask *= tmp\n        # mask=1-mask\n        ## 1: mask position 0: non-mask\n\n        mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)\n        non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()\n        all_patch_num=h*w/self.patch_size/self.patch_size\n        non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1)\n\n        x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)\n        y_unfold=x_unfold.permute(0,2,1)\n        x_unfold_normalized=F.normalize(x_unfold,dim=1)\n        y_unfold_normalized=F.normalize(y_unfold,dim=2)\n        correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized)\n        correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9)\n        correlation_matrix=F.softmax(correlation_matrix,dim=2)\n\n        # print(correlation_matrix)\n\n        R, max_arg=torch.max(correlation_matrix,dim=2)\n\n        composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg)\n        composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)\n\n        concat_1=torch.cat((z,composed_fold,mask),dim=1)\n        concat_1=self.F_Combine(concat_1)\n\n        return concat_1\n\n    def inference_forward(self,z,mask): ## Reduce the extra memory cost\n\n\n        x=self.res_block(z)\n\n        b,c,h,w=x.shape\n\n        ## mask resize + dilation\n        # tmp = 1 - mask\n        mask = F.interpolate(mask, (x.size(2), x.size(3)), mode=\"bilinear\")\n        mask[mask > 0] = 1.0\n        # mask = 1 - mask\n        # tmp = F.interpolate(tmp, (x.size(2), x.size(3)))\n        # mask *= tmp\n        # mask=1-mask\n        ## 1: mask position 0: non-mask\n\n        mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)\n        non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num\n\n        all_patch_num=h*w/self.patch_size/self.patch_size\n\n        mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0]\n\n\n        if len(mask_index)==0: ## No mask patch is selected, no attention is needed\n\n            composed_fold=x\n\n        else:\n\n            unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0]\n\n            x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)\n            \n            Query_Patch=torch.index_select(x_unfold,2,mask_index)\n            Key_Patch=torch.index_select(x_unfold,2,unmask_index)\n\n            Query_Patch=Query_Patch.permute(0,2,1)        \n            Query_Patch_normalized=F.normalize(Query_Patch,dim=2)\n            Key_Patch_normalized=F.normalize(Key_Patch,dim=1)\n\n            correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized)\n            correlation_matrix=F.softmax(correlation_matrix,dim=2)\n\n\n            R, max_arg=torch.max(correlation_matrix,dim=2)\n\n            composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg)\n            x_unfold[:,:,mask_index]=composed_unfold\n            composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)\n\n        concat_1=torch.cat((z,composed_fold,mask),dim=1)\n        concat_1=self.F_Combine(concat_1)\n\n\n        return concat_1\n\n##############################################################################\n# Losses\n##############################################################################\nclass GANLoss(nn.Module):\n    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,\n                 tensor=torch.FloatTensor):\n        super(GANLoss, self).__init__()\n        self.real_label = target_real_label\n        self.fake_label = target_fake_label\n        self.real_label_var = None\n        self.fake_label_var = None\n        self.Tensor = tensor\n        if use_lsgan:\n            self.loss = nn.MSELoss()\n        else:\n            self.loss = nn.BCELoss()\n\n    def get_target_tensor(self, input, target_is_real):\n        target_tensor = None\n        if target_is_real:\n            create_label = ((self.real_label_var is None) or\n                            (self.real_label_var.numel() != input.numel()))\n            if create_label:\n                real_tensor = self.Tensor(input.size()).fill_(self.real_label)\n                self.real_label_var = Variable(real_tensor, requires_grad=False)\n            target_tensor = self.real_label_var\n        else:\n            create_label = ((self.fake_label_var is None) or\n                            (self.fake_label_var.numel() != input.numel()))\n            if create_label:\n                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)\n                self.fake_label_var = Variable(fake_tensor, requires_grad=False)\n            target_tensor = self.fake_label_var\n        return target_tensor\n\n    def __call__(self, input, target_is_real):\n        if isinstance(input[0], list):\n            loss = 0\n            for input_i in input:\n                pred = input_i[-1]\n                target_tensor = self.get_target_tensor(pred, target_is_real)\n                loss += self.loss(pred, target_tensor)\n            return loss\n        else:\n            target_tensor = self.get_target_tensor(input[-1], target_is_real)\n            return self.loss(input[-1], target_tensor)\n\n\n\n\n####################################### VGG Loss\n\nfrom torchvision import models\nclass VGG19_torch(torch.nn.Module):\n    def __init__(self, requires_grad=False):\n        super(VGG19_torch, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h_relu1 = self.slice1(X)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\nclass VGGLoss_torch(nn.Module):\n    def __init__(self, gpu_ids):\n        super(VGGLoss_torch, self).__init__()\n        self.vgg = VGG19_torch().cuda()\n        self.criterion = nn.L1Loss()\n        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]\n\n    def forward(self, x, y):\n        x_vgg, y_vgg = self.vgg(x), self.vgg(y)\n        loss = 0\n        for i in range(len(x_vgg)):\n            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())\n        return loss"
  },
  {
    "path": "Global/models/pix2pixHD_model.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom torch.autograd import Variable\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\n\nclass Pix2PixHDModel(BaseModel):\n    def name(self):\n        return 'Pix2PixHDModel'\n    \n    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1):\n        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1)\n        def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1):\n            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f]\n        return loss_filter\n    \n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM\n            torch.backends.cudnn.benchmark = True\n        self.isTrain = opt.isTrain\n        self.use_features = opt.instance_feat or opt.label_feat   ## Clearly it is false\n        self.gen_features = self.use_features and not self.opt.load_features ## it is also false\n        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel #\n\n        ##### define networks        \n        # Generator network\n        netG_input_nc = input_nc        \n        if not opt.no_instance:\n            netG_input_nc += 1\n        if self.use_features:\n            netG_input_nc += opt.feat_num                  \n        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, \n                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, \n                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)       \n\n        # Discriminator network\n        if self.isTrain:\n            use_sigmoid = opt.no_lsgan\n            netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc\n            if not opt.no_instance:\n                netD_input_nc += 1\n            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,\n                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)\n\n        if self.opt.verbose:\n                print('---------- Networks initialized -------------')\n\n        # load networks\n        if not self.isTrain or opt.continue_train or opt.load_pretrain:\n            pretrained_path = '' if not self.isTrain else opt.load_pretrain\n            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)\n\n            print(\"---------- G Networks reloaded -------------\")\n            if self.isTrain:\n                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)\n                print(\"---------- D Networks reloaded -------------\")\n\n\n            if self.gen_features:\n                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              \n\n        # set loss functions and optimizers\n        if self.isTrain:\n            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:   ## The pool_size is 0!\n                raise NotImplementedError(\"Fake Pool Not Implemented for MultiGPU\")\n            self.fake_pool = ImagePool(opt.pool_size)\n            self.old_lr = opt.lr\n\n            # define loss functions\n            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1)\n            \n            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   \n            self.criterionFeat = torch.nn.L1Loss()\n\n            # self.criterionImage = torch.nn.SmoothL1Loss()\n            if not opt.no_vgg_loss:\n                self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)\n                \n\n            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1')\n\n            # initialize optimizers\n            # optimizer G\n            params = list(self.netG.parameters())\n            if self.gen_features:              \n                params += list(self.netE.parameters())         \n            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            \n\n            # optimizer D                        \n            params = list(self.netD.parameters())    \n            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))\n\n            print(\"---------- Optimizers initialized -------------\")\n\n            if opt.continue_train:\n                self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)\n                self.load_optimizer(self.optimizer_G, \"G\", opt.which_epoch)\n                for param_groups in self.optimizer_D.param_groups:\n                    self.old_lr=param_groups['lr']\n\n                print(\"---------- Optimizers reloaded -------------\")\n                print(\"---------- Current LR is %.8f -------------\"%(self.old_lr))\n\n            ## We also want to re-load the parameters of optimizer.\n\n    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             \n        if self.opt.label_nc == 0:\n            input_label = label_map.data.cuda()\n        else:\n            # create one-hot vector for label map \n            size = label_map.size()\n            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])\n            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()\n            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)\n            if self.opt.data_type == 16:\n                input_label = input_label.half()\n\n        # get edges from instance map\n        if not self.opt.no_instance:\n            inst_map = inst_map.data.cuda()\n            edge_map = self.get_edges(inst_map)\n            input_label = torch.cat((input_label, edge_map), dim=1)         \n        input_label = Variable(input_label, volatile=infer)\n\n        # real images for training\n        if real_image is not None:\n            real_image = Variable(real_image.data.cuda())\n\n        # instance map for feature encoding\n        if self.use_features:\n            # get precomputed feature maps\n            if self.opt.load_features:\n                feat_map = Variable(feat_map.data.cuda())\n            if self.opt.label_feat:\n                inst_map = label_map.cuda()\n\n        return input_label, inst_map, real_image, feat_map\n\n    def discriminate(self, input_label, test_image, use_pool=False):\n        if input_label is None:\n            input_concat = test_image.detach()\n        else:\n            input_concat = torch.cat((input_label, test_image.detach()), dim=1)\n        if use_pool:            \n            fake_query = self.fake_pool.query(input_concat)\n            return self.netD.forward(fake_query)\n        else:\n            return self.netD.forward(input_concat)\n\n    def forward(self, label, inst, image, feat, infer=False):\n        # Encode Inputs\n        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  \n\n        # Fake Generation\n        if self.use_features:\n            if not self.opt.load_features:\n                feat_map = self.netE.forward(real_image, inst_map)                     \n            input_concat = torch.cat((input_label, feat_map), dim=1)                        \n        else:\n            input_concat = input_label\n        hiddens = self.netG.forward(input_concat, 'enc')\n        noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))\n        # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.\n        # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)\n        fake_image = self.netG.forward(hiddens + noise, 'dec')\n\n        if self.opt.no_cgan:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)        \n\n            # Real Detection and Loss        \n            pred_real = self.discriminate(None, real_image)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)        \n            pred_fake = self.netD.forward(fake_image)        \n            loss_G_GAN = self.criterionGAN(pred_fake, True)\n        else:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)        \n\n            # Real Detection and Loss        \n            pred_real = self.discriminate(input_label, real_image)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)        \n            pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        \n            loss_G_GAN = self.criterionGAN(pred_fake, True) \n        \n        \n        loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl\n\n        # GAN feature matching loss\n        loss_G_GAN_Feat = 0\n        if not self.opt.no_ganFeat_loss:\n            feat_weights = 4.0 / (self.opt.n_layers_D + 1)\n            D_weights = 1.0 / self.opt.num_D\n            for i in range(self.opt.num_D):\n                for j in range(len(pred_fake[i])-1):\n                    loss_G_GAN_Feat += D_weights * feat_weights * \\\n                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat\n                   \n        # VGG feature matching loss\n        loss_G_VGG = 0\n        if not self.opt.no_vgg_loss:\n            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat\n\n\n        smooth_l1_loss=0\n\n        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ]\n\n    def inference(self, label, inst, image=None, feat=None):\n        # Encode Inputs        \n        image = Variable(image) if image is not None else None\n        input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)\n\n        # Fake Generation\n        if self.use_features:\n            if self.opt.use_encoded_image:\n                # encode the real image to get feature map\n                feat_map = self.netE.forward(real_image, inst_map)\n            else:\n                # sample clusters from precomputed features             \n                feat_map = self.sample_features(inst_map)\n            input_concat = torch.cat((input_label, feat_map), dim=1)\n        else:\n            input_concat = input_label        \n           \n        if torch.__version__.startswith('0.4'):\n            with torch.no_grad():\n                fake_image = self.netG.forward(input_concat)\n        else:\n            fake_image = self.netG.forward(input_concat)\n        return fake_image\n\n    def sample_features(self, inst): \n        # read precomputed feature clusters \n        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)        \n        features_clustered = np.load(cluster_path, encoding='latin1').item()\n\n        # randomly sample from the feature clusters\n        inst_np = inst.cpu().numpy().astype(int)                                      \n        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])\n        for i in np.unique(inst_np):    \n            label = i if i < 1000 else i//1000\n            if label in features_clustered:\n                feat = features_clustered[label]\n                cluster_idx = np.random.randint(0, feat.shape[0]) \n                                            \n                idx = (inst == int(i)).nonzero()\n                for k in range(self.opt.feat_num):                                    \n                    feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]\n        if self.opt.data_type==16:\n            feat_map = feat_map.half()\n        return feat_map\n\n    def encode_features(self, image, inst):\n        image = Variable(image.cuda(), volatile=True)\n        feat_num = self.opt.feat_num\n        h, w = inst.size()[2], inst.size()[3]\n        block_num = 32\n        feat_map = self.netE.forward(image, inst.cuda())\n        inst_np = inst.cpu().numpy().astype(int)\n        feature = {}\n        for i in range(self.opt.label_nc):\n            feature[i] = np.zeros((0, feat_num+1))\n        for i in np.unique(inst_np):\n            label = i if i < 1000 else i//1000\n            idx = (inst == int(i)).nonzero()\n            num = idx.size()[0]\n            idx = idx[num//2,:]\n            val = np.zeros((1, feat_num+1))                        \n            for k in range(feat_num):\n                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]            \n            val[0, feat_num] = float(num) / (h * w // block_num)\n            feature[label] = np.append(feature[label], val, axis=0)\n        return feature\n\n    def get_edges(self, t):\n        edge = torch.cuda.ByteTensor(t.size()).zero_()\n        edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])\n        edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])\n        edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])\n        edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])\n        if self.opt.data_type==16:\n            return edge.half()\n        else:\n            return edge.float()\n\n    def save(self, which_epoch):\n        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)\n        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)\n\n        self.save_optimizer(self.optimizer_G,\"G\",which_epoch)\n        self.save_optimizer(self.optimizer_D,\"D\",which_epoch)\n\n        if self.gen_features:\n            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)\n\n    def update_fixed_params(self):\n\n        params = list(self.netG.parameters())\n        if self.gen_features:\n            params += list(self.netE.parameters())           \n        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))\n        if self.opt.verbose:\n            print('------------ Now also finetuning global generator -----------')\n\n    def update_learning_rate(self):\n        lrd = self.opt.lr / self.opt.niter_decay\n        lr = self.old_lr - lrd        \n        for param_group in self.optimizer_D.param_groups:\n            param_group['lr'] = lr\n        for param_group in self.optimizer_G.param_groups:\n            param_group['lr'] = lr\n        if self.opt.verbose:\n            print('update learning rate: %f -> %f' % (self.old_lr, lr))\n        self.old_lr = lr\n\n\nclass InferenceModel(Pix2PixHDModel):\n    def forward(self, inp):\n        label, inst = inp\n        return self.inference(label, inst)\n"
  },
  {
    "path": "Global/models/pix2pixHD_model_DA.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport torch\nimport os\nfrom torch.autograd import Variable\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\n\n\nclass Pix2PixHDModel(BaseModel):\n    def name(self):\n        return 'Pix2PixHDModel'\n\n    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):\n        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True)\n\n        def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake):\n            return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f]\n\n        return loss_filter\n\n    def initialize(self, opt):\n        BaseModel.initialize(self, opt)\n        if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM\n            torch.backends.cudnn.benchmark = True\n        self.isTrain = opt.isTrain\n        self.use_features = opt.instance_feat or opt.label_feat  ## Clearly it is false\n        self.gen_features = self.use_features and not self.opt.load_features  ## it is also false\n        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc  ## Just is the origin input channel #\n\n        ##### define networks\n        # Generator network\n        netG_input_nc = input_nc\n        if not opt.no_instance:\n            netG_input_nc += 1\n        if self.use_features:\n            netG_input_nc += opt.feat_num\n        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,\n                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,\n                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)\n\n        # Discriminator network\n        if self.isTrain:\n            use_sigmoid = opt.no_lsgan\n            netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc\n            if not opt.no_instance:\n                netD_input_nc += 1\n            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid,\n                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)\n\n            self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,\n                                          1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)\n\n        if self.opt.verbose:\n            print('---------- Networks initialized -------------')\n\n        # load networks\n        if not self.isTrain or opt.continue_train or opt.load_pretrain:\n            pretrained_path = '' if not self.isTrain else opt.load_pretrain\n            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)\n\n            print(\"---------- G Networks reloaded -------------\")\n            if self.isTrain:\n                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)\n                self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path)\n                print(\"---------- D Networks reloaded -------------\")\n\n\n                # set loss functions and optimizers\n        if self.isTrain:\n            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:  ## The pool_size is 0!\n                raise NotImplementedError(\"Fake Pool Not Implemented for MultiGPU\")\n            self.fake_pool = ImagePool(opt.pool_size)\n            self.old_lr = opt.lr\n\n            # define loss functions\n            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)\n\n            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)\n            self.criterionFeat = torch.nn.L1Loss()\n            if not opt.no_vgg_loss:\n                self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)\n\n            # Names so we can breakout loss\n            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake')\n\n            # initialize optimizers\n            # optimizer G\n            params = list(self.netG.parameters())\n            if self.gen_features:\n                params += list(self.netE.parameters())\n            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))\n\n            # optimizer D\n            params = list(self.netD.parameters())\n            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))\n\n            params = list(self.feat_D.parameters())\n            self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))\n\n            print(\"---------- Optimizers initialized -------------\")\n\n            if opt.continue_train:\n                self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)\n                self.load_optimizer(self.optimizer_G, \"G\", opt.which_epoch)\n                self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch)\n                for param_groups in self.optimizer_D.param_groups:\n                    self.old_lr = param_groups['lr']\n\n                print(\"---------- Optimizers reloaded -------------\")\n                print(\"---------- Current LR is %.8f -------------\" % (self.old_lr))\n\n            ## We also want to re-load the parameters of optimizer.\n\n    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):\n        if self.opt.label_nc == 0:\n            input_label = label_map.data.cuda()\n        else:\n            # create one-hot vector for label map\n            size = label_map.size()\n            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])\n            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()\n            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)\n            if self.opt.data_type == 16:\n                input_label = input_label.half()\n\n        # get edges from instance map\n        if not self.opt.no_instance:\n            inst_map = inst_map.data.cuda()\n            edge_map = self.get_edges(inst_map)\n            input_label = torch.cat((input_label, edge_map), dim=1)\n        input_label = Variable(input_label, volatile=infer)\n\n        # real images for training\n        if real_image is not None:\n            real_image = Variable(real_image.data.cuda())\n\n        # instance map for feature encoding\n        if self.use_features:\n            # get precomputed feature maps\n            if self.opt.load_features:\n                feat_map = Variable(feat_map.data.cuda())\n            if self.opt.label_feat:\n                inst_map = label_map.cuda()\n\n        return input_label, inst_map, real_image, feat_map\n\n    def discriminate(self, input_label, test_image, use_pool=False):\n        if input_label is None:\n            input_concat = test_image.detach()\n        else:\n            input_concat = torch.cat((input_label, test_image.detach()), dim=1)\n        if use_pool:\n            fake_query = self.fake_pool.query(input_concat)\n            return self.netD.forward(fake_query)\n        else:\n            return self.netD.forward(input_concat)\n\n    def feat_discriminate(self,input):\n\n        return self.feat_D.forward(input.detach())\n\n\n    def forward(self, label, inst, image, feat, infer=False):\n        # Encode Inputs\n        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)\n\n        # Fake Generation\n        if self.use_features:\n            if not self.opt.load_features:\n                feat_map = self.netE.forward(real_image, inst_map)\n            input_concat = torch.cat((input_label, feat_map), dim=1)\n        else:\n            input_concat = input_label\n        hiddens = self.netG.forward(input_concat, 'enc')\n        noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))\n        # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.\n        # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)\n        fake_image = self.netG.forward(hiddens + noise, 'dec')\n\n        ####################\n        ##### GAN for the intermediate feature\n        real_old_feat =[]\n        syn_feat = []\n        for index,x in enumerate(inst):\n            if x==1:\n                real_old_feat.append(hiddens[index].unsqueeze(0))\n            else:\n                syn_feat.append(hiddens[index].unsqueeze(0))\n        L=min(len(real_old_feat),len(syn_feat))\n        real_old_feat=real_old_feat[:L]\n        syn_feat=syn_feat[:L]\n        real_old_feat=torch.cat(real_old_feat,0)\n        syn_feat=torch.cat(syn_feat,0)\n\n        pred_fake_feat=self.feat_discriminate(real_old_feat)\n        loss_featD_fake = self.criterionGAN(pred_fake_feat, False)\n        pred_real_feat=self.feat_discriminate(syn_feat)\n        loss_featD_real = self.criterionGAN(pred_real_feat, True)\n\n        pred_fake_feat_G=self.feat_D.forward(real_old_feat)\n        loss_G_featD=self.criterionGAN(pred_fake_feat_G,True)\n\n\n        #####################################\n        if self.opt.no_cgan:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)\n\n            # Real Detection and Loss\n            pred_real = self.discriminate(None, real_image)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)\n            pred_fake = self.netD.forward(fake_image)\n            loss_G_GAN = self.criterionGAN(pred_fake, True)\n        else:\n            # Fake Detection and Loss\n            pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)\n            loss_D_fake = self.criterionGAN(pred_fake_pool, False)\n\n            # Real Detection and Loss\n            pred_real = self.discriminate(input_label, real_image)\n            loss_D_real = self.criterionGAN(pred_real, True)\n\n            # GAN loss (Fake Passability Loss)\n            pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))\n            loss_G_GAN = self.criterionGAN(pred_fake, True)\n\n        loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl\n\n        # GAN feature matching loss\n        loss_G_GAN_Feat = 0\n        if not self.opt.no_ganFeat_loss:\n            feat_weights = 4.0 / (self.opt.n_layers_D + 1)\n            D_weights = 1.0 / self.opt.num_D\n            for i in range(self.opt.num_D):\n                for j in range(len(pred_fake[i]) - 1):\n                    loss_G_GAN_Feat += D_weights * feat_weights * \\\n                                       self.criterionFeat(pred_fake[i][j],\n                                                          pred_real[i][j].detach()) * self.opt.lambda_feat\n\n        # VGG feature matching loss\n        loss_G_VGG = 0\n        if not self.opt.no_vgg_loss:\n            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat\n\n        # Only return the fake_B image if necessary to save BW\n        return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake),\n                None if not infer else fake_image]\n\n    def inference(self, label, inst, image=None, feat=None):\n        # Encode Inputs\n        image = Variable(image) if image is not None else None\n        input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)\n\n        # Fake Generation\n        if self.use_features:\n            if self.opt.use_encoded_image:\n                # encode the real image to get feature map\n                feat_map = self.netE.forward(real_image, inst_map)\n            else:\n                # sample clusters from precomputed features\n                feat_map = self.sample_features(inst_map)\n            input_concat = torch.cat((input_label, feat_map), dim=1)\n        else:\n            input_concat = input_label\n\n        if torch.__version__.startswith('0.4'):\n            with torch.no_grad():\n                fake_image = self.netG.forward(input_concat)\n        else:\n            fake_image = self.netG.forward(input_concat)\n        return fake_image\n\n    def sample_features(self, inst):\n        # read precomputed feature clusters\n        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)\n        features_clustered = np.load(cluster_path, encoding='latin1').item()\n\n        # randomly sample from the feature clusters\n        inst_np = inst.cpu().numpy().astype(int)\n        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])\n        for i in np.unique(inst_np):\n            label = i if i < 1000 else i // 1000\n            if label in features_clustered:\n                feat = features_clustered[label]\n                cluster_idx = np.random.randint(0, feat.shape[0])\n\n                idx = (inst == int(i)).nonzero()\n                for k in range(self.opt.feat_num):\n                    feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k]\n        if self.opt.data_type == 16:\n            feat_map = feat_map.half()\n        return feat_map\n\n    def encode_features(self, image, inst):\n        image = Variable(image.cuda(), volatile=True)\n        feat_num = self.opt.feat_num\n        h, w = inst.size()[2], inst.size()[3]\n        block_num = 32\n        feat_map = self.netE.forward(image, inst.cuda())\n        inst_np = inst.cpu().numpy().astype(int)\n        feature = {}\n        for i in range(self.opt.label_nc):\n            feature[i] = np.zeros((0, feat_num + 1))\n        for i in np.unique(inst_np):\n            label = i if i < 1000 else i // 1000\n            idx = (inst == int(i)).nonzero()\n            num = idx.size()[0]\n            idx = idx[num // 2, :]\n            val = np.zeros((1, feat_num + 1))\n            for k in range(feat_num):\n                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]\n            val[0, feat_num] = float(num) / (h * w // block_num)\n            feature[label] = np.append(feature[label], val, axis=0)\n        return feature\n\n    def get_edges(self, t):\n        edge = torch.cuda.ByteTensor(t.size()).zero_()\n        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])\n        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])\n        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])\n        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])\n        if self.opt.data_type == 16:\n            return edge.half()\n        else:\n            return edge.float()\n\n    def save(self, which_epoch):\n        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)\n        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)\n        self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids)\n\n        self.save_optimizer(self.optimizer_G, \"G\", which_epoch)\n        self.save_optimizer(self.optimizer_D, \"D\", which_epoch)\n        self.save_optimizer(self.optimizer_featD,'featD',which_epoch)\n\n        if self.gen_features:\n            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)\n\n    def update_fixed_params(self):\n\n        params = list(self.netG.parameters())\n        if self.gen_features:\n            params += list(self.netE.parameters())\n        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))\n        if self.opt.verbose:\n            print('------------ Now also finetuning global generator -----------')\n\n    def update_learning_rate(self):\n        lrd = self.opt.lr / self.opt.niter_decay\n        lr = self.old_lr - lrd\n        for param_group in self.optimizer_D.param_groups:\n            param_group['lr'] = lr\n        for param_group in self.optimizer_G.param_groups:\n            param_group['lr'] = lr\n        for param_group in self.optimizer_featD.param_groups:\n            param_group['lr'] = lr\n        if self.opt.verbose:\n            print('update learning rate: %f -> %f' % (self.old_lr, lr))\n        self.old_lr = lr\n\n\nclass InferenceModel(Pix2PixHDModel):\n    def forward(self, inp):\n        label, inst = inp\n        return self.inference(label, inst)\n"
  },
  {
    "path": "Global/options/__init__.py",
    "content": ""
  },
  {
    "path": "Global/options/base_options.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport argparse\nimport os\nfrom util import util\nimport torch\n\n\nclass BaseOptions:\n    def __init__(self):\n        self.parser = argparse.ArgumentParser()\n        self.initialized = False\n\n    def initialize(self):\n        # experiment specifics\n        self.parser.add_argument(\n            \"--name\",\n            type=str,\n            default=\"label2city\",\n            help=\"name of the experiment. It decides where to store samples and models\",\n        )\n        self.parser.add_argument(\n            \"--gpu_ids\", type=str, default=\"0\", help=\"gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU\"\n        )\n        self.parser.add_argument(\n            \"--checkpoints_dir\", type=str, default=\"./checkpoints\", help=\"models are saved here\"\n        )  ## note: to add this param when using philly\n        # self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here')  ################### This is necessary for philly\n        self.parser.add_argument(\n            \"--outputs_dir\", type=str, default=\"./outputs\", help=\"models are saved here\"\n        )  ## note: to add this param when using philly  Please end with '/'\n        self.parser.add_argument(\"--model\", type=str, default=\"pix2pixHD\", help=\"which model to use\")\n        self.parser.add_argument(\n            \"--norm\", type=str, default=\"instance\", help=\"instance normalization or batch normalization\"\n        )\n        self.parser.add_argument(\"--use_dropout\", action=\"store_true\", help=\"use dropout for the generator\")\n        self.parser.add_argument(\n            \"--data_type\",\n            default=32,\n            type=int,\n            choices=[8, 16, 32],\n            help=\"Supported data type i.e. 8, 16, 32 bit\",\n        )\n        self.parser.add_argument(\"--verbose\", action=\"store_true\", default=False, help=\"toggles verbose\")\n\n        # input/output sizes\n        self.parser.add_argument(\"--batchSize\", type=int, default=1, help=\"input batch size\")\n        self.parser.add_argument(\"--loadSize\", type=int, default=1024, help=\"scale images to this size\")\n        self.parser.add_argument(\"--fineSize\", type=int, default=512, help=\"then crop to this size\")\n        self.parser.add_argument(\"--label_nc\", type=int, default=35, help=\"# of input label channels\")\n        self.parser.add_argument(\"--input_nc\", type=int, default=3, help=\"# of input image channels\")\n        self.parser.add_argument(\"--output_nc\", type=int, default=3, help=\"# of output image channels\")\n\n        # for setting inputs\n        self.parser.add_argument(\"--dataroot\", type=str, default=\"./datasets/cityscapes/\")\n        self.parser.add_argument(\n            \"--resize_or_crop\",\n            type=str,\n            default=\"scale_width\",\n            help=\"scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]\",\n        )\n        self.parser.add_argument(\n            \"--serial_batches\",\n            action=\"store_true\",\n            help=\"if true, takes images in order to make batches, otherwise takes them randomly\",\n        )\n        self.parser.add_argument(\n            \"--no_flip\",\n            action=\"store_true\",\n            help=\"if specified, do not flip the images for data argumentation\",\n        )\n        self.parser.add_argument(\"--nThreads\", default=2, type=int, help=\"# threads for loading data\")\n        self.parser.add_argument(\n            \"--max_dataset_size\",\n            type=int,\n            default=float(\"inf\"),\n            help=\"Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.\",\n        )\n\n        # for displays\n        self.parser.add_argument(\"--display_winsize\", type=int, default=512, help=\"display window size\")\n        self.parser.add_argument(\n            \"--tf_log\",\n            action=\"store_true\",\n            help=\"if specified, use tensorboard logging. Requires tensorflow installed\",\n        )\n\n        # for generator\n        self.parser.add_argument(\"--netG\", type=str, default=\"global\", help=\"selects model to use for netG\")\n        self.parser.add_argument(\"--ngf\", type=int, default=64, help=\"# of gen filters in first conv layer\")\n        self.parser.add_argument(\"--k_size\", type=int, default=3, help=\"# kernel size conv layer\")\n        self.parser.add_argument(\"--use_v2\", action=\"store_true\", help=\"use DCDCv2\")\n        self.parser.add_argument(\"--mc\", type=int, default=1024, help=\"# max channel\")\n        self.parser.add_argument(\"--start_r\", type=int, default=3, help=\"start layer to use resblock\")\n        self.parser.add_argument(\n            \"--n_downsample_global\", type=int, default=4, help=\"number of downsampling layers in netG\"\n        )\n        self.parser.add_argument(\n            \"--n_blocks_global\",\n            type=int,\n            default=9,\n            help=\"number of residual blocks in the global generator network\",\n        )\n        self.parser.add_argument(\n            \"--n_blocks_local\",\n            type=int,\n            default=3,\n            help=\"number of residual blocks in the local enhancer network\",\n        )\n        self.parser.add_argument(\n            \"--n_local_enhancers\", type=int, default=1, help=\"number of local enhancers to use\"\n        )\n        self.parser.add_argument(\n            \"--niter_fix_global\",\n            type=int,\n            default=0,\n            help=\"number of epochs that we only train the outmost local enhancer\",\n        )\n\n        self.parser.add_argument(\n            \"--load_pretrain\",\n            type=str,\n            default=\"\",\n            help=\"load the pretrained model from the specified location\",\n        )\n\n        # for instance-wise features\n        self.parser.add_argument(\n            \"--no_instance\", action=\"store_true\", help=\"if specified, do *not* add instance map as input\"\n        )\n        self.parser.add_argument(\n            \"--instance_feat\",\n            action=\"store_true\",\n            help=\"if specified, add encoded instance features as input\",\n        )\n        self.parser.add_argument(\n            \"--label_feat\", action=\"store_true\", help=\"if specified, add encoded label features as input\"\n        )\n        self.parser.add_argument(\"--feat_num\", type=int, default=3, help=\"vector length for encoded features\")\n        self.parser.add_argument(\n            \"--load_features\", action=\"store_true\", help=\"if specified, load precomputed feature maps\"\n        )\n        self.parser.add_argument(\n            \"--n_downsample_E\", type=int, default=4, help=\"# of downsampling layers in encoder\"\n        )\n        self.parser.add_argument(\n            \"--nef\", type=int, default=16, help=\"# of encoder filters in the first conv layer\"\n        )\n        self.parser.add_argument(\"--n_clusters\", type=int, default=10, help=\"number of clusters for features\")\n\n        # diy\n        self.parser.add_argument(\"--self_gen\", action=\"store_true\", help=\"self generate\")\n        self.parser.add_argument(\n            \"--mapping_n_block\", type=int, default=3, help=\"number of resblock in mapping\"\n        )\n        self.parser.add_argument(\"--map_mc\", type=int, default=64, help=\"max channel of mapping\")\n        self.parser.add_argument(\"--kl\", type=float, default=0, help=\"KL Loss\")\n        self.parser.add_argument(\n            \"--load_pretrainA\",\n            type=str,\n            default=\"\",\n            help=\"load the pretrained model from the specified location\",\n        )\n        self.parser.add_argument(\n            \"--load_pretrainB\",\n            type=str,\n            default=\"\",\n            help=\"load the pretrained model from the specified location\",\n        )\n        self.parser.add_argument(\"--feat_gan\", action=\"store_true\")\n        self.parser.add_argument(\"--no_cgan\", action=\"store_true\")\n        self.parser.add_argument(\"--map_unet\", action=\"store_true\")\n        self.parser.add_argument(\"--map_densenet\", action=\"store_true\")\n        self.parser.add_argument(\"--fcn\", action=\"store_true\")\n        self.parser.add_argument(\"--is_image\", action=\"store_true\", help=\"train image recon only pair data\")\n        self.parser.add_argument(\"--label_unpair\", action=\"store_true\")\n        self.parser.add_argument(\"--mapping_unpair\", action=\"store_true\")\n        self.parser.add_argument(\"--unpair_w\", type=float, default=1.0)\n        self.parser.add_argument(\"--pair_num\", type=int, default=-1)\n        self.parser.add_argument(\"--Gan_w\", type=float, default=1)\n        self.parser.add_argument(\"--feat_dim\", type=int, default=-1)\n        self.parser.add_argument(\"--abalation_vae_len\", type=int, default=-1)\n\n        ######################### useless, just to cooperate with docker\n        self.parser.add_argument(\"--gpu\", type=str)\n        self.parser.add_argument(\"--dataDir\", type=str)\n        self.parser.add_argument(\"--modelDir\", type=str)\n        self.parser.add_argument(\"--logDir\", type=str)\n        self.parser.add_argument(\"--data_dir\", type=str)\n\n        self.parser.add_argument(\"--use_skip_model\", action=\"store_true\")\n        self.parser.add_argument(\"--use_segmentation_model\", action=\"store_true\")\n\n        self.parser.add_argument(\"--spatio_size\", type=int, default=64)\n        self.parser.add_argument(\"--test_random_crop\", action=\"store_true\")\n        ##########################\n\n        self.parser.add_argument(\"--contain_scratch_L\", action=\"store_true\")\n        self.parser.add_argument(\n            \"--mask_dilation\", type=int, default=0\n        )  ## Don't change the input, only dilation the mask\n\n        self.parser.add_argument(\n            \"--irregular_mask\", type=str, default=\"\", help=\"This is the root of the mask\"\n        )\n        self.parser.add_argument(\n            \"--mapping_net_dilation\",\n            type=int,\n            default=1,\n            help=\"This parameter is the dilation size of the translation net\",\n        )\n\n        self.parser.add_argument(\n            \"--VOC\", type=str, default=\"VOC_RGB_JPEGImages.bigfile\", help=\"The root of VOC dataset\"\n        )\n\n        self.parser.add_argument(\"--non_local\", type=str, default=\"\", help=\"which non_local setting\")\n        self.parser.add_argument(\n            \"--NL_fusion_method\",\n            type=str,\n            default=\"add\",\n            help=\"how to fuse the origin feature and nl feature\",\n        )\n        self.parser.add_argument(\n            \"--NL_use_mask\", action=\"store_true\", help=\"If use mask while using Non-local mapping model\"\n        )\n        self.parser.add_argument(\n            \"--correlation_renormalize\",\n            action=\"store_true\",\n            help=\"Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight\",\n        )\n\n        self.parser.add_argument(\"--Smooth_L1\", action=\"store_true\", help=\"Use L1 Loss in image level\")\n\n        self.parser.add_argument(\n            \"--face_restore_setting\", type=int, default=1, help=\"This is for the aligned face restoration\"\n        )\n        self.parser.add_argument(\"--face_clean_url\", type=str, default=\"\")\n        self.parser.add_argument(\"--syn_input_url\", type=str, default=\"\")\n        self.parser.add_argument(\"--syn_gt_url\", type=str, default=\"\")\n\n        self.parser.add_argument(\n            \"--test_on_synthetic\",\n            action=\"store_true\",\n            help=\"If you want to test on the synthetic data, enable this parameter\",\n        )\n\n        self.parser.add_argument(\"--use_SN\", action=\"store_true\", help=\"Add SN to every parametric layer\")\n\n        self.parser.add_argument(\n            \"--use_two_stage_mapping\", action=\"store_true\", help=\"choose the model which uses two stage\"\n        )\n\n        self.parser.add_argument(\"--L1_weight\", type=float, default=10.0)\n        self.parser.add_argument(\"--softmax_temperature\", type=float, default=1.0)\n        self.parser.add_argument(\n            \"--patch_similarity\",\n            action=\"store_true\",\n            help=\"Enable this denotes using 3*3 patch to calculate similarity\",\n        )\n        self.parser.add_argument(\n            \"--use_self\",\n            action=\"store_true\",\n            help=\"Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)\",\n        )\n\n        self.parser.add_argument(\"--use_own_dataset\", action=\"store_true\")\n\n        self.parser.add_argument(\n            \"--test_hole_two_folders\",\n            action=\"store_true\",\n            help=\"Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively\",\n        )\n\n        self.parser.add_argument(\n            \"--no_hole\",\n            action=\"store_true\",\n            help=\"While test the full_model on non_scratch data, do not add random mask into the real old photos\",\n        )  ## Only for testing\n        self.parser.add_argument(\n            \"--random_hole\",\n            action=\"store_true\",\n            help=\"While training the full model, 50% probability add hole\",\n        )\n\n        self.parser.add_argument(\"--NL_res\", action=\"store_true\", help=\"NL+Resdual Block\")\n\n        self.parser.add_argument(\"--image_L1\", action=\"store_true\", help=\"Image level loss: L1\")\n        self.parser.add_argument(\n            \"--hole_image_no_mask\",\n            action=\"store_true\",\n            help=\"while testing, give hole image but not give the mask\",\n        )\n\n        self.parser.add_argument(\n            \"--down_sample_degradation\",\n            action=\"store_true\",\n            help=\"down_sample the image only, corresponds to [down_sample_face]\",\n        )\n\n        self.parser.add_argument(\n            \"--norm_G\", type=str, default=\"spectralinstance\", help=\"The norm type of Generator\"\n        )\n        self.parser.add_argument(\n            \"--init_G\",\n            type=str,\n            default=\"xavier\",\n            help=\"normal|xavier|xavier_uniform|kaiming|orthogonal|none\",\n        )\n\n        self.parser.add_argument(\"--use_new_G\", action=\"store_true\")\n        self.parser.add_argument(\"--use_new_D\", action=\"store_true\")\n\n        self.parser.add_argument(\n            \"--only_voc\", action=\"store_true\", help=\"test the trianed celebA face model using VOC face\"\n        )\n\n        self.parser.add_argument(\n            \"--cosin_similarity\",\n            action=\"store_true\",\n            help=\"For non-local, using cosin to calculate the similarity\",\n        )\n\n        self.parser.add_argument(\n            \"--downsample_mode\",\n            type=str,\n            default=\"nearest\",\n            help=\"For partial non-local, choose how to downsample the mask\",\n        )\n\n        self.parser.add_argument(\"--mapping_exp\",type=int,default=0,help='Default 0: original PNL|1: Multi-Scale Patch Attention')\n        self.parser.add_argument(\"--inference_optimize\",action='store_true',help='optimize the memory cost')\n\n\n        self.initialized = True\n\n    def parse(self, save=True):\n        if not self.initialized:\n            self.initialize()\n        self.opt = self.parser.parse_args()\n        self.opt.isTrain = self.isTrain  # train or test\n\n        str_ids = self.opt.gpu_ids.split(\",\")\n        self.opt.gpu_ids = []\n        for str_id in str_ids:\n            int_id = int(str_id)\n            if int_id >= 0:\n                self.opt.gpu_ids.append(int_id)\n\n        # set gpu ids\n        if len(self.opt.gpu_ids) > 0:\n            # pass\n            torch.cuda.set_device(self.opt.gpu_ids[0])\n\n        args = vars(self.opt)\n\n        # print('------------ Options -------------')\n        # for k, v in sorted(args.items()):\n        #     print('%s: %s' % (str(k), str(v)))\n        # print('-------------- End ----------------')\n\n        # save to the disk\n        expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)\n        util.mkdirs(expr_dir)\n        if save and not self.opt.continue_train:\n            file_name = os.path.join(expr_dir, \"opt.txt\")\n            with open(file_name, \"wt\") as opt_file:\n                opt_file.write(\"------------ Options -------------\\n\")\n                for k, v in sorted(args.items()):\n                    opt_file.write(\"%s: %s\\n\" % (str(k), str(v)))\n                opt_file.write(\"-------------- End ----------------\\n\")\n        return self.opt\n"
  },
  {
    "path": "Global/options/test_options.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.initialize(self)\n        self.parser.add_argument(\"--ntest\", type=int, default=float(\"inf\"), help=\"# of test examples.\")\n        self.parser.add_argument(\"--results_dir\", type=str, default=\"./results/\", help=\"saves results here.\")\n        self.parser.add_argument(\n            \"--aspect_ratio\", type=float, default=1.0, help=\"aspect ratio of result images\"\n        )\n        self.parser.add_argument(\"--phase\", type=str, default=\"test\", help=\"train, val, test, etc\")\n        self.parser.add_argument(\n            \"--which_epoch\",\n            type=str,\n            default=\"latest\",\n            help=\"which epoch to load? set to latest to use latest cached model\",\n        )\n        self.parser.add_argument(\"--how_many\", type=int, default=50, help=\"how many test images to run\")\n        self.parser.add_argument(\n            \"--cluster_path\",\n            type=str,\n            default=\"features_clustered_010.npy\",\n            help=\"the path for clustered results of encoded features\",\n        )\n        self.parser.add_argument(\n            \"--use_encoded_image\",\n            action=\"store_true\",\n            help=\"if specified, encode the real image to get the feature map\",\n        )\n        self.parser.add_argument(\"--export_onnx\", type=str, help=\"export ONNX model to a given file\")\n        self.parser.add_argument(\"--engine\", type=str, help=\"run serialized TRT engine\")\n        self.parser.add_argument(\"--onnx\", type=str, help=\"run ONNX model via TRT\")\n        self.parser.add_argument(\n            \"--start_epoch\",\n            type=int,\n            default=-1,\n            help=\"write the start_epoch of iter.txt into this parameter\",\n        )\n\n        self.parser.add_argument(\"--test_dataset\", type=str, default=\"Real_RGB_old.bigfile\")\n        self.parser.add_argument(\n            \"--no_degradation\",\n            action=\"store_true\",\n            help=\"when train the mapping, enable this parameter --> no degradation will be added into clean image\",\n        )\n        self.parser.add_argument(\n            \"--no_load_VAE\",\n            action=\"store_true\",\n            help=\"when train the mapping, enable this parameter --> random initialize the encoder an decoder\",\n        )\n        self.parser.add_argument(\n            \"--use_v2_degradation\",\n            action=\"store_true\",\n            help=\"enable this parameter --> 4 kinds of degradations will be used to synthesize corruption\",\n        )\n        self.parser.add_argument(\"--use_vae_which_epoch\", type=str, default=\"latest\")\n        self.isTrain = False\n\n        self.parser.add_argument(\"--generate_pair\", action=\"store_true\")\n\n        self.parser.add_argument(\"--multi_scale_test\", type=float, default=0.5)\n        self.parser.add_argument(\"--multi_scale_threshold\", type=float, default=0.5)\n        self.parser.add_argument(\n            \"--mask_need_scale\",\n            action=\"store_true\",\n            help=\"enable this param meas that the pixel range of mask is 0-255\",\n        )\n        self.parser.add_argument(\"--scale_num\", type=int, default=1)\n\n        self.parser.add_argument(\n            \"--save_feature_url\", type=str, default=\"\", help=\"While extracting the features, where to put\"\n        )\n\n        self.parser.add_argument(\n            \"--test_input\", type=str, default=\"\", help=\"A directory or a root of bigfile\"\n        )\n        self.parser.add_argument(\"--test_mask\", type=str, default=\"\", help=\"A directory or a root of bigfile\")\n        self.parser.add_argument(\"--test_gt\", type=str, default=\"\", help=\"A directory or a root of bigfile\")\n\n        self.parser.add_argument(\n            \"--scale_input\", action=\"store_true\", help=\"While testing, choose to scale the input firstly\"\n        )\n\n        self.parser.add_argument(\n            \"--save_feature_name\", type=str, default=\"features.json\", help=\"The name of saved features\"\n        )\n        self.parser.add_argument(\n            \"--test_rgb_old_wo_scratch\", action=\"store_true\", help=\"Same setting with origin test\"\n        )\n\n        self.parser.add_argument(\"--test_mode\", type=str, default=\"Crop\", help=\"Scale|Full|Crop\")\n        self.parser.add_argument(\"--Quality_restore\", action=\"store_true\", help=\"For RGB images\")\n        self.parser.add_argument(\n            \"--Scratch_and_Quality_restore\", action=\"store_true\", help=\"For scratched images\"\n        )\n        self.parser.add_argument(\"--HR\", action='store_true',help='Large input size with scratches')\n"
  },
  {
    "path": "Global/options/train_options.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom .base_options import BaseOptions\n\nclass TrainOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.initialize(self)\n        # for displays\n        self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')\n        self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n        self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')\n        self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')\n        self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n        self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')\n\n        # for training\n        self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        # self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')\n        self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')\n        self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')\n        self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n        self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n        self.parser.add_argument('--training_dataset',type=str,default='',help='training use which dataset')\n\n        # for discriminators        \n        self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')\n        self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')\n        self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')    \n        self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')\n        self.parser.add_argument('--l2_feat', type=float, help='weight for feature mapping loss') \n        self.parser.add_argument('--use_l1_feat', action='store_true', help='use l1 for feat mapping')               \n        self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')\n        self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')        \n        self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')\n        self.parser.add_argument('--gan_type', type=str, default='lsgan', help='Choose the loss type of GAN')\n        self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')\n        self.parser.add_argument('--norm_D',type=str, default='spectralinstance', help='instance normalization or batch normalization')\n        self.parser.add_argument('--init_D',type=str,default='xavier',help='normal|xavier|xavier_uniform|kaiming|orthogonal|none')\n\n        self.parser.add_argument('--no_TTUR',action='store_true',help='No TTUR')\n\n        self.parser.add_argument('--start_epoch',type=int,default=-1,help='write the start_epoch of iter.txt into this parameter')\n        self.parser.add_argument('--no_degradation',action='store_true',help='when train the mapping, enable this parameter --> no degradation will be added into clean image')\n        self.parser.add_argument('--no_load_VAE',action='store_true',help='when train the mapping, enable this parameter --> random initialize the encoder an decoder')\n        self.parser.add_argument('--use_v2_degradation',action='store_true',help='enable this parameter --> 4 kinds of degradations will be used to synthesize corruption')\n        self.parser.add_argument('--use_vae_which_epoch',type=str,default='200')\n\n\n        self.parser.add_argument('--use_focal_loss',action='store_true')\n\n        self.parser.add_argument('--mask_need_scale',action='store_true',help='enable this param means that the pixel range of mask is 0-255')\n        self.parser.add_argument('--positive_weight',type=float,default=1.0,help='(For scratch detection) Since the scratch number is less, and we use a weight strategy. This parameter means that we want to decrease the weight.')\n\n        self.parser.add_argument('--no_update_lr',action='store_true',help='use this means we do not update the LR while training')\n\n\n        self.isTrain = True\n"
  },
  {
    "path": "Global/test.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nfrom options.test_options import TestOptions\nfrom models.models import create_model\nfrom models.mapping_model import Pix2PixHDModel_Mapping\nimport util.util as util\nfrom PIL import Image\nimport torch\nimport torchvision.utils as vutils\nimport torchvision.transforms as transforms\nimport numpy as np\nimport cv2\n\ndef data_transforms(img, method=Image.BILINEAR, scale=False):\n\n    ow, oh = img.size\n    pw, ph = ow, oh\n    if scale == True:\n        if ow < oh:\n            ow = 256\n            oh = ph / pw * 256\n        else:\n            oh = 256\n            ow = pw / ph * 256\n\n    h = int(round(oh / 4) * 4)\n    w = int(round(ow / 4) * 4)\n\n    if (h == ph) and (w == pw):\n        return img\n\n    return img.resize((w, h), method)\n\n\ndef data_transforms_rgb_old(img):\n    w, h = img.size\n    A = img\n    if w < 256 or h < 256:\n        A = transforms.Scale(256, Image.BILINEAR)(img)\n    return transforms.CenterCrop(256)(A)\n\n\ndef irregular_hole_synthesize(img, mask):\n\n    img_np = np.array(img).astype(\"uint8\")\n    mask_np = np.array(mask).astype(\"uint8\")\n    mask_np = mask_np / 255\n    img_new = img_np * (1 - mask_np) + mask_np * 255\n\n    hole_img = Image.fromarray(img_new.astype(\"uint8\")).convert(\"RGB\")\n\n    return hole_img\n\n\ndef parameter_set(opt):\n    ## Default parameters\n    opt.serial_batches = True  # no shuffle\n    opt.no_flip = True  # no flip\n    opt.label_nc = 0\n    opt.n_downsample_global = 3\n    opt.mc = 64\n    opt.k_size = 4\n    opt.start_r = 1\n    opt.mapping_n_block = 6\n    opt.map_mc = 512\n    opt.no_instance = True\n    opt.checkpoints_dir = \"./checkpoints/restoration\"\n    ##\n\n    if opt.Quality_restore:\n        opt.name = \"mapping_quality\"\n        opt.load_pretrainA = os.path.join(opt.checkpoints_dir, \"VAE_A_quality\")\n        opt.load_pretrainB = os.path.join(opt.checkpoints_dir, \"VAE_B_quality\")\n    if opt.Scratch_and_Quality_restore:\n        opt.NL_res = True\n        opt.use_SN = True\n        opt.correlation_renormalize = True\n        opt.NL_use_mask = True\n        opt.NL_fusion_method = \"combine\"\n        opt.non_local = \"Setting_42\"\n        opt.name = \"mapping_scratch\"\n        opt.load_pretrainA = os.path.join(opt.checkpoints_dir, \"VAE_A_quality\")\n        opt.load_pretrainB = os.path.join(opt.checkpoints_dir, \"VAE_B_scratch\")\n        if opt.HR:\n            opt.mapping_exp = 1\n            opt.inference_optimize = True\n            opt.mask_dilation = 3\n            opt.name = \"mapping_Patch_Attention\"\n\n\nif __name__ == \"__main__\":\n\n    opt = TestOptions().parse(save=False)\n    parameter_set(opt)\n\n    model = Pix2PixHDModel_Mapping()\n\n    model.initialize(opt)\n    model.eval()\n\n    if not os.path.exists(opt.outputs_dir + \"/\" + \"input_image\"):\n        os.makedirs(opt.outputs_dir + \"/\" + \"input_image\")\n    if not os.path.exists(opt.outputs_dir + \"/\" + \"restored_image\"):\n        os.makedirs(opt.outputs_dir + \"/\" + \"restored_image\")\n    if not os.path.exists(opt.outputs_dir + \"/\" + \"origin\"):\n        os.makedirs(opt.outputs_dir + \"/\" + \"origin\")\n\n    dataset_size = 0\n\n    input_loader = os.listdir(opt.test_input)\n    dataset_size = len(input_loader)\n    input_loader.sort()\n\n    if opt.test_mask != \"\":\n        mask_loader = os.listdir(opt.test_mask)\n        dataset_size = len(os.listdir(opt.test_mask))\n        mask_loader.sort()\n\n    img_transform = transforms.Compose(\n        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n    )\n    mask_transform = transforms.ToTensor()\n\n    for i in range(dataset_size):\n\n        input_name = input_loader[i]\n        input_file = os.path.join(opt.test_input, input_name)\n        if not os.path.isfile(input_file):\n            print(\"Skipping non-file %s\" % input_name)\n            continue\n        input = Image.open(input_file).convert(\"RGB\")\n\n        print(\"Now you are processing %s\" % (input_name))\n\n        if opt.NL_use_mask:\n            mask_name = mask_loader[i]\n            mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert(\"RGB\")\n            if opt.mask_dilation != 0:\n                kernel = np.ones((3,3),np.uint8)\n                mask = np.array(mask)\n                mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation)\n                mask = Image.fromarray(mask.astype('uint8'))\n            origin = input\n            input = irregular_hole_synthesize(input, mask)\n            mask = mask_transform(mask)\n            mask = mask[:1, :, :]  ## Convert to single channel\n            mask = mask.unsqueeze(0)\n            input = img_transform(input)\n            input = input.unsqueeze(0)\n        else:\n            if opt.test_mode == \"Scale\":\n                input = data_transforms(input, scale=True)\n            if opt.test_mode == \"Full\":\n                input = data_transforms(input, scale=False)\n            if opt.test_mode == \"Crop\":\n                input = data_transforms_rgb_old(input)\n            origin = input\n            input = img_transform(input)\n            input = input.unsqueeze(0)\n            mask = torch.zeros_like(input)\n        ### Necessary input\n\n        try:\n            with torch.no_grad():\n                generated = model.inference(input, mask)\n        except Exception as ex:\n            print(\"Skip %s due to an error:\\n%s\" % (input_name, str(ex)))\n            continue\n\n        if input_name.endswith(\".jpg\"):\n            input_name = input_name[:-4] + \".png\"\n\n        image_grid = vutils.save_image(\n            (input + 1.0) / 2.0,\n            opt.outputs_dir + \"/input_image/\" + input_name,\n            nrow=1,\n            padding=0,\n            normalize=True,\n        )\n        image_grid = vutils.save_image(\n            (generated.data.cpu() + 1.0) / 2.0,\n            opt.outputs_dir + \"/restored_image/\" + input_name,\n            nrow=1,\n            padding=0,\n            normalize=True,\n        )\n\n        origin.save(opt.outputs_dir + \"/origin/\" + input_name)"
  },
  {
    "path": "Global/train_domain_A.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDict\nfrom options.train_options import TrainOptions\nfrom data.data_loader import CreateDataLoader\nfrom models.models import create_da_model\nimport util.util as util\nfrom util.visualizer import Visualizer\nimport os\nimport numpy as np\nimport torch\nimport torchvision.utils as vutils\nfrom torch.autograd import Variable\n\nopt = TrainOptions().parse()\n\nif opt.debug:\n    opt.display_freq = 1\n    opt.print_freq = 1\n    opt.niter = 1\n    opt.niter_decay = 0\n    opt.max_dataset_size = 10\n\ndata_loader = CreateDataLoader(opt)\ndataset = data_loader.load_data()\ndataset_size = len(dataset) * opt.batchSize\nprint('#training images = %d' % dataset_size)\n\npath = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt')\nvisualizer = Visualizer(opt)\n\niter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')\nif opt.continue_train:\n    try:\n        start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int)\n    except:\n        start_epoch, epoch_iter = 1, 0\n    visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter))\nelse:\n    start_epoch, epoch_iter = 1, 0\n\n# opt.which_epoch=start_epoch-1\nmodel = create_da_model(opt)\nfd = open(path, 'w')\nfd.write(str(model.module.netG))\nfd.write(str(model.module.netD))\nfd.close()\n\ntotal_steps = (start_epoch - 1) * dataset_size + epoch_iter\n\ndisplay_delta = total_steps % opt.display_freq\nprint_delta = total_steps % opt.print_freq\nsave_delta = total_steps % opt.save_latest_freq\n\nfor epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):\n    epoch_start_time = time.time()\n    if epoch != start_epoch:\n        epoch_iter = epoch_iter % dataset_size\n    for i, data in enumerate(dataset, start=epoch_iter):\n        iter_start_time = time.time()\n        total_steps += opt.batchSize\n        epoch_iter += opt.batchSize\n\n        # whether to collect output images\n        save_fake = total_steps % opt.display_freq == display_delta\n\n        ############## Forward Pass ######################\n        losses, generated = model(Variable(data['label']), Variable(data['inst']),\n                                  Variable(data['image']), Variable(data['feat']), infer=save_fake)\n\n        # sum per device losses\n        losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]\n        loss_dict = dict(zip(model.module.loss_names, losses))\n\n        # calculate final loss scalar\n        loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5\n        loss_featD=(loss_dict['featD_fake'] + loss_dict['featD_real']) * 0.5\n        loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0) + loss_dict['G_KL'] + loss_dict['G_featD']\n\n        ############### Backward Pass ####################\n        # update generator weights\n        model.module.optimizer_G.zero_grad()\n        loss_G.backward()\n        model.module.optimizer_G.step()\n\n        # update discriminator weights\n        model.module.optimizer_D.zero_grad()\n        loss_D.backward()\n        model.module.optimizer_D.step()\n\n        model.module.optimizer_featD.zero_grad()\n        loss_featD.backward()\n        model.module.optimizer_featD.step()\n\n        # call([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=memory.used,memory.free\"])\n\n        ############## Display results and errors ##########\n        ### print out errors\n        if total_steps % opt.print_freq == print_delta:\n            errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}\n            t = (time.time() - iter_start_time) / opt.batchSize\n            visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr)\n            visualizer.plot_current_errors(errors, total_steps)\n\n        ### display output images\n        if save_fake:\n\n            if not os.path.exists(opt.outputs_dir + opt.name):\n                os.makedirs(opt.outputs_dir + opt.name)\n            imgs_num = data['label'].shape[0]\n            imgs = torch.cat((data['label'], generated.data.cpu(), data['image']), 0)\n\n            imgs = (imgs + 1.) / 2.0\n\n            try:\n                image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(\n                    total_steps) + '.png',\n                                               nrow=imgs_num, padding=0, normalize=True)\n            except OSError as err:\n                print(err)\n\n\n        if epoch_iter >= dataset_size:\n            break\n\n    # end of epoch\n    iter_end_time = time.time()\n    print('End of epoch %d / %d \\t Time Taken: %d sec' %\n          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))\n\n    ### save model for this epoch\n    if epoch % opt.save_epoch_freq == 0:\n        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))\n        model.module.save('latest')\n        model.module.save(epoch)\n        np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')\n\n    ### instead of only training the local enhancer, train the entire network after certain iterations\n    if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):\n        model.module.update_fixed_params()\n\n    ### linearly decay learning rate after certain iterations\n    if epoch > opt.niter:\n        model.module.update_learning_rate()\n\n"
  },
  {
    "path": "Global/train_domain_B.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDict\nfrom options.train_options import TrainOptions\nfrom data.data_loader import CreateDataLoader\nfrom models.models import create_model\nimport util.util as util\nfrom util.visualizer import Visualizer\nimport os\nimport numpy as np\nimport torch\nimport torchvision.utils as vutils\nfrom torch.autograd import Variable\nimport random\n\n\nopt = TrainOptions().parse()\n\nif opt.debug:\n    opt.display_freq = 1\n    opt.print_freq = 1\n    opt.niter = 1\n    opt.niter_decay = 0\n    opt.max_dataset_size = 10\n\ndata_loader = CreateDataLoader(opt)\ndataset = data_loader.load_data()\ndataset_size = len(dataset) * opt.batchSize\nprint('#training images = %d' % dataset_size)\n\npath = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt')\nvisualizer = Visualizer(opt)\n\n\niter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')\nif opt.continue_train:\n    try:\n        start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)\n    except:\n        start_epoch, epoch_iter = 1, 0\n    visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter))\nelse:\n    start_epoch, epoch_iter = 1, 0\n\n# opt.which_epoch=start_epoch-1\nmodel = create_model(opt)\nfd = open(path, 'w')\nfd.write(str(model.module.netG))\nfd.write(str(model.module.netD))\nfd.close()\n\ntotal_steps = (start_epoch-1) * dataset_size + epoch_iter\n\ndisplay_delta = total_steps % opt.display_freq\nprint_delta = total_steps % opt.print_freq\nsave_delta = total_steps % opt.save_latest_freq\n\nfor epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):\n    epoch_start_time = time.time()\n    if epoch != start_epoch:\n        epoch_iter = epoch_iter % dataset_size\n    for i, data in enumerate(dataset, start=epoch_iter):\n        iter_start_time = time.time()\n        total_steps += opt.batchSize\n        epoch_iter += opt.batchSize\n\n        # whether to collect output images\n        save_fake = total_steps % opt.display_freq == display_delta\n\n        ############## Forward Pass ######################\n        losses, generated = model(Variable(data['label']), Variable(data['inst']), \n            Variable(data['image']), Variable(data['feat']), infer=save_fake)\n\n        # sum per device losses\n        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]\n        loss_dict = dict(zip(model.module.loss_names, losses))\n\n\n        # calculate final loss scalar\n        loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5\n        loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict['G_KL'] + loss_dict.get('Smooth_L1',0)*opt.L1_weight\n\n\n        ############### Backward Pass ####################\n        # update generator weights\n        model.module.optimizer_G.zero_grad()\n        loss_G.backward()\n        model.module.optimizer_G.step()\n\n        # update discriminator weights\n        model.module.optimizer_D.zero_grad()\n        loss_D.backward()\n        model.module.optimizer_D.step()\n\n        #call([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=memory.used,memory.free\"]) \n\n        ############## Display results and errors ##########\n        ### print out errors\n        if total_steps % opt.print_freq == print_delta:\n            errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}\n            t = (time.time() - iter_start_time) / opt.batchSize\n            visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr)\n            visualizer.plot_current_errors(errors, total_steps)\n\n        ### display output images\n        if save_fake:\n\n            if not os.path.exists(opt.outputs_dir + opt.name):\n                os.makedirs(opt.outputs_dir + opt.name)\n            imgs_num = 5\n            imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0)\n\n            imgs = (imgs + 1.) / 2.0\n\n            try:\n                image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png',\n                        nrow=imgs_num, padding=0, normalize=True)\n            except OSError as err:\n                print(err)\n\n        if epoch_iter >= dataset_size:\n            break\n       \n    # end of epoch \n    iter_end_time = time.time()\n    print('End of epoch %d / %d \\t Time Taken: %d sec' %\n          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))\n\n    ### save model for this epoch\n    if epoch % opt.save_epoch_freq == 0:\n        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        \n        model.module.save('latest')\n        model.module.save(epoch)\n        np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')\n\n    ### instead of only training the local enhancer, train the entire network after certain iterations\n    if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):\n        model.module.update_fixed_params()\n\n    ### linearly decay learning rate after certain iterations\n    if epoch > opt.niter:\n        model.module.update_learning_rate()\n\n"
  },
  {
    "path": "Global/train_mapping.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport time\nfrom collections import OrderedDict\nfrom options.train_options import TrainOptions\nfrom data.data_loader import CreateDataLoader\nfrom models.mapping_model import Pix2PixHDModel_Mapping\nimport util.util as util\nfrom util.visualizer import Visualizer\nimport os\nimport numpy as np\nimport torch\nimport torchvision.utils as vutils\nfrom torch.autograd import Variable\nimport datetime\nimport random\n\n\n\nopt = TrainOptions().parse()\nvisualizer = Visualizer(opt)\niter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')\nif opt.continue_train:\n    try:\n        start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)\n    except:\n        start_epoch, epoch_iter = 1, 0\n    visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter))\nelse:\n    start_epoch, epoch_iter = 1, 0\n\nif opt.which_epoch != \"latest\":\n    start_epoch=int(opt.which_epoch)\n    visualizer.print_save('Notice : Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter))\n\nopt.start_epoch=start_epoch\n### temp for continue train unfixed decoder\n\ndata_loader = CreateDataLoader(opt)\ndataset = data_loader.load_data()\ndataset_size = len(dataset) * opt.batchSize\nprint('#training images = %d' % dataset_size)\n\n\nmodel = Pix2PixHDModel_Mapping()\nmodel.initialize(opt)\n\npath = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt')\nfd = open(path, 'w')\n\nif opt.use_skip_model:\n    fd.write(str(model.mapping_net))\n    fd.close()\nelse:\n    fd.write(str(model.netG_A))\n    fd.write(str(model.mapping_net))\n    fd.close()\n\nif opt.isTrain and len(opt.gpu_ids) > 1:\n    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)\n\n\n\ntotal_steps = (start_epoch-1) * dataset_size + epoch_iter\n\ndisplay_delta = total_steps % opt.display_freq\nprint_delta = total_steps % opt.print_freq\nsave_delta = total_steps % opt.save_latest_freq\n### used for recovering training\n\nfor epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):\n    epoch_s_t=datetime.datetime.now()\n    epoch_start_time = time.time()\n    if epoch != start_epoch:\n        epoch_iter = epoch_iter % dataset_size\n    for i, data in enumerate(dataset, start=epoch_iter):\n        iter_start_time = time.time()\n        total_steps += opt.batchSize\n        epoch_iter += opt.batchSize\n\n        # whether to collect output images\n        save_fake = total_steps % opt.display_freq == display_delta\n\n        ############## Forward Pass ######################\n        #print(pair)\n        losses, generated = model(Variable(data['label']), Variable(data['inst']), \n            Variable(data['image']), Variable(data['feat']), infer=save_fake)\n        \n        # sum per device losses\n        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]\n        loss_dict = dict(zip(model.module.loss_names, losses))\n\n        # calculate final loss scalar\n        loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5\n        loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict.get('G_Feat_L2', 0) +loss_dict.get('Smooth_L1', 0)+loss_dict.get('G_Feat_L2_Stage_1',0)\n        #loss_G = loss_dict['G_Feat_L2'] \n\n        ############### Backward Pass ####################\n        # update generator weights\n        model.module.optimizer_mapping.zero_grad()\n        loss_G.backward()\n        model.module.optimizer_mapping.step()\n\n        # update discriminator weights\n        model.module.optimizer_D.zero_grad()\n        loss_D.backward()\n        model.module.optimizer_D.step()\n\n        ############## Display results and errors ##########\n        ### print out errors\n        if i == 0 or total_steps % opt.print_freq == print_delta:\n            errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}\n            t = (time.time() - iter_start_time) / opt.batchSize\n            visualizer.print_current_errors(epoch, epoch_iter, errors, t,model.module.old_lr)\n            visualizer.plot_current_errors(errors, total_steps)\n\n        ### display output images\n        if save_fake:\n\n            if not os.path.exists(opt.outputs_dir + opt.name):\n                os.makedirs(opt.outputs_dir + opt.name)\n\n            imgs_num = 5\n            if opt.NL_use_mask:\n                mask=data['inst'][:imgs_num]\n                mask=mask.repeat(1,3,1,1)\n                imgs = torch.cat((data['label'][:imgs_num], mask,generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0)\n            else:\n                imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0)\n\n            imgs=(imgs+1.)/2.0   ## de-normalize\n\n            try:\n                image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png',\n                        nrow=imgs_num, padding=0, normalize=True)\n            except OSError as err:\n                print(err)\n\n        if epoch_iter >= dataset_size:\n            break\n       \n    # end of epoch\n    epoch_e_t=datetime.datetime.now()\n    iter_end_time = time.time()\n    print('End of epoch %d / %d \\t Time Taken: %s' %\n          (epoch, opt.niter + opt.niter_decay, str(epoch_e_t-epoch_s_t)))\n\n    ### save model for this epoch\n    if epoch % opt.save_epoch_freq == 0:\n        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        \n        model.module.save('latest')\n        model.module.save(epoch)\n        np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')\n\n    ### instead of only training the local enhancer, train the entire network after certain iterations\n    if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):\n        model.module.update_fixed_params()\n\n    ### linearly decay learning rate after certain iterations\n    if epoch > opt.niter:\n        model.module.update_learning_rate()"
  },
  {
    "path": "Global/util/__init__.py",
    "content": ""
  },
  {
    "path": "Global/util/image_pool.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport random\nimport torch\nfrom torch.autograd import Variable\n\n\nclass ImagePool:\n    def __init__(self, pool_size):\n        self.pool_size = pool_size\n        if self.pool_size > 0:\n            self.num_imgs = 0\n            self.images = []\n\n    def query(self, images):\n        if self.pool_size == 0:\n            return images\n        return_images = []\n        for image in images.data:\n            image = torch.unsqueeze(image, 0)\n            if self.num_imgs < self.pool_size:\n                self.num_imgs = self.num_imgs + 1\n                self.images.append(image)\n                return_images.append(image)\n            else:\n                p = random.uniform(0, 1)\n                if p > 0.5:\n                    random_id = random.randint(0, self.pool_size - 1)\n                    tmp = self.images[random_id].clone()\n                    self.images[random_id] = image\n                    return_images.append(tmp)\n                else:\n                    return_images.append(image)\n        return_images = Variable(torch.cat(return_images, 0))\n        return return_images\n"
  },
  {
    "path": "Global/util/util.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nfrom __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport numpy as np\nimport os\nimport torch.nn as nn\n\n# Converts a Tensor into a Numpy array\n# |imtype|: the desired type of the converted numpy array\ndef tensor2im(image_tensor, imtype=np.uint8, normalize=True):\n    if isinstance(image_tensor, list):\n        image_numpy = []\n        for i in range(len(image_tensor)):\n            image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))\n        return image_numpy\n    image_numpy = image_tensor.cpu().float().numpy()\n    if normalize:\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n    else:\n        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0\n    image_numpy = np.clip(image_numpy, 0, 255)\n    if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:\n        image_numpy = image_numpy[:, :, 0]\n    return image_numpy.astype(imtype)\n\n\n# Converts a one-hot tensor into a colorful label map\ndef tensor2label(label_tensor, n_label, imtype=np.uint8):\n    if n_label == 0:\n        return tensor2im(label_tensor, imtype)\n    label_tensor = label_tensor.cpu().float()\n    if label_tensor.size()[0] > 1:\n        label_tensor = label_tensor.max(0, keepdim=True)[1]\n    label_tensor = Colorize(n_label)(label_tensor)\n    label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))\n    return label_numpy.astype(imtype)\n\n\ndef save_image(image_numpy, image_path):\n    image_pil = Image.fromarray(image_numpy)\n    image_pil.save(image_path)\n\n\ndef mkdirs(paths):\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n"
  },
  {
    "path": "Global/util/visualizer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\n#from . import html\nimport scipy.misc\ntry:\n    from StringIO import StringIO  # Python 2.7\nexcept ImportError:\n    from io import BytesIO         # Python 3.x\n\nclass Visualizer():\n    def __init__(self, opt):\n        # self.opt = opt\n        self.tf_log = opt.tf_log\n        self.use_html = opt.isTrain and not opt.no_html\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        if self.tf_log:\n            import tensorflow as tf\n            self.tf = tf\n            self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')\n            self.writer = tf.summary.FileWriter(self.log_dir)\n\n        if self.use_html:\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    # |visuals|: dictionary of images to display or save\n    def display_current_results(self, visuals, epoch, step):\n        if self.tf_log: # show images in tensorboard output\n            img_summaries = []\n            for label, image_numpy in visuals.items():\n                # Write the image to a string\n                try:\n                    s = StringIO()\n                except:\n                    s = BytesIO()\n                scipy.misc.toimage(image_numpy).save(s, format=\"jpeg\")\n                # Create an Image object\n                img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])\n                # Create a Summary value\n                img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))\n\n            # Create and write Summary\n            summary = self.tf.Summary(value=img_summaries)\n            self.writer.add_summary(summary, step)\n\n        if self.use_html: # save images to a html file\n            for label, image_numpy in visuals.items():\n                if isinstance(image_numpy, list):\n                    for i in range(len(image_numpy)):\n                        img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))\n                        util.save_image(image_numpy[i], img_path)\n                else:\n                    img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))\n                    util.save_image(image_numpy, img_path)\n\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)\n            for n in range(epoch, 0, -1):\n                webpage.add_header('epoch [%d]' % n)\n                ims = []\n                txts = []\n                links = []\n\n                for label, image_numpy in visuals.items():\n                    if isinstance(image_numpy, list):\n                        for i in range(len(image_numpy)):\n                            img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)\n                            ims.append(img_path)\n                            txts.append(label+str(i))\n                            links.append(img_path)\n                    else:\n                        img_path = 'epoch%.3d_%s.jpg' % (n, label)\n                        ims.append(img_path)\n                        txts.append(label)\n                        links.append(img_path)\n                if len(ims) < 10:\n                    webpage.add_images(ims, txts, links, width=self.win_size)\n                else:\n                    num = int(round(len(ims)/2.0))\n                    webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)\n                    webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)\n            webpage.save()\n\n    # errors: dictionary of error labels and values\n    def plot_current_errors(self, errors, step):\n        if self.tf_log:\n            for tag, value in errors.items():\n                summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])\n                self.writer.add_summary(summary, step)\n\n    # errors: same format as |errors| of plotCurrentErrors\n    def print_current_errors(self, epoch, i, errors, t, lr):\n        message = '(epoch: %d, iters: %d, time: %.3f lr: %.5f) ' % (epoch, i, t, lr)\n        for k, v in errors.items():\n            if v != 0:\n                message += '%s: %.3f ' % (k, v)\n\n        print(message)\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)\n\n\n    def print_save(self,message):\n\n        print(message)\n\n        with open(self.log_name,\"a\") as log_file:\n            log_file.write('%s\\n'%message)\n\n\n    # save image to the disk\n    def save_images(self, webpage, visuals, image_path):\n        image_dir = webpage.get_image_dir()\n        short_path = ntpath.basename(image_path[0])\n        name = os.path.splitext(short_path)[0]\n\n        webpage.add_header(name)\n        ims = []\n        txts = []\n        links = []\n\n        for label, image_numpy in visuals.items():\n            image_name = '%s_%s.jpg' % (name, label)\n            save_path = os.path.join(image_dir, image_name)\n            util.save_image(image_numpy, save_path)\n\n            ims.append(image_name)\n            txts.append(label)\n            links.append(image_name)\n        webpage.add_images(ims, txts, links, width=self.win_size)\n"
  },
  {
    "path": "LICENSE",
    "content": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "README.md",
    "content": "# Old Photo Restoration (Official PyTorch Implementation)\n\n<img src='imgs/0001.jpg'/>\n\n### [Project Page](http://raywzy.com/Old_Photo/) | [Paper (CVPR version)](https://arxiv.org/abs/2004.09484) | [Paper (Journal version)](https://arxiv.org/pdf/2009.07047v1.pdf) | [Pretrained Model](https://hkustconnect-my.sharepoint.com/:f:/g/personal/bzhangai_connect_ust_hk/Em0KnYOeSSxFtp4g_dhWdf0BdeT3tY12jIYJ6qvSf300cA?e=nXkJH2) | [Colab Demo](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing)  | [Replicate Demo & Docker Image](https://replicate.ai/zhangmozhe/bringing-old-photos-back-to-life) :fire:\n\n**Bringing Old Photos Back to Life, CVPR2020 (Oral)**\n\n**Old Photo Restoration via Deep Latent Space Translation, TPAMI 2022**\n\n[Ziyu Wan](http://raywzy.com/)<sup>1</sup>,\n[Bo Zhang](https://www.microsoft.com/en-us/research/people/zhanbo/)<sup>2</sup>,\n[Dongdong Chen](http://www.dongdongchen.bid/)<sup>3</sup>,\n[Pan Zhang](https://panzhang0212.github.io/)<sup>4</sup>,\n[Dong Chen](https://www.microsoft.com/en-us/research/people/doch/)<sup>2</sup>,\n[Jing Liao](https://liaojing.github.io/html/)<sup>1</sup>,\n[Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)<sup>2</sup> <br>\n<sup>1</sup>City University of Hong Kong, <sup>2</sup>Microsoft Research Asia, <sup>3</sup>Microsoft Cloud AI, <sup>4</sup>USTC\n\n<!-- ## Notes of this project\nThe code originates from our research project and the aim is to demonstrate the research idea, so we have not optimized it from a product perspective. And we will spend time to address some common issues, such as out of memory issue, limited resolution, but will not involve too much in engineering problems, such as speedup of the inference, fastapi deployment and so on. **We welcome volunteers to contribute to this project to make it more usable for practical application.** -->\n\n## :sparkles: News\n**2022.3.31**: Our new work regarding old film restoration will be published in CVPR 2022. For more details, please refer to the [project website](http://raywzy.com/Old_Film/) and [github repo](https://github.com/raywzy/Bringing-Old-Films-Back-to-Life).\n\nThe framework now supports the restoration of high-resolution input.\n\n<img src='imgs/HR_result.png'>\n\nTraining code is available and welcome to have a try and learn the training details. \n\nYou can now play with our [Colab](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing) and try it on your photos. \n\n## Requirement\nThe code is tested on Ubuntu with Nvidia GPUs and CUDA installed. Python>=3.6 is required to run the code.\n\n## Installation\n\nClone the Synchronized-BatchNorm-PyTorch repository for\n\n```\ncd Face_Enhancement/models/networks/\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .\ncd ../../../\n```\n\n```\ncd Global/detection_models\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .\ncd ../../\n```\n\nDownload the landmark detection pretrained model\n\n```\ncd Face_Detection/\nwget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\nbzip2 -d shape_predictor_68_face_landmarks.dat.bz2\ncd ../\n```\n\nDownload the pretrained model, put the file `Face_Enhancement/checkpoints.zip` under `./Face_Enhancement`, and put the file `Global/checkpoints.zip` under `./Global`. Then unzip them respectively.\n\n```\ncd Face_Enhancement/\nwget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip\nunzip face_checkpoints.zip\ncd ../\ncd Global/\nwget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip\nunzip global_checkpoints.zip\ncd ../\n```\n\nInstall dependencies:\n\n```bash\npip install -r requirements.txt\n```\n\n## :rocket: How to use?\n\n**Note**: GPU can be set 0 or 0,1,2 or 0,2; use -1 for CPU\n\n### 1) Full Pipeline\n\nYou could easily restore the old photos with one simple command after installation and downloading the pretrained model.\n\nFor images without scratches:\n\n```\npython run.py --input_folder [test_image_folder_path] \\\n              --output_folder [output_path] \\\n              --GPU 0\n```\n\nFor scratched images:\n\n```\npython run.py --input_folder [test_image_folder_path] \\\n              --output_folder [output_path] \\\n              --GPU 0 \\\n              --with_scratch\n```\n\n**For high-resolution images with scratches**:\n\n```\npython run.py --input_folder [test_image_folder_path] \\\n              --output_folder [output_path] \\\n              --GPU 0 \\\n              --with_scratch \\\n              --HR\n```\n\nNote: Please try to use the absolute path. The final results will be saved in `./output_path/final_output/`. You could also check the produced results of different steps in `output_path`.\n\n### 2) Scratch Detection\n\nCurrently we don't plan to release the scratched old photos dataset with labels directly. If you want to get the paired data, you could use our pretrained model to test the collected images to obtain the labels.\n\n```\ncd Global/\npython detection.py --test_path [test_image_folder_path] \\\n                    --output_dir [output_path] \\\n                    --input_size [resize_256|full_size|scale_256]\n```\n\n<img src='imgs/scratch_detection.png'>\n\n### 3) Global Restoration\n\nA triplet domain translation network is proposed to solve both structured degradation and unstructured degradation of old photos.\n\n<p align=\"center\">\n<img src='imgs/pipeline.PNG' width=\"50%\" height=\"50%\"/>\n</p>\n\n```\ncd Global/\npython test.py --Scratch_and_Quality_restore \\\n               --test_input [test_image_folder_path] \\\n               --test_mask [corresponding mask] \\\n               --outputs_dir [output_path]\n\npython test.py --Quality_restore \\\n               --test_input [test_image_folder_path] \\\n               --outputs_dir [output_path]\n```\n\n<img src='imgs/global.png'>\n\n\n### 4) Face Enhancement\n\nWe use a progressive generator to refine the face regions of old photos. More details could be found in our journal submission and `./Face_Enhancement` folder.\n\n<p align=\"center\">\n<img src='imgs/face_pipeline.jpg' width=\"60%\" height=\"60%\"/>\n</p>\n\n\n<img src='imgs/face.png'>\n\n> *NOTE*: \n> This repo is mainly for research purpose and we have not yet optimized the running performance. \n> \n> Since the model is pretrained with 256*256 images, the model may not work ideally for arbitrary resolution.\n\n### 5) GUI\n\nA user-friendly GUI which takes input of image by user and shows result in respective window.\n\n#### How it works:\n\n1. Run GUI.py file.\n2. Click browse and select your image from test_images/old_w_scratch folder to remove scratches.\n3. Click Modify Photo button.\n4. Wait for a while and see results on GUI window.\n5. Exit window by clicking Exit Window and get your result image in output folder.\n\n<img src='imgs/gui.PNG'>\n\n## How to train?\n\n### 1) Create Training File\n\nPut the folders of VOC dataset, collected old photos (e.g., Real_L_old and Real_RGB_old) into one shared folder. Then\n```\ncd Global/data/\npython Create_Bigfile.py\n```\nNote: Remember to modify the code based on your own environment.\n\n### 2) Train the VAEs of domain A and domain B respectively\n\n```\ncd ..\npython train_domain_A.py --use_v2_degradation --continue_train --training_dataset domain_A --name domainA_SR_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 100 --no_html --gpu_ids 0,1,2,3 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder]\n\npython train_domain_B.py --continue_train --training_dataset domain_B --name domainB_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder]  --no_instance --resize_or_crop crop_only --batchSize 120 --no_html --gpu_ids 0,1,2,3 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir [your_output_folder]  --checkpoints_dir [your_ckpt_folder]\n```\nNote: For the --name option, please ensure your experiment name contains \"domainA\" or \"domainB\", which will be used to select different dataset.\n\n### 3) Train the mapping network between domains\n\nTrain the mapping without scratches:\n```\npython train_mapping.py --use_v2_degradation --training_dataset mapping --use_vae_which_epoch 200 --continue_train --name mapping_quality --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 80 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder]\n```\n\n\nTraing the mapping with scraches:\n```\npython train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_scratch --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file]\n```\n\nTraing the mapping with scraches (Multi-Scale Patch Attention for HR input):\n```\npython train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_Patch_Attention --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] --mapping_exp 1\n```\n\n\n## Citation\n\nIf you find our work useful for your research, please consider citing the following papers :)\n\n```bibtex\n@inproceedings{wan2020bringing,\ntitle={Bringing Old Photos Back to Life},\nauthor={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang},\nbooktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\npages={2747--2757},\nyear={2020}\n}\n```\n\n```bibtex\n@article{wan2020old,\n  title={Old Photo Restoration via Deep Latent Space Translation},\n  author={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang},\n  journal={arXiv preprint arXiv:2009.07047},\n  year={2020}\n}\n```\n\nIf you are also interested in the legacy photo/video colorization, please refer to [this work](https://github.com/zhangmozhe/video-colorization).\n\n## Maintenance\n\nThis project is currently maintained by Ziyu Wan and is for academic research use only. If you have any questions, feel free to contact raywzy@gmail.com.\n\n## License\n\nThe codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file. We use our labeled dataset to train the scratch detection model.\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n"
  },
  {
    "path": "SECURITY.md",
    "content": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).\n\nIf you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.\n\n## Reporting Security Issues\n\n**Please do not report security vulnerabilities through public GitHub issues.**\n\nInstead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).\n\nIf you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).\n\nYou should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). \n\nPlease include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:\n\n  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)\n  * Full paths of source file(s) related to the manifestation of the issue\n  * The location of the affected source code (tag/branch/commit or direct URL)\n  * Any special configuration required to reproduce the issue\n  * Step-by-step instructions to reproduce the issue\n  * Proof-of-concept or exploit code (if possible)\n  * Impact of the issue, including how an attacker might exploit the issue\n\nThis information will help us triage your report more quickly.\n\nIf you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.\n\n## Preferred Languages\n\nWe prefer all communications to be in English.\n\n## Policy\n\nMicrosoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).\n\n<!-- END MICROSOFT SECURITY.MD BLOCK -->"
  },
  {
    "path": "ansible.yaml",
    "content": "---\r\n- name: Bringing-Old-Photos-Back-to-Life\r\n  hosts: all\r\n  gather_facts: no\r\n\r\n# Succesfully tested on Ubuntu 18.04\\20.04 and Debian 10 \r\n\r\n  pre_tasks: \r\n  - name: install packages\r\n    package:\r\n      name:\r\n        - python3\r\n        - python3-pip\r\n        - python3-venv\r\n        - git\r\n        - unzip\r\n        - tar\r\n        - lbzip2\r\n        - build-essential\r\n        - cmake\r\n        - ffmpeg\r\n        - libsm6\r\n        - libxext6\r\n        - libgl1-mesa-glx\r\n      state: latest\r\n    become: yes\r\n\r\n  tasks:\r\n  - name: git clone repo\r\n    git:\r\n      repo: 'https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life.git'\r\n      dest: Bringing-Old-Photos-Back-to-Life\r\n      clone: yes\r\n\r\n  - name: requirements setup\r\n    pip:\r\n      requirements: \"~/Bringing-Old-Photos-Back-to-Life/requirements.txt\"\r\n      virtualenv: \"~/Bringing-Old-Photos-Back-to-Life/.venv\"\r\n      virtualenv_command: /usr/bin/python3 -m venv .venv\r\n\r\n  - name: additional pip packages #requirements lack some packs\r\n    pip:\r\n      name: \r\n          - setuptools \r\n          - wheel\r\n          - scikit-build\r\n      virtualenv: \"~/Bringing-Old-Photos-Back-to-Life/.venv\"\r\n      virtualenv_command: /usr/bin/python3 -m venv .venv\r\n\r\n  - name: git clone batchnorm-pytorch\r\n    git:\r\n      repo: 'https://github.com/vacancy/Synchronized-BatchNorm-PyTorch'\r\n      dest: Synchronized-BatchNorm-PyTorch\r\n      clone: yes\r\n\r\n  - name: copy sync_batchnorm to face_enhancement\r\n    copy:\r\n      src: Synchronized-BatchNorm-PyTorch/sync_batchnorm\r\n      dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/models/networks/\r\n      remote_src: yes\r\n\r\n  - name: copy sync_batchnorm to global\r\n    copy:\r\n      src: Synchronized-BatchNorm-PyTorch/sync_batchnorm\r\n      dest: Bringing-Old-Photos-Back-to-Life/Global/detection_models\r\n      remote_src: yes\r\n\r\n  - name: check if shape_predictor_68_face_landmarks.dat\r\n    stat:\r\n      path: Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat\r\n    register: p\r\n\r\n  - name: get shape_predictor_68_face_landmarks.dat.bz2\r\n    get_url:\r\n      url: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\r\n      dest: Bringing-Old-Photos-Back-to-Life/Face_Detection/\r\n    when: p.stat.exists == False\r\n\r\n  - name: unarchive shape_predictor_68_face_landmarks.dat.bz2\r\n    shell: 'bzip2 -d Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat.bz2'\r\n    when: p.stat.exists == False\r\n\r\n  - name: check if face_enhancement\r\n    stat:\r\n      path: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/checkpoints/Setting_9_epoch_100/latest_net_G.pth\r\n    register: fc\r\n\r\n  - name: unarchive Face_Enhancement/checkpoints.zip\r\n    unarchive:\r\n      src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip\r\n      dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/\r\n      remote_src: yes\r\n    when: fc.stat.exists == False\r\n\r\n  - name: check if global\r\n    stat:\r\n      path: Bringing-Old-Photos-Back-to-Life/Global/checkpoints/detection/FT_Epoch_latest.pt\r\n    register: gc\r\n\r\n  - name: unarchive Global/checkpoints.zip\r\n    unarchive:\r\n      src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip\r\n      dest: Bringing-Old-Photos-Back-to-Life/Global/\r\n      remote_src: yes\r\n    when: gc.stat.exists == False\r\n\r\n# Do not forget to execute 'source .venv/bin/activate' inside Bringing-Old-Photos-Back-to-Life before starting run.py"
  },
  {
    "path": "cog.yaml",
    "content": "build:\n  gpu: true\n  python_version: \"3.8\"\n  system_packages:\n    - \"libgl1-mesa-glx\"\n    - \"libglib2.0-0\"\n  python_packages:\n    - \"cmake==3.21.2\"\n    - \"torchvision==0.9.0\"\n    - \"torch==1.8.0\"\n    - \"numpy==1.19.4\"\n    - \"opencv-python==4.4.0.46\"\n    - \"scipy==1.5.3\"\n    - \"tensorboardX==2.4\"\n    - \"dominate==2.6.0\"\n    - \"easydict==1.9\"\n    - \"PyYAML==5.3.1\"\n    - \"scikit-image==0.18.3\"\n    - \"dill==0.3.4\"\n    - \"einops==0.3.0\"\n    - \"PySimpleGUI==4.46.0\"\n    - \"ipython==7.19.0\"\n  run:\n    - pip install dlib\n\npredict: \"predict.py:Predictor\"\n"
  },
  {
    "path": "download-weights",
    "content": "#!/bin/sh\n\ncd Face_Enhancement/models/networks\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .\ncd ../../../\n\ncd Global/detection_models\ngit clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\ncp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .\ncd ../../\n\n# download the landmark detection model\ncd Face_Detection/\nwget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\nbzip2 -d shape_predictor_68_face_landmarks.dat.bz2\ncd ../\n\n# download the pretrained model\ncd Face_Enhancement/\nwget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip\nunzip checkpoints.zip\ncd ../\n\ncd Global/\nwget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip\nunzip checkpoints.zip\ncd ../\n"
  },
  {
    "path": "kubernetes-pod.yml",
    "content": "apiVersion: v1\nkind: Pod\nmetadata:\n  name: photo-back2life\nspec:\n  containers:\n    - name: photos-back2life\n      image: <YOUR IMAGE>\n      volumeMounts:\n      - mountPath: /in\n        name: in-folder\n      - mountPath: /out\n        name: out-folder\n      command: \n        - python\n        - /app/run.py\n      args:\n        - --input_folder\n        - /in\n        - --output_folder\n        - /out\n        - --GPU \n        - '0'\n        - --with_scratch\n      resources:\n        limits:\n          memory: 4Gi\n          cpu: 0\n          nvidia.com/gpu: 1      \n  volumes:\n  - name: in-folder\n    hostPath:\n      path: /srv/in\n      type: Directory   \n  - name: out-folder\n    hostPath:\n      path: /srv/out\n      type: Directory\n"
  },
  {
    "path": "predict.py",
    "content": "import tempfile\nfrom pathlib import Path\nimport argparse\nimport shutil\nimport os\nimport glob\nimport cv2\nimport cog\nfrom run import run_cmd\n\n\nclass Predictor(cog.Predictor):\n    def setup(self):\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\n            \"--input_folder\", type=str, default=\"input/cog_temp\", help=\"Test images\"\n        )\n        parser.add_argument(\n            \"--output_folder\",\n            type=str,\n            default=\"output\",\n            help=\"Restored images, please use the absolute path\",\n        )\n        parser.add_argument(\"--GPU\", type=str, default=\"0\", help=\"0,1,2\")\n        parser.add_argument(\n            \"--checkpoint_name\",\n            type=str,\n            default=\"Setting_9_epoch_100\",\n            help=\"choose which checkpoint\",\n        )\n        self.opts = parser.parse_args(\"\")\n        self.basepath = os.getcwd()\n        self.opts.input_folder = os.path.join(self.basepath, self.opts.input_folder)\n        self.opts.output_folder = os.path.join(self.basepath, self.opts.output_folder)\n        os.makedirs(self.opts.input_folder, exist_ok=True)\n        os.makedirs(self.opts.output_folder, exist_ok=True)\n\n    @cog.input(\"image\", type=Path, help=\"input image\")\n    @cog.input(\n        \"HR\",\n        type=bool,\n        default=False,\n        help=\"whether the input image is high-resolution\",\n    )\n    @cog.input(\n        \"with_scratch\",\n        type=bool,\n        default=False,\n        help=\"whether the input image is scratched\",\n    )\n    def predict(self, image, HR=False, with_scratch=False):\n        try:\n            os.chdir(self.basepath)\n            input_path = os.path.join(self.opts.input_folder, os.path.basename(image))\n            shutil.copy(str(image), input_path)\n\n            gpu1 = self.opts.GPU\n\n            ## Stage 1: Overall Quality Improve\n            print(\"Running Stage 1: Overall restoration\")\n            os.chdir(\"./Global\")\n            stage_1_input_dir = self.opts.input_folder\n            stage_1_output_dir = os.path.join(\n                self.opts.output_folder, \"stage_1_restore_output\"\n            )\n\n            os.makedirs(stage_1_output_dir, exist_ok=True)\n\n            if not with_scratch:\n\n                stage_1_command = (\n                        \"python test.py --test_mode Full --Quality_restore --test_input \"\n                        + stage_1_input_dir\n                        + \" --outputs_dir \"\n                        + stage_1_output_dir\n                        + \" --gpu_ids \"\n                        + gpu1\n                )\n                run_cmd(stage_1_command)\n            else:\n\n                mask_dir = os.path.join(stage_1_output_dir, \"masks\")\n                new_input = os.path.join(mask_dir, \"input\")\n                new_mask = os.path.join(mask_dir, \"mask\")\n                stage_1_command_1 = (\n                        \"python detection.py --test_path \"\n                        + stage_1_input_dir\n                        + \" --output_dir \"\n                        + mask_dir\n                        + \" --input_size full_size\"\n                        + \" --GPU \"\n                        + gpu1\n                )\n\n                if HR:\n                    HR_suffix = \" --HR\"\n                else:\n                    HR_suffix = \"\"\n\n                stage_1_command_2 = (\n                        \"python test.py --Scratch_and_Quality_restore --test_input \"\n                        + new_input\n                        + \" --test_mask \"\n                        + new_mask\n                        + \" --outputs_dir \"\n                        + stage_1_output_dir\n                        + \" --gpu_ids \"\n                        + gpu1\n                        + HR_suffix\n                )\n\n                run_cmd(stage_1_command_1)\n                run_cmd(stage_1_command_2)\n\n            ## Solve the case when there is no face in the old photo\n            stage_1_results = os.path.join(stage_1_output_dir, \"restored_image\")\n            stage_4_output_dir = os.path.join(self.opts.output_folder, \"final_output\")\n            os.makedirs(stage_4_output_dir, exist_ok=True)\n            for x in os.listdir(stage_1_results):\n                img_dir = os.path.join(stage_1_results, x)\n                shutil.copy(img_dir, stage_4_output_dir)\n\n            print(\"Finish Stage 1 ...\")\n            print(\"\\n\")\n\n            ## Stage 2: Face Detection\n\n            print(\"Running Stage 2: Face Detection\")\n            os.chdir(\".././Face_Detection\")\n            stage_2_input_dir = os.path.join(stage_1_output_dir, \"restored_image\")\n            stage_2_output_dir = os.path.join(\n                self.opts.output_folder, \"stage_2_detection_output\"\n            )\n            os.makedirs(stage_2_output_dir, exist_ok=True)\n\n            stage_2_command = (\n                    \"python detect_all_dlib_HR.py --url \"\n                    + stage_2_input_dir\n                    + \" --save_url \"\n                    + stage_2_output_dir\n            )\n\n            run_cmd(stage_2_command)\n            print(\"Finish Stage 2 ...\")\n            print(\"\\n\")\n\n            ## Stage 3: Face Restore\n            print(\"Running Stage 3: Face Enhancement\")\n            os.chdir(\".././Face_Enhancement\")\n            stage_3_input_mask = \"./\"\n            stage_3_input_face = stage_2_output_dir\n            stage_3_output_dir = os.path.join(\n                self.opts.output_folder, \"stage_3_face_output\"\n            )\n\n            os.makedirs(stage_3_output_dir, exist_ok=True)\n\n            self.opts.checkpoint_name = \"FaceSR_512\"\n            stage_3_command = (\n                    \"python test_face.py --old_face_folder \"\n                    + stage_3_input_face\n                    + \" --old_face_label_folder \"\n                    + stage_3_input_mask\n                    + \" --tensorboard_log --name \"\n                    + self.opts.checkpoint_name\n                    + \" --gpu_ids \"\n                    + gpu1\n                    + \" --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir \"\n                    + stage_3_output_dir\n                    + \" --no_parsing_map\"\n            )\n\n            run_cmd(stage_3_command)\n            print(\"Finish Stage 3 ...\")\n            print(\"\\n\")\n\n            ## Stage 4: Warp back\n            print(\"Running Stage 4: Blending\")\n            os.chdir(\".././Face_Detection\")\n            stage_4_input_image_dir = os.path.join(stage_1_output_dir, \"restored_image\")\n            stage_4_input_face_dir = os.path.join(stage_3_output_dir, \"each_img\")\n            stage_4_output_dir = os.path.join(self.opts.output_folder, \"final_output\")\n            os.makedirs(stage_4_output_dir, exist_ok=True)\n\n            stage_4_command = (\n                    \"python align_warp_back_multiple_dlib_HR.py --origin_url \"\n                    + stage_4_input_image_dir\n                    + \" --replace_url \"\n                    + stage_4_input_face_dir\n                    + \" --save_url \"\n                    + stage_4_output_dir\n            )\n\n            run_cmd(stage_4_command)\n            print(\"Finish Stage 4 ...\")\n            print(\"\\n\")\n\n            print(\"All the processing is done. Please check the results.\")\n\n            final_output = os.listdir(os.path.join(self.opts.output_folder, \"final_output\"))[0]\n\n            image_restore = cv2.imread(os.path.join(self.opts.output_folder, \"final_output\", final_output))\n\n            out_path = Path(tempfile.mkdtemp()) / \"out.png\"\n\n            cv2.imwrite(str(out_path), image_restore)\n        finally:\n            clean_folder(self.opts.input_folder)\n            clean_folder(self.opts.output_folder)\n        return out_path\n\n\ndef clean_folder(folder):\n    for filename in os.listdir(folder):\n        file_path = os.path.join(folder, filename)\n        try:\n            if os.path.isfile(file_path) or os.path.islink(file_path):\n                os.unlink(file_path)\n            elif os.path.isdir(file_path):\n                shutil.rmtree(file_path)\n        except Exception as e:\n            print(f\"Failed to delete {file_path}. Reason:{e}\")\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\ndlib\nscikit-image\neasydict\nPyYAML\ndominate>=2.3.1\ndill\ntensorboardX\nscipy\nopencv-python\neinops\nPySimpleGUI\nmatplotlib\n"
  },
  {
    "path": "run.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\nimport os\nimport argparse\nimport shutil\nimport sys\nfrom subprocess import call\n\ndef run_cmd(command):\n    try:\n        call(command, shell=True)\n    except KeyboardInterrupt:\n        print(\"Process interrupted\")\n        sys.exit(1)\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input_folder\", type=str, default=\"./test_images/old\", help=\"Test images\")\n    parser.add_argument(\n        \"--output_folder\",\n        type=str,\n        default=\"./output\",\n        help=\"Restored images, please use the absolute path\",\n    )\n    parser.add_argument(\"--GPU\", type=str, default=\"6,7\", help=\"0,1,2\")\n    parser.add_argument(\n        \"--checkpoint_name\", type=str, default=\"Setting_9_epoch_100\", help=\"choose which checkpoint\"\n    )\n    parser.add_argument(\"--with_scratch\", action=\"store_true\")\n    parser.add_argument(\"--HR\", action='store_true')\n    opts = parser.parse_args()\n\n    gpu1 = opts.GPU\n\n    # resolve relative paths before changing directory\n    opts.input_folder = os.path.abspath(opts.input_folder)\n    opts.output_folder = os.path.abspath(opts.output_folder)\n    if not os.path.exists(opts.output_folder):\n        os.makedirs(opts.output_folder)\n\n    main_environment = os.getcwd()\n\n    ## Stage 1: Overall Quality Improve\n    print(\"Running Stage 1: Overall restoration\")\n    os.chdir(\"./Global\")\n    stage_1_input_dir = opts.input_folder\n    stage_1_output_dir = os.path.join(opts.output_folder, \"stage_1_restore_output\")\n    if not os.path.exists(stage_1_output_dir):\n        os.makedirs(stage_1_output_dir)\n\n    if not opts.with_scratch:\n        stage_1_command = (\n            \"python test.py --test_mode Full --Quality_restore --test_input \"\n            + stage_1_input_dir\n            + \" --outputs_dir \"\n            + stage_1_output_dir\n            + \" --gpu_ids \"\n            + gpu1\n        )\n        run_cmd(stage_1_command)\n    else:\n\n        mask_dir = os.path.join(stage_1_output_dir, \"masks\")\n        new_input = os.path.join(mask_dir, \"input\")\n        new_mask = os.path.join(mask_dir, \"mask\")\n        stage_1_command_1 = (\n            \"python detection.py --test_path \"\n            + stage_1_input_dir\n            + \" --output_dir \"\n            + mask_dir\n            + \" --input_size full_size\"\n            + \" --GPU \"\n            + gpu1\n        )\n\n        if opts.HR:\n            HR_suffix=\" --HR\"\n        else:\n            HR_suffix=\"\"\n\n        stage_1_command_2 = (\n            \"python test.py --Scratch_and_Quality_restore --test_input \"\n            + new_input\n            + \" --test_mask \"\n            + new_mask\n            + \" --outputs_dir \"\n            + stage_1_output_dir\n            + \" --gpu_ids \"\n            + gpu1 + HR_suffix\n        )\n\n        run_cmd(stage_1_command_1)\n        run_cmd(stage_1_command_2)\n\n    ## Solve the case when there is no face in the old photo\n    stage_1_results = os.path.join(stage_1_output_dir, \"restored_image\")\n    stage_4_output_dir = os.path.join(opts.output_folder, \"final_output\")\n    if not os.path.exists(stage_4_output_dir):\n        os.makedirs(stage_4_output_dir)\n    for x in os.listdir(stage_1_results):\n        img_dir = os.path.join(stage_1_results, x)\n        shutil.copy(img_dir, stage_4_output_dir)\n\n    print(\"Finish Stage 1 ...\")\n    print(\"\\n\")\n\n    ## Stage 2: Face Detection\n\n    print(\"Running Stage 2: Face Detection\")\n    os.chdir(\".././Face_Detection\")\n    stage_2_input_dir = os.path.join(stage_1_output_dir, \"restored_image\")\n    stage_2_output_dir = os.path.join(opts.output_folder, \"stage_2_detection_output\")\n    if not os.path.exists(stage_2_output_dir):\n        os.makedirs(stage_2_output_dir)\n    if opts.HR:\n        stage_2_command = (\n            \"python detect_all_dlib_HR.py --url \" + stage_2_input_dir + \" --save_url \" + stage_2_output_dir\n        )\n    else:\n        stage_2_command = (\n            \"python detect_all_dlib.py --url \" + stage_2_input_dir + \" --save_url \" + stage_2_output_dir\n        )\n    run_cmd(stage_2_command)\n    print(\"Finish Stage 2 ...\")\n    print(\"\\n\")\n\n    ## Stage 3: Face Restore\n    print(\"Running Stage 3: Face Enhancement\")\n    os.chdir(\".././Face_Enhancement\")\n    stage_3_input_mask = \"./\"\n    stage_3_input_face = stage_2_output_dir\n    stage_3_output_dir = os.path.join(opts.output_folder, \"stage_3_face_output\")\n    if not os.path.exists(stage_3_output_dir):\n        os.makedirs(stage_3_output_dir)\n    \n    if opts.HR:\n        opts.checkpoint_name='FaceSR_512'\n        stage_3_command = (\n            \"python test_face.py --old_face_folder \"\n            + stage_3_input_face\n            + \" --old_face_label_folder \"\n            + stage_3_input_mask\n            + \" --tensorboard_log --name \"\n            + opts.checkpoint_name\n            + \" --gpu_ids \"\n            + gpu1\n            + \" --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir \"\n            + stage_3_output_dir\n            + \" --no_parsing_map\"\n        ) \n    else:\n        stage_3_command = (\n            \"python test_face.py --old_face_folder \"\n            + stage_3_input_face\n            + \" --old_face_label_folder \"\n            + stage_3_input_mask\n            + \" --tensorboard_log --name \"\n            + opts.checkpoint_name\n            + \" --gpu_ids \"\n            + gpu1\n            + \" --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir \"\n            + stage_3_output_dir\n            + \" --no_parsing_map\"\n        )\n    run_cmd(stage_3_command)\n    print(\"Finish Stage 3 ...\")\n    print(\"\\n\")\n\n    ## Stage 4: Warp back\n    print(\"Running Stage 4: Blending\")\n    os.chdir(\".././Face_Detection\")\n    stage_4_input_image_dir = os.path.join(stage_1_output_dir, \"restored_image\")\n    stage_4_input_face_dir = os.path.join(stage_3_output_dir, \"each_img\")\n    stage_4_output_dir = os.path.join(opts.output_folder, \"final_output\")\n    if not os.path.exists(stage_4_output_dir):\n        os.makedirs(stage_4_output_dir)\n    if opts.HR:\n        stage_4_command = (\n            \"python align_warp_back_multiple_dlib_HR.py --origin_url \"\n            + stage_4_input_image_dir\n            + \" --replace_url \"\n            + stage_4_input_face_dir\n            + \" --save_url \"\n            + stage_4_output_dir\n        )\n    else:\n        stage_4_command = (\n            \"python align_warp_back_multiple_dlib.py --origin_url \"\n            + stage_4_input_image_dir\n            + \" --replace_url \"\n            + stage_4_input_face_dir\n            + \" --save_url \"\n            + stage_4_output_dir\n        )\n    run_cmd(stage_4_command)\n    print(\"Finish Stage 4 ...\")\n    print(\"\\n\")\n\n    print(\"All the processing is done. Please check the results.\")\n\n"
  }
]