[
  {
    "path": ".gitignore",
    "content": "jax/internal/pycolmap\njax/dataset\njax/output\n# jax/scripts\n**/__pycache__\n.DS_Store\n.vscode/\n.idea/\n__MACOSX/"
  },
  {
    "path": "LICENSE",
    "content": "## creative commons\n\n# Attribution-NonCommercial 4.0 International\n\nCreative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.\n\n### Using Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.\n\n* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).\n\n* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).\n\n## Creative Commons Attribution-NonCommercial 4.0 International Public License\n\nBy exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License (\"Public License\"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.\n\n### Section 1 – Definitions.\n\na. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.\n\nb. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.\n\nc. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.\n\nd. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.\n\ne. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.\n\nf. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.\n\ng. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.\n\nh. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.\n\ni. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.\n\nj. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.\n\nk. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.\n\nl. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.\n\n### Section 2 – Scope.\n\na. ___License grant.___\n\n   1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:\n\n       A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and\n\n       B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.\n\n   2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.\n\n   3. __Term.__ The term of this Public License is specified in Section 6(a).\n\n   4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.\n\n   5. __Downstream recipients.__\n\n        A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.\n\n        B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.\n\n   6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).\n\nb. ___Other rights.___\n\n   1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.\n\n   2. Patent and trademark rights are not licensed under this Public License.\n\n   3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.\n\n### Section 3 – License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the following conditions.\n\na. ___Attribution.___\n\n   1. If You Share the Licensed Material (including in modified form), You must:\n\n       A. retain the following if it is supplied by the Licensor with the Licensed Material:\n\n         i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);\n\n         ii. a copyright notice;\n\n         iii. a notice that refers to this Public License;\n\n         iv. a notice that refers to the disclaimer of warranties;\n\n         v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;\n\n       B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and\n\n       C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.\n\n   2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.\n\n   3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.\n\n   4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.\n\n### Section 4 – Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:\n\na. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;\n\nb. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and\n\nc. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.\n\nFor the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.\n\n### Section 5 – Disclaimer of Warranties and Limitation of Liability.\n\na. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__\n\nb. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__\n\nc. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.\n\n### Section 6 – Term and Termination.\n\na. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.\n\nb. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:\n\n   1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or\n\n   2. upon express reinstatement by the Licensor.\n\n   For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.\n\nc. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.\n\nd. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.\n\n### Section 7 – Other Terms and Conditions.\n\na. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.\n\nb. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.\n\n### Section 8 – Interpretation.\n\na. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.\n\nb. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.\n\nc. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.\n\nd. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.\n\n> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.\n>\n> Creative Commons may be contacted at creativecommons.org\n\nCopyright (c) 2022 MCG-NKU\n"
  },
  {
    "path": "eval_metrics_llff.py",
    "content": "import sys\nsys.path.insert(0,'jax/internal/pycolmap')\nsys.path.insert(0,'jax/internal/pycolmap/pycolmap')\nimport os\nimport glob\nimport numpy as np\nimport cv2\nimport torch\nfrom torchmetrics.functional import peak_signal_noise_ratio as compute_psnr\nfrom torchmetrics.functional import structural_similarity_index_measure as compute_ssim\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS\nimport json\nfrom tqdm import tqdm\n\nfrom typing import Mapping, Optional, Sequence, Text, Tuple, Union\nimport enum\nimport types\n\nimport pycolmap\n\n_Array = Union[np.ndarray]\n\nclass ProjectionType(enum.Enum):\n  \"\"\"Camera projection type (standard perspective pinhole or fisheye model).\"\"\"\n  PERSPECTIVE = 'perspective'\n  FISHEYE = 'fisheye'\n  \ndef intrinsic_matrix(fx: float,\n                     fy: float,\n                     cx: float,\n                     cy: float,\n                     xnp: types.ModuleType = np) -> _Array:\n  \"\"\"Intrinsic matrix for a pinhole camera in OpenCV coordinate system.\"\"\"\n  return xnp.array([\n      [fx, 0, cx],\n      [0, fy, cy],\n      [0, 0, 1.],\n  ])\n  \ndef file_exists(pth):\n  return os.path.exists(pth)\n  \ndef listdir(pth):\n  return os.listdir(pth)\n\n\n\nclass NeRFSceneManager(pycolmap.SceneManager):\n  \"\"\"COLMAP pose loader.\n\n  Minor NeRF-specific extension to the third_party Python COLMAP loader:\n  google3/third_party/py/pycolmap/scene_manager.py\n  \"\"\"\n\n  def process(\n      self\n  ) -> Tuple[Sequence[Text], np.ndarray, np.ndarray, Optional[Mapping[\n      Text, float]], ProjectionType]:\n    \"\"\"Applies NeRF-specific postprocessing to the loaded pose data.\n\n    Returns:\n      a tuple [image_names, poses, pixtocam, distortion_params].\n      image_names:  contains the only the basename of the images.\n      poses: [N, 4, 4] array containing the camera to world matrices.\n      pixtocam: [N, 3, 3] array containing the camera to pixel space matrices.\n      distortion_params: mapping of distortion param name to distortion\n        parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2.\n    \"\"\"\n\n    self.load_cameras()\n    self.load_images()\n    # self.load_points3D()  # For now, we do not need the point cloud data.\n\n    # Assume shared intrinsics between all cameras.\n    cam = self.cameras[1]\n\n    # Extract focal lengths and principal point parameters.\n    fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy\n    pixtocam = np.linalg.inv(intrinsic_matrix(fx, fy, cx, cy))\n\n    # Extract extrinsic matrices in world-to-camera format.\n    imdata = self.images\n    w2c_mats = []\n    bottom = np.array([0, 0, 0, 1]).reshape(1, 4)\n    for k in imdata:\n      im = imdata[k]\n      rot = im.R()\n      trans = im.tvec.reshape(3, 1)\n      w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)\n      w2c_mats.append(w2c)\n    w2c_mats = np.stack(w2c_mats, axis=0)\n\n    # Convert extrinsics to camera-to-world.\n    c2w_mats = np.linalg.inv(w2c_mats)\n    poses = c2w_mats[:, :3, :4]\n\n    # Image names from COLMAP. No need for permuting the poses according to\n    # image names anymore.\n    names = [imdata[k].name for k in imdata]\n\n    # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.\n    poses = poses @ np.diag([1, -1, -1, 1])\n\n    # Get distortion parameters.\n    type_ = cam.camera_type\n\n    if type_ == 0 or type_ == 'SIMPLE_PINHOLE':\n      params = None\n      camtype = ProjectionType.PERSPECTIVE\n\n    elif type_ == 1 or type_ == 'PINHOLE':\n      params = None\n      camtype = ProjectionType.PERSPECTIVE\n\n    if type_ == 2 or type_ == 'SIMPLE_RADIAL':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      camtype = ProjectionType.PERSPECTIVE\n\n    elif type_ == 3 or type_ == 'RADIAL':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      camtype = ProjectionType.PERSPECTIVE\n\n    elif type_ == 4 or type_ == 'OPENCV':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      params['p1'] = cam.p1\n      params['p2'] = cam.p2\n      camtype = ProjectionType.PERSPECTIVE\n\n    elif type_ == 5 or type_ == 'OPENCV_FISHEYE':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      params['k3'] = cam.k3\n      params['k4'] = cam.k4\n      camtype = ProjectionType.FISHEYE\n\n    return names, poses, pixtocam, params, camtype\n\ndef load_gts(data_dir, llffhold = 8, factor = 8, load_alphabetical = True):\n    image_dir_suffix = ''\n    # Use downsampling factor (unless loading training split for raw dataset,\n    # we train raw at full resolution because of the Bayer mosaic pattern).\n    if factor > 0:\n      image_dir_suffix = f'_{factor}'\n\n    # Copy COLMAP data to local disk for faster loading.\n    colmap_dir = os.path.join(data_dir, 'sparse/0/')\n\n    # Load poses.\n    if os.path.exists(colmap_dir):\n      pose_data = NeRFSceneManager(colmap_dir).process()\n    else:\n      raise ValueError(f'Image folder {colmap_dir} does not exist.')\n    image_names = pose_data[0]\n\n    # Previous NeRF results were generated with images sorted by filename,\n    # use this flag to ensure metrics are reported on the same test set.\n    if load_alphabetical:\n      inds = np.argsort(image_names)\n      image_names = [image_names[i] for i in inds]\n      \n    colmap_image_dir = os.path.join(data_dir, 'images')\n    image_dir = os.path.join(data_dir, 'images' + image_dir_suffix)\n    for d in [image_dir, colmap_image_dir]:\n        if not file_exists(d):\n            raise ValueError(f'Image folder {d} does not exist.')\n    # Downsampled images may have different names vs images used for COLMAP,\n    # so we need to map between the two sorted lists of files.\n    colmap_files = sorted(listdir(colmap_image_dir))\n    image_files = sorted(listdir(image_dir))\n    colmap_to_image = dict(zip(colmap_files, image_files))\n    image_paths = [os.path.join(image_dir, colmap_to_image[f])\n                     for f in image_names]\n    images = [load_image(x) for x in image_paths]\n    images = np.stack(images, axis=0)\n    test_indices = np.arange(images.shape[0])[::llffhold]\n    test_images = images[test_indices]\n    \n    return test_images\n\ndef load_image(path):\n    return cv2.imread(path)[:,:,::-1]\n\ndef im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):\n# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):\n    return torch.Tensor((image / factor - cent)\n                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))\n_ = torch.manual_seed(202208)\ncompute_lpips = LPIPS()\n\n# prediction path\npath = '/mnt/sda/experiments/cvpr23/Mip-NeRF-360/logs_Mip-NeRF-360/Scan*/test_preds'\n# ground truth path\ngt_path_root = '/mnt/sda/experiments/cvpr23_real_cap_dataset'\nbasedir = '/'.join(path.split('/')[:-2])\n\nresults_list = sorted(glob.glob(path))\nresults_dict = {}\nfor scene in results_list:\n    scene_name = scene.split('/')[-2].split('_')[-1]\n    gt_images = load_gts(os.path.join(gt_path_root, scene_name))\n    img_list = glob.glob(os.path.join(scene, 'color_0*.png'))\n    scene_dict  = {}\n    psnrs, ssims, lpipss = [], [], []\n    print (f'eval {scene_name}')\n    for pred_path in tqdm(img_list):\n        img_name = pred_path.split('/')[-1]\n        idx = int(img_name.split('.')[0].split('_')[-1])\n        \n        pred = load_image(pred_path)\n        gt   = gt_images[idx]\n        \n        gt_out = pred_path.replace('color', 'gt')\n        cv2.imwrite(gt_out, gt[..., ::-1])\n        \n        pred_lpips = im2tensor(pred)\n        gt_lpips   = im2tensor(gt)\n        \n        pred = torch.from_numpy((pred / 255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)\n        gt   = torch.from_numpy((gt / 255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)\n        \n        with torch.no_grad():\n            psnr = compute_psnr(pred, gt, data_range=1.0).item()\n            ssim = compute_ssim(pred, gt, data_range=1.0).item()\n            lpips = compute_lpips(pred_lpips, gt_lpips).item()\n        img_dict = {'psnr'  : psnr,\n                    'ssim'  : ssim,\n                    'lpips' : lpips}\n        scene_dict[img_name] = img_dict\n        psnrs.append(psnr)\n        ssims.append(ssim)\n        lpipss.append(lpips)\n    avg_psnr = np.mean(psnrs)\n    avg_ssim = np.mean(ssims)\n    avg_lpips = np.mean(lpipss)\n    avg_dict = {'avg_psnr'  : avg_psnr,\n                'avg_ssim'  : avg_ssim,\n                'avg_lpips' : avg_lpips}\n    scene_dict[scene_name] = avg_dict\n    with open(os.path.join(scene, 'psnr_ssim_lpips.json'), 'w')  as f:\n        json.dump(scene_dict, f)\n    results_dict[scene_name] = avg_dict\n\nwhole_psnrs, whole_ssims, whole_lpipss = [], [], []\nfor _, val in results_dict.items():\n    whole_psnrs.append(val['avg_psnr'])\n    whole_ssims.append(val['avg_ssim'])\n    whole_lpipss.append(val['avg_lpips'])\nwhole_avg = {\n    'avg_psnr'  : np.mean(whole_psnrs),\n    'avg_ssim'  : np.mean(whole_ssims),\n    'avg_lpips' : np.mean(whole_lpipss)\n}\nresults_dict['whole'] = whole_avg\nwith open(os.path.join(basedir, 'psnr_ssim_lpips_llff.json'), 'w') as f:\n    json.dump(results_dict, f)"
  },
  {
    "path": "eval_metrics_syn.py",
    "content": "import os\nimport glob\nimport numpy as np\nimport cv2\nimport torch\nfrom torchmetrics.functional import peak_signal_noise_ratio as compute_psnr\nfrom torchmetrics.functional import structural_similarity_index_measure as compute_ssim\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS\nimport json\nfrom tqdm import tqdm\n\ndef load_image(path):\n    return cv2.imread(path)[:,:,::-1]\n\ndef im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):\n    return torch.Tensor((image / factor - cent)\n                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))\n_ = torch.manual_seed(202208)\ncompute_lpips = LPIPS()\n\n# prediction path\npath = '/mnt/sda/experiments/cvpr23/Mip-NeRF-360/logs_MS-Mip-NeRF-360/Scene*/test_preds'\n# ground truth path\ngt_path_root = '/mnt/sda/T3/cvpr23/dataset/synthetic_scenes'\nbasedir = '/'.join(path.split('/')[:-2])\nresults_list = sorted(glob.glob(path))\nresults_dict = {}\nfor scene in results_list:\n    scene_name = scene.split('/')[-2]\n    img_list = glob.glob(os.path.join(scene, 'color_0*.png'))\n    scene_dict  = {}\n    psnrs, ssims, lpipss = [], [], []\n    print (f'eval {scene_name}')\n    for pred_path in tqdm(img_list):\n        img_name = pred_path.split('/')[-1][6:9]\n        gt_path = os.path.join(gt_path_root, scene_name, f'test/r_{int(img_name)}.png')\n\n        pred = load_image(pred_path)\n        gt   = load_image(gt_path)\n        pred_lpips = im2tensor(pred)\n        gt_lpips   = im2tensor(gt)\n        \n        pred = torch.from_numpy((pred / 255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)\n        gt   = torch.from_numpy((gt / 255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)\n        \n        with torch.no_grad():\n            psnr = compute_psnr(pred, gt, data_range=1.0).item()\n            ssim = compute_ssim(pred, gt, data_range=1.0).item()\n            lpips = compute_lpips(pred_lpips, gt_lpips).item()\n        img_dict = {'psnr'  : psnr,\n                    'ssim'  : ssim,\n                    'lpips' : lpips}\n        scene_dict[img_name] = img_dict\n        psnrs.append(psnr)\n        ssims.append(ssim)\n        lpipss.append(lpips)\n    avg_psnr = np.mean(psnrs)\n    avg_ssim = np.mean(ssims)\n    avg_lpips = np.mean(lpipss)\n    avg_dict = {'avg_psnr'  : avg_psnr,\n                'avg_ssim'  : avg_ssim,\n                'avg_lpips' : avg_lpips}\n    scene_dict[scene_name] = avg_dict\n    with open(os.path.join(scene, 'psnr_ssim_lpips.json'), 'w')  as f:\n        json.dump(scene_dict, f)\n    results_dict[scene_name] = avg_dict\n\nwhole_psnrs, whole_ssims, whole_lpipss = [], [], []\nfor _, val in results_dict.items():\n    whole_psnrs.append(val['avg_psnr'])\n    whole_ssims.append(val['avg_ssim'])\n    whole_lpipss.append(val['avg_lpips'])\nwhole_avg = {\n    'avg_psnr'  : np.mean(whole_psnrs),\n    'avg_ssim'  : np.mean(whole_ssims),\n    'avg_lpips' : np.mean(whole_lpipss)\n}\nresults_dict['whole'] = whole_avg\nwith open(os.path.join(basedir, 'psnr_ssim_lpips_syn.json'), 'w') as f:\n    json.dump(results_dict, f)"
  },
  {
    "path": "mip360/README.md",
    "content": "# This is the code release for MS-NeRF(CVPR 2023) based on [MultiNeRF](https://github.com/google-research/multinerf)\n\n\nThis repository contains the code release for CVPR 2023 paper [MS-NeRF](https://arxiv.org/abs/2305.04268), and we conduct all the Mip-NeRF 360 based experiments in this repository, therefore, you should reproduce the results reported in our paper.\n\nBesides, the original repository contains three CVPR 2022 papers: \n[Mip-NeRF 360](https://jonbarron.info/mipnerf360/),\n[Ref-NeRF](https://dorverbin.github.io/refnerf/), and\n[RawNeRF](https://bmild.github.io/rawnerf/).\nAs we make minimal modifications, other methods should be runnable.\nBut we recommend using the original repository.\n\nThis implementation is written in [JAX](https://github.com/google/jax), and\nis a fork of [MultiNeRF](https://github.com/google-research/multinerf).\nThis is research code, and should be treated accordingly.\n\n## Setup\n\n```\n# Clone the repo.\ngit clone https://github.com/ZX-Yin/ms-nerf.git\ncd ms-nerf/jax/\n\n# Make a conda environment.\nconda create --name ms-nerf python=3.9\nconda activate ms-nerf\n\n# Prepare pip.\nconda install pip\npip install --upgrade pip\n\n# Install requirements.\npip install -r requirements.txt\n\n# Manually install rmbrualla's `pycolmap` (don't use pip's! It's different).\ngit clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap\n\n# Confirm that all the unit tests pass.\n./scripts/run_all_unit_tests.sh\n```\nYou'll probably also need to update your JAX installation to support GPUs or TPUs.\n\n## Running MS-Mip-NeRF 360\n\nExample scripts for training, evaluating and rendering with MS-Mip-NeRF 360 can be found in `scripts/msnerf/`. And we evaluate the PSNR, SSIM, and LPIPS using our own script.\n\n## Running MS-Mip-NeRF and NeRF\n\nWe are trying to integrate this two experiments into this repository.\n\n## Evaluating\n\nWe use PyTorch-based Python scripts to evaluate all our results, as there are many convient packages to use. Experiments on synthetic part of our dataset are evaluated using `../eval_metrics_syn.py`, and those on real captured part are using `../eval_metrics_llff.py`. You just need to set the variables `path` and `gt_path_root` in the scripts.\n\n### OOM errors\n\nYou may need to reduce the batch size (`Config.batch_size`) to avoid out of memory\nerrors. If you do this, but want to preserve quality, be sure to increase the number\nof training iterations and decrease the learning rate by whatever scale factor you\ndecrease batch size by.\n\n### Notification\n\nWe make minimal modifications to this repository, therefore, all the other properties should remain the same. Please follow the [instructions](https://github.com/google-research/multinerf/blob/main/README.md) to further conduct researches."
  },
  {
    "path": "mip360/configs/360.gin",
    "content": "Config.dataset_loader = 'llff'\nConfig.near = 0.2\nConfig.far = 1e6\nConfig.factor = 4\n\nModel.raydist_fn = @jnp.reciprocal\nModel.opaque_background = True\n\nPropMLP.warp_fn = @coord.contract\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.warp_fn = @coord.contract\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 1024\nNerfMLP.disable_density_normals = True\n\n"
  },
  {
    "path": "mip360/configs/360_glo4.gin",
    "content": "Config.dataset_loader = 'llff'\nConfig.near = 0.2\nConfig.far = 1e6\nConfig.factor = 4\n\nModel.raydist_fn = @jnp.reciprocal\nModel.num_glo_features = 4\nModel.opaque_background = True\n\nPropMLP.warp_fn = @coord.contract\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.warp_fn = @coord.contract\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 1024\nNerfMLP.disable_density_normals = True\n"
  },
  {
    "path": "mip360/configs/blender_256.gin",
    "content": "Config.dataset_loader = 'blender'\nConfig.batching = 'single_image'\nConfig.near = 2\nConfig.far = 6\nConfig.eval_render_interval = 5\nConfig.data_loss_type = 'mse'\nConfig.adam_eps = 1e-8\n\nModel.num_levels = 2\nModel.num_prop_samples = 128\nModel.num_nerf_samples = 32\n\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.basis_shape = 'octahedron'\nPropMLP.basis_subdivisions = 1\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 256\nNerfMLP.basis_shape = 'octahedron'\nNerfMLP.basis_subdivisions = 1\nNerfMLP.disable_density_normals = True\n\nConfig.distortion_loss_mult = 0.\n\nNerfMLP.max_deg_point = 16\nPropMLP.max_deg_point = 16\n"
  },
  {
    "path": "mip360/configs/blender_512.gin",
    "content": "Config.dataset_loader = 'blender'\nConfig.batching = 'single_image'\nConfig.near = 2\nConfig.far = 6\nConfig.eval_render_interval = 5\nConfig.data_loss_type = 'mse'\nConfig.adam_eps = 1e-8\n\nModel.num_levels = 2\nModel.num_prop_samples = 128\nModel.num_nerf_samples = 32\n\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 512\nNerfMLP.disable_density_normals = True\n\nConfig.distortion_loss_mult = 0.\n\nNerfMLP.max_deg_point = 16\nPropMLP.max_deg_point = 16\n"
  },
  {
    "path": "mip360/configs/blender_refnerf.gin",
    "content": "Config.dataset_loader = 'blender'\nConfig.batching = 'single_image'\nConfig.near = 2\nConfig.far = 6\nConfig.eval_render_interval = 5\nConfig.compute_normal_metrics = True\nConfig.data_loss_type = 'mse'\nConfig.distortion_loss_mult = 0.0\nConfig.orientation_loss_mult = 0.1\nConfig.orientation_loss_target = 'normals_pred'\nConfig.predicted_normal_loss_mult = 3e-4\nConfig.orientation_coarse_loss_mult = 0.01\nConfig.predicted_normal_coarse_loss_mult = 3e-5\nConfig.interlevel_loss_mult = 0.0\nConfig.data_coarse_loss_mult = 0.1\nConfig.adam_eps = 1e-8\n\nModel.num_levels = 2\nModel.single_mlp = True\nModel.num_prop_samples = 128  # This needs to be set despite single_mlp = True.\nModel.num_nerf_samples = 128\nModel.anneal_slope = 0.\nModel.dilation_multiplier = 0.\nModel.dilation_bias = 0.\nModel.single_jitter = False\nModel.resample_padding = 0.01\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 256\nNerfMLP.net_depth_viewdirs = 8\nNerfMLP.basis_shape = 'octahedron'\nNerfMLP.basis_subdivisions = 1\nNerfMLP.disable_density_normals = False\nNerfMLP.enable_pred_normals = True\nNerfMLP.use_directional_enc = True\nNerfMLP.use_reflections = True\nNerfMLP.deg_view = 5\nNerfMLP.enable_pred_roughness = True\nNerfMLP.use_diffuse_color = True\nNerfMLP.use_specular_tint = True\nNerfMLP.use_n_dot_v = True\nNerfMLP.bottleneck_width = 128\nNerfMLP.density_bias = 0.5\nNerfMLP.max_deg_point = 16\n"
  },
  {
    "path": "mip360/configs/debug.gin",
    "content": "# A short training schedule with no \"warm up\", useful for debugging.\nConfig.checkpoint_every = 1000\nConfig.print_every = 100\nConfig.train_render_every = 1000\nConfig.lr_delay_mult = 0.1\nConfig.lr_delay_steps = 500\nConfig.batch_size = 2048\nConfig.render_chunk_size = 2048\nConfig.lr_init = 5e-4\nConfig.lr_final = 5e-6\nConfig.factor = 4\nConfig.early_exit_steps = 3000\n\nPropMLP.net_depth = 2\nPropMLP.net_width = 64\n\nNerfMLP.net_depth = 4\nNerfMLP.net_width = 128\n"
  },
  {
    "path": "mip360/configs/llff_256.gin",
    "content": "Config.dataset_loader = 'llff'\nConfig.near = 0.\nConfig.far = 1.\nConfig.factor = 4\nConfig.forward_facing = True\nConfig.adam_eps = 1e-8\n\nModel.ray_shape = 'cylinder'\nModel.opaque_background = True\nModel.num_levels = 2\nModel.num_prop_samples = 128\nModel.num_nerf_samples = 32\n\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.basis_shape = 'octahedron'\nPropMLP.basis_subdivisions = 1\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 256\nNerfMLP.basis_shape = 'octahedron'\nNerfMLP.basis_subdivisions = 1\nNerfMLP.disable_density_normals = True\n\nNerfMLP.max_deg_point = 16\nPropMLP.max_deg_point = 16\n"
  },
  {
    "path": "mip360/configs/llff_512.gin",
    "content": "Config.dataset_loader = 'llff'\nConfig.near = 0.\nConfig.far = 1.\nConfig.factor = 4\nConfig.forward_facing = True\nConfig.adam_eps = 1e-8\n\nModel.ray_shape = 'cylinder'\nModel.opaque_background = True\nModel.num_levels = 2\nModel.num_prop_samples = 128\nModel.num_nerf_samples = 32\n\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 512\nNerfMLP.disable_density_normals = True\n\nNerfMLP.max_deg_point = 16\nPropMLP.max_deg_point = 16\n"
  },
  {
    "path": "mip360/configs/llff_raw.gin",
    "content": "# General LLFF settings\n\nConfig.dataset_loader = 'llff'\nConfig.near = 0.\nConfig.far = 1.\nConfig.factor = 4\nConfig.forward_facing = True\n\nModel.ray_shape = 'cylinder'\n\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.basis_shape = 'octahedron'\nPropMLP.basis_subdivisions = 1\nPropMLP.disable_density_normals = True  # Turn this off if using orientation loss.\nPropMLP.disable_rgb = True\n\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 256\nNerfMLP.basis_shape = 'octahedron'\nNerfMLP.basis_subdivisions = 1\nNerfMLP.disable_density_normals = True  # Turn this off if using orientation loss.\n\nNerfMLP.max_deg_point = 16\nPropMLP.max_deg_point = 16\n\nConfig.train_render_every = 5000\n\n\n########################## RawNeRF specific settings ##########################\n\nConfig.rawnerf_mode = True\nConfig.data_loss_type = 'rawnerf'\nConfig.apply_bayer_mask = True\nModel.learned_exposure_scaling = True\n\nModel.num_levels = 2\nModel.num_prop_samples = 128  # Using extra samples for now because of noise instability.\nModel.num_nerf_samples = 128\nModel.opaque_background = True\n\n# RGB activation we use for linear color outputs is exp(x - 5).\nNerfMLP.rgb_padding = 0.\nNerfMLP.rgb_activation = @math.safe_exp\nNerfMLP.rgb_bias = -5.\nPropMLP.rgb_padding = 0.\nPropMLP.rgb_activation = @math.safe_exp\nPropMLP.rgb_bias = -5.\n\n## Experimenting with the various regularizers and losses:\nConfig.interlevel_loss_mult = .0  # Turning off interlevel for now (default = 1.).\nConfig.distortion_loss_mult = .01  # Distortion loss helps with floaters (default = .01).\nConfig.orientation_loss_mult = 0.  # Orientation loss also not great (try .01).\nConfig.data_coarse_loss_mult = 0.1  # Setting this to match old MipNeRF.\n\n## Density noise used in original NeRF:\nNerfMLP.density_noise = 1.\nPropMLP.density_noise = 1.\n\n## Use a single MLP for all rounds of sampling:\nModel.single_mlp = True\n\n## Some algorithmic settings to match the paper:\nModel.anneal_slope = 0.\nModel.dilation_multiplier = 0.\nModel.dilation_bias = 0.\nModel.single_jitter = False\nNerfMLP.weight_init = 'glorot_uniform'\nPropMLP.weight_init = 'glorot_uniform'\n\n## Training hyperparameters used in the paper:\nConfig.batch_size = 16384\nConfig.render_chunk_size = 16384\nConfig.lr_init = 1e-3\nConfig.lr_final = 1e-5\nConfig.max_steps = 500000\nConfig.checkpoint_every = 25000\nConfig.lr_delay_steps = 2500\nConfig.lr_delay_mult = 0.01\nConfig.grad_max_norm = 0.1\nConfig.grad_max_val = 0.1\nConfig.adam_eps = 1e-8\n"
  },
  {
    "path": "mip360/configs/llff_raw_test.gin",
    "content": "include 'experimental/users/barron/mipnerf360/configs/llff_raw.gin'\n\nConfig.factor = 0\nConfig.eval_raw_affine_cc = True\nConfig.eval_crop_borders = 16\nConfig.vis_decimate = 4\n"
  },
  {
    "path": "mip360/configs/ms-nerf/360.gin",
    "content": "Config.dataset_loader = 'blender'\nConfig.near = 0.2\nConfig.far = 1e6\nConfig.factor = 0\nConfig.max_steps = 200000\nConfig.batch_size = 1024\nConfig.render_chunk_size = 1024\nConfig.checkpoint_every = 10000\nConfig.train_render_every = 10000\nConfig.lr_init = 2e-3\n\nModel.raydist_fn = @jnp.reciprocal\nModel.opaque_background = True\nModel.num_space = 1\n\nPropMLP.warp_fn = @coord.contract\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.warp_fn = @coord.contract\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 1024\nNerfMLP.disable_density_normals = True\nNerfMLP.num_rgb_channels = 3\n"
  },
  {
    "path": "mip360/configs/ms-nerf/ms360.gin",
    "content": "Config.dataset_loader = 'blender'\nConfig.near = 0.2\nConfig.far = 1e6\nConfig.factor = 0\nConfig.max_steps = 200000\nConfig.batch_size = 1024\nConfig.render_chunk_size = 1024\nConfig.checkpoint_every = 10000\nConfig.train_render_every = 10000\nConfig.lr_init = 2e-3\n\nModel.raydist_fn = @jnp.reciprocal\nModel.opaque_background = True\nModel.num_space = 8\n\nPropMLP.warp_fn = @coord.contract\nPropMLP.net_depth = 4\nPropMLP.net_width = 256\nPropMLP.disable_density_normals = True\nPropMLP.disable_rgb = True\n\nNerfMLP.warp_fn = @coord.contract\nNerfMLP.net_depth = 8\nNerfMLP.net_width = 1024\nNerfMLP.disable_density_normals = True\nNerfMLP.num_rgb_channels = 32\nNerfMLP.rgb_activation = @nn.relu\n\nDecoderMLP.hidden_width = 64\n"
  },
  {
    "path": "mip360/configs/render_config.gin",
    "content": "Config.render_path = True\nConfig.render_path_frames = 480\nConfig.render_video_fps = 60\n"
  },
  {
    "path": "mip360/configs/tat.gin",
    "content": "# This config is meant to be run while overriding a 360*.gin config.\n\nConfig.dataset_loader = 'tat_nerfpp'\nConfig.near = 0.1\nConfig.far = 1e6\n"
  },
  {
    "path": "mip360/eval.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Evaluation script.\"\"\"\n\nimport functools\nfrom os import path\nimport sys\nimport time\n\nfrom absl import app\nfrom flax.metrics import tensorboard\nfrom flax.training import checkpoints\nimport gin\nfrom internal import configs\nfrom internal import datasets\nfrom internal import image\nfrom internal import models\nfrom internal import raw_utils\nfrom internal import ref_utils\nfrom internal import train_utils\nfrom internal import utils\nfrom internal import vis\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\nconfigs.define_common_flags()\njax.config.parse_flags_with_absl()\n\n\ndef main(unused_argv):\n  config = configs.load_config(save_config=False)\n\n  dataset = datasets.load_dataset('test', config.data_dir, config)\n\n  key = random.PRNGKey(20200823)\n  _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key)\n\n  if config.rawnerf_mode:\n    postprocess_fn = dataset.metadata['postprocess_fn']\n  else:\n    postprocess_fn = lambda z: z\n\n  if config.eval_raw_affine_cc:\n    cc_fun = raw_utils.match_images_affine\n  else:\n    cc_fun = image.color_correct\n\n  metric_harness = image.MetricHarness()\n\n  last_step = 0\n  out_dir = path.join(config.checkpoint_dir,\n                      'path_renders' if config.render_path else 'test_preds')\n  path_fn = lambda x: path.join(out_dir, x)\n\n  if not config.eval_only_once:\n    summary_writer = tensorboard.SummaryWriter(\n        path.join(config.checkpoint_dir, 'eval'))\n  while True:\n    state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)\n    step = int(state.step)\n    if step <= last_step:\n      print(f'Checkpoint step {step} <= last step {last_step}, sleeping.')\n      time.sleep(10)\n      continue\n    print(f'Evaluating checkpoint at step {step}.')\n    if config.eval_save_output and (not utils.isdir(out_dir)):\n      utils.makedirs(out_dir)\n\n    num_eval = min(dataset.size, config.eval_dataset_limit)\n    key = random.PRNGKey(0 if config.deterministic_showcase else step)\n    perm = random.permutation(key, num_eval)\n    showcase_indices = np.sort(perm[:config.num_showcase_images])\n\n    metrics = []\n    metrics_cc = []\n    showcases = []\n    render_times = []\n    for idx in range(dataset.size):\n      eval_start_time = time.time()\n      batch = next(dataset)\n      if idx >= num_eval:\n        print(f'Skipping image {idx+1}/{dataset.size}')\n        continue\n      if config.eval_one is not None:\n        if idx != config.eval_one[0]:\n          print(f'Skipping image {idx+1}/{dataset.size}')\n          continue\n      print(f'Evaluating image {idx+1}/{dataset.size}')\n      rays = batch.rays\n      train_frac = state.step / config.max_steps\n      rendering = models.render_image(\n          functools.partial(\n              render_eval_pfn,\n              state.params,\n              train_frac,\n          ),\n          rays,\n          None,\n          config,\n      )\n\n      if jax.host_id() != 0:  # Only record via host 0.\n        continue\n\n      render_times.append((time.time() - eval_start_time))\n      print(f'Rendered in {render_times[-1]:0.3f}s')\n\n      # Cast to 64-bit to ensure high precision for color correction function.\n      gt_rgb = np.array(batch.rgb, dtype=np.float64)\n      rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64)\n\n      cc_start_time = time.time()\n      rendering['rgb_cc'] = cc_fun(rendering['rgb'], gt_rgb)\n      print(f'Color corrected in {(time.time() - cc_start_time):0.3f}s')\n\n      if not config.eval_only_once and idx in showcase_indices:\n        showcase_idx = idx if config.deterministic_showcase else len(showcases)\n        showcases.append((showcase_idx, rendering, batch))\n      if not config.render_path:\n        rgb = postprocess_fn(rendering['rgb'])\n        rgb_cc = postprocess_fn(rendering['rgb_cc'])\n        rgb_gt = postprocess_fn(gt_rgb)\n\n        if config.eval_quantize_metrics:\n          # Ensures that the images written to disk reproduce the metrics.\n          rgb = np.round(rgb * 255) / 255\n          rgb_cc = np.round(rgb_cc * 255) / 255\n\n        if config.eval_crop_borders > 0:\n          crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]\n          rgb = crop_fn(rgb)\n          rgb_cc = crop_fn(rgb_cc)\n          rgb_gt = crop_fn(rgb_gt)\n\n        metric = metric_harness(rgb, rgb_gt)\n        metric_cc = metric_harness(rgb_cc, rgb_gt)\n\n        if config.compute_disp_metrics:\n          for tag in ['mean', 'median']:\n            key = f'distance_{tag}'\n            if key in rendering:\n              disparity = 1 / (1 + rendering[key])\n              metric[f'disparity_{tag}_mse'] = float(\n                  ((disparity - batch.disps)**2).mean())\n\n        if config.compute_normal_metrics:\n          weights = rendering['acc'] * batch.alphas\n          normalized_normals_gt = ref_utils.l2_normalize(batch.normals)\n          for key, val in rendering.items():\n            if key.startswith('normals') and val is not None:\n              normalized_normals = ref_utils.l2_normalize(val)\n              metric[key + '_mae'] = ref_utils.compute_weighted_mae(\n                  weights, normalized_normals, normalized_normals_gt)\n\n        for m, v in metric.items():\n          print(f'{m:30s} = {v:.4f}')\n\n        metrics.append(metric)\n        metrics_cc.append(metric_cc)\n\n      if config.eval_save_output and (config.eval_render_interval > 0):\n        if (idx % config.eval_render_interval) == 0:\n          utils.save_img_u8(postprocess_fn(rendering['rgb']),\n                            path_fn(f'color_{idx:03d}.png'))\n          utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),\n                            path_fn(f'color_cc_{idx:03d}.png'))\n\n          for key in ['distance_mean', 'distance_median']:\n            if key in rendering:\n              utils.save_img_f32(rendering[key],\n                                 path_fn(f'{key}_{idx:03d}.tiff'))\n\n          for key in ['normals']:\n            if key in rendering:\n              utils.save_img_u8(rendering[key] / 2. + 0.5,\n                                path_fn(f'{key}_{idx:03d}.png'))\n\n          utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))\n\n      if config.eval_one is not None:\n        out_dir_one = path.join(out_dir, '_'.join([f\"{t:03d}\" for t in config.eval_one]))\n        utils.makedirs(out_dir_one)\n        path_fn_one = lambda x: path.join(out_dir_one, x)\n        utils.save_img_u8(postprocess_fn(rendering['rgb']),\n                          path_fn_one(f'color_{idx:03d}.png'))\n        utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),\n                          path_fn_one(f'color_cc_{idx:03d}.png'))\n\n        for key in ['distance_mean', 'distance_median']:\n          if key in rendering:\n            utils.save_img_f32(rendering[key],\n                               path_fn_one(f'{key}_{idx:03d}.tiff'))\n\n        for key in ['normals']:\n          if key in rendering:\n            utils.save_img_u8(rendering[key] / 2. + 0.5,\n                              path_fn_one(f'{key}_{idx:03d}.png'))\n\n        utils.save_img_f32(rendering['acc'], path_fn_one(f'acc_{idx:03d}.tiff'))\n\n        _, one_x, one_y = config.eval_one\n        dict_one = {}\n        for key in rendering:\n          if key.startswith('ray_'):\n            dict_one[key] = [r[one_x, one_y] for r in rendering[key]]\n          else:\n            dict_one[key] = rendering[key][one_x, one_y]\n        color_mark = rendering['rgb'].copy()\n        color_mark[one_x, one_y] = np.array([255,0,0])\n        utils.save_img_u8(color_mark, path_fn_one(f'color_mark_{idx:03d}_{one_x}_{one_y}.png'))\n        np.savez(path_fn_one(f'eval_one_{idx:03d}_{one_x}_{one_y}.npz'), dict_one)\n\n    if (not config.eval_only_once) and (jax.host_id() == 0):\n      summary_writer.scalar('eval_median_render_time', np.median(render_times),\n                            step)\n      for name in metrics[0]:\n        scores = [m[name] for m in metrics]\n        summary_writer.scalar('eval_metrics/' + name, np.mean(scores), step)\n        summary_writer.histogram('eval_metrics/' + 'perimage_' + name, scores,\n                                 step)\n      for name in metrics_cc[0]:\n        scores = [m[name] for m in metrics_cc]\n        summary_writer.scalar('eval_metrics_cc/' + name, np.mean(scores), step)\n        summary_writer.histogram('eval_metrics_cc/' + 'perimage_' + name,\n                                 scores, step)\n\n      for i, r, b in showcases:\n        if config.vis_decimate > 1:\n          d = config.vis_decimate\n          decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]\n        else:\n          decimate_fn = lambda x: x\n        r = jax.tree_util.tree_map(decimate_fn, r)\n        b = jax.tree_util.tree_map(decimate_fn, b)\n        visualizations = vis.visualize_suite(r, b.rays)\n        for k, v in visualizations.items():\n          if k == 'color':\n            v = postprocess_fn(v)\n          summary_writer.image(f'output_{k}_{i}', v, step)\n        if not config.render_path:\n          target = postprocess_fn(b.rgb)\n          summary_writer.image(f'true_color_{i}', target, step)\n          pred = postprocess_fn(visualizations['color'])\n          residual = np.clip(pred - target + 0.5, 0, 1)\n          summary_writer.image(f'true_residual_{i}', residual, step)\n          if config.compute_normal_metrics:\n            summary_writer.image(f'true_normals_{i}', b.normals / 2. + 0.5,\n                                 step)\n\n    if (config.eval_save_output and (not config.render_path) and\n        (jax.host_id() == 0)):\n      with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:\n        f.write(' '.join([str(r) for r in render_times]))\n      for name in metrics[0]:\n        with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:\n          f.write(' '.join([str(m[name]) for m in metrics]))\n      for name in metrics_cc[0]:\n        with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f:\n          f.write(' '.join([str(m[name]) for m in metrics_cc]))\n      if config.eval_save_ray_data:\n        for i, r, b in showcases:\n          rays = {k: v for k, v in r.items() if 'ray_' in k}\n          np.set_printoptions(threshold=sys.maxsize)\n          with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:\n            f.write(repr(rays))\n\n    # A hack that forces Jax to keep all TPUs alive until every TPU is finished.\n    x = jnp.ones([jax.local_device_count()])\n    x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))\n    print(x)\n\n    if config.eval_only_once:\n      break\n    if config.early_exit_steps is not None:\n      num_steps = config.early_exit_steps\n    else:\n      num_steps = config.max_steps\n    if int(step) >= num_steps:\n      break\n    last_step = step\n\n\nif __name__ == '__main__':\n  with gin.config_scope('eval'):\n    app.run(main)\n"
  },
  {
    "path": "mip360/internal/camera_utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Camera pose and ray generation utility functions.\"\"\"\n\nimport enum\nimport types\nfrom typing import List, Mapping, Optional, Text, Tuple, Union\n\nfrom internal import configs\nfrom internal import math\nfrom internal import stepfun\nfrom internal import utils\nimport jax.numpy as jnp\nimport numpy as np\nimport scipy\n\n_Array = Union[np.ndarray, jnp.ndarray]\n\n\ndef convert_to_ndc(origins: _Array,\n                   directions: _Array,\n                   pixtocam: _Array,\n                   near: float = 1.,\n                   xnp: types.ModuleType = np) -> Tuple[_Array, _Array]:\n  \"\"\"Converts a set of rays to normalized device coordinates (NDC).\n\n  Args:\n    origins: ndarray(float32), [..., 3], world space ray origins.\n    directions: ndarray(float32), [..., 3], world space ray directions.\n    pixtocam: ndarray(float32), [3, 3], inverse intrinsic matrix.\n    near: float, near plane along the negative z axis.\n    xnp: either numpy or jax.numpy.\n\n  Returns:\n    origins_ndc: ndarray(float32), [..., 3].\n    directions_ndc: ndarray(float32), [..., 3].\n\n  This function assumes input rays should be mapped into the NDC space for a\n  perspective projection pinhole camera, with identity extrinsic matrix (pose)\n  and intrinsic parameters defined by inputs focal, width, and height.\n\n  The near value specifies the near plane of the frustum, and the far plane is\n  assumed to be infinity.\n\n  The ray bundle for the identity pose camera will be remapped to parallel rays\n  within the (-1, -1, -1) to (1, 1, 1) cube. Any other ray in the original\n  world space can be remapped as long as it has dz < 0 (ray direction has a\n  negative z-coord); this allows us to share a common NDC space for \"forward\n  facing\" scenes.\n\n  Note that\n      projection(origins + t * directions)\n  will NOT be equal to\n      origins_ndc + t * directions_ndc\n  and that the directions_ndc are not unit length. Rather, directions_ndc is\n  defined such that the valid near and far planes in NDC will be 0 and 1.\n\n  See Appendix C in https://arxiv.org/abs/2003.08934 for additional details.\n  \"\"\"\n\n  # Shift ray origins to near plane, such that oz = -near.\n  # This makes the new near bound equal to 0.\n  t = -(near + origins[..., 2]) / directions[..., 2]\n  origins = origins + t[..., None] * directions\n\n  dx, dy, dz = xnp.moveaxis(directions, -1, 0)\n  ox, oy, oz = xnp.moveaxis(origins, -1, 0)\n\n  xmult = 1. / pixtocam[0, 2]  # Equal to -2. * focal / cx\n  ymult = 1. / pixtocam[1, 2]  # Equal to -2. * focal / cy\n\n  # Perspective projection into NDC for the t = 0 near points\n  #     origins + 0 * directions\n  origins_ndc = xnp.stack([xmult * ox / oz, ymult * oy / oz,\n                           -xnp.ones_like(oz)], axis=-1)\n\n  # Perspective projection into NDC for the t = infinity far points\n  #     origins + infinity * directions\n  infinity_ndc = np.stack([xmult * dx / dz, ymult * dy / dz,\n                           xnp.ones_like(oz)],\n                          axis=-1)\n\n  # directions_ndc points from origins_ndc to infinity_ndc\n  directions_ndc = infinity_ndc - origins_ndc\n\n  return origins_ndc, directions_ndc\n\n\ndef pad_poses(p: np.ndarray) -> np.ndarray:\n  \"\"\"Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].\"\"\"\n  bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)\n  return np.concatenate([p[..., :3, :4], bottom], axis=-2)\n\n\ndef unpad_poses(p: np.ndarray) -> np.ndarray:\n  \"\"\"Remove the homogeneous bottom row from [..., 4, 4] pose matrices.\"\"\"\n  return p[..., :3, :4]\n\n\ndef recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n  \"\"\"Recenter poses around the origin.\"\"\"\n  cam2world = average_pose(poses)\n  transform = np.linalg.inv(pad_poses(cam2world))\n  poses = transform @ pad_poses(poses)\n  return unpad_poses(poses), transform\n\n\ndef average_pose(poses: np.ndarray) -> np.ndarray:\n  \"\"\"New pose using average position, z-axis, and up vector of input poses.\"\"\"\n  position = poses[:, :3, 3].mean(0)\n  z_axis = poses[:, :3, 2].mean(0)\n  up = poses[:, :3, 1].mean(0)\n  cam2world = viewmatrix(z_axis, up, position)\n  return cam2world\n\n\ndef viewmatrix(lookdir: np.ndarray, up: np.ndarray,\n               position: np.ndarray) -> np.ndarray:\n  \"\"\"Construct lookat view matrix.\"\"\"\n  vec2 = normalize(lookdir)\n  vec0 = normalize(np.cross(up, vec2))\n  vec1 = normalize(np.cross(vec2, vec0))\n  m = np.stack([vec0, vec1, vec2, position], axis=1)\n  return m\n\n\ndef normalize(x: np.ndarray) -> np.ndarray:\n  \"\"\"Normalization helper function.\"\"\"\n  return x / np.linalg.norm(x)\n\n\ndef focus_point_fn(poses: np.ndarray) -> np.ndarray:\n  \"\"\"Calculate nearest point to all focal axes in poses.\"\"\"\n  directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]\n  m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])\n  mt_m = np.transpose(m, [0, 2, 1]) @ m\n  focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]\n  return focus_pt\n\n\n# Constants for generate_spiral_path():\nNEAR_STRETCH = .9  # Push forward near bound for forward facing render path.\nFAR_STRETCH = 5.  # Push back far bound for forward facing render path.\nFOCUS_DISTANCE = .75  # Relative weighting of near, far bounds for render path.\n\n\ndef generate_spiral_path(poses: np.ndarray,\n                         bounds: np.ndarray,\n                         n_frames: int = 120,\n                         n_rots: int = 2,\n                         zrate: float = .5) -> np.ndarray:\n  \"\"\"Calculates a forward facing spiral path for rendering.\"\"\"\n  # Find a reasonable 'focus depth' for this dataset as a weighted average\n  # of conservative near and far bounds in disparity space.\n  near_bound = bounds.min() * NEAR_STRETCH\n  far_bound = bounds.max() * FAR_STRETCH\n  # All cameras will point towards the world space point (0, 0, -focal).\n  focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound))\n\n  # Get radii for spiral path using 90th percentile of camera positions.\n  positions = poses[:, :3, 3]\n  radii = np.percentile(np.abs(positions), 90, 0)\n  radii = np.concatenate([radii, [1.]])\n\n  # Generate poses for spiral path.\n  render_poses = []\n  cam2world = average_pose(poses)\n  up = poses[:, :3, 1].mean(0)\n  for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):\n    t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]\n    position = cam2world @ t\n    lookat = cam2world @ [0, 0, -focal, 1.]\n    z_axis = position - lookat\n    render_poses.append(viewmatrix(z_axis, up, position))\n  render_poses = np.stack(render_poses, axis=0)\n  return render_poses\n\n\ndef transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n  \"\"\"Transforms poses so principal components lie on XYZ axes.\n\n  Args:\n    poses: a (N, 3, 4) array containing the cameras' camera to world transforms.\n\n  Returns:\n    A tuple (poses, transform), with the transformed poses and the applied\n    camera_to_world transforms.\n  \"\"\"\n  t = poses[:, :3, 3]\n  t_mean = t.mean(axis=0)\n  t = t - t_mean\n\n  eigval, eigvec = np.linalg.eig(t.T @ t)\n  # Sort eigenvectors in order of largest to smallest eigenvalue.\n  inds = np.argsort(eigval)[::-1]\n  eigvec = eigvec[:, inds]\n  rot = eigvec.T\n  if np.linalg.det(rot) < 0:\n    rot = np.diag(np.array([1, 1, -1])) @ rot\n\n  transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)\n  poses_recentered = unpad_poses(transform @ pad_poses(poses))\n  transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)\n\n  # Flip coordinate system if z component of y-axis is negative\n  if poses_recentered.mean(axis=0)[2, 1] < 0:\n    poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered\n    transform = np.diag(np.array([1, -1, -1, 1])) @ transform\n\n  # Just make sure it's it in the [-1, 1]^3 cube\n  scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))\n  poses_recentered[:, :3, 3] *= scale_factor\n  transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform\n\n  return poses_recentered, transform\n\n\ndef generate_ellipse_path(poses: np.ndarray,\n                          n_frames: int = 120,\n                          const_speed: bool = True,\n                          z_variation: float = 0.,\n                          z_phase: float = 0.) -> np.ndarray:\n  \"\"\"Generate an elliptical render path based on the given poses.\"\"\"\n  # Calculate the focal point for the path (cameras point toward this).\n  center = focus_point_fn(poses)\n  # Path height sits at z=0 (in middle of zero-mean capture pattern).\n  offset = np.array([center[0], center[1], 0])\n\n  # Calculate scaling for ellipse axes based on input camera positions.\n  sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)\n  # Use ellipse that is symmetric about the focal point in xy.\n  low = -sc + offset\n  high = sc + offset\n  # Optional height variation need not be symmetric\n  z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)\n  z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)\n\n  def get_positions(theta):\n    # Interpolate between bounds with trig functions to get ellipse in x-y.\n    # Optionally also interpolate in z to change camera height along path.\n    return np.stack([\n        low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),\n        low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),\n        z_variation * (z_low[2] + (z_high - z_low)[2] *\n                       (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),\n    ], -1)\n\n  theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)\n  positions = get_positions(theta)\n\n  if const_speed:\n    # Resample theta angles so that the velocity is closer to constant.\n    lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)\n    theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)\n    positions = get_positions(theta)\n\n  # Throw away duplicated last position.\n  positions = positions[:-1]\n\n  # Set path's up vector to axis closest to average of input pose up vectors.\n  avg_up = poses[:, :3, 1].mean(0)\n  avg_up = avg_up / np.linalg.norm(avg_up)\n  ind_up = np.argmax(np.abs(avg_up))\n  up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])\n\n  return np.stack([viewmatrix(p - center, up, p) for p in positions])\n\n\ndef generate_interpolated_path(poses: np.ndarray,\n                               n_interp: int,\n                               spline_degree: int = 5,\n                               smoothness: float = .03,\n                               rot_weight: float = .1):\n  \"\"\"Creates a smooth spline path between input keyframe camera poses.\n\n  Spline is calculated with poses in format (position, lookat-point, up-point).\n\n  Args:\n    poses: (n, 3, 4) array of input pose keyframes.\n    n_interp: returned path will have n_interp * (n - 1) total poses.\n    spline_degree: polynomial degree of B-spline.\n    smoothness: parameter for spline smoothing, 0 forces exact interpolation.\n    rot_weight: relative weighting of rotation/translation in spline solve.\n\n  Returns:\n    Array of new camera poses with shape (n_interp * (n - 1), 3, 4).\n  \"\"\"\n\n  def poses_to_points(poses, dist):\n    \"\"\"Converts from pose matrices to (position, lookat, up) format.\"\"\"\n    pos = poses[:, :3, -1]\n    lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]\n    up = poses[:, :3, -1] + dist * poses[:, :3, 1]\n    return np.stack([pos, lookat, up], 1)\n\n  def points_to_poses(points):\n    \"\"\"Converts from (position, lookat, up) format to pose matrices.\"\"\"\n    return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])\n\n  def interp(points, n, k, s):\n    \"\"\"Runs multidimensional B-spline interpolation on the input points.\"\"\"\n    sh = points.shape\n    pts = np.reshape(points, (sh[0], -1))\n    k = min(k, sh[0] - 1)\n    tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)\n    u = np.linspace(0, 1, n, endpoint=False)\n    new_points = np.array(scipy.interpolate.splev(u, tck))\n    new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))\n    return new_points\n\n  points = poses_to_points(poses, dist=rot_weight)\n  new_points = interp(points,\n                      n_interp * (points.shape[0] - 1),\n                      k=spline_degree,\n                      s=smoothness)\n  return points_to_poses(new_points)\n\n\ndef interpolate_1d(x: np.ndarray,\n                   n_interp: int,\n                   spline_degree: int,\n                   smoothness: float) -> np.ndarray:\n  \"\"\"Interpolate 1d signal x (by a factor of n_interp times).\"\"\"\n  t = np.linspace(0, 1, len(x), endpoint=True)\n  tck = scipy.interpolate.splrep(t, x, s=smoothness, k=spline_degree)\n  n = n_interp * (len(x) - 1)\n  u = np.linspace(0, 1, n, endpoint=False)\n  return scipy.interpolate.splev(u, tck)\n\n\ndef create_render_spline_path(\n    config: configs.Config,\n    image_names: Union[Text, List[Text]],\n    poses: np.ndarray,\n    exposures: Optional[np.ndarray],\n) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:\n  \"\"\"Creates spline interpolation render path from subset of dataset poses.\n\n  Args:\n    config: configs.Config object.\n    image_names: either a directory of images or a text file of image names.\n    poses: [N, 3, 4] array of extrinsic camera pose matrices.\n    exposures: optional list of floating point exposure values.\n\n  Returns:\n    spline_indices: list of indices used to select spline keyframe poses.\n    render_poses: array of interpolated extrinsic camera poses for the path.\n    render_exposures: optional list of interpolated exposures for the path.\n  \"\"\"\n  if utils.isdir(config.render_spline_keyframes):\n    # If directory, use image filenames.\n    keyframe_names = sorted(utils.listdir(config.render_spline_keyframes))\n  else:\n    # If text file, treat each line as an image filename.\n    with utils.open_file(config.render_spline_keyframes, 'r') as fp:\n      # Decode bytes into string and split into lines.\n      keyframe_names = fp.read().decode('utf-8').splitlines()\n  # Grab poses corresponding to the image filenames.\n  spline_indices = np.array(\n      [i for i, n in enumerate(image_names) if n in keyframe_names])\n  keyframes = poses[spline_indices]\n  render_poses = generate_interpolated_path(\n      keyframes,\n      n_interp=config.render_spline_n_interp,\n      spline_degree=config.render_spline_degree,\n      smoothness=config.render_spline_smoothness,\n      rot_weight=.1)\n  if config.render_spline_interpolate_exposure:\n    if exposures is None:\n      raise ValueError('config.render_spline_interpolate_exposure is True but '\n                       'create_render_spline_path() was passed exposures=None.')\n    # Interpolate per-frame exposure value.\n    log_exposure = np.log(exposures[spline_indices])\n    # Use aggressive smoothing for exposure interpolation to avoid flickering.\n    log_exposure_interp = interpolate_1d(\n        log_exposure,\n        config.render_spline_n_interp,\n        spline_degree=5,\n        smoothness=20)\n    render_exposures = np.exp(log_exposure_interp)\n  else:\n    render_exposures = None\n  return spline_indices, render_poses, render_exposures\n\n\ndef intrinsic_matrix(fx: float,\n                     fy: float,\n                     cx: float,\n                     cy: float,\n                     xnp: types.ModuleType = np) -> _Array:\n  \"\"\"Intrinsic matrix for a pinhole camera in OpenCV coordinate system.\"\"\"\n  return xnp.array([\n      [fx, 0, cx],\n      [0, fy, cy],\n      [0, 0, 1.],\n  ])\n\n\ndef get_pixtocam(focal: float,\n                 width: float,\n                 height: float,\n                 xnp: types.ModuleType = np) -> _Array:\n  \"\"\"Inverse intrinsic matrix for a perfect pinhole camera.\"\"\"\n  camtopix = intrinsic_matrix(focal, focal, width * .5, height * .5, xnp)\n  return xnp.linalg.inv(camtopix)\n\n\ndef pixel_coordinates(width: int,\n                      height: int,\n                      xnp: types.ModuleType = np) -> Tuple[_Array, _Array]:\n  \"\"\"Tuple of the x and y integer coordinates for a grid of pixels.\"\"\"\n  return xnp.meshgrid(xnp.arange(width), xnp.arange(height), indexing='xy')\n\n\ndef _compute_residual_and_jacobian(\n    x: _Array,\n    y: _Array,\n    xd: _Array,\n    yd: _Array,\n    k1: float = 0.0,\n    k2: float = 0.0,\n    k3: float = 0.0,\n    k4: float = 0.0,\n    p1: float = 0.0,\n    p2: float = 0.0,\n) -> Tuple[_Array, _Array, _Array, _Array, _Array, _Array]:\n  \"\"\"Auxiliary function of radial_and_tangential_undistort().\"\"\"\n  # Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py\n  # let r(x, y) = x^2 + y^2;\n  #     d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +\n  #                   k4 * r(x, y)^4;\n  r = x * x + y * y\n  d = 1.0 + r * (k1 + r * (k2 + r * (k3  + r * k4)))\n\n  # The perfect projection is:\n  # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);\n  # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);\n  #\n  # Let's define\n  #\n  # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;\n  # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;\n  #\n  # We are looking for a solution that satisfies\n  # fx(x, y) = fy(x, y) = 0;\n  fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd\n  fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd\n\n  # Compute derivative of d over [x, y]\n  d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)))\n  d_x = 2.0 * x * d_r\n  d_y = 2.0 * y * d_r\n\n  # Compute derivative of fx over x and y.\n  fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x\n  fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y\n\n  # Compute derivative of fy over x and y.\n  fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x\n  fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y\n\n  return fx, fy, fx_x, fx_y, fy_x, fy_y\n\n\ndef _radial_and_tangential_undistort(\n    xd: _Array,\n    yd: _Array,\n    k1: float = 0,\n    k2: float = 0,\n    k3: float = 0,\n    k4: float = 0,\n    p1: float = 0,\n    p2: float = 0,\n    eps: float = 1e-9,\n    max_iterations=10,\n    xnp: types.ModuleType = np) -> Tuple[_Array, _Array]:\n  \"\"\"Computes undistorted (x, y) from (xd, yd).\"\"\"\n  # From https://github.com/google/nerfies/blob/main/nerfies/camera.py\n  # Initialize from the distorted point.\n  x = xnp.copy(xd)\n  y = xnp.copy(yd)\n\n  for _ in range(max_iterations):\n    fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(\n        x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, k4=k4, p1=p1, p2=p2)\n    denominator = fy_x * fx_y - fx_x * fy_y\n    x_numerator = fx * fy_y - fy * fx_y\n    y_numerator = fy * fx_x - fx * fy_x\n    step_x = xnp.where(\n        xnp.abs(denominator) > eps, x_numerator / denominator,\n        xnp.zeros_like(denominator))\n    step_y = xnp.where(\n        xnp.abs(denominator) > eps, y_numerator / denominator,\n        xnp.zeros_like(denominator))\n\n    x = x + step_x\n    y = y + step_y\n\n  return x, y\n\n\nclass ProjectionType(enum.Enum):\n  \"\"\"Camera projection type (standard perspective pinhole or fisheye model).\"\"\"\n  PERSPECTIVE = 'perspective'\n  FISHEYE = 'fisheye'\n\n\ndef pixels_to_rays(\n    pix_x_int: _Array,\n    pix_y_int: _Array,\n    pixtocams: _Array,\n    camtoworlds: _Array,\n    distortion_params: Optional[Mapping[str, float]] = None,\n    pixtocam_ndc: Optional[_Array] = None,\n    camtype: ProjectionType = ProjectionType.PERSPECTIVE,\n    xnp: types.ModuleType = np,\n) -> Tuple[_Array, _Array, _Array, _Array, _Array]:\n  \"\"\"Calculates rays given pixel coordinates, intrinisics, and extrinsics.\n\n  Given 2D pixel coordinates pix_x_int, pix_y_int for cameras with\n  inverse intrinsics pixtocams and extrinsics camtoworlds (and optional\n  distortion coefficients distortion_params and NDC space projection matrix\n  pixtocam_ndc), computes the corresponding 3D camera rays.\n\n  Vectorized over the leading dimensions of the first four arguments.\n\n  Args:\n    pix_x_int: int array, shape SH, x coordinates of image pixels.\n    pix_y_int: int array, shape SH, y coordinates of image pixels.\n    pixtocams: float array, broadcastable to SH + [3, 3], inverse intrinsics.\n    camtoworlds: float array, broadcastable to SH + [3, 4], camera extrinsics.\n    distortion_params: dict of floats, optional camera distortion parameters.\n    pixtocam_ndc: float array, [3, 3], optional inverse intrinsics for NDC.\n    camtype: camera_utils.ProjectionType, fisheye or perspective camera.\n    xnp: either numpy or jax.numpy.\n\n  Returns:\n    origins: float array, shape SH + [3], ray origin points.\n    directions: float array, shape SH + [3], ray direction vectors.\n    viewdirs: float array, shape SH + [3], normalized ray direction vectors.\n    radii: float array, shape SH + [1], ray differential radii.\n    imageplane: float array, shape SH + [2], xy coordinates on the image plane.\n      If the image plane is at world space distance 1 from the pinhole, then\n      imageplane will be the xy coordinates of a pixel in that space (so the\n      camera ray direction at the origin would be (x, y, -1) in OpenGL coords).\n  \"\"\"\n  # Must add half pixel offset to shoot rays through pixel centers.\n  def pix_to_dir(x, y):\n    return xnp.stack([x + .5, y + .5, xnp.ones_like(x)], axis=-1)\n  # We need the dx and dy rays to calculate ray radii for mip-NeRF cones.\n  pixel_dirs_stacked = xnp.stack([\n      pix_to_dir(pix_x_int, pix_y_int),\n      pix_to_dir(pix_x_int + 1, pix_y_int),\n      pix_to_dir(pix_x_int, pix_y_int + 1)\n  ], axis=0)\n\n  # For jax, need to specify high-precision matmul.\n  matmul = math.matmul if xnp == jnp else xnp.matmul\n  mat_vec_mul = lambda A, b: matmul(A, b[..., None])[..., 0]\n\n  # Apply inverse intrinsic matrices.\n  camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked)\n\n  if distortion_params is not None:\n    # Correct for distortion.\n    x, y = _radial_and_tangential_undistort(\n        camera_dirs_stacked[..., 0],\n        camera_dirs_stacked[..., 1],\n        **distortion_params,\n        xnp=xnp)\n    camera_dirs_stacked = xnp.stack([x, y, xnp.ones_like(x)], -1)\n\n  if camtype == ProjectionType.FISHEYE:\n    theta = xnp.sqrt(xnp.sum(xnp.square(camera_dirs_stacked[..., :2]), axis=-1))\n    theta = xnp.minimum(xnp.pi, theta)\n\n    sin_theta_over_theta = xnp.sin(theta) / theta\n    camera_dirs_stacked = xnp.stack([\n        camera_dirs_stacked[..., 0] * sin_theta_over_theta,\n        camera_dirs_stacked[..., 1] * sin_theta_over_theta,\n        xnp.cos(theta),\n    ], axis=-1)\n\n  # Flip from OpenCV to OpenGL coordinate system.\n  camera_dirs_stacked = matmul(camera_dirs_stacked,\n                               xnp.diag(xnp.array([1., -1., -1.])))\n\n  # Extract 2D image plane (x, y) coordinates.\n  imageplane = camera_dirs_stacked[0, ..., :2]\n\n  # Apply camera rotation matrices.\n  directions_stacked = mat_vec_mul(camtoworlds[..., :3, :3],\n                                   camera_dirs_stacked)\n  # Extract the offset rays.\n  directions, dx, dy = directions_stacked\n\n  origins = xnp.broadcast_to(camtoworlds[..., :3, -1], directions.shape)\n  viewdirs = directions / xnp.linalg.norm(directions, axis=-1, keepdims=True)\n\n  if pixtocam_ndc is None:\n    # Distance from each unit-norm direction vector to its neighbors.\n    dx_norm = xnp.linalg.norm(dx - directions, axis=-1)\n    dy_norm = xnp.linalg.norm(dy - directions, axis=-1)\n\n  else:\n    # Convert ray origins and directions into projective NDC space.\n    origins_dx, _ = convert_to_ndc(origins, dx, pixtocam_ndc)\n    origins_dy, _ = convert_to_ndc(origins, dy, pixtocam_ndc)\n    origins, directions = convert_to_ndc(origins, directions, pixtocam_ndc)\n\n    # In NDC space, we use the offset between origins instead of directions.\n    dx_norm = xnp.linalg.norm(origins_dx - origins, axis=-1)\n    dy_norm = xnp.linalg.norm(origins_dy - origins, axis=-1)\n\n  # Cut the distance in half, multiply it to match the variance of a uniform\n  # distribution the size of a pixel (1/12, see the original mipnerf paper).\n  radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / xnp.sqrt(12)\n\n  return origins, directions, viewdirs, radii, imageplane\n\n\ndef cast_ray_batch(\n    cameras: Tuple[_Array, ...],\n    pixels: utils.Pixels,\n    camtype: ProjectionType = ProjectionType.PERSPECTIVE,\n    xnp: types.ModuleType = np) -> utils.Rays:\n  \"\"\"Maps from input cameras and Pixel batch to output Ray batch.\n\n  `cameras` is a Tuple of four sets of camera parameters.\n    pixtocams: 1 or N stacked [3, 3] inverse intrinsic matrices.\n    camtoworlds: 1 or N stacked [3, 4] extrinsic pose matrices.\n    distortion_params: optional, dict[str, float] containing pinhole model\n      distortion parameters.\n    pixtocam_ndc: optional, [3, 3] inverse intrinsic matrix for mapping to NDC.\n\n  Args:\n    cameras: described above.\n    pixels: integer pixel coordinates and camera indices, plus ray metadata.\n      These fields can be an arbitrary batch shape.\n    camtype: camera_utils.ProjectionType, fisheye or perspective camera.\n    xnp: either numpy or jax.numpy.\n\n  Returns:\n    rays: Rays dataclass with computed 3D world space ray data.\n  \"\"\"\n  pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras\n\n  # pixels.cam_idx has shape [..., 1], remove this hanging dimension.\n  cam_idx = pixels.cam_idx[..., 0]\n  batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx]\n\n  # Compute rays from pixel coordinates.\n  origins, directions, viewdirs, radii, imageplane = pixels_to_rays(\n      pixels.pix_x_int,\n      pixels.pix_y_int,\n      batch_index(pixtocams),\n      batch_index(camtoworlds),\n      distortion_params=distortion_params,\n      pixtocam_ndc=pixtocam_ndc,\n      camtype=camtype,\n      xnp=xnp)\n\n  # Create Rays data structure.\n  return utils.Rays(\n      origins=origins,\n      directions=directions,\n      viewdirs=viewdirs,\n      radii=radii,\n      imageplane=imageplane,\n      lossmult=pixels.lossmult,\n      near=pixels.near,\n      far=pixels.far,\n      cam_idx=pixels.cam_idx,\n      exposure_idx=pixels.exposure_idx,\n      exposure_values=pixels.exposure_values,\n  )\n\n\ndef cast_pinhole_rays(camtoworld: _Array,\n                      height: int,\n                      width: int,\n                      focal: float,\n                      near: float,\n                      far: float,\n                      xnp: types.ModuleType) -> utils.Rays:\n  \"\"\"Wrapper for generating a pinhole camera ray batch (w/o distortion).\"\"\"\n\n  pix_x_int, pix_y_int = pixel_coordinates(width, height, xnp=xnp)\n  pixtocam = get_pixtocam(focal, width, height, xnp=xnp)\n\n  ray_args = pixels_to_rays(pix_x_int, pix_y_int, pixtocam, camtoworld, xnp=xnp)\n\n  broadcast_scalar = lambda x: xnp.broadcast_to(x, pix_x_int.shape)[..., None]\n  ray_kwargs = {\n      'lossmult': broadcast_scalar(1.),\n      'near': broadcast_scalar(near),\n      'far': broadcast_scalar(far),\n      'cam_idx': broadcast_scalar(0),\n  }\n\n  return utils.Rays(*ray_args, **ray_kwargs)\n\n\ndef cast_spherical_rays(camtoworld: _Array,\n                        height: int,\n                        width: int,\n                        near: float,\n                        far: float,\n                        xnp: types.ModuleType) -> utils.Rays:\n  \"\"\"Generates a spherical camera ray batch.\"\"\"\n\n  theta_vals = xnp.linspace(0, 2 * xnp.pi, width + 1)\n  phi_vals = xnp.linspace(0, xnp.pi, height + 1)\n  theta, phi = xnp.meshgrid(theta_vals, phi_vals, indexing='xy')\n\n  # Spherical coordinates in camera reference frame (y is up).\n  directions = xnp.stack([\n      -xnp.sin(phi) * xnp.sin(theta),\n      xnp.cos(phi),\n      xnp.sin(phi) * xnp.cos(theta),\n  ],\n                         axis=-1)\n\n  # For jax, need to specify high-precision matmul.\n  matmul = math.matmul if xnp == jnp else xnp.matmul\n  directions = matmul(camtoworld[:3, :3], directions[..., None])[..., 0]\n\n  dy = xnp.diff(directions[:, :-1], axis=0)\n  dx = xnp.diff(directions[:-1, :], axis=1)\n  directions = directions[:-1, :-1]\n  viewdirs = directions\n\n  origins = xnp.broadcast_to(camtoworld[:3, -1], directions.shape)\n\n  dx_norm = xnp.linalg.norm(dx, axis=-1)\n  dy_norm = xnp.linalg.norm(dy, axis=-1)\n  radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / xnp.sqrt(12)\n\n  imageplane = xnp.zeros_like(directions[..., :2])\n\n  ray_args = (origins, directions, viewdirs, radii, imageplane)\n\n  broadcast_scalar = lambda x: xnp.broadcast_to(x, radii.shape[:-1])[..., None]\n  ray_kwargs = {\n      'lossmult': broadcast_scalar(1.),\n      'near': broadcast_scalar(near),\n      'far': broadcast_scalar(far),\n      'cam_idx': broadcast_scalar(0),\n  }\n\n  return utils.Rays(*ray_args, **ray_kwargs)\n"
  },
  {
    "path": "mip360/internal/configs.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utility functions for handling configurations.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Callable, Optional, Tuple\n\nfrom absl import flags\nfrom flax.core import FrozenDict\nimport gin\nfrom internal import utils\nimport jax\nimport jax.numpy as jnp\n\ngin.add_config_file_search_path('experimental/users/barron/mipnerf360/')\n\nconfigurables = {\n    'jnp': [jnp.reciprocal, jnp.log, jnp.log1p, jnp.exp, jnp.sqrt, jnp.square],\n    'jax.nn': [jax.nn.relu, jax.nn.softplus, jax.nn.silu],\n    'jax.nn.initializers.he_normal': [jax.nn.initializers.he_normal()],\n    'jax.nn.initializers.he_uniform': [jax.nn.initializers.he_uniform()],\n    'jax.nn.initializers.glorot_normal': [jax.nn.initializers.glorot_normal()],\n    'jax.nn.initializers.glorot_uniform': [\n        jax.nn.initializers.glorot_uniform()\n    ],\n}\n\nfor module, configurables in configurables.items():\n  for configurable in configurables:\n    gin.config.external_configurable(configurable, module=module)\n\n\n@gin.configurable()\n@dataclasses.dataclass\nclass Config:\n  \"\"\"Configuration flags for everything.\"\"\"\n  dataset_loader: str = 'llff'  # The type of dataset loader to use.\n  batching: str = 'all_images'  # Batch composition, [single_image, all_images].\n  batch_size: int = 16384  # The number of rays/pixels in each batch.\n  patch_size: int = 1  # Resolution of patches sampled for training batches.\n  factor: int = 0  # The downsample factor of images, 0 for no downsampling.\n  load_alphabetical: bool = True  # Load images in COLMAP vs alphabetical\n  # ordering (affects heldout test set).\n  forward_facing: bool = False  # Set to True for forward-facing LLFF captures.\n  render_path: bool = False  # If True, render a path. Used only by LLFF.\n  llffhold: int = 8  # Use every Nth image for the test set. Used only by LLFF.\n  # If true, use all input images for training.\n  llff_use_all_images_for_training: bool = False\n  use_tiffs: bool = False  # If True, use 32-bit TIFFs. Used only by Blender.\n  compute_disp_metrics: bool = False  # If True, load and compute disparity MSE.\n  compute_normal_metrics: bool = False  # If True, load and compute normal MAE.\n  gc_every: int = 10000  # The number of steps between garbage collections.\n  disable_multiscale_loss: bool = False  # If True, disable multiscale loss.\n  randomized: bool = True  # Use randomized stratified sampling.\n  near: float = 2.  # Near plane distance.\n  far: float = 6.  # Far plane distance.\n  checkpoint_dir: Optional[str] = None  # Where to log checkpoints.\n  render_dir: Optional[str] = None  # Output rendering directory.\n  data_dir: Optional[str] = None  # Input data directory.\n  vocab_tree_path: Optional[str] = None  # Path to vocab tree for COLMAP.\n  render_chunk_size: int = 16384  # Chunk size for whole-image renderings.\n  num_showcase_images: int = 5  # The number of test-set images to showcase.\n  deterministic_showcase: bool = True  # If True, showcase the same images.\n  vis_num_rays: int = 16  # The number of rays to visualize.\n  # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage.\n  vis_decimate: int = 0\n\n  # Only used by train.py:\n  max_steps: int = 250000  # The number of optimization steps.\n  early_exit_steps: Optional[int] = None  # Early stopping, for debugging.\n  checkpoint_every: int = 25000  # The number of steps to save a checkpoint.\n  print_every: int = 100  # The number of steps between reports to tensorboard.\n  train_render_every: int = 5000  # Steps between test set renders when training\n  cast_rays_in_train_step: bool = False  # If True, compute rays in train step.\n  data_loss_type: str = 'charb'  # What kind of loss to use ('mse' or 'charb').\n  charb_padding: float = 0.001  # The padding used for Charbonnier loss.\n  data_loss_mult: float = 1.0  # Mult for the finest data term in the loss.\n  data_coarse_loss_mult: float = 0.  # Multiplier for the coarser data terms.\n  interlevel_loss_mult: float = 1.0  # Mult. for the loss on the proposal MLP.\n  orientation_loss_mult: float = 0.0  # Multiplier on the orientation loss.\n  orientation_coarse_loss_mult: float = 0.0  # Coarser orientation loss weights.\n  # What that loss is imposed on, options are 'normals' or 'normals_pred'.\n  orientation_loss_target: str = 'normals_pred'\n  predicted_normal_loss_mult: float = 0.0  # Mult. on the predicted normal loss.\n  # Mult. on the coarser predicted normal loss.\n  predicted_normal_coarse_loss_mult: float = 0.0\n  weight_decay_mults: FrozenDict[str, Any] = FrozenDict({})  # Weight decays.\n  # An example that regularizes the NeRF and the first layer of the prop MLP:\n  #   weight_decay_mults = {\n  #       'NerfMLP_0': 0.00001,\n  #       'PropMLP_0/Dense_0': 0.001,\n  #   }\n  # Any model parameter that isn't specified gets a mult of 0. See the\n  # train_weight_l2_* parameters in TensorBoard to know what can be regularized.\n\n  lr_init: float = 0.002  # The initial learning rate.\n  lr_final: float = 0.00002  # The final learning rate.\n  lr_delay_steps: int = 512  # The number of \"warmup\" learning steps.\n  lr_delay_mult: float = 0.01  # How much sever the \"warmup\" should be.\n  adam_beta1: float = 0.9  # Adam's beta2 hyperparameter.\n  adam_beta2: float = 0.999  # Adam's beta2 hyperparameter.\n  adam_eps: float = 1e-6  # Adam's epsilon hyperparameter.\n  grad_max_norm: float = 0.001  # Gradient clipping magnitude, disabled if == 0.\n  grad_max_val: float = 0.  # Gradient clipping value, disabled if == 0.\n  distortion_loss_mult: float = 0.01  # Multiplier on the distortion loss.\n\n  # Only used by eval.py:\n  eval_only_once: bool = True  # If True evaluate the model only once, ow loop.\n  eval_save_output: bool = True  # If True save predicted images to disk.\n  eval_save_ray_data: bool = False  # If True save individual ray traces.\n  eval_render_interval: int = 1  # The interval between images saved to disk.\n  eval_dataset_limit: int = jnp.iinfo(jnp.int32).max  # Num test images to eval.\n  eval_one: Optional[tuple] = None  # eval one image or one ray\n  eval_quantize_metrics: bool = True  # If True, run metrics on 8-bit images.\n  eval_crop_borders: int = 0  # Ignore c border pixels in eval (x[c:-c, c:-c]).\n\n  # Only used by render.py\n  render_video_fps: int = 60  # Framerate in frames-per-second.\n  render_video_crf: int = 18  # Constant rate factor for ffmpeg video quality.\n  render_path_frames: int = 120  # Number of frames in render path.\n  z_variation: float = 0.  # How much height variation in render path.\n  z_phase: float = 0.  # Phase offset for height variation in render path.\n  render_dist_percentile: float = 0.5  # How much to trim from near/far planes.\n  render_dist_curve_fn: Callable[..., Any] = jnp.log  # How depth is curved.\n  render_path_file: Optional[str] = None  # Numpy render pose file to load.\n  render_job_id: int = 0  # Render job id.\n  render_num_jobs: int = 1  # Total number of render jobs.\n  render_resolution: Optional[Tuple[int, int]] = None  # Render resolution, as\n  # (width, height).\n  render_focal: Optional[float] = None  # Render focal length.\n  render_camtype: Optional[str] = None  # 'perspective', 'fisheye', or 'pano'.\n  render_spherical: bool = False  # Render spherical 360 panoramas.\n  render_save_async: bool = True  # Save to CNS using a separate thread.\n\n  render_spline_keyframes: Optional[str] = None  # Text file containing names of\n  # images to be used as spline\n  # keyframes, OR directory\n  # containing those images.\n  render_spline_n_interp: int = 30  # Num. frames to interpolate per keyframe.\n  render_spline_degree: int = 5  # Polynomial degree of B-spline interpolation.\n  render_spline_smoothness: float = .03  # B-spline smoothing factor, 0 for\n  # exact interpolation of keyframes.\n  # Interpolate per-frame exposure value from spline keyframes.\n  render_spline_interpolate_exposure: bool = False\n\n  # Flags for raw datasets.\n  rawnerf_mode: bool = False  # Load raw images and train in raw color space.\n  exposure_percentile: float = 97.  # Image percentile to expose as white.\n  num_border_pixels_to_mask: int = 0  # During training, discard N-pixel border\n  # around each input image.\n  apply_bayer_mask: bool = False  # During training, apply Bayer mosaic mask.\n  autoexpose_renders: bool = False  # During rendering, autoexpose each image.\n  # For raw test scenes, use affine raw-space color correction.\n  eval_raw_affine_cc: bool = False\n\n\ndef define_common_flags():\n  # Define the flags used by both train.py and eval.py\n  flags.DEFINE_string('mode', None, 'Required by GINXM, not used.')\n  flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.')\n  flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.')\n  flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.')\n\n\ndef load_config(save_config=True):\n  \"\"\"Load the config, and optionally checkpoint it.\"\"\"\n  gin.parse_config_files_and_bindings(\n      flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True)\n  config = Config()\n  if save_config and jax.host_id() == 0:\n    utils.makedirs(config.checkpoint_dir)\n    with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f:\n      f.write(gin.config_str())\n  return config\n"
  },
  {
    "path": "mip360/internal/coord.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tools for manipulating coordinate spaces and distances along rays.\"\"\"\n\nfrom internal import math\nimport jax\nimport jax.numpy as jnp\n\n\ndef contract(x):\n  \"\"\"Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).\"\"\"\n  eps = jnp.finfo(jnp.float32).eps\n  # Clamping to eps prevents non-finite gradients when x == 0.\n  x_mag_sq = jnp.maximum(eps, jnp.sum(x**2, axis=-1, keepdims=True))\n  z = jnp.where(x_mag_sq <= 1, x, ((2 * jnp.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)\n  return z\n\n\ndef inv_contract(z):\n  \"\"\"The inverse of contract().\"\"\"\n  eps = jnp.finfo(jnp.float32).eps\n  # Clamping to eps prevents non-finite gradients when z == 0.\n  z_mag_sq = jnp.maximum(eps, jnp.sum(z**2, axis=-1, keepdims=True))\n  x = jnp.where(z_mag_sq <= 1, z, z / (2 * jnp.sqrt(z_mag_sq) - z_mag_sq))\n  return x\n\n\ndef track_linearize(fn, mean, cov):\n  \"\"\"Apply function `fn` to a set of means and covariances, ala a Kalman filter.\n\n  We can analytically transform a Gaussian parameterized by `mean` and `cov`\n  with a function `fn` by linearizing `fn` around `mean`, and taking advantage\n  of the fact that Covar[Ax + y] = A(Covar[x])A^T (see\n  https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).\n\n  Args:\n    fn: the function applied to the Gaussians parameterized by (mean, cov).\n    mean: a tensor of means, where the last axis is the dimension.\n    cov: a tensor of covariances, where the last two axes are the dimensions.\n\n  Returns:\n    fn_mean: the transformed means.\n    fn_cov: the transformed covariances.\n  \"\"\"\n  if (len(mean.shape) + 1) != len(cov.shape):\n    raise ValueError('cov must be non-diagonal')\n  fn_mean, lin_fn = jax.linearize(fn, mean)\n  fn_cov = jax.vmap(lin_fn, -1, -2)(jax.vmap(lin_fn, -1, -2)(cov))\n  return fn_mean, fn_cov\n\n\ndef construct_ray_warps(fn, t_near, t_far):\n  \"\"\"Construct a bijection between metric distances and normalized distances.\n\n  See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a\n  detailed explanation.\n\n  Args:\n    fn: the function to ray distances.\n    t_near: a tensor of near-plane distances.\n    t_far: a tensor of far-plane distances.\n\n  Returns:\n    t_to_s: a function that maps distances to normalized distances in [0, 1].\n    s_to_t: the inverse of t_to_s.\n  \"\"\"\n  if fn is None:\n    fn_fwd = lambda x: x\n    fn_inv = lambda x: x\n  elif fn == 'piecewise':\n    # Piecewise spacing combining identity and 1/x functions to allow t_near=0.\n    fn_fwd = lambda x: jnp.where(x < 1, .5 * x, 1 - .5 / x)\n    fn_inv = lambda x: jnp.where(x < .5, 2 * x, .5 / (1 - x))\n  else:\n    inv_mapping = {\n        'reciprocal': jnp.reciprocal,\n        'log': jnp.exp,\n        'exp': jnp.log,\n        'sqrt': jnp.square,\n        'square': jnp.sqrt\n    }\n    fn_fwd = fn\n    fn_inv = inv_mapping[fn.__name__]\n\n  s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]\n  t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)\n  s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)\n  return t_to_s, s_to_t\n\n\ndef expected_sin(mean, var):\n  \"\"\"Compute the mean of sin(x), x ~ N(mean, var).\"\"\"\n  return jnp.exp(-0.5 * var) * math.safe_sin(mean)  # large var -> small value.\n\n\ndef integrated_pos_enc(mean, var, min_deg, max_deg):\n  \"\"\"Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).\n\n  Args:\n    mean: tensor, the mean coordinates to be encoded\n    var: tensor, the variance of the coordinates to be encoded.\n    min_deg: int, the min degree of the encoding.\n    max_deg: int, the max degree of the encoding.\n\n  Returns:\n    encoded: jnp.ndarray, encoded variables.\n  \"\"\"\n  scales = 2**jnp.arange(min_deg, max_deg)\n  shape = mean.shape[:-1] + (-1,)\n  scaled_mean = jnp.reshape(mean[..., None, :] * scales[:, None], shape)\n  scaled_var = jnp.reshape(var[..., None, :] * scales[:, None]**2, shape)\n\n  return expected_sin(\n      jnp.concatenate([scaled_mean, scaled_mean + 0.5 * jnp.pi], axis=-1),\n      jnp.concatenate([scaled_var] * 2, axis=-1))\n\n\ndef lift_and_diagonalize(mean, cov, basis):\n  \"\"\"Project `mean` and `cov` onto basis and diagonalize the projected cov.\"\"\"\n  fn_mean = math.matmul(mean, basis)\n  fn_cov_diag = jnp.sum(basis * math.matmul(cov, basis), axis=-2)\n  return fn_mean, fn_cov_diag\n\n\ndef pos_enc(x, min_deg, max_deg, append_identity=True):\n  \"\"\"The positional encoding used by the original NeRF paper.\"\"\"\n  scales = 2**jnp.arange(min_deg, max_deg)\n  shape = x.shape[:-1] + (-1,)\n  scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)\n  # Note that we're not using safe_sin, unlike IPE.\n  four_feat = jnp.sin(\n      jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))\n  if append_identity:\n    return jnp.concatenate([x] + [four_feat], axis=-1)\n  else:\n    return four_feat\n"
  },
  {
    "path": "mip360/internal/datasets.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Different datasets implementation plus a general port for all the datasets.\"\"\"\n\nimport abc\nimport copy\nimport json\nimport os\nfrom os import path\nimport queue\nimport threading\nfrom typing import Mapping, Optional, Sequence, Text, Tuple, Union\n\nimport cv2\nfrom internal import camera_utils\nfrom internal import configs\nfrom internal import image as lib_image\nfrom internal import raw_utils\nfrom internal import utils\nimport jax\nimport numpy as np\nfrom PIL import Image\n\n# This is ugly, but it works.\nimport sys\nsys.path.insert(0,'internal/pycolmap')\nsys.path.insert(0,'internal/pycolmap/pycolmap')\nimport pycolmap\n\n\ndef load_dataset(split, train_dir, config):\n  \"\"\"Loads a split of a dataset using the data_loader specified by `config`.\"\"\"\n  dataset_dict = {\n      'blender': Blender,\n      'llff': LLFF,\n      'tat_nerfpp': TanksAndTemplesNerfPP,\n      'tat_fvs': TanksAndTemplesFVS,\n      'dtu': DTU,\n  }\n  return dataset_dict[config.dataset_loader](split, train_dir, config)\n\n\nclass NeRFSceneManager(pycolmap.SceneManager):\n  \"\"\"COLMAP pose loader.\n\n  Minor NeRF-specific extension to the third_party Python COLMAP loader:\n  google3/third_party/py/pycolmap/scene_manager.py\n  \"\"\"\n\n  def process(\n      self\n  ) -> Tuple[Sequence[Text], np.ndarray, np.ndarray, Optional[Mapping[\n      Text, float]], camera_utils.ProjectionType]:\n    \"\"\"Applies NeRF-specific postprocessing to the loaded pose data.\n\n    Returns:\n      a tuple [image_names, poses, pixtocam, distortion_params].\n      image_names:  contains the only the basename of the images.\n      poses: [N, 4, 4] array containing the camera to world matrices.\n      pixtocam: [N, 3, 3] array containing the camera to pixel space matrices.\n      distortion_params: mapping of distortion param name to distortion\n        parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2.\n    \"\"\"\n\n    self.load_cameras()\n    self.load_images()\n    # self.load_points3D()  # For now, we do not need the point cloud data.\n\n    # Assume shared intrinsics between all cameras.\n    cam = self.cameras[1]\n\n    # Extract focal lengths and principal point parameters.\n    fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy\n    pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))\n\n    # Extract extrinsic matrices in world-to-camera format.\n    imdata = self.images\n    w2c_mats = []\n    bottom = np.array([0, 0, 0, 1]).reshape(1, 4)\n    for k in imdata:\n      im = imdata[k]\n      rot = im.R()\n      trans = im.tvec.reshape(3, 1)\n      w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)\n      w2c_mats.append(w2c)\n    w2c_mats = np.stack(w2c_mats, axis=0)\n\n    # Convert extrinsics to camera-to-world.\n    c2w_mats = np.linalg.inv(w2c_mats)\n    poses = c2w_mats[:, :3, :4]\n\n    # Image names from COLMAP. No need for permuting the poses according to\n    # image names anymore.\n    names = [imdata[k].name for k in imdata]\n\n    # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.\n    poses = poses @ np.diag([1, -1, -1, 1])\n\n    # Get distortion parameters.\n    type_ = cam.camera_type\n\n    if type_ == 0 or type_ == 'SIMPLE_PINHOLE':\n      params = None\n      camtype = camera_utils.ProjectionType.PERSPECTIVE\n\n    elif type_ == 1 or type_ == 'PINHOLE':\n      params = None\n      camtype = camera_utils.ProjectionType.PERSPECTIVE\n\n    if type_ == 2 or type_ == 'SIMPLE_RADIAL':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      camtype = camera_utils.ProjectionType.PERSPECTIVE\n\n    elif type_ == 3 or type_ == 'RADIAL':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      camtype = camera_utils.ProjectionType.PERSPECTIVE\n\n    elif type_ == 4 or type_ == 'OPENCV':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      params['p1'] = cam.p1\n      params['p2'] = cam.p2\n      camtype = camera_utils.ProjectionType.PERSPECTIVE\n\n    elif type_ == 5 or type_ == 'OPENCV_FISHEYE':\n      params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}\n      params['k1'] = cam.k1\n      params['k2'] = cam.k2\n      params['k3'] = cam.k3\n      params['k4'] = cam.k4\n      camtype = camera_utils.ProjectionType.FISHEYE\n\n    return names, poses, pixtocam, params, camtype\n\n\ndef load_blender_posedata(data_dir, split=None):\n  \"\"\"Load poses from `transforms.json` file, as used in Blender/NGP datasets.\"\"\"\n  suffix = '' if split is None else f'_{split}'\n  pose_file = path.join(data_dir, f'transforms{suffix}.json')\n  with utils.open_file(pose_file, 'r') as fp:\n    meta = json.load(fp)\n  names = []\n  poses = []\n  for _, frame in enumerate(meta['frames']):\n    filepath = os.path.join(data_dir, frame['file_path'])\n    if utils.file_exists(filepath):\n      names.append(frame['file_path'].split('/')[-1])\n      poses.append(np.array(frame['transform_matrix'], dtype=np.float32))\n  poses = np.stack(poses, axis=0)\n\n  w = meta['w']\n  h = meta['h']\n  cx = meta['cx'] if 'cx' in meta else w / 2.\n  cy = meta['cy'] if 'cy' in meta else h / 2.\n  if 'fl_x' in meta:\n    fx = meta['fl_x']\n  else:\n    fx = 0.5 * w / np.tan(0.5 * float(meta['camera_angle_x']))\n  if 'fl_y' in meta:\n    fy = meta['fl_y']\n  else:\n    fy = 0.5 * h / np.tan(0.5 * float(meta['camera_angle_y']))\n  pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))\n  coeffs = ['k1', 'k2', 'p1', 'p2']\n  if not any([c in meta for c in coeffs]):\n    params = None\n  else:\n    params = {c: (meta[c] if c in meta else 0.) for c in coeffs}\n  camtype = camera_utils.ProjectionType.PERSPECTIVE\n  return names, poses, pixtocam, params, camtype\n\n\nclass Dataset(threading.Thread, metaclass=abc.ABCMeta):\n  \"\"\"Dataset Base Class.\n\n  Base class for a NeRF dataset. Creates batches of ray and color data used for\n  training or rendering a NeRF model.\n\n  Each subclass is responsible for loading images and camera poses from disk by\n  implementing the _load_renderings() method. This data is used to generate\n  train and test batches of ray + color data for feeding through the NeRF model.\n  The ray parameters are calculated in _generate_rays().\n\n  The public interface mimics the behavior of a standard machine learning\n  pipeline dataset provider that can provide infinite batches of data to the\n  training/testing pipelines without exposing any details of how the batches are\n  loaded/created or how this is parallelized. Therefore, the initializer runs\n  all setup, including data loading from disk using _load_renderings(), and\n  begins the thread using its parent start() method. After the initializer\n  returns, the caller can request batches of data straight away.\n\n  The internal self._queue is initialized as queue.Queue(3), so the infinite\n  loop in run() will block on the call self._queue.put(self._next_fn()) once\n  there are 3 elements. The main thread training job runs in a loop that pops 1\n  element at a time off the front of the queue. The Dataset thread's run() loop\n  will populate the queue with 3 elements, then wait until a batch has been\n  removed and push one more onto the end.\n\n  This repeats indefinitely until the main thread's training loop completes\n  (typically hundreds of thousands of iterations), then the main thread will\n  exit and the Dataset thread will automatically be killed since it is a daemon.\n\n  Attributes:\n    alphas: np.ndarray, optional array of alpha channel data.\n    cameras: tuple summarizing all camera extrinsic/intrinsic/distortion params.\n    camtoworlds: np.ndarray, a list of extrinsic camera pose matrices.\n    camtype: camera_utils.ProjectionType, fisheye or perspective camera.\n    data_dir: str, location of the dataset on disk.\n    disp_images: np.ndarray, optional array of disparity (inverse depth) data.\n    distortion_params: dict, the camera distortion model parameters.\n    exposures: optional per-image exposure value (shutter * ISO / 1000).\n    far: float, far plane value for rays.\n    focal: float, focal length from camera intrinsics.\n    height: int, height of images.\n    images: np.ndarray, array of RGB image data.\n    metadata: dict, optional metadata for raw datasets.\n    near: float, near plane value for rays.\n    normal_images: np.ndarray, optional array of surface normal vector data.\n    pixtocams: np.ndarray, one or a list of inverse intrinsic camera matrices.\n    pixtocam_ndc: np.ndarray, the inverse intrinsic matrix used for NDC space.\n    poses: np.ndarray, optional array of auxiliary camera pose data.\n    rays: utils.Rays, ray data for every pixel in the dataset.\n    render_exposures: optional list of exposure values for the render path.\n    render_path: bool, indicates if a smooth camera path should be generated.\n    size: int, number of images in the dataset.\n    split: str, indicates if this is a \"train\" or \"test\" dataset.\n    width: int, width of images.\n  \"\"\"\n\n  def __init__(self,\n               split: str,\n               data_dir: str,\n               config: configs.Config):\n    super().__init__()\n\n    # Initialize attributes\n    self._queue = queue.Queue(3)  # Set prefetch buffer to 3 batches.\n    self.daemon = True  # Sets parent Thread to be a daemon.\n    self._patch_size = np.maximum(config.patch_size, 1)\n    self._batch_size = config.batch_size // jax.process_count()\n    if self._patch_size**2 > self._batch_size:\n      raise ValueError(f'Patch size {self._patch_size}^2 too large for ' +\n                       f'per-process batch size {self._batch_size}')\n    self._batching = utils.BatchingMethod(config.batching)\n    self._use_tiffs = config.use_tiffs\n    self._load_disps = config.compute_disp_metrics\n    self._load_normals = config.compute_normal_metrics\n    self._test_camera_idx = 0\n    self._num_border_pixels_to_mask = config.num_border_pixels_to_mask\n    self._apply_bayer_mask = config.apply_bayer_mask\n    self._cast_rays_in_train_step = config.cast_rays_in_train_step\n    self._render_spherical = False\n\n    self.split = utils.DataSplit(split)\n    self.data_dir = data_dir\n    self.near = config.near\n    self.far = config.far\n    self.render_path = config.render_path\n    self.distortion_params = None\n    self.disp_images = None\n    self.normal_images = None\n    self.alphas = None\n    self.poses = None\n    self.pixtocam_ndc = None\n    self.metadata = None\n    self.camtype = camera_utils.ProjectionType.PERSPECTIVE\n    self.exposures = None\n    self.render_exposures = None\n\n    # Providing type comments for these attributes, they must be correctly\n    # initialized by _load_renderings() (see docstring) in any subclass.\n    self.images: np.ndarray = None\n    self.camtoworlds: np.ndarray = None\n    self.pixtocams: np.ndarray = None\n    self.height: int = None\n    self.width: int = None\n\n    # Load data from disk using provided config parameters.\n    self._load_renderings(config)\n\n    if self.render_path:\n      if config.render_path_file is not None:\n        with utils.open_file(config.render_path_file, 'rb') as fp:\n          render_poses = np.load(fp)\n        self.camtoworlds = render_poses\n      if config.render_resolution is not None:\n        self.width, self.height = config.render_resolution\n      if config.render_focal is not None:\n        self.focal = config.render_focal\n      if config.render_camtype is not None:\n        if config.render_camtype == 'pano':\n          self._render_spherical = True\n        else:\n          self.camtype = camera_utils.ProjectionType(config.render_camtype)\n\n      self.distortion_params = None\n      self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,\n                                                 self.height)\n\n    self._n_examples = self.camtoworlds.shape[0]\n\n    self.cameras = (self.pixtocams,\n                    self.camtoworlds,\n                    self.distortion_params,\n                    self.pixtocam_ndc)\n\n    # Seed the queue with one batch to avoid race condition.\n    if self.split == utils.DataSplit.TRAIN:\n      self._next_fn = self._next_train\n    else:\n      self._next_fn = self._next_test\n    self._queue.put(self._next_fn())\n    self.start()\n\n  def __iter__(self):\n    return self\n\n  def __next__(self):\n    \"\"\"Get the next training batch or test example.\n\n    Returns:\n      batch: dict, has 'rgb' and 'rays'.\n    \"\"\"\n    x = self._queue.get()\n    if self.split == utils.DataSplit.TRAIN:\n      return utils.shard(x)\n    else:\n      # Do NOT move test `rays` to device, since it may be very large.\n      return x\n\n  def peek(self):\n    \"\"\"Peek at the next training batch or test example without dequeuing it.\n\n    Returns:\n      batch: dict, has 'rgb' and 'rays'.\n    \"\"\"\n    x = copy.copy(self._queue.queue[0])  # Make a copy of front of queue.\n    if self.split == utils.DataSplit.TRAIN:\n      return utils.shard(x)\n    else:\n      return jax.device_put(x)\n\n  def run(self):\n    while True:\n      self._queue.put(self._next_fn())\n\n  @property\n  def size(self):\n    return self._n_examples\n\n  @abc.abstractmethod\n  def _load_renderings(self, config):\n    \"\"\"Load images and poses from disk.\n\n    Args:\n      config: utils.Config, user-specified config parameters.\n    In inherited classes, this method must set the following public attributes:\n      images: [N, height, width, 3] array for RGB images.\n      disp_images: [N, height, width] array for depth data (optional).\n      normal_images: [N, height, width, 3] array for normals (optional).\n      camtoworlds: [N, 3, 4] array of extrinsic pose matrices.\n      poses: [..., 3, 4] array of auxiliary pose data (optional).\n      pixtocams: [N, 3, 4] array of inverse intrinsic matrices.\n      distortion_params: dict, camera lens distortion model parameters.\n      height: int, height of images.\n      width: int, width of images.\n      focal: float, focal length to use for ideal pinhole rendering.\n    \"\"\"\n\n  def _make_ray_batch(self,\n                      pix_x_int: np.ndarray,\n                      pix_y_int: np.ndarray,\n                      cam_idx: Union[np.ndarray, np.int32],\n                      lossmult: Optional[np.ndarray] = None\n                      ) -> utils.Batch:\n    \"\"\"Creates ray data batch from pixel coordinates and camera indices.\n\n    All arguments must have broadcastable shapes. If the arguments together\n    broadcast to a shape [a, b, c, ..., z] then the returned utils.Rays object\n    will have array attributes with shape [a, b, c, ..., z, N], where N=3 for\n    3D vectors and N=1 for per-ray scalar attributes.\n\n    Args:\n      pix_x_int: int array, x coordinates of image pixels.\n      pix_y_int: int array, y coordinates of image pixels.\n      cam_idx: int or int array, camera indices.\n      lossmult: float array, weight to apply to each ray when computing loss fn.\n\n    Returns:\n      A dict mapping from strings utils.Rays or arrays of image data.\n      This is the batch provided for one NeRF train or test iteration.\n    \"\"\"\n\n    broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]\n    ray_kwargs = {\n        'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult,\n        'near': broadcast_scalar(self.near),\n        'far': broadcast_scalar(self.far),\n        'cam_idx': broadcast_scalar(cam_idx),\n    }\n    # Collect per-camera information needed for each ray.\n    if self.metadata is not None:\n      # Exposure index and relative shutter speed, needed for RawNeRF.\n      for key in ['exposure_idx', 'exposure_values']:\n        idx = 0 if self.render_path else cam_idx\n        ray_kwargs[key] = broadcast_scalar(self.metadata[key][idx])\n    if self.exposures is not None:\n      idx = 0 if self.render_path else cam_idx\n      ray_kwargs['exposure_values'] = broadcast_scalar(self.exposures[idx])\n    if self.render_path and self.render_exposures is not None:\n      ray_kwargs['exposure_values'] = broadcast_scalar(\n          self.render_exposures[cam_idx])\n\n    pixels = utils.Pixels(pix_x_int, pix_y_int, **ray_kwargs)\n    if self._cast_rays_in_train_step and self.split == utils.DataSplit.TRAIN:\n      # Fast path, defer ray computation to the training loop (on device).\n      rays = pixels\n    else:\n      # Slow path, do ray computation using numpy (on CPU).\n      rays = camera_utils.cast_ray_batch(\n          self.cameras, pixels, self.camtype, xnp=np)\n\n    # Create data batch.\n    batch = {}\n    batch['rays'] = rays\n    if not self.render_path:\n      batch['rgb'] = self.images[cam_idx, pix_y_int, pix_x_int]\n    if self._load_disps:\n      batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int]\n    if self._load_normals:\n      batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int]\n      batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int]\n    return utils.Batch(**batch)\n\n  def _next_train(self) -> utils.Batch:\n    \"\"\"Sample next training batch (random rays).\"\"\"\n    # We assume all images in the dataset are the same resolution, so we can use\n    # the same width/height for sampling all pixels coordinates in the batch.\n    # Batch/patch sampling parameters.\n    num_patches = self._batch_size // self._patch_size ** 2\n    lower_border = self._num_border_pixels_to_mask\n    upper_border = self._num_border_pixels_to_mask + self._patch_size - 1\n    # Random pixel patch x-coordinates.\n    pix_x_int = np.random.randint(lower_border, self.width - upper_border,\n                                  (num_patches, 1, 1))\n    # Random pixel patch y-coordinates.\n    pix_y_int = np.random.randint(lower_border, self.height - upper_border,\n                                  (num_patches, 1, 1))\n    # Add patch coordinate offsets.\n    # Shape will broadcast to (num_patches, _patch_size, _patch_size).\n    patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates(\n        self._patch_size, self._patch_size)\n    pix_x_int = pix_x_int + patch_dx_int\n    pix_y_int = pix_y_int + patch_dy_int\n    # Random camera indices.\n    if self._batching == utils.BatchingMethod.ALL_IMAGES:\n      cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))\n    else:\n      cam_idx = np.random.randint(0, self._n_examples, (1,))\n\n    if self._apply_bayer_mask:\n      # Compute the Bayer mosaic mask for each pixel in the batch.\n      lossmult = raw_utils.pixels_to_bayer_mask(pix_x_int, pix_y_int)\n    else:\n      lossmult = None\n\n    return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx,\n                                lossmult=lossmult)\n\n  def generate_ray_batch(self, cam_idx: int) -> utils.Batch:\n    \"\"\"Generate ray batch for a specified camera in the dataset.\"\"\"\n    if self._render_spherical:\n      camtoworld = self.camtoworlds[cam_idx]\n      rays = camera_utils.cast_spherical_rays(\n          camtoworld, self.height, self.width, self.near, self.far, xnp=np)\n      return utils.Batch(rays=rays)\n    else:\n      # Generate rays for all pixels in the image.\n      pix_x_int, pix_y_int = camera_utils.pixel_coordinates(\n          self.width, self.height)\n      return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx)\n\n  def _next_test(self) -> utils.Batch:\n    \"\"\"Sample next test batch (one full image).\"\"\"\n    # Use the next camera index.\n    cam_idx = self._test_camera_idx\n    self._test_camera_idx = (self._test_camera_idx + 1) % self._n_examples\n    return self.generate_ray_batch(cam_idx)\n\n\nclass Blender(Dataset):\n  \"\"\"Blender Dataset.\"\"\"\n\n  def _load_renderings(self, config):\n    \"\"\"Load images from disk.\"\"\"\n    if config.render_path:\n      raise ValueError('render_path cannot be used for the blender dataset.')\n    pose_file = path.join(self.data_dir, f'transforms_{self.split.value}.json')\n    with utils.open_file(pose_file, 'r') as fp:\n      meta = json.load(fp)\n    images = []\n    disp_images = []\n    normal_images = []\n    cams = []\n    for _, frame in enumerate(meta['frames']):\n      fprefix = os.path.join(self.data_dir, frame['file_path'])\n\n      def get_img(f, fprefix=fprefix):\n        image = utils.load_img(fprefix + f)\n        if config.factor > 1:\n          image = lib_image.downsample(image, config.factor)\n        return image\n\n      if self._use_tiffs:\n        channels = [get_img(f'_{ch}.tiff') for ch in ['R', 'G', 'B', 'A']]\n        # Convert image to sRGB color space.\n        image = lib_image.linear_to_srgb(np.stack(channels, axis=-1))\n      else:\n        image = get_img('.png') / 255.\n      images.append(image)\n\n      if self._load_disps:\n        disp_image = get_img('_disp.tiff')\n        disp_images.append(disp_image)\n      if self._load_normals:\n        normal_image = get_img('_normal.png')[..., :3] * 2. / 255. - 1.\n        normal_images.append(normal_image)\n\n      cams.append(np.array(frame['transform_matrix'], dtype=np.float32))\n\n    self.images = np.stack(images, axis=0)\n    if self._load_disps:\n      self.disp_images = np.stack(disp_images, axis=0)\n    if self._load_normals:\n      self.normal_images = np.stack(normal_images, axis=0)\n      self.alphas = self.images[..., -1]\n\n    rgb, alpha = self.images[..., :3], self.images[..., -1:]\n    self.images = rgb * alpha + (1. - alpha)  # Use a white background.\n    self.height, self.width = self.images.shape[1:3]\n    self.camtoworlds = np.stack(cams, axis=0)\n    self.focal = .5 * self.width / np.tan(.5 * float(meta['camera_angle_x']))\n    self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,\n                                               self.height)\n\n\nclass LLFF(Dataset):\n  \"\"\"LLFF Dataset.\"\"\"\n\n  def _load_renderings(self, config):\n    \"\"\"Load images from disk.\"\"\"\n    # Set up scaling factor.\n    image_dir_suffix = ''\n    # Use downsampling factor (unless loading training split for raw dataset,\n    # we train raw at full resolution because of the Bayer mosaic pattern).\n    if config.factor > 0 and not (config.rawnerf_mode and\n                                  self.split == utils.DataSplit.TRAIN):\n      image_dir_suffix = f'_{config.factor}'\n      factor = config.factor\n    else:\n      factor = 1\n\n    # Copy COLMAP data to local disk for faster loading.\n    colmap_dir = os.path.join(self.data_dir, 'sparse/0/')\n\n    # Load poses.\n    if utils.file_exists(colmap_dir):\n      pose_data = NeRFSceneManager(colmap_dir).process()\n    else:\n      # Attempt to load Blender/NGP format if COLMAP data not present.\n      pose_data = load_blender_posedata(self.data_dir)\n    image_names, poses, pixtocam, distortion_params, camtype = pose_data\n\n    # Previous NeRF results were generated with images sorted by filename,\n    # use this flag to ensure metrics are reported on the same test set.\n    if config.load_alphabetical:\n      inds = np.argsort(image_names)\n      image_names = [image_names[i] for i in inds]\n      poses = poses[inds]\n\n    # Scale the inverse intrinsics matrix by the image downsampling factor.\n    pixtocam = pixtocam @ np.diag([factor, factor, 1.])\n    self.pixtocams = pixtocam.astype(np.float32)\n    self.focal = 1. / self.pixtocams[0, 0]\n    self.distortion_params = distortion_params\n    self.camtype = camtype\n\n    raw_testscene = False\n    if config.rawnerf_mode:\n      # Load raw images and metadata.\n      images, metadata, raw_testscene = raw_utils.load_raw_dataset(\n          self.split,\n          self.data_dir,\n          image_names,\n          config.exposure_percentile,\n          factor)\n      self.metadata = metadata\n\n    else:\n      # Load images.\n      colmap_image_dir = os.path.join(self.data_dir, 'images')\n      image_dir = os.path.join(self.data_dir, 'images' + image_dir_suffix)\n      for d in [image_dir, colmap_image_dir]:\n        if not utils.file_exists(d):\n          raise ValueError(f'Image folder {d} does not exist.')\n      # Downsampled images may have different names vs images used for COLMAP,\n      # so we need to map between the two sorted lists of files.\n      colmap_files = sorted(utils.listdir(colmap_image_dir))\n      image_files = sorted(utils.listdir(image_dir))\n      colmap_to_image = dict(zip(colmap_files, image_files))\n      image_paths = [os.path.join(image_dir, colmap_to_image[f])\n                     for f in image_names]\n      images = [utils.load_img(x) for x in image_paths]\n      images = np.stack(images, axis=0) / 255.\n\n      # EXIF data is usually only present in the original JPEG images.\n      jpeg_paths = [os.path.join(colmap_image_dir, f) for f in image_names]\n      exifs = [utils.load_exif(x) for x in jpeg_paths]\n      self.exifs = exifs\n      if 'ExposureTime' in exifs[0] and 'ISOSpeedRatings' in exifs[0]:\n        gather_exif_value = lambda k: np.array([float(x[k]) for x in exifs])\n        shutters = gather_exif_value('ExposureTime')\n        isos = gather_exif_value('ISOSpeedRatings')\n        self.exposures = shutters * isos / 1000.\n\n    # Load bounds if possible (only used in forward facing scenes).\n    posefile = os.path.join(self.data_dir, 'poses_bounds.npy')\n    if utils.file_exists(posefile):\n      with utils.open_file(posefile, 'rb') as fp:\n        poses_arr = np.load(fp)\n      bounds = poses_arr[:, -2:]\n    else:\n      bounds = np.array([0.01, 1.])\n    self.colmap_to_world_transform = np.eye(4)\n\n    # Separate out 360 versus forward facing scenes.\n    if config.forward_facing:\n      # Set the projective matrix defining the NDC transformation.\n      self.pixtocam_ndc = self.pixtocams.reshape(-1, 3, 3)[0]\n      # Rescale according to a default bd factor.\n      scale = 1. / (bounds.min() * .75)\n      poses[:, :3, 3] *= scale\n      self.colmap_to_world_transform = np.diag([scale] * 3 + [1])\n      bounds *= scale\n      # Recenter poses.\n      poses, transform = camera_utils.recenter_poses(poses)\n      self.colmap_to_world_transform = (\n          transform @ self.colmap_to_world_transform)\n      # Forward-facing spiral render path.\n      self.render_poses = camera_utils.generate_spiral_path(\n          poses, bounds, n_frames=config.render_path_frames)\n    else:\n      # Rotate/scale poses to align ground with xy plane and fit to unit cube.\n      poses, transform = camera_utils.transform_poses_pca(poses)\n      self.colmap_to_world_transform = transform\n      if config.render_spline_keyframes is not None:\n        rets = camera_utils.create_render_spline_path(config, image_names,\n                                                      poses, self.exposures)\n        self.spline_indices, self.render_poses, self.render_exposures = rets\n      else:\n        # Automatically generated inward-facing elliptical render path.\n        self.render_poses = camera_utils.generate_ellipse_path(\n            poses,\n            n_frames=config.render_path_frames,\n            z_variation=config.z_variation,\n            z_phase=config.z_phase)\n\n    if raw_testscene:\n      # For raw testscene, the first image sent to COLMAP has the same pose as\n      # the ground truth test image. The remaining images form the training set.\n      raw_testscene_poses = {\n          utils.DataSplit.TEST: poses[:1],\n          utils.DataSplit.TRAIN: poses[1:],\n      }\n      poses = raw_testscene_poses[self.split]\n\n    self.poses = poses\n\n    # Select the split.\n    all_indices = np.arange(images.shape[0])\n    if config.llff_use_all_images_for_training or raw_testscene:\n      train_indices = all_indices\n    else:\n      train_indices = all_indices % config.llffhold != 0\n    split_indices = {\n        utils.DataSplit.TEST: all_indices[all_indices % config.llffhold == 0],\n        utils.DataSplit.TRAIN: train_indices,\n    }\n    indices = split_indices[self.split]\n    # All per-image quantities must be re-indexed using the split indices.\n    images = images[indices]\n    poses = poses[indices]\n    if self.exposures is not None:\n      self.exposures = self.exposures[indices]\n    if config.rawnerf_mode:\n      for key in ['exposure_idx', 'exposure_values']:\n        self.metadata[key] = self.metadata[key][indices]\n\n    self.images = images\n    self.camtoworlds = self.render_poses if config.render_path else poses\n    self.height, self.width = images.shape[1:3]\n\n\nclass TanksAndTemplesNerfPP(Dataset):\n  \"\"\"Subset of Tanks and Temples Dataset as processed by NeRF++.\"\"\"\n\n  def _load_renderings(self, config):\n    \"\"\"Load images from disk.\"\"\"\n    if config.render_path:\n      split_str = 'camera_path'\n    else:\n      split_str = self.split.value\n\n    basedir = os.path.join(self.data_dir, split_str)\n\n    def load_files(dirname, load_fn, shape=None):\n      files = [\n          os.path.join(basedir, dirname, f)\n          for f in sorted(utils.listdir(os.path.join(basedir, dirname)))\n      ]\n      mats = np.array([load_fn(utils.open_file(f, 'rb')) for f in files])\n      if shape is not None:\n        mats = mats.reshape(mats.shape[:1] + shape)\n      return mats\n\n    poses = load_files('pose', np.loadtxt, (4, 4))\n    # Flip Y and Z axes to get correct coordinate frame.\n    poses = np.matmul(poses, np.diag(np.array([1, -1, -1, 1])))\n\n    # For now, ignore all but the first focal length in intrinsics\n    intrinsics = load_files('intrinsics', np.loadtxt, (4, 4))\n\n    if not config.render_path:\n      images = load_files('rgb', lambda f: np.array(Image.open(f))) / 255.\n      self.images = images\n      self.height, self.width = self.images.shape[1:3]\n\n    else:\n      # Hack to grab the image resolution from a test image\n      d = os.path.join(self.data_dir, 'test', 'rgb')\n      f = os.path.join(d, sorted(utils.listdir(d))[0])\n      shape = utils.load_img(f).shape\n      self.height, self.width = shape[:2]\n      self.images = None\n\n    self.camtoworlds = poses\n    self.focal = intrinsics[0, 0, 0]\n    self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,\n                                               self.height)\n\n\nclass TanksAndTemplesFVS(Dataset):\n  \"\"\"Subset of Tanks and Temples Dataset as processed by Free View Synthesis.\"\"\"\n\n  def _load_renderings(self, config):\n    \"\"\"Load images from disk.\"\"\"\n    render_only = config.render_path and self.split == utils.DataSplit.TEST\n\n    basedir = os.path.join(self.data_dir, 'dense')\n    sizes = [f for f in sorted(utils.listdir(basedir)) if f.startswith('ibr3d')]\n    sizes = sizes[::-1]\n\n    if config.factor >= len(sizes):\n      raise ValueError(f'Factor {config.factor} larger than {len(sizes)}')\n\n    basedir = os.path.join(basedir, sizes[config.factor])\n    open_fn = lambda f: utils.open_file(os.path.join(basedir, f), 'rb')\n\n    files = [f for f in sorted(utils.listdir(basedir)) if f.startswith('im_')]\n    if render_only:\n      files = files[:1]\n    images = np.array([np.array(Image.open(open_fn(f))) for f in files]) / 255.\n\n    names = ['Ks', 'Rs', 'ts']\n    intrinsics, rot, trans = (np.load(open_fn(f'{n}.npy')) for n in names)\n\n    # Convert poses from colmap world-to-cam into our cam-to-world.\n    w2c = np.concatenate([rot, trans[..., None]], axis=-1)\n    c2w_colmap = np.linalg.inv(camera_utils.pad_poses(w2c))[:, :3, :4]\n    c2w = c2w_colmap @ np.diag(np.array([1, -1, -1, 1]))\n\n    # Reorient poses so z-axis is up\n    poses, _ = camera_utils.transform_poses_pca(c2w)\n    self.poses = poses\n\n    self.images = images\n    self.height, self.width = self.images.shape[1:3]\n    self.camtoworlds = poses\n    # For now, ignore all but the first focal length in intrinsics\n    self.focal = intrinsics[0, 0, 0]\n    self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,\n                                               self.height)\n\n    if render_only:\n      render_path = camera_utils.generate_ellipse_path(\n          poses,\n          config.render_path_frames,\n          z_variation=config.z_variation,\n          z_phase=config.z_phase)\n      self.images = None\n      self.camtoworlds = render_path\n      self.render_poses = render_path\n    else:\n      # Select the split.\n      all_indices = np.arange(images.shape[0])\n      indices = {\n          utils.DataSplit.TEST:\n              all_indices[all_indices % config.llffhold == 0],\n          utils.DataSplit.TRAIN:\n              all_indices[all_indices % config.llffhold != 0],\n      }[self.split]\n\n      self.images = self.images[indices]\n      self.camtoworlds = self.camtoworlds[indices]\n\n\nclass DTU(Dataset):\n  \"\"\"DTU Dataset.\"\"\"\n\n  def _load_renderings(self, config):\n    \"\"\"Load images from disk.\"\"\"\n    if config.render_path:\n      raise ValueError('render_path cannot be used for the DTU dataset.')\n\n    images = []\n    pixtocams = []\n    camtoworlds = []\n\n    # Find out whether the particular scan has 49 or 65 images.\n    n_images = len(utils.listdir(self.data_dir)) // 8\n\n    # Loop over all images.\n    for i in range(1, n_images + 1):\n      # Set light condition string accordingly.\n      if config.dtu_light_cond < 7:\n        light_str = f'{config.dtu_light_cond}_r' + ('5000'\n                                                    if i < 50 else '7000')\n      else:\n        light_str = 'max'\n\n      # Load image.\n      fname = os.path.join(self.data_dir, f'rect_{i:03d}_{light_str}.png')\n      image = utils.load_img(fname) / 255.\n      if config.factor > 1:\n        image = lib_image.downsample(image, config.factor)\n      images.append(image)\n\n      # Load projection matrix from file.\n      fname = path.join(self.data_dir, f'../../cal18/pos_{i:03d}.txt')\n      with utils.open_file(fname, 'rb') as f:\n        projection = np.loadtxt(f, dtype=np.float32)\n\n      # Decompose projection matrix into pose and camera matrix.\n      camera_mat, rot_mat, t = cv2.decomposeProjectionMatrix(projection)[:3]\n      camera_mat = camera_mat / camera_mat[2, 2]\n      pose = np.eye(4, dtype=np.float32)\n      pose[:3, :3] = rot_mat.transpose()\n      pose[:3, 3] = (t[:3] / t[3])[:, 0]\n      pose = pose[:3]\n      camtoworlds.append(pose)\n\n      if config.factor > 0:\n        # Scale camera matrix according to downsampling factor.\n        camera_mat = np.diag([1. / config.factor, 1. / config.factor, 1.\n                             ]).astype(np.float32) @ camera_mat\n      pixtocams.append(np.linalg.inv(camera_mat))\n\n    pixtocams = np.stack(pixtocams)\n    camtoworlds = np.stack(camtoworlds)\n    images = np.stack(images)\n\n    def rescale_poses(poses):\n      \"\"\"Rescales camera poses according to maximum x/y/z value.\"\"\"\n      s = np.max(np.abs(poses[:, :3, -1]))\n      out = np.copy(poses)\n      out[:, :3, -1] /= s\n      return out\n\n    # Center and scale poses.\n    camtoworlds, _ = camera_utils.recenter_poses(camtoworlds)\n    camtoworlds = rescale_poses(camtoworlds)\n    # Flip y and z axes to get poses in OpenGL coordinate system.\n    camtoworlds = camtoworlds @ np.diag([1., -1., -1., 1.]).astype(np.float32)\n\n    all_indices = np.arange(images.shape[0])\n    split_indices = {\n        utils.DataSplit.TEST: all_indices[all_indices % config.dtuhold == 0],\n        utils.DataSplit.TRAIN: all_indices[all_indices % config.dtuhold != 0],\n    }\n    indices = split_indices[self.split]\n\n    self.images = images[indices]\n    self.height, self.width = images.shape[1:3]\n    self.camtoworlds = camtoworlds[indices]\n    self.pixtocams = pixtocams[indices]\n"
  },
  {
    "path": "mip360/internal/geopoly.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tools for constructing geodesic polyhedron, which are used as a basis.\"\"\"\n\nimport itertools\nimport numpy as np\n\n\ndef compute_sq_dist(mat0, mat1=None):\n  \"\"\"Compute the squared Euclidean distance between all pairs of columns.\"\"\"\n  if mat1 is None:\n    mat1 = mat0\n  # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.\n  sq_norm0 = np.sum(mat0**2, 0)\n  sq_norm1 = np.sum(mat1**2, 0)\n  sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1\n  sq_dist = np.maximum(0, sq_dist)  # Negative values must be numerical errors.\n  return sq_dist\n\n\ndef compute_tesselation_weights(v):\n  \"\"\"Tesselate the vertices of a triangle by a factor of `v`.\"\"\"\n  if v < 1:\n    raise ValueError(f'v {v} must be >= 1')\n  int_weights = []\n  for i in range(v + 1):\n    for j in range(v + 1 - i):\n      int_weights.append((i, j, v - (i + j)))\n  int_weights = np.array(int_weights)\n  weights = int_weights / v  # Barycentric weights.\n  return weights\n\n\ndef tesselate_geodesic(base_verts, base_faces, v, eps=1e-4):\n  \"\"\"Tesselate the vertices of a geodesic polyhedron.\n\n  Args:\n    base_verts: tensor of floats, the vertex coordinates of the geodesic.\n    base_faces: tensor of ints, the indices of the vertices of base_verts that\n      constitute eachface of the polyhedra.\n    v: int, the factor of the tesselation (v==1 is a no-op).\n    eps: float, a small value used to determine if two vertices are the same.\n\n  Returns:\n    verts: a tensor of floats, the coordinates of the tesselated vertices.\n  \"\"\"\n  if not isinstance(v, int):\n    raise ValueError(f'v {v} must an integer')\n  tri_weights = compute_tesselation_weights(v)\n\n  verts = []\n  for base_face in base_faces:\n    new_verts = np.matmul(tri_weights, base_verts[base_face, :])\n    new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True))\n    verts.append(new_verts)\n  verts = np.concatenate(verts, 0)\n\n  sq_dist = compute_sq_dist(verts.T)\n  assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist])\n  unique = np.unique(assignment)\n  verts = verts[unique, :]\n\n  return verts\n\n\ndef generate_basis(base_shape,\n                   angular_tesselation,\n                   remove_symmetries=True,\n                   eps=1e-4):\n  \"\"\"Generates a 3D basis by tesselating a geometric polyhedron.\n\n  Args:\n    base_shape: string, the name of the starting polyhedron, must be either\n      'icosahedron' or 'octahedron'.\n    angular_tesselation: int, the number of times to tesselate the polyhedron,\n      must be >= 1 (a value of 1 is a no-op to the polyhedron).\n    remove_symmetries: bool, if True then remove the symmetric basis columns,\n      which is usually a good idea because otherwise projections onto the basis\n      will have redundant negative copies of each other.\n    eps: float, a small number used to determine symmetries.\n\n  Returns:\n    basis: a matrix with shape [3, n].\n  \"\"\"\n  if base_shape == 'icosahedron':\n    a = (np.sqrt(5) + 1) / 2\n    verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1),\n                      (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0),\n                      (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2)\n    faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1),\n                      (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3),\n                      (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6),\n                      (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5),\n                      (7, 2, 11)])\n    verts = tesselate_geodesic(verts, faces, angular_tesselation)\n  elif base_shape == 'octahedron':\n    verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0),\n                      (1, 0, 0)])\n    corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n    pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n    faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n    verts = tesselate_geodesic(verts, faces, angular_tesselation)\n  else:\n    raise ValueError(f'base_shape {base_shape} not supported')\n\n  if remove_symmetries:\n    # Remove elements of `verts` that are reflections of each other.\n    match = compute_sq_dist(verts.T, -verts.T) < eps\n    verts = verts[np.any(np.triu(match), 1), :]\n\n  basis = verts[:, ::-1]\n  return basis\n"
  },
  {
    "path": "mip360/internal/image.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functions for processing images.\"\"\"\n\nimport types\nfrom typing import Optional, Union\n\nimport dm_pix\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\n_Array = Union[np.ndarray, jnp.ndarray]\n\n\ndef mse_to_psnr(mse):\n  \"\"\"Compute PSNR given an MSE (we assume the maximum pixel value is 1).\"\"\"\n  return -10. / jnp.log(10.) * jnp.log(mse)\n\n\ndef psnr_to_mse(psnr):\n  \"\"\"Compute MSE given a PSNR (we assume the maximum pixel value is 1).\"\"\"\n  return jnp.exp(-0.1 * jnp.log(10.) * psnr)\n\n\ndef ssim_to_dssim(ssim):\n  \"\"\"Compute DSSIM given an SSIM.\"\"\"\n  return (1 - ssim) / 2\n\n\ndef dssim_to_ssim(dssim):\n  \"\"\"Compute DSSIM given an SSIM.\"\"\"\n  return 1 - 2 * dssim\n\n\ndef linear_to_srgb(linear: _Array,\n                   eps: Optional[float] = None,\n                   xnp: types.ModuleType = jnp) -> _Array:\n  \"\"\"Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n  if eps is None:\n    eps = xnp.finfo(xnp.float32).eps\n  srgb0 = 323 / 25 * linear\n  srgb1 = (211 * xnp.maximum(eps, linear)**(5 / 12) - 11) / 200\n  return xnp.where(linear <= 0.0031308, srgb0, srgb1)\n\n\ndef srgb_to_linear(srgb: _Array,\n                   eps: Optional[float] = None,\n                   xnp: types.ModuleType = jnp) -> _Array:\n  \"\"\"Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n  if eps is None:\n    eps = xnp.finfo(xnp.float32).eps\n  linear0 = 25 / 323 * srgb\n  linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5)\n  return xnp.where(srgb <= 0.04045, linear0, linear1)\n\n\ndef downsample(img, factor):\n  \"\"\"Area downsample img (factor must evenly divide img height and width).\"\"\"\n  sh = img.shape\n  if not (sh[0] % factor == 0 and sh[1] % factor == 0):\n    raise ValueError(f'Downsampling factor {factor} does not '\n                     f'evenly divide image shape {sh[:2]}')\n  img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:])\n  img = img.mean((1, 3))\n  return img\n\n\ndef color_correct(img, ref, num_iters=5, eps=0.5 / 255):\n  \"\"\"Warp `img` to match the colors in `ref_img`.\"\"\"\n  if img.shape[-1] != ref.shape[-1]:\n    raise ValueError(\n        f'img\\'s {img.shape[-1]} and ref\\'s {ref.shape[-1]} channels must match'\n    )\n  num_channels = img.shape[-1]\n  img_mat = img.reshape([-1, num_channels])\n  ref_mat = ref.reshape([-1, num_channels])\n  is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps))  # z \\in [eps, 1-eps].\n  mask0 = is_unclipped(img_mat)\n  # Because the set of saturated pixels may change after solving for a\n  # transformation, we repeatedly solve a system `num_iters` times and update\n  # our estimate of which pixels are saturated.\n  for _ in range(num_iters):\n    # Construct the left hand side of a linear system that contains a quadratic\n    # expansion of each pixel of `img`.\n    a_mat = []\n    for c in range(num_channels):\n      a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:])  # Quadratic term.\n    a_mat.append(img_mat)  # Linear term.\n    a_mat.append(jnp.ones_like(img_mat[:, :1]))  # Bias term.\n    a_mat = jnp.concatenate(a_mat, axis=-1)\n    warp = []\n    for c in range(num_channels):\n      # Construct the right hand side of a linear system containing each color\n      # of `ref`.\n      b = ref_mat[:, c]\n      # Ignore rows of the linear system that were saturated in the input or are\n      # saturated in the current corrected color estimate.\n      mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)\n      ma_mat = jnp.where(mask[:, None], a_mat, 0)\n      mb = jnp.where(mask, b, 0)\n      # Solve the linear system. We're using the np.lstsq instead of jnp because\n      # it's significantly more stable in this case, for some reason.\n      w = np.linalg.lstsq(ma_mat, mb, rcond=-1)[0]\n      assert jnp.all(jnp.isfinite(w))\n      warp.append(w)\n    warp = jnp.stack(warp, axis=-1)\n    # Apply the warp to update img_mat.\n    img_mat = jnp.clip(\n        jnp.matmul(a_mat, warp, precision=jax.lax.Precision.HIGHEST), 0, 1)\n  corrected_img = jnp.reshape(img_mat, img.shape)\n  return corrected_img\n\n\nclass MetricHarness:\n  \"\"\"A helper class for evaluating several error metrics.\"\"\"\n\n  def __init__(self):\n    self.ssim_fn = jax.jit(dm_pix.ssim)\n\n  def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s):\n    \"\"\"Evaluate the error between a predicted rgb image and the true image.\"\"\"\n    psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean()))\n    ssim = float(self.ssim_fn(rgb_pred, rgb_gt))\n\n    return {\n        name_fn('psnr'): psnr,\n        name_fn('ssim'): ssim,\n    }\n"
  },
  {
    "path": "mip360/internal/math.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Mathy utility functions.\"\"\"\n\nimport jax\nimport jax.numpy as jnp\n\n\ndef matmul(a, b):\n  \"\"\"jnp.matmul defaults to bfloat16, but this helper function doesn't.\"\"\"\n  return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST)\n\n\ndef safe_trig_helper(x, fn, t=100 * jnp.pi):\n  \"\"\"Helper function used by safe_cos/safe_sin: mods x before sin()/cos().\"\"\"\n  return fn(jnp.where(jnp.abs(x) < t, x, x % t))\n\n\ndef safe_cos(x):\n  \"\"\"jnp.cos() on a TPU may NaN out for large values.\"\"\"\n  return safe_trig_helper(x, jnp.cos)\n\n\ndef safe_sin(x):\n  \"\"\"jnp.sin() on a TPU may NaN out for large values.\"\"\"\n  return safe_trig_helper(x, jnp.sin)\n\n\n@jax.custom_jvp\ndef safe_exp(x):\n  \"\"\"jnp.exp() but with finite output and gradients for large inputs.\"\"\"\n  return jnp.exp(jnp.minimum(x, 88.))  # jnp.exp(89) is infinity.\n\n\n@safe_exp.defjvp\ndef safe_exp_jvp(primals, tangents):\n  \"\"\"Override safe_exp()'s gradient so that it's large when inputs are large.\"\"\"\n  x, = primals\n  x_dot, = tangents\n  exp_x = safe_exp(x)\n  exp_x_dot = exp_x * x_dot\n  return exp_x, exp_x_dot\n\n\ndef log_lerp(t, v0, v1):\n  \"\"\"Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).\"\"\"\n  if v0 <= 0 or v1 <= 0:\n    raise ValueError(f'Interpolants {v0} and {v1} must be positive.')\n  lv0 = jnp.log(v0)\n  lv1 = jnp.log(v1)\n  return jnp.exp(jnp.clip(t, 0, 1) * (lv1 - lv0) + lv0)\n\n\ndef learning_rate_decay(step,\n                        lr_init,\n                        lr_final,\n                        max_steps,\n                        lr_delay_steps=0,\n                        lr_delay_mult=1):\n  \"\"\"Continuous learning rate decay function.\n\n  The returned rate is lr_init when step=0 and lr_final when step=max_steps, and\n  is log-linearly interpolated elsewhere (equivalent to exponential decay).\n  If lr_delay_steps>0 then the learning rate will be scaled by some smooth\n  function of lr_delay_mult, such that the initial learning rate is\n  lr_init*lr_delay_mult at the beginning of optimization but will be eased back\n  to the normal learning rate when steps>lr_delay_steps.\n\n  Args:\n    step: int, the current optimization step.\n    lr_init: float, the initial learning rate.\n    lr_final: float, the final learning rate.\n    max_steps: int, the number of steps during optimization.\n    lr_delay_steps: int, the number of steps to delay the full learning rate.\n    lr_delay_mult: float, the multiplier on the rate when delaying it.\n\n  Returns:\n    lr: the learning for current step 'step'.\n  \"\"\"\n  if lr_delay_steps > 0:\n    # A kind of reverse cosine decay.\n    delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin(\n        0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1))\n  else:\n    delay_rate = 1.\n  return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)\n\n\ndef interp(*args):\n  \"\"\"A gather-based (GPU-friendly) vectorized replacement for jnp.interp().\"\"\"\n  args_flat = [x.reshape([-1, x.shape[-1]]) for x in args]\n  ret = jax.vmap(jnp.interp)(*args_flat).reshape(args[0].shape)\n  return ret\n\n\ndef sorted_interp(x, xp, fp):\n  \"\"\"A TPU-friendly version of interp(), where xp and fp must be sorted.\"\"\"\n\n  # Identify the location in `xp` that corresponds to each `x`.\n  # The final `True` index in `mask` is the start of the matching interval.\n  mask = x[..., None, :] >= xp[..., :, None]\n\n  def find_interval(x):\n    # Grab the value where `mask` switches from True to False, and vice versa.\n    # This approach takes advantage of the fact that `x` is sorted.\n    x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)\n    x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)\n    return x0, x1\n\n  fp0, fp1 = find_interval(fp)\n  xp0, xp1 = find_interval(xp)\n\n  offset = jnp.clip(jnp.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)\n  ret = fp0 + offset * (fp1 - fp0)\n  return ret\n"
  },
  {
    "path": "mip360/internal/models.py",
    "content": "# Copyright 2023 Ze-Xin Yin\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"NeRF and its MLPs, with helper functions for construction and rendering.\"\"\"\n\nimport functools\nfrom typing import Any, Callable, List, Mapping, MutableMapping, Optional, Text, Tuple\n\nfrom flax import linen as nn\nimport gin\nfrom internal import configs\nfrom internal import coord\nfrom internal import geopoly\nfrom internal import image\nfrom internal import math\nfrom internal import ref_utils\nfrom internal import render\nfrom internal import stepfun\nfrom internal import utils\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\n\ngin.config.external_configurable(math.safe_exp, module='math')\ngin.config.external_configurable(coord.contract, module='coord')\n\n\ndef random_split(rng):\n  if rng is None:\n    key = None\n  else:\n    key, rng = random.split(rng)\n  return key, rng\n\n\n@gin.configurable\nclass Model(nn.Module):\n  \"\"\"A mip-Nerf360 model containing all MLPs.\"\"\"\n  config: Any = None  # A Config class, must be set upon construction.\n  num_prop_samples: int = 64  # The number of samples for each proposal level.\n  num_nerf_samples: int = 32  # The number of samples the final nerf level.\n  num_levels: int = 3  # The number of sampling levels (3==2 proposals, 1 nerf).\n  bg_intensity_range: Tuple[float] = (1., 1.)  # The range of background colors.\n  anneal_slope: float = 10  # Higher = more rapid annealing.\n  stop_level_grad: bool = True  # If True, don't backprop across levels.\n  use_viewdirs: bool = True  # If True, use view directions as input.\n  raydist_fn: Callable[..., Any] = None  # The curve used for ray dists.\n  ray_shape: str = 'cone'  # The shape of cast rays ('cone' or 'cylinder').\n  disable_integration: bool = False  # If True, use PE instead of IPE.\n  single_jitter: bool = True  # If True, jitter whole rays instead of samples.\n  dilation_multiplier: float = 0.5  # How much to dilate intervals relatively.\n  dilation_bias: float = 0.0025  # How much to dilate intervals absolutely.\n  num_glo_features: int = 0  # GLO vector length, disabled if 0.\n  num_glo_embeddings: int = 1000  # Upper bound on max number of train images.\n  learned_exposure_scaling: bool = False  # Learned exposure scaling (RawNeRF).\n  near_anneal_rate: Optional[float] = None  # How fast to anneal in near bound.\n  near_anneal_init: float = 0.95  # Where to initialize near bound (in [0, 1]).\n  single_mlp: bool = False  # Use the NerfMLP for all rounds of sampling.\n  resample_padding: float = 0.0  # Dirichlet/alpha \"padding\" on the histogram.\n  use_gpu_resampling: bool = False  # Use gather ops for faster GPU resampling.\n  opaque_background: bool = False  # If true, make the background opaque.\n  num_space: int = 1 # The number of sub-spaces\n\n  @nn.compact\n  def __call__(\n      self,\n      rng,\n      rays,\n      train_frac,\n      compute_extras,\n      zero_glo=True,\n  ):\n    \"\"\"The mip-NeRF Model.\n\n    Args:\n      rng: random number generator (or None for deterministic output).\n      rays: util.Rays, a pytree of ray origins, directions, and viewdirs.\n      train_frac: float in [0, 1], what fraction of training is complete.\n      compute_extras: bool, if True, compute extra quantities besides color.\n      zero_glo: bool, if True, when using GLO pass in vector of zeros.\n\n    Returns:\n      ret: list, [*(rgb, distance, acc)]\n    \"\"\"\n\n    # Construct MLPs. WARNING: Construction order may matter, if MLP weights are\n    # being regularized.\n    nerf_mlp = NerfMLP()\n    prop_mlp = nerf_mlp if self.single_mlp else PropMLP()\n    decoder = None if self.num_space == 1 else DecoderMLP()\n\n    if self.num_glo_features > 0:\n      if not zero_glo:\n        # Construct/grab GLO vectors for the cameras of each input ray.\n        glo_vecs = nn.Embed(self.num_glo_embeddings, self.num_glo_features)\n        cam_idx = rays.cam_idx[..., 0]\n        glo_vec = glo_vecs(cam_idx)\n      else:\n        glo_vec = jnp.zeros(rays.origins.shape[:-1] + (self.num_glo_features,))\n    else:\n      glo_vec = None\n\n    if self.learned_exposure_scaling:\n      # Setup learned scaling factors for output colors.\n      max_num_exposures = self.num_glo_embeddings\n      # Initialize the learned scaling offsets at 0.\n      init_fn = jax.nn.initializers.zeros\n      exposure_scaling_offsets = nn.Embed(\n          max_num_exposures,\n          features=3,\n          embedding_init=init_fn,\n          name='exposure_scaling_offsets')\n\n    # Define the mapping from normalized to metric ray distance.\n    _, s_to_t = coord.construct_ray_warps(self.raydist_fn, rays.near, rays.far)\n\n    # Initialize the range of (normalized) distances for each ray to [0, 1],\n    # and assign that single interval a weight of 1. These distances and weights\n    # will be repeatedly updated as we proceed through sampling levels.\n    # `near_anneal_rate` can be used to anneal in the near bound at the start\n    # of training, eg. 0.1 anneals in the bound over the first 10% of training.\n    if self.near_anneal_rate is None:\n      init_s_near = 0.\n    else:\n      init_s_near = jnp.clip(1 - train_frac / self.near_anneal_rate, 0,\n                             self.near_anneal_init)\n    init_s_far = 1.\n    sdist = jnp.concatenate([\n        jnp.full_like(rays.near, init_s_near),\n        jnp.full_like(rays.far, init_s_far)\n    ],\n                            axis=-1)\n    weights = jnp.ones_like(rays.near)\n    prod_num_samples = 1\n\n    ray_history = []\n    renderings = []\n    for i_level in range(self.num_levels):\n      is_prop = i_level < (self.num_levels - 1)\n      num_samples = self.num_prop_samples if is_prop else self.num_nerf_samples\n\n      # Dilate by some multiple of the expected span of each current interval,\n      # with some bias added in.\n      dilation = self.dilation_bias + self.dilation_multiplier * (\n          init_s_far - init_s_near) / prod_num_samples\n\n      # Record the product of the number of samples seen so far.\n      prod_num_samples *= num_samples\n\n      # After the first level (where dilation would be a no-op) optionally\n      # dilate the interval weights along each ray slightly so that they're\n      # overestimates, which can reduce aliasing.\n      use_dilation = self.dilation_bias > 0 or self.dilation_multiplier > 0\n      if i_level > 0 and use_dilation:\n        sdist, weights = stepfun.max_dilate_weights(\n            sdist,\n            weights,\n            dilation,\n            domain=(init_s_near, init_s_far),\n            renormalize=True)\n        sdist = sdist[..., 1:-1]\n        weights = weights[..., 1:-1]\n\n      # Optionally anneal the weights as a function of training iteration.\n      if self.anneal_slope > 0:\n        # Schlick's bias function, see https://arxiv.org/abs/2010.09714\n        bias = lambda x, s: (s * x) / ((s - 1) * x + 1)\n        anneal = bias(train_frac, self.anneal_slope)\n      else:\n        anneal = 1.\n\n      # A slightly more stable way to compute weights**anneal. If the distance\n      # between adjacent intervals is zero then its weight is fixed to 0.\n      logits_resample = jnp.where(\n          sdist[..., 1:] > sdist[..., :-1],\n          anneal * jnp.log(weights + self.resample_padding), -jnp.inf)\n\n      # Draw sampled intervals from each ray's current weights.\n      key, rng = random_split(rng)\n      sdist = stepfun.sample_intervals(\n          key,\n          sdist,\n          logits_resample,\n          num_samples,\n          single_jitter=self.single_jitter,\n          domain=(init_s_near, init_s_far),\n          use_gpu_resampling=self.use_gpu_resampling)\n\n      # Optimization will usually go nonlinear if you propagate gradients\n      # through sampling.\n      if self.stop_level_grad:\n        sdist = jax.lax.stop_gradient(sdist)\n\n      # Convert normalized distances to metric distances.\n      tdist = s_to_t(sdist)\n\n      # Cast our rays, by turning our distance intervals into Gaussians.\n      gaussians = render.cast_rays(\n          tdist,\n          rays.origins,\n          rays.directions,\n          rays.radii,\n          self.ray_shape,\n          diag=False)\n\n      if self.disable_integration:\n        # Setting the covariance of our Gaussian samples to 0 disables the\n        # \"integrated\" part of integrated positional encoding.\n        gaussians = (gaussians[0], jnp.zeros_like(gaussians[1]))\n      # print (f'level {i_level}')\n      # Push our Gaussians through one of our two MLPs.\n      mlp = prop_mlp if is_prop else nerf_mlp\n      key, rng = random_split(rng)\n      \n      ray_results = mlp(\n          key,\n          gaussians,\n          viewdirs=rays.viewdirs if self.use_viewdirs else None,\n          imageplane=rays.imageplane,\n          glo_vec=None if is_prop else glo_vec,\n          exposure=rays.exposure_values,\n          num_space=1 if is_prop else self.num_space,\n      )\n\n      # Get the weights used by volumetric rendering (and our other losses).\n      cumprod_weights = render.compute_alpha_weights if is_prop or self.num_space == 1 else render.compute_alpha_weights_multispace\n      weights = cumprod_weights(\n          ray_results['density'],\n          tdist,\n          rays.directions,\n          opaque_background=self.opaque_background,\n      )[0]\n\n      # Define or sample the background color for each ray.\n      if self.bg_intensity_range[0] == self.bg_intensity_range[1]:\n        # If the min and max of the range are equal, just take it.\n        # bg_rgbs scalar\n        bg_rgbs = self.bg_intensity_range[0]\n      elif rng is None:\n        # If rendering is deterministic, use the midpoint of the range.\n        bg_rgbs = (self.bg_intensity_range[0] + self.bg_intensity_range[1]) / 2\n      else:\n        # Sample RGB values from the range for each ray.\n        # bg_rgbs [n_rays, n_samples, 3]\n        key, rng = random_split(rng)\n        bg_rgbs = random.uniform(\n            key,\n            shape=weights.shape[:-1] + (3,),\n            minval=self.bg_intensity_range[0],\n            maxval=self.bg_intensity_range[1])\n\n      # RawNeRF exposure logic.\n      if rays.exposure_idx is not None:\n        # Scale output colors by the exposure.\n        ray_results['rgb'] *= rays.exposure_values[..., None, :]\n        if self.learned_exposure_scaling:\n          exposure_idx = rays.exposure_idx[..., 0]\n          # Force scaling offset to always be zero when exposure_idx is 0.\n          # This constraint fixes a reference point for the scene's brightness.\n          mask = exposure_idx > 0\n          # Scaling is parameterized as an offset from 1.\n          scaling = 1 + mask[..., None] * exposure_scaling_offsets(exposure_idx)\n          ray_results['rgb'] *= scaling[..., None, :]\n\n      if self.num_space == 1 or is_prop:\n        rendering = render.volumetric_rendering(\n            ray_results['rgb'],\n            weights,\n            tdist,\n            bg_rgbs,\n            rays.far,\n            compute_extras,\n            extras={\n                k: v\n                for k, v in ray_results.items()\n                if k.startswith('normals') or k in ['roughness']\n            })\n      else:\n        rendering, weights, ray_rgbs = render.volumetric_rendering_multispace(ray_results['rgb'], \n                                                                              weights, \n                                                                              tdist, \n                                                                              bg_rgbs, \n                                                                              rays.far, \n                                                                              compute_extras,\n                                                                              num_space=self.num_space,\n                                                                              decoder=decoder,\n                                                                              extras={\n                                                                              k: v\n                                                                              for k, v in ray_results.items()\n                                                                              if k.startswith('normals') or k in ['roughness']\n                                                                              })\n        ray_results['rgb'] = ray_rgbs\n      \n      if compute_extras:\n        # Collect some rays to visualize directly. By naming these quantities\n        # with `ray_` they get treated differently downstream --- they're\n        # treated as bags of rays, rather than image chunks.\n        n = self.config.vis_num_rays\n        if self.config.eval_one:\n          n = rays.origins.shape[0]\n        rendering['ray_sdist'] = sdist.reshape([-1, sdist.shape[-1]])[:n, :]\n        rendering['ray_densities'] = ray_results['density'][:n]\n        rendering['ray_weights'] = (\n            weights.reshape([-1, weights.shape[-1]])[:n, :])\n        rgb = ray_results['rgb']\n        rendering['ray_rgbs'] = (rgb.reshape((-1,) + rgb.shape[-2:]))[:n, :, :]\n\n      renderings.append(rendering)\n      ray_results['sdist'] = jnp.copy(sdist)\n      ray_results['weights'] = jnp.copy(weights)\n      ray_history.append(ray_results)\n\n    if compute_extras:\n      # Because the proposal network doesn't produce meaningful colors, for\n      # easier visualization we replace their colors with the final average\n      # color.\n      weights = [r['ray_weights'] for r in renderings]\n      rgbs = [r['ray_rgbs'] for r in renderings]\n      final_rgb = jnp.sum(rgbs[-1] * weights[-1][..., None], axis=-2)\n      avg_rgbs = [\n          jnp.broadcast_to(final_rgb[:, None, :], r.shape) for r in rgbs[:-1]\n      ]\n      for i in range(len(avg_rgbs)):\n        renderings[i]['ray_rgbs'] = avg_rgbs[i]\n        \n    return renderings, ray_history\n\n\ndef construct_model(rng, rays, config):\n  \"\"\"Construct a mip-NeRF 360 model.\n\n  Args:\n    rng: jnp.ndarray. Random number generator.\n    rays: an example of input Rays.\n    config: A Config class.\n\n  Returns:\n    model: initialized nn.Module, a NeRF model with parameters.\n    init_variables: flax.Module.state, initialized NeRF model parameters.\n  \"\"\"\n  # Grab just 10 rays, to minimize memory overhead during construction.\n  ray = jax.tree_util.tree_map(lambda x: jnp.reshape(x, [-1, x.shape[-1]])[:10],\n                               rays)\n  model = Model(config=config)\n  init_variables = model.init(\n      rng,  # The RNG used by flax to initialize random weights.\n      rng=None,  # The RNG used by sampling within the model.\n      rays=ray,\n      train_frac=1.,\n      compute_extras=False,\n      zero_glo=model.num_glo_features == 0)\n  return model, init_variables\n\n\nclass MLP(nn.Module):\n  \"\"\"A PosEnc MLP.\"\"\"\n  net_depth: int = 8  # The depth of the first part of MLP.\n  net_width: int = 256  # The width of the first part of MLP.\n  bottleneck_width: int = 256  # The width of the bottleneck vector.\n  net_depth_viewdirs: int = 1  # The depth of the second part of ML.\n  net_width_viewdirs: int = 128  # The width of the second part of MLP.\n  net_activation: Callable[..., Any] = nn.relu  # The activation function.\n  min_deg_point: int = 0  # Min degree of positional encoding for 3D points.\n  max_deg_point: int = 12  # Max degree of positional encoding for 3D points.\n  weight_init: str = 'he_uniform'  # Initializer for the weights of the MLP.\n  skip_layer: int = 4  # Add a skip connection to the output of every N layers.\n  skip_layer_dir: int = 4  # Add a skip connection to 2nd MLP every N layers.\n  num_rgb_channels: int = 3  # The number of RGB channels.\n  deg_view: int = 4  # Degree of encoding for viewdirs or refdirs.\n  use_reflections: bool = False  # If True, use refdirs instead of viewdirs.\n  use_directional_enc: bool = False  # If True, use IDE to encode directions.\n  # If False and if use_directional_enc is True, use zero roughness in IDE.\n  enable_pred_roughness: bool = False\n  # Roughness activation function.\n  roughness_activation: Callable[..., Any] = nn.softplus\n  roughness_bias: float = -1.  # Shift added to raw roughness pre-activation.\n  use_diffuse_color: bool = False  # If True, predict diffuse & specular colors.\n  use_specular_tint: bool = False  # If True, predict tint.\n  use_n_dot_v: bool = False  # If True, feed dot(n * viewdir) to 2nd MLP.\n  bottleneck_noise: float = 0.0  # Std. deviation of noise added to bottleneck.\n  density_activation: Callable[..., Any] = nn.softplus  # Density activation.\n  density_bias: float = -1.  # Shift added to raw densities pre-activation.\n  density_noise: float = 0.  # Standard deviation of noise added to raw density.\n  rgb_premultiplier: float = 1.  # Premultiplier on RGB before activation.\n  rgb_activation: Callable[..., Any] = nn.sigmoid  # The RGB activation.\n  rgb_bias: float = 0.  # The shift added to raw colors pre-activation.\n  rgb_padding: float = 0.001  # Padding added to the RGB outputs.\n  enable_pred_normals: bool = False  # If True compute predicted normals.\n  disable_density_normals: bool = False  # If True don't compute normals.\n  disable_rgb: bool = False  # If True don't output RGB.\n  warp_fn: Callable[..., Any] = None\n  basis_shape: str = 'icosahedron'  # `octahedron` or `icosahedron`.\n  basis_subdivisions: int = 2  # Tesselation count. 'octahedron' + 1 == eye(3).\n\n  def setup(self):\n    # Make sure that normals are computed if reflection direction is used.\n    if self.use_reflections and not (self.enable_pred_normals or\n                                     not self.disable_density_normals):\n      raise ValueError('Normals must be computed for reflection directions.')\n\n    # Precompute and store (the transpose of) the basis being used.\n    self.pos_basis_t = jnp.array(\n        geopoly.generate_basis(self.basis_shape, self.basis_subdivisions)).T\n\n    # Precompute and define viewdir or refdir encoding function.\n    if self.use_directional_enc:\n      self.dir_enc_fn = ref_utils.generate_ide_fn(self.deg_view)\n    else:\n\n      def dir_enc_fn(direction, _):\n        return coord.pos_enc(\n            direction, min_deg=0, max_deg=self.deg_view, append_identity=True)\n\n      self.dir_enc_fn = dir_enc_fn\n\n  @nn.compact\n  def __call__(self,\n               rng,\n               gaussians,\n               viewdirs=None,\n               imageplane=None,\n               glo_vec=None,\n               exposure=None,\n               num_space=1):\n    \"\"\"Evaluate the MLP.\n\n    Args:\n      rng: jnp.ndarray. Random number generator.\n      gaussians: a tuple containing:                                           /\n        - mean: [..., n, 3], coordinate means, and                             /\n        - cov: [..., n, 3{, 3}], coordinate covariance matrices.\n      viewdirs: jnp.ndarray(float32), [..., 3], if not None, this variable will\n        be part of the input to the second part of the MLP concatenated with the\n        output vector of the first part of the MLP. If None, only the first part\n        of the MLP will be used with input x. In the original paper, this\n        variable is the view direction.\n      imageplane: jnp.ndarray(float32), [batch, 2], xy image plane coordinates\n        for each ray in the batch. Useful for image plane operations such as a\n        learned vignette mapping.\n      glo_vec: [..., num_glo_features], The GLO vector for each ray.\n      exposure: [..., 1], exposure value (shutter_speed * ISO) for each ray.\n\n    Returns:\n      rgb: jnp.ndarray(float32), with a shape of [..., num_rgb_channels].\n      density: jnp.ndarray(float32), with a shape of [...].\n      normals: jnp.ndarray(float32), with a shape of [..., 3], or None.\n      normals_pred: jnp.ndarray(float32), with a shape of [..., 3], or None.\n      roughness: jnp.ndarray(float32), with a shape of [..., 1], or None.\n    \"\"\"\n\n    dense_layer = functools.partial(\n        nn.Dense, kernel_init=getattr(jax.nn.initializers, self.weight_init)())\n\n    density_key, rng = random_split(rng)\n\n    def predict_density(means, covs):\n      \"\"\"Helper function to output density.\"\"\"\n      # Encode input positions\n\n      if self.warp_fn is not None:\n        means, covs = coord.track_linearize(self.warp_fn, means, covs)\n\n      lifted_means, lifted_vars = (\n          coord.lift_and_diagonalize(means, covs, self.pos_basis_t))\n      x = coord.integrated_pos_enc(lifted_means, lifted_vars,\n                                   self.min_deg_point, self.max_deg_point)\n\n      inputs = x\n      # Evaluate network to produce the output density.\n      for i in range(self.net_depth):\n        x = dense_layer(self.net_width)(x)\n        x = self.net_activation(x)\n        if i % self.skip_layer == 0 and i > 0:\n          x = jnp.concatenate([x, inputs], axis=-1)\n      # raw_density: [n_rays, n_samples] if num_space == 1 else [n_rays, n_samples, num_space]\n      # x: [n_rays, n_samples, self.net_width]\n      if num_space == 1:\n        raw_density = dense_layer(1)(x)[..., 0]  # Hardcoded to a single channel.\n      else:\n        raw_density = dense_layer(1 * num_space)(x)\n      # print (f'raw_density shape {raw_density.shape}, x shape {x.shape}')\n      # Add noise to regularize the density predictions if needed.\n      if (density_key is not None) and (self.density_noise > 0):\n        raw_density += self.density_noise * random.normal(\n            density_key, raw_density.shape)\n      return raw_density, x\n\n    means, covs = gaussians\n    if self.disable_density_normals:\n      raw_density, x = predict_density(means, covs)\n      raw_grad_density = None\n      normals = None\n    else:\n      # Flatten the input so value_and_grad can be vmap'ed.\n      means_flat = means.reshape((-1, means.shape[-1]))\n      covs_flat = covs.reshape((-1,) + covs.shape[len(means.shape) - 1:])\n\n      # Evaluate the network and its gradient on the flattened input.\n      predict_density_and_grad_fn = jax.vmap(\n          jax.value_and_grad(predict_density, has_aux=True), in_axes=(0, 0))\n      (raw_density_flat, x_flat), raw_grad_density_flat = (\n          predict_density_and_grad_fn(means_flat, covs_flat))\n\n      # Unflatten the output.\n      raw_density = raw_density_flat.reshape(list(means.shape[:-1]) + [] if num_space == 1 else [num_space])\n      x = x_flat.reshape(means.shape[:-1] + (x_flat.shape[-1],))\n      # means shape\n      raw_grad_density = raw_grad_density_flat.reshape(means.shape)\n\n      # Compute normal vectors as negative normalized density gradient.\n      # We normalize the gradient of raw (pre-activation) density because\n      # it's the same as post-activation density, but is more numerically stable\n      # when the activation function has a steep or flat gradient.\n      normals = -ref_utils.l2_normalize(raw_grad_density)\n\n    if self.enable_pred_normals:\n      grad_pred = dense_layer(3)(x)\n\n      # Normalize negative predicted gradients to get predicted normal vectors.\n      normals_pred = -ref_utils.l2_normalize(grad_pred)\n      normals_to_use = normals_pred\n    else:\n      grad_pred = None\n      normals_pred = None\n      normals_to_use = normals\n\n    # Apply bias and activation to raw density\n    density = self.density_activation(raw_density + self.density_bias)\n\n    roughness = None\n    if self.disable_rgb:\n      rgb = jnp.zeros_like(means)\n    else:\n      if viewdirs is not None:\n        # Predict diffuse color.\n        if self.use_diffuse_color:\n          raw_rgb_diffuse = dense_layer(self.num_rgb_channels)(x)\n\n        if self.use_specular_tint:\n          tint = nn.sigmoid(dense_layer(3)(x))\n\n        if self.enable_pred_roughness:\n          raw_roughness = dense_layer(1)(x)\n          roughness = (\n              self.roughness_activation(raw_roughness + self.roughness_bias))\n\n        # Output of the first part of MLP.\n        if self.bottleneck_width > 0:\n          bottleneck = dense_layer(self.bottleneck_width)(x)\n\n          # Add bottleneck noise.\n          if (rng is not None) and (self.bottleneck_noise > 0):\n            key, rng = random_split(rng)\n            bottleneck += self.bottleneck_noise * random.normal(\n                key, bottleneck.shape)\n\n          x = [bottleneck]\n        else:\n          x = []\n\n        # Encode view (or reflection) directions.\n        if self.use_reflections:\n          # Compute reflection directions. Note that we flip viewdirs before\n          # reflecting, because they point from the camera to the point,\n          # whereas ref_utils.reflect() assumes they point toward the camera.\n          # Returned refdirs then point from the point to the environment.\n          refdirs = ref_utils.reflect(-viewdirs[..., None, :], normals_to_use)\n          # Encode reflection directions.\n          dir_enc = self.dir_enc_fn(refdirs, roughness)\n        else:\n          # Encode view directions.\n          dir_enc = self.dir_enc_fn(viewdirs, roughness)\n          dir_enc = jnp.broadcast_to(\n              dir_enc[..., None, :],\n              bottleneck.shape[:-1] + (dir_enc.shape[-1],))\n\n        # Append view (or reflection) direction encoding to bottleneck vector.\n        x.append(dir_enc)\n\n        # Append dot product between normal vectors and view directions.\n        # False for mip nerf 360\n        if self.use_n_dot_v:\n          dotprod = jnp.sum(\n              normals_to_use * viewdirs[..., None, :], axis=-1, keepdims=True)\n          x.append(dotprod)\n\n        # Append GLO vector if used.\n        if glo_vec is not None:\n          glo_vec = jnp.broadcast_to(glo_vec[..., None, :],\n                                     bottleneck.shape[:-1] + glo_vec.shape[-1:])\n          x.append(glo_vec)\n\n        # Concatenate bottleneck, directional encoding, and GLO.\n        x = jnp.concatenate(x, axis=-1)\n\n        # Output of the second part of MLP.\n        inputs = x\n        # print (f'inputs to mlp shape {inputs.shape}')\n        for i in range(self.net_depth_viewdirs):\n          x = dense_layer(self.net_width_viewdirs)(x)\n          x = self.net_activation(x)\n          if i % self.skip_layer_dir == 0 and i > 0:\n            x = jnp.concatenate([x, inputs], axis=-1)\n\n      # If using diffuse/specular colors, then `rgb` is treated as linear\n      # specular color. Otherwise it's treated as the color itself.\n      rgb = self.rgb_activation(self.rgb_premultiplier *\n                                dense_layer(self.num_rgb_channels * num_space)(x) +\n                                self.rgb_bias)\n\n      if self.use_diffuse_color:\n        # Initialize linear diffuse color around 0.25, so that the combined\n        # linear color is initialized around 0.5.\n        diffuse_linear = nn.sigmoid(raw_rgb_diffuse - jnp.log(3.0))\n        if self.use_specular_tint:\n          specular_linear = tint * rgb\n        else:\n          specular_linear = 0.5 * rgb\n\n        # Combine specular and diffuse components and tone map to sRGB.\n        rgb = jnp.clip(\n            image.linear_to_srgb(specular_linear + diffuse_linear), 0.0, 1.0)\n\n      # Apply padding, mapping color to [-rgb_padding, 1+rgb_padding].\n      rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding\n\n    return dict(\n        density=density,\n        rgb=rgb,\n        raw_grad_density=raw_grad_density,\n        grad_pred=grad_pred,\n        normals=normals,\n        normals_pred=normals_pred,\n        roughness=roughness,\n    )\n    \n\nclass Decoder(nn.Module):\n  net_depth: int = 2 # The depth of decoder module.\n  hidden_width: int = 64 # The hidden dim of the bottleneck layer.\n  net_output: int = 3 # The output dim of the Decoder.\n  weight_init: str = 'he_uniform'  # Initializer for the weights of the Decoder.\n  net_activation: Callable[..., Any] = nn.relu  # The activation function.\n  \n  @nn.compact\n  def __call__(self,\n               feats,\n               ):\n    dense_layer = functools.partial(\n        nn.Dense, kernel_init=getattr(jax.nn.initializers, self.weight_init)())\n    \n    alpha = feats\n    for _ in range(self.net_depth - 1):\n      feats = dense_layer(self.hidden_width)(feats)\n      feats = self.net_activation(feats)\n      \n      alpha = dense_layer(self.hidden_width)(alpha)\n      alpha = self.net_activation(alpha)\n    \n    alpha = dense_layer(1)(alpha)\n    alpha = nn.softmax(alpha, axis=-2)\n    \n    rgbs = dense_layer(3)(feats)\n    rgbs = nn.sigmoid(rgbs)\n    rgb = (alpha * rgbs).sum(-2)\n      \n    return rgb, rgbs, alpha\n\n\n@gin.configurable\nclass NerfMLP(MLP):\n  pass\n\n\n@gin.configurable\nclass PropMLP(MLP):\n  pass\n\n@gin.configurable\nclass DecoderMLP(Decoder):\n  pass\n\ndef render_image(render_fn: Callable[[jnp.array, utils.Rays],\n                                     Tuple[List[Mapping[Text, jnp.ndarray]],\n                                           List[Tuple[jnp.ndarray, ...]]]],\n                 rays: utils.Rays,\n                 rng: jnp.array,\n                 config: configs.Config,\n                 verbose: bool = True) -> MutableMapping[Text, Any]:\n  \"\"\"Render all the pixels of an image (in test mode).\n\n  Args:\n    render_fn: function, jit-ed render function mapping (rng, rays) -> pytree.\n    rays: a `Rays` pytree, the rays to be rendered.\n    rng: jnp.ndarray, random number generator (used in training mode only).\n    config: A Config class.\n    verbose: print progress indicators.\n\n  Returns:\n    rgb: jnp.ndarray, rendered color image.\n    disp: jnp.ndarray, rendered disparity image.\n    acc: jnp.ndarray, rendered accumulated weights per pixel.\n  \"\"\"\n  height, width = rays.origins.shape[:2]\n  num_rays = height * width\n  rays = jax.tree_util.tree_map(lambda r: r.reshape((num_rays, -1)), rays)\n\n  host_id = jax.process_index()\n  chunks = []\n  idx0s = range(0, num_rays, config.render_chunk_size)\n  for i_chunk, idx0 in enumerate(idx0s):\n    # pylint: disable=cell-var-from-loop\n    if verbose and i_chunk % max(1, len(idx0s) // 10) == 0:\n      print(f'Rendering chunk {i_chunk}/{len(idx0s)-1}')\n    chunk_rays = (\n        jax.tree_util.tree_map(\n            lambda r: r[idx0:idx0 + config.render_chunk_size], rays))\n    actual_chunk_size = chunk_rays.origins.shape[0]\n    rays_remaining = actual_chunk_size % jax.device_count()\n    if rays_remaining != 0:\n      padding = jax.device_count() - rays_remaining\n      chunk_rays = jax.tree_util.tree_map(\n          lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode='edge'), chunk_rays)\n    else:\n      padding = 0\n    # After padding the number of chunk_rays is always divisible by host_count.\n    rays_per_host = chunk_rays.origins.shape[0] // jax.process_count()\n    start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host\n    chunk_rays = jax.tree_util.tree_map(lambda r: utils.shard(r[start:stop]),\n                                        chunk_rays)\n    chunk_renderings, _ = render_fn(rng, chunk_rays)\n\n    # Unshard the renderings.\n    chunk_renderings = jax.tree_util.tree_map(\n        lambda v: utils.unshard(v[0], padding), chunk_renderings)\n\n    # Gather the final pass for 2D buffers and all passes for ray bundles.\n    chunk_rendering = chunk_renderings[-1]\n    for k in chunk_renderings[0]:\n      if k.startswith('ray_'):\n        chunk_rendering[k] = [r[k] for r in chunk_renderings]\n\n    chunks.append(chunk_rendering)\n\n  # Concatenate all chunks within each leaf of a single pytree.\n  rendering = (\n      jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))\n  for k, z in rendering.items():\n    if not k.startswith('ray_'):\n      # Reshape 2D buffers into original image shape.\n      rendering[k] = z.reshape((height, width) + z.shape[1:])\n\n  # After all of the ray bundles have been concatenated together, extract a\n  # new random bundle (deterministically) from the concatenation that is the\n  # same size as one of the individual bundles.\n  keys = [k for k in rendering if k.startswith('ray_')]\n  if keys:\n    num_rays = rendering[keys[0]][0].shape[0]\n    if config.eval_one:\n      for k in keys:\n        rendering[k] = [z.reshape((height, width) + z.shape[1:]) for z in rendering[k]]\n    else:\n      ray_idx = random.permutation(random.PRNGKey(0), num_rays)\n      ray_idx = ray_idx[:config.vis_num_rays]\n      for k in keys:\n        rendering[k] = [r[ray_idx] for r in rendering[k]]\n\n  return rendering\n"
  },
  {
    "path": "mip360/internal/raw_utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functions for processing and loading raw image data.\"\"\"\n\nimport glob\nimport json\nimport os\nimport types\nfrom typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple, Union\n\nfrom internal import image as lib_image\nfrom internal import math\nfrom internal import utils\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport rawpy\n\n_Array = Union[np.ndarray, jnp.ndarray]\n_Axis = Optional[Union[int, Tuple[int, ...]]]\n\n\ndef postprocess_raw(raw: _Array,\n                    camtorgb: _Array,\n                    exposure: Optional[float] = None,\n                    xnp: types.ModuleType = np) -> _Array:\n  \"\"\"Converts demosaicked raw to sRGB with a minimal postprocessing pipeline.\n\n  Numpy array inputs will be automatically converted to Jax arrays.\n\n  Args:\n    raw: [H, W, 3], demosaicked raw camera image.\n    camtorgb: [3, 3], color correction transformation to apply to raw image.\n    exposure: color value to be scaled to pure white after color correction.\n              If None, \"autoexposes\" at the 97th percentile.\n    xnp: either numpy or jax.numpy.\n\n  Returns:\n    srgb: [H, W, 3], color corrected + exposed + gamma mapped image.\n  \"\"\"\n  if raw.shape[-1] != 3:\n    raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3')\n  if camtorgb.shape != (3, 3):\n    raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)')\n  # Convert from camera color space to standard linear RGB color space.\n  matmul = math.matmul if xnp == jnp else np.matmul\n  rgb_linear = matmul(raw, camtorgb.T)\n  if exposure is None:\n    exposure = xnp.percentile(rgb_linear, 97)\n  # \"Expose\" image by mapping the input exposure level to white and clipping.\n  rgb_linear_scaled = xnp.clip(rgb_linear / exposure, 0, 1)\n  # Apply sRGB gamma curve to serve as a simple tonemap.\n  srgb = lib_image.linear_to_srgb(rgb_linear_scaled, xnp=xnp)\n  return srgb\n\n\ndef pixels_to_bayer_mask(pix_x: np.ndarray, pix_y: np.ndarray) -> np.ndarray:\n  \"\"\"Computes binary RGB Bayer mask values from integer pixel coordinates.\"\"\"\n  # Red is top left (0, 0).\n  r = (pix_x % 2 == 0) * (pix_y % 2 == 0)\n  # Green is top right (0, 1) and bottom left (1, 0).\n  g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1)\n  # Blue is bottom right (1, 1).\n  b = (pix_x % 2 == 1) * (pix_y % 2 == 1)\n  return np.stack([r, g, b], -1).astype(np.float32)\n\n\ndef bilinear_demosaic(bayer: _Array,\n                      xnp: types.ModuleType) -> _Array:\n  \"\"\"Converts Bayer data into a full RGB image using bilinear demosaicking.\n\n  Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern:\n    -------------\n    |red  |green|\n    -------------\n    |green|blue |\n    -------------\n  Red and blue channels are bilinearly upsampled 2x, missing green channel\n  elements are the average of the neighboring 4 values in a cross pattern.\n\n  Args:\n    bayer: [H, W] array, Bayer mosaic pattern input image.\n    xnp: either numpy or jax.numpy.\n\n  Returns:\n    rgb: [H, W, 3] array, full RGB image.\n  \"\"\"\n  def reshape_quads(*planes):\n    \"\"\"Reshape pixels from four input images to make tiled 2x2 quads.\"\"\"\n    planes = xnp.stack(planes, -1)\n    shape = planes.shape[:-1]\n    # Create [2, 2] arrays out of 4 channels.\n    zup = planes.reshape(shape + (2, 2,))\n    # Transpose so that x-axis dimensions come before y-axis dimensions.\n    zup = xnp.transpose(zup, (0, 2, 1, 3))\n    # Reshape to 2D.\n    zup = zup.reshape((shape[0] * 2, shape[1] * 2))\n    return zup\n\n  def bilinear_upsample(z):\n    \"\"\"2x bilinear image upsample.\"\"\"\n    # Using np.roll makes the right and bottom edges wrap around. The raw image\n    # data has a few garbage columns/rows at the edges that must be discarded\n    # anyway, so this does not matter in practice.\n    # Horizontally interpolated values.\n    zx = .5 * (z + xnp.roll(z, -1, axis=-1))\n    # Vertically interpolated values.\n    zy = .5 * (z + xnp.roll(z, -1, axis=-2))\n    # Diagonally interpolated values.\n    zxy = .5 * (zx + xnp.roll(zx, -1, axis=-2))\n    return reshape_quads(z, zx, zy, zxy)\n\n  def upsample_green(g1, g2):\n    \"\"\"Special 2x upsample from the two green channels.\"\"\"\n    z = xnp.zeros_like(g1)\n    z = reshape_quads(z, g1, g2, z)\n    alt = 0\n    # Grab the 4 directly adjacent neighbors in a \"cross\" pattern.\n    for i in range(4):\n      axis = -1 - (i // 2)\n      roll = -1 + 2 * (i % 2)\n      alt = alt + .25 * xnp.roll(z, roll, axis=axis)\n    # For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross),\n    # so alt + z will have every pixel filled in.\n    return alt + z\n\n  r, g1, g2, b = [bayer[(i//2)::2, (i%2)::2] for i in range(4)]\n  r = bilinear_upsample(r)\n  # Flip in x and y before and after calling upsample, as bilinear_upsample\n  # assumes that the samples are at the top-left corner of the 2x2 sample.\n  b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1]\n  g = upsample_green(g1, g2)\n  rgb = xnp.stack([r, g, b], -1)\n  return rgb\n\n\nbilinear_demosaic_jax = jax.jit(lambda bayer: bilinear_demosaic(bayer, xnp=jnp))\n\n\ndef load_raw_images(image_dir: str,\n                    image_names: Optional[Sequence[str]] = None\n                    ) -> Tuple[np.ndarray, Sequence[Mapping[str, Any]]]:\n  \"\"\"Loads raw images and their metadata from disk.\n\n  Args:\n    image_dir: directory containing raw image and EXIF data.\n    image_names: files to load (ignores file extension), loads all DNGs if None.\n\n  Returns:\n    A tuple (images, exifs).\n    images: [N, height, width, 3] array of raw sensor data.\n    exifs: [N] list of dicts, one per image, containing the EXIF data.\n  Raises:\n    ValueError: The requested `image_dir` does not exist on disk.\n  \"\"\"\n\n  if not utils.file_exists(image_dir):\n    raise ValueError(f'Raw image folder {image_dir} does not exist.')\n\n  # Load raw images (dng files) and exif metadata (json files).\n  def load_raw_exif(image_name):\n    base = os.path.join(image_dir, os.path.splitext(image_name)[0])\n    with utils.open_file(base + '.dng', 'rb') as f:\n      raw = rawpy.imread(f).raw_image\n    with utils.open_file(base + '.json', 'rb') as f:\n      exif = json.load(f)[0]\n    return raw, exif\n\n  if image_names is None:\n    image_names = [\n        os.path.basename(f)\n        for f in sorted(glob.glob(os.path.join(image_dir, '*.dng')))\n    ]\n\n  data = [load_raw_exif(x) for x in image_names]\n  raws, exifs = zip(*data)\n  raws = np.stack(raws, axis=0).astype(np.float32)\n\n  return raws, exifs\n\n\n# Brightness percentiles to use for re-exposing and tonemapping raw images.\n_PERCENTILE_LIST = (80, 90, 97, 99, 100)\n\n# Relevant fields to extract from raw image EXIF metadata.\n# For details regarding EXIF parameters, see:\n# https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf.\n_EXIF_KEYS = (\n    'BlackLevel',  # Black level offset added to sensor measurements.\n    'WhiteLevel',  # Maximum possible sensor measurement.\n    'AsShotNeutral',  # RGB white balance coefficients.\n    'ColorMatrix2',  # XYZ to camera color space conversion matrix.\n    'NoiseProfile',  # Shot and read noise levels.\n)\n\n# Color conversion from reference illuminant XYZ to RGB color space.\n# See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html.\n_RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375],\n                     [0.2126729, 0.7151522, 0.0721750],\n                     [0.0193339, 0.1191920, 0.9503041]])\n\n\ndef process_exif(\n    exifs: Sequence[Mapping[str, Any]]) -> MutableMapping[str, Any]:\n  \"\"\"Processes list of raw image EXIF data into useful metadata dict.\n\n  Input should be a list of dictionaries loaded from JSON files.\n  These JSON files are produced by running\n    $ exiftool -json IMAGE.dng > IMAGE.json\n  for each input raw file.\n\n  We extract only the parameters relevant to\n  1. Rescaling the raw data to [0, 1],\n  2. White balance and color correction, and\n  3. Noise level estimation.\n\n  Args:\n    exifs: a list of dicts containing EXIF data as loaded from JSON files.\n\n  Returns:\n    meta: a dict of the relevant metadata for running RawNeRF.\n  \"\"\"\n  meta = {}\n  exif = exifs[0]\n  # Convert from array of dicts (exifs) to dict of arrays (meta).\n  for key in _EXIF_KEYS:\n    exif_value = exif.get(key)\n    if exif_value is None:\n      continue\n    # Values can be a single int or float...\n    if isinstance(exif_value, int) or isinstance(exif_value, float):\n      vals = [x[key] for x in exifs]\n    # Or a string of numbers with ' ' between.\n    elif isinstance(exif_value, str):\n      vals = [[float(z) for z in x[key].split(' ')] for x in exifs]\n    meta[key] = np.squeeze(np.array(vals))\n  # Shutter speed is a special case, a string written like 1/N.\n  meta['ShutterSpeed'] = np.fromiter(\n      (1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float)\n\n  # Create raw-to-sRGB color transform matrices. Pipeline is:\n  # cam space -> white balanced cam space (\"camwb\") -> XYZ space -> RGB space.\n  # 'AsShotNeutral' is an RGB triplet representing how pure white would measure\n  # on the sensor, so dividing by these numbers corrects the white balance.\n  whitebalance = meta['AsShotNeutral'].reshape(-1, 3)\n  cam2camwb = np.array([np.diag(1. / x) for x in whitebalance])\n  # ColorMatrix2 converts from XYZ color space to \"reference illuminant\" (white\n  # balanced) camera space.\n  xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3)\n  rgb2camwb = xyz2camwb @ _RGB2XYZ\n  # We normalize the rows of the full color correction matrix, as is done in\n  # https://github.com/AbdoKamel/simple-camera-pipeline.\n  rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True)\n  # Combining color correction with white balance gives the entire transform.\n  cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb\n  meta['cam2rgb'] = cam2rgb\n\n  return meta\n\n\ndef load_raw_dataset(split: utils.DataSplit,\n                     data_dir: str,\n                     image_names: Sequence[str],\n                     exposure_percentile: float,\n                     n_downsample: int,\n                     ) -> Tuple[np.ndarray, MutableMapping[str, Any], bool]:\n  \"\"\"Loads and processes a set of RawNeRF input images.\n\n  Includes logic necessary for special \"test\" scenes that include a noiseless\n  ground truth frame, produced by HDR+ merge.\n\n  Args:\n    split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic.\n    data_dir: base directory for scene data.\n    image_names: which images were successfully posed by COLMAP.\n    exposure_percentile: what brightness percentile to expose to white.\n    n_downsample: returned images are downsampled by a factor of n_downsample.\n\n  Returns:\n    A tuple (images, meta, testscene).\n    images: [N, height // n_downsample, width // n_downsample, 3] array of\n      demosaicked raw image data.\n    meta: EXIF metadata and other useful processing parameters. Includes per\n      image exposure information that can be passed into the NeRF model with\n      each ray: the set of unique exposure times is determined and each image\n      assigned a corresponding exposure index (mapping to an exposure value).\n      These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in\n      the `meta` dictionary.\n      We rescale so the maximum `exposure_value` is 1 for convenience.\n    testscene: True when dataset includes ground truth test image, else False.\n  \"\"\"\n\n  image_dir = os.path.join(data_dir, 'raw')\n\n  testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng')\n  testscene = utils.file_exists(testimg_file)\n  if testscene:\n    # Test scenes have train/ and test/ split subdirectories inside raw/.\n    image_dir = os.path.join(image_dir, split.value)\n    if split == utils.DataSplit.TEST:\n      # COLMAP image names not valid for test split of test scene.\n      image_names = None\n    else:\n      # Discard the first COLMAP image name as it is a copy of the test image.\n      image_names = image_names[1:]\n\n  raws, exifs = load_raw_images(image_dir, image_names)\n  meta = process_exif(exifs)\n\n  if testscene and split == utils.DataSplit.TEST:\n    # Test split for test scene must load the \"ground truth\" HDR+ merged image.\n    with utils.open_file(testimg_file, 'rb') as imgin:\n      testraw = rawpy.imread(imgin).raw_image\n    # HDR+ output has 2 extra bits of fixed precision, need to divide by 4.\n    testraw = testraw.astype(np.float32) / 4.\n    # Need to rescale long exposure test image by fast:slow shutter speed ratio.\n    fast_shutter = meta['ShutterSpeed'][0]\n    slow_shutter = meta['ShutterSpeed'][-1]\n    shutter_ratio = fast_shutter / slow_shutter\n    # Replace loaded raws with the \"ground truth\" test image.\n    raws = testraw[None]\n    # Test image shares metadata with the first loaded image (fast exposure).\n    meta = {k: meta[k][:1] for k in meta}\n  else:\n    shutter_ratio = 1.\n\n  # Next we determine an index for each unique shutter speed in the data.\n  shutter_speeds = meta['ShutterSpeed']\n  # Sort the shutter speeds from slowest (largest) to fastest (smallest).\n  # This way index 0 will always correspond to the brightest image.\n  unique_shutters = np.sort(np.unique(shutter_speeds))[::-1]\n  exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32)\n  for i, shutter in enumerate(unique_shutters):\n    # Assign index `i` to all images with shutter speed `shutter`.\n    exposure_idx[shutter_speeds == shutter] = i\n  meta['exposure_idx'] = exposure_idx\n  meta['unique_shutters'] = unique_shutters\n  # Rescale to use relative shutter speeds, where 1. is the brightest.\n  # This way the NeRF output with exposure=1 will always be reasonable.\n  meta['exposure_values'] = shutter_speeds / unique_shutters[0]\n\n  # Rescale raw sensor measurements to [0, 1] (plus noise).\n  blacklevel = meta['BlackLevel'].reshape(-1, 1, 1)\n  whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1)\n  images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio\n\n  # Calculate value for exposure level when gamma mapping, defaults to 97%.\n  # Always based on full resolution image 0 (for consistency).\n  image0_raw_demosaic = np.array(bilinear_demosaic_jax(images[0]))\n  image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T\n  exposure = np.percentile(image0_rgb, exposure_percentile)\n  meta['exposure'] = exposure\n  # Sweep over various exposure percentiles to visualize in training logs.\n  exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST}\n  meta['exposure_levels'] = exposure_levels\n\n  # Create postprocessing function mapping raw images to tonemapped sRGB space.\n  cam2rgb0 = meta['cam2rgb'][0]\n  meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x)\n\n  # Demosaic Bayer images (preserves the measured RGGB values) and downsample\n  # if needed. Moving array to device + running processing function in Jax +\n  # copying back to CPU is faster than running directly on CPU.\n  def processing_fn(x):\n    x_jax = jnp.array(x)\n    x_demosaic_jax = bilinear_demosaic_jax(x_jax)\n    if n_downsample > 1:\n      x_demosaic_jax = lib_image.downsample(x_demosaic_jax, n_downsample)\n    return np.array(x_demosaic_jax)\n  images = np.stack([processing_fn(im) for im in images], axis=0)\n\n  return images, meta, testscene\n\n\ndef best_fit_affine(x: _Array, y: _Array, axis: _Axis) -> _Array:\n  \"\"\"Computes best fit a, b such that a * x + b = y, in a least square sense.\"\"\"\n  x_m = x.mean(axis=axis)\n  y_m = y.mean(axis=axis)\n  xy_m = (x * y).mean(axis=axis)\n  xx_m = (x * x).mean(axis=axis)\n  # slope a = Cov(x, y) / Cov(x, x).\n  a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m)\n  b = y_m - a * x_m\n  return a, b\n\n\ndef match_images_affine(est: _Array, gt: _Array,\n                        axis: _Axis = (0, 1)) -> _Array:\n  \"\"\"Computes affine best fit of gt->est, then maps est back to match gt.\"\"\"\n  # Mapping is computed gt->est to be robust since `est` may be very noisy.\n  a, b = best_fit_affine(gt, est, axis=axis)\n  # Inverse mapping back to gt ensures we use a consistent space for metrics.\n  est_matched = (est - b) / a\n  return est_matched\n"
  },
  {
    "path": "mip360/internal/ref_utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functions for reflection directions and directional encodings.\"\"\"\n\nfrom internal import math\nimport jax.numpy as jnp\nimport numpy as np\n\n\ndef reflect(viewdirs, normals):\n  \"\"\"Reflect view directions about normals.\n\n  The reflection of a vector v about a unit vector n is a vector u such that\n  dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two\n  equations is u = 2 dot(n, v) n - v.\n\n  Args:\n    viewdirs: [..., 3] array of view directions.\n    normals: [..., 3] array of normal directions (assumed to be unit vectors).\n\n  Returns:\n    [..., 3] array of reflection directions.\n  \"\"\"\n  return 2.0 * jnp.sum(\n      normals * viewdirs, axis=-1, keepdims=True) * normals - viewdirs\n\n\ndef l2_normalize(x, eps=jnp.finfo(jnp.float32).eps):\n  \"\"\"Normalize x to unit length along last axis.\"\"\"\n  return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis=-1, keepdims=True), eps))\n\n\ndef compute_weighted_mae(weights, normals, normals_gt):\n  \"\"\"Compute weighted mean angular error, assuming normals are unit length.\"\"\"\n  one_eps = 1 - jnp.finfo(jnp.float32).eps\n  return (weights * jnp.arccos(\n      jnp.clip((normals * normals_gt).sum(-1), -one_eps,\n               one_eps))).sum() / weights.sum() * 180.0 / jnp.pi\n\n\ndef generalized_binomial_coeff(a, k):\n  \"\"\"Compute generalized binomial coefficients.\"\"\"\n  return np.prod(a - np.arange(k)) / np.math.factorial(k)\n\n\ndef assoc_legendre_coeff(l, m, k):\n  \"\"\"Compute associated Legendre polynomial coefficients.\n\n  Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the\n  (l, m)th associated Legendre polynomial, P_l^m(cos(theta)).\n\n  Args:\n    l: associated Legendre polynomial degree.\n    m: associated Legendre polynomial order.\n    k: power of cos(theta).\n\n  Returns:\n    A float, the coefficient of the term corresponding to the inputs.\n  \"\"\"\n  return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) /\n          np.math.factorial(l - k - m) *\n          generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l))\n\n\ndef sph_harm_coeff(l, m, k):\n  \"\"\"Compute spherical harmonic coefficients.\"\"\"\n  return (np.sqrt(\n      (2.0 * l + 1.0) * np.math.factorial(l - m) /\n      (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k))\n\n\ndef get_ml_array(deg_view):\n  \"\"\"Create a list with all pairs of (l, m) values to use in the encoding.\"\"\"\n  ml_list = []\n  for i in range(deg_view):\n    l = 2**i\n    # Only use nonnegative m values, later splitting real and imaginary parts.\n    for m in range(l + 1):\n      ml_list.append((m, l))\n\n  # Convert list into a numpy array.\n  ml_array = np.array(ml_list).T\n  return ml_array\n\n\ndef generate_ide_fn(deg_view):\n  \"\"\"Generate integrated directional encoding (IDE) function.\n\n  This function returns a function that computes the integrated directional\n  encoding from Equations 6-8 of arxiv.org/abs/2112.03907.\n\n  Args:\n    deg_view: number of spherical harmonics degrees to use.\n\n  Returns:\n    A function for evaluating integrated directional encoding.\n\n  Raises:\n    ValueError: if deg_view is larger than 5.\n  \"\"\"\n  if deg_view > 5:\n    raise ValueError('Only deg_view of at most 5 is numerically stable.')\n\n  ml_array = get_ml_array(deg_view)\n  l_max = 2**(deg_view - 1)\n\n  # Create a matrix corresponding to ml_array holding all coefficients, which,\n  # when multiplied (from the right) by the z coordinate Vandermonde matrix,\n  # results in the z component of the encoding.\n  mat = np.zeros((l_max + 1, ml_array.shape[1]))\n  for i, (m, l) in enumerate(ml_array.T):\n    for k in range(l - m + 1):\n      mat[k, i] = sph_harm_coeff(l, m, k)\n\n  def integrated_dir_enc_fn(xyz, kappa_inv):\n    \"\"\"Function returning integrated directional encoding (IDE).\n\n    Args:\n      xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at.\n      kappa_inv: [..., 1] reciprocal of the concentration parameter of the von\n        Mises-Fisher distribution.\n\n    Returns:\n      An array with the resulting IDE.\n    \"\"\"\n    x = xyz[..., 0:1]\n    y = xyz[..., 1:2]\n    z = xyz[..., 2:3]\n\n    # Compute z Vandermonde matrix.\n    vmz = jnp.concatenate([z**i for i in range(mat.shape[0])], axis=-1)\n\n    # Compute x+iy Vandermonde matrix.\n    vmxy = jnp.concatenate([(x + 1j * y)**m for m in ml_array[0, :]], axis=-1)\n\n    # Get spherical harmonics.\n    sph_harms = vmxy * math.matmul(vmz, mat)\n\n    # Apply attenuation function using the von Mises-Fisher distribution\n    # concentration parameter, kappa.\n    sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1)\n    ide = sph_harms * jnp.exp(-sigma * kappa_inv)\n\n    # Split into real and imaginary parts and return\n    return jnp.concatenate([jnp.real(ide), jnp.imag(ide)], axis=-1)\n\n  return integrated_dir_enc_fn\n\n\ndef generate_dir_enc_fn(deg_view):\n  \"\"\"Generate directional encoding (DE) function.\n\n  Args:\n    deg_view: number of spherical harmonics degrees to use.\n\n  Returns:\n    A function for evaluating directional encoding.\n  \"\"\"\n  integrated_dir_enc_fn = generate_ide_fn(deg_view)\n\n  def dir_enc_fn(xyz):\n    \"\"\"Function returning directional encoding (DE).\"\"\"\n    return integrated_dir_enc_fn(xyz, jnp.zeros_like(xyz[..., :1]))\n\n  return dir_enc_fn\n"
  },
  {
    "path": "mip360/internal/render.py",
    "content": "# Copyright 2023 Ze-Xin Yin\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helper functions for shooting and rendering rays.\"\"\"\n\nfrom internal import stepfun\nimport jax.numpy as jnp\nfrom einops import rearrange, repeat\n\n\ndef lift_gaussian(d, t_mean, t_var, r_var, diag):\n  \"\"\"Lift a Gaussian defined along a ray to 3D coordinates.\"\"\"\n  mean = d[..., None, :] * t_mean[..., None]\n\n  d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True))\n\n  if diag:\n    d_outer_diag = d**2\n    null_outer_diag = 1 - d_outer_diag / d_mag_sq\n    t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]\n    xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]\n    cov_diag = t_cov_diag + xy_cov_diag\n    return mean, cov_diag\n  else:\n    d_outer = d[..., :, None] * d[..., None, :]\n    eye = jnp.eye(d.shape[-1])\n    null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]\n    t_cov = t_var[..., None, None] * d_outer[..., None, :, :]\n    xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]\n    cov = t_cov + xy_cov\n    return mean, cov\n\n\ndef conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True):\n  \"\"\"Approximate a conical frustum as a Gaussian distribution (mean+cov).\n\n  Assumes the ray is originating from the origin, and base_radius is the\n  radius at dist=1. Doesn't assume `d` is normalized.\n\n  Args:\n    d: jnp.float32 3-vector, the axis of the cone\n    t0: float, the starting distance of the frustum.\n    t1: float, the ending distance of the frustum.\n    base_radius: float, the scale of the radius as a function of distance.\n    diag: boolean, whether or the Gaussian will be diagonal or full-covariance.\n    stable: boolean, whether or not to use the stable computation described in\n      the paper (setting this to False will cause catastrophic failure).\n\n  Returns:\n    a Gaussian (mean and covariance).\n  \"\"\"\n  if stable:\n    # Equation 7 in the paper (https://arxiv.org/abs/2103.13415).\n    mu = (t0 + t1) / 2  # The average of the two `t` values.\n    hw = (t1 - t0) / 2  # The half-width of the two `t` values.\n    eps = jnp.finfo(jnp.float32).eps\n    t_mean = mu + (2 * mu * hw**2) / jnp.maximum(eps, 3 * mu**2 + hw**2)\n    denom = jnp.maximum(eps, 3 * mu**2 + hw**2)\n    t_var = (hw**2) / 3 - (4 / 15) * hw**4 * (12 * mu**2 - hw**2) / denom**2\n    r_var = (mu**2) / 4 + (5 / 12) * hw**2 - (4 / 15) * (hw**4) / denom\n  else:\n    # Equations 37-39 in the paper.\n    t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))\n    r_var = 3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3)\n    t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)\n    t_var = t_mosq - t_mean**2\n  r_var *= base_radius**2\n  return lift_gaussian(d, t_mean, t_var, r_var, diag)\n\n\ndef cylinder_to_gaussian(d, t0, t1, radius, diag):\n  \"\"\"Approximate a cylinder as a Gaussian distribution (mean+cov).\n\n  Assumes the ray is originating from the origin, and radius is the\n  radius. Does not renormalize `d`.\n\n  Args:\n    d: jnp.float32 3-vector, the axis of the cylinder\n    t0: float, the starting distance of the cylinder.\n    t1: float, the ending distance of the cylinder.\n    radius: float, the radius of the cylinder\n    diag: boolean, whether or the Gaussian will be diagonal or full-covariance.\n\n  Returns:\n    a Gaussian (mean and covariance).\n  \"\"\"\n  t_mean = (t0 + t1) / 2\n  r_var = radius**2 / 4\n  t_var = (t1 - t0)**2 / 12\n  return lift_gaussian(d, t_mean, t_var, r_var, diag)\n\n\ndef cast_rays(tdist, origins, directions, radii, ray_shape, diag=True):\n  \"\"\"Cast rays (cone- or cylinder-shaped) and featurize sections of it.\n\n  Args:\n    tdist: float array, the \"fencepost\" distances along the ray.\n    origins: float array, the ray origin coordinates.\n    directions: float array, the ray direction vectors.\n    radii: float array, the radii (base radii for cones) of the rays.\n    ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.\n    diag: boolean, whether or not the covariance matrices should be diagonal.\n\n  Returns:\n    a tuple of arrays of means and covariances.\n    means: [n_rays, n_samples, 3]\n    covs : [n_rays, n_samples, 3, 3]\n  \"\"\"\n  t0 = tdist[..., :-1]\n  t1 = tdist[..., 1:]\n  if ray_shape == 'cone':\n    gaussian_fn = conical_frustum_to_gaussian\n  elif ray_shape == 'cylinder':\n    gaussian_fn = cylinder_to_gaussian\n  else:\n    raise ValueError('ray_shape must be \\'cone\\' or \\'cylinder\\'')\n  means, covs = gaussian_fn(directions, t0, t1, radii, diag)\n  means = means + origins[..., None, :]\n  # print (f'means shape {means.shape}, covs shape {covs.shape}')\n  return means, covs\n\n\ndef compute_alpha_weights(density, tdist, dirs, opaque_background=False):\n  \"\"\"Helper function for computing alpha compositing weights.\"\"\"\n  t_delta = tdist[..., 1:] - tdist[..., :-1]\n  delta = t_delta * jnp.linalg.norm(dirs[..., None, :], axis=-1)\n  density_delta = density * delta\n\n  if opaque_background:\n    # Equivalent to making the final t-interval infinitely wide.\n    density_delta = jnp.concatenate([\n        density_delta[..., :-1],\n        jnp.full_like(density_delta[..., -1:], jnp.inf)\n    ],\n                                    axis=-1)\n    \n  alpha = 1 - jnp.exp(-density_delta)\n  trans = jnp.exp(-jnp.concatenate([\n      jnp.zeros_like(density_delta[..., :1]),\n      jnp.cumsum(density_delta[..., :-1], axis=-1)\n  ],\n                                   axis=-1))\n  weights = alpha * trans\n  return weights, alpha, trans\n\ndef compute_alpha_weights_multispace(density, tdist, dirs, opaque_background=False):\n  \"\"\"Helper function for computing alpha compositing weights.\"\"\"\n  t_delta = tdist[..., 1:] - tdist[..., :-1]\n  delta = t_delta * jnp.linalg.norm(dirs[..., None, :], axis=-1)\n  density_delta = density * delta[..., None]\n\n  if opaque_background:\n    # Equivalent to making the final t-interval infinitely wide.\n    density_delta = jnp.concatenate([\n        density_delta[..., :-1, :],\n        jnp.full_like(density_delta[..., -1:, :], jnp.inf)\n    ],\n                                    axis=-2)\n\n  alpha = 1 - jnp.exp(-density_delta)\n  trans = jnp.exp(-jnp.concatenate([\n      jnp.zeros_like(density_delta[..., :1]),\n      jnp.cumsum(density_delta[..., :-1], axis=-2)\n  ],\n                                   axis=-1))\n  weights = alpha * trans\n  return weights, alpha, trans\n\ndef volumetric_rendering(rgbs,\n                         weights,\n                         tdist,\n                         bg_rgbs,\n                         t_far,\n                         compute_extras,\n                         extras=None):\n  \"\"\"Volumetric Rendering Function.\n\n  Args:\n    rgbs: jnp.ndarray(float32), color, [batch_size, num_samples, 3]\n    weights: jnp.ndarray(float32), weights, [batch_size, num_samples].\n    tdist: jnp.ndarray(float32), [batch_size, num_samples].\n    bg_rgbs: jnp.ndarray(float32), the color(s) to use for the background.\n    t_far: jnp.ndarray(float32), [batch_size, 1], the distance of the far plane.\n    compute_extras: bool, if True, compute extra quantities besides color.\n    extras: dict, a set of values along rays to render by alpha compositing.\n\n  Returns:\n    rendering: a dict containing an rgb image of size [batch_size, 3], and other\n      visualizations if compute_extras=True.\n  \"\"\"\n  eps = jnp.finfo(jnp.float32).eps\n  rendering = {}\n\n  acc = weights.sum(axis=-1)\n  bg_w = jnp.maximum(0, 1 - acc[..., None])  # The weight of the background.\n  rgb = (weights[..., None] * rgbs).sum(axis=-2) + bg_w * bg_rgbs\n  rendering['rgb'] = rgb\n\n  if compute_extras:\n    rendering['acc'] = acc\n\n    if extras is not None:\n      for k, v in extras.items():\n        if v is not None:\n          rendering[k] = (weights[..., None] * v).sum(axis=-2)\n\n    expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc)\n    t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])\n    # For numerical stability this expectation is computing using log-distance.\n    rendering['distance_mean'] = (\n        jnp.clip(\n            jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf),\n            tdist[..., 0], tdist[..., -1]))\n\n    # Add an extra fencepost with the far distance at the end of each ray, with\n    # whatever weight is needed to make the new weight vector sum to exactly 1\n    # (`weights` is only guaranteed to sum to <= 1, not == 1).\n    t_aug = jnp.concatenate([tdist, t_far], axis=-1)\n    weights_aug = jnp.concatenate([weights, bg_w], axis=-1)\n\n    ps = [5, 50, 95]\n    distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)\n\n    for i, p in enumerate(ps):\n      s = 'median' if p == 50 else 'percentile_' + str(p)\n      rendering['distance_' + s] = distance_percentiles[..., i]\n\n  return rendering\n\ndef volumetric_rendering_multispace(feats,\n                                    weights,\n                                    tdist,\n                                    bg_rgbs,\n                                    t_far,\n                                    compute_extras,\n                                    num_space=1,\n                                    decoder=None,\n                                    extras=None):\n  \"\"\"Volumetric Rendering Function for multi space.\n\n  Args:\n    feats: jnp.ndarray(float32), ray marching feature, [batch_size, num_samples, num_feat * num_space]\n    weights: jnp.ndarray(float32), weights, [batch_size, num_samples, num_space].\n    tdist: jnp.ndarray(float32), [batch_size, num_samples].\n    bg_rgbs: jnp.ndarray(float32), the color(s) to use for the background.\n    t_far: jnp.ndarray(float32), [batch_size, 1], the distance of the far plane.\n    compute_extras: bool, if True, compute extra quantities besides color.\n    decoder: decoder to transform features to rgbs\n    extras: dict, a set of values along rays to render by alpha compositing.\n\n  Returns:\n    rendering: a dict containing an rgb image of size [batch_size, 3], and other\n      visualizations if compute_extras=True.\n  \"\"\"\n  eps = jnp.finfo(jnp.float32).eps\n  rendering = {}\n  n_rays, n_samples = feats.shape[:2]\n  \n  feats = rearrange(feats, '... (n c) -> ... n c', n = num_space)\n  feats = (weights[..., None] * feats).sum(axis=-3)\n  rgb, sub_rgbs, alpha = decoder(feats)\n\n  acc = (weights.sum(axis=-2) * alpha[..., 0]).sum(-1)\n  bg_w = jnp.maximum(0, 1 - acc[..., None])  # The weight of the background.\n  rendering['rgb'] = rgb + bg_w * bg_rgbs\n  for n in range(num_space):\n    rendering[f'rgb_{n}'] = sub_rgbs[..., n, :]\n  weights = (weights * repeat(alpha[..., 0], '... s -> ... repeat s', repeat=n_samples)).sum(-1)\n  \n  ray_rgb = repeat(rendering['rgb'], 'r ... c -> r ... s c', s=n_samples) / n_samples / jnp.linalg.norm(weights[..., None], axis=-2, keepdims=True)\n\n  if compute_extras:\n    rendering['acc'] = acc\n\n    if extras is not None:\n      for k, v in extras.items():\n        if v is not None:\n          rendering[k] = (weights[..., None] * v).sum(axis=-2)\n\n    expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc)\n    t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])\n    # For numerical stability this expectation is computing using log-distance.\n    rendering['distance_mean'] = (\n        jnp.clip(\n            jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf),\n            tdist[..., 0], tdist[..., -1]))\n\n    # Add an extra fencepost with the far distance at the end of each ray, with\n    # whatever weight is needed to make the new weight vector sum to exactly 1\n    # (`weights` is only guaranteed to sum to <= 1, not == 1).\n    t_aug = jnp.concatenate([tdist, t_far], axis=-1)\n    weights_aug = jnp.concatenate([weights, bg_w], axis=-1)\n\n    ps = [5, 50, 95]\n    distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)\n\n    for i, p in enumerate(ps):\n      s = 'median' if p == 50 else 'percentile_' + str(p)\n      rendering['distance_' + s] = distance_percentiles[..., i]\n\n  return rendering, weights, ray_rgb\n"
  },
  {
    "path": "mip360/internal/stepfun.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tools for manipulating step functions (piecewise-constant 1D functions).\n\nWe have a shared naming and dimension convention for these functions.\nAll input/output step functions are assumed to be aligned along the last axis.\n`t` always indicates the x coordinates of the *endpoints* of a step function.\n`y` indicates unconstrained values for the *bins* of a step function\n`w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin\nvalues that *integrate* to <= 1.\n\"\"\"\n\nfrom internal import math\nimport jax\nimport jax.numpy as jnp\n\n\ndef searchsorted(a, v):\n  \"\"\"Find indices where v should be inserted into a to maintain order.\n\n  This behaves like jnp.searchsorted (its second output is the same as\n  jnp.searchsorted's output if all elements of v are in [a[0], a[-1]]) but is\n  faster because it wastes memory to save some compute.\n\n  Args:\n    a: tensor, the sorted reference points that we are scanning to see where v\n      should lie.\n    v: tensor, the query points that we are pretending to insert into a. Does\n      not need to be sorted. All but the last dimensions should match or expand\n      to those of a, the last dimension can differ.\n\n  Returns:\n    (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the\n    range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or\n    last index of a.\n  \"\"\"\n  i = jnp.arange(a.shape[-1])\n  v_ge_a = v[..., None, :] >= a[..., :, None]\n  idx_lo = jnp.max(jnp.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2)\n  idx_hi = jnp.min(jnp.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2)\n  return idx_lo, idx_hi\n\n\ndef query(tq, t, y, outside_value=0):\n  \"\"\"Look up the values of the step function (t, y) at locations tq.\"\"\"\n  idx_lo, idx_hi = searchsorted(t, tq)\n  yq = jnp.where(idx_lo == idx_hi, outside_value,\n                 jnp.take_along_axis(y, idx_lo, axis=-1))\n  return yq\n\n\ndef inner_outer(t0, t1, y1):\n  \"\"\"Construct inner and outer measures on (t1, y1) for t0.\"\"\"\n  cy1 = jnp.concatenate([jnp.zeros_like(y1[..., :1]),\n                         jnp.cumsum(y1, axis=-1)],\n                        axis=-1)\n  idx_lo, idx_hi = searchsorted(t1, t0)\n\n  cy1_lo = jnp.take_along_axis(cy1, idx_lo, axis=-1)\n  cy1_hi = jnp.take_along_axis(cy1, idx_hi, axis=-1)\n\n  y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]\n  y0_inner = jnp.where(idx_hi[..., :-1] <= idx_lo[..., 1:],\n                       cy1_lo[..., 1:] - cy1_hi[..., :-1], 0)\n  return y0_inner, y0_outer\n\n\ndef lossfun_outer(t, w, t_env, w_env, eps=jnp.finfo(jnp.float32).eps):\n  \"\"\"The proposal weight should be an upper envelope on the nerf weight.\"\"\"\n  _, w_outer = inner_outer(t, t_env, w_env)\n  # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's\n  # more effective to pull w_outer up than it is to push w_inner down.\n  # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.\n  return jnp.maximum(0, w - w_outer)**2 / (w + eps)\n\n\ndef weight_to_pdf(t, w, eps=jnp.finfo(jnp.float32).eps**2):\n  \"\"\"Turn a vector of weights that sums to 1 into a PDF that integrates to 1.\"\"\"\n  return w / jnp.maximum(eps, (t[..., 1:] - t[..., :-1]))\n\n\ndef pdf_to_weight(t, p):\n  \"\"\"Turn a PDF that integrates to 1 into a vector of weights that sums to 1.\"\"\"\n  return p * (t[..., 1:] - t[..., :-1])\n\n\ndef max_dilate(t, w, dilation, domain=(-jnp.inf, jnp.inf)):\n  \"\"\"Dilate (via max-pooling) a non-negative step function.\"\"\"\n  t0 = t[..., :-1] - dilation\n  t1 = t[..., 1:] + dilation\n  t_dilate = jnp.sort(jnp.concatenate([t, t0, t1], axis=-1), axis=-1)\n  t_dilate = jnp.clip(t_dilate, *domain)\n  w_dilate = jnp.max(\n      jnp.where(\n          (t0[..., None, :] <= t_dilate[..., None])\n          & (t1[..., None, :] > t_dilate[..., None]),\n          w[..., None, :],\n          0,\n      ),\n      axis=-1)[..., :-1]\n  return t_dilate, w_dilate\n\n\ndef max_dilate_weights(t,\n                       w,\n                       dilation,\n                       domain=(-jnp.inf, jnp.inf),\n                       renormalize=False,\n                       eps=jnp.finfo(jnp.float32).eps**2):\n  \"\"\"Dilate (via max-pooling) a set of weights.\"\"\"\n  p = weight_to_pdf(t, w)\n  t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)\n  w_dilate = pdf_to_weight(t_dilate, p_dilate)\n  if renormalize:\n    w_dilate /= jnp.maximum(eps, jnp.sum(w_dilate, axis=-1, keepdims=True))\n  return t_dilate, w_dilate\n\n\ndef integrate_weights(w):\n  \"\"\"Compute the cumulative sum of w, assuming all weight vectors sum to 1.\n\n  The output's size on the last dimension is one greater than that of the input,\n  because we're computing the integral corresponding to the endpoints of a step\n  function, not the integral of the interior/bin values.\n\n  Args:\n    w: Tensor, which will be integrated along the last axis. This is assumed to\n      sum to 1 along the last axis, and this function will (silently) break if\n      that is not the case.\n\n  Returns:\n    cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1\n  \"\"\"\n  cw = jnp.minimum(1, jnp.cumsum(w[..., :-1], axis=-1))\n  shape = cw.shape[:-1] + (1,)\n  # Ensure that the CDF starts with exactly 0 and ends with exactly 1.\n  cw0 = jnp.concatenate([jnp.zeros(shape), cw, jnp.ones(shape)], axis=-1)\n  return cw0\n\n\ndef invert_cdf(u, t, w_logits, use_gpu_resampling=False):\n  \"\"\"Invert the CDF defined by (t, w) at the points specified by u in [0, 1).\"\"\"\n  # Compute the PDF and CDF for each weight vector.\n  w = jax.nn.softmax(w_logits, axis=-1)\n  cw = integrate_weights(w)\n  # Interpolate into the inverse CDF.\n  interp_fn = math.interp if use_gpu_resampling else math.sorted_interp\n  t_new = interp_fn(u, cw, t)\n  return t_new\n\n\ndef sample(rng,\n           t,\n           w_logits,\n           num_samples,\n           single_jitter=False,\n           deterministic_center=False,\n           use_gpu_resampling=False):\n  \"\"\"Piecewise-Constant PDF sampling from a step function.\n\n  Args:\n    rng: random number generator (or None for `linspace` sampling).\n    t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)\n    w_logits: [..., num_bins], logits corresponding to bin weights\n    num_samples: int, the number of samples.\n    single_jitter: bool, if True, jitter every sample along each ray by the same\n      amount in the inverse CDF. Otherwise, jitter each sample independently.\n    deterministic_center: bool, if False, when `rng` is None return samples that\n      linspace the entire PDF. If True, skip the front and back of the linspace\n      so that the centers of each PDF interval are returned.\n    use_gpu_resampling: bool, If True this resamples the rays based on a\n      \"gather\" instruction, which is fast on GPUs but slow on TPUs. If False,\n      this resamples the rays based on brute-force searches, which is fast on\n      TPUs, but slow on GPUs.\n\n  Returns:\n    t_samples: jnp.ndarray(float32), [batch_size, num_samples].\n  \"\"\"\n  eps = jnp.finfo(jnp.float32).eps\n\n  # Draw uniform samples.\n  if rng is None:\n    # Match the behavior of jax.random.uniform() by spanning [0, 1-eps].\n    if deterministic_center:\n      pad = 1 / (2 * num_samples)\n      u = jnp.linspace(pad, 1. - pad - eps, num_samples)\n    else:\n      u = jnp.linspace(0, 1. - eps, num_samples)\n    u = jnp.broadcast_to(u, t.shape[:-1] + (num_samples,))\n  else:\n    # `u` is in [0, 1) --- it can be zero, but it can never be 1.\n    u_max = eps + (1 - eps) / num_samples\n    max_jitter = (1 - u_max) / (num_samples - 1) - eps\n    d = 1 if single_jitter else num_samples\n    u = (\n        jnp.linspace(0, 1 - u_max, num_samples) +\n        jax.random.uniform(rng, t.shape[:-1] + (d,), maxval=max_jitter))\n\n  return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling)\n\n\ndef sample_intervals(rng,\n                     t,\n                     w_logits,\n                     num_samples,\n                     single_jitter=False,\n                     domain=(-jnp.inf, jnp.inf),\n                     use_gpu_resampling=False):\n  \"\"\"Sample *intervals* (rather than points) from a step function.\n\n  Args:\n    rng: random number generator (or None for `linspace` sampling).\n    t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)\n    w_logits: [..., num_bins], logits corresponding to bin weights\n    num_samples: int, the number of intervals to sample.\n    single_jitter: bool, if True, jitter every sample along each ray by the same\n      amount in the inverse CDF. Otherwise, jitter each sample independently.\n    domain: (minval, maxval), the range of valid values for `t`.\n    use_gpu_resampling:  bool, If True this resamples the rays based on a\n      \"gather\" instruction, which is fast on GPUs but slow on TPUs. If False,\n      this resamples the rays based on brute-force searches, which is fast on\n      TPUs, but slow on GPUs.\n\n  Returns:\n    t_samples: jnp.ndarray(float32), [batch_size, num_samples].\n  \"\"\"\n  if num_samples <= 1:\n    raise ValueError(f'num_samples must be > 1, is {num_samples}.')\n\n  # Sample a set of points from the step function.\n  centers = sample(\n      rng,\n      t,\n      w_logits,\n      num_samples,\n      single_jitter,\n      deterministic_center=True,\n      use_gpu_resampling=use_gpu_resampling)\n\n  # The intervals we return will span the midpoints of each adjacent sample.\n  mid = (centers[..., 1:] + centers[..., :-1]) / 2\n\n  # Each first/last fencepost is the reflection of the first/last midpoint\n  # around the first/last sampled center. We clamp to the limits of the input\n  # domain, provided by the caller.\n  minval, maxval = domain\n  first = jnp.maximum(minval, 2 * centers[..., :1] - mid[..., :1])\n  last = jnp.minimum(maxval, 2 * centers[..., -1:] - mid[..., -1:])\n\n  t_samples = jnp.concatenate([first, mid, last], axis=-1)\n  return t_samples\n\n\ndef lossfun_distortion(t, w):\n  \"\"\"Compute iint w[i] w[j] |t[i] - t[j]| di dj.\"\"\"\n  # The loss incurred between all pairs of intervals.\n  ut = (t[..., 1:] + t[..., :-1]) / 2\n  dut = jnp.abs(ut[..., :, None] - ut[..., None, :])\n  loss_inter = jnp.sum(w * jnp.sum(w[..., None, :] * dut, axis=-1), axis=-1)\n\n  # The loss incurred within each individual interval with itself.\n  loss_intra = jnp.sum(w**2 * (t[..., 1:] - t[..., :-1]), axis=-1) / 3\n\n  return loss_inter + loss_intra\n\n\ndef interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):\n  \"\"\"Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).\"\"\"\n  # Distortion when the intervals do not overlap.\n  d_disjoint = jnp.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)\n\n  # Distortion when the intervals overlap.\n  d_overlap = (2 *\n               (jnp.minimum(t0_hi, t1_hi)**3 - jnp.maximum(t0_lo, t1_lo)**3) +\n               3 * (t1_hi * t0_hi * jnp.abs(t1_hi - t0_hi) +\n                    t1_lo * t0_lo * jnp.abs(t1_lo - t0_lo) + t1_hi * t0_lo *\n                    (t0_lo - t1_hi) + t1_lo * t0_hi *\n                    (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))\n\n  # Are the two intervals not overlapping?\n  are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)\n\n  return jnp.where(are_disjoint, d_disjoint, d_overlap)\n\n\ndef weighted_percentile(t, w, ps):\n  \"\"\"Compute the weighted percentiles of a step function. w's must sum to 1.\"\"\"\n  cw = integrate_weights(w)\n  # We want to interpolate into the integrated weights according to `ps`.\n  fn = lambda cw_i, t_i: jnp.interp(jnp.array(ps) / 100, cw_i, t_i)\n  # Vmap fn to an arbitrary number of leading dimensions.\n  cw_mat = cw.reshape([-1, cw.shape[-1]])\n  t_mat = t.reshape([-1, t.shape[-1]])\n  wprctile_mat = (jax.vmap(fn, 0)(cw_mat, t_mat))\n  wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))\n  return wprctile\n\n\ndef resample(t, tp, vp, use_avg=False, eps=jnp.finfo(jnp.float32).eps):\n  \"\"\"Resample a step function defined by (tp, vp) into intervals t.\n\n  Notation roughly matches jnp.interp. Resamples by summation by default.\n\n  Args:\n    t: tensor with shape (..., n+1), the endpoints to resample into.\n    tp: tensor with shape (..., m+1), the endpoints of the step function being\n      resampled.\n    vp: tensor with shape (..., m), the values of the step function being\n      resampled.\n    use_avg: bool, if False, return the sum of the step function for each\n      interval in `t`. If True, return the average, weighted by the width of\n      each interval in `t`.\n    eps: float, a small value to prevent division by zero when use_avg=True.\n\n  Returns:\n    v: tensor with shape (..., n), the values of the resampled step function.\n  \"\"\"\n  if use_avg:\n    wp = jnp.diff(tp, axis=-1)\n    v_numer = resample(t, tp, vp * wp, use_avg=False)\n    v_denom = resample(t, tp, wp, use_avg=False)\n    v = v_numer / jnp.maximum(eps, v_denom)\n    return v\n\n  acc = jnp.cumsum(vp, axis=-1)\n  acc0 = jnp.concatenate([jnp.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)\n  acc0_resampled = jnp.vectorize(\n      jnp.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)\n  v = jnp.diff(acc0_resampled, axis=-1)\n  return v\n"
  },
  {
    "path": "mip360/internal/train_utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Training step and model creation functions.\"\"\"\n\nimport collections\nimport functools\nfrom typing import Any, Callable, Dict, MutableMapping, Optional, Text, Tuple\n\nfrom flax.core.scope import FrozenVariableDict\nfrom flax.training.train_state import TrainState\nfrom internal import camera_utils\nfrom internal import configs\nfrom internal import datasets\nfrom internal import image\nfrom internal import math\nfrom internal import models\nfrom internal import ref_utils\nfrom internal import stepfun\nfrom internal import utils\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport optax\n\n\ndef tree_sum(tree):\n  return jax.tree_util.tree_reduce(lambda x, y: x + y, tree, initializer=0)\n\n\ndef tree_norm_sq(tree):\n  return tree_sum(jax.tree_util.tree_map(lambda x: jnp.sum(x**2), tree))\n\n\ndef tree_norm(tree):\n  return jnp.sqrt(tree_norm_sq(tree))\n\n\ndef tree_abs_max(tree):\n  return jax.tree_util.tree_reduce(\n      lambda x, y: jnp.maximum(x, jnp.max(jnp.abs(y))), tree, initializer=0)\n\n\ndef tree_len(tree):\n  return tree_sum(\n      jax.tree_util.tree_map(lambda z: jnp.prod(jnp.array(z.shape)), tree))\n\n\ndef summarize_tree(tree, fn, ancestry=(), max_depth=3):\n  \"\"\"Flatten 'tree' while 'fn'-ing values and formatting keys like/this.\"\"\"\n  stats = {}\n  for k, v in tree.items():\n    name = ancestry + (k,)\n    stats['/'.join(name)] = fn(v)\n    if hasattr(v, 'items') and len(ancestry) < (max_depth - 1):\n      stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth))\n  return stats\n\n\ndef compute_data_loss(batch, renderings, rays, config):\n  \"\"\"Computes data loss terms for RGB, normal, and depth outputs.\"\"\"\n  data_losses = []\n  stats = collections.defaultdict(lambda: [])\n\n  # lossmult can be used to apply a weight to each ray in the batch.\n  # For example: masking out rays, applying the Bayer mosaic mask, upweighting\n  # rays from lower resolution images and so on.\n  lossmult = rays.lossmult\n  lossmult = jnp.broadcast_to(lossmult, batch.rgb[..., :3].shape)\n  if config.disable_multiscale_loss:\n    lossmult = jnp.ones_like(lossmult)\n\n  for rendering in renderings:\n    resid_sq = (rendering['rgb'] - batch.rgb[..., :3])**2\n    denom = lossmult.sum()\n    stats['mses'].append((lossmult * resid_sq).sum() / denom)\n\n    if config.data_loss_type == 'mse':\n      # Mean-squared error (L2) loss.\n      data_loss = resid_sq\n    elif config.data_loss_type == 'charb':\n      # Charbonnier loss.\n      data_loss = jnp.sqrt(resid_sq + config.charb_padding**2)\n    elif config.data_loss_type == 'rawnerf':\n      # Clip raw values against 1 to match sensor overexposure behavior.\n      rgb_render_clip = jnp.minimum(1., rendering['rgb'])\n      resid_sq_clip = (rgb_render_clip - batch.rgb[..., :3])**2\n      # Scale by gradient of log tonemapping curve.\n      scaling_grad = 1. / (1e-3 + jax.lax.stop_gradient(rgb_render_clip))\n      # Reweighted L2 loss.\n      data_loss = resid_sq_clip * scaling_grad**2\n    else:\n      assert False\n    data_losses.append((lossmult * data_loss).sum() / denom)\n\n    if config.compute_disp_metrics:\n      # Using mean to compute disparity, but other distance statistics can\n      # be used instead.\n      disp = 1 / (1 + rendering['distance_mean'])\n      stats['disparity_mses'].append(((disp - batch.disps)**2).mean())\n\n    if config.compute_normal_metrics:\n      if 'normals' in rendering:\n        weights = rendering['acc'] * batch.alphas\n        normalized_normals_gt = ref_utils.l2_normalize(batch.normals)\n        normalized_normals = ref_utils.l2_normalize(rendering['normals'])\n        normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals,\n                                                    normalized_normals_gt)\n      else:\n        # If normals are not computed, set MAE to NaN.\n        normal_mae = jnp.nan\n      stats['normal_maes'].append(normal_mae)\n\n  data_losses = jnp.array(data_losses)\n  loss = (\n      config.data_coarse_loss_mult * jnp.sum(data_losses[:-1]) +\n      config.data_loss_mult * data_losses[-1])\n  stats = {k: jnp.array(stats[k]) for k in stats}\n  return loss, stats\n\n\ndef interlevel_loss(ray_history, config):\n  \"\"\"Computes the interlevel loss defined in mip-NeRF 360.\"\"\"\n  # Stop the gradient from the interlevel loss onto the NeRF MLP.\n  last_ray_results = ray_history[-1]\n  c = jax.lax.stop_gradient(last_ray_results['sdist'])\n  w = jax.lax.stop_gradient(last_ray_results['weights'])\n  loss_interlevel = 0.\n  for ray_results in ray_history[:-1]:\n    cp = ray_results['sdist']\n    wp = ray_results['weights']\n    loss_interlevel += jnp.mean(stepfun.lossfun_outer(c, w, cp, wp))\n  return config.interlevel_loss_mult * loss_interlevel\n\n\ndef distortion_loss(ray_history, config):\n  \"\"\"Computes the distortion loss regularizer defined in mip-NeRF 360.\"\"\"\n  last_ray_results = ray_history[-1]\n  c = last_ray_results['sdist']\n  w = last_ray_results['weights']\n  loss = jnp.mean(stepfun.lossfun_distortion(c, w))\n  return config.distortion_loss_mult * loss\n\n\ndef orientation_loss(rays, model, ray_history, config):\n  \"\"\"Computes the orientation loss regularizer defined in ref-NeRF.\"\"\"\n  total_loss = 0.\n  for i, ray_results in enumerate(ray_history):\n    w = ray_results['weights']\n    n = ray_results[config.orientation_loss_target]\n    if n is None:\n      raise ValueError('Normals cannot be None if orientation loss is on.')\n    # Negate viewdirs to represent normalized vectors from point to camera.\n    v = -1. * rays.viewdirs\n    n_dot_v = (n * v[..., None, :]).sum(axis=-1)\n    loss = jnp.mean((w * jnp.minimum(0.0, n_dot_v)**2).sum(axis=-1))\n    if i < model.num_levels - 1:\n      total_loss += config.orientation_coarse_loss_mult * loss\n    else:\n      total_loss += config.orientation_loss_mult * loss\n  return total_loss\n\n\ndef predicted_normal_loss(model, ray_history, config):\n  \"\"\"Computes the predicted normal supervision loss defined in ref-NeRF.\"\"\"\n  total_loss = 0.\n  for i, ray_results in enumerate(ray_history):\n    w = ray_results['weights']\n    n = ray_results['normals']\n    n_pred = ray_results['normals_pred']\n    if n is None or n_pred is None:\n      raise ValueError(\n          'Predicted normals and gradient normals cannot be None if '\n          'predicted normal loss is on.')\n    loss = jnp.mean((w * (1.0 - jnp.sum(n * n_pred, axis=-1))).sum(axis=-1))\n    if i < model.num_levels - 1:\n      total_loss += config.predicted_normal_coarse_loss_mult * loss\n    else:\n      total_loss += config.predicted_normal_loss_mult * loss\n  return total_loss\n\n\ndef clip_gradients(grad, config):\n  \"\"\"Clips gradients of each MLP individually based on norm and max value.\"\"\"\n  # Clip the gradients of each MLP individually.\n  grad_clipped = {'params': {}}\n  for k, g in grad['params'].items():\n    # Clip by value.\n    if config.grad_max_val > 0:\n      g = jax.tree_util.tree_map(\n          lambda z: jnp.clip(z, -config.grad_max_val, config.grad_max_val), g)\n\n    # Then clip by norm.\n    if config.grad_max_norm > 0:\n      mult = jnp.minimum(\n          1, config.grad_max_norm / (jnp.finfo(jnp.float32).eps + tree_norm(g)))\n      g = jax.tree_util.tree_map(lambda z: mult * z, g)  # pylint:disable=cell-var-from-loop\n\n    grad_clipped['params'][k] = g\n  grad = type(grad)(grad_clipped)\n  return grad\n\n\ndef create_train_step(model: models.Model,\n                      config: configs.Config,\n                      dataset: Optional[datasets.Dataset] = None):\n  \"\"\"Creates the pmap'ed Nerf training function.\n\n  Args:\n    model: The linen model.\n    config: The configuration.\n    dataset: Training dataset.\n\n  Returns:\n    pmap'ed training function.\n  \"\"\"\n  if dataset is None:\n    camtype = camera_utils.ProjectionType.PERSPECTIVE\n  else:\n    camtype = dataset.camtype\n\n  def train_step(\n      rng,\n      state,\n      batch,\n      cameras,\n      train_frac,\n  ):\n    \"\"\"One optimization step.\n\n    Args:\n      rng: jnp.ndarray, random number generator.\n      state: TrainState, state of the model/optimizer.\n      batch: dict, a mini-batch of data for training.\n      cameras: module containing camera poses.\n      train_frac: float, the fraction of training that is complete.\n\n    Returns:\n      A tuple (new_state, stats, rng) with\n        new_state: TrainState, new training state.\n        stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].\n        rng: jnp.ndarray, updated random number generator.\n    \"\"\"\n    rng, key = random.split(rng)\n\n    def loss_fn(variables):\n      rays = batch.rays\n      if config.cast_rays_in_train_step:\n        rays = camera_utils.cast_ray_batch(cameras, rays, camtype, xnp=jnp)\n\n      # Indicates whether we need to compute output normal or depth maps in 2D.\n      compute_extras = (\n          config.compute_disp_metrics or config.compute_normal_metrics)\n\n      renderings, ray_history = model.apply(\n          variables,\n          key if config.randomized else None,\n          rays,\n          train_frac=train_frac,\n          compute_extras=compute_extras,\n          zero_glo=False)\n\n      losses = {}\n\n      data_loss, stats = compute_data_loss(batch, renderings, rays, config)\n      losses['data'] = data_loss\n\n      if config.interlevel_loss_mult > 0:\n        losses['interlevel'] = interlevel_loss(ray_history, config)\n\n      if config.distortion_loss_mult > 0:\n        losses['distortion'] = distortion_loss(ray_history, config)\n\n      if (config.orientation_coarse_loss_mult > 0 or\n          config.orientation_loss_mult > 0):\n        losses['orientation'] = orientation_loss(rays, model, ray_history,\n                                                 config)\n\n      if (config.predicted_normal_coarse_loss_mult > 0 or\n          config.predicted_normal_loss_mult > 0):\n        losses['predicted_normals'] = predicted_normal_loss(\n            model, ray_history, config)\n\n      stats['weight_l2s'] = summarize_tree(variables['params'], tree_norm_sq)\n\n      if config.weight_decay_mults:\n        it = config.weight_decay_mults.items\n        losses['weight'] = jnp.sum(\n            jnp.array([m * stats['weight_l2s'][k] for k, m in it()]))\n\n      stats['loss'] = jnp.sum(jnp.array(list(losses.values())))\n      stats['losses'] = losses\n\n      return stats['loss'], stats\n\n    loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n    (_, stats), grad = loss_grad_fn(state.params)\n\n    pmean = lambda x: jax.lax.pmean(x, axis_name='batch')\n    grad = pmean(grad)\n    stats = pmean(stats)\n\n    stats['grad_norms'] = summarize_tree(grad['params'], tree_norm)\n    stats['grad_maxes'] = summarize_tree(grad['params'], tree_abs_max)\n\n    grad = clip_gradients(grad, config)\n\n    grad = jax.tree_util.tree_map(jnp.nan_to_num, grad)\n\n    new_state = state.apply_gradients(grads=grad)\n\n    opt_delta = jax.tree_util.tree_map(lambda x, y: x - y, new_state,\n                                       state).params['params']\n    stats['opt_update_norms'] = summarize_tree(opt_delta, tree_norm)\n    stats['opt_update_maxes'] = summarize_tree(opt_delta, tree_abs_max)\n\n    stats['psnrs'] = image.mse_to_psnr(stats['mses'])\n    stats['psnr'] = stats['psnrs'][-1]\n    return new_state, stats, rng\n\n  train_pstep = jax.pmap(\n      train_step,\n      axis_name='batch',\n      in_axes=(0, 0, 0, None, None),\n      donate_argnums=(0, 1))\n  return train_pstep\n\n\ndef create_optimizer(\n    config: configs.Config,\n    variables: FrozenVariableDict) -> Tuple[TrainState, Callable[[int], float]]:\n  \"\"\"Creates optax optimizer for model training.\"\"\"\n  adam_kwargs = {\n      'b1': config.adam_beta1,\n      'b2': config.adam_beta2,\n      'eps': config.adam_eps,\n  }\n  lr_kwargs = {\n      'max_steps': config.max_steps,\n      'lr_delay_steps': config.lr_delay_steps,\n      'lr_delay_mult': config.lr_delay_mult,\n  }\n\n  def get_lr_fn(lr_init, lr_final):\n    return functools.partial(\n        math.learning_rate_decay,\n        lr_init=lr_init,\n        lr_final=lr_final,\n        **lr_kwargs)\n\n  lr_fn_main = get_lr_fn(config.lr_init, config.lr_final)\n  tx = optax.adam(learning_rate=lr_fn_main, **adam_kwargs)\n\n  return TrainState.create(apply_fn=None, params=variables, tx=tx), lr_fn_main\n\n\ndef create_render_fn(model: models.Model):\n  \"\"\"Creates pmap'ed function for full image rendering.\"\"\"\n\n  def render_eval_fn(variables, train_frac, _, rays):\n    return jax.lax.all_gather(\n        model.apply(\n            variables,\n            None,  # Deterministic.\n            rays,\n            train_frac=train_frac,\n            compute_extras=True),\n        axis_name='batch')\n\n  # pmap over only the data input.\n  render_eval_pfn = jax.pmap(\n      render_eval_fn,\n      in_axes=(None, None, None, 0),\n      axis_name='batch',\n  )\n  return render_eval_pfn\n\n\ndef setup_model(\n    config: configs.Config,\n    rng: jnp.array,\n    dataset: Optional[datasets.Dataset] = None,\n) -> Tuple[models.Model, TrainState, Callable[\n    [FrozenVariableDict, jnp.array, utils.Rays],\n    MutableMapping[Text, Any]], Callable[\n        [jnp.array, TrainState, utils.Batch, Optional[Tuple[Any, ...]], float],\n        Tuple[TrainState, Dict[Text, Any], jnp.array]], Callable[[int], float]]:\n  \"\"\"Creates NeRF model, optimizer, and pmap-ed train/render functions.\"\"\"\n\n  dummy_rays = utils.dummy_rays(\n      include_exposure_idx=config.rawnerf_mode, include_exposure_values=True)\n  model, variables = models.construct_model(rng, dummy_rays, config)\n\n  state, lr_fn = create_optimizer(config, variables)\n  render_eval_pfn = create_render_fn(model)\n  train_pstep = create_train_step(model, config, dataset=dataset)\n\n  return model, state, render_eval_pfn, train_pstep, lr_fn\n"
  },
  {
    "path": "mip360/internal/utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utility functions.\"\"\"\n\nimport enum\nimport os\nfrom typing import Any, Dict, Optional, Union\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom PIL import ExifTags\nfrom PIL import Image\n\n_Array = Union[np.ndarray, jnp.ndarray]\n\n\n@flax.struct.dataclass\nclass Pixels:\n  \"\"\"All tensors must have the same num_dims and first n-1 dims must match.\"\"\"\n  pix_x_int: _Array\n  pix_y_int: _Array\n  lossmult: _Array\n  near: _Array\n  far: _Array\n  cam_idx: _Array\n  exposure_idx: Optional[_Array] = None\n  exposure_values: Optional[_Array] = None\n\n\n@flax.struct.dataclass\nclass Rays:\n  \"\"\"All tensors must have the same num_dims and first n-1 dims must match.\"\"\"\n  origins: _Array\n  directions: _Array\n  viewdirs: _Array\n  radii: _Array\n  imageplane: _Array\n  lossmult: _Array\n  near: _Array\n  far: _Array\n  cam_idx: _Array\n  exposure_idx: Optional[_Array] = None\n  exposure_values: Optional[_Array] = None\n\n\n# Dummy Rays object that can be used to initialize NeRF model.\ndef dummy_rays(include_exposure_idx: bool = False,\n               include_exposure_values: bool = False) -> Rays:\n  data_fn = lambda n: jnp.zeros((1, n))\n  exposure_kwargs = {}\n  if include_exposure_idx:\n    exposure_kwargs['exposure_idx'] = data_fn(1).astype(jnp.int32)\n  if include_exposure_values:\n    exposure_kwargs['exposure_values'] = data_fn(1)\n  return Rays(\n      origins=data_fn(3),\n      directions=data_fn(3),\n      viewdirs=data_fn(3),\n      radii=data_fn(1),\n      imageplane=data_fn(2),\n      lossmult=data_fn(1),\n      near=data_fn(1),\n      far=data_fn(1),\n      cam_idx=data_fn(1).astype(jnp.int32),\n      **exposure_kwargs)\n\n\n@flax.struct.dataclass\nclass Batch:\n  \"\"\"Data batch for NeRF training or testing.\"\"\"\n  rays: Union[Pixels, Rays]\n  rgb: Optional[_Array] = None\n  disps: Optional[_Array] = None\n  normals: Optional[_Array] = None\n  alphas: Optional[_Array] = None\n\n\nclass DataSplit(enum.Enum):\n  \"\"\"Dataset split.\"\"\"\n  TRAIN = 'train'\n  TEST = 'test'\n\n\nclass BatchingMethod(enum.Enum):\n  \"\"\"Draw rays randomly from a single image or all images, in each batch.\"\"\"\n  ALL_IMAGES = 'all_images'\n  SINGLE_IMAGE = 'single_image'\n\n\ndef open_file(pth, mode='r'):\n  return open(pth, mode=mode)\n\n\ndef file_exists(pth):\n  return os.path.exists(pth)\n\n\ndef listdir(pth):\n  return os.listdir(pth)\n\n\ndef isdir(pth):\n  return os.path.isdir(pth)\n\n\ndef makedirs(pth):\n  if not file_exists(pth):\n    os.makedirs(pth)\n\n\ndef shard(xs):\n  \"\"\"Split data into shards for multiple devices along the first dimension.\"\"\"\n  return jax.tree_util.tree_map(\n      lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)\n\n\ndef unshard(x, padding=0):\n  \"\"\"Collect the sharded tensor to the shape before sharding.\"\"\"\n  y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))\n  if padding > 0:\n    y = y[:-padding]\n  return y\n\n\ndef load_img(pth: str) -> np.ndarray:\n  \"\"\"Load an image and cast to float32.\"\"\"\n  with open_file(pth, 'rb') as f:\n    image = np.array(Image.open(f), dtype=np.float32)\n  return image\n\n\ndef load_exif(pth: str) -> Dict[str, Any]:\n  \"\"\"Load EXIF data for an image.\"\"\"\n  with open_file(pth, 'rb') as f:\n    image_pil = Image.open(f)\n    exif_pil = image_pil._getexif()  # pylint: disable=protected-access\n    if exif_pil is not None:\n      exif = {\n          ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS\n      }\n    else:\n      exif = {}\n  return exif\n\n\ndef save_img_u8(img, pth):\n  \"\"\"Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.\"\"\"\n  with open_file(pth, 'wb') as f:\n    Image.fromarray(\n        (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save(\n            f, 'PNG')\n\n\ndef save_img_f32(depthmap, pth):\n  \"\"\"Save an image (probably a depthmap) to disk as a float32 TIFF.\"\"\"\n  with open_file(pth, 'wb') as f:\n    Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')\n"
  },
  {
    "path": "mip360/internal/vis.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helper functions for visualizing things.\"\"\"\n\nfrom internal import stepfun\nimport jax.numpy as jnp\nfrom matplotlib import cm\n\n\ndef weighted_percentile(x, w, ps, assume_sorted=False):\n  \"\"\"Compute the weighted percentile(s) of a single vector.\"\"\"\n  x = x.reshape([-1])\n  w = w.reshape([-1])\n  if not assume_sorted:\n    sortidx = jnp.argsort(x)\n    x, w = x[sortidx], w[sortidx]\n  acc_w = jnp.cumsum(w)\n  return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x)\n\n\ndef sinebow(h):\n  \"\"\"A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.\"\"\"\n  f = lambda x: jnp.sin(jnp.pi * x)**2\n  return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1)\n\n\ndef matte(vis, acc, dark=0.8, light=1.0, width=8):\n  \"\"\"Set non-accumulated pixels to a Photoshop-esque checker pattern.\"\"\"\n  bg_mask = jnp.logical_xor(\n      (jnp.arange(acc.shape[0]) % (2 * width) // width)[:, None],\n      (jnp.arange(acc.shape[1]) % (2 * width) // width)[None, :])\n  bg = jnp.where(bg_mask, light, dark)\n  return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None]\n\n\ndef visualize_cmap(value,\n                   weight,\n                   colormap,\n                   lo=None,\n                   hi=None,\n                   percentile=99.,\n                   curve_fn=lambda x: x,\n                   modulus=None,\n                   matte_background=True):\n  \"\"\"Visualize a 1D image and a 1D weighting according to some colormap.\n\n  Args:\n    value: A 1D image.\n    weight: A weight map, in [0, 1].\n    colormap: A colormap function.\n    lo: The lower bound to use when rendering, if None then use a percentile.\n    hi: The upper bound to use when rendering, if None then use a percentile.\n    percentile: What percentile of the value map to crop to when automatically\n      generating `lo` and `hi`. Depends on `weight` as well as `value'.\n    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`\n      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).\n    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If\n      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.\n    matte_background: If True, matte the image over a checkerboard.\n\n  Returns:\n    A colormap rendering.\n  \"\"\"\n  # Identify the values that bound the middle of `value' according to `weight`.\n  lo_auto, hi_auto = weighted_percentile(\n      value, weight, [50 - percentile / 2, 50 + percentile / 2])\n\n  # If `lo` or `hi` are None, use the automatically-computed bounds above.\n  eps = jnp.finfo(jnp.float32).eps\n  lo = lo or (lo_auto - eps)\n  hi = hi or (hi_auto + eps)\n\n  # Curve all values.\n  value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]\n\n  # Wrap the values around if requested.\n  if modulus:\n    value = jnp.mod(value, modulus) / modulus\n  else:\n    # Otherwise, just scale to [0, 1].\n    value = jnp.nan_to_num(\n        jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1))\n\n  if colormap:\n    colorized = colormap(value)[:, :, :3]\n  else:\n    if len(value.shape) != 3:\n      raise ValueError(f'value must have 3 dims but has {len(value.shape)}')\n    if value.shape[-1] != 3:\n      raise ValueError(\n          f'value must have 3 channels but has {len(value.shape[-1])}')\n    colorized = value\n\n  return matte(colorized, weight) if matte_background else colorized\n\n\ndef visualize_coord_mod(coords, acc):\n  \"\"\"Visualize the coordinate of each point within its \"cell\".\"\"\"\n  return matte(((coords + 1) % 2) / 2, acc)\n\n\ndef visualize_rays(dist,\n                   dist_range,\n                   weights,\n                   rgbs,\n                   accumulate=False,\n                   renormalize=False,\n                   resolution=2048,\n                   bg_color=0.8):\n  \"\"\"Visualize a bundle of rays.\"\"\"\n  dist_vis = jnp.linspace(*dist_range, resolution + 1)\n  vis_rgb, vis_alpha = [], []\n  for ds, ws, rs in zip(dist, weights, rgbs):\n    vis_rs, vis_ws = [], []\n    for d, w, r in zip(ds, ws, rs):\n      if accumulate:\n        # Produce the accumulated color and weight at each point along the ray.\n        w_csum = jnp.cumsum(w, axis=0)\n        rw_csum = jnp.cumsum((r * w[:, None]), axis=0)\n        eps = jnp.finfo(jnp.float32).eps\n        r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum\n      vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T)\n      vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T)\n    vis_rgb.append(jnp.stack(vis_rs))\n    vis_alpha.append(jnp.stack(vis_ws))\n  vis_rgb = jnp.stack(vis_rgb, axis=1)\n  vis_alpha = jnp.stack(vis_alpha, axis=1)\n\n  if renormalize:\n    # Scale the alphas so that the largest value is 1, for visualization.\n    vis_alpha /= jnp.maximum(jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha))\n\n  if resolution > vis_rgb.shape[0]:\n    rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1)\n    stride = rep * vis_rgb.shape[1]\n\n    vis_rgb = jnp.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:])\n    vis_alpha = jnp.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:])\n\n    # Add a strip of background pixels after each set of levels of rays.\n    vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:])\n    vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:])\n    vis_rgb = jnp.concatenate([vis_rgb, jnp.zeros_like(vis_rgb[:, :1])],\n                              axis=1).reshape((-1,) + vis_rgb.shape[2:])\n    vis_alpha = jnp.concatenate(\n        [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])],\n        axis=1).reshape((-1,) + vis_alpha.shape[2:])\n\n  # Matte the RGB image over the background.\n  vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None]\n\n  # Remove the final row of background pixels.\n  vis = vis[:-1]\n  vis_alpha = vis_alpha[:-1]\n  return vis, vis_alpha\n\n\ndef visualize_suite(rendering, rays):\n  \"\"\"A wrapper around other visualizations for easy integration.\"\"\"\n\n  depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps)\n\n  rgb = rendering['rgb']\n  acc = rendering['acc']\n\n  distance_mean = rendering['distance_mean']\n  distance_median = rendering['distance_median']\n  distance_p5 = rendering['distance_percentile_5']\n  distance_p95 = rendering['distance_percentile_95']\n  acc = jnp.where(jnp.isnan(distance_mean), jnp.zeros_like(acc), acc)\n\n  # The xyz coordinates where rays terminate.\n  coords = rays.origins + rays.directions * distance_mean[:, :, None]\n\n  vis_depth_mean, vis_depth_median = [\n      visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn)\n      for x in [distance_mean, distance_median]\n  ]\n\n  # Render three depth percentiles directly to RGB channels, where the spacing\n  # determines the color. delta == big change, epsilon = small change.\n  #   Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon]\n  #   Purple: A thin but even density, [x-delta, x, x+delta]\n  #   Red: A thin density, then a thick density, [x-delta, x, x+epsilon]\n  #   Blue: A thick density, then a thin density, [x-epsilon, x, x+delta]\n  vis_depth_triplet = visualize_cmap(\n      jnp.stack(\n          [2 * distance_median - distance_p5, distance_median, distance_p95],\n          axis=-1),\n      acc,\n      None,\n      curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps))\n\n  dist = rendering['ray_sdist']\n  dist_range = (0, 1)\n  weights = rendering['ray_weights']\n  rgbs = [jnp.clip(r, 0, 1) for r in rendering['ray_rgbs']]\n\n  vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs)\n\n  sqrt_weights = [jnp.sqrt(w) for w in weights]\n  sqrt_ray_weights, ray_alpha = visualize_rays(\n      dist,\n      dist_range,\n      [jnp.ones_like(lw) for lw in sqrt_weights],\n      [lw[..., None] for lw in sqrt_weights],\n      bg_color=0,\n  )\n  sqrt_ray_weights = sqrt_ray_weights[..., 0]\n\n  null_color = jnp.array([1., 0., 0.])\n  vis_ray_weights = jnp.where(\n      ray_alpha[:, :, None] == 0,\n      null_color[None, None],\n      visualize_cmap(\n          sqrt_ray_weights,\n          jnp.ones_like(sqrt_ray_weights),\n          cm.get_cmap('gray'),\n          lo=0,\n          hi=1,\n          matte_background=False,\n      ),\n  )\n\n  vis = {\n      'color': rgb,\n      'acc': acc,\n      'color_matte': matte(rgb, acc),\n      'depth_mean': vis_depth_mean,\n      'depth_median': vis_depth_median,\n      'depth_triplet': vis_depth_triplet,\n      'coords_mod': visualize_coord_mod(coords, acc),\n      'ray_colors': vis_ray_colors,\n      'ray_weights': vis_ray_weights,\n  }\n\n  if 'rgb_cc' in rendering:\n    vis['color_corrected'] = rendering['rgb_cc']\n\n  # Render every item named \"normals*\".\n  for key, val in rendering.items():\n    if key.startswith('normals'):\n      vis[key] = matte(val / 2. + 0.5, acc)\n\n  if 'roughness' in rendering:\n    vis['roughness'] = matte(jnp.tanh(rendering['roughness']), acc)\n\n  return vis\n"
  },
  {
    "path": "mip360/render.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Render script.\"\"\"\n\nimport concurrent.futures\nimport functools\nimport glob\nimport os\nimport time\n\nfrom absl import app\nfrom flax.training import checkpoints\nimport gin\nfrom internal import configs\nfrom internal import datasets\nfrom internal import models\nfrom internal import train_utils\nfrom internal import utils\nimport jax\nfrom jax import random\nfrom matplotlib import cm\nimport mediapy as media\nimport numpy as np\n\nconfigs.define_common_flags()\njax.config.parse_flags_with_absl()\n\n\ndef create_videos(config, base_dir, out_dir, out_name, num_frames):\n  \"\"\"Creates videos out of the images saved to disk.\"\"\"\n  names = [n for n in config.checkpoint_dir.split('/') if n]\n  # Last two parts of checkpoint path are experiment name and scene name.\n  exp_name, scene_name = names[-2:]\n  video_prefix = f'{scene_name}_{exp_name}_{out_name}'\n\n  zpad = max(3, len(str(num_frames - 1)))\n  idx_to_str = lambda idx: str(idx).zfill(zpad)\n\n  utils.makedirs(base_dir)\n\n  # Load one example frame to get image shape and depth range.\n  depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff')\n  depth_frame = utils.load_img(depth_file)\n  shape = depth_frame.shape\n  p = config.render_dist_percentile\n  distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p])\n  lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits]\n  print(f'Video shape is {shape[:2]}')\n\n  video_kwargs = {\n      'shape': shape[:2],\n      'codec': 'h264',\n      'fps': config.render_video_fps,\n      'crf': config.render_video_crf,\n  }\n\n  for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']:\n    video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')\n    input_format = 'gray' if k == 'acc' else 'rgb'\n    file_ext = 'png' if k in ['color', 'normals'] else 'tiff'\n    idx = 0\n    file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}')\n    if not utils.file_exists(file0):\n      print(f'Images missing for tag {k}')\n      continue\n    print(f'Making video {video_file}...')\n    with media.VideoWriter(\n        video_file, **video_kwargs, input_format=input_format) as writer:\n      for idx in range(num_frames):\n        img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}')\n        if not utils.file_exists(img_file):\n          ValueError(f'Image file {img_file} does not exist.')\n        img = utils.load_img(img_file)\n        if k in ['color', 'normals']:\n          img = img / 255.\n        elif k.startswith('distance'):\n          img = config.render_dist_curve_fn(img)\n          img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)\n          img = cm.get_cmap('turbo')(img)[..., :3]\n\n        frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)\n        writer.add_image(frame)\n        idx += 1\n\n\ndef main(unused_argv):\n\n  config = configs.load_config(save_config=False)\n\n  dataset = datasets.load_dataset('test', config.data_dir, config)\n\n  key = random.PRNGKey(20200823)\n  _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key)\n\n  if config.rawnerf_mode:\n    postprocess_fn = dataset.metadata['postprocess_fn']\n  else:\n    postprocess_fn = lambda z: z\n\n  state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)\n  step = int(state.step)\n  print(f'Rendering checkpoint at step {step}.')\n\n  out_name = 'path_renders' if config.render_path else 'test_preds'\n  out_name = f'{out_name}_step_{step}'\n  base_dir = config.render_dir\n  if base_dir is None:\n    base_dir = os.path.join(config.checkpoint_dir, 'render')\n  out_dir = os.path.join(base_dir, out_name)\n  if not utils.isdir(out_dir):\n    utils.makedirs(out_dir)\n\n  path_fn = lambda x: os.path.join(out_dir, x)\n\n  # Ensure sufficient zero-padding of image indices in output filenames.\n  zpad = max(3, len(str(dataset.size - 1)))\n  idx_to_str = lambda idx: str(idx).zfill(zpad)\n\n  if config.render_save_async:\n    async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)\n    async_futures = []\n    def save_fn(fn, *args, **kwargs):\n      async_futures.append(async_executor.submit(fn, *args, **kwargs))\n  else:\n    def save_fn(fn, *args, **kwargs):\n      fn(*args, **kwargs)\n\n  for idx in range(dataset.size):\n    if idx % config.render_num_jobs != config.render_job_id:\n      continue\n    # If current image and next image both already exist, skip ahead.\n    idx_str = idx_to_str(idx)\n    curr_file = path_fn(f'color_{idx_str}.png')\n    next_idx_str = idx_to_str(idx + config.render_num_jobs)\n    next_file = path_fn(f'color_{next_idx_str}.png')\n    if utils.file_exists(curr_file) and utils.file_exists(next_file):\n      print(f'Image {idx}/{dataset.size} already exists, skipping')\n      continue\n    print(f'Evaluating image {idx+1}/{dataset.size}')\n    eval_start_time = time.time()\n    rays = dataset.generate_ray_batch(idx).rays\n    train_frac = 1.\n    rendering = models.render_image(\n        functools.partial(render_eval_pfn, state.params, train_frac),\n        rays, None, config)\n    print(f'Rendered in {(time.time() - eval_start_time):0.3f}s')\n\n    if jax.host_id() != 0:  # Only record via host 0.\n      continue\n\n    rendering['rgb'] = postprocess_fn(rendering['rgb'])\n\n    save_fn(\n        utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png'))\n    if 'normals' in rendering:\n      save_fn(\n          utils.save_img_u8, rendering['normals'] / 2. + 0.5,\n          path_fn(f'normals_{idx_str}.png'))\n    save_fn(\n        utils.save_img_f32, rendering['distance_mean'],\n        path_fn(f'distance_mean_{idx_str}.tiff'))\n    save_fn(\n        utils.save_img_f32, rendering['distance_median'],\n        path_fn(f'distance_median_{idx_str}.tiff'))\n    save_fn(\n        utils.save_img_f32, rendering['acc'], path_fn(f'acc_{idx_str}.tiff'))\n\n  if config.render_save_async:\n    # Wait until all worker threads finish.\n    async_executor.shutdown(wait=True)\n\n    # This will ensure that exceptions in child threads are raised to the\n    # main thread.\n    for future in async_futures:\n      future.result()\n\n  time.sleep(1)\n  num_files = len(glob.glob(path_fn('acc_*.tiff')))\n  time.sleep(10)\n  if jax.host_id() == 0 and num_files == dataset.size:\n    print(f'All files found, creating videos (job {config.render_job_id}).')\n    create_videos(config, base_dir, out_dir, out_name, dataset.size)\n\n  # A hack that forces Jax to keep all TPUs alive until every TPU is finished.\n  x = jax.numpy.ones([jax.local_device_count()])\n  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))\n  print(x)\n\n\nif __name__ == '__main__':\n  with gin.config_scope('eval'):  # Use the same scope as eval.py\n    app.run(main)\n"
  },
  {
    "path": "mip360/requirements.txt",
    "content": "numpy\njax[cuda11_cudnn82]==0.3.24 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\nflax==0.6.2\nopencv-python\nPillow\ntensorboard\ntensorflow\ngin-config\ndm_pix\nrawpy\nmediapy\neinops\ntorch==1.10.1\ntorchvision==0.11.2\ntorchmetrics==0.9.3\ntqdm\nlpips==0.1.4"
  },
  {
    "path": "mip360/scripts/eval_blender.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=ficus\nEXPERIMENT=blender\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/blender_256.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/eval_llff.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=flower\nEXPERIMENT=llff\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/llff_256.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/eval_raw.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=nightpiano\nEXPERIMENT=raw\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/llff_raw.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/eval_shinyblender.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=toaster\nEXPERIMENT=shinyblender\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/blender_refnerf.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/generate_tables.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {\n    \"id\": \"wbSGA8PNKvIy\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import glob\\n\",\n    \"import os\\n\",\n    \"import numpy as np\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"from collections import defaultdict\\n\",\n    \"from matplotlib import rc\\n\",\n    \"plt.rc('font', family='serif')\\n\",\n    \"import tensorflow as tf\\n\",\n    \"import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"RpuHMCndxqs4\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def scrape_folder(folder, num_iters, metric_names = ['psnr', 'ssim', 'lpips']):\\n\",\n    \"  stats = {}\\n\",\n    \"  for i_metric, metric_name in enumerate(metric_names):\\n\",\n    \"    filename = os.path.join(folder, 'test_preds', f'metric_{metric_name}_{num_iters}.txt')\\n\",\n    \"    with open(filename, mode='r') as f:\\n\",\n    \"      v = np.array([float(s) for s in f.readline().split(' ')])\\n\",\n    \"    stats[metric_name] = np.mean(v)\\n\",\n    \"\\n\",\n    \"  tic = time.time()\\n\",\n    \"  grab_tags = ['train_steps_per_sec', 'num_params', 'train_num_params']\\n\",\n    \"  grabbed_tags = {k:[] for k in grab_tags}\\n\",\n    \"  for pattern in ['events*']:#, 'eval/events*']:\\n\",\n    \"    for event_file in glob.glob(os.path.join(folder, pattern)):\\n\",\n    \"      for event in tf.compat.v1.train.summary_iterator(event_file):\\n\",\n    \"        value = event.summary.value\\n\",\n    \"        if len(value) > 0:\\n\",\n    \"          tag = event.summary.value[0].tag\\n\",\n    \"          if tag in grab_tags:\\n\",\n    \"            grabbed_tags[tag].append(np.array(tf.make_ndarray(event.summary.value[0].tensor)).item())\\n\",\n    \"\\n\",\n    \"  if grabbed_tags['train_steps_per_sec']:\\n\",\n    \"    steps_per_sec = np.percentile(np.array(grabbed_tags['train_steps_per_sec']), 95)\\n\",\n    \"    stats['num_hours'] = (num_iters / steps_per_sec) / (60**2)\\n\",\n    \"  else:\\n\",\n    \"    stats['num_hours'] = np.nan\\n\",\n    \"\\n\",\n    \"  if grabbed_tags['num_params']:\\n\",\n    \"    stats['mega_params'] = int(np.max(grabbed_tags['num_params'])) / (1000000) # in millions\\n\",\n    \"  elif grabbed_tags['train_num_params']:\\n\",\n    \"    stats['mega_params'] = int(np.max(grabbed_tags['train_num_params'])) / (1000000) # in millions\\n\",\n    \"  else:\\n\",\n    \"    stats['mega_params'] = np.nan\\n\",\n    \"\\n\",\n    \"  return stats\\n\",\n    \"\\n\",\n    \"def render_table(names, data, precisions, rank_order, suffixes=None, hlines = []):\\n\",\n    \"  def rankify(x, order):\\n\",\n    \"    assert len(x.shape) == 1\\n\",\n    \"    if order == 0:\\n\",\n    \"      return np.full_like(x, 1e5, dtype=np.int32)\\n\",\n    \"    u = np.sort(np.unique(x))\\n\",\n    \"    if order == 1:\\n\",\n    \"      u = u[::-1]\\n\",\n    \"    r = np.zeros_like(x, dtype=np.int32)\\n\",\n    \"    for ui, uu in enumerate(u):\\n\",\n    \"      mask = x == uu\\n\",\n    \"      r[mask] = ui\\n\",\n    \"    return np.int32(r)\\n\",\n    \"\\n\",\n    \"  tags = ['   \\\\cellcolor{red}',\\n\",\n    \"          '\\\\cellcolor{orange}',\\n\",\n    \"          '\\\\cellcolor{yellow}',\\n\",\n    \"          '                  ']\\n\",\n    \"\\n\",\n    \"  max_len = max([len(v) for v in list(names)])\\n\",\n    \"  names_padded = [v + ' '*(max_len-len(v)) for v in names]\\n\",\n    \"\\n\",\n    \"  data_quant = np.round((data * 10.**(np.array(precisions)[None, :]))) / 10.**(np.array(precisions)[None, :])\\n\",\n    \"  if suffixes is None:\\n\",\n    \"    suffixes = [''] * len(precisions)\\n\",\n    \"\\n\",\n    \"  tagranks = []\\n\",\n    \"  for d in range(data_quant.shape[1]):\\n\",\n    \"    tagranks.append(np.clip(rankify(data_quant[:,d], rank_order[d]), 0, len(tags)-1))\\n\",\n    \"  tagranks = np.stack(tagranks, -1)\\n\",\n    \"\\n\",\n    \"  for i_row in range(len(names)):\\n\",\n    \"    line = ''\\n\",\n    \"    if i_row in hlines:\\n\",\n    \"      line += '\\\\\\\\hline\\\\n'\\n\",\n    \"    line += names_padded[i_row]\\n\",\n    \"    for d in range(data_quant.shape[1]):\\n\",\n    \"      line += ' & '\\n\",\n    \"      if rank_order[d] != 0 and not np.isnan(data[i_row,d]):\\n\",\n    \"        line += tags[tagranks[i_row, d]]\\n\",\n    \"      if np.isnan(data[i_row,d]):\\n\",\n    \"        line += ' - '\\n\",\n    \"      else:\\n\",\n    \"        assert precisions[d] >= 0\\n\",\n    \"        line += ('{:' + f'0.{precisions[d]}f' + '}').format(data_quant[i_row,d]) + suffixes[d]\\n\",\n    \"    line += ' \\\\\\\\\\\\\\\\'\\n\",\n    \"    print(line)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"b2wPHPku5jO8\"\n   },\n   \"source\": [\n    \"Mip-NeRF 360's results\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 27352,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870904226,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"gKnbxt2AKz8w\",\n    \"outputId\": \"c8818082-cf6d-4212-c503-558e1dec49e7\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mipnerf360_450936927/360 bicycle {'psnr': 24.39237983703613, 'ssim': 0.6856040573120117, 'lpips': 0.3024290585517883, 'num_hours': 6.2325243293076875, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 flowerbed {'psnr': 21.722528544339266, 'ssim': 0.5822233611887152, 'lpips': 0.3457577336918224, 'num_hours': 7.029125016907558, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 gardenvase {'psnr': 26.994542519251507, 'ssim': 0.8132077877720197, 'lpips': 0.16961431813736758, 'num_hours': 6.373895524145641, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 stump {'psnr': 26.433189034461975, 'ssim': 0.7453256547451019, 'lpips': 0.2610513372346759, 'num_hours': 6.372994107616741, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 treehill {'psnr': 22.80838351779514, 'ssim': 0.6286024749279022, 'lpips': 0.3419443451695972, 'num_hours': 6.373792756813197, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 fulllivingroom {'psnr': 31.74024293361566, 'ssim': 0.9143212055548643, 'lpips': 0.210464691504454, 'num_hours': 6.229779091791955, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 kitchencounter {'psnr': 29.52976067860921, 'ssim': 0.8942874411741892, 'lpips': 0.20467615375916162, 'num_hours': 6.376739565840228, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 kitchenlego {'psnr': 32.162456675938195, 'ssim': 0.9199009571756636, 'lpips': 0.12717486768960953, 'num_hours': 6.381604252968429, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 officebonsai {'psnr': 33.498427004427526, 'ssim': 0.9409479949925397, 'lpips': 0.1763693257360845, 'num_hours': 6.371603406297137, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360 {'psnr': 27.697990082830515, 'ssim': 0.7916023260936674, 'lpips': 0.23772020349717346, 'num_hours': 6.415784227965397, 'mega_params': 9.007493}\\n\",\n      \"mipnerf360_450936927/360_glo4 bicycle {'psnr': 24.053745727539063, 'ssim': 0.6909286546707153, 'lpips': 0.2945095646381378, 'num_hours': 6.3832255160561, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 flowerbed {'psnr': 21.374122966419566, 'ssim': 0.5803043625571511, 'lpips': 0.3430247970602729, 'num_hours': 6.38164317846884, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 gardenvase {'psnr': 25.532458384831745, 'ssim': 0.8063657606641451, 'lpips': 0.17009872818986574, 'num_hours': 6.386194917543813, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 stump {'psnr': 26.132395386695862, 'ssim': 0.7474702708423138, 'lpips': 0.25648255459964275, 'num_hours': 6.386754034296546, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 treehill {'psnr': 22.360591994391548, 'ssim': 0.6240272356404198, 'lpips': 0.33734683361318374, 'num_hours': 6.382674989177764, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 fulllivingroom {'psnr': 29.69650620680589, 'ssim': 0.912017149802966, 'lpips': 0.20836940980874574, 'num_hours': 6.387421971615129, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 kitchencounter {'psnr': 28.329934310913085, 'ssim': 0.8888227045536041, 'lpips': 0.20833696722984313, 'num_hours': 6.386267560219862, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 kitchenlego {'psnr': 30.207322692871095, 'ssim': 0.9143631560461861, 'lpips': 0.13056355544498988, 'num_hours': 6.387698179101134, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 officebonsai {'psnr': 30.50528634561075, 'ssim': 0.9331820059467006, 'lpips': 0.18204631998732285, 'num_hours': 6.254863324211862, 'mega_params': 9.012005}\\n\",\n      \"mipnerf360_450936927/360_glo4 {'psnr': 26.465818224008732, 'ssim': 0.7886090334138002, 'lpips': 0.2367531922857783, 'num_hours': 6.37074929674345, 'mega_params': 9.012005}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/nerf/'\\n\",\n    \"scene_names = ['bicycle', 'flowerbed', 'gardenvase', 'stump', 'treehill', 'fulllivingroom', 'kitchencounter', 'kitchenlego', 'officebonsai']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"models_meta[None] = 'mip-NeRF 360 (from the paper)'\\n\",\n    \"models_meta['mipnerf360_450936927/360'] = 'mip-NeRF 360'\\n\",\n    \"models_meta['mipnerf360_450936927/360_glo4'] = 'mip-NeRF 360 w/GLO'\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 250000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for model_path in models_meta.keys():\\n\",\n    \"\\n\",\n    \"  if model_path is None:\\n\",\n    \"    # Inject the numbers from the paper.\\n\",\n    \"    psnrs = [24.37, 21.73, 26.98, 26.40, 22.87, 31.63, 29.55, 32.23, 33.46]\\n\",\n    \"    ssims = [0.685, 0.583, 0.813, 0.744, 0.632, 0.913, 0.894, 0.920, 0.941]\\n\",\n    \"    lpips = [0.301, 0.344, 0.170, 0.261, 0.339, 0.211, 0.204, 0.127, 0.176]\\n\",\n    \"    train_times = [np.nan]*len(psnrs)\\n\",\n    \"    model_sizes = [np.nan]*len(psnrs)\\n\",\n    \"    scene_stats = []\\n\",\n    \"    for p, s, l, tt, ms in zip(psnrs, ssims, lpips, train_times, model_sizes):\\n\",\n    \"      scene_stats.append({'psnr': p, 'ssim': s, 'lpips': l, 'num_hours': tt, 'mega_params': ms})\\n\",\n    \"    avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"    all_stats.append(scene_stats)\\n\",\n    \"    continue;\\n\",\n    \"\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS)\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 68,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870904417,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"ISRu0l-f0kTU\",\n    \"outputId\": \"0674ba64-e15e-450c-d8d1-77f6f740ca48\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mip-NeRF 360 (from the paper) & \\\\cellcolor{orange}27.69 &    \\\\cellcolor{red}0.792 &    \\\\cellcolor{red}0.237 &  -  &  -  \\\\\\\\\\n\",\n      \"mip-NeRF 360                  &    \\\\cellcolor{red}27.70 &    \\\\cellcolor{red}0.792 & \\\\cellcolor{orange}0.238 & 6.42 & 9.0M \\\\\\\\\\n\",\n      \"mip-NeRF 360 w/GLO            & \\\\cellcolor{yellow}26.47 & \\\\cellcolor{orange}0.789 &    \\\\cellcolor{red}0.237 & 6.37 & 9.0M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 61,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870904594,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"6FoznCOYTQUf\",\n    \"outputId\": \"8154ca71-8949-4de9-c567-5ef73e98a812\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"psnr\\n\",\n      \" & \\\\textit{bicycle} & \\\\textit{flowers} & \\\\textit{garden} & \\\\textit{stump} & \\\\textit{treehill} & \\\\textit{room} & \\\\textit{counter} & \\\\textit{kitchen} & \\\\textit{bonsai} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (from the paper) & \\\\cellcolor{orange}24.37 &    \\\\cellcolor{red}21.73 & \\\\cellcolor{orange}26.98 & \\\\cellcolor{orange}26.40 &    \\\\cellcolor{red}22.87 & \\\\cellcolor{orange}31.63 &    \\\\cellcolor{red}29.55 &    \\\\cellcolor{red}32.23 & \\\\cellcolor{orange}33.46 \\\\\\\\\\n\",\n      \"mip-NeRF 360                  &    \\\\cellcolor{red}24.39 & \\\\cellcolor{orange}21.72 &    \\\\cellcolor{red}26.99 &    \\\\cellcolor{red}26.43 & \\\\cellcolor{orange}22.81 &    \\\\cellcolor{red}31.74 & \\\\cellcolor{orange}29.53 & \\\\cellcolor{orange}32.16 &    \\\\cellcolor{red}33.50 \\\\\\\\\\n\",\n      \"mip-NeRF 360 w/GLO            & \\\\cellcolor{yellow}24.05 & \\\\cellcolor{yellow}21.37 & \\\\cellcolor{yellow}25.53 & \\\\cellcolor{yellow}26.13 & \\\\cellcolor{yellow}22.36 & \\\\cellcolor{yellow}29.70 & \\\\cellcolor{yellow}28.33 & \\\\cellcolor{yellow}30.21 & \\\\cellcolor{yellow}30.51 \\\\\\\\\\n\",\n      \"\\n\",\n      \"ssim\\n\",\n      \" & \\\\textit{bicycle} & \\\\textit{flowers} & \\\\textit{garden} & \\\\textit{stump} & \\\\textit{treehill} & \\\\textit{room} & \\\\textit{counter} & \\\\textit{kitchen} & \\\\textit{bonsai} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (from the paper) & \\\\cellcolor{yellow}0.685 &    \\\\cellcolor{red}0.583 &    \\\\cellcolor{red}0.813 & \\\\cellcolor{yellow}0.744 &    \\\\cellcolor{red}0.632 & \\\\cellcolor{orange}0.913 &    \\\\cellcolor{red}0.894 &    \\\\cellcolor{red}0.920 &    \\\\cellcolor{red}0.941 \\\\\\\\\\n\",\n      \"mip-NeRF 360                  & \\\\cellcolor{orange}0.686 & \\\\cellcolor{orange}0.582 &    \\\\cellcolor{red}0.813 & \\\\cellcolor{orange}0.745 & \\\\cellcolor{orange}0.629 &    \\\\cellcolor{red}0.914 &    \\\\cellcolor{red}0.894 &    \\\\cellcolor{red}0.920 &    \\\\cellcolor{red}0.941 \\\\\\\\\\n\",\n      \"mip-NeRF 360 w/GLO            &    \\\\cellcolor{red}0.691 & \\\\cellcolor{yellow}0.580 & \\\\cellcolor{orange}0.806 &    \\\\cellcolor{red}0.747 & \\\\cellcolor{yellow}0.624 & \\\\cellcolor{yellow}0.912 & \\\\cellcolor{orange}0.889 & \\\\cellcolor{orange}0.914 & \\\\cellcolor{orange}0.933 \\\\\\\\\\n\",\n      \"\\n\",\n      \"lpips\\n\",\n      \" & \\\\textit{bicycle} & \\\\textit{flowers} & \\\\textit{garden} & \\\\textit{stump} & \\\\textit{treehill} & \\\\textit{room} & \\\\textit{counter} & \\\\textit{kitchen} & \\\\textit{bonsai} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (from the paper) & \\\\cellcolor{orange}0.301 & \\\\cellcolor{orange}0.344 &    \\\\cellcolor{red}0.170 & \\\\cellcolor{orange}0.261 & \\\\cellcolor{orange}0.339 & \\\\cellcolor{yellow}0.211 &    \\\\cellcolor{red}0.204 &    \\\\cellcolor{red}0.127 &    \\\\cellcolor{red}0.176 \\\\\\\\\\n\",\n      \"mip-NeRF 360                  & \\\\cellcolor{yellow}0.302 & \\\\cellcolor{yellow}0.346 &    \\\\cellcolor{red}0.170 & \\\\cellcolor{orange}0.261 & \\\\cellcolor{yellow}0.342 & \\\\cellcolor{orange}0.210 & \\\\cellcolor{orange}0.205 &    \\\\cellcolor{red}0.127 &    \\\\cellcolor{red}0.176 \\\\\\\\\\n\",\n      \"mip-NeRF 360 w/GLO            &    \\\\cellcolor{red}0.295 &    \\\\cellcolor{red}0.343 &    \\\\cellcolor{red}0.170 &    \\\\cellcolor{red}0.256 &    \\\\cellcolor{red}0.337 &    \\\\cellcolor{red}0.208 & \\\\cellcolor{yellow}0.208 & \\\\cellcolor{orange}0.131 & \\\\cellcolor{orange}0.182 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3]\\n\",\n    \"rank_orders = [1, 1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"name_map['gardenvase'] = 'garden'\\n\",\n    \"name_map['flowerbed'] = 'flowers'\\n\",\n    \"name_map['fulllivingroom'] = 'room'\\n\",\n    \"name_map['kitchencounter'] = 'counter'\\n\",\n    \"name_map['kitchenlego'] = 'kitchen'\\n\",\n    \"name_map['officebonsai'] = 'bonsai'\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['psnr', 'ssim', 'lpips']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names), hlines = [len(names)-3])\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 13540,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870918242,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"wpFu_zPZV8rM\",\n    \"outputId\": \"bc6e8748-0017-4452-8262-956f4f31c1d5\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mipnerf360_452961962/blender_256 chair {'psnr': 35.14409606933594, 'ssim': 0.9807422143220902, 'lpips': 0.019885611250065268, 'num_hours': 3.0120813620093627, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 drums {'psnr': 25.672427253723143, 'ssim': 0.9345874139666557, 'lpips': 0.06211544882506132, 'num_hours': 2.454567881500612, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 ficus {'psnr': 32.68634537696838, 'ssim': 0.9779657471179962, 'lpips': 0.022341654361225663, 'num_hours': 2.4591483474261193, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 hotdog {'psnr': 37.58191083908081, 'ssim': 0.9816449856758118, 'lpips': 0.024997519506141545, 'num_hours': 2.4588705980222576, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 lego {'psnr': 36.30012715339661, 'ssim': 0.9804214736819268, 'lpips': 0.01723310980014503, 'num_hours': 3.00943955424293, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 materials {'psnr': 30.475751152038573, 'ssim': 0.9558810117840767, 'lpips': 0.04359573864378035, 'num_hours': 3.0132687091659123, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 mic {'psnr': 36.44918758392334, 'ssim': 0.9910019421577454, 'lpips': 0.008599376350175589, 'num_hours': 3.0127024649051144, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 ship {'psnr': 30.35776879310608, 'ssim': 0.8874139302968979, 'lpips': 0.12848058927804232, 'num_hours': 3.0069640380480425, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_256 {'psnr': 33.08345177769661, 'ssim': 0.9612073398754, 'lpips': 0.04090613100182963, 'num_hours': 2.803380369415044, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452961962/blender_512 chair {'psnr': 35.857433795928955, 'ssim': 0.98388851583004, 'lpips': 0.01675196835771203, 'num_hours': 3.7615084467568956, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 drums {'psnr': 25.507867193222047, 'ssim': 0.9318369966745377, 'lpips': 0.06635790368542076, 'num_hours': 3.7576099485237027, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 ficus {'psnr': 33.011657524108884, 'ssim': 0.9789321529865265, 'lpips': 0.021975036268122493, 'num_hours': 3.760496131362062, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 hotdog {'psnr': 38.04464369773865, 'ssim': 0.9834308186173439, 'lpips': 0.021500862692482768, 'num_hours': 3.7610453555148697, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 lego {'psnr': 36.325964279174805, 'ssim': 0.9808579578995704, 'lpips': 0.01648236089386046, 'num_hours': 4.3412879804404625, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 materials {'psnr': 30.130296745300292, 'ssim': 0.9524855437874794, 'lpips': 0.04783135158009827, 'num_hours': 4.343275342164319, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 mic {'psnr': 36.847799701690676, 'ssim': 0.9914190286397934, 'lpips': 0.00919440199737437, 'num_hours': 4.337300623881409, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 ship {'psnr': 31.14282497406006, 'ssim': 0.8940365880727768, 'lpips': 0.11575569273903966, 'num_hours': 3.759209558641999, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452961962/blender_512 {'psnr': 33.35856098890305, 'ssim': 0.9621109503135086, 'lpips': 0.03948119727676385, 'num_hours': 3.977716673410715, 'mega_params': 3.065733}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/nerf/'\\n\",\n    \"scene_names = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"\\n\",\n    \"models_meta['mipnerf360_452961962/blender_256'] = 'mip-NeRF 360 (256 hidden)'\\n\",\n    \"models_meta['mipnerf360_452961962/blender_512'] = 'mip-NeRF 360 (512 hidden)'\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 250000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for model_path in models_meta.keys():\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS)\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 70,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870918436,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"knNz0h93yzXP\",\n    \"outputId\": \"10364ad3-dcba-4549-a1d5-5b7380eef1d0\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}33.08 & \\\\cellcolor{orange}0.961 & \\\\cellcolor{orange}0.041 & 2.80 & 0.8M \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}33.36 &    \\\\cellcolor{red}0.962 &    \\\\cellcolor{red}0.039 & 3.98 & 3.1M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 68,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870918615,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"62it2f5ukbMP\",\n    \"outputId\": \"1b0d078a-b2dc-4291-c373-6f58f2489742\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"psnr\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}35.14 &    \\\\cellcolor{red}25.67 & \\\\cellcolor{orange}32.69 & \\\\cellcolor{orange}37.58 & \\\\cellcolor{orange}36.30 &    \\\\cellcolor{red}30.48 & \\\\cellcolor{orange}36.45 & \\\\cellcolor{orange}30.36 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}35.86 & \\\\cellcolor{orange}25.51 &    \\\\cellcolor{red}33.01 &    \\\\cellcolor{red}38.04 &    \\\\cellcolor{red}36.33 & \\\\cellcolor{orange}30.13 &    \\\\cellcolor{red}36.85 &    \\\\cellcolor{red}31.14 \\\\\\\\\\n\",\n      \"\\n\",\n      \"ssim\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}0.981 &    \\\\cellcolor{red}0.935 & \\\\cellcolor{orange}0.978 & \\\\cellcolor{orange}0.982 & \\\\cellcolor{orange}0.980 &    \\\\cellcolor{red}0.956 &    \\\\cellcolor{red}0.991 & \\\\cellcolor{orange}0.887 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}0.984 & \\\\cellcolor{orange}0.932 &    \\\\cellcolor{red}0.979 &    \\\\cellcolor{red}0.983 &    \\\\cellcolor{red}0.981 & \\\\cellcolor{orange}0.952 &    \\\\cellcolor{red}0.991 &    \\\\cellcolor{red}0.894 \\\\\\\\\\n\",\n      \"\\n\",\n      \"lpips\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}0.020 &    \\\\cellcolor{red}0.062 &    \\\\cellcolor{red}0.022 & \\\\cellcolor{orange}0.025 & \\\\cellcolor{orange}0.017 &    \\\\cellcolor{red}0.044 &    \\\\cellcolor{red}0.009 & \\\\cellcolor{orange}0.128 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}0.017 & \\\\cellcolor{orange}0.066 &    \\\\cellcolor{red}0.022 &    \\\\cellcolor{red}0.022 &    \\\\cellcolor{red}0.016 & \\\\cellcolor{orange}0.048 &    \\\\cellcolor{red}0.009 &    \\\\cellcolor{red}0.116 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3]\\n\",\n    \"rank_orders = [1, 1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['psnr', 'ssim', 'lpips']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names), hlines = [len(names)-2])\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 14157,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870932910,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"8Dq0kS3jzLQu\",\n    \"outputId\": \"657cfa79-8959-4f6b-bcb1-517a2f57e3ae\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mipnerf360_452868940/llff_256 fern {'psnr': 25.53523317972819, 'ssim': 0.8355980316797892, 'lpips': 0.17796580493450165, 'num_hours': 3.0791895901515773, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 flower {'psnr': 28.27823829650879, 'ssim': 0.8683305859565735, 'lpips': 0.13373453170061111, 'num_hours': 2.430556049280797, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 fortress {'psnr': 31.65355904897054, 'ssim': 0.9036512076854706, 'lpips': 0.10307400052746137, 'num_hours': 2.420831934334721, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 horns {'psnr': 28.028242111206055, 'ssim': 0.8812565580010414, 'lpips': 0.14718765299767256, 'num_hours': 3.0909099844438597, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 leaves {'psnr': 21.118935585021973, 'ssim': 0.7481956332921982, 'lpips': 0.18846691772341728, 'num_hours': 3.099055060584408, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 orchids {'psnr': 20.187190532684326, 'ssim': 0.682639941573143, 'lpips': 0.2243024781346321, 'num_hours': 2.4169084620651113, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 room {'psnr': 33.111086209615074, 'ssim': 0.9573092659314474, 'lpips': 0.08938046048084895, 'num_hours': 2.4188344266999375, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 trex {'psnr': 27.265713827950613, 'ssim': 0.9158718841416496, 'lpips': 0.11908453490052905, 'num_hours': 3.0837948697070785, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_256 {'psnr': 26.897274848960695, 'ssim': 0.8491066385326641, 'lpips': 0.14789954767495925, 'num_hours': 2.7550100471584367, 'mega_params': 0.835205}\\n\",\n      \"mipnerf360_452868940/llff_512 fern {'psnr': 25.259475072224934, 'ssim': 0.8444480299949646, 'lpips': 0.15365532040596008, 'num_hours': 3.7073142281576246, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 flower {'psnr': 28.34323196411133, 'ssim': 0.8737815976142883, 'lpips': 0.11761842519044877, 'num_hours': 4.416412998877655, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 fortress {'psnr': 31.545193036397297, 'ssim': 0.9064980645974478, 'lpips': 0.0853491226832072, 'num_hours': 4.410339683910223, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 horns {'psnr': 28.975930213928223, 'ssim': 0.9083132669329643, 'lpips': 0.10360320564359426, 'num_hours': 4.411220554963127, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 leaves {'psnr': 20.930516242980957, 'ssim': 0.757549062371254, 'lpips': 0.1729690581560135, 'num_hours': 4.405542887670279, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 orchids {'psnr': 19.98445224761963, 'ssim': 0.6835718005895615, 'lpips': 0.21904940903186798, 'num_hours': 4.416696944857583, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 room {'psnr': 33.66789722442627, 'ssim': 0.9654540717601776, 'lpips': 0.06235106599827608, 'num_hours': 4.409909662392265, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 trex {'psnr': 28.37834221976144, 'ssim': 0.929371850831168, 'lpips': 0.10232165455818176, 'num_hours': 4.407786158934037, 'mega_params': 3.065733}\\n\",\n      \"mipnerf360_452868940/llff_512 {'psnr': 27.13562977768126, 'ssim': 0.8586234680864783, 'lpips': 0.1271146577084437, 'num_hours': 4.323152889970349, 'mega_params': 3.065733}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/nerf/'\\n\",\n    \"scene_names = ['fern', 'flower', 'fortress', 'horns', 'leaves', 'orchids', 'room', 'trex']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"models_meta['mipnerf360_452868940/llff_256'] = 'mip-NeRF 360 (256 hidden)'\\n\",\n    \"models_meta['mipnerf360_452868940/llff_512'] = 'mip-NeRF 360 (512 hidden)'\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 250000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for i_model, model_path in enumerate(models_meta.keys()):\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS)\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 66,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870933094,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"wiF0Y8mRzSUf\",\n    \"outputId\": \"d01cec60-a22e-41e9-d268-dd8d449fcc67\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}26.90 & \\\\cellcolor{orange}0.849 & \\\\cellcolor{orange}0.148 & 2.76 & 0.8M \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}27.14 &    \\\\cellcolor{red}0.859 &    \\\\cellcolor{red}0.127 & 4.32 & 3.1M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 63,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870933280,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"B8iQz1uBj_BM\",\n    \"outputId\": \"f5a04137-0a8d-495b-9855-04c4fde7d77e\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"psnr\\n\",\n      \" & \\\\textit{fern} & \\\\textit{flower} & \\\\textit{fortress} & \\\\textit{horns} & \\\\textit{leaves} & \\\\textit{orchids} & \\\\textit{room} & \\\\textit{t-rex} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) &    \\\\cellcolor{red}25.54 & \\\\cellcolor{orange}28.28 &    \\\\cellcolor{red}31.65 & \\\\cellcolor{orange}28.03 &    \\\\cellcolor{red}21.12 &    \\\\cellcolor{red}20.19 & \\\\cellcolor{orange}33.11 & \\\\cellcolor{orange}27.27 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) & \\\\cellcolor{orange}25.26 &    \\\\cellcolor{red}28.34 & \\\\cellcolor{orange}31.55 &    \\\\cellcolor{red}28.98 & \\\\cellcolor{orange}20.93 & \\\\cellcolor{orange}19.98 &    \\\\cellcolor{red}33.67 &    \\\\cellcolor{red}28.38 \\\\\\\\\\n\",\n      \"\\n\",\n      \"ssim\\n\",\n      \" & \\\\textit{fern} & \\\\textit{flower} & \\\\textit{fortress} & \\\\textit{horns} & \\\\textit{leaves} & \\\\textit{orchids} & \\\\textit{room} & \\\\textit{t-rex} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}0.836 & \\\\cellcolor{orange}0.868 & \\\\cellcolor{orange}0.904 & \\\\cellcolor{orange}0.881 & \\\\cellcolor{orange}0.748 & \\\\cellcolor{orange}0.683 & \\\\cellcolor{orange}0.957 & \\\\cellcolor{orange}0.916 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}0.844 &    \\\\cellcolor{red}0.874 &    \\\\cellcolor{red}0.906 &    \\\\cellcolor{red}0.908 &    \\\\cellcolor{red}0.758 &    \\\\cellcolor{red}0.684 &    \\\\cellcolor{red}0.965 &    \\\\cellcolor{red}0.929 \\\\\\\\\\n\",\n      \"\\n\",\n      \"lpips\\n\",\n      \" & \\\\textit{fern} & \\\\textit{flower} & \\\\textit{fortress} & \\\\textit{horns} & \\\\textit{leaves} & \\\\textit{orchids} & \\\\textit{room} & \\\\textit{t-rex} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"mip-NeRF 360 (256 hidden) & \\\\cellcolor{orange}0.178 & \\\\cellcolor{orange}0.134 & \\\\cellcolor{orange}0.103 & \\\\cellcolor{orange}0.147 & \\\\cellcolor{orange}0.188 & \\\\cellcolor{orange}0.224 & \\\\cellcolor{orange}0.089 & \\\\cellcolor{orange}0.119 \\\\\\\\\\n\",\n      \"mip-NeRF 360 (512 hidden) &    \\\\cellcolor{red}0.154 &    \\\\cellcolor{red}0.118 &    \\\\cellcolor{red}0.085 &    \\\\cellcolor{red}0.104 &    \\\\cellcolor{red}0.173 &    \\\\cellcolor{red}0.219 &    \\\\cellcolor{red}0.062 &    \\\\cellcolor{red}0.102 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3]\\n\",\n    \"rank_orders = [1, 1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"name_map['trex'] = 't-rex'\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['psnr', 'ssim', 'lpips']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names), hlines = [len(names)-2])\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"JRakQXoa5c_s\"\n   },\n   \"source\": [\n    \"Reproducing Ref-NeRF's Results\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 6652,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870940046,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"8W1_8o5cpdnP\",\n    \"outputId\": \"1d52a6be-be7d-415a-fac4-681b5594e1f8\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"blender_refnerf_453115176 chair {'psnr': 35.43609601974487, 'ssim': 0.9821090731024742, 'lpips': 0.019048675729427488, 'normals_mae': 20.85330159, 'num_hours': 6.571275680175564, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 drums {'psnr': 25.852998418807985, 'ssim': 0.9371991902589798, 'lpips': 0.06024521112442017, 'normals_mae': 27.641556060000003, 'num_hours': 6.567777006412732, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 ficus {'psnr': 30.604970512390135, 'ssim': 0.966888021826744, 'lpips': 0.03917563056573272, 'normals_mae': 40.80395823, 'num_hours': 6.5770709353398535, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 hotdog {'psnr': 37.485798597335815, 'ssim': 0.9824606701731682, 'lpips': 0.02336025163065642, 'normals_mae': 8.700470121, 'num_hours': 6.570256454244849, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 lego {'psnr': 36.05616488456726, 'ssim': 0.9799841883778572, 'lpips': 0.01861008478794247, 'normals_mae': 24.550344864999996, 'num_hours': 6.577586022517292, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 materials {'psnr': 35.035994396209716, 'ssim': 0.9813295957446099, 'lpips': 0.023410547198727726, 'normals_mae': 10.3090648925, 'num_hours': 6.568449424984051, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 mic {'psnr': 36.72967861175537, 'ssim': 0.9917984154820442, 'lpips': 0.007707923451671377, 'normals_mae': 23.699775159999998, 'num_hours': 6.5744581258500725, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 ship {'psnr': 30.507489061355592, 'ssim': 0.8849879172444344, 'lpips': 0.13231919281184673, 'normals_mae': 30.48735126, 'num_hours': 6.570066162619403, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 {'psnr': 33.46364881277084, 'ssim': 0.963344634026289, 'lpips': 0.04048468966255314, 'normals_mae': 23.3807277723125, 'num_hours': 6.5721174765179775, 'mega_params': 0.71323}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/dorverbin/nerf/mipnerf360/'\\n\",\n    \"scene_names = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']\\n\",\n    \"metric_names = ['psnr', 'ssim', 'lpips', 'normals_mae']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"models_meta[None] = 'ref-NeRF (from the paper)'\\n\",\n    \"models_meta['blender_refnerf_453115176'] = 'ref-NeRF 360'\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 250000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for model_path in models_meta.keys():\\n\",\n    \"\\n\",\n    \"  if model_path is None:\\n\",\n    \"    # Inject the numbers from the paper.\\n\",\n    \"    psnrs = [35.83, 25.79, 33.91, 37.72, 36.25, 35.41, 36.76, 30.28]\\n\",\n    \"    ssims = [0.984, 0.937, 0.983, 0.984, 0.981, 0.983, 0.992, 0.880]\\n\",\n    \"    lpips = [0.017, 0.059, 0.019, 0.022, 0.018, 0.022, 0.007, 0.139]\\n\",\n    \"    # There was a bug in MAE computation for ['ficus', 'lego', 'mic'] in the paper, those numbers shouldn't be trusted.\\n\",\n    \"    nmaes = [19.852, 27.853, np.nan, 13.211, np.nan, 9.531, np.nan, 31.707]\\n\",\n    \"    train_times = [np.nan]*len(psnrs)\\n\",\n    \"    model_sizes = [np.nan]*len(psnrs)\\n\",\n    \"    scene_stats = []\\n\",\n    \"    for p, s, l, n, tt, ms in zip(psnrs, ssims, lpips, nmaes, train_times, model_sizes):\\n\",\n    \"      scene_stats.append({'psnr': p, 'ssim': s, 'lpips': l, 'normals_mae': n, 'num_hours': tt, 'mega_params': ms})\\n\",\n    \"    avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"    all_stats.append(scene_stats)\\n\",\n    \"    continue;\\n\",\n    \"\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS, metric_names=metric_names)\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 67,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870940232,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"EPQZUmp7p2N1\",\n    \"outputId\": \"3ee94963-dd84-4025-b61f-21c58d8f2b84\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}33.99 &    \\\\cellcolor{red}0.966 &    \\\\cellcolor{red}0.038 &  -  &  -  &  -  \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}33.46 & \\\\cellcolor{orange}0.963 & \\\\cellcolor{orange}0.040 &    \\\\cellcolor{red}23.38 & 6.57 & 0.7M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 72,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870940411,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"8dZGlK7iKI48\",\n    \"outputId\": \"b2c70244-b4b2-4e93-9fa1-1706053b2a1c\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"psnr\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}35.83 & \\\\cellcolor{orange}25.79 &    \\\\cellcolor{red}33.91 &    \\\\cellcolor{red}37.72 &    \\\\cellcolor{red}36.25 &    \\\\cellcolor{red}35.41 &    \\\\cellcolor{red}36.76 & \\\\cellcolor{orange}30.28 \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}35.44 &    \\\\cellcolor{red}25.85 & \\\\cellcolor{orange}30.60 & \\\\cellcolor{orange}37.49 & \\\\cellcolor{orange}36.06 & \\\\cellcolor{orange}35.04 & \\\\cellcolor{orange}36.73 &    \\\\cellcolor{red}30.51 \\\\\\\\\\n\",\n      \"\\n\",\n      \"ssim\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}0.984 &    \\\\cellcolor{red}0.937 &    \\\\cellcolor{red}0.983 &    \\\\cellcolor{red}0.984 &    \\\\cellcolor{red}0.981 &    \\\\cellcolor{red}0.983 &    \\\\cellcolor{red}0.992 & \\\\cellcolor{orange}0.880 \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}0.982 &    \\\\cellcolor{red}0.937 & \\\\cellcolor{orange}0.967 & \\\\cellcolor{orange}0.982 & \\\\cellcolor{orange}0.980 & \\\\cellcolor{orange}0.981 &    \\\\cellcolor{red}0.992 &    \\\\cellcolor{red}0.885 \\\\\\\\\\n\",\n      \"\\n\",\n      \"lpips\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}0.017 &    \\\\cellcolor{red}0.059 &    \\\\cellcolor{red}0.019 &    \\\\cellcolor{red}0.022 &    \\\\cellcolor{red}0.018 &    \\\\cellcolor{red}0.022 &    \\\\cellcolor{red}0.007 & \\\\cellcolor{orange}0.139 \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}0.019 & \\\\cellcolor{orange}0.060 & \\\\cellcolor{orange}0.039 & \\\\cellcolor{orange}0.023 & \\\\cellcolor{orange}0.019 & \\\\cellcolor{orange}0.023 & \\\\cellcolor{orange}0.008 &    \\\\cellcolor{red}0.132 \\\\\\\\\\n\",\n      \"\\n\",\n      \"normals_mae\\n\",\n      \" & \\\\textit{chair} & \\\\textit{drums} & \\\\textit{ficus} & \\\\textit{hotdog} & \\\\textit{lego} & \\\\textit{materials} & \\\\textit{mic} & \\\\textit{ship} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}19.85 & \\\\cellcolor{orange}27.85 &  -  & \\\\cellcolor{orange}13.21 &  -  &    \\\\cellcolor{red}9.53 &  -  & \\\\cellcolor{orange}31.71 \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}20.85 &    \\\\cellcolor{red}27.64 &    \\\\cellcolor{red}40.80 &    \\\\cellcolor{red}8.70 &    \\\\cellcolor{red}24.55 & \\\\cellcolor{orange}10.31 &    \\\\cellcolor{red}23.70 &    \\\\cellcolor{red}30.49 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3, 2]\\n\",\n    \"rank_orders = [1, 1, -1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['psnr', 'ssim', 'lpips', 'normals_mae']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names))\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 4986,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870945508,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"k54UULwA3U-N\",\n    \"outputId\": \"6543b07e-009a-4f42-c882-bd2ac9771c46\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"blender_refnerf_453115176 coffee {'psnr': 34.890018711090086, 'ssim': 0.9749827212095261, 'lpips': 0.07588588723912836, 'normals_mae': 11.845668504500003, 'num_hours': 6.575390185535848, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 helmet3 {'psnr': 29.714286947250365, 'ssim': 0.9552837216854095, 'lpips': 0.08357305970042944, 'normals_mae': 36.671281015, 'num_hours': 6.560205667893172, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 musclecar {'psnr': 31.08492564201355, 'ssim': 0.9564637777209282, 'lpips': 0.04047968026250601, 'normals_mae': 15.1815526225, 'num_hours': 6.561444666399807, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 phong_envmap {'psnr': 38.04869948387146, 'ssim': 0.9839657709002495, 'lpips': 0.0813480182737112, 'normals_mae': 4.32499015, 'num_hours': 6.565148490103001, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 teapot2 {'psnr': 45.91226264953613, 'ssim': 0.9964187118411064, 'lpips': 0.009179396553663536, 'normals_mae': 26.659667385, 'num_hours': 6.568492677763366, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 toaster {'psnr': 25.5054944562912, 'ssim': 0.9183712688088417, 'lpips': 0.10447088478133082, 'normals_mae': 43.071832670000006, 'num_hours': 6.571352772368665, 'mega_params': 0.71323}\\n\",\n      \"blender_refnerf_453115176 {'psnr': 34.19261464834213, 'ssim': 0.9642476620276769, 'lpips': 0.06582282113512823, 'normals_mae': 22.959165391166664, 'num_hours': 6.567005743343977, 'mega_params': 0.7132300000000001}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/dorverbin/nerf/mipnerf360/'\\n\",\n    \"scene_names = ['coffee', 'helmet3', 'musclecar', 'phong_envmap', 'teapot2', 'toaster']\\n\",\n    \"metric_names = ['psnr', 'ssim', 'lpips', 'normals_mae']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"models_meta[None] = 'ref-NeRF (from the paper)'\\n\",\n    \"models_meta['blender_refnerf_453115176'] = 'ref-NeRF 360'\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 250000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for model_path in models_meta.keys():\\n\",\n    \"\\n\",\n    \"  if model_path is None:\\n\",\n    \"    # Inject the numbers from the paper.\\n\",\n    \"    psnrs = [34.21, 29.68, 30.82, 47.46, 47.90, 25.70]\\n\",\n    \"    ssims = [0.974, 0.958, 0.955, 0.995, 0.998, 0.922]\\n\",\n    \"    lpips = [0.078, 0.075, 0.041, 0.059, 0.004, 0.095]\\n\",\n    \"    nmaes = [12.240, 29.484, 14.927, 1.548, 9.234, 42.870]\\n\",\n    \"    train_times = [np.nan]*len(psnrs)\\n\",\n    \"    model_sizes = [np.nan]*len(psnrs)\\n\",\n    \"    scene_stats = []\\n\",\n    \"    for p, s, l, n, tt, ms in zip(psnrs, ssims, lpips, nmaes, train_times, model_sizes):\\n\",\n    \"      scene_stats.append({'psnr': p, 'ssim': s, 'lpips': l, 'normals_mae': n, 'num_hours': tt, 'mega_params': ms})\\n\",\n    \"    avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"    all_stats.append(scene_stats)\\n\",\n    \"    continue;\\n\",\n    \"\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS, metric_names=metric_names)\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 74,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870945704,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"DpWZo3IIRnq4\",\n    \"outputId\": \"05b1c264-f524-44b9-c447-eb5cc6d9603c\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"ref-NeRF (from the paper) &    \\\\cellcolor{red}35.96 &    \\\\cellcolor{red}0.967 &    \\\\cellcolor{red}0.059 &    \\\\cellcolor{red}18.38 &  -  &  -  \\\\\\\\\\n\",\n      \"ref-NeRF 360              & \\\\cellcolor{orange}34.19 & \\\\cellcolor{orange}0.964 & \\\\cellcolor{orange}0.066 & \\\\cellcolor{orange}22.96 & 6.57 & 0.7M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 56,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870945900,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"rJpqItPs3hLA\",\n    \"outputId\": \"8a19973b-2968-4fe3-ed12-3e7d710635b0\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"psnr\\n\",\n      \" & \\\\textit{coffee} & \\\\textit{helmet} & \\\\textit{car} & \\\\textit{ball} & \\\\textit{teapot} & \\\\textit{toaster} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) & \\\\cellcolor{orange}34.21 & \\\\cellcolor{orange}29.68 & \\\\cellcolor{orange}30.82 &    \\\\cellcolor{red}47.46 &    \\\\cellcolor{red}47.90 &    \\\\cellcolor{red}25.70 \\\\\\\\\\n\",\n      \"ref-NeRF 360              &    \\\\cellcolor{red}34.89 &    \\\\cellcolor{red}29.71 &    \\\\cellcolor{red}31.08 & \\\\cellcolor{orange}38.05 & \\\\cellcolor{orange}45.91 & \\\\cellcolor{orange}25.51 \\\\\\\\\\n\",\n      \"\\n\",\n      \"ssim\\n\",\n      \" & \\\\textit{coffee} & \\\\textit{helmet} & \\\\textit{car} & \\\\textit{ball} & \\\\textit{teapot} & \\\\textit{toaster} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) & \\\\cellcolor{orange}0.974 &    \\\\cellcolor{red}0.958 & \\\\cellcolor{orange}0.955 &    \\\\cellcolor{red}0.995 &    \\\\cellcolor{red}0.998 &    \\\\cellcolor{red}0.922 \\\\\\\\\\n\",\n      \"ref-NeRF 360              &    \\\\cellcolor{red}0.975 & \\\\cellcolor{orange}0.955 &    \\\\cellcolor{red}0.956 & \\\\cellcolor{orange}0.984 & \\\\cellcolor{orange}0.996 & \\\\cellcolor{orange}0.918 \\\\\\\\\\n\",\n      \"\\n\",\n      \"lpips\\n\",\n      \" & \\\\textit{coffee} & \\\\textit{helmet} & \\\\textit{car} & \\\\textit{ball} & \\\\textit{teapot} & \\\\textit{toaster} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) & \\\\cellcolor{orange}0.078 &    \\\\cellcolor{red}0.075 & \\\\cellcolor{orange}0.041 &    \\\\cellcolor{red}0.059 &    \\\\cellcolor{red}0.004 &    \\\\cellcolor{red}0.095 \\\\\\\\\\n\",\n      \"ref-NeRF 360              &    \\\\cellcolor{red}0.076 & \\\\cellcolor{orange}0.084 &    \\\\cellcolor{red}0.040 & \\\\cellcolor{orange}0.081 & \\\\cellcolor{orange}0.009 & \\\\cellcolor{orange}0.104 \\\\\\\\\\n\",\n      \"\\n\",\n      \"normals_mae\\n\",\n      \" & \\\\textit{coffee} & \\\\textit{helmet} & \\\\textit{car} & \\\\textit{ball} & \\\\textit{teapot} & \\\\textit{toaster} \\\\\\\\\\\\hline\\n\",\n      \"ref-NeRF (from the paper) & \\\\cellcolor{orange}12.24 &    \\\\cellcolor{red}29.48 &    \\\\cellcolor{red}14.93 &    \\\\cellcolor{red}1.55 &    \\\\cellcolor{red}9.23 &    \\\\cellcolor{red}42.87 \\\\\\\\\\n\",\n      \"ref-NeRF 360              &    \\\\cellcolor{red}11.85 & \\\\cellcolor{orange}36.67 & \\\\cellcolor{orange}15.18 & \\\\cellcolor{orange}4.32 & \\\\cellcolor{orange}26.66 & \\\\cellcolor{orange}43.07 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3, 2]\\n\",\n    \"rank_orders = [1, 1, -1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"name_map['coffee'] = 'coffee'\\n\",\n    \"name_map['teapot2'] = 'teapot'\\n\",\n    \"name_map['musclecar'] = 'car'\\n\",\n    \"name_map['phong_envmap'] = 'ball'\\n\",\n    \"name_map['helmet3'] = 'helmet'\\n\",\n    \"name_map['toaster'] = 'toaster'\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['psnr', 'ssim', 'lpips', 'normals_mae']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names))\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"nXzghNgdJ35g\"\n   },\n   \"source\": [\n    \"Reproducing RawNeRF's results\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 44245,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870990258,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"ueL648m2ShTT\",\n    \"outputId\": \"75855af0-a037-4f22-86b4-304638a38f35\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"llff_raw_test_455616468 officetest {'cc_psnr': 23.438732147216797, 'cc_ssim': 0.5584264993667603, 'cc_lpips': 0.4703865349292755, 'num_hours': 5.789938422481767, 'mega_params': 0.61574}\\n\",\n      \"llff_raw_test_455616468 pianotest {'cc_psnr': 24.394481658935547, 'cc_ssim': 0.5581468343734741, 'cc_lpips': 0.4914418160915375, 'num_hours': 5.789418753302435, 'mega_params': 0.61574}\\n\",\n      \"llff_raw_test_455616468 yuccatest {'cc_psnr': 22.472274780273438, 'cc_ssim': 0.48502984642982483, 'cc_lpips': 0.5376803278923035, 'num_hours': 5.788862314940476, 'mega_params': 0.61574}\\n\",\n      \"llff_raw_test_455616468 {'cc_psnr': 23.435162862141926, 'cc_ssim': 0.533867726723353, 'cc_lpips': 0.49983622630437213, 'num_hours': 5.789406496908227, 'mega_params': 0.61574}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"root_folder = '/cns/lu-d/home/buff/bmild/nerf/mipnerf360'\\n\",\n    \"scene_names = ['officetest', 'pianotest', 'yuccatest']\\n\",\n    \"\\n\",\n    \"models_meta = {} # folder : latex_name\\n\",\n    \"models_meta[None] = 'rawNeRF (from the paper)'\\n\",\n    \"models_meta['llff_raw_test_455616468'] = 'rawNeRF'\\n\",\n    \"\\n\",\n    \"NUM_ITERS = 500000\\n\",\n    \"\\n\",\n    \"all_stats = []\\n\",\n    \"avg_stats = []\\n\",\n    \"for model_path in models_meta.keys():\\n\",\n    \"\\n\",\n    \"  if model_path is None:\\n\",\n    \"    # Inject the numbers from the paper.\\n\",\n    \"    psnrs = [23.42711075794648, 24.387560024128817, 22.3921655141463]\\n\",\n    \"    ssims = [0.5596604, 0.55790526, 0.48582202]\\n\",\n    \"    lpips = [0.47882235, 0.49443525, 0.5339949]\\n\",\n    \"    train_times = [np.nan]*len(psnrs)\\n\",\n    \"    model_sizes = [np.nan]*len(psnrs)\\n\",\n    \"    scene_stats = []\\n\",\n    \"    for p, s, l, tt, ms in zip(psnrs, ssims, lpips, train_times, model_sizes):\\n\",\n    \"      scene_stats.append({'cc_psnr': p, 'cc_ssim': s, 'cc_lpips': l, 'num_hours': tt, 'mega_params': ms})\\n\",\n    \"    avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"    all_stats.append(scene_stats)\\n\",\n    \"    continue;\\n\",\n    \"\\n\",\n    \"  scene_stats = []\\n\",\n    \"  for scene_name in scene_names:\\n\",\n    \"    folder = os.path.join(root_folder, model_path, scene_name)\\n\",\n    \"    stats = scrape_folder(folder, NUM_ITERS, metric_names = ['cc_psnr', 'cc_ssim', 'cc_lpips'])\\n\",\n    \"    print(model_path, scene_name, stats)\\n\",\n    \"    scene_stats.append(stats)\\n\",\n    \"  avg_stats.append({k: type(scene_stats[0][k])(np.mean([s[k] for s in scene_stats])) for k in scene_stats[0].keys()})\\n\",\n    \"  all_stats.append(scene_stats)\\n\",\n    \"  print(model_path, avg_stats[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 54,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870990431,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"0Qgef6LHJ841\",\n    \"outputId\": \"b9fcb848-1134-4cd9-b239-933dec1ac454\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"rawNeRF (from the paper) & \\\\cellcolor{orange}23.40 &    \\\\cellcolor{red}0.534 & \\\\cellcolor{orange}0.502 &  -  &  -  \\\\\\\\\\n\",\n      \"rawNeRF                  &    \\\\cellcolor{red}23.44 &    \\\\cellcolor{red}0.534 &    \\\\cellcolor{red}0.500 & 5.79 & 0.6M \\\\\\\\\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"data = np.stack([list(s.values()) for s in avg_stats])\\n\",\n    \"precisions = [2, 3, 3, 2, 1]\\n\",\n    \"rank_order = [1, 1, -1, 0, 0]  # +1 = higher is better, -1 = lower is better, 0 = do not color code\\n\",\n    \"suffixes = ['', '', '', '', 'M']\\n\",\n    \"render_table(names, data, precisions, rank_order, suffixes=suffixes)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 70,\n     \"status\": \"ok\",\n     \"timestamp\": 1656870990613,\n     \"user\": {\n      \"displayName\": \"\",\n      \"userId\": \"\"\n     },\n     \"user_tz\": -60\n    },\n    \"id\": \"c594nENnJ-bB\",\n    \"outputId\": \"46f0885d-b225-4484-849d-ec20818a7e88\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"cc_psnr\\n\",\n      \" & \\\\textit{officetest} & \\\\textit{pianotest} & \\\\textit{yuccatest} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"rawNeRF (from the paper) & \\\\cellcolor{orange}23.43 &    \\\\cellcolor{red}24.39 & \\\\cellcolor{orange}22.39 \\\\\\\\\\n\",\n      \"rawNeRF                  &    \\\\cellcolor{red}23.44 &    \\\\cellcolor{red}24.39 &    \\\\cellcolor{red}22.47 \\\\\\\\\\n\",\n      \"\\n\",\n      \"cc_ssim\\n\",\n      \" & \\\\textit{officetest} & \\\\textit{pianotest} & \\\\textit{yuccatest} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"rawNeRF (from the paper) &    \\\\cellcolor{red}0.560 &    \\\\cellcolor{red}0.558 &    \\\\cellcolor{red}0.486 \\\\\\\\\\n\",\n      \"rawNeRF                  & \\\\cellcolor{orange}0.558 &    \\\\cellcolor{red}0.558 & \\\\cellcolor{orange}0.485 \\\\\\\\\\n\",\n      \"\\n\",\n      \"cc_lpips\\n\",\n      \" & \\\\textit{officetest} & \\\\textit{pianotest} & \\\\textit{yuccatest} \\\\\\\\\\\\hline\\n\",\n      \"\\\\hline\\n\",\n      \"rawNeRF (from the paper) & \\\\cellcolor{orange}0.479 & \\\\cellcolor{orange}0.494 &    \\\\cellcolor{red}0.534 \\\\\\\\\\n\",\n      \"rawNeRF                  &    \\\\cellcolor{red}0.470 &    \\\\cellcolor{red}0.491 & \\\\cellcolor{orange}0.538 \\\\\\\\\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"names = list(models_meta.values())\\n\",\n    \"\\n\",\n    \"precisions = [2, 3, 3]\\n\",\n    \"rank_orders = [1, 1, -1]\\n\",\n    \"\\n\",\n    \"name_map = {s: s for s in scene_names}\\n\",\n    \"pretty_scene_names = [name_map[s] for s in scene_names]\\n\",\n    \"\\n\",\n    \"for i_metric, metric in enumerate(['cc_psnr', 'cc_ssim', 'cc_lpips']):\\n\",\n    \"  print(metric)\\n\",\n    \"  precision = precisions[i_metric]\\n\",\n    \"  rank_order = rank_orders[i_metric]\\n\",\n    \"\\n\",\n    \"  print(' & ' + ' & '.join(['\\\\\\\\textit{' + s + '}' for s in pretty_scene_names]) + ' \\\\\\\\\\\\\\\\\\\\\\\\hline')\\n\",\n    \"  data = np.array([np.array([s[metric] for s in scene_stats]) for scene_stats in all_stats])\\n\",\n    \"  render_table(names, data, [precision] * len(scene_names), [rank_order] * len(scene_names), hlines = [len(names)-2])\\n\",\n    \"  print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"l2kJLXgyMWWM\"\n   },\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"colab\": {\n   \"collapsed_sections\": [],\n   \"last_runtime\": {\n    \"build_target\": \"//learning/deepmind/public/tools/ml_python:ml_notebook\",\n    \"kind\": \"private\"\n   },\n   \"name\": \"generate_tables.ipynb\",\n   \"provenance\": [\n    {\n     \"file_id\": \"/piper/depot/google3/googlex/gcam/buff/mipnerf360/scripts/generate_tables.ipynb?workspaceId=dorverbin:refnerf_hybrid::citc\",\n     \"timestamp\": 1654708265431\n    },\n    {\n     \"file_id\": \"/piper/depot/google3/experimental/users/barron/prob_nerf/scripts/eval_single.ipynb?workspaceId=barron:mipnerf360_paper::citc\",\n     \"timestamp\": 1635265902094\n    },\n    {\n     \"file_id\": \"/piper/depot/google3/experimental/users/barron/prob_nerf/scripts/Pre_NeRF_Eval_multi.ipynb?workspaceId=barron:jaxnerf_mono5::citc\",\n     \"timestamp\": 1614394543651\n    },\n    {\n     \"file_id\": \"/piper/depot/google3/experimental/users/barron/prob_nerf/scripts/Pre_NeRF_Eval.ipynb?workspaceId=barron:jaxnerf_mono5::citc\",\n     \"timestamp\": 1614038274387\n    },\n    {\n     \"file_id\": \"10opVizeODokMJ10R7hwq7qVyLmYZx_ZA\",\n     \"timestamp\": 1613166364224\n    }\n   ]\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 1\n}\n"
  },
  {
    "path": "mip360/scripts/local_colmap_and_resize.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Set to 0 if you do not have a GPU.\nUSE_GPU=1\n# Path to a directory `base/` with images in `base/images/`.\nDATASET_PATH=$1\n# Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye.\nCAMERA=${2:-OPENCV}\n\n\n# Run COLMAP.\n\n### Feature extraction\n\ncolmap feature_extractor \\\n    --database_path \"$DATASET_PATH\"/database.db \\\n    --image_path \"$DATASET_PATH\"/images \\\n    --ImageReader.single_camera 1 \\\n    --ImageReader.camera_model \"$CAMERA\" \\\n    --SiftExtraction.use_gpu \"$USE_GPU\"\n\n\n### Feature matching\n\ncolmap exhaustive_matcher \\\n    --database_path \"$DATASET_PATH\"/database.db \\\n    --SiftMatching.use_gpu \"$USE_GPU\"\n\n## Use if your scene has > 500 images\n## Replace this path with your own local copy of the file.\n## Download from: https://demuc.de/colmap/#download\n# VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin\n# colmap vocab_tree_matcher \\\n#     --database_path \"$DATASET_PATH\"/database.db \\\n#     --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \\\n#     --SiftMatching.use_gpu \"$USE_GPU\"\n\n\n### Bundle adjustment\n\n# The default Mapper tolerance is unnecessarily large,\n# decreasing it speeds up bundle adjustment steps.\nmkdir -p \"$DATASET_PATH\"/sparse\ncolmap mapper \\\n    --database_path \"$DATASET_PATH\"/database.db \\\n    --image_path \"$DATASET_PATH\"/images \\\n    --output_path \"$DATASET_PATH\"/sparse \\\n    --Mapper.ba_global_function_tolerance=0.000001\n\n\n### Image undistortion\n\n## Use this if you want to undistort your images into ideal pinhole intrinsics.\n# mkdir -p \"$DATASET_PATH\"/dense\n# colmap image_undistorter \\\n#     --image_path \"$DATASET_PATH\"/images \\\n#     --input_path \"$DATASET_PATH\"/sparse/0 \\\n#     --output_path \"$DATASET_PATH\"/dense \\\n#     --output_type COLMAP\n\n# Resize images.\n\ncp -r \"$DATASET_PATH\"/images \"$DATASET_PATH\"/images_2\n\npushd \"$DATASET_PATH\"/images_2\nls | xargs -P 8 -I {} mogrify -resize 50% {}\npopd\n\ncp -r \"$DATASET_PATH\"/images \"$DATASET_PATH\"/images_4\n\npushd \"$DATASET_PATH\"/images_4\nls | xargs -P 8 -I {} mogrify -resize 25% {}\npopd\n\ncp -r \"$DATASET_PATH\"/images \"$DATASET_PATH\"/images_8\n\npushd \"$DATASET_PATH\"/images_8\nls | xargs -P 8 -I {} mogrify -resize 12.5% {}\npopd\n"
  },
  {
    "path": "mip360/scripts/msnerf/eval_360.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene01\nEXPERIMENT=logs_Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/synthetic_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/ms-nerf/360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n\n# for real captured part\nSCENE=Scan01\nEXPERIMENT=logs_Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/posed_real_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/ms-nerf/360.gin \\\n  --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n  --gin_bindings=\"Config.factor = 8\" \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr"
  },
  {
    "path": "mip360/scripts/msnerf/eval_ms360.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene04\nEXPERIMENT=MS-Mip-NeRF-360\nDATA_DIR=/jiaopengyi/ms-nerf/jax/dataset\nCHECKPOINT_DIR=/jiaopengyi/ms-nerf/jax/output/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/ms-nerf/ms360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n\n# for real captured part\n# SCENE=Scan01\n# EXPERIMENT=logs_MS-Mip-NeRF-360\n# DATA_DIR=/mnt/sda/experiments/cvpr23_real_cap_dataset\n# CHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\n# python -m eval \\\n#   --gin_configs=configs/ms-nerf/ms360.gin \\\n#   --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n#   --gin_bindings=\"Config.factor = 8\" \\\n#   --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n#   --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n#   --logtostderr"
  },
  {
    "path": "mip360/scripts/msnerf/eval_one_ms360.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene04\nEXPERIMENT=MS-Mip-NeRF-360\nDATA_DIR=/jiaopengyi/ms-nerf/jax/dataset\nCHECKPOINT_DIR=/jiaopengyi/ms-nerf/jax/output/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m eval \\\n  --gin_configs=configs/ms-nerf/ms360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.eval_one = (5,300,200)\" \\\n  --logtostderr\n\n# for real captured part\n# SCENE=Scan01\n# EXPERIMENT=logs_MS-Mip-NeRF-360\n# DATA_DIR=/mnt/sda/experiments/cvpr23_real_cap_dataset\n# CHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\n# python -m eval \\\n#   --gin_configs=configs/ms-nerf/ms360.gin \\\n#   --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n#   --gin_bindings=\"Config.factor = 8\" \\\n#   --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n#   --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n#   --logtostderr"
  },
  {
    "path": "mip360/scripts/msnerf/render_360.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene01\nEXPERIMENT=logs_Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/synthetic_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/ms-nerf/360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr\n\n# for real captured part\nSCENE=Scan01\nEXPERIMENT=logs_Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/posed_real_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/ms-nerf/360.gin \\\n  --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n  --gin_bindings=\"Config.factor = 8\" \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr"
  },
  {
    "path": "mip360/scripts/msnerf/render_ms360.sh",
    "content": "#!/bin/bash\n# Copyright 2023 Ze-Xin Yin\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene01\nEXPERIMENT=logs_MS-Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/synthetic_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/ms-nerf/ms360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr\n\n# for real captured part\nSCENE=Scan01\nEXPERIMENT=logs_MS-Mip-NeRF-360\nDATA_DIR=/mnt/sda/T3/cvpr23/dataset/posed_real_scenes\nCHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/ms-nerf/ms360.gin \\\n  --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n  --gin_bindings=\"Config.factor = 8\" \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/msnerf/train_360.sh",
    "content": "#!/bin/bash\n# Copyright 2023 Ze-Xin Yin\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene04\nEXPERIMENT=Mip-NeRF-360\nDATA_DIR=/jiaopengyi/ms-nerf/jax/dataset\nCHECKPOINT_DIR=/jiaopengyi/ms-nerf/jax/output/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/ms-nerf/360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n\n# for real captured part\n# SCENE=Scan01\n# EXPERIMENT=logs_Mip-NeRF-360\n# DATA_DIR=/mnt/sda/T3/cvpr23/dataset/posed_real_scenes\n# CHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\n# rm \"$CHECKPOINT_DIR\"/*\n# python -m train \\\n#   --gin_configs=configs/ms-nerf/360.gin \\\n#   --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n#   --gin_bindings=\"Config.factor = 8\" \\\n#   --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n#   --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n#   --logtostderr"
  },
  {
    "path": "mip360/scripts/msnerf/train_ms360.sh",
    "content": "#!/bin/bash\n# Copyright 2023 Ze-Xin Yin\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES='0,1'\n\n# for synthetic part\nSCENE=Scene04\nEXPERIMENT=MS-Mip-NeRF-360\nDATA_DIR=/jiaopengyi/ms-nerf/jax/dataset\nCHECKPOINT_DIR=/jiaopengyi/ms-nerf/jax/output/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/ms-nerf/ms360.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n\n# for real captured part\n# SCENE=Scan01\n# EXPERIMENT=logs_MS-Mip-NeRF-360\n# DATA_DIR=/mnt/sda/T3/cvpr23/dataset/posed_real_scenes\n# CHECKPOINT_DIR=/mnt/sda/experiments/cvpr23/Mip-NeRF-360/\"$EXPERIMENT\"/\"$SCENE\"\n\n# rm \"$CHECKPOINT_DIR\"/*\n# python -m train \\\n#   --gin_configs=configs/ms-nerf/ms360.gin \\\n#   --gin_bindings=\"Config.dataset_loader = 'llff'\" \\\n#   --gin_bindings=\"Config.factor = 8\" \\\n#   --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n#   --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n#   --logtostderr"
  },
  {
    "path": "mip360/scripts/render_llff.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=flower\nEXPERIMENT=llff\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/llff_256.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/render_raw.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=nightpiano\nEXPERIMENT=raw\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\npython -m render \\\n  --gin_configs=configs/llff_raw.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --gin_bindings=\"Config.render_path = True\" \\\n  --gin_bindings=\"Config.render_path_frames = 10\" \\\n  --gin_bindings=\"Config.render_dir = '${CHECKPOINT_DIR}/render/'\" \\\n  --gin_bindings=\"Config.render_video_fps = 2\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/run_all_unit_tests.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\npython -m unittest tests.camera_utils_test\npython -m unittest tests.geopoly_test\npython -m unittest tests.stepfun_test\npython -m unittest tests.coord_test\npython -m unittest tests.image_test\npython -m unittest tests.ref_utils_test\npython -m unittest tests.utils_test\npython -m unittest tests.datasets_test\npython -m unittest tests.math_test\npython -m unittest tests.render_test\n"
  },
  {
    "path": "mip360/scripts/train_blender.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=ficus\nEXPERIMENT=blender\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_synthetic\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/blender_256.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/train_llff.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=flower\nEXPERIMENT=llff\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/llff_256.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/train_raw.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=nightpiano\nEXPERIMENT=raw\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/llff_raw.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/scripts/train_shinyblender.sh",
    "content": "#!/bin/bash\n# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nexport CUDA_VISIBLE_DEVICES=0\n\nSCENE=toaster\nEXPERIMENT=shinyblender\nDATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic\nCHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/\"$EXPERIMENT\"/\"$SCENE\"\n\nrm \"$CHECKPOINT_DIR\"/*\npython -m train \\\n  --gin_configs=configs/blender_refnerf.gin \\\n  --gin_bindings=\"Config.data_dir = '${DATA_DIR}/${SCENE}'\" \\\n  --gin_bindings=\"Config.checkpoint_dir = '${CHECKPOINT_DIR}'\" \\\n  --logtostderr\n"
  },
  {
    "path": "mip360/tests/camera_utils_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for camera_utils.\"\"\"\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom internal import camera_utils\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\n\nclass CameraUtilsTest(parameterized.TestCase):\n\n  def test_convert_to_ndc(self):\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      # Random pinhole camera intrinsics.\n      key, rng = random.split(rng)\n      focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.)\n      camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2.,\n                                               height / 2.)\n      pixtocam = np.linalg.inv(camtopix)\n      near = 1.\n\n      # Random rays, pointing forward (negative z direction).\n      num_rays = 1000\n      key, rng = random.split(rng)\n      origins = jnp.array([0., 0., 1.])\n      origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.)\n      directions = jnp.array([0., 0., -1.])\n      directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5)\n\n      # Project world-space points along each ray into NDC space.\n      t = jnp.linspace(0., 1., 10)\n      pts_world = origins + t[:, None, None] * directions\n      pts_ndc = jnp.stack([\n          -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2],\n          -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2],\n          1. + 2. * near / pts_world[..., 2],\n      ],\n                          axis=-1)\n\n      # Get NDC space rays.\n      origins_ndc, directions_ndc = camera_utils.convert_to_ndc(\n          origins, directions, pixtocam, near)\n\n      # Ensure that the NDC space points lie on the calculated rays.\n      directions_ndc_norm = jnp.linalg.norm(\n          directions_ndc, axis=-1, keepdims=True)\n      directions_ndc_unit = directions_ndc / directions_ndc_norm\n      projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1)\n      pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None]\n\n      # pts_ndc should be close to their projections pts_ndc_proj onto the rays.\n      np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/coord_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for coord.\"\"\"\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom internal import coord\nfrom internal import math\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\n\ndef sample_covariance(rng, batch_size, num_dims):\n  \"\"\"Sample a random covariance matrix.\"\"\"\n  half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2)\n  cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2))\n  return cov\n\n\ndef stable_pos_enc(x, n):\n  \"\"\"A stable pos_enc for very high degrees, courtesy of Sameer Agarwal.\"\"\"\n  sin_x = np.sin(x)\n  cos_x = np.cos(x)\n  output = []\n  rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double')\n  for _ in range(n):\n    output.append(rotmat[::-1, 0, :])\n    rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat)\n  return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n])\n\n\nclass CoordTest(parameterized.TestCase):\n\n  def test_stable_pos_enc(self):\n    \"\"\"Test that the stable posenc implementation works on multiples of pi/2.\"\"\"\n    n = 10\n    x = np.linspace(-np.pi, np.pi, 5)\n    z = stable_pos_enc(x, n).reshape([-1, 2, n])\n    z0_true = np.zeros_like(z[:, 0, :])\n    z1_true = np.ones_like(z[:, 1, :])\n    z0_true[:, 0] = [0, -1, 0, 1, 0]\n    z1_true[:, 0] = [-1, 0, 1, 0, -1]\n    z1_true[:, 1] = [1, -1, 1, -1, 1]\n    z_true = np.stack([z0_true, z1_true], axis=1)\n    np.testing.assert_allclose(z, z_true, atol=1e-10)\n\n  def test_contract_matches_special_case(self):\n    \"\"\"Test the math for Figure 2 of https://arxiv.org/abs/2111.12077.\"\"\"\n    n = 10\n    _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf)\n    s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1)\n    tc = coord.contract(s_to_t(s)[:, None])[:, 0]\n    delta_tc = tc[1:] - tc[:-1]\n    np.testing.assert_allclose(\n        delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5)\n\n  def test_contract_is_bounded(self):\n    n, d = 10000, 3\n    rng = random.PRNGKey(0)\n    key0, key1, rng = random.split(rng, 3)\n    x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp(\n        random.uniform(key1, [n, d], minval=-10, maxval=10))\n    y = coord.contract(x)\n    self.assertLessEqual(jnp.max(y), 2)\n\n  def test_contract_is_noop_when_norm_is_leq_one(self):\n    n, d = 10000, 3\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    x = random.normal(key, shape=[n, d])\n    xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True))\n\n    # Sanity check on the test itself.\n    assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6\n\n    yc = coord.contract(xc)\n    np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5)\n\n  def test_contract_gradients_are_finite(self):\n    # Construct x such that we probe x == 0, where things are unstable.\n    x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1)\n    grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x)\n    self.assertTrue(jnp.all(jnp.isfinite(grad)))\n\n  def test_inv_contract_gradients_are_finite(self):\n    z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1)\n    z = z.reshape([-1, 2])\n    z = z[jnp.sum(z**2, axis=-1) < 2, :]\n    grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z)\n    self.assertTrue(jnp.all(jnp.isfinite(grad)))\n\n  def test_inv_contract_inverts_contract(self):\n    \"\"\"Do a round-trip from metric space to contracted space and back.\"\"\"\n    x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1)\n    x_recon = coord.inv_contract(coord.contract(x))\n    np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5)\n\n  @parameterized.named_parameters(\n      ('05_1e-5', 5, 1e-5),\n      ('10_1e-4', 10, 1e-4),\n      ('15_0.005', 15, 0.005),\n      ('20_0.2', 20, 0.2),  # At high degrees, our implementation is unstable.\n      ('25_2', 25, 2),  # 2 is the maximum possible error.\n      ('30_2', 30, 2),\n  )\n  def test_pos_enc(self, n, tol):\n    \"\"\"test pos_enc against a stable recursive implementation.\"\"\"\n    x = np.linspace(-np.pi, np.pi, 10001)\n    z = coord.pos_enc(x[:, None], 0, n, append_identity=False)\n    z_stable = stable_pos_enc(x, n)\n    max_err = np.max(np.abs(z - z_stable))\n    print(f'PE of degree {n} has a maximum error of {max_err}')\n    self.assertLess(max_err, tol)\n\n  def test_pos_enc_matches_integrated(self):\n    \"\"\"Integrated positional encoding with a variance of zero must be pos_enc.\"\"\"\n    min_deg = 0\n    max_deg = 10\n    np.linspace(-jnp.pi, jnp.pi, 10)\n    x = jnp.stack(\n        jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1)\n    x = np.linspace(-jnp.pi, jnp.pi, 10000)\n    z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg)\n    z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False)\n    # We're using a pretty wide tolerance because IPE uses safe_sin().\n    np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4)\n\n  def test_track_linearize(self):\n    rng = random.PRNGKey(0)\n    batch_size = 20\n    for _ in range(30):\n      # Construct some random Gaussians with dimensionalities in [1, 10].\n      key, rng = random.split(rng)\n      in_dims = random.randint(key, (), 1, 10)\n      key, rng = random.split(rng)\n      mean = jax.random.normal(key, [batch_size, in_dims])\n      key, rng = random.split(rng)\n      cov = sample_covariance(key, batch_size, in_dims)\n      key, rng = random.split(rng)\n      out_dims = random.randint(key, (), 1, 10)\n\n      # Construct a random affine transformation.\n      key, rng = random.split(rng)\n      a_mat = jax.random.normal(key, [out_dims, in_dims])\n      key, rng = random.split(rng)\n      b = jax.random.normal(key, [out_dims])\n\n      def fn(x):\n        x_vec = x.reshape([-1, x.shape[-1]])\n        y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b  # pylint:disable=cell-var-from-loop\n        y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]])\n        return y\n\n      # Apply the affine function to the Gaussians.\n      fn_mean_true = fn(mean)\n      fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T)\n\n      # Tracking the Gaussians through a linearized function of a linear\n      # operator should be the same.\n      fn_mean, fn_cov = coord.track_linearize(fn, mean, cov)\n      np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5)\n      np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5)\n\n  @parameterized.named_parameters(('reciprocal', jnp.reciprocal),\n                                  ('log', jnp.log), ('sqrt', jnp.sqrt))\n  def test_construct_ray_warps_extents(self, fn):\n    n = 100\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    t_near = jnp.exp(jax.random.normal(key, [n]))\n    key, rng = random.split(rng)\n    t_far = t_near + jnp.exp(jax.random.normal(key, [n]))\n\n    t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far)\n\n    np.testing.assert_allclose(\n        t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5)\n\n  def test_construct_ray_warps_special_reciprocal(self):\n    \"\"\"Test fn=1/x against its closed form.\"\"\"\n    n = 100\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    t_near = jnp.exp(jax.random.normal(key, [n]))\n    key, rng = random.split(rng)\n    t_far = t_near + jnp.exp(jax.random.normal(key, [n]))\n\n    key, rng = random.split(rng)\n    u = jax.random.uniform(key, [n])\n    t = t_near * (1 - u) + t_far * u\n    key, rng = random.split(rng)\n    s = jax.random.uniform(key, [n])\n\n    t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far)\n\n    # Special cases for fn=reciprocal.\n    s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near)\n    t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near))\n\n    np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5)\n\n  def test_expected_sin(self):\n    normal_samples = random.normal(random.PRNGKey(0), (10000,))\n    for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]:\n      sin_mu = coord.expected_sin(mu, var)\n      x = jnp.sin(jnp.sqrt(var) * normal_samples + mu)\n      np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2)\n\n  def test_integrated_pos_enc(self):\n    num_dims = 2  # The number of input dimensions.\n    min_deg = 0  # Must be 0 for this test to work.\n    max_deg = 4\n    num_samples = 100000\n    rng = random.PRNGKey(0)\n    for _ in range(5):\n      # Generate a coordinate's mean and covariance matrix.\n      key, rng = random.split(rng)\n      mean = random.normal(key, (2,))\n      key, rng = random.split(rng)\n      half_cov = jax.random.normal(key, [num_dims] * 2)\n      cov = half_cov @ half_cov.T\n      var = jnp.diag(cov)\n      # Generate an IPE.\n      enc = coord.integrated_pos_enc(\n          mean,\n          var,\n          min_deg,\n          max_deg,\n      )\n\n      # Draw samples, encode them, and take their mean.\n      key, rng = random.split(rng)\n      samples = random.multivariate_normal(key, mean, cov, [num_samples])\n      assert min_deg == 0\n      enc_samples = np.concatenate(\n          [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1)\n      # Correct for a different dimension ordering in stable_pos_enc.\n      enc_gt = jnp.mean(enc_samples, 0)\n      enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1])\n      np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/datasets_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for datasets.\"\"\"\n\nfrom absl.testing import absltest\nfrom internal import camera_utils\nfrom internal import configs\nfrom internal import datasets\nfrom jax import random\nimport numpy as np\n\n\nclass DummyDataset(datasets.Dataset):\n\n  def _load_renderings(self, config):\n    \"\"\"Generates dummy image and pose data.\"\"\"\n    self._n_examples = 2\n    self.height = 3\n    self.width = 4\n    self._resolution = self.height * self.width\n    self.focal = 5.\n    self.pixtocams = np.linalg.inv(\n        camera_utils.intrinsic_matrix(self.focal, self.focal, self.width * 0.5,\n                                      self.height * 0.5))\n\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    images_shape = (self._n_examples, self.height, self.width, 3)\n    self.images = random.uniform(key, images_shape)\n\n    key, rng = random.split(rng)\n    self.camtoworlds = np.stack([\n        camera_utils.viewmatrix(*random.normal(k, (3, 3)))\n        for k in random.split(key, self._n_examples)\n    ],\n                                axis=0)\n\n\nclass DatasetsTest(absltest.TestCase):\n\n  def test_dataset_batch_creation(self):\n    np.random.seed(0)\n    config = configs.Config(batch_size=8)\n\n    # Check shapes are consistent across all ray attributes.\n    for split in ['train', 'test']:\n      dummy_dataset = DummyDataset(split, '', config)\n      rays = dummy_dataset.peek().rays\n      sh_gt = rays.origins.shape[:-1]\n      for z in rays.__dict__.values():\n        if z is not None:\n          self.assertEqual(z.shape[:-1], sh_gt)\n\n    # Check test batch generation matches golden data.\n    dummy_dataset = DummyDataset('test', '', config)\n    batch = dummy_dataset.peek()\n\n    rgb = batch.rgb.ravel()\n    rgb_gt = np.array([\n        0.5289556, 0.28869557, 0.24527192, 0.12083626, 0.8904066, 0.6259936,\n        0.57573485, 0.09355974, 0.8017353, 0.538651, 0.4998169, 0.42061496,\n        0.5591258, 0.00577283, 0.6804651, 0.9139203, 0.00444758, 0.96962905,\n        0.52956843, 0.38282406, 0.28777933, 0.6640035, 0.39736128, 0.99495006,\n        0.13100398, 0.7597165, 0.8532667, 0.67468107, 0.6804743, 0.26873016,\n        0.60699487, 0.5722265, 0.44482303, 0.6511061, 0.54807067, 0.09894073\n    ])\n    np.testing.assert_allclose(rgb, rgb_gt, atol=1e-4, rtol=1e-4)\n\n    ray_origins = batch.rays.origins.ravel()\n    ray_origins_gt = np.array([\n        -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472,\n        -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469,\n        -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224,\n        -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472,\n        -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469,\n        -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224,\n        -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224\n    ])\n    np.testing.assert_allclose(\n        ray_origins, ray_origins_gt, atol=1e-4, rtol=1e-4)\n\n    ray_dirs = batch.rays.directions.ravel()\n    ray_dirs_gt = np.array([\n        0.24370372, 0.89296186, -0.5227117, 0.05601424, 0.8468699, -0.57417226,\n        -0.13167524, 0.8007779, -0.62563276, -0.31936473, 0.75468594,\n        -0.67709327, 0.17780769, 0.96766925, -0.34928587, -0.0098818, 0.9215773,\n        -0.4007464, -0.19757128, 0.87548524, -0.4522069, -0.38526076,\n        0.82939327, -0.5036674, 0.11191163, 1.0423766, -0.17586003, -0.07577785,\n        0.9962846, -0.22732055, -0.26346734, 0.95019263, -0.2787811,\n        -0.45115682, 0.90410066, -0.3302416\n    ])\n    np.testing.assert_allclose(ray_dirs, ray_dirs_gt, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/geopoly_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for geopoly.\"\"\"\nimport itertools\n\nfrom absl.testing import absltest\nfrom internal import geopoly\nimport jax\nfrom jax import random\nimport numpy as np\n\n\ndef is_same_basis(x, y, tol=1e-10):\n  \"\"\"Check if `x` and `y` describe the same linear basis.\"\"\"\n  match = np.minimum(\n      geopoly.compute_sq_dist(x, y), geopoly.compute_sq_dist(x, -y)) <= tol\n  return (np.all(np.array(x.shape) == np.array(y.shape)) and\n          np.all(np.sum(match, axis=0) == 1) and\n          np.all(np.sum(match, axis=1) == 1))\n\n\nclass GeopolyTest(absltest.TestCase):\n\n  def test_compute_sq_dist_reference(self):\n    \"\"\"Test against a simple reimplementation of compute_sq_dist.\"\"\"\n    num_points = 100\n    num_dims = 10\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    mat0 = jax.random.normal(key, [num_dims, num_points])\n    key, rng = random.split(rng)\n    mat1 = jax.random.normal(key, [num_dims, num_points])\n\n    sq_dist = geopoly.compute_sq_dist(mat0, mat1)\n\n    sq_dist_ref = np.zeros([num_points, num_points])\n    for i in range(num_points):\n      for j in range(num_points):\n        sq_dist_ref[i, j] = np.sum((mat0[:, i] - mat1[:, j])**2)\n\n    np.testing.assert_allclose(sq_dist, sq_dist_ref, atol=1e-4, rtol=1e-4)\n\n  def test_compute_sq_dist_single_input(self):\n    \"\"\"Test that compute_sq_dist with a single input works correctly.\"\"\"\n    rng = random.PRNGKey(0)\n    num_points = 100\n    num_dims = 10\n    key, rng = random.split(rng)\n    mat0 = jax.random.normal(key, [num_dims, num_points])\n\n    sq_dist = geopoly.compute_sq_dist(mat0)\n    sq_dist_ref = geopoly.compute_sq_dist(mat0, mat0)\n    np.testing.assert_allclose(sq_dist, sq_dist_ref)\n\n  def test_compute_tesselation_weights_reference(self):\n    \"\"\"A reference implementation for triangle tesselation.\"\"\"\n    for v in range(1, 10):\n      w = geopoly.compute_tesselation_weights(v)\n      perm = np.array(list(itertools.product(range(v + 1), repeat=3)))\n      w_ref = perm[np.sum(perm, axis=-1) == v, :] / v\n      # Check that all rows of x are close to some row in x_ref.\n      self.assertTrue(is_same_basis(w.T, w_ref.T))\n\n  def test_generate_basis_golden(self):\n    \"\"\"A mediocre golden test against two arbitrary basis choices.\"\"\"\n    basis = geopoly.generate_basis('icosahedron', 2)\n    basis_golden = np.array([[0.85065081, 0.00000000, 0.52573111],\n                             [0.80901699, 0.50000000, 0.30901699],\n                             [0.52573111, 0.85065081, 0.00000000],\n                             [1.00000000, 0.00000000, 0.00000000],\n                             [0.80901699, 0.50000000, -0.30901699],\n                             [0.85065081, 0.00000000, -0.52573111],\n                             [0.30901699, 0.80901699, -0.50000000],\n                             [0.00000000, 0.52573111, -0.85065081],\n                             [0.50000000, 0.30901699, -0.80901699],\n                             [0.00000000, 1.00000000, 0.00000000],\n                             [-0.52573111, 0.85065081, 0.00000000],\n                             [-0.30901699, 0.80901699, -0.50000000],\n                             [0.00000000, 0.52573111, 0.85065081],\n                             [-0.30901699, 0.80901699, 0.50000000],\n                             [0.30901699, 0.80901699, 0.50000000],\n                             [0.50000000, 0.30901699, 0.80901699],\n                             [0.50000000, -0.30901699, 0.80901699],\n                             [0.00000000, 0.00000000, 1.00000000],\n                             [-0.50000000, 0.30901699, 0.80901699],\n                             [-0.80901699, 0.50000000, 0.30901699],\n                             [-0.80901699, 0.50000000, -0.30901699]])\n    self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n    basis = geopoly.generate_basis('octahedron', 4)\n    basis_golden = np.array([[0.00000000, 0.00000000, -1.00000000],\n                             [0.00000000, -0.31622777, -0.94868330],\n                             [0.00000000, -0.70710678, -0.70710678],\n                             [0.00000000, -0.94868330, -0.31622777],\n                             [0.00000000, -1.00000000, 0.00000000],\n                             [-0.31622777, 0.00000000, -0.94868330],\n                             [-0.40824829, -0.40824829, -0.81649658],\n                             [-0.40824829, -0.81649658, -0.40824829],\n                             [-0.31622777, -0.94868330, 0.00000000],\n                             [-0.70710678, 0.00000000, -0.70710678],\n                             [-0.81649658, -0.40824829, -0.40824829],\n                             [-0.70710678, -0.70710678, 0.00000000],\n                             [-0.94868330, 0.00000000, -0.31622777],\n                             [-0.94868330, -0.31622777, 0.00000000],\n                             [-1.00000000, 0.00000000, 0.00000000],\n                             [0.00000000, -0.31622777, 0.94868330],\n                             [0.00000000, -0.70710678, 0.70710678],\n                             [0.00000000, -0.94868330, 0.31622777],\n                             [0.40824829, -0.40824829, 0.81649658],\n                             [0.40824829, -0.81649658, 0.40824829],\n                             [0.31622777, -0.94868330, 0.00000000],\n                             [0.81649658, -0.40824829, 0.40824829],\n                             [0.70710678, -0.70710678, 0.00000000],\n                             [0.94868330, -0.31622777, 0.00000000],\n                             [0.31622777, 0.00000000, -0.94868330],\n                             [0.40824829, 0.40824829, -0.81649658],\n                             [0.40824829, 0.81649658, -0.40824829],\n                             [0.70710678, 0.00000000, -0.70710678],\n                             [0.81649658, 0.40824829, -0.40824829],\n                             [0.94868330, 0.00000000, -0.31622777],\n                             [0.40824829, -0.40824829, -0.81649658],\n                             [0.40824829, -0.81649658, -0.40824829],\n                             [0.81649658, -0.40824829, -0.40824829]])\n    self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/image_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for image.\"\"\"\n\nfrom absl.testing import absltest\nfrom internal import image\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\n\ndef matmul(a, b):\n  \"\"\"jnp.matmul defaults to bfloat16, but this helper function doesn't.\"\"\"\n  return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST)\n\n\nclass ImageTest(absltest.TestCase):\n\n  def test_color_correction(self):\n    \"\"\"Test that color correction can undo a CCM + quadratic warp + shift.\"\"\"\n    im_shape = (128, 128, 3)\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      # Construct a random image.\n      key, rng = random.split(rng)\n      im0 = random.uniform(key, shape=im_shape, minval=0.1, maxval=0.9)\n\n      # Construct a random linear + quadratic color transformation.\n      key, rng = random.split(rng)\n      ccm_scale = random.normal(key) / 10\n      key, rng = random.split(rng)\n      shift = random.normal(key) / 10\n      key, rng = random.split(rng)\n      sq_mult = random.normal(key) / 10\n      key, rng = random.split(rng)\n      ccm = jnp.eye(3) + random.normal(key, shape=(3, 3)) * ccm_scale\n\n      # Apply that random transformation to the image.\n      im1 = jnp.clip(\n          (matmul(jnp.reshape(im0, [-1, 3]), ccm)).reshape(im0.shape) +\n          sq_mult * im0**2 + shift, 0, 1)\n\n      # Check that color correction recovers the randomly transformed image.\n      im0_cc = image.color_correct(im0, im1)\n      np.testing.assert_allclose(im0_cc, im1, atol=1E-5, rtol=1E-5)\n\n  def test_psnr_mse_round_trip(self):\n    \"\"\"PSNR -> MSE -> PSNR is a no-op.\"\"\"\n    for psnr in [10., 20., 30.]:\n      np.testing.assert_allclose(\n          image.mse_to_psnr(image.psnr_to_mse(psnr)),\n          psnr,\n          atol=1E-5,\n          rtol=1E-5)\n\n  def test_ssim_dssim_round_trip(self):\n    \"\"\"SSIM -> DSSIM -> SSIM is a no-op.\"\"\"\n    for ssim in [-0.9, 0, 0.9]:\n      np.testing.assert_allclose(\n          image.dssim_to_ssim(image.ssim_to_dssim(ssim)),\n          ssim,\n          atol=1E-5,\n          rtol=1E-5)\n\n  def test_srgb_linearize(self):\n    x = jnp.linspace(-1, 3, 10000)  # Nobody should call this <0 but it works.\n    # Check that the round-trip transformation is a no-op.\n    np.testing.assert_allclose(\n        image.linear_to_srgb(image.srgb_to_linear(x)), x, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        image.srgb_to_linear(image.linear_to_srgb(x)), x, atol=1E-5, rtol=1E-5)\n    # Check that gradients are finite.\n    self.assertTrue(\n        jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.linear_to_srgb))(x))))\n    self.assertTrue(\n        jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.srgb_to_linear))(x))))\n\n  def test_srgb_to_linear_golden(self):\n    \"\"\"A lazy golden test for srgb_to_linear.\"\"\"\n    srgb = jnp.linspace(0, 1, 64)\n    linear = image.srgb_to_linear(srgb)\n    linear_gt = jnp.array([\n        0.00000000, 0.00122856, 0.00245712, 0.00372513, 0.00526076, 0.00711347,\n        0.00929964, 0.01183453, 0.01473243, 0.01800687, 0.02167065, 0.02573599,\n        0.03021459, 0.03511761, 0.04045585, 0.04623971, 0.05247922, 0.05918410,\n        0.06636375, 0.07402734, 0.08218378, 0.09084171, 0.10000957, 0.10969563,\n        0.11990791, 0.13065430, 0.14194246, 0.15377994, 0.16617411, 0.17913227,\n        0.19266140, 0.20676863, 0.22146071, 0.23674440, 0.25262633, 0.26911288,\n        0.28621066, 0.30392596, 0.32226467, 0.34123330, 0.36083785, 0.38108405,\n        0.40197787, 0.42352500, 0.44573134, 0.46860245, 0.49214387, 0.51636110,\n        0.54125960, 0.56684470, 0.59312177, 0.62009590, 0.64777250, 0.67615650,\n        0.70525320, 0.73506740, 0.76560410, 0.79686830, 0.82886493, 0.86159873,\n        0.89507430, 0.92929670, 0.96427040, 1.00000000\n    ])\n    np.testing.assert_allclose(linear, linear_gt, atol=1E-5, rtol=1E-5)\n\n  def test_mse_to_psnr_golden(self):\n    \"\"\"A lazy golden test for mse_to_psnr.\"\"\"\n    mse = jnp.exp(jnp.linspace(-10, 0, 64))\n    psnr = image.mse_to_psnr(mse)\n    psnr_gt = jnp.array([\n        43.429447, 42.740090, 42.050735, 41.361378, 40.6720240, 39.982666,\n        39.293310, 38.603954, 37.914597, 37.225240, 36.5358850, 35.846527,\n        35.157170, 34.467810, 33.778458, 33.089100, 32.3997460, 31.710388,\n        31.021034, 30.331675, 29.642320, 28.952961, 28.2636070, 27.574250,\n        26.884893, 26.195538, 25.506180, 24.816826, 24.1274700, 23.438112,\n        22.748756, 22.059400, 21.370045, 20.680689, 19.9913310, 19.301975,\n        18.612620, 17.923262, 17.233906, 16.544550, 15.8551940, 15.165837,\n        14.4764805, 13.787125, 13.097769, 12.408413, 11.719056, 11.029700,\n        10.3403420, 9.6509850, 8.9616290, 8.2722720, 7.5829163, 6.8935600,\n        6.2042036, 5.5148473, 4.825491, 4.136135, 3.4467785, 2.7574227,\n        2.0680661, 1.37871, 0.68935364, 0.\n    ])\n    np.testing.assert_allclose(psnr, psnr_gt, atol=1E-5, rtol=1E-5)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/math_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for math.\"\"\"\n\nimport functools\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom internal import math\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\n\ndef safe_trig_harness(fn, max_exp):\n  x = 10**np.linspace(-30, max_exp, 10000)\n  x = np.concatenate([-x[::-1], np.array([0]), x])\n  y_true = getattr(np, fn)(x)\n  y = getattr(math, 'safe_' + fn)(x)\n  return y_true, y\n\n\nclass MathTest(parameterized.TestCase):\n\n  def test_sin(self):\n    \"\"\"In [-1e10, 1e10] safe_sin and safe_cos are accurate.\"\"\"\n    for fn in ['sin', 'cos']:\n      y_true, y = safe_trig_harness(fn, 10)\n      self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4)\n      self.assertFalse(jnp.any(jnp.isnan(y)))\n    # Beyond that range it's less accurate but we just don't want it to be NaN.\n    for fn in ['sin', 'cos']:\n      y_true, y = safe_trig_harness(fn, 60)\n      self.assertFalse(jnp.any(jnp.isnan(y)))\n\n  def test_safe_exp_correct(self):\n    \"\"\"math.safe_exp() should match np.exp() for not-huge values.\"\"\"\n    x = jnp.linspace(-80, 80, 10001)\n    y = math.safe_exp(x)\n    g = jax.vmap(jax.grad(math.safe_exp))(x)\n    yg_true = jnp.exp(x)\n    np.testing.assert_allclose(y, yg_true)\n    np.testing.assert_allclose(g, yg_true)\n\n  def test_safe_exp_finite(self):\n    \"\"\"math.safe_exp() behaves reasonably for huge values.\"\"\"\n    x = jnp.linspace(-100000, 100000, 10001)\n    y = math.safe_exp(x)\n    g = jax.vmap(jax.grad(math.safe_exp))(x)\n    # `y` and `g` should both always be finite.\n    self.assertTrue(jnp.all(jnp.isfinite(y)))\n    self.assertTrue(jnp.all(jnp.isfinite(g)))\n    # The derivative of exp() should be exp().\n    np.testing.assert_allclose(y, g)\n    # safe_exp()'s output and gradient should be monotonic.\n    self.assertTrue(jnp.all(y[1:] >= y[:-1]))\n    self.assertTrue(jnp.all(g[1:] >= g[:-1]))\n\n  def test_learning_rate_decay(self):\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      lr_init = jnp.exp(random.normal(key) - 3)\n      key, rng = random.split(rng)\n      lr_final = lr_init * jnp.exp(random.normal(key) - 5)\n      key, rng = random.split(rng)\n      max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key))))\n\n      lr_fn = functools.partial(\n          math.learning_rate_decay,\n          lr_init=lr_init,\n          lr_final=lr_final,\n          max_steps=max_steps)\n\n      # Test that the rate at the beginning is the initial rate.\n      np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5)\n\n      # Test that the rate at the end is the final rate.\n      np.testing.assert_allclose(\n          lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5)\n\n      # Test that the rate at the middle is the geometric mean of the two rates.\n      np.testing.assert_allclose(\n          lr_fn(max_steps / 2),\n          jnp.sqrt(lr_init * lr_final),\n          atol=1E-5,\n          rtol=1E-5)\n\n      # Test that the rate past the end is the final rate\n      np.testing.assert_allclose(\n          lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5)\n\n  def test_delayed_learning_rate_decay(self):\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      lr_init = jnp.exp(random.normal(key) - 3)\n      key, rng = random.split(rng)\n      lr_final = lr_init * jnp.exp(random.normal(key) - 5)\n      key, rng = random.split(rng)\n      max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key))))\n      key, rng = random.split(rng)\n      lr_delay_steps = int(\n          random.uniform(key, minval=0.1, maxval=0.4) * max_steps)\n      key, rng = random.split(rng)\n      lr_delay_mult = jnp.exp(random.normal(key) - 3)\n\n      lr_fn = functools.partial(\n          math.learning_rate_decay,\n          lr_init=lr_init,\n          lr_final=lr_final,\n          max_steps=max_steps,\n          lr_delay_steps=lr_delay_steps,\n          lr_delay_mult=lr_delay_mult)\n\n      # Test that the rate at the beginning is the delayed initial rate.\n      np.testing.assert_allclose(\n          lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5)\n\n      # Test that the rate at the end is the final rate.\n      np.testing.assert_allclose(\n          lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5)\n\n      # Test that the rate at after the delay is over is the usual rate.\n      np.testing.assert_allclose(\n          lr_fn(lr_delay_steps),\n          math.learning_rate_decay(lr_delay_steps, lr_init, lr_final,\n                                   max_steps),\n          atol=1E-5,\n          rtol=1E-5)\n\n      # Test that the rate at the middle is the geometric mean of the two rates.\n      np.testing.assert_allclose(\n          lr_fn(max_steps / 2),\n          jnp.sqrt(lr_init * lr_final),\n          atol=1E-5,\n          rtol=1E-5)\n\n      # Test that the rate past the end is the final rate\n      np.testing.assert_allclose(\n          lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5)\n\n  @parameterized.named_parameters(('', False), ('sort', True))\n  def test_interp(self, sort):\n    n, d0, d1 = 100, 10, 20\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    x = random.normal(key, [n, d0])\n\n    key, rng = random.split(rng)\n    xp = random.normal(key, [n, d1])\n\n    key, rng = random.split(rng)\n    fp = random.normal(key, [n, d1])\n\n    if sort:\n      xp = jnp.sort(xp, axis=-1)\n      fp = jnp.sort(fp, axis=-1)\n      z = math.sorted_interp(x, xp, fp)\n    else:\n      z = math.interp(x, xp, fp)\n\n    z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)])\n    np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/ref_utils_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for ref_utils.\"\"\"\n\nfrom absl.testing import absltest\nfrom internal import ref_utils\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\nimport scipy\n\n\ndef generate_dir_enc_fn_scipy(deg_view):\n  \"\"\"Return spherical harmonics using scipy.special.sph_harm.\"\"\"\n  ml_array = ref_utils.get_ml_array(deg_view)\n\n  def dir_enc_fn(theta, phi):\n    de = [scipy.special.sph_harm(m, l, phi, theta) for m, l in ml_array.T]\n    de = np.stack(de, axis=-1)\n    # Split into real and imaginary parts.\n    return np.concatenate([np.real(de), np.imag(de)], axis=-1)\n\n  return dir_enc_fn\n\n\nclass RefUtilsTest(absltest.TestCase):\n\n  def test_reflection(self):\n    \"\"\"Make sure reflected vectors have the same angle from normals as input.\"\"\"\n    rng = random.PRNGKey(0)\n    for shape in [(45, 3), (4, 7, 3)]:\n      key, rng = random.split(rng)\n      normals = random.normal(key, shape)\n      key, rng = random.split(rng)\n      directions = random.normal(key, shape)\n\n      # Normalize normal vectors.\n      normals = normals / (\n          jnp.linalg.norm(normals, axis=-1, keepdims=True) + 1e-10)\n\n      reflected_directions = ref_utils.reflect(directions, normals)\n\n      cos_angle_original = jnp.sum(directions * normals, axis=-1)\n      cos_angle_reflected = jnp.sum(reflected_directions * normals, axis=-1)\n\n      np.testing.assert_allclose(\n          cos_angle_original, cos_angle_reflected, atol=1E-5, rtol=1E-5)\n\n  def test_spherical_harmonics(self):\n    \"\"\"Make sure the fast spherical harmonics are accurate.\"\"\"\n    shape = (12, 11, 13)\n\n    # Generate random points on sphere.\n    rng = random.PRNGKey(0)\n    key1, key2 = random.split(rng)\n    theta = random.uniform(key1, shape, minval=0.0, maxval=jnp.pi)\n    phi = random.uniform(key2, shape, minval=0.0, maxval=2.0*jnp.pi)\n\n    # Convert to Cartesian coordinates.\n    x = jnp.sin(theta) * jnp.cos(phi)\n    y = jnp.sin(theta) * jnp.sin(phi)\n    z = jnp.cos(theta)\n    xyz = jnp.stack([x, y, z], axis=-1)\n\n    deg_view = 5\n    de = ref_utils.generate_dir_enc_fn(deg_view)(xyz)\n    de_scipy = generate_dir_enc_fn_scipy(deg_view)(theta, phi)\n\n    np.testing.assert_allclose(\n        de, de_scipy, atol=0.02, rtol=1e6)  # Only use atol.\n    self.assertFalse(jnp.any(jnp.isnan(de)))\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/render_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for render.\"\"\"\n\nimport functools\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom internal import math\nfrom internal import render\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\n\ndef surface_stats(points):\n  \"\"\"Get the sample mean and covariance matrix of a set of matrices [..., d].\"\"\"\n  means = jnp.mean(points, -1)\n  centered = points - means[..., None]\n  covs = jnp.mean(centered[..., None, :, :] * centered[..., :, None, :], -1)\n  return means, covs\n\n\ndef sqrtm(mat):\n  \"\"\"Take the matrix square root of a PSD matrix [..., d, d].\"\"\"\n  eigval, eigvec = jax.scipy.linalg.eigh(mat)\n  scaling = jnp.sqrt(jnp.maximum(0., eigval))[..., None, :]\n  return math.matmul(eigvec * scaling, jnp.moveaxis(eigvec, -2, -1))\n\n\ndef control_points(mean, cov):\n  \"\"\"Construct \"sigma points\" using a matrix sqrt (Cholesky or SVD are fine).\"\"\"\n  sqrtm_cov = sqrtm(cov)  # or could be jax.scipy.linalg.cholesky(cov)\n  offsets = jnp.sqrt(mean.shape[-1] + 0.5) * jnp.concatenate(\n      [jnp.zeros_like(mean[..., None]), sqrtm_cov, -sqrtm_cov], -1)\n  return mean[..., None] + offsets\n\n\ndef inside_conical_frustum(x, d, t0, t1, r, ttol=1e-6, rtol=1e-6):\n  \"\"\"Test if `x` is inside the conical frustum specified by the other inputs.\"\"\"\n  d_normsq = jnp.sum(d**2)\n  d_norm = jnp.sqrt(d_normsq)\n  x_normsq = jnp.sum(x**2, -1)\n  x_norm = jnp.sqrt(x_normsq)\n  xd = math.matmul(x, d)\n  is_inside = (\n      (t0 - ttol) <= xd / d_normsq) & (xd / d_normsq <= (t1 + ttol)) & (\n          (xd / (d_norm * x_norm)) >=\n          (1 / jnp.sqrt(1 + r**2 / d_normsq) - rtol))\n  return is_inside\n\n\ndef sample_conical_frustum(rng, num_samples, d, t0, t1, base_radius):\n  \"\"\"Draw random samples from a conical frustum.\n\n  Args:\n    rng: The RNG seed.\n    num_samples: int, the number of samples to draw.\n    d: jnp.float32 3-vector, the axis of the cone.\n    t0: float, the starting distance of the frustum.\n    t1: float, the ending distance of the frustum.\n    base_radius: float, the scale of the radius as a function of distance.\n\n  Returns:\n    A matrix of samples.\n  \"\"\"\n  key, rng = random.split(rng)\n  u = random.uniform(key, shape=[num_samples])\n  t = (t0**3 * (1 - u) + t1**3 * u)**(1 / 3)\n  key, rng = random.split(rng)\n  theta = random.uniform(key, shape=[num_samples], minval=0, maxval=jnp.pi * 2)\n  key, rng = random.split(rng)\n  r = base_radius * t * jnp.sqrt(random.uniform(key, shape=[num_samples]))\n\n  d_norm = d / jnp.linalg.norm(d)\n  null = jnp.eye(3) - d_norm[:, None] * d_norm[None, :]\n  basis = jnp.linalg.svd(null)[0][:, :2]\n  rot_samples = ((basis[:, 0:1] * r * jnp.cos(theta)) +\n                 (basis[:, 1:2] * r * jnp.sin(theta)) + d[:, None] * t).T\n  return rot_samples\n\n\ndef generate_random_cylinder(rng, num_zs=4):\n  t0, t1 = [], []\n  for _ in range(num_zs):\n    rng, key = random.split(rng)\n    z_mean = random.uniform(key, minval=1.5, maxval=3)\n    rng, key = random.split(rng)\n    z_delta = random.uniform(key, minval=0.1, maxval=.3)\n    t0.append(z_mean - z_delta)\n    t1.append(z_mean + z_delta)\n  t0 = jnp.array(t0)\n  t1 = jnp.array(t1)\n\n  rng, key = random.split(rng)\n  radius = random.uniform(key, minval=0.1, maxval=.2)\n\n  rng, key = random.split(rng)\n  raydir = random.normal(key, [3])\n  raydir = raydir / jnp.sqrt(jnp.sum(raydir**2, -1))\n\n  rng, key = random.split(rng)\n  scale = random.uniform(key, minval=0.4, maxval=1.2)\n  raydir = scale * raydir\n\n  return raydir, t0, t1, radius\n\n\ndef generate_random_conical_frustum(rng, num_zs=4):\n  t0, t1 = [], []\n  for _ in range(num_zs):\n    rng, key = random.split(rng)\n    z_mean = random.uniform(key, minval=1.5, maxval=3)\n    rng, key = random.split(rng)\n    z_delta = random.uniform(key, minval=0.1, maxval=.3)\n    t0.append(z_mean - z_delta)\n    t1.append(z_mean + z_delta)\n  t0 = jnp.array(t0)\n  t1 = jnp.array(t1)\n\n  rng, key = random.split(rng)\n  r = random.uniform(key, minval=0.01, maxval=.05)\n\n  rng, key = random.split(rng)\n  raydir = random.normal(key, [3])\n  raydir = raydir / jnp.sqrt(jnp.sum(raydir**2, -1))\n\n  rng, key = random.split(rng)\n  scale = random.uniform(key, minval=0.8, maxval=1.2)\n  raydir = scale * raydir\n\n  return raydir, t0, t1, r\n\n\ndef cylinder_to_gaussian_sample(key,\n                                raydir,\n                                t0,\n                                t1,\n                                radius,\n                                padding=1,\n                                num_samples=1000000):\n  # Sample uniformly from a cube that surrounds the entire conical frustom.\n  z_max = max(t0, t1)\n  samples = random.uniform(\n      key, [num_samples, 3],\n      minval=jnp.min(raydir) * z_max - padding,\n      maxval=jnp.max(raydir) * z_max + padding)\n\n  # Grab only the points within the cylinder.\n  raydir_magsq = jnp.sum(raydir**2, -1, keepdims=True)\n  proj = (raydir * (samples @ raydir)[:, None]) / raydir_magsq\n  dist = samples @ raydir\n  mask = (dist >= raydir_magsq * t0) & (dist <= raydir_magsq * t1) & (\n      jnp.sum((proj - samples)**2, -1) < radius**2)\n  samples = samples[mask, :]\n\n  # Compute their mean and covariance.\n  mean = jnp.mean(samples, 0)\n  cov = jnp.cov(samples.T, bias=False)\n  return mean, cov\n\n\ndef conical_frustum_to_gaussian_sample(key, raydir, t0, t1, r):\n  \"\"\"A brute-force numerical approximation to conical_frustum_to_gaussian().\"\"\"\n  # Sample uniformly from a cube that surrounds the entire conical frustum.\n  samples = sample_conical_frustum(key, 100000, raydir, t0, t1, r)\n  # Compute their mean and covariance.\n  return surface_stats(samples.T)\n\n\ndef finite_output_and_gradient(fn, args):\n  \"\"\"True if fn(*args) and all of its gradients are finite.\"\"\"\n  vals = fn(*args)\n  is_finite = True\n  for do, v in enumerate(vals):\n    is_finite &= jnp.all(jnp.isfinite(v))\n    # pylint: disable=cell-var-from-loop\n    grads = jax.grad(\n        lambda *x: jnp.sum(fn(*x)[do]), argnums=range(len(args)))(*args)\n    for g in grads:\n      is_finite &= jnp.all(jnp.isfinite(g))\n  return is_finite\n\n\nclass RenderTest(parameterized.TestCase):\n\n  def test_cylinder_scaling(self):\n    d = jnp.array([0., 0., 1.])\n    t0 = jnp.array([0.3])\n    t1 = jnp.array([0.7])\n    radius = jnp.array([0.4])\n    mean, cov = render.cylinder_to_gaussian(\n        d,\n        t0,\n        t1,\n        radius,\n        False,\n    )\n    scale = 2.7\n    scaled_mean, scaled_cov = render.cylinder_to_gaussian(\n        scale * d,\n        t0,\n        t1,\n        radius,\n        False,\n    )\n    np.testing.assert_allclose(scale * mean, scaled_mean, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        scale**2 * cov[2, 2], scaled_cov[2, 2], atol=1E-5, rtol=1E-5)\n    control = control_points(mean, cov)[0]\n    control_scaled = control_points(scaled_mean, scaled_cov)[0]\n    np.testing.assert_allclose(\n        control[:2, :], control_scaled[:2, :], atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        control[2, :] * scale, control_scaled[2, :], atol=1E-5, rtol=1E-5)\n\n  def test_conical_frustum_scaling(self):\n    d = jnp.array([0., 0., 1.])\n    t0 = jnp.array([0.3])\n    t1 = jnp.array([0.7])\n    radius = jnp.array([0.4])\n    mean, cov = render.conical_frustum_to_gaussian(\n        d,\n        t0,\n        t1,\n        radius,\n        False,\n    )\n    scale = 2.7\n    scaled_mean, scaled_cov = render.conical_frustum_to_gaussian(\n        scale * d,\n        t0,\n        t1,\n        radius,\n        False,\n    )\n    np.testing.assert_allclose(scale * mean, scaled_mean, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        scale**2 * cov[2, 2], scaled_cov[2, 2], atol=1E-5, rtol=1E-5)\n    control = control_points(mean, cov)[0]\n    control_scaled = control_points(scaled_mean, scaled_cov)[0]\n    np.testing.assert_allclose(\n        control[:2, :], control_scaled[:2, :], atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        control[2, :] * scale, control_scaled[2, :], atol=1E-5, rtol=1E-5)\n\n  def test_control_points(self):\n    rng = random.PRNGKey(0)\n    batch_size = 10\n    for num_dims in [1, 2, 3]:\n      key, rng = random.split(rng)\n      mean = jax.random.normal(key, [batch_size, num_dims])\n      key, rng = random.split(rng)\n      half_cov = jax.random.normal(key, [batch_size] + [num_dims] * 2)\n      cov = half_cov @ jnp.moveaxis(half_cov, -1, -2)\n\n      sqrtm_cov = sqrtm(cov)\n      np.testing.assert_allclose(\n          sqrtm_cov @ sqrtm_cov, cov, atol=1e-5, rtol=1E-5)\n\n      points = control_points(mean, cov)\n      mean_recon, cov_recon = surface_stats(points)\n      np.testing.assert_allclose(mean, mean_recon, atol=1E-5, rtol=1E-5)\n      np.testing.assert_allclose(cov, cov_recon, atol=1e-5, rtol=1E-5)\n\n  def test_conical_frustum(self):\n    rng = random.PRNGKey(0)\n    data = []\n    for _ in range(10):\n      key, rng = random.split(rng)\n      raydir, t0, t1, r = generate_random_conical_frustum(key)\n      i_results = []\n      for i_t0, i_t1 in zip(t0, t1):\n        key, rng = random.split(rng)\n        i_results.append(\n            conical_frustum_to_gaussian_sample(key, raydir, i_t0, i_t1, r))\n      mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*i_results)]\n      data.append((raydir, t0, t1, r, mean_gt, cov_gt))\n    raydir, t0, t1, r, mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*data)]\n    diag_cov_gt = jax.vmap(jax.vmap(jnp.diag))(cov_gt)\n    for diag in [False, True]:\n      for stable in [False, True]:\n        mean, cov = render.conical_frustum_to_gaussian(\n            raydir, t0, t1, r[..., None], diag, stable=stable)\n        np.testing.assert_allclose(mean, mean_gt, atol=0.001)\n        if diag:\n          np.testing.assert_allclose(cov, diag_cov_gt, atol=0.0002)\n        else:\n          np.testing.assert_allclose(cov, cov_gt, atol=0.0002)\n\n  def test_inside_conical_frustum(self):\n    \"\"\"This test only tests helper functions used by other tests.\"\"\"\n    rng = random.PRNGKey(0)\n    for _ in range(20):\n      key, rng = random.split(rng)\n      d, t0, t1, r = generate_random_conical_frustum(key, num_zs=1)\n      key, rng = random.split(rng)\n      # Sample some points.\n      samples = sample_conical_frustum(key, 1000000, d, t0, t1, r)\n      # Check that they're all inside.\n      check = lambda x: inside_conical_frustum(x, d, t0, t1, r)  # pylint: disable=cell-var-from-loop\n      self.assertTrue(jnp.all(check(samples)))\n      # Check that wiggling them a little puts some outside (potentially flaky).\n      self.assertFalse(jnp.all(check(samples + 1e-3)))\n      self.assertFalse(jnp.all(check(samples - 1e-3)))\n\n  def test_conical_frustum_stable(self):\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d, t0, t1, r = generate_random_conical_frustum(key)\n      for diag in [False, True]:\n        mean, cov = render.conical_frustum_to_gaussian(\n            d, t0, t1, r, diag, stable=False)\n        mean_stable, cov_stable = render.conical_frustum_to_gaussian(\n            d, t0, t1, r, diag, stable=True)\n        np.testing.assert_allclose(mean, mean_stable, atol=1e-7, rtol=1E-5)\n        np.testing.assert_allclose(cov, cov_stable, atol=1e-5, rtol=1E-5)\n\n  def test_cylinder(self):\n    rng = random.PRNGKey(0)\n    data = []\n    for _ in range(10):\n      key, rng = random.split(rng)\n      raydir, t0, t1, radius = generate_random_cylinder(rng)\n      key, rng = random.split(rng)\n      i_results = []\n      for i_t0, i_t1 in zip(t0, t1):\n        i_results.append(\n            cylinder_to_gaussian_sample(key, raydir, i_t0, i_t1, radius))\n      mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*i_results)]\n      data.append((raydir, t0, t1, radius, mean_gt, cov_gt))\n    raydir, t0, t1, radius, mean_gt, cov_gt = [\n        jnp.stack(x, 0) for x in zip(*data)\n    ]\n    mean, cov = (\n        render.cylinder_to_gaussian(raydir, t0, t1, radius[..., None], False))\n    np.testing.assert_allclose(mean, mean_gt, atol=0.1)\n    np.testing.assert_allclose(cov, cov_gt, atol=0.01)\n\n  def test_lift_gaussian_diag(self):\n    dims, n, m = 3, 10, 4\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    d = random.normal(key, [n, dims])\n    key, rng = random.split(rng)\n    z_mean = random.normal(key, [n, m])\n    key, rng = random.split(rng)\n    z_var = jnp.exp(random.normal(key, [n, m]))\n    key, rng = random.split(rng)\n    xy_var = jnp.exp(random.normal(key, [n, m]))\n    mean, cov = render.lift_gaussian(d, z_mean, z_var, xy_var, diag=False)\n    mean_diag, cov_diag = render.lift_gaussian(\n        d, z_mean, z_var, xy_var, diag=True)\n    np.testing.assert_allclose(mean, mean_diag, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(\n        jax.vmap(jax.vmap(jnp.diag))(cov), cov_diag, atol=1E-5, rtol=1E-5)\n\n  def test_rotated_conic_frustums(self):\n    # Test that conic frustum Gaussians are closed under rotation.\n    diag = False\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      rng, key = random.split(rng)\n      z_mean = random.uniform(key, minval=1.5, maxval=3)\n      rng, key = random.split(rng)\n      z_delta = random.uniform(key, minval=0.1, maxval=.3)\n      t0 = jnp.array(z_mean - z_delta)\n      t1 = jnp.array(z_mean + z_delta)\n\n      rng, key = random.split(rng)\n      r = random.uniform(key, minval=0.1, maxval=.2)\n\n      rng, key = random.split(rng)\n      d = random.normal(key, [3])\n\n      mean, cov = render.conical_frustum_to_gaussian(d, t0, t1, r, diag)\n\n      # Make a random rotation matrix.\n      rng, key = random.split(rng)\n      x = random.normal(key, [10, 3])\n      rot_mat = x.T @ x\n      u, _, v = jnp.linalg.svd(rot_mat)\n      rot_mat = u @ v.T\n\n      mean, cov = render.conical_frustum_to_gaussian(d, t0, t1, r, diag)\n      rot_mean, rot_cov = render.conical_frustum_to_gaussian(\n          rot_mat @ d, t0, t1, r, diag)\n      gt_rot_mean, gt_rot_cov = surface_stats(\n          rot_mat @ control_points(mean, cov))\n\n      np.testing.assert_allclose(rot_mean, gt_rot_mean, atol=1E-4, rtol=1E-4)\n      np.testing.assert_allclose(rot_cov, gt_rot_cov, atol=1E-4, rtol=1E-4)\n\n  @parameterized.named_parameters(\n      ('-100 -100', -100, -100),\n      ('-100 -10', -100, -10),\n      ('-100  0', -100, 0),\n      ('-100  10', -100, 10),\n      ('-10  -100', -10, -100),\n      ('-10  -10', -10, -10),\n      ('-10   0', -10, 0),\n      ('-10   10', -10, 10),\n      (' 0   -100', 0, -100),\n      (' 0   -10', 0, -10),\n      (' 0    0', 0, 0),\n      (' 0    10', 0, 10),\n      (' 10  -10', 10, -10),\n      (' 10   0', 10, 0),\n      (' 10   10', 10, 10),\n      (' 10  -100', 10, -100),\n  )\n  def test_alpha_weights_finite(self, log_density_log_mult, tvals_log_mult):\n    rng = random.PRNGKey(0)\n    n, d = 100, 128\n\n    key, rng = random.split(rng)\n    density = jnp.exp(log_density_log_mult + random.normal(key, [n, d]))\n    key, rng = random.split(rng)\n    tvals_unsorted = 2 * random.uniform(key, [n, d + 1]) - 1\n    tvals = jnp.exp(tvals_log_mult) * jnp.sort(tvals_unsorted, axis=-1)\n    key, rng = random.split(rng)\n    dirs = random.normal(key, [n, 3])\n\n    fn = jax.jit(render.compute_alpha_weights)\n    args = density, tvals, dirs\n\n    self.assertTrue(finite_output_and_gradient(fn, args))\n\n  def test_alpha_weights_delta_correct(self):\n    \"\"\"A single interval with a huge density should produce a spikey weight.\"\"\"\n    max_density = 1e10\n    rng = random.PRNGKey(0)\n    n, d = 100, 128\n\n    key, rng = random.split(rng)\n    r = random.normal(key, [n, d])\n    mask = (r == jnp.max(r, axis=-1, keepdims=True))\n    density = max_density * mask\n\n    key, rng = random.split(rng)\n    tvals_unsorted = 2 * random.uniform(key, [n, d + 1]) - 1\n    tvals = jnp.sort(tvals_unsorted, axis=-1)\n\n    key, rng = random.split(rng)\n    dirs = random.normal(key, [n, 3])\n\n    weights, alpha, _ = render.compute_alpha_weights(density, tvals, dirs)\n    np.testing.assert_allclose(jnp.float32(mask), weights, atol=1E-5, rtol=1E-5)\n    np.testing.assert_allclose(jnp.float32(mask), alpha, atol=1E-5, rtol=1E-5)\n\n  @parameterized.named_parameters(\n      ('-100_-100', -100, -100),\n      ('-100_-10', -100, -10),\n      ('-100__0', -100, 0),\n      ('-100__10', -100, 10),\n      ('-10__-100', -10, -100),\n      ('-10__-10', -10, -10),\n      ('-10___0', -10, 0),\n      ('-10___10', -10, 10),\n      ('_0___-100', 0, -100),\n      ('_0___-10', 0, -10),\n      ('_0____0', 0, 0),\n      ('_0____10', 0, 10),\n      ('_10__-10', 10, -10),\n      ('_10___0', 10, 0),\n      ('_10___10', 10, 10),\n      ('_10__-100', 10, -100),\n  )\n  def test_conical_frustum_to_gaussian_finite(\n      self,\n      tvals_log_mult,\n      radius_log_mult,\n  ):\n    n, d = 10, 128\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    rad = jnp.exp(radius_log_mult) * jnp.exp(random.normal(key, [n, d]))\n\n    key, rng = random.split(rng)\n    tvals_unsorted = random.uniform(key, [n, d + 1], minval=-1, maxval=1)\n    tvals = jnp.exp(tvals_log_mult) * jnp.sort(tvals_unsorted, axis=-1)\n\n    key, rng = random.split(rng)\n    dirs = random.normal(key, [n, 3])\n\n    t0, t1 = tvals[..., :-1], tvals[..., 1:]\n\n    fn = jax.jit(\n        functools.partial(render.conical_frustum_to_gaussian, diag=True))\n    args = dirs, t0, t1, rad\n\n    self.assertTrue(finite_output_and_gradient(fn, args))\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/stepfun_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Unit tests for stepfun.\"\"\"\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom internal import stepfun\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\nimport scipy as sp\n\n\ndef inner(t0, t1, w1):\n  \"\"\"A reference implementation for computing the inner measure of (t1, w1).\"\"\"\n  w0_inner = []\n  for i in range(len(t0) - 1):\n    w_sum = 0\n    for j in range(len(t1) - 1):\n      if (t1[j] >= t0[i]) and (t1[j + 1] < t0[i + 1]):\n        w_sum += w1[j]\n    w0_inner.append(w_sum)\n  w0_inner = jnp.array(w0_inner)\n  return w0_inner\n\n\ndef outer(t0, t1, w1):\n  \"\"\"A reference implementation for computing the outer measure of (t1, w1).\"\"\"\n  w0_outer = []\n  for i in range(len(t0) - 1):\n    w_sum = 0\n    for j in range(len(t1) - 1):\n      if (t1[j + 1] >= t0[i]) and (t1[j] <= t0[i + 1]):\n        w_sum += w1[j]\n    w0_outer.append(w_sum)\n  w0_outer = jnp.array(w0_outer)\n  return w0_outer\n\n\nclass StepFunTest(parameterized.TestCase):\n\n  def test_searchsorted_in_bounds(self):\n    \"\"\"Test that a[i] <= v < a[j], with (i, j) = searchsorted(a, v).\"\"\"\n    rng = random.PRNGKey(0)\n    eps = 1e-7\n    for _ in range(10):\n      # Sample vector lengths.\n      key, rng = random.split(rng)\n      n = random.randint(key, (), 10, 100)\n      key, rng = random.split(rng)\n      m = random.randint(key, (), 10, 100)\n\n      # Generate query points in [eps, 1-eps].\n      key, rng = random.split(rng)\n      v = random.uniform(key, [n], minval=eps, maxval=1 - eps)\n\n      # Generate sorted reference points that span all of [0, 1].\n      key, rng = random.split(rng)\n      a = jnp.sort(random.uniform(key, [m]))\n      a = jnp.concatenate([jnp.array([0.]), a, jnp.array([1.])])\n      idx_lo, idx_hi = stepfun.searchsorted(a, v)\n\n      self.assertTrue(jnp.all(a[idx_lo] <= v))\n      self.assertTrue(jnp.all(v < a[idx_hi]))\n\n  def test_searchsorted_out_of_bounds(self):\n    \"\"\"searchsorted should produce the first/last indices when out of bounds.\"\"\"\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      # Sample vector lengths.\n      key, rng = random.split(rng)\n      n = random.randint(key, (), 10, 100)\n      key, rng = random.split(rng)\n      m = random.randint(key, (), 10, 100)\n\n      # Generate sorted reference points that span [1, 2].\n      key, rng = random.split(rng)\n      a = jnp.sort(random.uniform(key, [m], minval=1, maxval=2))\n\n      # Generated queries below and above the reference points.\n      key, rng = random.split(rng)\n      v_lo = random.uniform(key, [n], minval=0., maxval=0.9)\n\n      key, rng = random.split(rng)\n      v_hi = random.uniform(key, [n], minval=2.1, maxval=3)\n\n      idx_lo, idx_hi = stepfun.searchsorted(a, v_lo)\n      np.testing.assert_array_equal(idx_lo, jnp.zeros_like(idx_lo))\n      np.testing.assert_array_equal(idx_hi, jnp.zeros_like(idx_hi))\n\n      idx_lo, idx_hi = stepfun.searchsorted(a, v_hi)\n      np.testing.assert_array_equal(idx_lo, jnp.full_like(idx_lo, m - 1))\n      np.testing.assert_array_equal(idx_hi, jnp.full_like(idx_hi, m - 1))\n\n  def test_searchsorted_reference(self):\n    \"\"\"Test against jnp.searchsorted, which behaves similarly to ours.\"\"\"\n    rng = random.PRNGKey(0)\n    eps = 1e-7\n    n = 30\n    m = 40\n\n    # Generate query points in [eps, 1-eps].\n    key, rng = random.split(rng)\n    v = random.uniform(key, [n], minval=eps, maxval=1 - eps)\n\n    # Generate sorted reference points that span all of [0, 1].\n    key, rng = random.split(rng)\n    a = jnp.sort(random.uniform(key, [m]))\n    a = jnp.concatenate([jnp.array([0.]), a, jnp.array([1.])])\n    _, idx_hi = stepfun.searchsorted(a, v)\n    np.testing.assert_array_equal(jnp.searchsorted(a, v), idx_hi)\n\n  def test_searchsorted(self):\n    \"\"\"An alternative correctness test for in-range queries to searchsorted.\"\"\"\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    a = jnp.sort(random.uniform(key, [10], minval=-4, maxval=4))\n\n    key, rng = random.split(rng)\n    v = random.uniform(key, [100], minval=-6, maxval=6)\n\n    idx_lo, idx_hi = stepfun.searchsorted(a, v)\n\n    for x, i0, i1 in zip(v, idx_lo, idx_hi):\n      if x < jnp.min(a):\n        i0_true, i1_true = [0] * 2\n      elif x > jnp.max(a):\n        i0_true, i1_true = [len(a) - 1] * 2\n      else:\n        i0_true = jnp.argmax(jnp.where(x >= a, a, -jnp.inf))\n        i1_true = jnp.argmin(jnp.where(x < a, a, jnp.inf))\n      np.testing.assert_array_equal(i0_true, i0)\n      np.testing.assert_array_equal(i1_true, i1)\n\n  @parameterized.named_parameters(\n      ('front_delta_0', 'front', 0.),  # Include the front of each span.\n      ('front_delta_0.05', 'front', 0.05),\n      ('front_delta_0.099', 'front', 0.099),\n      ('back_delta_1e-6', 'back', 1e-6),  # Exclude the back of each span.\n      ('back_delta_0.05', 'back', 0.05),\n      ('back_delta_0.099', 'back', 0.099),\n      ('before', 'before', 1e-6),\n      ('after', 'after', 0.),\n  )\n  def test_query(self, mode, delta):\n    \"\"\"Test that query() behaves sensibly in easy cases.\"\"\"\n    n, d = 10, 8\n    outside_value = -10.\n    max_delta = 0.1\n\n    key0, key1 = random.split(random.PRNGKey(0))\n    # Each t value is at least max_delta more than the one before.\n    t = -d / 2 + jnp.cumsum(\n        random.uniform(key0, minval=max_delta, shape=(n, d + 1)), axis=-1)\n    y = random.normal(key1, shape=(n, d))\n\n    query = lambda tq: stepfun.query(tq, t, y, outside_value=outside_value)\n\n    if mode == 'front':\n      # Query the a point relative to the front of each span, shifted by delta\n      # (if delta < max_delta this will not take you out of the current span).\n      assert delta >= 0\n      assert delta < max_delta\n      yq = query(t[..., :-1] + delta)\n      np.testing.assert_array_equal(yq, y)\n    elif mode == 'back':\n      # Query the a point relative to the back of each span, shifted by delta\n      # (if delta < max_delta this will not take you out of the current span).\n      assert delta >= 0\n      assert delta < max_delta\n      yq = query(t[..., 1:] - delta)\n      np.testing.assert_array_equal(yq, y)\n    elif mode == 'before':\n      # Query values before the domain of the step function (exclusive).\n      min_val = jnp.min(t, axis=-1)\n      assert delta >= 0\n      tq = min_val[:, None] + jnp.linspace(-10, -delta, 100)[None, :]\n      yq = query(tq)\n      np.testing.assert_array_equal(yq, outside_value)\n    elif mode == 'after':\n      # Queries values after the domain of the step function (inclusive).\n      max_val = jnp.max(t, axis=-1)\n      assert delta >= 0\n      tq = max_val[:, None] + jnp.linspace(delta, 10, 100)[None, :]\n      yq = query(tq)\n      np.testing.assert_array_equal(yq, outside_value)\n\n  def test_distortion_loss_against_sampling(self):\n    \"\"\"Test that the distortion loss matches a stochastic approximation.\"\"\"\n    # Construct a random step function that defines a weight distribution.\n    n, d = 10, 8\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    t = random.uniform(key, minval=-3, maxval=3, shape=(n, d + 1))\n    t = jnp.sort(t, axis=-1)\n    key, rng = random.split(rng)\n    logits = 2 * random.normal(key, shape=(n, d))\n\n    # Compute the distortion loss.\n    w = jax.nn.softmax(logits, axis=-1)\n    losses = stepfun.lossfun_distortion(t, w)\n\n    # Approximate the distortion loss using samples from the step function.\n    key, rng = random.split(rng)\n    samples = stepfun.sample(key, t, logits, 10000, single_jitter=False)\n    losses_stoch = []\n    for i in range(n):\n      losses_stoch.append(\n          jnp.mean(jnp.abs(samples[i][:, None] - samples[i][None, :])))\n    losses_stoch = jnp.array(losses_stoch)\n\n    np.testing.assert_allclose(losses, losses_stoch, atol=1e-4, rtol=1e-4)\n\n  def test_interval_distortion_against_brute_force(self):\n    n, d = 3, 7\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    t0 = random.uniform(key, minval=-3, maxval=3, shape=(n, d + 1))\n    t0 = jnp.sort(t0, axis=-1)\n\n    key, rng = random.split(rng)\n    t1 = random.uniform(key, minval=-3, maxval=3, shape=(n, d + 1))\n    t1 = jnp.sort(t1, axis=-1)\n\n    distortions = stepfun.interval_distortion(t0[..., :-1], t0[..., 1:],\n                                              t1[..., :-1], t1[..., 1:])\n\n    distortions_brute = np.array(jnp.zeros_like(distortions))\n    for i in range(n):\n      for j in range(d):\n        distortions_brute[i, j] = jnp.mean(\n            jnp.abs(\n                jnp.linspace(t0[i, j], t0[i, j + 1], 5001)[:, None] -\n                jnp.linspace(t1[i, j], t1[i, j + 1], 5001)[None, :]))\n    np.testing.assert_allclose(\n        distortions, distortions_brute, atol=1e-6, rtol=1e-3)\n\n  def test_distortion_loss_against_interval_distortion(self):\n    \"\"\"Test that the distortion loss matches a brute-force alternative.\"\"\"\n    # Construct a random step function that defines a weight distribution.\n    n, d = 3, 8\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    t = random.uniform(key, minval=-3, maxval=3, shape=(n, d + 1))\n    t = jnp.sort(t, axis=-1)\n    key, rng = random.split(rng)\n    logits = 2 * random.normal(key, shape=(n, d))\n\n    # Compute the distortion loss.\n    w = jax.nn.softmax(logits, axis=-1)\n    losses = stepfun.lossfun_distortion(t, w)\n\n    # Compute it again in a more brute-force way, but computing the weighted\n    # distortion of all pairs of intervals.\n    d = stepfun.interval_distortion(t[..., :-1, None], t[..., 1:, None],\n                                    t[..., None, :-1], t[..., None, 1:])\n    losses_alt = jnp.sum(w[:, None, :] * w[:, :, None] * d, axis=[-1, -2])\n\n    np.testing.assert_allclose(losses, losses_alt, atol=1e-6, rtol=1e-4)\n\n  def test_max_dilate(self):\n    \"\"\"Compare max_dilate to a brute force test on queries of step functions.\"\"\"\n    n, d, dilation = 20, 8, 0.53\n\n    # Construct a non-negative step function.\n    key0, key1 = random.split(random.PRNGKey(0))\n    t = jnp.cumsum(\n        random.randint(key0, minval=1, maxval=10, shape=(n, d + 1)),\n        axis=-1) / 10\n    w = jax.nn.softmax(random.normal(key1, shape=(n, d)), axis=-1)\n\n    # Dilate it.\n    td, wd = stepfun.max_dilate(t, w, dilation)\n\n    # Construct queries at the midpoint of each interval.\n    tq = (jnp.arange((d + 4) * 10) - 2.5) / 10\n\n    # Query the step function and its dilation.\n    wq = stepfun.query(tq[None], t, w)\n    wdq = stepfun.query(tq[None], td, wd)\n\n    # The queries of the dilation must be the max of the non-dilated queries.\n    mask = jnp.abs(tq[None, :] - tq[:, None]) <= dilation\n    for i in range(n):\n      wdq_i = jnp.max(mask * wq[i], axis=-1)\n      np.testing.assert_array_equal(wdq[i], wdq_i)\n\n  @parameterized.named_parameters(('deterministic', False, None),\n                                  ('random_multiple_jitters', True, False),\n                                  ('random_single_jitter', True, True))\n  def test_sample_train_mode(self, randomized, single_jitter):\n    \"\"\"Test that piecewise-constant sampling reproduces its distribution.\"\"\"\n    rng = random.PRNGKey(0)\n    batch_size = 4\n    num_bins = 16\n    num_samples = 1000000\n    precision = 1e5\n\n    # Generate a series of random PDFs to sample from.\n    data = []\n    for _ in range(batch_size):\n      rng, key = random.split(rng)\n      # Randomly initialize the distances between bins.\n      # We're rolling our own fixed precision here to make cumsum exact.\n      bins_delta = jnp.round(precision * jnp.exp(\n          random.uniform(key, shape=(num_bins + 1,), minval=-3, maxval=3)))\n\n      # Set some of the bin distances to 0.\n      rng, key = random.split(rng)\n      bins_delta *= random.uniform(key, shape=bins_delta.shape) < 0.9\n\n      # Integrate the bins.\n      bins = jnp.cumsum(bins_delta) / precision\n      rng, key = random.split(rng)\n      bins += random.normal(key) * num_bins / 2\n      rng, key = random.split(rng)\n\n      # Randomly generate weights, allowing some to be zero.\n      weights = jnp.maximum(\n          0, random.uniform(key, shape=(num_bins,), minval=-0.5, maxval=1.))\n      gt_hist = weights / weights.sum()\n      data.append((bins, weights, gt_hist))\n\n    bins, weights, gt_hist = [jnp.stack(x) for x in zip(*data)]\n\n    rng = random.PRNGKey(0) if randomized else None\n    # Draw samples from the batch of PDFs.\n    samples = stepfun.sample(\n        key,\n        bins,\n        jnp.log(weights) + 0.7,\n        num_samples,\n        single_jitter=single_jitter,\n    )\n    self.assertEqual(samples.shape[-1], num_samples)\n\n    # Check that samples are sorted.\n    self.assertTrue(jnp.all(samples[..., 1:] >= samples[..., :-1]))\n\n    # Verify that each set of samples resembles the target distribution.\n    for i_samples, i_bins, i_gt_hist in zip(samples, bins, gt_hist):\n      i_hist = jnp.float32(jnp.histogram(i_samples, i_bins)[0]) / num_samples\n      i_gt_hist = jnp.array(i_gt_hist)\n\n      # Merge any of the zero-span bins until there aren't any left.\n      while jnp.any(i_bins[:-1] == i_bins[1:]):\n        j = int(jnp.where(i_bins[:-1] == i_bins[1:])[0][0])\n        i_hist = jnp.concatenate([\n            i_hist[:j],\n            jnp.array([i_hist[j] + i_hist[j + 1]]), i_hist[j + 2:]\n        ])\n        i_gt_hist = jnp.concatenate([\n            i_gt_hist[:j],\n            jnp.array([i_gt_hist[j] + i_gt_hist[j + 1]]), i_gt_hist[j + 2:]\n        ])\n        i_bins = jnp.concatenate([i_bins[:j], i_bins[j + 1:]])\n\n      # Angle between the two histograms in degrees.\n      angle = 180 / jnp.pi * jnp.arccos(\n          jnp.minimum(\n              1.,\n              jnp.mean((i_hist * i_gt_hist) /\n                       jnp.sqrt(jnp.mean(i_hist**2) * jnp.mean(i_gt_hist**2)))))\n      # Jensen-Shannon divergence.\n      m = (i_hist + i_gt_hist) / 2\n      js_div = jnp.sum(\n          sp.special.kl_div(i_hist, m) + sp.special.kl_div(i_gt_hist, m)) / 2\n      self.assertLessEqual(angle, 0.5)\n      self.assertLessEqual(js_div, 1e-5)\n\n  @parameterized.named_parameters(('deterministic', False, None),\n                                  ('random_multiple_jitters', True, False),\n                                  ('random_single_jitter', True, True))\n  def test_sample_large_flat(self, randomized, single_jitter):\n    \"\"\"Test sampling when given a large flat distribution.\"\"\"\n    key = random.PRNGKey(0) if randomized else None\n    num_samples = 100\n    num_bins = 100000\n    bins = jnp.arange(num_bins)\n    weights = np.ones(len(bins) - 1)\n    samples = stepfun.sample(\n        key,\n        bins[None],\n        jnp.log(jnp.maximum(1e-15, weights[None])),\n        num_samples,\n        single_jitter=single_jitter,\n    )[0]\n    # All samples should be within the range of the bins.\n    self.assertTrue(jnp.all(samples >= bins[0]))\n    self.assertTrue(jnp.all(samples <= bins[-1]))\n\n    # Samples modded by their bin index should resemble a uniform distribution.\n    samples_mod = jnp.mod(samples, 1)\n    self.assertLessEqual(\n        sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2)\n\n    # All samples should collectively resemble a uniform distribution.\n    self.assertLessEqual(\n        sp.stats.kstest(samples, 'uniform', (bins[0], bins[-1])).statistic, 0.2)\n\n  def test_gpu_vs_tpu_resampling(self):\n    \"\"\"Test that  gather-based resampling matches the search-based resampler.\"\"\"\n    key = random.PRNGKey(0)\n    num_samples = 100\n    num_bins = 100000\n    bins = jnp.arange(num_bins)\n    weights = np.ones(len(bins) - 1)\n    samples_search_tpu = stepfun.sample(\n        key,\n        bins[None],\n        jnp.log(jnp.maximum(1e-15, weights[None])),\n        num_samples,\n        single_jitter=False,\n        use_gpu_resampling=False,\n    )[0]\n    samples_search_gpu = stepfun.sample(\n        key,\n        bins[None],\n        jnp.log(jnp.maximum(1e-15, weights[None])),\n        num_samples,\n        single_jitter=False,\n        use_gpu_resampling=True,\n    )[0]\n    np.testing.assert_allclose(\n        samples_search_tpu, samples_search_gpu, atol=1E-5, rtol=1E-5)\n\n  @parameterized.named_parameters(('deterministic', False, None),\n                                  ('random_multiple_jitters', True, False),\n                                  ('random_single_jitter', True, True))\n  def test_sample_sparse_delta(self, randomized, single_jitter):\n    \"\"\"Test sampling when given a large distribution with a big delta in it.\"\"\"\n    key = random.PRNGKey(0) if randomized else None\n    num_samples = 100\n    num_bins = 100000\n    bins = jnp.arange(num_bins)\n    weights = np.ones(len(bins) - 1)\n    delta_idx = len(weights) // 2\n    weights[delta_idx] = len(weights) - 1\n    samples = stepfun.sample(\n        key,\n        bins[None],\n        jnp.log(jnp.maximum(1e-15, weights[None])),\n        num_samples,\n        single_jitter=single_jitter,\n    )[0]\n\n    # All samples should be within the range of the bins.\n    self.assertTrue(jnp.all(samples >= bins[0]))\n    self.assertTrue(jnp.all(samples <= bins[-1]))\n\n    # Samples modded by their bin index should resemble a uniform distribution.\n    samples_mod = jnp.mod(samples, 1)\n    self.assertLessEqual(\n        sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2)\n\n    # The delta function bin should contain ~half of the samples.\n    in_delta = (samples >= bins[delta_idx]) & (samples <= bins[delta_idx + 1])\n    np.testing.assert_allclose(jnp.mean(in_delta), 0.5, atol=0.05)\n\n  @parameterized.named_parameters(('deterministic', False, None),\n                                  ('random_multiple_jitters', True, False),\n                                  ('random_single_jitter', True, True))\n  def test_sample_single_bin(self, randomized, single_jitter):\n    \"\"\"Test sampling when given a small `one hot' distribution.\"\"\"\n    key = random.PRNGKey(0) if randomized else None\n    num_samples = 625\n    bins = jnp.array([0, 1, 3, 6, 10], jnp.float32)\n    for i in range(len(bins) - 1):\n      weights = np.zeros(len(bins) - 1, jnp.float32)\n      weights[i] = 1.\n      samples = stepfun.sample(\n          key,\n          bins[None],\n          jnp.log(weights[None]),\n          num_samples,\n          single_jitter=single_jitter,\n      )[0]\n\n      # All samples should be within [bins[i], bins[i+1]].\n      self.assertTrue(jnp.all(samples >= bins[i]))\n      self.assertTrue(jnp.all(samples <= bins[i + 1]))\n\n  @parameterized.named_parameters(('deterministic', False, 0.1),\n                                  ('random', True, 0.1))\n  def test_sample_intervals_accuracy(self, randomized, tolerance):\n    \"\"\"Test that resampled intervals resemble their original distribution.\"\"\"\n    n, d = 50, 32\n    d_resample = 2 * d\n    domain = -3, 3\n\n    # Generate some step functions.\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    t = random.uniform(\n        key, minval=domain[0], maxval=domain[1], shape=(n, d + 1))\n    t = jnp.sort(t, axis=-1)\n    key, rng = random.split(rng)\n    logits = 2 * random.normal(key, shape=(n, d))\n\n    # Resample the step functions.\n    key = random.PRNGKey(999) if randomized else None\n    t_sampled = stepfun.sample_intervals(\n        key, t, logits, d_resample, single_jitter=True, domain=domain)\n\n    # Precompute the accumulated weights of the original intervals.\n    weights = jax.nn.softmax(logits, axis=-1)\n    acc_weights = stepfun.integrate_weights(weights)\n\n    errors = []\n    for i in range(t_sampled.shape[0]):\n      # Resample into the original accumulated weights.\n      acc_resampled = jnp.interp(t_sampled[i], t[i], acc_weights[i])\n      # Differentiate the accumulation to get resampled weights (that do not\n      # necessarily sum to 1 because some of the ends might get missed).\n      weights_resampled = jnp.diff(acc_resampled, axis=-1)\n      # Check that the resampled weights resemble a uniform distribution.\n      u = 1 / len(weights_resampled)\n      errors.append(float(jnp.sum(jnp.abs(weights_resampled - u))))\n    errors = jnp.array(errors)\n    mean_error = jnp.mean(errors)\n    print(f'Mean Error = {mean_error}, Tolerance = {tolerance}')\n    self.assertLess(mean_error, tolerance)\n\n  @parameterized.named_parameters(('deterministic_unbounded', False, False),\n                                  ('random_unbounded', True, False),\n                                  ('deterministic_bounded', False, True),\n                                  ('random_bounded', True, True))\n  def test_sample_intervals_unbiased(self, randomized, bound_domain):\n    \"\"\"Test that resampled intervals are unbiased.\"\"\"\n    n, d_resample = 1000, 64\n    domain = (-0.5, 0.5) if bound_domain else (-jnp.inf, jnp.inf)\n\n    # A single interval from [-0.5, 0.5].\n    t = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])\n    logits = jnp.array([0, 0, 100., 0, 0])\n\n    ts = jnp.tile(t[None], [n, 1])\n    logits = jnp.tile(logits[None], [n, 1])\n\n    # Resample the step functions.\n    rng = random.PRNGKey(0) if randomized else None\n    t_sampled = stepfun.sample_intervals(\n        rng, ts, logits, d_resample, single_jitter=True, domain=domain)\n\n    # The average sample should be close to zero.\n    if randomized:\n      self.assertLess(\n          jnp.max(jnp.abs(jnp.mean(t_sampled, axis=-1))), 0.5 / d_resample)\n    else:\n      np.testing.assert_allclose(\n          jnp.mean(t_sampled, axis=-1), jnp.zeros(n), atol=1E-5, rtol=1E-5)\n\n    # The extents of the samples should be near -0.5 and 0.5.\n    if bound_domain and randomized:\n      np.testing.assert_allclose(jnp.median(t_sampled[:, 0]), -0.5, atol=1e-4)\n      np.testing.assert_allclose(jnp.median(t_sampled[:, -1]), 0.5, atol=1e-4)\n\n    # The interval edge near the extent should be centered around +/-0.5.\n    if randomized:\n      np.testing.assert_allclose(\n          jnp.mean(t_sampled[:, 0] > -0.5), 0.5, atol=1 / d_resample)\n      np.testing.assert_allclose(\n          jnp.mean(t_sampled[:, -1] < 0.5), 0.5, atol=1 / d_resample)\n\n  def test_sample_single_interval(self):\n    \"\"\"Resample a single interval and check that it's a linspace.\"\"\"\n    t = jnp.array([1, 2, 3, 4, 5, 6])\n    logits = jnp.array([0, 0, 100, 0, 0])\n    key = None\n    t_sampled = stepfun.sample_intervals(key, t, logits, 10, single_jitter=True)\n    np.testing.assert_allclose(\n        t_sampled, jnp.linspace(3, 4, 11), atol=1E-5, rtol=1E-5)\n\n  @parameterized.named_parameters(('sameset', 0, True), ('diffset', 2, False))\n  def test_lossfun_outer(self, num_ablate, is_all_zero):\n    \"\"\"Two histograms of the same/diff points have a loss of zero/non-zero.\"\"\"\n    rng = random.PRNGKey(0)\n    eps = 1e-12  # Need a little slack because of cumsum's numerical weirdness.\n    all_zero = True\n    for _ in range(10):\n      key, rng = random.split(rng)\n      num_pts, d0, d1 = random.randint(key, [3], minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t0 = jnp.sort(random.uniform(key, [d0 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      t1 = jnp.sort(random.uniform(key, [d1 + 1]), axis=-1)\n\n      lo = jnp.maximum(jnp.min(t0), jnp.min(t1)) + 0.1\n      hi = jnp.minimum(jnp.max(t0), jnp.max(t1)) - 0.1\n      rand = random.uniform(key, [num_pts], minval=lo, maxval=hi)\n\n      pts = rand\n      pts_ablate = rand[:-num_ablate] if num_ablate > 0 else pts\n\n      w0 = []\n      for i in range(len(t0) - 1):\n        w0.append(jnp.mean((pts_ablate >= t0[i]) & (pts_ablate < t0[i + 1])))\n      w0 = jnp.array(w0)\n\n      w1 = []\n      for i in range(len(t1) - 1):\n        w1.append(jnp.mean((pts >= t1[i]) & (pts < t1[i + 1])))\n      w1 = jnp.array(w1)\n\n      all_zero &= jnp.all(stepfun.lossfun_outer(t0, w0, t1, w1) < eps)\n    self.assertEqual(is_all_zero, all_zero)\n\n  def test_inner_outer(self):\n    \"\"\"Two histograms of the same points will be bounds on each other.\"\"\"\n    rng = random.PRNGKey(4)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d0, d1, num_pts = random.randint(key, [3], minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t0 = jnp.sort(random.uniform(key, [d0 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      t1 = jnp.sort(random.uniform(key, [d1 + 1]), axis=-1)\n\n      lo = jnp.maximum(jnp.min(t0), jnp.min(t1)) + 0.1\n      hi = jnp.minimum(jnp.max(t0), jnp.max(t1)) - 0.1\n      pts = random.uniform(key, [num_pts], minval=lo, maxval=hi)\n\n      w0 = []\n      for i in range(len(t0) - 1):\n        w0.append(jnp.sum((pts >= t0[i]) & (pts < t0[i + 1])))\n      w0 = jnp.array(w0)\n\n      w1 = []\n      for i in range(len(t1) - 1):\n        w1.append(jnp.sum((pts >= t1[i]) & (pts < t1[i + 1])))\n      w1 = jnp.array(w1)\n\n      w0_inner, w0_outer = stepfun.inner_outer(t0, t1, w1)\n      w1_inner, w1_outer = stepfun.inner_outer(t1, t0, w0)\n\n      self.assertTrue(jnp.all(w0_inner <= w0) and jnp.all(w0 <= w0_outer))\n      self.assertTrue(jnp.all(w1_inner <= w1) and jnp.all(w1 <= w1_outer))\n\n  def test_lossfun_outer_monotonic(self):\n    \"\"\"The loss is invariant to monotonic transformations on `t`.\"\"\"\n    rng = random.PRNGKey(0)\n\n    curve_fn = lambda x: 1 + x**3  # Some monotonic transformation.\n\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d0, d1 = random.randint(key, [2], minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t0 = jnp.sort(random.uniform(key, [d0 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      t1 = jnp.sort(random.uniform(key, [d1 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      w0 = jnp.exp(random.normal(key, [d0]))\n\n      key, rng = random.split(rng)\n      w1 = jnp.exp(random.normal(key, [d1]))\n\n      excess = stepfun.lossfun_outer(t0, w0, t1, w1)\n      curve_excess = stepfun.lossfun_outer(curve_fn(t0), w0, curve_fn(t1), w1)\n      self.assertTrue(jnp.all(excess == curve_excess))\n\n  def test_lossfun_outer_self_zero(self):\n    \"\"\"The loss is ~zero for the same (t, w) step function.\"\"\"\n    rng = random.PRNGKey(0)\n\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d = random.randint(key, (), minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t = jnp.sort(random.uniform(key, [d + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      w = jnp.exp(random.normal(key, [d]))\n\n      self.assertTrue(jnp.all(stepfun.lossfun_outer(t, w, t, w) < 1e-10))\n\n  def test_outer_measure_reference(self):\n    \"\"\"Test that outer measures match a reference implementation.\"\"\"\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d0, d1 = random.randint(key, [2], minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t0 = jnp.sort(random.uniform(key, [d0 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      t1 = jnp.sort(random.uniform(key, [d1 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      w0 = jnp.exp(random.normal(key, [d0]))\n\n      _, w1_outer = stepfun.inner_outer(t1, t0, w0)\n      w1_outer_ref = outer(t1, t0, w0)\n      np.testing.assert_allclose(w1_outer, w1_outer_ref, atol=1E-5, rtol=1E-5)\n\n  def test_inner_measure_reference(self):\n    \"\"\"Test that inner measures match a reference implementation.\"\"\"\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      key, rng = random.split(rng)\n      d0, d1 = random.randint(key, [2], minval=10, maxval=20)\n\n      key, rng = random.split(rng)\n      t0 = jnp.sort(random.uniform(key, [d0 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      t1 = jnp.sort(random.uniform(key, [d1 + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      w0 = jnp.exp(random.normal(key, [d0]))\n\n      w1_inner, _ = stepfun.inner_outer(t1, t0, w0)\n      w1_inner_ref = inner(t1, t0, w0)\n      np.testing.assert_allclose(w1_inner, w1_inner_ref, rtol=1e-5, atol=1e-5)\n\n  def test_weighted_percentile(self):\n    \"\"\"Test that step function percentiles match the empirical percentile.\"\"\"\n    num_samples = 1000000\n    rng = random.PRNGKey(0)\n    for _ in range(10):\n      rng, key = random.split(rng)\n      d = random.randint(key, (), minval=10, maxval=20)\n\n      rng, key = random.split(rng)\n      ps = 100 * random.uniform(key, [3])\n\n      key, rng = random.split(rng)\n      t = jnp.sort(random.normal(key, [d + 1]), axis=-1)\n\n      key, rng = random.split(rng)\n      w = jax.nn.softmax(random.normal(key, [d]))\n\n      key, rng = random.split(rng)\n      samples = stepfun.sample(\n          key, t, jnp.log(w), num_samples, single_jitter=False)\n      true_percentiles = jnp.percentile(samples, ps)\n\n      our_percentiles = stepfun.weighted_percentile(t, w, ps)\n      np.testing.assert_allclose(\n          our_percentiles, true_percentiles, rtol=1e-4, atol=1e-4)\n\n  def test_weighted_percentile_vectorized(self):\n    rng = random.PRNGKey(0)\n    shape = (3, 4)\n    d = 128\n\n    rng, key = random.split(rng)\n    ps = 100 * random.uniform(key, (5,))\n\n    key, rng = random.split(rng)\n    t = jnp.sort(random.normal(key, shape + (d + 1,)), axis=-1)\n\n    key, rng = random.split(rng)\n    w = jax.nn.softmax(random.normal(key, shape + (d,)))\n\n    percentiles_vec = stepfun.weighted_percentile(t, w, ps)\n\n    percentiles = []\n    for i in range(shape[0]):\n      percentiles.append([])\n      for j in range(shape[1]):\n        percentiles[i].append(stepfun.weighted_percentile(t[i, j], w[i, j], ps))\n      percentiles[i] = jnp.stack(percentiles[i])\n    percentiles = jnp.stack(percentiles)\n\n    np.testing.assert_allclose(\n        percentiles_vec, percentiles, rtol=1e-5, atol=1e-5)\n\n  @parameterized.named_parameters(('', False), ('_avg', True))\n  def test_resample_self_noop(self, use_avg):\n    \"\"\"Resampling a step function into itself should be a no-op.\"\"\"\n    d = 32\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=(d + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=(d,))\n\n    vp_recon = stepfun.resample(tp, tp, vp, use_avg=use_avg)\n    np.testing.assert_allclose(vp, vp_recon, atol=1e-4)\n\n  @parameterized.named_parameters(('', False), ('_avg', True))\n  def test_resample_2x_downsample(self, use_avg):\n    \"\"\"Check resampling for a 2d downsample.\"\"\"\n    d = 32\n    rng = random.PRNGKey(0)\n\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=(d + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=(d,))\n\n    t = tp[::2]\n\n    v = stepfun.resample(t, tp, vp, use_avg=use_avg)\n\n    vp2 = vp.reshape([-1, 2])\n    dtp2 = jnp.diff(tp).reshape([-1, 2])\n    if use_avg:\n      v_true = jnp.sum(vp2 * dtp2, axis=-1) / jnp.sum(dtp2, axis=-1)\n    else:\n      v_true = jnp.sum(vp2, axis=-1)\n\n    np.testing.assert_allclose(v, v_true, atol=1e-4)\n\n  @parameterized.named_parameters(('', False), ('_avg', True))\n  def test_resample_entire_interval(self, use_avg):\n    \"\"\"Check the sum (or weighted mean) of an entire interval.\"\"\"\n    d = 32\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=(d + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=(d,))\n\n    t = jnp.array([jnp.min(tp), jnp.max(tp)])\n\n    v = stepfun.resample(t, tp, vp, use_avg=use_avg)[0]\n    if use_avg:\n      v_true = jnp.sum(vp * jnp.diff(tp)) / sum(jnp.diff(tp))\n    else:\n      v_true = jnp.sum(vp)\n\n    np.testing.assert_allclose(v, v_true, atol=1e-4)\n\n  def test_resample_entire_domain(self):\n    \"\"\"Check the sum of the entire input domain.\"\"\"\n    d = 32\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=(d + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=(d,))\n\n    t = jnp.array([-1e6, 1e6])\n\n    v = stepfun.resample(t, tp, vp)[0]\n    v_true = jnp.sum(vp)\n\n    np.testing.assert_allclose(v, v_true, atol=1e-4)\n\n  @parameterized.named_parameters(('', False), ('_avg', True))\n  def test_resample_single_span(self, use_avg):\n    \"\"\"Check the sum (or weighted mean) of a single span.\"\"\"\n    d = 32\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=(d + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=(d,))\n\n    pad = (tp[d // 2 + 1] - tp[d // 2]) / 4\n    t = jnp.array([tp[d // 2] + pad, tp[d // 2 + 1] - pad])\n\n    v = stepfun.resample(t, tp, vp, use_avg=use_avg)[0]\n    if use_avg:\n      v_true = vp[d // 2]\n    else:\n      v_true = vp[d // 2] * 0.5\n\n    np.testing.assert_allclose(v, v_true, atol=1e-4)\n\n  @parameterized.named_parameters(('', False), ('_avg', True))\n  def test_resample_vectorized(self, use_avg):\n    \"\"\"Check that resample works with vectorized inputs.\"\"\"\n    shape = (3, 4)\n    dp = 32\n    d = 16\n    rng = random.PRNGKey(0)\n    key, rng = random.split(rng)\n    tp = random.normal(rng, shape=shape + (dp + 1,))\n    tp = jnp.sort(tp)\n\n    key, rng = random.split(rng)\n    vp = random.normal(key, shape=shape + (dp,))\n\n    key, rng = random.split(rng)\n    t = random.normal(rng, shape=shape + (d + 1,))\n    t = jnp.sort(t)\n\n    v_batch = stepfun.resample(t, tp, vp, use_avg=use_avg)\n\n    v_indiv = []\n    for i in range(t.shape[0]):\n      v_indiv.append(\n          jnp.array([\n              stepfun.resample(t[i, j], tp[i, j], vp[i, j], use_avg=use_avg)\n              for j in range(t.shape[1])\n          ]))\n    v_indiv = jnp.array(v_indiv)\n\n    np.testing.assert_allclose(v_batch, v_indiv, atol=1e-4)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/tests/utils_test.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for utils.\"\"\"\n\nfrom absl.testing import absltest\n\nfrom internal import utils\n\n\nclass UtilsTest(absltest.TestCase):\n\n  def test_dummy_rays(self):\n    \"\"\"Ensures that the dummy Rays object is correctly initialized.\"\"\"\n    rays = utils.dummy_rays()\n    self.assertEqual(rays.origins.shape[-1], 3)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "mip360/train.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Training script.\"\"\"\n# import os\n# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'\n# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'\n# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n\nimport functools\nimport gc\nimport time\n\nfrom absl import app\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import checkpoints\nimport gin\nfrom internal import configs\nfrom internal import datasets\nfrom internal import image\nfrom internal import models\nfrom internal import train_utils\nfrom internal import utils\nfrom internal import vis\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\nimport numpy as np\n\nconfigs.define_common_flags()\njax.config.parse_flags_with_absl()\n\nTIME_PRECISION = 1000  # Internally represent integer times in milliseconds.\n\n\ndef main(unused_argv):\n  rng = random.PRNGKey(20200823)\n  # Shift the numpy random seed by host_id() to shuffle data loaded by different\n  # hosts.\n  np.random.seed(20201473 + jax.host_id())\n\n  config = configs.load_config()\n\n  if config.batch_size % jax.device_count() != 0:\n    raise ValueError('Batch size must be divisible by the number of devices.')\n\n  dataset = datasets.load_dataset('train', config.data_dir, config)\n  test_dataset = datasets.load_dataset('test', config.data_dir, config)\n\n  np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x\n  cameras = tuple(np_to_jax(x) for x in dataset.cameras)\n\n  if config.rawnerf_mode:\n    postprocess_fn = test_dataset.metadata['postprocess_fn']\n  else:\n    postprocess_fn = lambda z, _=None: z\n\n  rng, key = random.split(rng)\n  setup = train_utils.setup_model(config, key, dataset=dataset)\n  model, state, render_eval_pfn, train_pstep, lr_fn = setup\n\n  variables = state.params\n  num_params = jax.tree_util.tree_reduce(\n      lambda x, y: x + jnp.prod(jnp.array(y.shape)), variables, initializer=0)\n  print(f'Number of parameters being optimized: {num_params}')\n\n  if (dataset.size > model.num_glo_embeddings and model.num_glo_features > 0):\n    raise ValueError(f'Number of glo embeddings {model.num_glo_embeddings} '\n                     f'must be at least equal to number of train images '\n                     f'{dataset.size}')\n\n  metric_harness = image.MetricHarness()\n\n  if not utils.isdir(config.checkpoint_dir):\n    utils.makedirs(config.checkpoint_dir)\n  state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)\n  # Resume training at the step of the last checkpoint.\n  init_step = state.step + 1\n  state = flax.jax_utils.replicate(state)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(config.checkpoint_dir)\n    if config.rawnerf_mode:\n      for name, data in zip(['train', 'test'], [dataset, test_dataset]):\n        # Log shutter speed metadata in TensorBoard for debug purposes.\n        for key in ['exposure_idx', 'exposure_values', 'unique_shutters']:\n          summary_writer.text(f'{name}_{key}', str(data.metadata[key]), 0)\n\n  # Prefetch_buffer_size = 3 x batch_size.\n  pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)\n  rng = rng + jax.host_id()  # Make random seed separate across hosts.\n  rngs = random.split(rng, jax.local_device_count())  # For pmapping RNG keys.\n  gc.disable()  # Disable automatic garbage collection for efficiency.\n  total_time = 0\n  total_steps = 0\n  reset_stats = True\n  if config.early_exit_steps is not None:\n    num_steps = config.early_exit_steps\n  else:\n    num_steps = config.max_steps\n  for step, batch in zip(range(init_step, num_steps + 1), pdataset):\n\n    if reset_stats and (jax.host_id() == 0):\n      stats_buffer = []\n      train_start_time = time.time()\n      reset_stats = False\n\n    learning_rate = lr_fn(step)\n    train_frac = jnp.clip((step - 1) / (config.max_steps - 1), 0, 1)\n\n    state, stats, rngs = train_pstep(\n        rngs,\n        state,\n        batch,\n        cameras,\n        train_frac,\n    )\n\n    if step % config.gc_every == 0:\n      gc.collect()  # Disable automatic garbage collection for efficiency.\n\n    # Log training summaries. This is put behind a host_id check because in\n    # multi-host evaluation, all hosts need to run inference even though we\n    # only use host 0 to record results.\n    if jax.host_id() == 0:\n      stats = flax.jax_utils.unreplicate(stats)\n\n      stats_buffer.append(stats)\n\n      if step == init_step or step % config.print_every == 0:\n        elapsed_time = time.time() - train_start_time\n        steps_per_sec = config.print_every / elapsed_time\n        rays_per_sec = config.batch_size * steps_per_sec\n\n        # A robust approximation of total training time, in case of pre-emption.\n        total_time += int(round(TIME_PRECISION * elapsed_time))\n        total_steps += config.print_every\n        approx_total_time = int(round(step * total_time / total_steps))\n\n        # Transpose and stack stats_buffer along axis 0.\n        fs = [flax.traverse_util.flatten_dict(s, sep='/') for s in stats_buffer]\n        stats_stacked = {k: jnp.stack([f[k] for f in fs]) for k in fs[0].keys()}\n\n        # Split every statistic that isn't a vector into a set of statistics.\n        stats_split = {}\n        for k, v in stats_stacked.items():\n          if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer):\n            raise ValueError('statistics must be of size [n], or [n, k].')\n          if v.ndim == 1:\n            stats_split[k] = v\n          elif v.ndim == 2:\n            for i, vi in enumerate(tuple(v.T)):\n              stats_split[f'{k}/{i}'] = vi\n\n        # Summarize the entire histogram of each statistic.\n        for k, v in stats_split.items():\n          summary_writer.histogram('train_' + k, v, step)\n\n        # Take the mean and max of each statistic since the last summary.\n        avg_stats = {k: jnp.mean(v) for k, v in stats_split.items()}\n        max_stats = {k: jnp.max(v) for k, v in stats_split.items()}\n\n        summ_fn = lambda s, v: summary_writer.scalar(s, v, step)  # pylint:disable=cell-var-from-loop\n\n        # Summarize the mean and max of each statistic.\n        for k, v in avg_stats.items():\n          summ_fn(f'train_avg_{k}', v)\n        for k, v in max_stats.items():\n          summ_fn(f'train_max_{k}', v)\n\n        summ_fn('train_num_params', num_params)\n        summ_fn('train_learning_rate', learning_rate)\n        summ_fn('train_steps_per_sec', steps_per_sec)\n        summ_fn('train_rays_per_sec', rays_per_sec)\n\n        summary_writer.scalar('train_avg_psnr_timed', avg_stats['psnr'],\n                              total_time // TIME_PRECISION)\n        summary_writer.scalar('train_avg_psnr_timed_approx', avg_stats['psnr'],\n                              approx_total_time // TIME_PRECISION)\n\n        if dataset.metadata is not None and model.learned_exposure_scaling:\n          params = state.params['params']\n          scalings = params['exposure_scaling_offsets']['embedding'][0]\n          num_shutter_speeds = dataset.metadata['unique_shutters'].shape[0]\n          for i_s in range(num_shutter_speeds):\n            for j_s, value in enumerate(scalings[i_s]):\n              summary_name = f'exposure/scaling_{i_s}_{j_s}'\n              summary_writer.scalar(summary_name, value, step)\n\n        precision = int(np.ceil(np.log10(config.max_steps))) + 1\n        avg_loss = avg_stats['loss']\n        avg_psnr = avg_stats['psnr']\n        str_losses = {  # Grab each \"losses_{x}\" field and print it as \"x[:4]\".\n            k[7:11]: (f'{v:0.5f}' if v >= 1e-4 and v < 10 else f'{v:0.1e}')\n            for k, v in avg_stats.items()\n            if k.startswith('losses/')\n        }\n        print(f'{step:{precision}d}' + f'/{config.max_steps:d}: ' +\n              f'loss={avg_loss:0.5f}, ' + f'psnr={avg_psnr:6.3f}, ' +\n              f'lr={learning_rate:0.2e} | ' +\n              ', '.join([f'{k}={s}' for k, s in str_losses.items()]) +\n              f', {rays_per_sec:0.0f} r/s')\n\n        # Reset everything we are tracking between summarizations.\n        reset_stats = True\n\n      if step == 1 or step % config.checkpoint_every == 0:\n        state_to_save = jax.device_get(\n            flax.jax_utils.unreplicate(state))\n        checkpoints.save_checkpoint(\n            config.checkpoint_dir, state_to_save, int(step), keep=100)\n\n    # Test-set evaluation.\n    if config.train_render_every > 0 and step % config.train_render_every == 0:\n      # We reuse the same random number generator from the optimization step\n      # here on purpose so that the visualization matches what happened in\n      # training.\n      eval_start_time = time.time()\n      eval_variables = flax.jax_utils.unreplicate(state).params\n      test_case = next(test_dataset)\n      rendering = models.render_image(\n          functools.partial(render_eval_pfn, eval_variables, train_frac),\n          test_case.rays, rngs[0], config)\n\n      # Log eval summaries on host 0.\n      if jax.host_id() == 0:\n        eval_time = time.time() - eval_start_time\n        num_rays = jnp.prod(jnp.array(test_case.rays.directions.shape[:-1]))\n        rays_per_sec = num_rays / eval_time\n        summary_writer.scalar('test_rays_per_sec', rays_per_sec, step)\n        print(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec')\n\n        metric_start_time = time.time()\n        metric = metric_harness(\n            postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb))\n        print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s')\n        for name, val in metric.items():\n          if not np.isnan(val):\n            print(f'{name} = {val:.4f}')\n            summary_writer.scalar('train_metrics/' + name, val, step)\n\n        if config.vis_decimate > 1:\n          d = config.vis_decimate\n          decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]\n        else:\n          decimate_fn = lambda x: x\n        rendering = jax.tree_util.tree_map(decimate_fn, rendering)\n        test_case = jax.tree_util.tree_map(decimate_fn, test_case)\n        vis_start_time = time.time()\n        vis_suite = vis.visualize_suite(rendering, test_case.rays)\n        print(f'Visualized in {(time.time() - vis_start_time):0.3f}s')\n        if config.rawnerf_mode:\n          # Unprocess raw output.\n          vis_suite['color_raw'] = rendering['rgb']\n          # Autoexposed colors.\n          vis_suite['color_auto'] = postprocess_fn(rendering['rgb'], None)\n          summary_writer.image('test_true_auto',\n                               postprocess_fn(test_case.rgb, None), step)\n          # Exposure sweep colors.\n          exposures = test_dataset.metadata['exposure_levels']\n          for p, x in list(exposures.items()):\n            vis_suite[f'color/{p}'] = postprocess_fn(rendering['rgb'], x)\n            summary_writer.image(f'test_true_color/{p}',\n                                 postprocess_fn(test_case.rgb, x), step)\n        summary_writer.image('test_true_color', test_case.rgb, step)\n        if config.compute_normal_metrics:\n          summary_writer.image('test_true_normals',\n                               test_case.normals / 2. + 0.5, step)\n        for k, v in vis_suite.items():\n          summary_writer.image('test_output_' + k, v, step)\n\n  if jax.host_id() == 0 and config.max_steps % config.checkpoint_every != 0:\n    state = jax.device_get(flax.jax_utils.unreplicate(state))\n    checkpoints.save_checkpoint(\n        config.checkpoint_dir, state, int(config.max_steps), keep=100)\n\n\nif __name__ == '__main__':\n  with gin.config_scope('train'):\n    app.run(main)\n"
  }
]