[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\n!hy3dgen/texgen/custom_rasterizer/lib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n.DS_Store\n# Cython debug symbols\ncython_debug/\ngradio_cache/\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n"
  },
  {
    "path": "LICENSE",
    "content": "TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT\nTencent Hunyuan 3D 2.1 Release Date: June 13, 2025\nTHIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.\nBy clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan 3D 2.1 Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.\n1.\tDEFINITIONS.\na.\t“Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.\nb.\t“Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan 3D 2.1 Works or any portion or element thereof set forth herein.\nc.\t“Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan 3D 2.1 made publicly available by Tencent.\nd.\t“Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.\ne.\t“Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan 3D 2.1 Works for any purpose and in any field of use.\nf.\t“Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan 3D 2.1 and Documentation (and any portion thereof) as made available by Tencent under this Agreement.\ng.\t“Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; (ii) works based on Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1, to that model in order to cause that model to perform similarly to Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1 for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.\nh.\t“Output” shall mean the information and/or content output of Tencent Hunyuan 3D 2.1 or a Model Derivative that results from operating or otherwise using Tencent Hunyuan 3D 2.1 or a Model Derivative, including via a Hosted Service.\ni.\t“Tencent,” “We” or “Us” shall mean THL Q Limited.\nj.\t“Tencent Hunyuan 3D 2.1” shall mean the 3D generation models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at [ https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1].\nk.\t“Tencent Hunyuan 3D 2.1 Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.\nl.\t“Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea. \nm.\t“Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.\nn.\t“including” shall mean including but not limited to.\n2.\tGRANT OF RIGHTS.\nWe grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.\n3.\tDISTRIBUTION.\nYou may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan 3D 2.1 Works, exclusively in the Territory, provided that You meet all of the following conditions:\na.\tYou must provide all such Third Party recipients of the Tencent Hunyuan 3D 2.1 Works or products or services using them a copy of this Agreement;\nb.\tYou must cause any modified files to carry prominent notices stating that You changed the files;\nc.\tYou are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan 3D 2.1 Works; and (ii) mark the products or services developed by using the Tencent Hunyuan 3D 2.1 Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and\nd.\tAll distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan 3D 2.1 is licensed under the Tencent Hunyuan 3D 2.1 Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”\nYou may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan 3D 2.1 Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.\n4.\tADDITIONAL COMMERCIAL TERMS.\nIf, on the Tencent Hunyuan 3D 2.1 version release date, the monthly active users of all products or services made available by or for Licensee is greater than 1 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights. \nSubject to Tencent's written approval, you may request a license for the use of Tencent Hunyuan 3D 2.1 by submitting the following information to hunyuan3d@tencent.com:\na.\tYour company’s name and associated business sector that plans to use Tencent Hunyuan 3D 2.1.\nb.\tYour intended use case and the purpose of using Tencent Hunyuan 3D 2.1.\nc.\tYour plans to modify Tencent Hunyuan 3D 2.1 or create Model Derivatives.\n5.\tRULES OF USE.\na.\tYour use of the Tencent Hunyuan 3D 2.1 Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan 3D 2.1 Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan 3D 2.1 Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan 3D 2.1 Works are subject to the use restrictions in these Sections 5(a) and 5(b).\nb.\tYou must not use the Tencent Hunyuan 3D 2.1 Works or any Output or results of the Tencent Hunyuan 3D 2.1 Works to improve any other AI model (other than Tencent Hunyuan 3D 2.1 or Model Derivatives thereof).\nc.\tYou must not use, reproduce, modify, distribute, or display the Tencent Hunyuan 3D 2.1 Works, Output or results of the Tencent Hunyuan 3D 2.1 Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.\n6.\tINTELLECTUAL PROPERTY.\na.\tSubject to Tencent’s ownership of Tencent Hunyuan 3D 2.1 Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.\nb.\tNo trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan 3D 2.1 Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan 3D 2.1 Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.\nc.\tIf You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan 3D 2.1 Works.\nd.\tTencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.\n7.\tDISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.\na.\tWe are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan 3D 2.1 Works or to grant any license thereto.\nb.\tUNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN 3D 2.1 WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.\nc.\tTO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.\n8.\tSURVIVAL AND TERMINATION.\na.\tThe term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.\nb.\tWe may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan 3D 2.1 Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.\n9.\tGOVERNING LAW AND JURISDICTION.\na.\tThis Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.\nb.\tExclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.\n \nEXHIBIT A\nACCEPTABLE USE POLICY\n\nTencent reserves the right to update this Acceptable Use Policy from time to time.\nLast modified: November 5, 2024\n\nTencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan 3D 2.1. You agree not to use Tencent Hunyuan 3D 2.1 or Model Derivatives:\n1.\tOutside the Territory;\n2.\tIn any way that violates any applicable national, federal, state, local, international or any other law or regulation;\n3.\tTo harm Yourself or others;\n4.\tTo repurpose or distribute output from Tencent Hunyuan 3D 2.1 or any Model Derivatives to harm Yourself or others; \n5.\tTo override or circumvent the safety guardrails and safeguards We have put in place;\n6.\tFor the purpose of exploiting, harming or attempting to exploit or harm minors in any way;\n7.\tTo generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;\n8.\tTo generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;\n9.\tTo intentionally defame, disparage or otherwise harass others;\n10.\tTo generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;\n11.\tTo generate or disseminate personal identifiable information with the purpose of harming others;\n12.\tTo generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;\n13.\tTo impersonate another individual without consent, authorization, or legal right;\n14.\tTo make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);\n15.\tIn a manner that violates or disrespects the social ethics and moral standards of other countries or regions;\n16.\tTo perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;\n17.\tFor any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;\n18.\tTo intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;\n19.\tFor military purposes;\n20.\tTo engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.\n"
  },
  {
    "path": "Notice.txt",
    "content": "Usage and Legal Notices:\n\nTencent is pleased to support the open source community by making Hunyuan 3D 2.1 available.\n\nCopyright (C) 2025 Tencent.  All rights reserved. The below software and/or models in this distribution may have been modified by Tencent (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) Tencent.\n\nHunyuan 3D 2.1 is licensed under the TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT except for the third-party components listed below, which is licensed under different terms. Hunyuan 3D 2.1 does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. \n\nFor avoidance of doubts, Hunyuan 3D 2.1 means inference-enabling code, parameters, and weights of this Model only, which are made publicly available by Tencent in accordance with TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT.\n\n\nOther dependencies and licenses:\n\n\nOpen Source Model Licensed under the MIT and CreativeML Open RAIL++-M License:\n--------------------------------------------------------------------\n1. Stable Diffusion\nCopyright (c) 2022 Stability AI\n\n\nTerms of the MIT and CreativeML Open RAIL++-M License:\n--------------------------------------------------------------------\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\nCreativeML Open RAIL++-M License\ndated November 24, 2022\n\nSection I: PREAMBLE\n\nMultimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.\n\nNotwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.\n\nIn short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.\n\nEven though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.\n\nThis License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.\n\nNOW THEREFORE, You and Licensor agree as follows:\n\n1. Definitions\n\n- \"License\" means the terms and conditions for use, reproduction, and Distribution as defined in this document.\n- \"Data\" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.\n- \"Output\" means the results of operating a Model as embodied in informational content resulting therefrom.\n- \"Model\" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.\n- \"Derivatives of the Model\" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.\n- \"Complementary Material\" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.\n- \"Distribution\" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.\n- \"Licensor\" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.\n- \"You\" (or \"Your\") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.\n- \"Third Parties\" means individuals or legal entities that are not under common control with Licensor or You.\n- \"Contribution\" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, \"submitted\" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as \"Not a Contribution.\"\n- \"Contributor\" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.\n\nSection II: INTELLECTUAL PROPERTY RIGHTS\n\nBoth copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.\n\n2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.\n3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.\n\nSection III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION\n\n4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:\nUse-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.\nYou must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;\nYou must cause any modified files to carry prominent notices stating that You changed the files;\nYou must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.\nYou may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.\n5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).\n6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.\n\nSection IV: OTHER PROVISIONS\n\n7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.\n8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.\n9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.\n10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.\n12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.\n\nEND OF TERMS AND CONDITIONS\n\n\n\n\nAttachment A\n\nUse Restrictions\n\nYou agree not to use the Model or Derivatives of the Model:\n\n- In any way that violates any applicable national, federal, state, local or international law or regulation;\n- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;\n- To generate or disseminate verifiably false information and/or content with the purpose of harming others;\n- To generate or disseminate personal identifiable information that can be used to harm an individual;\n- To defame, disparage or otherwise harass others;\n- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;\n- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;\n- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;\n- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;\n- To provide medical advice and medical results interpretation;\n- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n<h1>UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement</h1>\n\n<a href=\"https://arxiv.org/pdf/2512.21185\"><img src=\"https://img.shields.io/badge/arXiv-2512.21185-b31b1b.svg?style=flat-square\" alt=\"arXiv\"></a>\n<a href=\"https://pku-yuangroup.github.io/UltraShape-1.0/\"><img src=\"https://img.shields.io/badge/Project-Page-blue?style=flat-square\" alt=\"Project Page\"></a>\n<a href=\"https://huggingface.co/infinith/UltraShape\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow?style=flat-square\" alt=\"HuggingFace Models\"></a>\n\n</div>\n\n<br/>\n\n<div align=\"center\">\n  <img src=\"docs/assets/images/teaser.png\" width=\"100%\" alt=\"UltraShape 1.0 Teaser\" />\n</div>\n\n<br/>\n\n## 📖 Abstract\n\nIn this report, we introduce **UltraShape 1.0**, a scalable 3D diffusion framework for high-fidelity 3D geometry generation. The proposed approach adopts a **two-stage generation pipeline**: a coarse global structure is first synthesized and then refined to produce detailed, high-quality geometry.\n\nTo support reliable 3D generation, we develop a comprehensive data processing pipeline that includes a novel **watertight processing method** and **high-quality data filtering**. This pipeline improves the geometric quality of publicly available 3D datasets by removing low-quality samples, filling holes, and thickening thin structures, while preserving fine-grained geometric details. \n\nTo enable fine-grained geometry refinement, we decouple spatial localization from geometric detail synthesis in the diffusion process. We achieve this by performing **voxel-based refinement** at fixed spatial locations, where voxel queries derived from coarse geometry provide explicit positional anchors encoded via **RoPE**, allowing the diffusion model to focus on synthesizing local geometric details within a reduced, structured solution space.\n\nExtensive evaluations demonstrate that UltraShape 1.0 performs competitively with existing open-source methods in both data processing quality and geometry generation.\n\n## 🔥 News\n\n* **[2025-12-25]** 📄 We released the technical report of **UltraShape 1.0** on arXiv.\n* **[2025-12-26]** 🚀 We released the inference code and pre-trained models.\n* **[2025-12-31]** 🚀 We released the training code.\n\n## 🗓️ To-Do List\n- [x] Release inference code\n- [x] Release pre-trained weights (Hugging Face)\n- [x] Release training code\n- [ ] Release data processing scripts\n\n## 🛠️ Installation & Usage\n\n### 1. Environment Setup\n```bash\ngit clone https://github.com/PKU-YuanGroup/UltraShape-1.0.git\ncd UltraShape-1.0\n# 1. Create and activate the environment\nconda create -n ultrashape python=3.10\nconda activate ultrashape\n\n# 2. Install PyTorch (CUDA 12.1 recommended)\npip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121\n\n# 3. Install dependencies\npip install -r requirements.txt\n\n# 4. Install cubvh (Required for MC acceleration)\npip install git+https://github.com/ashawkey/cubvh --no-build-isolation\n\n# For Training & Sampling (Optional)\npip install --no-build-isolation \"git+https://github.com/facebookresearch/pytorch3d.git@stable\"\npip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl\n```\n⬇️ Model Weights\n\nPlease download the pre-trained weights from Hugging Face [ [infinith/UltraShape](https://huggingface.co/infinith/UltraShape/tree/main) ] and place them in your checkpoint directory (e.g., ./checkpoints/).\n\n\n### 2. Generate Coarse Mesh\n\nFirst, use Hunyuan3D-2.1 to generate a coarse mesh from your input image.\n\nRepository: [Tencent-Hunyuan/Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1)\n\nFollow the instructions in the Hunyuan3D-2.1 repository to obtain the initial mesh file (e.g., .glb or .obj).\n\n### 3. Generate Refined Mesh\n\nOnce you have the coarse mesh, use the provided script to run the refinement stage.\n\nRun the inference script:\n```bash\nsh scripts/run.sh\n```\n\n**image**: Path to the reference image.\n\n**mesh**: Path to the coarse mesh.\n\n**output_dir**: Directory to save the refined result.\n\n**ckpt**: Path to the downloaded UltraShape checkpoint.\n\n**step**: the number of DiT inference sampling steps. The default is 50, and it can be reduced to 12 to speed up generation.\n\n*Alternatively, you can run the gradio app for interactive inference:*\n```bash\npython scripts/gradio_app.py --ckpt <path_to_checkpoint>\n```\n\n#### Low VRAM\n1. Use a low value for num_latents (Try 8192)\n2. Use a low chunk_size (Try 2048)\n3. Try the --low_vram arg in gradio_app.py and infer_dit_refine.py\n\n### 4. Data Preparation & Training\n\nFirst, prepare the data, including watertight meshes and rendered images.\nThen, run the sampling script as follows:\n```\npython scripts/sampling.py \\\n    --mesh_json data/mesh_paths.json \\\n    --output_dir data/sample\n```\n\nHere, mesh_json is a list containing the file paths of the watertight meshes.\n\n\nThe multi-node training script is:\n```\nsh train.sh [node_idx]\n```\n\n**training_data_list**: the folder containing train.json and val.json, which store the ID lists for datasets.\n\n**sample_pcd_dir**: the directory containing the sampled .npz files.\n\n**image_data_json**: the file paths of the rendered images.\n\nYou can switch between VAE and DiT training in train.sh, and specify the output directory and configuration file there as well.\n\n## 🔗 BibTeX\n\nIf you found this repository helpful, please cite our reports:\n\n```bibtex\n@article{jia2025ultrashape,\n    title={UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},\n    author={Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},\n    journal={arxiv preprint arXiv:2512.21185},\n    year={2025}\n}\n```\n\n## Acknowledgements\n\nOur code is built upon the excellent work of [Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1). The core idea of our method is greatly inspired by [LATTICE](https://arxiv.org/abs/2512.03052). We deeply appreciate the contributions of these works to the 3D generation community. Please also consider citing **Hunyuan3D 2.1** and **LATTICE**:\n\n- **[Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1)**\n- **[Lattice3D](https://lattice3d.github.io/)**\n"
  },
  {
    "path": "configs/infer_dit_refine.yaml",
    "content": "model:\n  target: ultrashape.pipelines.UltraShapePipeline\n  params:\n    # 1. VAE Config\n    vae_config:\n      target: ultrashape.models.autoencoders.ShapeVAE\n      params:\n        num_latents: &token_num 32768 # infer token_num\n        embed_dim: 64\n        num_freqs: 8\n        include_pi: false\n        heads: 16\n        width: 1024\n        point_feats: 4\n        num_encoder_layers: 8\n        num_decoder_layers: 16\n        pc_size: 409600 # num_s (204800) + num_u (204800)\n        pc_sharpedge_size: 0\n        downsample_ratio: 20\n        qkv_bias: false\n        qk_norm: true\n        scale_factor: 1.0039506158752403\n        geo_decoder_mlp_expand_ratio: 4\n        # geo_decoder_downsample_ratio: 1\n        geo_decoder_ln_post: true\n        enable_flashvdm: true\n        jitter_query: false\n        voxel_query: true\n        voxel_query_res: &voxel_query_res 128\n      \n    # 2. DiT Denoiser Config\n    dit_cfg:\n      target: ultrashape.models.denoisers.dit_mask.RefineDiT\n      params:\n        input_size: *token_num\n        in_channels: 64\n        hidden_size: 2048\n        context_dim: 1024\n        depth: 21\n        num_heads: 16\n        qk_norm: true\n        text_len: 1370\n        qk_norm_type: 'rms'\n        qkv_bias: false\n        num_moe_layers: 6\n        num_experts: 8\n        moe_top_k: 2\n        voxel_query_res: *voxel_query_res\n\n    # 3. Image Encoder Config\n    conditioner_config:\n      target: ultrashape.models.conditioner_mask.SingleImageEncoder\n      params:\n        drop_ratio: 0.0\n        main_image_encoder:\n            type: DinoImageEncoder \n            kwargs:\n                version: 'facebook/dinov2-large' \n                image_size: 1022\n                use_cls_token: true\n        \n    # 4. Scheduler Config\n    scheduler_cfg:\n      target: ultrashape.schedulers.FlowMatchEulerDiscreteScheduler\n      params:\n        num_train_timesteps: 1000\n\n    # 5. Image Processor\n    image_processor_cfg:\n      target: ultrashape.preprocessors.ImageProcessorV2\n      params: \n        size: 1024\n"
  },
  {
    "path": "configs/train_dit_refine.yaml",
    "content": "name: \"UltraShape Refine DiT\"\n\ntraining:\n  # ckpt_path:\n  steps: 10_0000_0000\n  use_amp: true\n  amp_type: \"bf16\"\n  base_lr: 1e-5\n  gradient_clip_val: 1.0\n  gradient_clip_algorithm: \"norm\"\n  every_n_train_steps: 2500\n  val_check_interval: 1000\n  limit_val_batches: 16\n  accumulate_grad_batches: 4\n\ndataset:\n  target: ultrashape.data.objaverse_dit.ObjaverseDataModule\n  params:\n    batch_size: 1\n    num_workers: 4\n    val_num_workers: 4\n\n    # data\n    training_data_list: data/data_list\n    sample_pcd_dir: data/sample\n    image_data_json: data/render.json\n\n    # image\n    image_size: &image_size 1022  # 518\n    mean: &mean [0.5, 0.5, 0.5]\n    std: &std [0.5, 0.5, 0.5]\n    padding: true\n\n    # input_pcd\n    pc_size: &pc_size 163840\n    pc_sharpedge_size: &pc_sharpedge_size 0\n    sharpedge_label: &sharpedge_label true\n    return_normal: true\n\nmodel:\n  target: ultrashape.models.diffusion.flow_matching_dit_trainer.Diffuser\n  params:\n    ckpt_path: ckpt/dit_step=XXX.ckpt\n    scale_by_std: false\n    z_scale_factor: &z_scale_factor 1.0039506158752403\n    torch_compile: false\n\n    vae_config:\n      target: ultrashape.models.autoencoders.ShapeVAE\n      from_pretrained: ckpt/vae_step=XXX.ckpt\n      params:\n        num_latents: &num_latents 8192  # 4096\n        embed_dim: 64\n        num_freqs: 8\n        include_pi: false\n        heads: 16\n        width: 1024\n        point_feats: 4\n        num_encoder_layers: 8\n        num_decoder_layers: 16\n        pc_size: *pc_size\n        pc_sharpedge_size: *pc_sharpedge_size\n        downsample_ratio: 20\n        qkv_bias: false\n        qk_norm: true\n        scale_factor: *z_scale_factor\n        geo_decoder_mlp_expand_ratio: 4\n        geo_decoder_downsample_ratio: 1\n        geo_decoder_ln_post: true\n        enable_flashvdm: true\n        jitter_query: false\n        voxel_query: true\n        voxel_query_res: 128\n\n    cond_config:\n      target: ultrashape.models.conditioner_mask.SingleImageEncoder\n      params:\n        drop_ratio: 0.1\n        # disable_drop: false\n        main_image_encoder:\n            type: DinoImageEncoder \n            kwargs:\n                version: 'facebook/dinov2-large'\n                image_size: *image_size\n                use_cls_token: true\n\n    dit_cfg:\n      target: ultrashape.models.denoisers.dit_mask.RefineDiT\n      params:\n        input_size: *num_latents\n        in_channels: 64\n        hidden_size: 2048\n        context_dim: 1024\n        depth: 21\n        num_heads: 16\n        qk_norm: true\n        text_len: 5330  # 1370\n        qk_norm_type: 'rms'\n        qkv_bias: false\n        num_moe_layers: 6\n        num_experts: 8\n        moe_top_k: 2\n        \n    scheduler_cfg:\n      transport:\n        target: ultrashape.models.diffusion.transport.create_transport\n        params:\n          path_type: Linear\n          prediction: velocity\n      sampler:\n        target: ultrashape.models.diffusion.transport.Sampler\n        params: {}\n        ode_params:\n          sampling_method: euler\n          num_steps: &num_steps 50\n\n    optimizer_cfg:\n      optimizer:\n        target: torch.optim.AdamW\n        params:\n          betas: [0.9, 0.99]\n          eps: 1.e-6\n          weight_decay: 1.e-2\n\n      scheduler:\n        target: ultrashape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler\n        params:\n          warm_up_steps: 500 # 5000\n          f_start: 1.e-6\n          f_min: 1.e-3\n          f_max: 1.0\n\n    pipeline_cfg:\n      target: ultrashape.pipelines.UltraShapePipeline\n\n    image_processor_cfg:\n      target: ultrashape.preprocessors.ImageProcessorV2\n      params: {}\n"
  },
  {
    "path": "configs/train_vae_refine.yaml",
    "content": "name: \"UltraShape Refine VAE\"\n\ntraining:\n  # ckpt_path: \n  steps: 10_0000_0000\n  use_amp: true\n  amp_type: \"bf16\"\n  base_lr: 1e-5\n  gradient_clip_val: 1.0\n  gradient_clip_algorithm: \"norm\"\n  every_n_train_steps: 2500\n  val_check_interval: 1000\n  limit_val_batches: 16\n\ndataset:\n  target: ultrashape.data.objaverse_vae.ObjaverseDataModule\n  params:\n    batch_size: 4\n    num_workers: 4\n    val_num_workers: 4\n\n    # data \n    training_data_list: data/data_list\n    sample_pcd_dir: data/sample\n\n    # input_pcd\n    pc_size: &pc_size 163840\n    pc_sharpedge_size: &pc_sharpedge_size 0\n    sharpedge_label: &sharpedge_label true\n    return_normal: true\n\n    # sup_pcd\n    sup_near_uni_size: 100000\n    sup_near_sharp_size: 100000\n    sup_space_size: 100000\n    tsdf_threshold: 0.01\n\nmodel:\n  target: ultrashape.models.autoencoders.VAETrainer\n  params:\n    ckpt_path: ckpt/vae_step=15000.ckpt\n    torch_compile: false\n    save_dir: outputs/vae_recon\n    mc_res: 512\n    vae_config:\n      target: ultrashape.models.autoencoders.ShapeVAE\n      params:\n        num_latents: &num_latents 8192 # 4096\n        embed_dim: 64\n        num_freqs: 8\n        include_pi: false\n        heads: 16\n        width: 1024\n        point_feats: 4\n        num_encoder_layers: 8\n        num_decoder_layers: 16\n        pc_size: *pc_size\n        pc_sharpedge_size: *pc_sharpedge_size\n        downsample_ratio: 20\n        qkv_bias: false\n        qk_norm: true\n        geo_decoder_mlp_expand_ratio: 4\n        geo_decoder_downsample_ratio: 1\n        geo_decoder_ln_post: true\n        enable_flashvdm: true\n        jitter_query: true\n\n    optimizer_cfg:\n      optimizer:\n        target: torch.optim.AdamW\n        params:\n          betas: [0.9, 0.99]\n          eps: 1.e-6\n          weight_decay: 1.e-2\n\n      scheduler:\n        target: ultrashape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler\n        params:\n          warm_up_steps: 500 # 5000\n          f_start: 1.e-6\n          f_min: 1.e-3\n          f_max: 1.0\n        \n    loss_cfg:\n      lambda_logits: 1.\n      lambda_kl: 0.001\n      # lambda_eik: -1.\n      # lambda_sn: -1.\n      # lambda_sign: -1.\n"
  },
  {
    "path": "docs/carousel.css",
    "content": ".x-carousel-tags {\n    width: 100%;\n    display: flex;\n    align-items: center;\n    justify-content: left;\n    flex-wrap: wrap;\n}\n\n.x-carousel-tag {\n    background-color: rgba(255, 255, 255, 0.9);\n    box-shadow: rgba(0, 0, 0, 0.1) 0px 2px 4px;\n    border: 2px solid transparent;\n    border-radius: 8px;\n    cursor: pointer;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    justify-content: center;\n    transition: all 0.3s ease;\n    color: #2a2a2a;\n    margin: 4px;\n    padding: 8px 16px;\n    text-align: center;\n}\n\n.x-carousel-tag:hover {\n    background-color: rgba(255, 255, 255, 1);\n    transform: translateY(-2px);\n    box-shadow: rgba(0, 0, 0, 0.15) 0px 4px 6px;\n}\n\n.x-carousel-tag.active { border-color: #666; background-color: rgba(255, 255, 255, 1); }\n\n.x-carousel-slider {\n    width: 100%;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: wrap;\n}\n\n.x-carousel-slider-item {\n    max-width: 100%;\n    width: 100%;\n    flex: 1 0 0;\n}\n\n.x-carousel-nav {\n    width: 100%;\n    height: 40px;\n    display: flex;\n    align-items: center;\n    justify-content: space-between;\n}\n\n.x-carousel-switch {\n    width: 50px;\n    height: 25px;\n    margin: 8px;\n    border-radius: 25px;\n    cursor: pointer;\n    user-select: none;\n    color: rgba(42, 42, 42, 0.4);\n    font-size: 24px;\n    font-weight: 500;\n    transition: all 0.25s ease;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n}\n\n.x-carousel-switch:hover {\n    color: rgba(42, 42, 42, 0.8);\n    transform: scale(1.05);\n}\n\n.x-carousel-pages {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: wrap;\n}\n\n.x-carousel-page {\n    width: 10px;\n    height: 10px;\n    border-radius: 10px;\n    background: rgba(42, 42, 42, 0.2);\n    /* background: linear-gradient(107.54deg, #0078d4 .39%, #8661c5 51.23%, #ff9349 100%) fixed; */\n    margin: 0 3px;\n    cursor: pointer;\n    transition: all 0.25s ease;\n}\n\n.x-carousel-page:hover {\n    background: rgba(42, 42, 42, 0.5);\n}\n\n.x-carousel-page.x-carousel-page-active {\n    width: 12px;\n    height: 13px;\n}\n"
  },
  {
    "path": "docs/carousel.js",
    "content": "/**\n * Carousel functionality for research project page\n * Handles navigation, filtering, and page indicators for carousel components\n */\n\n(function() {\n    'use strict';\n\n    /**\n     * Initialize carousel functionality\n     * @param {string} carouselId - ID of the carousel container\n     */\n    function initCarousel(carouselId) {\n        const carousel = document.getElementById(carouselId);\n        if (!carousel) return;\n\n        const slider = carousel.querySelector('.x-carousel-slider');\n        const allItems = carousel.querySelectorAll('.x-carousel-slider-item');\n        const prevBtn = carousel.querySelector('.x-carousel-nav .x-carousel-switch:first-child');\n        const nextBtn = carousel.querySelector('.x-carousel-nav .x-carousel-switch:last-child');\n        const pages = carousel.querySelectorAll('.x-carousel-page');\n        const tags = carousel.querySelectorAll('.x-carousel-tag');\n\n        if (!slider || !allItems.length) return;\n\n        let currentIndex = 0;\n        let currentFilter = 'all'; // Current filter: 'all', 'class1', 'class2', 'class3'\n        let filteredItems = Array.from(allItems); // Currently visible items\n\n        /**\n         * Filter items by tag\n         * @param {string} filter - Filter value: 'all', 'class1', 'class2', 'class3'\n         */\n        function filterItems(filter) {\n            currentFilter = filter;\n            \n            // Filter items based on data-tag attribute\n            if (filter === 'all') {\n                filteredItems = Array.from(allItems);\n            } else {\n                filteredItems = Array.from(allItems).filter(item => {\n                    return item.getAttribute('data-tag') === filter;\n                });\n            }\n\n            // Reset to first item after filtering\n            currentIndex = 0;\n            \n            // Update visibility of all items\n            allItems.forEach(item => {\n                if (filteredItems.includes(item)) {\n                    item.style.display = 'block';\n                } else {\n                    item.style.display = 'none';\n                }\n            });\n\n            // Show first filtered item\n            goToSlide(0);\n            updatePages();\n        }\n\n        /**\n         * Navigate to a specific slide (within filtered items)\n         * @param {number} index - Index of the slide to show\n         */\n        function goToSlide(index) {\n            const totalItems = filteredItems.length;\n            currentIndex = Math.max(0, Math.min(index, totalItems - 1));\n\n            // Hide all filtered items\n            filteredItems.forEach(item => {\n                item.style.display = 'none';\n            });\n\n            // Show current item - use block instead of flex to preserve card's flex layout\n            if (filteredItems[currentIndex]) {\n                filteredItems[currentIndex].style.display = 'block';\n            }\n\n            updatePages();\n            updateButtons();\n        }\n\n        /**\n         * Update page indicators to reflect current slide\n         */\n        function updatePages() {\n            const totalItems = filteredItems.length;\n            pages.forEach((page, index) => {\n                if (index === currentIndex && index < totalItems) {\n                    page.classList.add('x-carousel-page-active');\n                } else {\n                    page.classList.remove('x-carousel-page-active');\n                }\n            });\n            \n            // Hide unused page indicators\n            pages.forEach((page, index) => {\n                if (index >= totalItems) {\n                    page.style.display = 'none';\n                } else {\n                    page.style.display = '';\n                }\n            });\n        }\n\n        /**\n         * Update navigation buttons state (enable/disable at boundaries)\n         */\n        function updateButtons() {\n            const totalItems = filteredItems.length;\n            if (prevBtn) {\n                prevBtn.style.opacity = currentIndex === 0 ? '0.3' : '1';\n                prevBtn.style.cursor = currentIndex === 0 ? 'not-allowed' : 'pointer';\n            }\n            if (nextBtn) {\n                nextBtn.style.opacity = currentIndex === totalItems - 1 ? '0.3' : '1';\n                nextBtn.style.cursor = currentIndex === totalItems - 1 ? 'not-allowed' : 'pointer';\n            }\n        }\n\n        /**\n         * Tag filtering (only for results-gen carousel)\n         */\n        if (tags.length && carouselId === 'results-gen') {\n            tags.forEach((tag) => {\n                tag.addEventListener('click', function() {\n                    const filter = tag.getAttribute('data-filter');\n                    if (!filter) return;\n                    \n                    // Remove active class from all tags\n                    tags.forEach(t => t.classList.remove('active'));\n                    // Add active class to clicked tag\n                    tag.classList.add('active');\n                    // Filter items\n                    filterItems(filter);\n                });\n            });\n        }\n\n        // Previous/Next buttons\n        if (prevBtn) {\n            prevBtn.addEventListener('click', function() {\n                if (currentIndex > 0) {\n                    goToSlide(currentIndex - 1);\n                }\n            });\n        }\n\n        if (nextBtn) {\n            nextBtn.addEventListener('click', function() {\n                const totalItems = filteredItems.length;\n                if (currentIndex < totalItems - 1) {\n                    goToSlide(currentIndex + 1);\n                }\n            });\n        }\n\n        // Page indicators - click to jump to specific slide\n        pages.forEach((page, index) => {\n            page.addEventListener('click', function() {\n                goToSlide(index);\n            });\n        });\n\n        // Initialize - filter to 'all' and show first item\n        filterItems('all');\n    }\n\n    // Initialize all carousels when DOM is ready\n    document.addEventListener('DOMContentLoaded', function() {\n        initCarousel('results-gen');\n        initCarousel('results-recon');\n    });\n})();\n\n\ndocument.addEventListener('DOMContentLoaded', function() {\n    \n    // 1. 预先创建一个用于显示图片的容器（一开始隐藏）\n    const promptImgContainer = document.createElement('div');\n    promptImgContainer.id = 'glb-prompt-image-container';\n    promptImgContainer.style.cssText = `\n        position: fixed; \n        bottom: 20px; \n        right: 20px; \n        width: 200px; \n        height: 200px; \n        z-index: 10000; /* 保证在最上层 */\n        display: none; /* 默认隐藏 */\n        background-color: white;\n        padding: 5px;\n        border-radius: 8px;\n        box-shadow: 0 4px 12px rgba(0,0,0,0.3);\n        cursor: pointer; /* 提示可点击关闭 */\n    `;\n    \n    // 创建图片元素\n    const promptImg = document.createElement('img');\n    promptImg.style.cssText = `\n        width: 100%; \n        height: 100%; \n        object-fit: contain; \n        display: block;\n    `;\n    promptImgContainer.appendChild(promptImg);\n\n    // 添加关闭提示文字（可选）\n    const closeTip = document.createElement('div');\n    closeTip.innerText = \"Click to close\";\n    closeTip.style.cssText = \"position:absolute; top:-25px; right:0; color:white; font-size:12px; background:rgba(0,0,0,0.5); padding:2px 5px; border-radius:4px;\";\n    promptImgContainer.appendChild(closeTip);\n\n    document.body.appendChild(promptImgContainer);\n\n    // 点击图片容器时，自己隐藏\n    promptImgContainer.addEventListener('click', function() {\n        this.style.display = 'none';\n    });\n\n\n    // 2. 为所有的 View GLB 按钮添加点击事件\n    const buttons = document.querySelectorAll('.x-button');\n    \n    buttons.forEach(btn => {\n        btn.addEventListener('click', function(e) {\n            // 获取 HTML 中定义的 data-prompt 属性 (assets/images/1.png)\n            const imgUrl = this.getAttribute('data-prompt');\n            \n            if (imgUrl) {\n                promptImg.src = imgUrl;\n                promptImgContainer.style.display = 'block'; // 显示图片\n            }\n        });\n    });\n\n    // 3. (可选) 如果你的 GLB 查看器有“关闭”按钮（例如 class 为 .close-viewer），\n    // 你需要在这里添加逻辑，让点击关闭查看器时，图片也跟着消失。\n    // 假设关闭按钮的类名是 .close-btn (你需要确认实际类名)\n    /*\n    const closeGlbBtn = document.querySelector('.close-btn-class-name');\n    if(closeGlbBtn) {\n        closeGlbBtn.addEventListener('click', () => {\n             promptImgContainer.style.display = 'none';\n        });\n    }\n    */\n});"
  },
  {
    "path": "docs/index copy.html",
    "content": "<html lang=\"en\"><head><style type=\"text/css\">\n.anticon {\n  display: inline-block;\n  color: inherit;\n  font-style: normal;\n  line-height: 0;\n  text-align: center;\n  text-transform: none;\n  vertical-align: -0.125em;\n  text-rendering: optimizeLegibility;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n.anticon > * {\n  line-height: 1;\n}\n\n.anticon svg {\n  display: inline-block;\n}\n\n.anticon::before {\n  display: none;\n}\n\n.anticon .anticon-icon {\n  display: block;\n}\n\n.anticon[tabindex] {\n  cursor: pointer;\n}\n\n.anticon-spin::before,\n.anticon-spin {\n  display: inline-block;\n  -webkit-animation: loadingCircle 1s infinite linear;\n  animation: loadingCircle 1s infinite linear;\n}\n\n@-webkit-keyframes loadingCircle {\n  100% {\n    -webkit-transform: rotate(360deg);\n    transform: rotate(360deg);\n  }\n}\n\n@keyframes loadingCircle {\n  100% {\n    -webkit-transform: rotate(360deg);\n    transform: rotate(360deg);\n  }\n}\n</style>\n        <meta charset=\"UTF-8\">\n        <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n        <title>UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement</title>\n        <!-- TODO: Replace with UltraShape 1.0 favicon -->\n        <link rel=\"icon\" href=\"assets/favicon.png\">\n        <link rel=\"stylesheet\" href=\"stylesheet.css\">\n        <link rel=\"stylesheet\" href=\"style.css\">\n        <link rel=\"stylesheet\" href=\"window.css\">\n        <link rel=\"stylesheet\" href=\"carousel.css\">\n        <link rel=\"stylesheet\" href=\"main.css\">\n        <link rel=\"stylesheet\" href=\"pv.css\">\n        <link href=\"https://fonts.googleapis.com/icon?family=Material+Icons\" rel=\"stylesheet\">\n        <script type=\"module\" src=\"https://ajax.googleapis.com/ajax/libs/model-viewer/3.3.0/model-viewer.min.js\"></script>\n        <script src=\"carousel.js\" defer></script>\n    </head>\n    <body>\n        <div class=\"animated-gradient hero-expandable\" style=\"width: 100%; height: 90vh; position: relative; overflow: hidden;\">\n    \n            <img \n                src=\"assets\\images\\teaser.png\" \n                alt=\"Hero background\" \n                style=\"position: absolute; top: 0; left: 0; width: 100%; height: 100%; object-fit: cover; opacity: 0.9;\"\n            >\n            <div style=\"position: relative; z-index: 1; height: 100%; display: flex; flex-direction: column; justify-content: center; align-items: center; color: #ffffff; padding: 80px 20px; text-align: center;\">\n\n                <link href=\"https://fonts.googleapis.com/css2?family=Outfit:wght@400;700&family=Inter:wght@300;400&display=swap\" rel=\"stylesheet\">\n\n                <div style=\"font-family: 'Outfit', sans-serif;\">\n                    \n                    <div style=\"\n                        font-size: 100px; \n                        font-weight: 700; \n                        margin-bottom: 24px; \n                        background: linear-gradient(135deg, #ffffff 0%, #c3ecff 100%);\n                        -webkit-background-clip: text;\n                        -webkit-text-fill-color: transparent;\n                        letter-spacing: -2px; \n                        text-shadow: 0px 10px 30px rgba(0,0,0,0.3); \n                    \">\n                        UltraShape 1.0\n                    </div>\n\n                    <div style=\"\n                        font-family: 'Inter', sans-serif;\n                        font-size: 60px; \n                        font-weight: 600; \n                        color: rgba(255, 255, 255, 0.9);\n                        margin-bottom: 48px; \n                        max-width: 1200px; \n                        line-height: 1.6;\n                        letter-spacing: 1px; \n                    \">\n                        High-Fidelity 3D Shape Generation via Scalable Geometric Refinement\n                    </div>\n\n                </div>\n\n                <div style=\"font-size: 22px; font-weight: 400; margin-bottom: 48px; max-width: 1200px; line-height: 1.8; opacity: 0.95;\">\n                    We introduce UltraShape-1.0, a scalable two-stage diffusion framework for high-quality 3D geometry generation, enhanced by an advanced data processing pipeline that ensures geometric details through watertight processing and quality filtering.\n                </div>\n                <div id=\"links\" style=\"display: flex; gap: 16px; flex-wrap: wrap; justify-content: center;\">\n                    <div><a id=\"paper\" href=\"https://arxiv.org/pdf/2512.21185\">Tech Report</a></div>\n                    <div><a id=\"code\" href=\"https://github.com/PKU-YuanGroup/UltraShape-1.0\">Code</a></div>\n                    <div><a id=\"demo\" href=\"#demo\">Demo</a></div>\n                </div>\n            </div>\n        </div>\n\n        <div id=\"main\" style=\"background: #faf9f7;\">\n            <div class=\"x-section-title\" style=\"text-align: center; margin-top: 80px;\">Main Pipeline</div>\n            \n            <div style=\"background: #f8f9fa; padding: 80px 20px;\">\n                <div style=\"max-width: 1200px; margin: 40px auto; border-radius: 16px; overflow: hidden; box-shadow: 0 10px 40px rgba(0,0,0,0.15); position: relative; padding: 20px; background-color: #ffffff;\">\n                    \n                    <img src=\"assets/images/pipeline.png?v=2\" alt=\"Video cover\" style=\"width: 100%; height: auto; display: block;\">\n                    \n                    \n                </div>\n            </div>    \n\n            <div class=\"x-section-title\"  style=\"text-align: center; margin-top: 50px;\">Image to 3D Shape Generation</div>\n            <div id=\"results-gen\">\n                <div class=\"x-carousel-tags\">\n                    <!-- <div class=\"x-carousel-tag active\" data-filter=\"all\">All</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class1\">Object</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class2\">Multi-Object</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class3\">Character</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class4\">Car</div> -->\n                </div>\n                <div class=\"x-carousel-slider\">\n                    <!-- Item 1 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/1.glb\">\n                                View GLB\n                            </div>\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px; padding: 12px 24px; font-size: 24px;\" data-glb=\"assets/meshs/1.glb\" data-prompt=\"assets/images/1.png\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/1.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/1.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 2 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/2.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/2.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/2.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 3 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/3.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/3.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/3.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 4 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/4.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/4.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/4.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 5 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/5.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/5.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/5.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 6 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/6.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/6.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/6.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 7 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/7.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/7.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/7.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 8 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/8.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/8.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/8.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 9 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/9.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/9.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/9.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 10 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/10.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/10.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/10.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 11 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/11.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/11.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/11.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 12 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/12.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/12.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/12.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 13 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/13.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/13.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/13.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                </div>\n                <div class=\"x-carousel-nav\">\n                    <div class=\"x-carousel-switch\">&lt;</div>\n                    <div class=\"x-carousel-pages\">\n                        <div class=\"x-carousel-page x-carousel-page-active\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                    </div>\n                    <div class=\"x-carousel-switch\">&gt;</div>\n                </div>\n            </div>\n\n            <div class=\"x-section-title\" style=\"text-align: center; margin-top: 150px;\">Comparison of 3D Shape Generation</div>\n\n            <div id=\"results-recon\" style=\"max-width: 1200px; margin: 0 auto; padding: 0 20px;\">\n                <div class=\"x-carousel-slider\">\n                    \n                    <!-- 第一张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <!-- 修改1：增加 display: flex; flex-direction: column; -->\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Open-Source Models 1\n                            </div>\n                            <!-- 修改2：将图片高度改为 flex: 1 或 auto，去掉 height: 100% -->\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-open1.png\" \n                                    alt=\"Comparison with Open-Source Models 1\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n\n                    <!-- 第二张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Open-Source Models 2\n                            </div>\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-open2.png\" \n                                    alt=\"Comparison with Open-Source Models 2\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                            <!-- <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 16px; color: #333;\">\n                                Comparison with Open-Source Models 2\n                            </div> -->\n                        </div>\n                    </div>\n\n                    <!-- 第三张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Commercial Models\n                            </div>\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-close.png\" \n                                    alt=\"Comparison with Close Models 1\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                            <!-- <div class=\"x-img-caption\" style=\"padding: 15px; text-align: center; background: #faf9f7; font-size: 16px; color: #333;\">\n                                Comparison with Close Models 1\n                            </div> -->\n                        </div>\n                    </div>\n\n                </div>\n                \n                <div class=\"x-carousel-nav\">\n                    <div class=\"x-carousel-switch\">&lt;</div>\n                    <div class=\"x-carousel-pages\">\n                        <div class=\"x-carousel-page x-carousel-page-active\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                    </div>\n                    <div class=\"x-carousel-switch\">&gt;</div>\n                </div>\n            </div>\n\n            </div>\n            <div class=\"x-small-header\" style=\"text-align: center; margin-top: -100px;\">Citation</div>\n            <div class=\"bibtex-entry\" style=\"max-width: 800px; width: 90%; margin: 20px auto; box-sizing: border-box; background-color: #f5f5f5; padding: 15px; border-radius: 5px; overflow-x: auto;\">\n                <div class=\"line\">\n                    <span class=\"value\">@article{jia2025ultrashape,</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    title={</span>\n                    <span class=\"value\">UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    author={</span>\n                    <span class=\"value\">Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    journal={</span>\n                    <span class=\"value\">arxiv preprint arXiv:2512.21185},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    year={</span>\n                    <span class=\"value\">2025}</span>\n                </div>\n                <div class=\"line\">}</div>\n            </div>\n            <div style=\"height: 100px;\"> </div>\n\n        </div>\n\n        <footer style=\"text-align: center; padding: 20px 0; color: #666; font-size: 14px;\">\n            <p>\n                The website template was borrowed from <a href=\"https://microsoft.github.io/TRELLIS.2/\" target=\"_blank\">TRELLIS.2</a>.\n            </p>\n        </footer>\n\n        <!-- Scroll Down Indicator -->\n        <div id=\"scroll-indicator\" class=\"scroll-indicator\">\n            <svg width=\"32\" height=\"32\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n                <path d=\"M7 13l5 5 5-5\"></path>\n            </svg>\n        </div>\n\n    <div id=\"fullscreen\">\n        <div id=\"window\">\n            <div id=\"close\">✕</div>\n            <div id=\"content\"></div>\n        </div>\n    </div>\n        <script>\n            // Mouse-following gradient effect\n            document.addEventListener('DOMContentLoaded', function() {\n                const gradientEl = document.querySelector('.animated-gradient');\n                if (!gradientEl) return;\n\n                gradientEl.addEventListener('mousemove', function(e) {\n                    const rect = gradientEl.getBoundingClientRect();\n                    const x = ((e.clientX - rect.left) / rect.width) * 100;\n                    const y = ((e.clientY - rect.top) / rect.height) * 100;\n                    \n                    gradientEl.style.setProperty('--mouse-x', x + '%');\n                    gradientEl.style.setProperty('--mouse-y', y + '%');\n                });\n\n                // Reset to center when mouse leaves\n                gradientEl.addEventListener('mouseleave', function() {\n                    gradientEl.style.setProperty('--mouse-x', '50%');\n                    gradientEl.style.setProperty('--mouse-y', '50%');\n                });\n            });\n\n            // Scroll expand effect - hero section expands on scroll\n            (function() {\n                const hero = document.querySelector('.hero-expandable');\n                if (!hero) return;\n\n                const windowHeight = window.innerHeight;\n                const expandThreshold = windowHeight * 0.5; // Start expanding after scrolling 50vh\n                \n                function updateHero() {\n                    const scrollY = window.scrollY;\n                    \n                    if (scrollY === 0) {\n                        // Initial state - hero is 90vh (positioned lower, showing content below)\n                        hero.style.height = '90vh';\n                    } else if (scrollY <= expandThreshold) {\n                        // Expanding phase - hero grows from 90vh to 100vh\n                        const progress = scrollY / expandThreshold;\n                        const newHeight = 90 + (10 * progress);\n                        hero.style.height = newHeight + 'vh';\n                    } else {\n                        // Fully expanded - hero is 100vh\n                        hero.style.height = '100vh';\n                    }\n                }\n\n                // Initial state\n                hero.style.height = '90vh';\n                \n                // Throttle scroll for better performance\n                let ticking = false;\n                window.addEventListener('scroll', function() {\n                    if (!ticking) {\n                        window.requestAnimationFrame(function() {\n                            updateHero();\n                            ticking = false;\n                        });\n                        ticking = true;\n                    }\n                }, { passive: true });\n                \n                // Handle initial load\n                updateHero();\n            })();\n\n            // View GLB functionality\n            (function() {\n                const fullscreen = document.getElementById('fullscreen');\n                const windowEl = document.getElementById('window');\n                const closeBtn = document.getElementById('close');\n                const content = document.getElementById('content');\n\n                if (!fullscreen || !windowEl || !closeBtn || !content) return;\n\n                /**\n                 * Open fullscreen window with GLB model\n                 * @param {string} glbPath - Path to GLB file\n                 * @param {string} prompt - Input prompt text\n                 */\n                function openGLBViewer(glbPath, prompt) {\n                    // Create model-viewer element\n                    content.innerHTML = `\n                        <div style=\"display: flex; flex-wrap: wrap; align-items: start; justify-content: center;\">\n                            <div class=\"modelviewer-container\" style=\"width: 500px; height: 500px;\">\n                                <model-viewer\n                                    src=\"${glbPath}\"\n                                    alt=\"3D Model\"\n                                    camera-controls\n                                    auto-rotate\n                                    interaction-policy=\"allow-when-focused\"\n                                    reveal=\"auto\"\n                                    style=\"width: 100%; height: 100%; background-color: #f5f5f5; display: block;\"\n                                    loading=\"eager\"\n                                    shadow-intensity=\"1\"\n                                    exposure=\"1\"\n                                    tone-mapping=\"neutral\"\n                                >\n                                </model-viewer>\n                            </div>\n                            <div class=\"modelviewer-panel\">\n                                <div class=\"modelviewer-panel-desc\">\n                                    <div>Input Prompt</div>\n                                </div>\n                                <div class=\"modelviewer-panel-prompt\">\n                                    ${prompt || '[PLACEHOLDER: Input prompt text]'}\n                                </div>\n                            </div>\n                        </div>\n                    `;\n\n                    // Show fullscreen window\n                    fullscreen.style.display = 'flex';\n                    // Trigger opacity transition after a short delay to ensure model-viewer is ready\n                    setTimeout(() => {\n                        fullscreen.style.opacity = '1';\n                        // Get the model-viewer element and add error handling\n                        const modelViewer = content.querySelector('model-viewer');\n                        if (modelViewer) {\n                            // Add load success handler\n                            modelViewer.addEventListener('load', () => {\n                                console.log('Model loaded successfully:', glbPath);\n                                // Force update after load\n                                setTimeout(() => {\n                                    modelViewer.style.width = '100%';\n                                    modelViewer.style.height = '100%';\n                                    modelViewer.style.display = 'block';\n                                }, 100);\n                            });\n                            \n                            // Add error handler\n                            modelViewer.addEventListener('error', (e) => {\n                                console.error('Model loading error:', e);\n                                console.error('Failed path:', glbPath);\n                                const container = content.querySelector('.modelviewer-container');\n                                if (container) {\n                                    container.innerHTML = `\n                                        <div style=\"display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; color: #d32f2f; text-align: center; padding: 20px;\">\n                                            <div>\n                                                <div style=\"font-size: 16px; margin-bottom: 8px; font-weight: 600;\">Failed to load model</div>\n                                                <div style=\"font-size: 12px; color: #666; margin-bottom: 8px;\">Path: ${glbPath}</div>\n                                                <div style=\"font-size: 11px; color: #999;\">Please check the file path and ensure the file exists</div>\n                                            </div>\n                                        </div>\n                                    `;\n                                }\n                            });\n                        } else {\n                            console.error('model-viewer element not found after creation');\n                        }\n                    }, 100);\n                }\n\n                /**\n                 * Close fullscreen window\n                 */\n                function closeGLBViewer() {\n                    fullscreen.style.opacity = '0';\n                    setTimeout(() => {\n                        fullscreen.style.display = 'none';\n                        content.innerHTML = '';\n                    }, 250); // Match transition duration\n                }\n\n                // Add click handlers to all View GLB buttons\n                document.addEventListener('click', function(e) {\n                    const button = e.target.closest('.x-button');\n                    if (button && button.textContent.trim() === 'View GLB') {\n                        const glbPath = button.getAttribute('data-glb');\n                        const prompt = button.getAttribute('data-prompt');\n                        if (glbPath) {\n                            openGLBViewer(glbPath, prompt);\n                        }\n                    }\n                });\n\n                // Close button handler\n                closeBtn.addEventListener('click', closeGLBViewer);\n\n                // Close on backdrop click\n                fullscreen.addEventListener('click', function(e) {\n                    if (e.target === fullscreen) {\n                        closeGLBViewer();\n                    }\n                });\n\n                // Close on Escape key\n                document.addEventListener('keydown', function(e) {\n                    if (e.key === 'Escape' && fullscreen.style.display === 'flex') {\n                        closeGLBViewer();\n                    }\n                });\n            })();\n        </script>\n    </body></html>"
  },
  {
    "path": "docs/index.html",
    "content": "<html lang=\"en\"><head><style type=\"text/css\">\n.anticon {\n  display: inline-block;\n  color: inherit;\n  font-style: normal;\n  line-height: 0;\n  text-align: center;\n  text-transform: none;\n  vertical-align: -0.125em;\n  text-rendering: optimizeLegibility;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n.anticon > * {\n  line-height: 1;\n}\n\n.anticon svg {\n  display: inline-block;\n}\n\n.anticon::before {\n  display: none;\n}\n\n.anticon .anticon-icon {\n  display: block;\n}\n\n.anticon[tabindex] {\n  cursor: pointer;\n}\n\n.anticon-spin::before,\n.anticon-spin {\n  display: inline-block;\n  -webkit-animation: loadingCircle 1s infinite linear;\n  animation: loadingCircle 1s infinite linear;\n}\n\n@-webkit-keyframes loadingCircle {\n  100% {\n    -webkit-transform: rotate(360deg);\n    transform: rotate(360deg);\n  }\n}\n\n@keyframes loadingCircle {\n  100% {\n    -webkit-transform: rotate(360deg);\n    transform: rotate(360deg);\n  }\n}\n</style>\n        <meta charset=\"UTF-8\">\n        <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n        <title>UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement</title>\n        <!-- TODO: Replace with UltraShape 1.0 favicon -->\n        <link rel=\"icon\" href=\"assets/favicon.png\">\n        <link rel=\"stylesheet\" href=\"stylesheet.css\">\n        <link rel=\"stylesheet\" href=\"style.css\">\n        <link rel=\"stylesheet\" href=\"window.css\">\n        <link rel=\"stylesheet\" href=\"carousel.css\">\n        <link rel=\"stylesheet\" href=\"main.css\">\n        <link rel=\"stylesheet\" href=\"pv.css\">\n        <link href=\"https://fonts.googleapis.com/icon?family=Material+Icons\" rel=\"stylesheet\">\n        <script type=\"module\" src=\"https://ajax.googleapis.com/ajax/libs/model-viewer/3.3.0/model-viewer.min.js\"></script>\n        <script src=\"carousel.js\" defer></script>\n    </head>\n    <body>\n        <div class=\"animated-gradient hero-expandable\" style=\"width: 100%; height: 90vh; position: relative; overflow: hidden;\">\n    \n            <img \n                src=\"assets\\images\\teaser.png\" \n                alt=\"Hero background\" \n                style=\"position: absolute; top: 0; left: 0; width: 100%; height: 100%; object-fit: cover; opacity: 0.9;\"\n            >\n            <div style=\"position: relative; z-index: 1; height: 100%; display: flex; flex-direction: column; justify-content: center; align-items: center; color: #ffffff; padding: 80px 20px; text-align: center;\">\n\n                <link href=\"https://fonts.googleapis.com/css2?family=Outfit:wght@400;700&family=Inter:wght@300;400&display=swap\" rel=\"stylesheet\">\n\n                <div style=\"font-family: 'Outfit', sans-serif;\">\n                    \n                    <div style=\"\n                        font-size: 100px; \n                        font-weight: 700; \n                        margin-bottom: 24px; \n                        background: linear-gradient(135deg, #ffffff 0%, #c3ecff 100%);\n                        -webkit-background-clip: text;\n                        -webkit-text-fill-color: transparent;\n                        letter-spacing: -2px; \n                        text-shadow: 0px 10px 30px rgba(0,0,0,0.3); \n                    \">\n                        UltraShape 1.0\n                    </div>\n\n                    <div style=\"\n                        font-family: 'Inter', sans-serif;\n                        font-size: 60px; \n                        font-weight: 600; \n                        color: rgba(255, 255, 255, 0.9);\n                        margin-bottom: 48px; \n                        max-width: 1200px; \n                        line-height: 1.6;\n                        letter-spacing: 1px; \n                    \">\n                        High-Fidelity 3D Shape Generation via Scalable Geometric Refinement\n                    </div>\n\n                </div>\n\n                <div style=\"font-size: 22px; font-weight: 400; margin-bottom: 48px; max-width: 1200px; line-height: 1.8; opacity: 0.95;\">\n                    We introduce UltraShape-1.0, a scalable two-stage diffusion framework for high-quality 3D geometry generation, enhanced by an advanced data processing pipeline that ensures geometric details through watertight processing and quality filtering.\n                </div>\n                <div id=\"links\" style=\"display: flex; gap: 16px; flex-wrap: wrap; justify-content: center;\">\n                    <div><a id=\"paper\" href=\"https://arxiv.org/pdf/2512.21185\">Tech Report</a></div>\n                    <div><a id=\"code\" href=\"https://github.com/PKU-YuanGroup/UltraShape-1.0\">Code</a></div>\n                    <div><a id=\"demo\" href=\"#demo\">Demo</a></div>\n                </div>\n            </div>\n        </div>\n\n        <div id=\"main\" style=\"background: #faf9f7;\">\n            <div class=\"x-section-title\" style=\"text-align: center; margin-top: 80px;\">Main Pipeline</div>\n            \n            <div style=\"background: #f8f9fa; padding: 80px 20px;\">\n                <div style=\"max-width: 1200px; margin: 40px auto; border-radius: 16px; overflow: hidden; box-shadow: 0 10px 40px rgba(0,0,0,0.15); position: relative; padding: 20px; background-color: #ffffff;\">\n                    \n                    <img src=\"assets/images/pipeline.png?v=2\" alt=\"Video cover\" style=\"width: 100%; height: auto; display: block;\">\n                    \n                    \n                </div>\n            </div>    \n\n            <div class=\"x-section-title\"  style=\"text-align: center; margin-top: 50px;\">Image to 3D Shape Generation</div>\n            <div id=\"results-gen\">\n                <div class=\"x-carousel-tags\">\n                    <!-- <div class=\"x-carousel-tag active\" data-filter=\"all\">All</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class1\">Object</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class2\">Multi-Object</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class3\">Character</div> -->\n                    <!-- <div class=\"x-carousel-tag\" data-filter=\"class4\">Car</div> -->\n                </div>\n                <div class=\"x-carousel-slider\">\n                    <!-- Item 1 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/1.glb\"\n                                data-prompt=\"assets/images/1.png\">\n                                View GLB\n                            </div>\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px; padding: 12px 24px; font-size: 24px;\" data-glb=\"assets/meshs/1.glb\" data-prompt=\"assets/images/1.png\">View GLB</div> -->\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/1.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/1.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 2 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/2.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/2.glb\"\n                                data-prompt=\"assets/images/2.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/2.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/2.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 3 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/3.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/3.glb\"\n                                data-prompt=\"assets/images/3.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/3.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/3.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 4 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/4.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/4.glb\"\n                                data-prompt=\"assets/images/4.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/4.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/4.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 5 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/5.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/5.glb\"\n                                data-prompt=\"assets/images/5.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/5.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/5.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 6 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/6.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/6.glb\"\n                                data-prompt=\"assets/images/6.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/6.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/6.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 7 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/7.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/7.glb\"\n                                data-prompt=\"assets/images/7.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/7.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/7.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 8 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/8.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/8.glb\"\n                                data-prompt=\"assets/images/8.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/8.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/8.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 9 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/9.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/9.glb\"\n                                data-prompt=\"assets/images/9.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/9.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/9.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 10 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class3\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/10.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/10.glb\"\n                                data-prompt=\"assets/images/10.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/10.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/10.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 11 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/11.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/11.glb\"\n                                data-prompt=\"assets/images/11.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/11.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/11.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 12 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/12.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/12.glb\"\n                                data-prompt=\"assets/images/12.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/12.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/12.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                    <!-- Item 13 -->\n                    <div class=\"x-carousel-slider-item\" data-tag=\"class1\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 70vh;\">\n                            <!-- <div class=\"x-button\" style=\"position: absolute; top: 10px; right: 10px\" data-glb=\"assets/meshs/13.glb\" data-prompt=\"Input Prompt 1\">View GLB</div> -->\n                            <div class=\"x-button\" \n                                style=\"position: absolute; top: 10px; right: 85px; padding: 15px 30px; \n                                        font-size: 22px; font-weight: bold; z-index: 10; cursor: pointer;\" \n                                data-glb=\"https://huggingface.co/datasets/infinith/ultrashape_page/resolve/main/13.glb\"\n                                data-prompt=\"assets/images/13.png\">\n                                View GLB\n                            </div>\n                            <div style=\"width: 85%; aspect-ratio: 1; background-color: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;\">\n                                <video \n                                    autoplay \n                                    loop \n                                    muted \n                                    playsinline \n                                    style=\"width: 100%; height: 100%; object-fit: contain;\">\n                                    <source src=\"assets/videos/13.mp4\" type=\"video/mp4\">\n                                </video>\n                            </div>\n                            <div style=\"width: 200px; aspect-ratio: 1; background-color: transparent; border-radius: 4px; overflow: hidden;\">\n                                <img \n                                    src=\"assets/images/13.png\" \n                                    alt=\"Input Image\"\n                                    style=\"width: 100%; height: 100%; object-fit: contain; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n                </div>\n                <div class=\"x-carousel-nav\">\n                    <div class=\"x-carousel-switch\">&lt;</div>\n                    <div class=\"x-carousel-pages\">\n                        <div class=\"x-carousel-page x-carousel-page-active\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                    </div>\n                    <div class=\"x-carousel-switch\">&gt;</div>\n                </div>\n            </div>\n\n            <div class=\"x-section-title\" style=\"text-align: center; margin-top: 150px;\">Comparison of 3D Shape Generation</div>\n\n            <div id=\"results-recon\" style=\"max-width: 1200px; margin: 0 auto; padding: 0 20px;\">\n                <div class=\"x-carousel-slider\">\n                    \n                    <!-- 第一张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <!-- 修改1：增加 display: flex; flex-direction: column; -->\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Open-Source Models 1\n                            </div>\n                            <!-- 修改2：将图片高度改为 flex: 1 或 auto，去掉 height: 100% -->\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-open1.png\" \n                                    alt=\"Comparison with Open-Source Models 1\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                        </div>\n                    </div>\n\n                    <!-- 第二张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Open-Source Models 2\n                            </div>\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-open2.png\" \n                                    alt=\"Comparison with Open-Source Models 2\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                            <!-- <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 16px; color: #333;\">\n                                Comparison with Open-Source Models 2\n                            </div> -->\n                        </div>\n                    </div>\n\n                    <!-- 第三张图 -->\n                    <div class=\"x-carousel-slider-item\" style=\"flex-basis: calc(100% / 1);\">\n                        <div class=\"x-card\" style=\"height: 86vh; background-color: #faf9f7; border-radius: 4px; display: flex; flex-direction: column; padding: 0 0 40px 0;\">\n                            <div class=\"x-img-caption\" style=\"padding: 10px; text-align: center; background: #faf9f7; font-size: 24px; color: #333; flex-shrink: 0;\">\n                                Comparison with Commercial Models\n                            </div>\n                            <div style=\"flex: 1; min-height: 0; padding: 20px; background-color: #ffffff; border-radius: 16px; box-shadow: 0 10px 40px rgba(0,0,0,0.15); margin: 0 20px 20px 20px; display: flex; align-items: center; justify-content: center;\">\n                                <img \n                                    src=\"assets/images/comp-close.png\" \n                                    alt=\"Comparison with Close Models 1\"\n                                    style=\"max-width: 100%; max-height: 100%; width: auto; height: auto; display: block;\"\n                                >\n                            </div>\n                            <!-- <div class=\"x-img-caption\" style=\"padding: 15px; text-align: center; background: #faf9f7; font-size: 16px; color: #333;\">\n                                Comparison with Close Models 1\n                            </div> -->\n                        </div>\n                    </div>\n\n                </div>\n                \n                <div class=\"x-carousel-nav\">\n                    <div class=\"x-carousel-switch\">&lt;</div>\n                    <div class=\"x-carousel-pages\">\n                        <div class=\"x-carousel-page x-carousel-page-active\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                        <div class=\"x-carousel-page\"></div>\n                    </div>\n                    <div class=\"x-carousel-switch\">&gt;</div>\n                </div>\n            </div>\n\n            </div>\n            <div class=\"x-small-header\" style=\"text-align: center; margin-top: -100px;\">Citation</div>\n            <div class=\"bibtex-entry\" style=\"max-width: 800px; width: 90%; margin: 20px auto; box-sizing: border-box; background-color: #f5f5f5; padding: 15px; border-radius: 5px; overflow-x: auto;\">\n                <div class=\"line\">\n                    <span class=\"value\">@article{jia2025ultrashape,</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    title={</span>\n                    <span class=\"value\">UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    author={</span>\n                    <span class=\"value\">Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    journal={</span>\n                    <span class=\"value\">arxiv preprint arXiv:2512.21185},</span>\n                </div>\n                <div class=\"line\">\n                    <span class=\"key\">    year={</span>\n                    <span class=\"value\">2025}</span>\n                </div>\n                <div class=\"line\">}</div>\n            </div>\n            <div style=\"height: 100px;\"> </div>\n\n        </div>\n\n        <footer style=\"text-align: center; padding: 20px 0; color: #666; font-size: 14px;\">\n            <p>\n                The website template was borrowed from <a href=\"https://microsoft.github.io/TRELLIS.2/\" target=\"_blank\">TRELLIS.2</a>.\n            </p>\n        </footer>\n\n        <!-- Scroll Down Indicator -->\n        <div id=\"scroll-indicator\" class=\"scroll-indicator\">\n            <svg width=\"32\" height=\"32\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n                <path d=\"M7 13l5 5 5-5\"></path>\n            </svg>\n        </div>\n\n    <div id=\"fullscreen\">\n        <div id=\"window\">\n            <div id=\"close\">✕</div>\n            <div id=\"content\"></div>\n        </div>\n    </div>\n        <script>\n            // Mouse-following gradient effect\n            document.addEventListener('DOMContentLoaded', function() {\n                const gradientEl = document.querySelector('.animated-gradient');\n                if (!gradientEl) return;\n\n                gradientEl.addEventListener('mousemove', function(e) {\n                    const rect = gradientEl.getBoundingClientRect();\n                    const x = ((e.clientX - rect.left) / rect.width) * 100;\n                    const y = ((e.clientY - rect.top) / rect.height) * 100;\n                    \n                    gradientEl.style.setProperty('--mouse-x', x + '%');\n                    gradientEl.style.setProperty('--mouse-y', y + '%');\n                });\n\n                // Reset to center when mouse leaves\n                gradientEl.addEventListener('mouseleave', function() {\n                    gradientEl.style.setProperty('--mouse-x', '50%');\n                    gradientEl.style.setProperty('--mouse-y', '50%');\n                });\n            });\n\n            // Scroll expand effect - hero section expands on scroll\n            (function() {\n                const hero = document.querySelector('.hero-expandable');\n                if (!hero) return;\n\n                const windowHeight = window.innerHeight;\n                const expandThreshold = windowHeight * 0.5; // Start expanding after scrolling 50vh\n                \n                function updateHero() {\n                    const scrollY = window.scrollY;\n                    \n                    if (scrollY === 0) {\n                        // Initial state - hero is 90vh (positioned lower, showing content below)\n                        hero.style.height = '90vh';\n                    } else if (scrollY <= expandThreshold) {\n                        // Expanding phase - hero grows from 90vh to 100vh\n                        const progress = scrollY / expandThreshold;\n                        const newHeight = 90 + (10 * progress);\n                        hero.style.height = newHeight + 'vh';\n                    } else {\n                        // Fully expanded - hero is 100vh\n                        hero.style.height = '100vh';\n                    }\n                }\n\n                // Initial state\n                hero.style.height = '90vh';\n                \n                // Throttle scroll for better performance\n                let ticking = false;\n                window.addEventListener('scroll', function() {\n                    if (!ticking) {\n                        window.requestAnimationFrame(function() {\n                            updateHero();\n                            ticking = false;\n                        });\n                        ticking = true;\n                    }\n                }, { passive: true });\n                \n                // Handle initial load\n                updateHero();\n            })();\n\n            // View GLB functionality\n            (function() {\n                const fullscreen = document.getElementById('fullscreen');\n                const windowEl = document.getElementById('window');\n                const closeBtn = document.getElementById('close');\n                const content = document.getElementById('content');\n\n                if (!fullscreen || !windowEl || !closeBtn || !content) return;\n\n                /**\n                 * Open fullscreen window with GLB model\n                 * @param {string} glbPath - Path to GLB file\n                 * @param {string} image - Input prompt image\n                 */\n                function openGLBViewer(glbPath, imagePath) {\n                    const overlay = document.createElement('div');\n                    overlay.style.cssText = `\n                        position: fixed; top: 0; left: 0; width: 100%; height: 100%;\n                        background: rgba(0, 0, 0, 0.85); z-index: 9999;\n                        display: flex; align-items: center; justify-content: center;\n                        backdrop-filter: blur(5px);\n                    `;\n\n                    const container = document.createElement('div');\n                    container.style.cssText = `\n                        width: 80%; height: 80%; background: #fff; border-radius: 12px;\n                        display: flex; overflow: hidden; box-shadow: 0 10px 30px rgba(0,0,0,0.5);\n                        position: relative;\n                    `;\n\n                    const imgContainer = document.createElement('div');\n                    imgContainer.style.cssText = `\n                        flex: 1; background: #f0f0f0; display: flex; \n                        align-items: center; justify-content: center; border-right: 1px solid #ddd;\n                        flex-direction: column;\n                    `;\n                    const imgTitle = document.createElement('div');\n                    imgTitle.innerText = \"Input Image\";\n                    imgTitle.style.cssText = \"margin-bottom: 10px; font-weight: bold; color: #555;\";\n                    const img = document.createElement('img');\n                    img.src = imagePath;\n                    img.style.cssText = \"max-width: 90%; max-height: 80%; object-fit: contain; border-radius: 4px;\";\n                    imgContainer.appendChild(imgTitle);\n                    imgContainer.appendChild(img);\n\n                    const modelContainer = document.createElement('div');\n                    modelContainer.style.cssText = `\n                        flex: 1; background: #e0e0e0; display: flex; \n                        align-items: center; justify-content: center; flex-direction: column; position: relative;\n                    `;\n                    const modelTitle = document.createElement('div');\n                    modelTitle.innerText = \"Generated Mesh\";\n                    modelTitle.style.cssText = \"position: absolute; top: 20px; font-weight: bold; color: #555; z-index: 10;\";\n\n                    const viewer = document.createElement('model-viewer');\n                    viewer.src = glbPath;\n                    viewer.setAttribute('auto-rotate', '');\n                    viewer.setAttribute('camera-controls', '');\n                    viewer.setAttribute('shadow-intensity', '1');\n                    viewer.style.cssText = \"width: 100%; height: 100%;\";\n\n                    viewer.addEventListener('load', () => {\n                        if (viewer.model && viewer.model.materials.length > 0) {\n                            const material = viewer.model.materials[0];\n                            \n                            // [Red, Green, Blue, Alpha] (0.0 - 1.0)\n                            material.pbrMetallicRoughness.setBaseColorFactor([0.8, 0.35, 0.2, 1.0]); \n                            material.pbrMetallicRoughness.setBaseColorFactor([0.203, 0.374, 0.637, 1.0]); \n                            // const r = Math.random();\n                            // const g = Math.random();\n                            // const b = Math.random();\n                            // material.pbrMetallicRoughness.setBaseColorFactor([r, g, b, 1.0]);\n\n                            material.pbrMetallicRoughness.setMetallicFactor(0.1); \n                            material.pbrMetallicRoughness.setRoughnessFactor(0.7);\n                        }\n                    });\n\n                    modelContainer.appendChild(modelTitle);\n                    modelContainer.appendChild(viewer);\n\n                    const closeBtn = document.createElement('button');\n                    closeBtn.innerHTML = \"&times;\";\n                    closeBtn.style.cssText = `\n                        position: absolute; top: 15px; right: 20px; \n                        background: none; border: none; color: #333; font-size: 30px; \n                        cursor: pointer; z-index: 100; line-height: 1;\n                    `;\n                    closeBtn.onclick = () => document.body.removeChild(overlay);\n                    overlay.onclick = (e) => { if (e.target === overlay) document.body.removeChild(overlay); };\n\n                    container.appendChild(imgContainer);\n                    container.appendChild(modelContainer);\n                    container.appendChild(closeBtn);\n                    overlay.appendChild(container);\n                    document.body.appendChild(overlay);\n                }\n\n                /**\n                 * Close fullscreen window\n                 */\n                function closeGLBViewer() {\n                    fullscreen.style.opacity = '0';\n                    setTimeout(() => {\n                        fullscreen.style.display = 'none';\n                        content.innerHTML = '';\n                    }, 250); // Match transition duration\n                }\n\n                // Add click handlers to all View GLB buttons\n                document.addEventListener('click', function(e) {\n                    const button = e.target.closest('.x-button');\n                    if (button && button.textContent.trim() === 'View GLB') {\n                        const glbPath = button.getAttribute('data-glb');\n                        const prompt = button.getAttribute('data-prompt');\n                        if (glbPath) {\n                            openGLBViewer(glbPath, prompt);\n                        }\n                    }\n                });\n\n                // Close button handler\n                closeBtn.addEventListener('click', closeGLBViewer);\n\n                // Close on backdrop click\n                fullscreen.addEventListener('click', function(e) {\n                    if (e.target === fullscreen) {\n                        closeGLBViewer();\n                    }\n                });\n\n                // Close on Escape key\n                document.addEventListener('keydown', function(e) {\n                    if (e.key === 'Escape' && fullscreen.style.display === 'flex') {\n                        closeGLBViewer();\n                    }\n                });\n            })();\n        </script>\n    </body></html>"
  },
  {
    "path": "docs/main.css",
    "content": "* {\n    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n}\n\na {\n    color: #5b6acf;\n    text-decoration: none;\n}\n  \na.link:focus, a.link:hover {\n    color: #8b5cf6;\n    text-decoration: none;\n}\n\nhtml {\n    scroll-behavior: smooth;\n    scroll-snap-type: y proximity;\n}\n\nbody {\n    background: #faf9f7;\n    position: relative;\n    margin: 0px;\n    padding: 0px;\n    color: #2a2a2a;\n    overflow-x: hidden;\n}\n\np {\n    position: relative;\n    margin: 16px;\n    font-size: 16px;\n    font-weight: 300;\n    text-align: justify;\n}\n\np span {\n    font-weight: 500;\n}\n\n.x-row {\n    width: 100%;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: nowrap;\n}\n\n.x-column {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: nowrap;\n    flex-direction: column;\n}\n\n.x-center-text {\n    margin: 16px 32px;\n    text-align: center;\n}\n\n.x-left-align {\n    display: flex;\n    align-items: center;\n    justify-content: left;\n    flex-wrap: nowrap;\n}\n\n.x-right-align {\n    display: flex;\n    align-items: center;\n    justify-content: right;\n    flex-wrap: nowrap;\n}\n\n.x-flex-spacer {\n    flex: 1;\n}\n\n.x-labels {\n    position: absolute;\n    top: 8px;\n    right: 6px;\n    display: flex;\n    align-items: center;\n    justify-content: left;\n    flex-direction: row-reverse;\n}\n\n.x-label {\n    height: 20px;\n    padding: 0px 6px;\n    margin: 0px 2px;\n    color: #2a2a2a;\n    font-size: 12px;\n    font-weight: 600;\n    background: rgba(45, 45, 45, 0.1);\n    border-radius: 16px;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n}\n\n.x-button {\n    height: 36px;\n    padding: 0px 14px;\n    background: rgba(45, 45, 45, 0.08);\n    color: #2a2a2a;\n    border-radius: 50px;\n    box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);\n    font-size: 16px;\n    font-weight: 300;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    cursor: pointer;\n    transition: all 0.2s ease;\n}\n\n.x-button.small {\n    height: 32px;\n    padding: 0px 12px;\n    border-radius: 50px;\n    font-size: 14px;\n    font-weight: 600;\n}\n\n.x-button:hover {\n    background: rgba(45, 45, 45, 0.15);\n    transform: translateY(-2px);\n    box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);\n}\n\n.x-button.disabled {\n    background: rgba(45, 45, 45, 0.05);\n    color: rgba(42, 42, 42, 0.4);\n    cursor: default;\n}\n\n.x-button.disabled:hover {\n    background: rgba(45, 45, 45, 0.05);\n    transform: none;\n}\n\n.x-gradient-font {\n    background: linear-gradient(270deg, #845ade 0%, #2e6ed6 25%, #ff7d4b 75%, #ec9b0b 100%);\n    background-clip: text;\n    -webkit-background-clip: text;\n    -webkit-text-fill-color: transparent;\n}\n\n.x-gradient-block {\n    color: #3f3f3f;\n    background: linear-gradient(270deg, #845ade2f 0%, #2e6ed62f 25%, #ff7d4b2f 75%, #ec9b0b2f 100%);\n    border-radius: 16px;\n}\n\n.x-gradient-border {\n    position: relative;\n    padding: 1px;\n    margin: 3px;\n    border: 3px;\n    background: white;\n    background-clip: padding-box;\n    border: solid border transparent;\n    border-radius: 16px;\n}\n\n.x-gradient-border::before {\n    content: '';\n    position: absolute;\n    top: 0; right: 0; bottom: 0; left: 0;\n    z-index: -1;\n    margin: -3px;\n    border-radius: 16px;\n    background: linear-gradient(270deg, #845ade 0%, #2e6ed6 25%, #ff7d4b 75%, #ec9b0b 100%);\n}\n\n.x-section-title {\n    text-align: center;\n    margin: 100px 0px 48px 0px;\n    font-size: 36px;\n    font-weight: 600;\n    letter-spacing: 2px;\n    text-transform: uppercase;\n    color: #333;\n    position: relative;\n    padding-bottom: 20px;\n}\n\n.x-section-title::after {\n    content: '';\n    position: absolute;\n    bottom: 0;\n    left: 50%;\n    transform: translateX(-50%);\n    width: 80px;\n    height: 3px;\n    background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);\n    border-radius: 2px;\n}\n\n.x-note {\n    color: rgba(42, 42, 42, 0.7);\n    font-size: 14px;\n    font-weight: 300;\n}\n\n.x-card {\n    position: relative;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: wrap;\n}\n\n.x-card .caption {\n    height: 200px;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    justify-content: center;\n    color: #2a2a2a;\n    font-size: 16px;\n    font-weight: 600;\n    width: 100%;\n}\n\n.x-handwriting {\n    width: 100%;\n    font-family: 'Segoe Print';\n    font-size: 12px;\n    font-weight: 600;\n    line-height: 1.5;\n    color: black;\n    text-align: justify;\n}\n\n.x-image-prompt {\n    position: relative;\n    height: calc(100% - 2px);\n    aspect-ratio: 1/1;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n}\n\n.x-image-prompt img {\n    max-width: 100%;\n    max-height: 100%;\n}\n\n.x-small-header {\n    text-align: center;\n    margin-top: 64px;\n    margin-bottom: 24px;\n    margin-left: 4px;\n    font-size: 16px;\n    font-weight: 500;\n    letter-spacing: 4px;\n    text-transform: uppercase;\n    color: #666;\n}\n\n.x-dot-card {\n    background: rgba(255, 255, 255, 0.8);\n    border-radius: 16px;\n    padding: 24px;\n    display: flex;\n    flex-direction: column;\n    box-shadow: 0px 2px 8px rgba(0, 0, 0, 0.08);\n}\n\n.x-dot-card-title {\n    margin-top: 0;\n    margin-bottom: 0px;\n    font-size: 20px; \n    color: #2a2a2a;\n    display: flex;\n    align-items: center;\n    gap: 10px;\n}\n\n#main {\n    max-width: 1000px;\n    margin: 0px auto;\n    padding-bottom: 200px;\n}\n\n.author-info {\n    display: flex;\n    justify-content: center;\n    align-items: center;\n    gap: 32px;\n    padding: 8px;\n}\n\n.author-link {\n    color: #2a2a2a;\n    text-decoration: none;\n    font-weight: 500;\n}\n  \n.author-link:focus, .author-link:hover {\n    text-decoration: underline;\n}\n\n.affiliation-link {\n    font-size: 14px;\n    color: rgba(42, 42, 42, 0.7);\n    text-decoration: none;\n    font-weight: 300;\n}\n\n#links {\n    margin: 16px 0;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    flex-wrap: wrap;\n}\n\n#links div {\n    margin: 4px 8px;\n    height: 38px;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n}\n\n#links a {\n    height: 20px;\n    padding: 8px 16px;\n    color: #2a2a2a;\n    font-size: 16px;\n    font-weight: 300;\n    background: rgba(255, 255, 255, 0.9);\n    border-radius: 50px;\n    box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    transition: all 0.2s ease;\n}\n\n#links a:hover {\n    background: rgba(255, 255, 255, 1);\n    transform: translateY(-2px);\n    box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);\n}\n\n#links a.disabled {\n    background-color: rgba(200, 200, 200, 0.5);\n    color: rgba(42, 42, 42, 0.5);\n}\n\n#links a.disabled:hover {\n    background-color: rgba(200, 200, 200, 0.5);\n}\n\n#links a::before {\n    /* Material Icons */\n    font-family: 'Material Icons' !important;\n    font-style: normal;\n    font-weight: normal;\n    font-variant: normal;\n    text-transform: none;\n    line-height: 1;\n    letter-spacing: normal;\n    word-wrap: normal;\n    white-space: nowrap;\n    direction: ltr;\n\n    /* Better Font Rendering =========== */\n    -webkit-font-smoothing: antialiased;\n    -moz-osx-font-smoothing: grayscale;\n    text-rendering: optimizeLegibility;\n    font-feature-settings: 'liga';\n\n    margin-right: 8px;\n    font-size: 20px;\n}\n\n#links #paper::before {\n    content: \"description\";\n}\n\n#links #arxiv::before {\n    content: \"article\";\n}\n\n#links #code::before {\n    content: \"code\";\n}\n\n#links #poster::before {\n    content: \"picture_as_pdf\";\n}\n\n#links #video::before {\n    content: \"play_circle\";\n}\n\n#links #demo::before {\n    content: \"rocket_launch\";\n}\n\n.feature-container {\n    max-width: 1000px;\n    margin: 32px auto;\n    font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, \"Segoe UI\", Roboto, \"Helvetica Neue\", Arial, sans-serif;\n}\n\n.feature-tabs {\n    display: grid;\n    grid-template-columns: repeat(4, 1fr);\n    gap: 16px;\n}\n\n.feature-tab {\n    aspect-ratio: 1 / 1;\n    background-color: rgba(255, 255, 255, 0.8);\n    border: 2px solid transparent;\n    border-radius: 12px;\n    cursor: pointer;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    justify-content: center;\n    transition: all 0.3s ease;\n    color: #666;\n    padding: 10px;\n    text-align: center;\n    box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.08);\n}\n\n.feature-tab:hover {\n    background-color: rgba(255, 255, 255, 1);\n    transform: translateY(-2px);\n    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.12);\n}\n\n.feature-tab.active-yellow { border-color: #facd5c; color: #facd5c; background-color: rgba(251, 191, 36, 0.1); }\n.feature-tab.active-red    { border-color: #d17969; color: #d17969; background-color: rgba(209, 121, 105, 0.1); }\n.feature-tab.active-blue   { border-color: #60a5fa; color: #60a5fa; background-color: rgba(96, 165, 250, 0.1); }\n.feature-tab.active-purple { border-color: #b7a5ff; color: #b7a5ff; background-color: rgba(167, 139, 250, 0.1); }\n\n.feature-tab svg {\n    width: 64px;\n    height: 64px;\n    margin-bottom: 8px;\n    fill: currentColor;\n}\n\n.feature-tab span {\n    font-size: 15px;\n    font-weight: 400;\n    line-height: 1.2;\n}\n\n.feature-panel {\n    display: none;\n    padding: 32px;\n    animation: fadeIn 0.4s ease;\n}\n\n.feature-panel.active {\n    display: block;\n}\n\n@keyframes fadeIn {\n    from { opacity: 0; transform: translateY(10px); }\n    to { opacity: 1; transform: translateY(0); }\n}\n\n@media (max-width: 600px) {\n    .feature-tabs {\n        grid-template-columns: repeat(2, 1fr);\n    }\n    .feature-panel {\n        padding: 16px;\n    }\n}\n\n.bibtex-entry {\n    margin: 32px auto;\n    max-width: 900px;\n    padding: 0px 24px;\n    color: #2a2a2a;\n    font-family: consolas, monospace;\n    white-space: pre;\n    text-wrap: wrap;\n    font-size: 14px;\n    font-weight: 300;\n    display: flex;\n    flex-direction: column;\n    align-items: left;\n    justify-content: center;\n}\n\n.line {\n    display: grid;\n    grid-template-columns: max-content 1fr; \n    gap: 0; \n}\n\n.key {\n    font-family: consolas, monospace;\n    text-align: right;\n    padding-right: 0; \n}\n\n.value {\n    font-family: consolas, monospace;\n    word-break: break-word; \n}\n\n#bottombar {\n    position: absolute;\n    bottom: 0px;\n    height: 100px;\n    width: 100%;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    justify-content: space-around;\n    user-select: none;\n}\n\n#bottombar .row {\n    width: 90%;\n    padding: 0px 5%;\n    display: flex;\n    align-items: center;\n    justify-content: space-between;\n    user-select: none;\n}\n\n#bottombar div {\n    color: rgba(42, 42, 42, 0.7);\n    font-size: 12px;\n    font-weight: 500;\n}\n\n#bottombar div a {\n    color: rgba(42, 42, 42, 0.7);\n    font-size: 12px;\n    font-weight: 500;\n}\n\n#bottombar div a:hover {\n    color: rgba(42, 42, 42, 1);\n    font-size: 12px;\n}\n\n#bottombar div span {\n    font-weight: 700;\n}\n\n.scroll-indicator {\n    position: fixed;\n    bottom: 4px;\n    left: 50%;\n    transform: translateX(-50%);\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    color: #666;\n    cursor: pointer;\n    z-index: 1000;\n    \n    animation: scroll-bounce 2s infinite;\n    transition: opacity 0.5s ease, visibility 0.5s;\n}\n\n.scroll-indicator.hidden {\n    opacity: 0;\n    visibility: hidden;\n}\n\n@keyframes scroll-bounce {\n    0%, 20%, 50%, 80%, 100% {\n        transform: translateX(-50%) translateY(0);\n    }\n    40% {\n        transform: translateX(-50%) translateY(-10px);\n    }\n    60% {\n        transform: translateX(-50%) translateY(-5px);\n    }\n}\n\n/* Animated Gradient Background with Mouse Follow */\n.animated-gradient {\n    --mouse-x: 50%;\n    --mouse-y: 50%;\n    background: \n        radial-gradient(circle at var(--mouse-x) var(--mouse-y), rgba(102, 126, 234, 0.8) 0%, transparent 50%),\n        radial-gradient(circle at calc(100% - var(--mouse-x)) calc(100% - var(--mouse-y)), rgba(118, 75, 162, 0.8) 0%, transparent 50%),\n        radial-gradient(circle at var(--mouse-x) calc(100% - var(--mouse-y)), rgba(240, 147, 251, 0.6) 0%, transparent 50%),\n        linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%);\n    background-size: 200% 200%;\n    animation: gradient-shift 20s ease infinite;\n    transition: background 0.1s ease-out;\n}\n\n/* Hero expandable effect */\n.hero-expandable {\n    position: sticky;\n    top: 0;\n    z-index: 10;\n    transition: height 0.6s cubic-bezier(0.4, 0, 0.2, 1);\n    margin-bottom: 0;\n}\n\n@keyframes gradient-shift {\n    0% {\n        background-position: 0% 50%;\n    }\n    50% {\n        background-position: 100% 50%;\n    }\n    100% {\n        background-position: 0% 50%;\n    }\n}"
  },
  {
    "path": "docs/pv.css",
    "content": ".pv-video-wrapper {\n    position: relative;\n    width: 100%;\n    aspect-ratio: 16 / 9;\n    margin: 0 auto;\n    background-color: #2a2a2a;\n    overflow: hidden;\n    border-radius: 8px;\n}\n\n.pv-video-element {\n    width: 100%;\n    height: 100%;\n    display: block;\n    object-fit: contain;\n}\n\n.pv-poster-overlay {\n    position: absolute;\n    top: 0;\n    left: 0;\n    width: 100%;\n    height: 100%;\n    cursor: pointer;\n    z-index: 10;\n    display: flex;\n    justify-content: center;\n    align-items: center;\n    background-size: cover;\n    background-position: center;\n    transition: opacity 0.3s ease;\n}\n\n.pv-play-btn {\n    width: 64px;\n    height: 64px;\n    background-color: rgba(255, 255, 255, 0.9);\n    border: 2px solid #2a2a2a;\n    border-radius: 50%;\n    display: flex;\n    justify-content: center;\n    align-items: center;\n    transition: transform 0.2s ease, background-color 0.2s;\n    box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);\n}\n\n.pv-play-btn::after {\n    content: '';\n    display: block;\n    width: 0;\n    height: 0;\n    border-top: 10px solid transparent;\n    border-bottom: 10px solid transparent;\n    border-left: 16px solid #2a2a2a;\n    margin-left: 4px;\n}\n\n.pv-poster-overlay:hover .pv-play-btn {\n    transform: scale(1.1);\n    background-color: rgba(255, 255, 255, 1);\n}\n\n.pv-video-wrapper.is-playing .pv-poster-overlay {\n    opacity: 0;\n    pointer-events: none;\n}"
  },
  {
    "path": "docs/style.css",
    "content": "/* Icomoon font removed - now using Material Icons via CDN */\n/* Material Icons are loaded in index.html via Google Fonts CDN */\n"
  },
  {
    "path": "docs/stylesheet.css",
    "content": "/* Font definitions removed - using system Segoe UI font */\n/* All font-face declarations for Avenir Next Cyr have been removed */\n/* The site now uses 'Segoe UI' as defined in main.css and style.css */\n"
  },
  {
    "path": "docs/window.css",
    "content": "#fullscreen {\n    position: fixed;\n    top: 0;\n    left: 0;\n    width: 100vw;\n    height: 100vh;\n    background: transparent;\n    display: none;\n    align-items: center;\n    justify-content: center;\n    z-index: 1000;\n    user-select: none;\n    backdrop-filter: blur(10px);\n    opacity: 0;\n    transition: opacity 0.25s ease;\n}\n\n#fullscreen #window {\n    position: relative;\n    min-width: 25vw;\n    min-height: 25vh;\n    max-width: 100vw;\n    max-height: 90vh;\n    background: #ffffff;\n    border-radius: 16px;\n    box-shadow: 0px 4px 16px rgba(0, 0, 0, 0.2);\n    padding: 8px;\n    \n}\n\n#fullscreen #window #close {\n    position: absolute;\n    top: 0px;\n    right: 0px;\n    width: 31px;\n    height: 30px;\n    padding: 0px 0px 2px 1px;\n    color: black;\n    font-size: 16px;\n    font-weight: 700;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    cursor: pointer;\n    transition: all 0.2s ease;\n    z-index: 100;\n}\n\n#fullscreen #window #close {\n    color: #2a2a2a;\n}\n\n#fullscreen #window #close:hover {\n    color: #d32f2f;\n}\n\n#fullscreen #window #content {\n    max-width: calc(100vw - 16px);\n    max-height: calc(90vh - 16px);\n    overflow-x: hidden;\n    overflow-y: auto;\n}\n\n.modelviewer-container {\n    width: 500px;\n    height: 500px;\n    margin: 8px;\n    border-radius: 8px;\n    background: white;\n    box-shadow: inset 0px 0px 4px rgba(0, 0, 0, 0.25);\n    overflow: hidden;\n    position: relative;\n}\n\n.modelviewer-container model-viewer {\n    width: 100% !important;\n    height: 100% !important;\n    display: block !important;\n    background-color: #f5f5f5;\n}\n\n.modelviewer-container model-viewer button {\n    height: 16px;\n    padding: 0px 6px;\n    background: rgba(255, 255, 255, 0.75);\n    border-radius: 50px;\n    box-shadow: 0px 0px 4px rgba(0, 0, 0, 0.25);\n    border: none;\n    font-size: 12px;\n    font-weight: 300;\n    display: none;\n    opacity: 0;\n    align-items: center;\n    justify-content: center;\n    pointer-events: none;\n}\n\n.modelviewer-panel {\n    width: 300px;\n    margin: 8px;\n    margin-top: 0px;\n    display: flex;\n    flex-direction: column;\n    align-items: start;\n    justify-content: start;\n}\n\n.modelviewer-panel-desc {\n    width: 100%;\n}\n\n.modelviewer-panel-desc div {\n    font-size: 16px;\n    font-weight: 500;\n    margin: 4px;\n}\n\n.modelviewer-panel-prompt {\n    width: calc(100% - 16px);\n    height: 250px;\n    padding: 8px;\n    background: #f5f5f5;\n    border-radius: 8px;\n    box-shadow: inset 0px 0px 4px rgba(0, 0, 0, 0.1);\n    display: flex;\n    align-items: start;\n    justify-content: center;\n    overflow-y: auto;\n    user-select: text;\n}\n\n.modelviewer-panel-button {\n    height: 40px;\n    margin: 4px 4px;\n    padding: 0px 14px;\n    background: rgba(45, 45, 45, 0.08);\n    color: #2a2a2a;\n    border-radius: 50px;\n    box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);\n    font-size: 16px;\n    font-weight: 300;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    cursor: pointer;\n    transition: all 0.2s ease;\n}\n\n.modelviewer-panel-button.small {\n    height: 32px;\n    padding: 0px 12px;\n    border-radius: 50px;\n    font-size: 14px;\n    font-weight: 300;\n}\n\n.modelviewer-panel-button.tiny {\n    height: 24px;\n    padding: 0px 10px;\n    border-radius: 50px;\n    font-size: 12px;\n    font-weight: 300;\n}\n\n.modelviewer-panel-button:hover {\n    background: rgba(45, 45, 45, 0.15);\n    transform: translateY(-2px);\n    box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);\n}\n\n.modelviewer-panel-button.checked {\n    border: 2px solid #2a2a2a;\n    background: rgba(45, 45, 45, 0.1);\n    color: #2a2a2a;\n}\n"
  },
  {
    "path": "main.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\n\n# import warnings\n# warnings.filterwarnings(\"ignore\")\n\nimport os\nimport torch\nimport argparse\nfrom pathlib import Path\nfrom typing import Tuple, List\nfrom omegaconf import OmegaConf, DictConfig\nfrom einops._torch_specific import allow_ops_in_compiled_graph  # requires einops>=0.6.1\nallow_ops_in_compiled_graph()\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback\nfrom pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy\nfrom pytorch_lightning.loggers import Logger, TensorBoardLogger\nfrom pytorch_lightning.utilities import rank_zero_info\n\nfrom ultrashape.utils import get_config_from_file, instantiate_from_config\n\n\nclass SetupCallback(Callback):\n    def __init__(self, config: DictConfig, basedir: Path, logdir: str = \"log\", ckptdir: str = \"ckpt\") -> None:\n        super().__init__()\n        self.logdir = basedir / logdir\n        self.ckptdir = basedir / ckptdir\n        self.config = config\n\n    def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:\n        if trainer.global_rank == 0:\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n\n\ndef setup_callbacks(config: DictConfig) -> Tuple[List[Callback], Logger]:\n    training_cfg = config.training\n    basedir = Path(training_cfg.output_dir)\n    os.makedirs(basedir, exist_ok=True)\n    all_callbacks = []\n\n    setup_callback = SetupCallback(config, basedir)\n    all_callbacks.append(setup_callback)\n    \n    checkpoint_callback = ModelCheckpoint(\n        dirpath=setup_callback.ckptdir,\n        filename=\"ckpt-{step:08d}\",\n        save_top_k=-1,\n        verbose=False,\n        every_n_train_steps=training_cfg.every_n_train_steps)\n    all_callbacks.append(checkpoint_callback)\n\n    if \"callbacks\" in config:\n        for key, value in config['callbacks'].items():\n            custom_callback = instantiate_from_config(value)\n            all_callbacks.append(custom_callback)\n\n    logger = TensorBoardLogger(save_dir=str(setup_callback.logdir), name=\"tensorboard\")\n\n    return all_callbacks, logger\n\n\ndef merge_cfg(cfg, arg_cfg):\n    for key in arg_cfg.keys():\n        if key in cfg.training:\n            arg_cfg[key] = cfg.training[key]\n    cfg.training = DictConfig(arg_cfg)\n    return cfg\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--fast\", action='store_true')\n    parser.add_argument(\"-c\", \"--config\", type=str, required=True)\n    parser.add_argument(\"-s\", \"--seed\", type=int, default=0)\n    parser.add_argument(\"-nn\", \"--num_nodes\", type=int, default=1)\n    parser.add_argument(\"-ng\", \"--num_gpus\", type=int, default=1)\n    parser.add_argument(\"-u\", \"--update_every\", type=int, default=1)\n    parser.add_argument(\"-st\", \"--steps\", type=int, default=50000000)\n    parser.add_argument(\"-lr\", \"--base_lr\", type=float, default=4.5e-6)\n    parser.add_argument(\"-a\", \"--use_amp\", default=False, action=\"store_true\")\n    parser.add_argument(\"--amp_type\", type=str, default=\"16\")\n    parser.add_argument(\"--gradient_clip_val\", type=float, default=None)\n    parser.add_argument(\"--gradient_clip_algorithm\", type=str, default=None)\n    parser.add_argument(\"--every_n_train_steps\", type=int, default=50000)\n    parser.add_argument(\"--log_every_n_steps\", type=int, default=50)\n    parser.add_argument(\"--val_check_interval\", type=int, default=1024)\n    parser.add_argument(\"--limit_val_batches\", type=int, default=64)\n    parser.add_argument(\"--monitor\", type=str, default=\"val/total_loss\")\n    parser.add_argument(\"--output_dir\", type=str, help=\"the output directory to save everything.\")\n    parser.add_argument(\"--ckpt_path\", type=str, default=\"\", help=\"the restore checkpoints.\")\n    parser.add_argument(\"--deepspeed\", default=False, action=\"store_true\")\n    parser.add_argument(\"--deepspeed2\", default=False, action=\"store_true\")\n    parser.add_argument(\"--scale_lr\", type=bool, nargs=\"?\", const=True, default=False,\n                        help=\"scale base-lr by ngpu * batch_size * n_accumulate\")\n    return parser.parse_args()\n    \n\nif __name__ == \"__main__\":\n    \n    args = get_args()\n    \n    if args.fast:\n        torch.backends.cudnn.allow_tf32 = True\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.set_float32_matmul_precision('medium')\n        torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.05\n\n    # Set random seed\n    pl.seed_everything(args.seed, workers=True)\n\n    # Load configuration\n    config = get_config_from_file(args.config)\n    config = merge_cfg(config, vars(args))\n    training_cfg = config.training\n\n    # print config\n    rank_zero_info(\"Begin to print configuration ...\")\n    rank_zero_info(OmegaConf.to_yaml(config))\n    rank_zero_info(\"Finish print ...\")\n\n    # Setup callbacks\n    callbacks, loggers = setup_callbacks(config)\n\n    # Build data modules\n    data: pl.LightningDataModule = instantiate_from_config(config.dataset)\n\n    # Build model\n    model: pl.LightningModule = instantiate_from_config(config.model)\n    \n    nodes = args.num_nodes\n    ngpus = args.num_gpus\n    base_lr = training_cfg.base_lr\n    accumulate_grad_batches = training_cfg.update_every\n    batch_size = config.dataset.params.batch_size\n\n    if 'NNODES' in os.environ:\n        nodes = int(os.environ['NNODES'])\n        training_cfg.num_nodes = nodes\n        args.num_nodes = nodes\n\n    if args.scale_lr:\n        model.learning_rate = accumulate_grad_batches * nodes * ngpus * batch_size * base_lr\n        info = f\"Setting learning rate to {model.learning_rate:.2e} = {accumulate_grad_batches} (accumulate)\"\n        info += f\" * {nodes} (nodes) * {ngpus} (num_gpus) * {batch_size} (batchsize) * {base_lr:.2e} (base_lr)\"\n        rank_zero_info(info)\n    else:\n        model.learning_rate = base_lr\n        rank_zero_info(\"++++ NOT USING LR SCALING ++++\")\n        rank_zero_info(f\"Setting learning rate to {model.learning_rate:.2e}\")\n\n    # Build trainer\n    if args.num_nodes > 1 or args.num_gpus > 1:\n        if args.deepspeed:\n            ddp_strategy = DeepSpeedStrategy(stage=1)\n        elif args.deepspeed2:\n            ddp_strategy = 'deepspeed_stage_2'\n        else:\n            ddp_strategy = DDPStrategy(find_unused_parameters=False, bucket_cap_mb=1500)\n    else:\n        ddp_strategy = 'ddp'\n\n    rank_zero_info(f'*' * 100)\n    if training_cfg.use_amp:\n        amp_type = training_cfg.amp_type\n        assert amp_type in ['bf16', '16', '32'], f\"Invalid amp_type: {amp_type}\"\n        rank_zero_info(f'Using {amp_type} precision')\n    else:\n        amp_type = 32\n        rank_zero_info(f'Using 32 bit precision')\n    rank_zero_info(f'*' * 100)\n\n    trainer = pl.Trainer(\n        max_steps=training_cfg.steps,\n        precision=amp_type,\n        callbacks=callbacks,\n        accelerator=\"gpu\",\n        devices=args.num_gpus,\n        num_nodes=training_cfg.num_nodes,\n        strategy=ddp_strategy,\n        gradient_clip_val=training_cfg.get('gradient_clip_val'),\n        gradient_clip_algorithm=training_cfg.get('gradient_clip_algorithm'),\n        accumulate_grad_batches=args.update_every,\n        logger=loggers,\n        log_every_n_steps=training_cfg.log_every_n_steps,\n        val_check_interval=training_cfg.val_check_interval,\n        limit_val_batches=training_cfg.limit_val_batches,\n        check_val_every_n_epoch=None\n    )\n\n    # Train\n    if training_cfg.ckpt_path == '': \n        training_cfg.ckpt_path = None\n    trainer.fit(model, datamodule=data, ckpt_path=training_cfg.ckpt_path)\n"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate==1.1.1\ndiffusers==0.30.0\ndeepspeed\ndiso==0.1.4\neinops==0.8.1\nflash_attn==2.8.3\nhuggingface_hub==0.36.0\nimageio==2.36.0\nipywidgets==8.1.7\njaxtyping==0.3.4\nmatplotlib==3.10.8\nnumpy==1.24.4\nomegaconf==2.3.0\nopencv_python==4.10.0.84\nopencv_python_headless==4.11.0.86\npandas==2.3.3\nPillow==12.0.0\npymeshlab==2022.2.post3\npythreejs==2.4.2\npytorch_lightning==1.9.5\nPyYAML==6.0.2\nsafetensors==0.7.0\nsageattention==1.0.6\nscikit-image==0.24.0\nonnxruntime\nrembg\ntensorboard\ntimm==1.0.22\ntorchdiffeq==0.2.5\ntqdm==4.66.5\ntransformers==4.37.2\ntrimesh==4.4.7\ntypeguard==4.3.0\nwandb==0.23.1\n"
  },
  {
    "path": "scripts/gradio_app.py",
    "content": "import argparse\nimport gc\nimport os\nimport sys\n\nimport gradio as gr\nimport torch\nfrom omegaconf import OmegaConf\n\n# Add project root to path\nsys.path.append(os.getcwd())\n\nfrom ultrashape.rembg import BackgroundRemover\nfrom ultrashape.utils.misc import instantiate_from_config\nfrom ultrashape.surface_loaders import SharpEdgeSurfaceLoader\nfrom ultrashape.utils import voxelize_from_point\nfrom ultrashape.pipelines import UltraShapePipeline\n\n# Global variables to cache the model\nMODEL_CACHE = {}\n\n\ndef get_pipeline_cached(config_path, ckpt_path, device='cuda', low_vram=False):\n    # Check if we have a valid cached pipeline for this checkpoint\n    if \"pipeline\" in MODEL_CACHE and MODEL_CACHE.get(\"ckpt_path\") == ckpt_path:\n        print(\"Using cached pipeline...\")\n        return MODEL_CACHE[\"pipeline\"], MODEL_CACHE[\"config\"]\n\n    # Clear old cache if it exists (e.g. different checkpoint)\n    if MODEL_CACHE:\n        print(\"Clearing old model cache...\")\n        MODEL_CACHE.clear()\n        gc.collect()\n        torch.cuda.empty_cache()\n\n    print(f\"Loading config from {config_path}...\")\n    config = OmegaConf.load(config_path)\n\n    print(\"Instantiating VAE...\")\n    vae = instantiate_from_config(config.model.params.vae_config)\n\n    print(\"Instantiating DiT...\")\n    dit = instantiate_from_config(config.model.params.dit_cfg)\n\n    print(\"Instantiating Conditioner...\")\n    conditioner = instantiate_from_config(config.model.params.conditioner_config)\n\n    print(\"Instantiating Scheduler & Processor...\")\n    scheduler = instantiate_from_config(config.model.params.scheduler_cfg)\n    image_processor = instantiate_from_config(config.model.params.image_processor_cfg)\n\n    print(f\"Loading weights from {ckpt_path}...\")\n    weights = torch.load(ckpt_path, map_location='cpu')\n\n    vae.load_state_dict(weights['vae'], strict=True)\n    dit.load_state_dict(weights['dit'], strict=True)\n    conditioner.load_state_dict(weights['conditioner'], strict=True)\n\n    vae.eval().to(device)\n    dit.eval().to(device)\n    conditioner.eval().to(device)\n\n    if hasattr(vae, 'enable_flashvdm_decoder'):\n        vae.enable_flashvdm_decoder()\n\n    print(\"Creating Pipeline...\")\n    pipeline = UltraShapePipeline(\n        vae=vae,\n        model=dit,\n        scheduler=scheduler,\n        conditioner=conditioner,\n        image_processor=image_processor\n    )\n\n    if low_vram:\n        pipeline.enable_model_cpu_offload()\n\n    MODEL_CACHE[\"pipeline\"] = pipeline\n    MODEL_CACHE[\"config\"] = config\n    MODEL_CACHE[\"ckpt_path\"] = ckpt_path\n\n    return pipeline, config\n\n\ndef predict(\n        image_input,\n        mesh_input,\n        steps,\n        scale,\n        octree_res,\n        num_latents,\n        chunk_size,\n        seed,\n        remove_bg,\n        ckpt_path,\n        low_vram\n):\n    # Aggressive memory cleanup at start\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    try:\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        config_path = \"configs/infer_dit_refine.yaml\"\n\n        if not os.path.exists(config_path):\n            raise FileNotFoundError(f\"Config not found at {config_path}\")\n\n        pipeline, config = get_pipeline_cached(config_path, ckpt_path, device, low_vram)\n\n        voxel_res = config.model.params.vae_config.params.voxel_query_res\n\n        print(f\"Initializing Surface Loader (Token Num: {num_latents})...\")\n        loader = SharpEdgeSurfaceLoader(\n            num_sharp_points=204800,\n            num_uniform_points=204800,\n        )\n\n        print(f\"Processing inputs...\")\n        if image_input is None:\n            raise ValueError(\"Image input is required\")\n        if mesh_input is None:\n            raise ValueError(\"Mesh input is required\")\n\n        # Handle image input\n        if isinstance(image_input, dict):\n            # In newer gradio versions Image component might return a dict for mask etc, but usually just PIL/numpy\n            # if type='pil' it is PIL.Image\n            pass\n\n        image = image_input.convert(\"RGBA\")\n\n        if remove_bg or image.mode != 'RGBA':\n            rembg = BackgroundRemover()\n            image = rembg(image)\n\n        # Handle mesh input - Gradio Model3D returns path to file\n        surface = loader(mesh_input, normalize_scale=scale).to(device, dtype=torch.float16)\n        pc = surface[:, :, :3]  # [B, N, 3]\n\n        # Voxelize\n        _, voxel_idx = voxelize_from_point(pc, num_latents, resolution=voxel_res)\n\n        print(\"Running diffusion process...\")\n        gen_device = \"cpu\" if low_vram else device\n        generator = torch.Generator(gen_device).manual_seed(int(seed))\n\n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            mesh_out_list, _ = pipeline(\n                image=image,\n                voxel_cond=voxel_idx,\n                generator=generator,\n                box_v=1.0,\n                mc_level=0.0,\n                octree_resolution=int(octree_res),\n                num_chunks=int(chunk_size),\n                num_inference_steps=int(steps)\n            )\n\n        # Save output\n        output_dir = \"outputs_gradio\"\n        os.makedirs(output_dir, exist_ok=True)\n        base_name = \"output\"\n        save_path = os.path.join(output_dir, f\"{base_name}_refined.glb\")\n\n        mesh_out = mesh_out_list[0]\n        mesh_out.export(save_path)\n        print(f\"Successfully saved to {save_path}\")\n\n        return save_path\n\n    finally:\n        # Aggressive memory cleanup at end\n        gc.collect()\n        torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"UltraShape Gradio App\")\n    parser.add_argument(\"--ckpt\", type=str, required=True, help=\"Path to split checkpoint (.pt)\")\n    parser.add_argument(\"--share\", action=\"store_true\", help=\"Share the gradio app\")\n    parser.add_argument(\"--low_vram\", action=\"store_true\", help=\"Optimize for low VRAM usage\")\n\n    args = parser.parse_args()\n\n    # Define Gradio Interface\n    with gr.Blocks(title=\"UltraShape Inference\") as demo:\n        gr.Markdown(\"# UltraShape Inference: Mesh & Image Refinement\")\n\n        with gr.Row():\n            with gr.Column():\n                image_input = gr.Image(type=\"pil\", label=\"Input Image\", image_mode=\"RGBA\")\n                mesh_input = gr.Model3D(label=\"Input Coarse Mesh (.glb, .obj)\")\n\n                with gr.Accordion(\"Advanced Parameters\", open=True):\n                    steps = gr.Slider(minimum=1, maximum=200, value=50, step=1, label=\"Inference Steps\")\n                    scale = gr.Slider(minimum=0.1, maximum=2.0, value=0.99, label=\"Mesh Normalization Scale\")\n                    octree_res = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label=\"Octree Resolution\")\n                    num_latents = gr.Slider(minimum=1024, maximum=32768, value=32768, step=128,\n                                            label=\"Number of Latent Tokens (Use 8192 if OOM)\")\n                    chunk_size = gr.Slider(minimum=512, maximum=10000, value=2048, step=512,\n                                           label=\"Chunk Size (Use 2000 if OOM)\")\n                    seed = gr.Number(value=42, label=\"Random Seed\")\n                    remove_bg = gr.Checkbox(label=\"Remove Background\", value=False)\n\n                run_btn = gr.Button(\"Run Inference\", variant=\"primary\")\n\n            with gr.Column():\n                output_model = gr.Model3D(label=\"Refined Output Mesh\")\n\n        run_btn.click(\n            fn=lambda img, mesh, s, sc, oct, nml, chk, sd, rm: predict(img, mesh, s, sc, oct, nml, chk, sd, rm, args.ckpt,\n                                                                      args.low_vram),\n            inputs=[image_input, mesh_input, steps, scale, octree_res, num_latents, chunk_size, seed, remove_bg],\n            outputs=[output_model]\n        )\n\n    demo.launch(share=args.share, server_name='0.0.0.0', server_port=7860)\n"
  },
  {
    "path": "scripts/infer_dit_refine.py",
    "content": "import os\nimport sys\nimport argparse\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom omegaconf import OmegaConf\n\n# project_root = '[your_project_root_path]'  # Replace with your project root path\n# sys.path.insert(0, project_root)\n\nfrom ultrashape.rembg import BackgroundRemover\nfrom ultrashape.utils.misc import instantiate_from_config\nfrom ultrashape.surface_loaders import SharpEdgeSurfaceLoader\nfrom ultrashape.utils import voxelize_from_point\nfrom ultrashape.pipelines import UltraShapePipeline \n\ndef load_models(config_path, ckpt_path, device='cuda'):\n\n    print(f\"Loading config from {config_path}...\")\n    config = OmegaConf.load(config_path)\n    \n    print(\"Instantiating VAE...\")\n    vae = instantiate_from_config(config.model.params.vae_config)\n    \n    print(\"Instantiating DiT...\")\n    dit = instantiate_from_config(config.model.params.dit_cfg)\n    \n    print(\"Instantiating Conditioner...\")\n    conditioner = instantiate_from_config(config.model.params.conditioner_config)\n    \n    print(\"Instantiating Scheduler & Processor...\")\n    scheduler = instantiate_from_config(config.model.params.scheduler_cfg)\n    image_processor = instantiate_from_config(config.model.params.image_processor_cfg)\n    \n    print(f\"Loading weights from {ckpt_path}...\")\n    weights = torch.load(ckpt_path, map_location='cpu')\n    \n    vae.load_state_dict(weights['vae'], strict=True)\n    dit.load_state_dict(weights['dit'], strict=True)\n    conditioner.load_state_dict(weights['conditioner'], strict=True)\n    \n    vae.eval().to(device)\n    dit.eval().to(device)\n    conditioner.eval().to(device)\n    \n    if hasattr(vae, 'enable_flashvdm_decoder'):\n        vae.enable_flashvdm_decoder()\n\n    components = {\n        \"vae\": vae,\n        \"dit\": dit,\n        \"conditioner\": conditioner,\n        \"scheduler\": scheduler,\n        \"image_processor\": image_processor,\n    }\n    \n    return components, config\n\ndef run_inference(args):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    components, config = load_models(args.config, args.ckpt, device)\n    \n    pipeline = UltraShapePipeline(\n        vae=components['vae'],\n        model=components['dit'],\n        scheduler=components['scheduler'],\n        conditioner=components['conditioner'],\n        image_processor=components['image_processor']\n    )\n\n    if args.low_vram:\n        pipeline.enable_model_cpu_offload()\n\n    token_num = args.num_latents\n    voxel_res = config.model.params.vae_config.params.voxel_query_res\n    \n    print(f\"Initializing Surface Loader (Token Num: {token_num})...\")\n    loader = SharpEdgeSurfaceLoader(\n        num_sharp_points=204800,\n        num_uniform_points=204800,\n    )\n\n    print(f\"Processing inputs: {args.image} & {args.mesh}\")\n    image = Image.open(args.image)\n    \n    if args.remove_bg or image.mode != 'RGBA':\n        rembg = BackgroundRemover()\n        image = rembg(image)\n    \n    surface = loader(args.mesh, normalize_scale=args.scale).to(device, dtype=torch.float16)\n    pc = surface[:, :, :3] # [B, N, 3]\n    \n    # Voxelize\n    _, voxel_idx = voxelize_from_point(pc, token_num, resolution=voxel_res)\n    \n    print(\"Running diffusion process...\")\n    generator = torch.Generator(device).manual_seed(args.seed)\n    \n    with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n        mesh, _ = pipeline(\n            image=image,\n            voxel_cond=voxel_idx,\n            generator=generator,\n            box_v=1.0,\n            mc_level=0.0,\n            octree_resolution=args.octree_res,\n            num_inference_steps=args.steps,\n            num_chunks=args.chunk_size,\n        )\n    \n    os.makedirs(args.output_dir, exist_ok=True)\n    base_name = os.path.splitext(os.path.basename(args.image))[0]\n    save_path = os.path.join(args.output_dir, f\"{base_name}_refined.glb\")\n    \n    mesh = mesh[0]\n    mesh.export(save_path)\n    print(f\"Successfully saved to {save_path}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"UltraShape Inference Script\")\n    \n    parser.add_argument(\"--config\", type=str, default=\"configs/infer_dit2.yaml\", help=\"Path to inference config\")\n    parser.add_argument(\"--ckpt\", type=str, required=True, help=\"Path to split checkpoint (.pt)\")\n    parser.add_argument(\"--low_vram\", action=\"store_true\", help=\"Optimize for low VRAM usage\")\n    \n    parser.add_argument(\"--image\", type=str, required=True, help=\"Input image path\")\n    parser.add_argument(\"--mesh\", type=str, required=True, help=\"Input coarse mesh (.glb/.obj)\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"outputs\", help=\"Output directory\")\n    \n    parser.add_argument(\"--steps\", type=int, default=50, help=\"Inference steps\")\n    parser.add_argument(\"--scale\", type=float, default=0.99, help=\"Mesh normalization scale\")\n    parser.add_argument(\"--num_latents\", type=int, default=32768, help=\"Number of latents\")\n    parser.add_argument(\"--chunk_size\", type=int, default=8000, help=\"Chunk size for inference\")\n    parser.add_argument(\"--octree_res\", type=int, default=1024, help=\"Marching Cubes resolution\")\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"Random seed\")\n    parser.add_argument(\"--remove_bg\", action=\"store_true\", help=\"Force remove background\")\n\n    args = parser.parse_args()\n    \n    run_inference(args)\n"
  },
  {
    "path": "scripts/install_env.sh",
    "content": "conda create -n ultrashape python=3.10\nconda activate ultrashape \npip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121\npip install -r requirements.txt\npip install git+https://github.com/ashawkey/cubvh --no-build-isolation\n\npip install --no-build-isolation \"git+https://github.com/facebookresearch/pytorch3d.git@stable\"\npip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl\n"
  },
  {
    "path": "scripts/run.sh",
    "content": "# sampling \n# python scripts/sampling.py \\\n#     --mesh_json data/mesh_paths.json \\\n#     --output_dir data/sample\n\n# inference refine_dit\npython scripts/infer_dit_refine.py \\\n    --ckpt checkpoints/ultrashape_v1.pt \\\n    --image inputs/image/1.png \\\n    --mesh inputs/coarse_mesh/1.glb \\\n    --config configs/infer_dit_refine.yaml\n    # --steps 12\n"
  },
  {
    "path": "scripts/sampling.py",
    "content": "import os\nimport trimesh\nimport numpy as np\nfrom typing import List, Optional, Any, Tuple, Union\nimport pytorch_lightning as pl\nfrom pytorch_lightning.utilities.types import STEP_OUTPUT\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nimport pytorch3d.structures\nimport pytorch3d.ops\nfrom scipy.stats import truncnorm\nimport json\nimport argparse\nimport cubvh\n\n# import logging\n# from tools.logger import init_log, set_all_log\n# sys_logger = init_log(\"sampler\", logging.DEBUG) \n# set_all_log(level=logging.DEBUG, path='./debug/logs')   \n    \ndef load_mesh(mesh_path: str, device: str = \"cuda\") -> Tuple[torch.Tensor, torch.Tensor]:\n    if mesh_path.endswith(\".npz\"):\n        mesh_np = np.load(mesh_path)\n        vertices, faces = torch.tensor(mesh_np[\"vertices\"], device=device), torch.tensor(mesh_np[\"faces\"].astype('i8'), device=device)\n    else:\n        mesh = trimesh.load(mesh_path, force='mesh')\n        vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)\n        faces = torch.tensor(mesh.faces, dtype=torch.long, device=device)\n    if faces.shape[0] > 2 * 1e8:\n        raise ValueError(f\"too many faces {faces.shape}\")\n    return vertices, faces\n\ndef compute_mesh_features(vertices: torch.Tensor, faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    device = vertices.device\n    \n    v0 = vertices[faces[:, 0]]\n    v1 = vertices[faces[:, 1]]\n    v2 = vertices[faces[:, 2]]\n    face_normals = torch.cross(v1 - v0, v2 - v0)\n    face_areas = torch.norm(face_normals, dim=1) * 0.5\n    face_normals = face_normals / (face_areas.unsqueeze(1) * 2 + 1e-12)\n    \n    vertex_normals = torch.zeros_like(vertices)\n    face_normals_weighted = face_normals * face_areas.unsqueeze(1)\n    \n    vertex_normals.scatter_add_(0, faces[:, 0:1].expand(-1, 3), face_normals_weighted)\n    vertex_normals.scatter_add_(0, faces[:, 1:2].expand(-1, 3), face_normals_weighted)\n    vertex_normals.scatter_add_(0, faces[:, 2:3].expand(-1, 3), face_normals_weighted)\n    \n    vertex_normals = vertex_normals / (torch.norm(vertex_normals, dim=1, keepdim=True) + 1e-12)\n    \n    edges = torch.cat([\n        faces[:, [0, 1]],\n        faces[:, [1, 2]],\n        faces[:, [2, 0]]\n    ], dim=0)\n    \n    edges_unique, edges_inverse = torch.unique(torch.sort(edges, dim=1)[0], dim=0, return_inverse=True)\n    edge_normals_diff = torch.norm(\n        vertex_normals[edges[:, 0]] - vertex_normals[edges[:, 1]],\n        dim=1\n    )\n    \n    vertex_curvatures = torch.zeros(len(vertices), device=device)\n    vertex_curvatures.scatter_add_(0, edges[:, 0], edge_normals_diff)\n    vertex_curvatures.scatter_add_(0, edges[:, 1], edge_normals_diff)\n\n    vertex_degrees = torch.zeros(len(vertices), device=device)\n    vertex_degrees.scatter_add_(0, edges[:, 0], torch.ones_like(edge_normals_diff))\n    vertex_degrees.scatter_add_(0, edges[:, 1], torch.ones_like(edge_normals_diff))\n    \n    vertex_curvatures = vertex_curvatures / (vertex_degrees + 1e-12)\n    vertex_curvatures = (vertex_curvatures - vertex_curvatures.min()) / (\n        vertex_curvatures.max() - vertex_curvatures.min() + 1e-12)\n    \n    return face_areas, vertex_curvatures\n\ndef sample_uniform_points(\n    vertices: torch.Tensor,\n    faces: torch.Tensor,\n    num_samples: int,\n    random_seed: Optional[int] = None\n) -> Tuple[torch.Tensor, torch.Tensor]:\n\n    if random_seed is not None:\n        torch.manual_seed(random_seed)\n    mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces])\n    \n    points, normals = pytorch3d.ops.sample_points_from_meshes(\n        mesh, num_samples=num_samples, return_normals=True)\n    \n    return points[0], normals[0]\n\ndef sample_surface_points(\n    vertices: torch.Tensor,\n    faces: torch.Tensor,\n    num_samples: int,\n    min_samples_per_face: int = 0,\n    use_curvature: bool = True,\n    random_seed: Optional[int] = None\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Curvature-based surface sampling\"\"\"\n    device = vertices.device\n    if random_seed is not None:\n        torch.manual_seed(random_seed)\n    \n    # Compute face areas and vertex curvatures\n    face_areas, vertex_curvatures = compute_mesh_features(vertices, faces)\n    \n    # Compute average curvature of faces\n    face_curvatures = torch.mean(vertex_curvatures[faces], dim=1)\n    sampling_weights = face_curvatures  # Use only curvature as weights\n    # Calculate number of sample points per face\n    num_faces = len(faces)\n    \n    # Chunk forward\n    if min_samples_per_face > 0:\n        base_samples = torch.full((num_faces,), min_samples_per_face, device=device)\n        remaining_samples = num_samples - torch.sum(base_samples).item()\n        \n        if remaining_samples > 0:\n            # Block sampling to avoid large mesh issues\n            if num_faces > 2**24:\n                chunk_size = 1000000  # Process 1 million faces at a time\n                additional_counts = torch.zeros(num_faces, device=device)\n                \n                for start in range(0, num_faces, chunk_size):\n                    end = min(start + chunk_size, num_faces)\n                    chunk_weights = sampling_weights[start:end]\n                    chunk_probs = chunk_weights / chunk_weights.sum()\n                    \n                    # Proportinally allocate remaining samples\n                    chunk_samples = int(remaining_samples * (end - start) / num_faces)\n                    samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)\n                    chunk_counts = torch.bincount(samples, minlength=chunk_size)\n                    additional_counts[start:end] += chunk_counts[:end-start]\n                \n                sample_counts = additional_counts + base_samples\n            else:\n                probs = sampling_weights / sampling_weights.sum()\n                additional_samples = torch.multinomial(probs, remaining_samples, replacement=True)\n                sample_counts = torch.bincount(additional_samples, minlength=num_faces) + base_samples\n        else:\n            sample_counts = base_samples\n    else:\n        if num_faces > 2**24:\n            # Chunk sampling strategy\n            sample_counts = torch.zeros(num_faces, device=device)\n            chunk_size = 1000000  # Process 1 million faces at a time\n            chunk_samples = num_samples // ((num_faces + chunk_size - 1) // chunk_size)\n            \n            for start in range(0, num_faces, chunk_size):\n                end = min(start + chunk_size, num_faces)\n                chunk_weights = sampling_weights[start:end]\n                chunk_probs = chunk_weights / chunk_weights.sum()\n                \n                samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)\n                chunk_counts = torch.bincount(samples, minlength=chunk_size)\n                sample_counts[start:end] += chunk_counts[:end-start]\n        else:\n            probs = sampling_weights / sampling_weights.sum()\n            samples = torch.multinomial(probs, num_samples, replacement=True)\n            sample_counts = torch.bincount(samples, minlength=num_faces)\n    \n    # Generate barycentric coordinates for sampled points\n    total_samples = sample_counts.sum().item()\n    r1 = torch.sqrt(torch.rand(total_samples, device=device))\n    r2 = torch.rand(total_samples, device=device)\n    \n    barycentric_coords = torch.stack([\n        1 - r1,\n        r1 * (1 - r2),\n        r1 * r2\n    ], dim=1)\n    \n    # Generate face indices\n    face_indices = torch.repeat_interleave(\n        torch.arange(num_faces, device=device),\n        sample_counts\n    )\n    \n    # Get vertices of corresponding faces\n    face_vertices = vertices[faces[face_indices]]\n    \n    # Compute 3D coordinates of sampled points\n    points = (barycentric_coords.unsqueeze(1) @ face_vertices).squeeze(1)\n    \n    # Compute normal vectors of sampled points\n    v0, v1, v2 = face_vertices[:, 0], face_vertices[:, 1], face_vertices[:, 2]\n    face_normals = torch.cross(v1 - v0, v2 - v0)\n    normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-12)\n    \n    return points, face_indices, normals\n\ndef normalize_points_and_mesh(vertices: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Normalize mesh and point cloud to unit cube\"\"\"\n    device = vertices.device\n    vmin = vertices.min(dim=0)[0]\n    vmax = vertices.max(dim=0)[0]\n    center = (vmax + vmin) / 2\n    scale = (vmax - vmin).max()\n    margin = 0.01\n    scale = scale * (1 + 2 * margin)\n    \n    vertices_normalized = (vertices - center) / scale + 0.5\n    points_normalized = (points - center) / scale + 0.5\n    \n    return vertices_normalized, points_normalized, center, scale\n\ndef add_gaussian_noise(uniform_surface_points: torch.Tensor, curvature_surface_points: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:\n    \"\"\"Add Gaussian noise to point cloud\"\"\"\n    # noise = torch.randn_like(points) * sigma\n    # print(\"u_num:\",uniform_surface_points.shape)\n    # print(\"c_num:\",curvature_surface_points.shape)\n\n    idx1 = torch.randperm(uniform_surface_points.shape[0])\n    idx2 = torch.randperm(curvature_surface_points.shape[0])\n    uniform_surface_points = uniform_surface_points[idx1]\n    curvature_surface_points = curvature_surface_points[idx2]\n\n    a, b = -0.25, 0.25\n    mu = 0\n\n    # get near points (add offset on surface points)\n    offset1 = torch.tensor(truncnorm.rvs((a - mu) / 0.005, (b - mu) / 0.005, loc=mu, scale=0.005, size=(len(uniform_surface_points), 3)), \n                         dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)\n    offset2 = torch.tensor(truncnorm.rvs((a - mu) / 0.05, (b - mu) / 0.05, loc=mu, scale=0.05,  size=(len(uniform_surface_points), 3)), \n                         dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)\n    uniform_near_points = torch.cat([\n        uniform_surface_points + offset1,\n        uniform_surface_points + offset2\n    ], dim=0)\n\n    # Generate multi-scale noise for curvature sample points\n    unit_num = curvature_surface_points.shape[0] // 6\n    scales = [0.001, 0.003, 0.006, 0.01, 0.02, 0.04]\n    \n    curvature_near_points = []\n    for i in range(6):\n        start = i * unit_num\n        end = (i + 1) * unit_num if i < 5 else curvature_surface_points.shape[0]\n        noise = torch.randn((end - start, 3), dtype=curvature_surface_points.dtype, \n                          device=curvature_surface_points.device) * scales[i]\n        curvature_near_points.append(curvature_surface_points[start:end] + noise)\n    \n    curvature_near_points = torch.cat(curvature_near_points, dim=0)\n\n    return uniform_near_points, curvature_near_points\n\ndef compute_points_value_bvh(\n    vertices: torch.Tensor,\n    faces: torch.Tensor,\n    points: torch.Tensor,\n    use_sdf: bool = True,\n    batch_size: int = 100_00000\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Compute SDF or occupancy values for sampled points\"\"\"\n    device = vertices.device\n    \n    # Normalize mesh and point cloud\n    vertices_norm, points_norm, center, scale = normalize_points_and_mesh(vertices, points)\n    \n    BVH = cubvh.cuBVH(vertices_norm, faces)\n    distances, face_id, uvw = BVH.signed_distance(points, return_uvw=True, mode='watertight')\n    values = distances\n    \n    return values, points_norm, center, scale\n\ndef save_point_cloud(\n    points: torch.Tensor,\n    output_path: str,\n    normals: Optional[torch.Tensor] = None,\n    colors: Optional[torch.Tensor] = None\n) -> None:\n    \"\"\"Save point cloud to file\"\"\"\n    points_np = points.cpu().numpy()\n    normals_np = normals.cpu().numpy() if normals is not None else None\n    colors_np = None\n    \n    if colors is not None:\n        colors_np = colors.cpu().numpy()\n        if colors_np.max() <= 1.0:\n            colors_np = (colors_np * 255).astype(np.uint8)\n    \n    ext = os.path.splitext(output_path)[1].lower()\n    \n    if ext == '.txt':\n        data_list = [points_np]\n        if normals_np is not None:\n            data_list.append(normals_np)\n        if colors_np is not None:\n            data_list.append(colors_np)\n            \n        combined_data = np.hstack(data_list)\n        np.savetxt(output_path, combined_data, fmt='%.6f')\n        \n    elif ext == '.ply':\n        cloud = trimesh.PointCloud(points_np, colors=colors_np)\n        if normals_np is not None:\n            cloud.metadata['normals'] = normals_np\n        cloud.export(output_path)\n        \n    else:\n        raise ValueError(f\"Unsupported file format: {ext}. Please use .txt or .ply\")\n\ndef sample_points_in_bbox(\n    bbox_min: torch.Tensor,\n    bbox_max: torch.Tensor,\n    num_samples: int,\n    device: str = \"cuda\"\n) -> torch.Tensor:\n    \"\"\"Uniformly sample points within bounding box\"\"\"\n    points = torch.rand(num_samples, 3, device=device)\n    points = points * (bbox_max - bbox_min) + bbox_min\n    return points\n\ndef process_single_mesh(\n    mesh_name:str,\n    mesh_path: str,\n    output_dir: str,\n    data_type:str = 'mesh',\n    surface_uniform_samples: int = 100000,      # surface上均匀采样点数\n    surface_curvature_samples: int = 200000,    # surface上曲率采样点数\n    space_samples: int = 300000,               # 空间中采样点数\n    noise_sigma: float = 0.01,\n    device: str = \"cuda\"\n) -> None:\n    \"\"\"Process a single mesh file\n    Args:\n        mesh_path: Input mesh path\n        output_dir: Output directory\n        surface_uniform_samples: Number of uniform sample points on surface\n        surface_curvature_samples: Number of curvature-based sample points on surface\n        space_samples: Number of sample points in space\n        noise_sigma: Gaussian noise standard deviation\n        device: Computation device\n    \"\"\"\n    os.makedirs(output_dir, exist_ok=True)\n    \n    if data_type == \"mesh\":\n        vertices, faces = load_mesh(mesh_path, device)\n    elif data_type == \"sparse_voxel\":\n        pass\n    vertices_normalized, _, center, scale = normalize_points_and_mesh(vertices, vertices)\n    \n    space_points = torch.rand(space_samples, 3, device=device)\n    \n    uniform_surface_points, uniform_surface_normals = sample_uniform_points(\n        vertices=vertices_normalized,\n        faces=faces,\n        num_samples=surface_uniform_samples\n    )\n    \n    curvature_surface_points, _, curvature_surface_normals = sample_surface_points(\n        vertices=vertices_normalized,\n        faces=faces,\n        num_samples=surface_curvature_samples,\n        use_curvature=True\n    )\n    \n    clean_surface_points = torch.cat([uniform_surface_points, curvature_surface_points], dim=0)\n    clean_surface_normals = torch.cat([uniform_surface_normals, curvature_surface_normals], dim=0)\n\n    surface_uni_save_path = os.path.join(output_dir, f\"{mesh_name}_uni_surface\")\n    save_point_cloud(\n        points=uniform_surface_points,\n        output_path=f\"{surface_uni_save_path}.ply\",\n        normals=uniform_surface_normals\n    )   \n\n    surface_cur_save_path = os.path.join(output_dir, f\"{mesh_name}_cur_surface\")\n    save_point_cloud(\n        points=curvature_surface_points,\n        output_path=f\"{surface_cur_save_path}.ply\",\n        normals=curvature_surface_normals\n    )\n    \n    uniform_near_points, curvature_near_points = add_gaussian_noise(uniform_surface_points = uniform_surface_points.clone(),\n                            curvature_surface_points = curvature_surface_points.clone(), sigma=noise_sigma)\n\n    space_sdf, _, _, _ = compute_points_value_bvh(\n        vertices=vertices_normalized,\n        faces=faces,\n        points=space_points,\n        use_sdf=True,\n        batch_size=1000_00000\n    )\n    \n    # clean_surface_sdf = torch.zeros(len(clean_surface_points), device=device)\n    uniform_near_sdf, _, _, _ = compute_points_value_bvh(\n        vertices=vertices_normalized,\n        faces=faces,\n        points=uniform_near_points,\n        use_sdf=True,\n        batch_size=1000_00000\n    )\n    \n    curvature_near_sdf, _, _, _ = compute_points_value_bvh(\n        vertices=vertices_normalized,\n        faces=faces,\n        points=curvature_near_points,\n        use_sdf=True,\n        batch_size=1000_00000\n    )\n    \n    print(\"sdf:\",uniform_near_sdf.shape, curvature_near_sdf.shape)\n    \n    base_save_path = os.path.join(output_dir, mesh_name)\n    \n    np.savez(f\"{base_save_path}.npz\",\n             space_points=space_points.cpu().numpy(),\n             space_sdf=space_sdf.cpu().numpy(),\n             clean_surface_points=clean_surface_points.cpu().numpy(),\n             clean_surface_normals=clean_surface_normals.cpu().numpy(),\n             uniform_near_points=uniform_near_points.cpu().numpy(),\n             curvature_near_points=curvature_near_points.cpu().numpy(),\n             uniform_near_sdf=uniform_near_sdf.cpu().numpy(),\n             curvature_near_sdf=curvature_near_sdf.cpu().numpy(),\n             center=center.cpu().numpy(),\n             scale=scale.cpu().numpy())\n\nclass MeshDataset(Dataset):\n    def __init__(self, mesh_json: str):\n        with open(mesh_json, \"r\") as f:\n            self.mesh_paths = json.load(f)\n        # print(len(self.mesh_paths))\n            \n    def __len__(self) -> int:\n        return len(self.mesh_paths)\n    def __getitem__(self, idx: int) -> dict:\n        mesh_path = self.mesh_paths[idx]\n        mesh_name = os.path.basename(mesh_path)[:-4]\n        mesh =  {\n            \"mesh_path\": mesh_path,\n            \"mesh_name\": mesh_name,\n        }\n        return mesh\n\nclass MeshProcessor(pl.LightningModule):\n    def __init__(\n        self,\n        mesh_json: str,\n        output_dir: str,\n        data_type:str,\n        surface_uniform_samples: int = 20000,\n        surface_curvature_samples: int = 40000,\n        space_samples: int = 300000,\n        noise_sigma: float = 0.01,\n        batch_size: int = 1,\n        num_workers: int = 4\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        os.makedirs(output_dir, exist_ok=True)\n    \n    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:\n        mesh_path = batch[\"mesh_path\"][0]\n        mesh_name = batch[\"mesh_name\"][0]\n        \n        # sys_logger.info(f\"Processing {batch_idx}/{len(self.trainer.predict_dataloaders)}: {mesh_name} from {mesh_path}\")\n        \n        output_subdir = self.hparams.output_dir\n        \n        try:\n            filename = os.path.splitext(os.path.basename(mesh_path))[0]\n            if os.path.exists(os.path.join(output_subdir, f\"{filename}.npz\")):\n                # sys_logger.info(f\"Skipping {mesh_name} as it already exists.\")\n                return {\n                    \"status\": \"success\",\n                    \"mesh_name\": mesh_name\n                }\n            process_single_mesh(\n                mesh_name=mesh_name,\n                mesh_path=mesh_path,\n                output_dir=output_subdir,\n                data_type = self.hparams.data_type,\n                surface_uniform_samples=self.hparams.surface_uniform_samples,\n                surface_curvature_samples=self.hparams.surface_curvature_samples,\n                space_samples=self.hparams.space_samples,\n                noise_sigma=self.hparams.noise_sigma,\n                device=self.device\n            )\n            \n            return {\n                \"status\": \"success\",\n                \"mesh_name\": mesh_name\n            }\n        \n        except Exception as e:\n                print(f\"Error processing {mesh_name}: {str(e)}\")\n                return {\n                    \"status\": \"error\",\n                    \"mesh_name\": mesh_name,\n                    \"error\": str(e)\n                }\n\n    def predict_dataloader(self) -> DataLoader:\n        dataset = MeshDataset(\n            self.hparams.mesh_json)\n        return DataLoader(\n            dataset,\n            batch_size=self.hparams.batch_size,\n            num_workers=self.hparams.num_workers,\n            persistent_workers=True,\n            shuffle=False\n        )\n\ndef process_mesh_directory(\n    mesh_json: str,\n    output_dir: str,\n    data_type: str,\n    surface_uniform_samples: int = 100000,\n    surface_curvature_samples: int = 200000,\n    space_samples: int = 300000,\n    noise_sigma: float = 0.01,\n    num_gpus: int = -1,\n    batch_size: int = 1,\n    num_workers: int = 4\n) -> None:\n    model = MeshProcessor(\n        mesh_json=mesh_json,\n        output_dir=output_dir,\n        data_type=data_type,\n        surface_uniform_samples=surface_uniform_samples,\n        surface_curvature_samples=surface_curvature_samples,\n        space_samples=space_samples,\n        noise_sigma=noise_sigma,\n        batch_size=batch_size,\n        num_workers=num_workers\n    )\n\n    trainer = pl.Trainer(\n        accelerator=\"gpu\",\n        devices=num_gpus,\n        strategy=\"ddp\",\n        precision=32,\n        logger=False,\n        enable_progress_bar=True\n    )\n    \n    predictions = trainer.predict(model)\n    \n    success_count = sum(1 for p in predictions if p[\"status\"] == \"success\")\n    error_count = sum(1 for p in predictions if p[\"status\"] == \"error\")\n    \n    print(f\"\\nProcessing completed:\")\n    print(f\"Successfully processed: {success_count} files\")\n    print(f\"Failed to process: {error_count} files\")\n    \n    if error_count > 0:\n        print(\"\\nFailed files:\")\n        for p in predictions:\n            if p[\"status\"] == \"error\":\n                print(f\"- {p['mesh_name']}: {p['error']}\")\n                \nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(description=\"Process Mesh Directory for Sampling\")\n\n    parser.add_argument(\"--mesh_json\", type=str, default=\"test_mesh.json\", help=\"Path to the mesh json file\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"ultrashape_test1\", help=\"Directory to save outputs\")\n\n    parser.add_argument(\"--surface_uniform_samples\", type=int, default=300000, help=\"Number of uniform samples on surface\")\n    parser.add_argument(\"--surface_curvature_samples\", type=int, default=300000, help=\"Number of curvature-based samples on surface\")\n    parser.add_argument(\"--space_samples\", type=int, default=400000, help=\"Number of samples in space\")\n\n    parser.add_argument(\"--noise_sigma\", type=float, default=0.01, help=\"Sigma for Gaussian noise\")\n    parser.add_argument(\"--num_gpus\", type=int, default=1, help=\"Number of GPUs to use\")\n    parser.add_argument(\"--num_workers\", type=int, default=16, help=\"Number of data loading workers\")\n    parser.add_argument(\"--batch_size\", type=int, default=1, help=\"Batch size per GPU\")\n\n    args = parser.parse_args()\n    # print(f\"Arguments: {args}\")\n\n    process_mesh_directory(\n        mesh_json=args.mesh_json,\n        output_dir=args.output_dir,\n        data_type='mesh',\n        surface_uniform_samples=args.surface_uniform_samples,\n        surface_curvature_samples=args.surface_curvature_samples,\n        space_samples=args.space_samples,\n        noise_sigma=args.noise_sigma,\n        num_gpus=args.num_gpus,\n        num_workers=args.num_workers,\n        batch_size=args.batch_size\n    )\n"
  },
  {
    "path": "scripts/train_deepspeed.sh",
    "content": "\nexport NCCL_IB_TIMEOUT=24\nexport NCCL_NVLS_ENABLE=0\nNET_TYPE=\"high\"\nif [[ \"${NET_TYPE}\" = \"low\" ]]; then\n    export NCCL_SOCKET_IFNAME=eth1\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_IB_HCA=mlx5_2:1,mlx5_2:1\n    export NCCL_IB_SL=3\n    export NCCL_CHECKS_DISABLE=1\n    export NCCL_P2P_DISABLE=0\n    export NCCL_LL_THRESHOLD=16384\n    export NCCL_IB_CUDA_SUPPORT=1\nelse\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_IB_SL=3\n    export NCCL_CHECKS_DISABLE=1\n    export NCCL_P2P_DISABLE=0\n    export NCCL_IB_DISABLE=0\n    export NCCL_LL_THRESHOLD=16384\n    export NCCL_IB_CUDA_SUPPORT=1\n    export NCCL_SOCKET_IFNAME=bond1\n    export NCCL_COLLNET_ENABLE=0\n    export SHARP_COLL_ENABLE_SAT=0\n    export NCCL_NET_GDR_LEVEL=2\n    export NCCL_IB_QPS_PER_CONNECTION=4\n    export NCCL_IB_TC=160\n    export NCCL_PXN_DISABLE=1\nfi\n# export NCCL_DEBUG=INFO\n\nnode_num=$1\nnode_rank=$2\nnum_gpu_per_node=$3\nmaster_ip=$4\nconfig=$5\noutput_dir=$6\n\necho node_num $node_num\necho node_rank $node_rank\necho master_ip $master_ip\necho config $config\necho output_dir $output_dir\n\nif test -d \"$output_dir\"; then\n    cp $config $output_dir\nelse\n    mkdir -p \"$output_dir\"\n    cp $config $output_dir\nfi\n\nNODE_RANK=$node_rank \\\nHF_HUB_OFFLINE=0 \\\nMASTER_PORT=12348 \\\nMASTER_ADDR=$master_ip \\\nNCCL_SOCKET_IFNAME=bond1 \\\nNCCL_IB_GID_INDEX=3 \\\nNCCL_NVLS_ENABLE=0 \\\npython3 main.py \\\n    --num_nodes $node_num \\\n    --num_gpus $num_gpu_per_node \\\n    --config $config \\\n    --output_dir $output_dir \\\n    --deepspeed\n"
  },
  {
    "path": "train.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nexport num_gpu_per_node=8\n\nexport node_num=1\nexport node_rank=$1\nexport master_ip= # [your master ip here]\n\n\n############## vae ##############\n# export config=configs/train_vae_refine.yaml\n# export output_dir=outputs/vae_ultrashape/exp1_token8192\n# bash scripts/train_deepspeed.sh $node_num $node_rank $num_gpu_per_node $master_ip $config $output_dir\n\n############## dit ##############\nexport config=configs/train_dit_refine.yaml\nexport output_dir=outputs/dit_ultrashape/exp1_token8192\nbash scripts/train_deepspeed.sh $node_num $node_rank $num_gpu_per_node $master_ip $config $output_dir\n\n"
  },
  {
    "path": "ultrashape/__init__.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom .pipelines import UltraShapePipeline\nfrom .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier\nfrom .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR\n"
  },
  {
    "path": "ultrashape/data/objaverse_dit.py",
    "content": "# -*- coding: utf-8 -*-\n\n# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport math\nimport os\nimport json\nfrom dataclasses import dataclass, field\n\nimport random\nimport imageio\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset\nfrom PIL import Image\nimport pickle\nfrom ultrashape.utils.typing import *\nimport pandas as pd\nimport cv2\nimport torchvision.transforms as transforms\nfrom pytorch_lightning.utilities import rank_zero_info\n\ndef padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):\n    \"\"\"\n    Pad the input image and mask to a square shape with padding ratio.\n\n    Args:\n        image (np.ndarray): Input image array of shape (H, W, C).\n        mask (np.ndarray): Corresponding mask array of shape (H, W).\n        center (bool): Whether to center the original image in the padded output.\n        padding_ratio_range (list): Range [min, max] to randomly select padding ratio.\n\n    Returns:\n        newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).\n        newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).\n    \"\"\"\n    h, w = image.shape[:2]\n    max_side = max(h, w)\n\n    # Select padding ratio either fixed or randomly within the given range\n    if padding_ratio_range[0] == padding_ratio_range[1]:\n        padding_ratio = padding_ratio_range[0]\n    else:\n        padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])\n    resize_side = int(max_side * padding_ratio)\n\n    pad_h = resize_side - h\n    pad_w = resize_side - w\n    if center:\n        start_h = pad_h // 2\n    else:\n        start_h = pad_h - resize_side // 20\n        \n    start_w = pad_w // 2\n\n    # Create new white image and black mask with padded size\n    newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255\n    newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)\n    \n    # Place original image and mask into the padded canvas\n    newimg[start_h:start_h + h, start_w:start_w + w] = image\n    newmask[start_h:start_h + h, start_w:start_w + w] = mask\n    \n    return newimg, newmask\n\n\nclass ObjaverseDataset(Dataset):\n    def __init__(\n        self,\n        data_json,\n        sample_root,\n        image_path,\n        image_transform = None,\n        pc_size: int = 2048,\n        pc_sharpedge_size: int = 2048,\n        sharpedge_label: bool = False,\n        return_normal: bool = False,\n        padding = True,\n        padding_ratio_range=[1.15, 1.15],\n    ):\n        super().__init__()\n\n        self.uids = json.load(open(data_json))\n        self.sample_root = sample_root\n        self.image_paths = json.load(open(image_path))\n        self.image_transform = image_transform\n        \n        self.pc_size = pc_size\n        self.pc_sharpedge_size = pc_sharpedge_size\n        self.sharpedge_label = sharpedge_label\n        self.return_normal = return_normal\n\n        self.padding = padding\n        self.padding_ratio_range = padding_ratio_range\n        \n        print(f\"Loaded {len(self.uids)} uids from {data_json}.\")\n\n        rank_zero_info(f'*' * 50)\n        rank_zero_info(f'Dataset Infos:')\n        rank_zero_info(f'# of 3D file: {len(self.uids)}')\n        rank_zero_info(f'# of Surface Points: {self.pc_size}')\n        rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')\n        rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')\n        rank_zero_info(f'*' * 50)\n\n    def __len__(self):\n        return len(self.uids)\n\n    def _load_shape(self, index: int) -> Dict[str, Any]:\n\n        data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')\n\n        surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2 \n        normal = np.asarray(data['clean_surface_normals']) \n        surface_og_n = np.concatenate([surface_og, normal], axis=1) \n        rng = np.random.default_rng()\n\n        # hard code: first 300k are uniform, last 300k are sharp\n        assert surface_og_n.shape[0] == 600000, f\"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}\"\n        coarse_surface = surface_og_n[:300000]\n        sharp_surface = surface_og_n[300000:]\n\n        surface_normal = []\n        rng = np.random.default_rng()\n        if self.pc_size > 0:\n            ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)\n            coarse_surface = coarse_surface[ind]\n            if self.sharpedge_label:\n                sharpedge_label = np.zeros((self.pc_size // 2, 1))\n                coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)\n            surface_normal.append(coarse_surface)\n\n            ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)\n            sharp_surface = sharp_surface[ind_sharpedge]\n            if self.sharpedge_label:\n                sharpedge_label = np.ones((self.pc_size // 2, 1))\n                sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)\n            surface_normal.append(sharp_surface)\n        \n        surface_normal = np.concatenate(surface_normal, axis=0)\n        surface_normal = torch.FloatTensor(surface_normal)\n        surface = surface_normal[:, 0:3]\n        normal = surface_normal[:, 3:6]\n        assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size\n\n        geo_points = 0.0\n        normal = torch.nn.functional.normalize(normal, p=2, dim=1)\n        if self.return_normal:\n            surface = torch.cat([surface, normal], dim=-1)\n        if self.sharpedge_label:\n            surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)\n\n        ret = {\n                \"uid\": self.uids[index],\n                \"surface\": surface,\n                \"geo_points\": geo_points\n            }\n        return ret\n\n        \n    def _load_image(self, index: int) -> Dict[str, Any]:\n        ret = {}\n        sel_idx = random.randint(0, 15)\n        ret[\"sel_image_idx\"] = sel_idx\n        obj_name = self.uids[index]\n        img_path = f'{self.image_paths[obj_name]}/{os.path.basename(self.image_paths[obj_name])}/rgba/' + f\"{sel_idx:03d}.png\"\n \n        images, masks = [], []\n        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)\n        assert image.shape[2] == 4\n        alpha = image[:, :, 3:4].astype(np.float32) / 255\n        forground = image[:, :, :3]\n        background = np.ones_like(forground) * 255\n        img_new = forground * alpha + background * (1 - alpha)\n        image = img_new.astype(np.uint8)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        mask = (alpha[:, :, 0] * 255).astype(np.uint8)\n\n        if self.padding:\n            h, w = image.shape[:2]\n            binary = mask > 0.3\n            non_zero_coords = np.argwhere(binary)\n            x_min, y_min = non_zero_coords.min(axis=0)\n            x_max, y_max = non_zero_coords.max(axis=0)\n            image, mask = padding(\n                image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],\n                mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],\n                center=True, padding_ratio_range=self.padding_ratio_range)\n        \n        if self.image_transform:\n            image = self.image_transform(image)\n            mask = np.stack((mask, mask, mask), axis=-1)\n            mask = self.image_transform(mask)\n        \n        images.append(image)\n        masks.append(mask)\n        ret[\"image\"] = torch.cat(images, dim=0)\n        ret[\"mask\"] = torch.cat(masks, dim=0)[:1, ...]\n        \n        return ret\n\n    def get_data(self, index):\n        ret = self._load_shape(index)\n        ret.update(self._load_image(index))\n        return ret\n        \n    def __getitem__(self, index):\n        try:\n            return self.get_data(index)\n        except Exception as e:\n            print(f\"Error in {self.uids[index]}: {e}\")\n            return self.__getitem__(np.random.randint(len(self)))\n\n    def collate(self, batch):\n        batch = torch.utils.data.default_collate(batch)\n        return batch\n\n\nclass ObjaverseDataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        batch_size: int = 1,\n        num_workers: int = 4,\n        val_num_workers: int = 2,\n        training_data_list: str = None,\n        sample_pcd_dir: str = None,\n        image_data_json: str = None,\n        image_size: int = 224,\n        mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),\n        std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),\n        pc_size: int = 2048,\n        pc_sharpedge_size: int = 2048,\n        sharpedge_label: bool = False,\n        return_normal: bool = False, \n        padding = True,\n        padding_ratio_range=[1.15, 1.15]\n    ):\n\n        super().__init__()\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.val_num_workers = val_num_workers\n\n        self.training_data_list = training_data_list\n        self.sample_pcd_dir = sample_pcd_dir\n        self.image_data_json = image_data_json\n        \n        self.image_size = image_size\n        self.mean = mean\n        self.std = std\n        self.train_image_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Resize(self.image_size),\n            transforms.Normalize(mean=self.mean, std=self.std)])\n        self.val_image_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Resize(self.image_size),\n            transforms.Normalize(mean=self.mean, std=self.std)])\n\n        self.pc_size = pc_size\n        self.pc_sharpedge_size = pc_sharpedge_size\n        self.sharpedge_label = sharpedge_label\n        self.return_normal = return_normal\n\n        self.padding = padding\n        self.padding_ratio_range = padding_ratio_range\n\n    def train_dataloader(self):\n        asl_params = {\n            \"data_json\": f'{self.training_data_list}/train.json',\n            \"sample_root\": self.sample_pcd_dir,\n            \"image_path\": self.image_data_json,\n            \"image_transform\": self.train_image_transform,\n            \"pc_size\": self.pc_size,\n            \"pc_sharpedge_size\": self.pc_sharpedge_size,\n            \"sharpedge_label\": self.sharpedge_label,\n            \"return_normal\": self.return_normal,\n            \"padding\": self.padding,\n            \"padding_ratio_range\": self.padding_ratio_range,\n        }\n        dataset = ObjaverseDataset(**asl_params)\n        return torch.utils.data.DataLoader(\n            dataset,\n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=True,\n        )\n\n    def val_dataloader(self):\n        asl_params = {\n            \"data_json\": f'{self.training_data_list}/val.json',\n            \"sample_root\": self.sample_pcd_dir,\n            \"image_path\": self.image_data_json,\n            \"image_transform\": self.val_image_transform,\n            \"pc_size\": self.pc_size,\n            \"pc_sharpedge_size\": self.pc_sharpedge_size,\n            \"sharpedge_label\": self.sharpedge_label,\n            \"return_normal\": self.return_normal, \n            \"padding\": self.padding,\n            \"padding_ratio_range\": self.padding_ratio_range,\n        }\n        dataset = ObjaverseDataset(**asl_params)\n        return torch.utils.data.DataLoader(\n            dataset,\n            batch_size=self.batch_size,\n            num_workers=self.val_num_workers,\n            pin_memory=True,\n            drop_last=True,\n        )\n"
  },
  {
    "path": "ultrashape/data/objaverse_vae.py",
    "content": "# -*- coding: utf-8 -*-\n\n# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\n\nimport os\nimport cv2\nimport json\nimport math\nimport random\nimport imageio\nimport pickle\nimport numpy as np\nfrom PIL import Image\nimport pandas as pd\nfrom dataclasses import dataclass, field\n\nimport torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoader, Dataset\nimport torchvision.transforms as transforms\nfrom pytorch_lightning.utilities import rank_zero_info\nfrom ultrashape.utils.typing import *\n\nclass ObjaverseDataset(Dataset):\n    def __init__(\n        self,\n        data_json,\n        sample_root,\n        pc_size: int = 2048,\n        pc_sharpedge_size: int = 2048,\n        sup_near_uni_size: int = 4096,\n        sup_near_sharp_size: int = 4096,\n        sup_space_size: int = 4096,\n        tsdf_threshold: float = 0.05,\n        sharpedge_label: bool = False,\n        return_normal: bool = False,\n    ):\n        super().__init__()\n\n        self.uids = json.load(open(data_json))\n        self.sample_root = sample_root\n        \n        self.pc_size = pc_size\n        self.pc_sharpedge_size = pc_sharpedge_size\n        self.sharpedge_label = sharpedge_label\n        self.return_normal = return_normal\n\n        self.sup_near_uni_size = sup_near_uni_size\n        self.sup_near_sharp_size = sup_near_sharp_size\n        self.sup_space_size = sup_space_size\n        self.tsdf_threshold = tsdf_threshold\n        \n        print(f\"Loaded {len(self.uids)} uids from {data_json}.\")\n\n        rank_zero_info(f'*' * 50)\n        rank_zero_info(f'Dataset Infos:')\n        rank_zero_info(f'# of 3D file: {len(self.uids)}')\n        rank_zero_info(f'# of Surface Points: {self.pc_size}')\n        rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')\n        rank_zero_info(f'# of Uniform Near-Surface Sup-Points: {self.sup_near_uni_size}')\n        rank_zero_info(f'# of Sharpedge Near-Surface Sup-Points: {self.sup_near_sharp_size}')\n        rank_zero_info(f'# of Random Space Sup-Points: {self.sup_space_size}')\n        rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')\n        rank_zero_info(f'*' * 50)\n\n    def __len__(self):\n        return len(self.uids)\n\n    def _load_shape(self, index: int) -> Dict[str, Any]:\n        rng = np.random.default_rng()\n\n        data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')\n        \n        ##################### sup pcd&sdf ######################\n        uniform_near_points =  (np.asarray(data['uniform_near_points'])-0.5) * 2\n        curvature_near_points = (np.asarray(data['curvature_near_points'])-0.5) * 2\n        space_points = (np.asarray(data['space_points'])-0.5) * 2 \n        uniform_near_sdf = np.asarray(data['uniform_near_sdf']) * 2 \n        curvature_near_sdf = np.asarray(data['curvature_near_sdf']) * 2\n        space_sdf = np.asarray(data['space_sdf']) * 2 \n\n        uni_noisy_idx = rng.choice(uniform_near_points.shape[0], self.sup_near_uni_size, replace=False)\n        cur_noisy_idx = rng.choice(curvature_near_points.shape[0], self.sup_near_sharp_size, replace=False)\n        space_idx = rng.choice(space_points.shape[0], self.sup_space_size, replace=False)\n\n        uniform_near_points = uniform_near_points[uni_noisy_idx]\n        curvature_near_points = curvature_near_points[cur_noisy_idx]\n        space_points = space_points[space_idx]\n        uniform_near_sdf = uniform_near_sdf[uni_noisy_idx]\n        curvature_near_sdf = curvature_near_sdf[cur_noisy_idx]\n        space_sdf = space_sdf[space_idx]\n\n        uniform_near_sdf, curvature_near_sdf, space_sdf = map(self._clip_to_tsdf, (uniform_near_sdf, curvature_near_sdf, space_sdf))\n\n        surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2 \n        normal = np.asarray(data['clean_surface_normals'])\n        surface_og_n = np.concatenate([surface_og, normal], axis=1) \n        rng = np.random.default_rng()\n\n        # hard code: first 300k are uniform, last 300k are sharp\n        assert surface_og_n.shape[0] == 600000, f\"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}\"\n        coarse_surface = surface_og_n[:300000]\n        sharp_surface = surface_og_n[300000:]\n\n        surface_normal = []\n        \n        if self.pc_size > 0:\n            ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)\n            coarse_surface = coarse_surface[ind]\n            if self.sharpedge_label:\n                sharpedge_label = np.zeros((self.pc_size // 2, 1))\n                coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)\n            surface_normal.append(coarse_surface)\n\n            ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)\n            sharp_surface = sharp_surface[ind_sharpedge]\n            if self.sharpedge_label:\n                sharpedge_label = np.ones((self.pc_size // 2, 1))\n                sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)\n            surface_normal.append(sharp_surface)\n        \n        surface_normal = np.concatenate(surface_normal, axis=0)\n        surface_normal = torch.FloatTensor(surface_normal)\n        surface = surface_normal[:, 0:3]\n        normal = surface_normal[:, 3:6]\n        assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size\n\n        geo_points = 0.0\n        normal = torch.nn.functional.normalize(normal, p=2, dim=1)\n        if self.return_normal:\n            surface = torch.cat([surface, normal], dim=-1)\n        if self.sharpedge_label:\n            surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)\n\n        ret = {\n                \"uid\": self.uids[index],\n                \"surface\": surface,\n                \"sup_near_uniform\": np.concatenate([uniform_near_points, uniform_near_sdf[...,None]], axis=1), \n                \"sup_near_sharp\": np.concatenate([curvature_near_points, curvature_near_sdf[...,None]], axis=1), \n                \"sup_space\": np.concatenate([space_points, space_sdf[...,None]], axis=1),\n                \"geo_points\": geo_points\n            }\n        return ret\n    \n    def _clip_to_tsdf(self, sdf: np.array):\n        nan_mask = np.isnan(sdf)\n        if np.any(nan_mask):\n            sdf=np.nan_to_num(sdf, nan=1.0, posinf=1.0, neginf=-1.0)\n        return sdf.flatten().astype(np.float32).clip(-self.tsdf_threshold, self.tsdf_threshold) / self.tsdf_threshold\n\n    def get_data(self, index):\n        ret = self._load_shape(index)\n        return ret\n        \n    def __getitem__(self, index):\n        return self.get_data(index)\n\n    def collate(self, batch):\n        batch = torch.utils.data.default_collate(batch)\n        return batch\n\n\nclass ObjaverseDataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        batch_size: int = 1,\n        num_workers: int = 4,\n        val_num_workers: int = 2,\n        training_data_list: str = None,\n        sample_pcd_dir: str = None,\n        pc_size: int = 2048,\n        pc_sharpedge_size: int = 2048,\n        sup_near_uni_size: int = 4096,\n        sup_near_sharp_size: int = 4096,\n        sup_space_size: int = 4096,\n        tsdf_threshold: float = 0.05,\n        sharpedge_label: bool = False,\n        return_normal: bool = False, \n    ):\n\n        super().__init__()\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.val_num_workers = val_num_workers\n\n        self.training_data_list = training_data_list\n        self.sample_pcd_dir = sample_pcd_dir\n\n        self.pc_size = pc_size\n        self.pc_sharpedge_size = pc_sharpedge_size\n        self.sharpedge_label = sharpedge_label\n        self.return_normal = return_normal\n\n        self.sup_near_uni_size = sup_near_uni_size\n        self.sup_near_sharp_size = sup_near_sharp_size\n        self.sup_space_size = sup_space_size\n        self.tsdf_threshold = tsdf_threshold\n\n    def train_dataloader(self):\n        asl_params = {\n            \"data_json\": f'{self.training_data_list}/train.json',\n            \"sample_root\": self.sample_pcd_dir,\n            \"pc_size\": self.pc_size,\n            \"pc_sharpedge_size\": self.pc_sharpedge_size,\n            \"sup_near_uni_size\": self.sup_near_uni_size,\n            \"sup_near_sharp_size\": self.sup_near_sharp_size,\n            \"sup_space_size\": self.sup_space_size,\n            \"tsdf_threshold\": self.tsdf_threshold,\n            \"sharpedge_label\": self.sharpedge_label,\n            \"return_normal\": self.return_normal,\n        }\n        dataset = ObjaverseDataset(**asl_params)\n        return torch.utils.data.DataLoader(\n            dataset,\n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=True,\n        )\n\n    def val_dataloader(self):\n        asl_params = {\n            \"data_json\": f'{self.training_data_list}/val.json',\n            \"sample_root\": self.sample_pcd_dir,\n            \"pc_size\": self.pc_size,\n            \"pc_sharpedge_size\": self.pc_sharpedge_size,\n            \"sup_near_uni_size\": self.sup_near_uni_size,\n            \"sup_near_sharp_size\": self.sup_near_sharp_size,\n            \"sup_space_size\": self.sup_space_size,\n            \"tsdf_threshold\": self.tsdf_threshold,\n            \"sharpedge_label\": self.sharpedge_label,\n            \"return_normal\": self.return_normal, \n        }\n        dataset = ObjaverseDataset(**asl_params)\n        return torch.utils.data.DataLoader(\n            dataset,\n            batch_size=self.batch_size,\n            num_workers=self.val_num_workers,\n            pin_memory=True,\n            drop_last=True,\n        )\n"
  },
  {
    "path": "ultrashape/data/utils.py",
    "content": "# -*- coding: utf-8 -*-\n\n# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.\n# This file is part of the WebDataset library.\n# See the LICENSE file for licensing terms (BSD-style).\n\n\n\"\"\"Miscellaneous utility functions.\"\"\"\n\nimport importlib\nimport itertools as itt\nimport os\nimport re\nimport sys\nfrom typing import Any, Callable, Iterator, Union\nimport torch\nimport numpy as np\n\n\ndef make_seed(*args):\n    seed = 0\n    for arg in args:\n        seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF\n    return seed\n\n\nclass PipelineStage:\n    def invoke(self, *args, **kw):\n        raise NotImplementedError\n\n\ndef identity(x: Any) -> Any:\n    \"\"\"Return the argument as is.\"\"\"\n    return x\n\n\ndef safe_eval(s: str, expr: str = \"{}\"):\n    \"\"\"Evaluate the given expression more safely.\"\"\"\n    if re.sub(\"[^A-Za-z0-9_]\", \"\", s) != s:\n        raise ValueError(f\"safe_eval: illegal characters in: '{s}'\")\n    return eval(expr.format(s))\n\n\ndef lookup_sym(sym: str, modules: list):\n    \"\"\"Look up a symbol in a list of modules.\"\"\"\n    for mname in modules:\n        module = importlib.import_module(mname, package=\"webdataset\")\n        result = getattr(module, sym, None)\n        if result is not None:\n            return result\n    return None\n\n\ndef repeatedly0(\n    loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize\n):\n    \"\"\"Repeatedly returns batches from a DataLoader.\"\"\"\n    for _ in range(nepochs):\n        yield from itt.islice(loader, nbatches)\n\n\ndef guess_batchsize(batch: Union[tuple, list]):\n    \"\"\"Guess the batch size by looking at the length of the first element in a tuple.\"\"\"\n    return len(batch[0])\n\n\ndef repeatedly(\n    source: Iterator,\n    nepochs: int = None,\n    nbatches: int = None,\n    nsamples: int = None,\n    batchsize: Callable[..., int] = guess_batchsize,\n):\n    \"\"\"Repeatedly yield samples from an iterator.\"\"\"\n    epoch = 0\n    batch = 0\n    total = 0\n    while True:\n        for sample in source:\n            yield sample\n            batch += 1\n            if nbatches is not None and batch >= nbatches:\n                return\n            if nsamples is not None:\n                total += guess_batchsize(sample)\n                if total >= nsamples:\n                    return\n        epoch += 1\n        if nepochs is not None and epoch >= nepochs:\n            return\n\n\ndef pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress\n    \"\"\"Return node and worker info for PyTorch and some distributed environments.\"\"\"\n    rank = 0\n    world_size = 1\n    worker = 0\n    num_workers = 1\n    if \"RANK\" in os.environ and \"WORLD_SIZE\" in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n    else:\n        try:\n            import torch.distributed\n\n            if torch.distributed.is_available() and torch.distributed.is_initialized():\n                group = group or torch.distributed.group.WORLD\n                rank = torch.distributed.get_rank(group=group)\n                world_size = torch.distributed.get_world_size(group=group)\n        except ModuleNotFoundError:\n            pass\n    if \"WORKER\" in os.environ and \"NUM_WORKERS\" in os.environ:\n        worker = int(os.environ[\"WORKER\"])\n        num_workers = int(os.environ[\"NUM_WORKERS\"])\n    else:\n        try:\n            import torch.utils.data\n\n            worker_info = torch.utils.data.get_worker_info()\n            if worker_info is not None:\n                worker = worker_info.id\n                num_workers = worker_info.num_workers\n        except ModuleNotFoundError:\n            pass\n\n    return rank, world_size, worker, num_workers\n\n\ndef pytorch_worker_seed(group=None):\n    \"\"\"Compute a distinct, deterministic RNG seed for each worker and node.\"\"\"\n    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)\n    return rank * 1000 + worker\n\ndef worker_init_fn(_):\n    worker_info = torch.utils.data.get_worker_info()\n    worker_id = worker_info.id\n\n    # dataset = worker_info.dataset\n    # split_size = dataset.num_records // worker_info.num_workers\n    # # reset num_records to the true number to retain reliable length information\n    # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]\n    # current_id = np.random.choice(len(np.random.get_state()[1]), 1)\n    # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)\n\n    return np.random.seed(np.random.get_state()[1][0] + worker_id)\n\n\ndef collation_fn(samples, combine_tensors=True, combine_scalars=True):\n    \"\"\"\n\n    Args:\n        samples (list[dict]):\n        combine_tensors:\n        combine_scalars:\n\n    Returns:\n\n    \"\"\"\n\n    result = {}\n\n    keys = samples[0].keys()\n\n    for key in keys:\n        result[key] = []\n\n    for sample in samples:\n        for key in keys:\n            val = sample[key]\n            result[key].append(val)\n\n    for key in keys:\n        val_list = result[key]\n        if isinstance(val_list[0], (int, float)):\n            if combine_scalars:\n                result[key] = np.array(result[key])\n\n        elif isinstance(val_list[0], torch.Tensor):\n            if combine_tensors:\n                result[key] = torch.stack(val_list)\n\n        elif isinstance(val_list[0], np.ndarray):\n            if combine_tensors:\n                result[key] = np.stack(val_list)\n\n    return result\n"
  },
  {
    "path": "ultrashape/models/__init__.py",
    "content": "# Open Source Model Licensed under the Apache License Version 2.0\n# and Other Licenses of the Third-Party Components therein:\n# The below Model in this distribution may have been modified by THL A29 Limited\n# (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.\n\n# Copyright (C) 2024 THL A29 Limited, a Tencent company.  All rights reserved.\n# The below software and/or models in this distribution may have been\n# modified by THL A29 Limited (\"Tencent Modifications\").\n# All Tencent Modifications are Copyright (C) THL A29 Limited.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom .autoencoders import ShapeVAE\nfrom .conditioner_mask import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder\nfrom .denoisers import RefineDiT\n"
  },
  {
    "path": "ultrashape/models/autoencoders/__init__.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom .attention_blocks import CrossAttentionDecoder\nfrom .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \\\n    FlashVDMTopMCrossAttentionProcessor\nfrom .model import ShapeVAE, VectsetVAE\nfrom .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput\nfrom .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder\nfrom .vae_trainer import VAETrainer\n"
  },
  {
    "path": "ultrashape/models/autoencoders/attention_blocks.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Open Source Model Licensed under the Apache License Version 2.0\n# and Other Licenses of the Third-Party Components therein:\n# The below Model in this distribution may have been modified by THL A29 Limited\n# (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.\n\n# Copyright (C) 2024 THL A29 Limited, a Tencent company.  All rights reserved.\n# The below software and/or models in this distribution may have been\n# modified by THL A29 Limited (\"Tencent Modifications\").\n# All Tencent Modifications are Copyright (C) THL A29 Limited.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\n\nimport os\nfrom typing import Optional, Union, List\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch import Tensor\n\nfrom .attention_processors import CrossAttentionProcessor\nfrom ...utils import logger\nfrom ultrashape.utils import voxelize_from_point\n\nscaled_dot_product_attention = nn.functional.scaled_dot_product_attention\n\nif os.environ.get('USE_SAGEATTN', '0') == '1':\n    try:\n        from sageattention import sageattn\n    except ImportError:\n        raise ImportError('Please install the package \"sageattention\" to use this USE_SAGEATTN.')\n    scaled_dot_product_attention = sageattn\n\n\nclass FourierEmbedder(nn.Module):\n    \"\"\" The sin/cosine positional embedding. \"\"\"\n\n    def __init__(self,\n                 num_freqs: int = 6,\n                 logspace: bool = True,\n                 input_dim: int = 3,\n                 include_input: bool = True,\n                 include_pi: bool = True) -> None:\n\n        super().__init__()\n\n        if logspace:\n            frequencies = 2.0 ** torch.arange(\n                num_freqs,\n                dtype=torch.float32\n            )\n        else:\n            frequencies = torch.linspace(\n                1.0,\n                2.0 ** (num_freqs - 1),\n                num_freqs,\n                dtype=torch.float32\n            )\n\n        if include_pi:\n            frequencies *= torch.pi\n\n        self.register_buffer(\"frequencies\", frequencies, persistent=False)\n        self.include_input = include_input\n        self.num_freqs = num_freqs\n\n        self.out_dim = self.get_dims(input_dim)\n\n    def get_dims(self, input_dim):\n        temp = 1 if self.include_input or self.num_freqs == 0 else 0\n        out_dim = input_dim * (self.num_freqs * 2 + temp)\n\n        return out_dim\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\" Forward process.\n\n        Args:\n            x: tensor of shape [..., dim]\n\n        Returns:\n            embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]\n                where temp is 1 if include_input is True and 0 otherwise.\n        \"\"\"\n\n        if self.num_freqs > 0:\n            embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)\n            if self.include_input:\n                return torch.cat((x, embed.sin(), embed.cos()), dim=-1)\n            else:\n                return torch.cat((embed.sin(), embed.cos()), dim=-1)\n        else:\n            return x\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n        self.scale_by_keep = scale_by_keep\n\n    def forward(self, x):\n        \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n\n        This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n        the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n        See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n        changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n        'survival rate' as the argument.\n\n        \"\"\"\n        if self.drop_prob == 0. or not self.training:\n            return x\n        keep_prob = 1 - self.drop_prob\n        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n        if keep_prob > 0.0 and self.scale_by_keep:\n            random_tensor.div_(keep_prob)\n        return x * random_tensor\n\n    def extra_repr(self):\n        return f'drop_prob={round(self.drop_prob, 3):0.3f}'\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self, *,\n        width: int,\n        expand_ratio: int = 4,\n        output_width: int = None,\n        drop_path_rate: float = 0.0\n    ):\n        super().__init__()\n        self.width = width\n        self.c_fc = nn.Linear(width, width * expand_ratio)\n        self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)\n        self.gelu = nn.GELU()\n        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()\n\n    def forward(self, x):\n        return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))\n\n\nclass QKVMultiheadCrossAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        heads: int,\n        n_data: Optional[int] = None,\n        width=None,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm\n    ):\n        super().__init__()\n        self.heads = heads\n        self.n_data = n_data\n        self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n\n        self.attn_processor = CrossAttentionProcessor()\n\n    def forward(self, q, kv):\n        _, n_ctx, _ = q.shape\n        bs, n_data, width = kv.shape\n        attn_ch = width // self.heads // 2\n        q = q.view(bs, n_ctx, self.heads, -1)\n        kv = kv.view(bs, n_data, self.heads, -1)\n        k, v = torch.split(kv, attn_ch, dim=-1)\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n        q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))\n        out = self.attn_processor(self, q, k, v)\n        out = out.transpose(1, 2).reshape(bs, n_ctx, -1)\n        return out\n\n\nclass MultiheadCrossAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        width: int,\n        heads: int,\n        qkv_bias: bool = True,\n        n_data: Optional[int] = None,\n        data_width: Optional[int] = None,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        kv_cache: bool = False,\n    ):\n        super().__init__()\n        self.n_data = n_data\n        self.width = width\n        self.heads = heads\n        self.data_width = width if data_width is None else data_width\n        self.c_q = nn.Linear(width, width, bias=qkv_bias)\n        self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)\n        self.c_proj = nn.Linear(width, width)\n        self.attention = QKVMultiheadCrossAttention(\n            heads=heads,\n            n_data=n_data,\n            width=width,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm\n        )\n        self.kv_cache = kv_cache\n        self.data = None\n\n    def forward(self, x, data):\n        x = self.c_q(x)\n        if self.kv_cache:\n            if self.data is None:\n                self.data = self.c_kv(data)\n                logger.info('Save kv cache,this should be called only once for one mesh')\n            data = self.data\n        else:\n            data = self.c_kv(data)\n        x = self.attention(x, data)\n        x = self.c_proj(x)\n        return x\n\n\nclass ResidualCrossAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_data: Optional[int] = None,\n        width: int,\n        heads: int,\n        mlp_expand_ratio: int = 4,\n        data_width: Optional[int] = None,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False\n    ):\n        super().__init__()\n\n        if data_width is None:\n            data_width = width\n\n        self.attn = MultiheadCrossAttention(\n            n_data=n_data,\n            width=width,\n            heads=heads,\n            data_width=data_width,\n            qkv_bias=qkv_bias,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm\n        )\n        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)\n        self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)\n\n    def forward(self, x: torch.Tensor, data: torch.Tensor):\n        x = x + self.attn(self.ln_1(x), self.ln_2(data))\n        x = x + self.mlp(self.ln_3(x))\n        return x\n\n\nclass QKVMultiheadAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        heads: int,\n        n_ctx: int,\n        width=None,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm\n    ):\n        super().__init__()\n        self.heads = heads\n        self.n_ctx = n_ctx\n        self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n\n    def forward(self, qkv):\n        bs, n_ctx, width = qkv.shape\n        attn_ch = width // self.heads // 3\n        qkv = qkv.view(bs, n_ctx, self.heads, -1)\n        q, k, v = torch.split(qkv, attn_ch, dim=-1)\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))\n        out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)\n        return out\n\n\nclass MultiheadAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        qkv_bias: bool,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.heads = heads\n        self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)\n        self.c_proj = nn.Linear(width, width)\n        self.attention = QKVMultiheadAttention(\n            heads=heads,\n            n_ctx=n_ctx,\n            width=width,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm\n        )\n        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()\n\n    def forward(self, x):\n        x = self.c_qkv(x)\n        x = self.attention(x)\n        x = self.drop_path(self.c_proj(x))\n        return x\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0,\n    ):\n        super().__init__()\n        self.attn = MultiheadAttention(\n            n_ctx=n_ctx,\n            width=width,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm,\n            drop_path_rate=drop_path_rate\n        )\n        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)\n        self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        layers: int,\n        heads: int,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(\n                    n_ctx=n_ctx,\n                    width=width,\n                    heads=heads,\n                    qkv_bias=qkv_bias,\n                    norm_layer=norm_layer,\n                    qk_norm=qk_norm,\n                    drop_path_rate=drop_path_rate\n                )\n                for _ in range(layers)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor):\n        for block in self.resblocks:\n            x = block(x)\n        return x\n\n\nclass CrossAttentionDecoder(nn.Module):\n\n    def __init__(\n        self,\n        *,\n        num_latents: int,\n        out_channels: int,\n        fourier_embedder: FourierEmbedder,\n        width: int,\n        heads: int,\n        mlp_expand_ratio: int = 4,\n        downsample_ratio: int = 1,\n        enable_ln_post: bool = True,\n        qkv_bias: bool = True,\n        qk_norm: bool = False,\n        label_type: str = \"binary\"\n    ):\n        super().__init__()\n\n        self.enable_ln_post = enable_ln_post\n        self.fourier_embedder = fourier_embedder\n        self.downsample_ratio = downsample_ratio\n        self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)\n        if self.downsample_ratio != 1:\n            self.latents_proj = nn.Linear(width * downsample_ratio, width)\n        if self.enable_ln_post == False:\n            qk_norm = False\n        self.cross_attn_decoder = ResidualCrossAttentionBlock(\n            n_data=num_latents,\n            width=width,\n            mlp_expand_ratio=mlp_expand_ratio,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm\n        )\n\n        if self.enable_ln_post:\n            self.ln_post = nn.LayerNorm(width)\n        self.output_proj = nn.Linear(width, out_channels)\n        self.label_type = label_type\n        self.count = 0\n\n    def set_cross_attention_processor(self, processor):\n        self.cross_attn_decoder.attn.attention.attn_processor = processor\n\n    def set_default_cross_attention_processor(self):\n        self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor\n\n    def forward(self, queries=None, query_embeddings=None, latents=None):\n        if query_embeddings is None:\n            query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))\n        self.count += query_embeddings.shape[1]\n        if self.downsample_ratio != 1:\n            latents = self.latents_proj(latents)\n        x = self.cross_attn_decoder(query_embeddings, latents)\n        if self.enable_ln_post:\n            x = self.ln_post(x)\n        occ = self.output_proj(x)\n        return occ\n\n\ndef fps(\n    src: torch.Tensor,\n    batch: Optional[Tensor] = None,\n    ratio: Optional[Union[Tensor, float]] = None,\n    random_start: bool = True,\n    batch_size: Optional[int] = None,\n    ptr: Optional[Union[Tensor, List[int]]] = None,\n):\n    src = src.float()\n    from torch_cluster import fps as fps_fn\n    output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)\n    return output\n\n\nclass PointCrossAttentionEncoder(nn.Module):\n\n    def __init__(\n        self, *,\n        num_latents: int,\n        downsample_ratio: float,\n        pc_size: int,\n        pc_sharpedge_size: int,\n        fourier_embedder: FourierEmbedder,\n        point_feats: int,\n        width: int,\n        heads: int,\n        layers: int,\n        voxel_query_res: int,\n        normal_pe: bool = False,\n        qkv_bias: bool = True,\n        use_ln_post: bool = False,\n        use_checkpoint: bool = False,\n        qk_norm: bool = False,\n        jitter_query: bool = False,\n        voxel_query: bool = False,\n    ):\n\n        super().__init__()\n\n        self.use_checkpoint = use_checkpoint\n        self.num_latents = num_latents\n        self.downsample_ratio = downsample_ratio\n        self.point_feats = point_feats\n        self.normal_pe = normal_pe\n        self.jitter_query = jitter_query\n        self.voxel_query = voxel_query\n        self.voxel_query_res = voxel_query_res\n\n        if pc_sharpedge_size == 0:\n            print(\n                f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is zero')\n        else:\n            print(\n                f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')\n\n        self.pc_size = pc_size\n        self.pc_sharpedge_size = pc_sharpedge_size\n\n        self.fourier_embedder = fourier_embedder\n\n        if self.jitter_query or self.voxel_query:\n            self.input_proj_q = nn.Linear(self.fourier_embedder.out_dim, width)\n            self.input_proj_kv = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)\n        else:\n            self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)\n        self.cross_attn = ResidualCrossAttentionBlock(\n            width=width,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm\n        )\n\n        self.self_attn = None\n        if layers > 0:\n            self.self_attn = Transformer(\n                n_ctx=num_latents,\n                width=width,\n                layers=layers,\n                heads=heads,\n                qkv_bias=qkv_bias,\n                qk_norm=qk_norm\n            )\n\n        if use_ln_post:\n            self.ln_post = nn.LayerNorm(width)\n        else:\n            self.ln_post = None\n\n    def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):\n        B, N, D = pc.shape\n        num_pts = self.num_latents * self.downsample_ratio\n\n        # Compute number of latents\n        num_latents = int(num_pts / self.downsample_ratio)\n\n        # Compute the number of random and sharpedge latents\n        num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents\n        num_sharpedge_query = num_latents - num_random_query\n\n        # Split random and sharpedge surface points\n        random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)\n        assert random_pc.shape[1] <= self.pc_size, \"Random surface points size must be less than or equal to pc_size\"\n        assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, \"Sharpedge surface points size must be less than or equal to pc_sharpedge_size\"\n\n        # Randomly select random surface points and random query points\n        input_random_pc_size = int(num_random_query * self.downsample_ratio)\n        random_query_ratio = num_random_query / input_random_pc_size\n        idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]\n        input_random_pc = random_pc[:, idx_random_pc, :]\n\n        if self.voxel_query:\n            query_random_pc, query_voxel_indices = voxelize_from_point(pc, num_latents, resolution=self.voxel_query_res)\n        else:\n            flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)\n            N_down = int(flatten_input_random_pc.shape[0] / B)\n            batch_down = torch.arange(B).to(pc.device)\n            batch_down = torch.repeat_interleave(batch_down, N_down)\n            idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)\n            query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)\n\n        # Randomly select sharpedge surface points and sharpedge query points\n        input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)\n        if input_sharpedge_pc_size == 0 or self.voxel_query:\n            input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)\n            query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)\n        else:\n            sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size\n            idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[\n                            :input_sharpedge_pc_size]\n            input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]\n            flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)\n            N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)\n            batch_down = torch.arange(B).to(pc.device)\n            batch_down = torch.repeat_interleave(batch_down, N_down)\n            idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)\n            query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)\n\n        # Concatenate random and sharpedge surface points and query points\n        query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)\n        input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)\n\n        if self.jitter_query:\n            R = self.voxel_query_res // 2\n            noise = torch.rand_like(query_pc)\n            query_pc += (noise - 0.5) / R\n\n        # PE\n        query = self.fourier_embedder(query_pc)\n        data = self.fourier_embedder(input_pc)\n\n        # Concat normal if given\n        if self.point_feats != 0:\n\n            random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],\n                                                                        dim=1)\n            input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]\n            if not self.voxel_query and not self.jitter_query:\n                flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)\n                query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,\n                                                                                           flatten_input_random_surface_feats.shape[\n                                                                                               -1])\n\n            if input_sharpedge_pc_size == 0:\n                input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,\n                                                            dtype=input_random_surface_feats.dtype).to(pc.device)\n                if not self.voxel_query and not self.jitter_query:\n                    query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(\n                        pc.device)\n            else:\n                input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]\n                if not self.voxel_query and not self.jitter_query:\n                    flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,\n                                                                                            -1)\n                    query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,\n                                                                                                        flatten_input_sharpedge_surface_feats.shape[\n                                                                                                            -1])\n            if not self.voxel_query and not self.jitter_query:\n                query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)\n            input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)\n\n            if self.normal_pe:\n                if not self.voxel_query and not self.jitter_query:\n                    query_normal_pe = self.fourier_embedder(query_feats[..., :3])\n                    query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)\n                input_normal_pe = self.fourier_embedder(input_feats[..., :3])\n                input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)\n\n            if not self.voxel_query and not self.jitter_query:\n                query = torch.cat([query, query_feats], dim=-1)\n            data = torch.cat([data, input_feats], dim=-1)\n\n        if input_sharpedge_pc_size == 0:\n            query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)\n            input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)\n\n        if self.voxel_query:\n            pc_infos = [query_voxel_indices, query_random_pc]\n        else:\n            pc_infos = [query_pc, input_pc, query_random_pc, input_random_pc, query_sharpedge_pc, input_sharpedge_pc]\n        return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), pc_infos\n\n\n    def forward(self, pc, feats):\n        \"\"\"\n\n        Args:\n            pc (torch.FloatTensor): [B, N, 3]\n            feats (torch.FloatTensor or None): [B, N, C]\n\n        Returns:\n\n        \"\"\"\n        query, data, pc_infos = self.sample_points_and_latents(pc, feats)\n\n        if self.jitter_query or self.voxel_query:\n            query = self.input_proj_q(query)\n            query = query\n            data = self.input_proj_kv(data)\n            data = data\n        else:\n            query = self.input_proj(query)\n            query = query\n            data = self.input_proj(data)\n            data = data\n\n        latents = self.cross_attn(query, data)\n        if self.self_attn is not None:\n            latents = self.self_attn(latents)\n\n        if self.ln_post is not None:\n            latents = self.ln_post(latents)\n\n        return latents, pc_infos\n"
  },
  {
    "path": "ultrashape/models/autoencoders/attention_processors.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\n\nimport torch\nimport torch.nn.functional as F\n\nscaled_dot_product_attention = F.scaled_dot_product_attention\nif os.environ.get('CA_USE_SAGEATTN', '0') == '1':\n    try:\n        from sageattention import sageattn\n    except ImportError:\n        raise ImportError('Please install the package \"sageattention\" to use this USE_SAGEATTN.')\n    scaled_dot_product_attention = sageattn\n\n\nclass CrossAttentionProcessor:\n    def __call__(self, attn, q, k, v):\n        out = scaled_dot_product_attention(q, k, v)\n        return out\n\n\nclass FlashVDMCrossAttentionProcessor:\n    def __init__(self, topk=None):\n        self.topk = topk\n\n    def __call__(self, attn, q, k, v):\n        if k.shape[-2] == 3072:\n            topk = 1024\n        elif k.shape[-2] == 512:\n            topk = 256\n        else:\n            topk = k.shape[-2] // 3\n\n        if self.topk is True:\n            q1 = q[:, :, ::100, :]\n            sim = q1 @ k.transpose(-1, -2)\n            sim = torch.mean(sim, -2)\n            topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)\n            topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])\n            v0 = torch.gather(v, dim=-2, index=topk_ind)\n            k0 = torch.gather(k, dim=-2, index=topk_ind)\n            out = scaled_dot_product_attention(q, k0, v0)\n        elif self.topk is False:\n            out = scaled_dot_product_attention(q, k, v)\n        else:\n            idx, counts = self.topk\n            start = 0\n            outs = []\n            for grid_coord, count in zip(idx, counts):\n                end = start + count\n                q_chunk = q[:, :, start:end, :]\n                k0, v0 = self.select_topkv(q_chunk, k, v, topk)\n                out = scaled_dot_product_attention(q_chunk, k0, v0)\n                outs.append(out)\n                start += count\n            out = torch.cat(outs, dim=-2)\n        self.topk = False\n        return out\n\n    def select_topkv(self, q_chunk, k, v, topk):\n        q1 = q_chunk[:, :, ::50, :]\n        sim = q1 @ k.transpose(-1, -2)\n        sim = torch.mean(sim, -2)\n        topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)\n        topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])\n        v0 = torch.gather(v, dim=-2, index=topk_ind)\n        k0 = torch.gather(k, dim=-2, index=topk_ind)\n        return k0, v0\n\n\nclass FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):\n    def select_topkv(self, q_chunk, k, v, topk):\n        q1 = q_chunk[:, :, ::30, :]\n        sim = q1 @ k.transpose(-1, -2)\n        # sim = sim.to(torch.float32)\n        sim = sim.softmax(-1)\n        sim = torch.mean(sim, 1)\n        activated_token = torch.where(sim > 1e-6)[2]\n        index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)\n        index = index.expand(-1, v.shape[1], -1, v.shape[-1])\n        v0 = torch.gather(v, dim=-2, index=index)\n        k0 = torch.gather(k, dim=-2, index=index)\n        return k0, v0\n"
  },
  {
    "path": "ultrashape/models/autoencoders/model.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Open Source Model Licensed under the Apache License Version 2.0\n# and Other Licenses of the Third-Party Components therein:\n# The below Model in this distribution may have been modified by THL A29 Limited\n# (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.\n\n# Copyright (C) 2024 THL A29 Limited, a Tencent company.  All rights reserved.\n# The below software and/or models in this distribution may have been\n# modified by THL A29 Limited (\"Tencent Modifications\").\n# All Tencent Modifications are Copyright (C) THL A29 Limited.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\nfrom typing import Union, List\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport yaml\n\nfrom .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder\nfrom .surface_extractors import MCSurfaceExtractor, SurfaceExtractors\nfrom .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding\nfrom ...utils import logger, synchronize_timer, smart_load_model\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):\n        \"\"\"\n        Initialize a diagonal Gaussian distribution with mean and log-variance parameters.\n\n        Args:\n            parameters (Union[torch.Tensor, List[torch.Tensor]]): \n                Either a single tensor containing concatenated mean and log-variance along `feat_dim`,\n                or a list of two tensors [mean, logvar].\n            deterministic (bool, optional): If True, the distribution is deterministic (zero variance). \n                Default is False. feat_dim (int, optional): Dimension along which mean and logvar are \n                concatenated if parameters is a single tensor. Default is 1.\n        \"\"\"\n        self.feat_dim = feat_dim\n        self.parameters = parameters\n\n        if isinstance(parameters, list):\n            self.mean = parameters[0]\n            self.logvar = parameters[1]\n        else:\n            self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)\n\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean)\n\n    def sample(self):\n        \"\"\"\n        Sample from the diagonal Gaussian distribution.\n\n        Returns:\n            torch.Tensor: A sample tensor with the same shape as the mean.\n        \"\"\"\n        x = self.mean + self.std * torch.randn_like(self.mean)\n        return x\n\n    def kl(self, other=None, dims=(1, 2)):\n        \"\"\"\n        Compute the Kullback-Leibler (KL) divergence between this distribution and another.\n\n        If `other` is None, compute KL divergence to a standard normal distribution N(0, I).\n\n        Args:\n            other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.\n            dims (tuple, optional): Dimensions along which to compute the mean KL divergence. \n                Default is (1, 2, 3).\n\n        Returns:\n            torch.Tensor: The mean KL divergence value.\n        \"\"\"\n        if self.deterministic:\n            return torch.Tensor([0.])\n        else:\n            if other is None:\n                return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)\n            else:\n                return 0.5 * torch.mean(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var - 1.0 - self.logvar + other.logvar,\n                    dim=dims)\n\n    def nll(self, sample, dims=(1, 2, 3)):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims)\n\n    def mode(self):\n        return self.mean\n\n\nclass VectsetVAE(nn.Module):\n\n    @classmethod\n    @synchronize_timer('VectsetVAE Model Loading')\n    def from_single_file(\n        cls,\n        ckpt_path,\n        config_path=None,\n        params=None,\n        device='cuda',\n        dtype=torch.float16,\n        use_safetensors=None,\n        **kwargs,\n    ):\n        # load config\n        with open(config_path, 'r') as f:\n            config = yaml.safe_load(f)\n\n        # load ckpt\n        if use_safetensors:\n            ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')\n        if not os.path.exists(ckpt_path):\n            raise FileNotFoundError(f\"Model file {ckpt_path} not found\")\n\n        logger.info(f\"Loading model from {ckpt_path}\")\n        if use_safetensors:\n            import safetensors.torch\n            ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')\n        else:\n            ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)\n\n        if params is not None:\n            model_kwargs = params\n        else:\n            model_kwargs = config['params']\n        model_kwargs.update(kwargs)\n\n        model = cls(**model_kwargs)\n        model.load_state_dict(ckpt)\n\n        model.to(device=device, dtype=dtype)\n        return model\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_path,\n        device='cuda',\n        params=None,\n        dtype=torch.float16,\n        use_safetensors=False,\n        variant='fp16',\n        subfolder='hunyuan3d-vae-v2-1',\n        **kwargs,\n    ):\n        config_path, ckpt_path = smart_load_model(\n            model_path,\n            subfolder=subfolder,\n            use_safetensors=use_safetensors,\n            variant=variant\n        )\n\n        return cls.from_single_file(\n            ckpt_path,\n            config_path=config_path,\n            params=params,\n            device=device,\n            dtype=dtype,\n            use_safetensors=use_safetensors,\n            **kwargs\n        )\n        \n    def init_from_ckpt(self, path, ignore_keys=()):\n        state_dict = torch.load(path, map_location=\"cpu\")\n        state_dict = state_dict.get(\"state_dict\", state_dict)\n        keys = list(state_dict.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del state_dict[k]\n        missing, unexpected = self.load_state_dict(state_dict, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def __init__(\n        self,\n        volume_decoder=None,\n        surface_extractor=None\n    ):\n        super().__init__()\n        if volume_decoder is None:\n            volume_decoder = VanillaVolumeDecoder()\n        if surface_extractor is None:\n            surface_extractor = MCSurfaceExtractor()\n        self.volume_decoder = volume_decoder\n        self.surface_extractor = surface_extractor\n\n    def latents2mesh(self, latents: torch.FloatTensor, **kwargs):\n        with synchronize_timer('Volume decoding'):\n            grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)\n        with synchronize_timer('Surface extraction'):\n            outputs = self.surface_extractor(grid_logits, **kwargs)\n        return outputs, grid_logits\n\n    def enable_flashvdm_decoder(\n        self,\n        enabled: bool = True,\n        adaptive_kv_selection=True,\n        topk_mode='mean',\n        mc_algo='mc',\n    ):\n        if enabled:\n            if adaptive_kv_selection:\n                self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)\n            else:\n                self.volume_decoder = HierarchicalVolumeDecoding()\n            if mc_algo not in SurfaceExtractors.keys():\n                raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')\n            self.surface_extractor = SurfaceExtractors[mc_algo]()\n        else:\n            self.volume_decoder = VanillaVolumeDecoder()\n            self.surface_extractor = MCSurfaceExtractor()\n\n\nclass ShapeVAE(VectsetVAE):\n    def __init__(\n        self,\n        *,\n        num_latents: int,\n        embed_dim: int,\n        width: int,\n        heads: int,\n        num_decoder_layers: int,\n        num_encoder_layers: int = 8,\n        pc_size: int = 5120,\n        pc_sharpedge_size: int = 5120,\n        point_feats: int = 3,\n        downsample_ratio: int = 20,\n        geo_decoder_downsample_ratio: int = 1,\n        geo_decoder_mlp_expand_ratio: int = 4,\n        geo_decoder_ln_post: bool = True,\n        num_freqs: int = 8,\n        include_pi: bool = True,\n        qkv_bias: bool = True,\n        qk_norm: bool = False,\n        label_type: str = \"binary\",\n        drop_path_rate: float = 0.0,\n        scale_factor: float = 1.0,\n        use_ln_post: bool = True,\n        enable_flashvdm: bool = False,\n        ckpt_path = None,\n        jitter_query: bool = False,\n        voxel_query: bool = False,\n        voxel_query_res: int = 128,\n    ):\n        super().__init__()\n        self.geo_decoder_ln_post = geo_decoder_ln_post\n        self.downsample_ratio = downsample_ratio\n\n        self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)\n\n        self.encoder = PointCrossAttentionEncoder(\n            fourier_embedder=self.fourier_embedder,\n            num_latents=num_latents,\n            downsample_ratio=self.downsample_ratio,\n            pc_size=pc_size,\n            pc_sharpedge_size=pc_sharpedge_size,\n            point_feats=point_feats,\n            width=width,\n            heads=heads,\n            layers=num_encoder_layers,\n            qkv_bias=qkv_bias,\n            use_ln_post=use_ln_post,\n            qk_norm=qk_norm,\n            jitter_query=jitter_query,\n            voxel_query=voxel_query,\n            voxel_query_res=voxel_query_res\n        )\n\n        self.pre_kl = nn.Linear(width, embed_dim * 2)\n        self.post_kl = nn.Linear(embed_dim, width)\n\n        self.transformer = Transformer(\n            n_ctx=num_latents,\n            width=width,\n            layers=num_decoder_layers,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            drop_path_rate=drop_path_rate\n        )\n\n        self.geo_decoder = CrossAttentionDecoder(\n            fourier_embedder=self.fourier_embedder,\n            out_channels=1,\n            num_latents=num_latents,\n            mlp_expand_ratio=geo_decoder_mlp_expand_ratio,\n            downsample_ratio=geo_decoder_downsample_ratio,\n            enable_ln_post=self.geo_decoder_ln_post,\n            width=width // geo_decoder_downsample_ratio,\n            heads=heads // geo_decoder_downsample_ratio,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            label_type=label_type,\n        )\n\n        self.scale_factor = scale_factor\n        self.latent_shape = (num_latents, embed_dim)\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        if enable_flashvdm:\n            self.enable_flashvdm_decoder()\n\n    def forward(self, latents):\n        latents = self.post_kl(latents)\n        latents = self.transformer(latents)\n        return latents\n\n    def encode(self, surface, sample_posterior=True, need_kl=False, need_voxel=False):\n        pc, feats = surface[:, :, :3], surface[:, :, 3:]\n        latents, pc_infos = self.encoder(pc, feats)\n        # print(latents.shape, self.pre_kl.weight.shape)\n        moments = self.pre_kl(latents)\n        posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)\n        if sample_posterior:\n            latents = posterior.sample()\n        else:\n            latents = posterior.mode()\n        if need_kl:\n            return latents, posterior \n        if need_voxel:\n            return latents, pc_infos[0]\n        return latents\n\n    def decode(self, latents, voxel_idx=None):\n        latents = self.post_kl(latents)\n        latents = self.transformer(latents)\n        return latents\n\n    def query(self, latents, queries, voxel_idx=None):\n        \"\"\"\n        Args:\n            queries (torch.FloatTensor): [B, N, 3]\n            latents (torch.FloatTensor): [B, embed_dim]\n\n        Returns:\n            logits (torch.FloatTensor): [B, N], occupancy logits\n        \"\"\"\n        logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)\n            \n        return logits\n"
  },
  {
    "path": "ultrashape/models/autoencoders/surface_extractors.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom typing import Union, Tuple, List\n\nimport numpy as np\nimport torch\nfrom skimage import measure\nimport cubvh\n\n\nclass Latent2MeshOutput:\n    def __init__(self, mesh_v=None, mesh_f=None):\n        self.mesh_v = mesh_v\n        self.mesh_f = mesh_f\n\n\ndef center_vertices(vertices):\n    \"\"\"Translate the vertices so that bounding box is centered at zero.\"\"\"\n    vert_min = vertices.min(dim=0)[0]\n    vert_max = vertices.max(dim=0)[0]\n    vert_center = 0.5 * (vert_min + vert_max)\n    return vertices - vert_center\n\n\nclass SurfaceExtractor:\n    def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):\n        \"\"\"\n        Compute grid size, bounding box minimum coordinates, and bounding box size based on input \n        bounds and resolution.\n\n        Args:\n            bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single \n            float representing half side length.\n                If float, bounds are assumed symmetric around zero in all axes.\n                Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].\n            octree_resolution (int): Resolution of the octree grid.\n\n        Returns:\n            grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.\n            bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).\n            bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).\n        \"\"\"\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n\n        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n        grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]\n        return grid_size, bbox_min, bbox_size\n\n    def run(self, *args, **kwargs):\n        \"\"\"\n        Abstract method to extract surface mesh from grid logits.\n\n        This method should be implemented by subclasses.\n\n        Raises:\n            NotImplementedError: Always, since this is an abstract method.\n        \"\"\"\n        return NotImplementedError\n\n    def __call__(self, grid_logits, **kwargs):\n        \"\"\"\n        Process a batch of grid logits to extract surface meshes.\n\n        Args:\n            grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).\n            **kwargs: Additional keyword arguments passed to the `run` method.\n\n        Returns:\n            List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.\n                If extraction fails for a grid, None is appended at that position.\n        \"\"\"\n        outputs = []\n        for i in range(grid_logits.shape[0]):\n            try:\n                vertices, faces = self.run(grid_logits[i], **kwargs)\n                vertices = vertices.astype(np.float32)\n                faces = np.ascontiguousarray(faces)\n                outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))\n\n            except Exception:\n                import traceback\n                traceback.print_exc()\n                outputs.append(None)\n\n        return outputs\n\n\ndef get_sparse_valid_voxels(grid_logit: torch.Tensor):\n\n    if not isinstance(grid_logit, torch.Tensor):\n        raise TypeError(\"Input must be a PyTorch tensor.\")\n    if grid_logit.dim() != 3 or grid_logit.shape[0] != grid_logit.shape[1] or grid_logit.shape[0] != grid_logit.shape[2]:\n        raise ValueError(\"Input tensor must have shape (N, N, N)\")\n\n    N = grid_logit.shape[0]\n    device = grid_logit.device\n\n    # Chunk processing to save memory\n    chunk_size = 128\n\n    all_sparse_coords = []\n    all_sparse_logits = []\n\n    # Process in chunks along x-axis\n    for start_x in range(0, N - 1, chunk_size):\n        end_x = min(start_x + chunk_size, N - 1)\n\n        # Determine slice range including +1 for neighbor checks\n        # slice_end needs to be end_x + 1 to include the neighbors for the last voxel in chunk\n        slice_end = end_x + 1\n\n        chunk = grid_logit[start_x:slice_end, :, :]\n        nan_mask = torch.isnan(chunk)\n\n        # Compute mask for this chunk (valid voxels are 0 to end_x - start_x)\n        # Note: chunk shape is [D_chunk, N, N].\n        # We want to check validity for [0..D_chunk-1, :-1, :-1]\n\n        sub_nan_mask = nan_mask\n\n        # Validity check requires looking at i and i+1\n        # Invalid if ANY corner is NaN\n        invalid_voxel_mask = (\n            sub_nan_mask[:-1, :-1, :-1] |\n            sub_nan_mask[1:, :-1, :-1]  |\n            sub_nan_mask[:-1, 1:, :-1]  |\n            sub_nan_mask[:-1, :-1, 1:]  |\n            sub_nan_mask[:-1, 1:, 1:]   |\n            sub_nan_mask[1:, :-1, 1:]   |\n            sub_nan_mask[1:, 1:, :-1]   |\n            sub_nan_mask[1:, 1:, 1:]\n        )\n\n        valid_voxel_mask = ~invalid_voxel_mask\n\n        # Get local coordinates\n        local_coords = valid_voxel_mask.nonzero(as_tuple=False)\n\n        if local_coords.shape[0] > 0:\n            lx, ly, lz = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]\n\n            # Extract logits using local indices on the chunk\n            # v0 is at lx, v1 is at lx+1, etc.\n            sparse_vertex_logits = torch.stack([\n                chunk[lx,     ly,     lz],     # v0\n                chunk[lx + 1, ly,     lz],     # v1\n                chunk[lx + 1, ly + 1, lz],     # v2\n                chunk[lx,     ly + 1, lz],     # v3\n                chunk[lx,     ly,     lz + 1], # v4\n                chunk[lx + 1, ly,     lz + 1], # v5\n                chunk[lx + 1, ly + 1, lz + 1], # v6\n                chunk[lx,     ly + 1, lz + 1]  # v7\n            ], dim=1)\n\n            # Convert local coords to global coords\n            # x coordinate needs offset added\n            global_coords = local_coords.clone()\n            global_coords[:, 0] += start_x\n\n            all_sparse_coords.append(global_coords)\n            all_sparse_logits.append(sparse_vertex_logits)\n\n        # Free memory\n        del chunk, nan_mask, invalid_voxel_mask, valid_voxel_mask, local_coords\n\n    if not all_sparse_coords:\n        return torch.empty((0, 3), dtype=torch.long, device=device), torch.empty((0, 8), dtype=grid_logit.dtype, device=device)\n\n    sparse_coords = torch.cat(all_sparse_coords, dim=0)\n    sparse_vertex_logits = torch.cat(all_sparse_logits, dim=0)\n\n    return sparse_coords, sparse_vertex_logits\n\n\nclass MCSurfaceExtractor(SurfaceExtractor):\n    def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):\n        \"\"\"\n        Extract surface mesh using the Marching Cubes algorithm.\n\n        Args:\n            grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.\n            mc_level (float): The level (iso-value) at which to extract the surface.\n            bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.\n            octree_resolution (int): Resolution of the octree grid.\n            **kwargs: Additional keyword arguments (ignored).\n\n        Returns:\n            Tuple[np.ndarray, np.ndarray]: Tuple containing:\n                - vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding \n                  box coordinates.\n                - faces (np.ndarray): Extracted mesh faces (triangles).\n        \"\"\"\n\n        grid_logit = grid_logit.detach()\n\n        sparse_coords, sparse_logits = get_sparse_valid_voxels(grid_logit)\n        # Convert to float32 only for the sparse set\n        vertices, faces = cubvh.sparse_marching_cubes(sparse_coords, sparse_logits.float(), mc_level)\n\n        vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()\n        # vertices, faces, normals, _ = measure.marching_cubes(grid_logit,\n        #             mc_level, method=\"lewiner\", mask=(~np.isnan(grid_logit)))\n        grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)\n        vertices = vertices / grid_size * bbox_size + bbox_min\n        return vertices, faces\n\n\nclass DMCSurfaceExtractor(SurfaceExtractor):\n    def run(self, grid_logit, *, octree_resolution, **kwargs):\n        \"\"\"\n        Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.\n\n        Args:\n            grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.\n            octree_resolution (int): Resolution of the octree grid.\n            **kwargs: Additional keyword arguments (ignored).\n\n        Returns:\n            Tuple[np.ndarray, np.ndarray]: Tuple containing:\n                - vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.\n                - faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.\n        \n        Raises:\n            ImportError: If the 'diso' package is not installed.\n        \"\"\"\n        device = grid_logit.device\n        if not hasattr(self, 'dmc'):\n            try:\n                from diso import DiffDMC\n                self.dmc = DiffDMC(dtype=torch.float32).to(device)\n            except:\n                raise ImportError(\"Please install diso via `pip install diso`, or set mc_algo to 'mc'\")\n        sdf = -grid_logit / octree_resolution\n        sdf = sdf.to(torch.float32).contiguous()\n        verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)\n        verts = center_vertices(verts)\n        vertices = verts.detach().cpu().numpy()\n        faces = faces.detach().cpu().numpy()[:, ::-1]\n        return vertices, faces\n\n\nSurfaceExtractors = {\n    'mc': MCSurfaceExtractor,\n    'dmc': DMCSurfaceExtractor,\n}\n"
  },
  {
    "path": "ultrashape/models/autoencoders/vae_trainer.py",
    "content": "import os\nfrom contextlib import contextmanager\nfrom typing import List, Tuple, Optional, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\nimport pytorch_lightning as pl\nfrom pytorch_lightning.utilities import rank_zero_info\nfrom pytorch_lightning.utilities import rank_zero_only\nimport trimesh\n\nfrom ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model\n\n\ndef export_to_trimesh(mesh_output):\n    if isinstance(mesh_output, list):\n        outputs = []\n        for mesh in mesh_output:\n            if mesh is None:\n                outputs.append(None)\n            else:\n                mesh.mesh_f = mesh.mesh_f[:, ::-1]\n                mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)\n                outputs.append(mesh_output)\n        return outputs\n    else:\n        mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]\n        mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)\n        return mesh_output\n\nclass VAETrainer(pl.LightningModule):\n    def __init__(\n        self,\n        *,\n        vae_config,\n        optimizer_cfg,\n        loss_cfg,\n        save_dir,\n        mc_res,\n        ckpt_path: Optional[str] = None,\n        ignore_keys: Union[Tuple[str], List[str]] = (),\n        torch_compile: bool = False,\n    ):\n        super().__init__()\n\n        # ========= init optimizer config ========= #\n        self.optimizer_cfg = optimizer_cfg\n        self.loss_cfg = loss_cfg\n        self.ckpt_path = ckpt_path\n        self.vae_model = instantiate_vae_model(vae_config, requires_grad=True)\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n\n        self.mc_res = mc_res\n        self.save_root = save_dir\n        if not os.path.exists(save_dir):\n            os.makedirs(save_dir)\n\n        # ========= torch compile to accelerate ========= #\n        self.torch_compile = torch_compile\n        if self.torch_compile:\n            torch.nn.Module.compile(self.vae_model)\n            print(f'*' * 100)\n            print(f'Compile model for acceleration')\n            print(f'*' * 100)\n\n    def init_from_ckpt(self, path, ignore_keys=()):\n        ckpt = torch.load(path, map_location=\"cpu\")\n        if 'state_dict' not in ckpt:\n            # deepspeed ckpt\n            state_dict = {}\n            for k in ckpt.keys():\n                new_k = k.replace('_forward_module.', '')\n                state_dict[new_k] = ckpt[k]\n        else:\n            state_dict = ckpt[\"state_dict\"]\n\n        keys = list(state_dict.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if ik in k:\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del state_dict[k]\n        \n        # # ==================== Weight Surgery Start ====================\n        # old_key_base = \"vae_model.encoder.input_proj\"\n        # old_weight_key = f\"{old_key_base}.weight\"\n        # old_bias_key = f\"{old_key_base}.bias\"\n\n        # if old_weight_key in state_dict:\n        #     print(f\"[*] Detected legacy '{old_key_base}' in checkpoint. Performing weight surgery...\")\n            \n        #     src_weight = state_dict[old_weight_key]\n        #     src_bias = state_dict[old_bias_key]\n            \n        #     encoder = self.vae_model.encoder\n        #     fourier_dim = encoder.fourier_embedder.out_dim\n\n        #     # --- A. input_proj_kv ---\n        #     # shape: [width, fourier_dim + point_feats]\n        #     encoder.input_proj_kv.weight.data.copy_(src_weight)\n        #     encoder.input_proj_kv.bias.data.copy_(src_bias)\n        #     print(f\"    -> Loaded input_proj_kv from {old_key_base}\")\n\n        #     # --- B. input_proj_q ---\n        #     # shape: [width, fourier_dim]\n        #     sliced_weight = src_weight[:, :fourier_dim]\n        #     encoder.input_proj_q.weight.data.copy_(sliced_weight)\n        #     encoder.input_proj_q.bias.data.copy_(src_bias)\n        #     print(f\"    -> Loaded input_proj_q (sliced) from {old_key_base}\")\n\n        #     del state_dict[old_weight_key]\n        #     if old_bias_key in state_dict:\n        #         del state_dict[old_bias_key]\n        # # ==================== Weight Surgery End ====================\n\n        missing, unexpected = self.load_state_dict(state_dict, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n            print(f\"Unexpected Keys: {unexpected}\")\n\n\n    def configure_optimizers(self) -> Tuple[List, List]:\n        lr = self.learning_rate\n\n        params_list = []\n        trainable_parameters = list(self.vae_model.parameters())\n        params_list.append({'params': trainable_parameters, 'lr': lr})\n\n        optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)\n        if hasattr(self.optimizer_cfg, 'scheduler'):\n            scheduler_func = instantiate_from_config(\n                self.optimizer_cfg.scheduler,\n                max_decay_steps=self.trainer.max_steps,\n                lr_max=lr\n            )\n            scheduler = {\n                \"scheduler\": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),\n                \"interval\": \"step\",\n                \"frequency\": 1\n            }\n            schedulers = [scheduler]\n        else:\n            schedulers = []\n        optimizers = [optimizer]\n\n        return optimizers, schedulers\n\n    def on_train_epoch_start(self) -> None:\n        pl.seed_everything(self.trainer.global_rank)\n\n    def forward(self, batch):\n        sup_pc_s_list = [batch[\"sup_near_uniform\"], batch[\"sup_near_sharp\"], batch[\"sup_space\"]]\n        rand_points = [sup_pc_s[:,:,:3] for sup_pc_s in sup_pc_s_list]\n        rand_points_val = [sup_pc_s[:,:,3:] for sup_pc_s in sup_pc_s_list]\n\n        rand_points = torch.cat(rand_points, dim=1)\n        target = torch.cat(rand_points_val, dim=1)[...,0]\n        target = -target\n\n        latents, posterior = self.vae_model.encode(\n            batch['surface'], sample_posterior=True, need_kl=True)\n        latents = self.vae_model.decode(latents)\n        logits = self.vae_model.query(latents, rand_points)\n        \n        loss_kl = posterior.kl()\n        loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]\n\n        criteria = torch.nn.MSELoss()\n        criteria2 = torch.nn.L1Loss()\n        loss_logits = criteria(logits, target).mean() + criteria2(logits, target).mean()\n        loss = self.loss_cfg.lambda_logits * loss_logits + self.loss_cfg.lambda_kl * loss_kl\n\n        loss_dict = {\n            \"loss\": loss,\n            \"loss_logits\": loss_logits,\n            \"loss_kl\": loss_kl\n        }\n        return loss_dict, latents\n\n    def training_step(self, batch, batch_idx, optimizer_idx=0):\n        loss, latents = self.forward(batch)\n        split = 'train'\n        loss_dict = {\n            f\"{split}/total_loss\": loss[\"loss\"].detach(),\n            f\"{split}/loss_logits\": loss[\"loss_logits\"].detach(),\n            f\"{split}/loss_kl\": loss[\"loss_kl\"].detach(),\n            f\"{split}/lr_abs\": self.optimizers().param_groups[0]['lr'],\n        }\n        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)\n\n        return loss\n\n    def validation_step(self, batch, batch_idx, optimizer_idx=0):\n        loss, latents = self.forward(batch)\n        split = 'val'\n        loss_dict = {\n            f\"{split}/total_loss\": loss[\"loss\"].detach(),\n            f\"{split}/loss_logits\": loss[\"loss_logits\"].detach(),\n            f\"{split}/loss_kl\": loss[\"loss_kl\"].detach(),\n            f\"{split}/lr_abs\": self.optimizers().param_groups[0]['lr'],\n        }\n        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)\n        if self.trainer.global_rank < 2:\n            with torch.no_grad():\n                save_dir = f\"{self.save_root}/gs{self.global_step:010d}_rank{self.trainer.global_rank}\"\n                if not os.path.exists(save_dir):\n                    os.makedirs(save_dir)\n                uids = batch.get('uid')\n                for i, latent in enumerate(latents[:5]):\n                    mesh, grid_logits = self.vae_model.latents2mesh(\n                            latent[None],\n                            output_type='trimesh',\n                            bounds=1.01,\n                            mc_level=0.0,\n                            num_chunks=20000,\n                            octree_resolution=self.mc_res,\n                            mc_algo='mc',\n                            enable_pbar=True\n                        )\n        \n                    mesh = export_to_trimesh(mesh[0])\n                    \n                    save_path = f\"{save_dir}/recon_{os.path.splitext(os.path.basename(uids[i]))[0]}_mc{self.mc_res}.obj\"               \n                    mesh.export(save_path)\n\n        return loss\n"
  },
  {
    "path": "ultrashape/models/autoencoders/volume_decoders.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom typing import Union, Tuple, List, Callable\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import repeat\nfrom tqdm import tqdm\n\nfrom .attention_blocks import CrossAttentionDecoder\nfrom .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor\nfrom ...utils import logger\n\n\ndef extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):\n    val = input_tensor + alpha\n    valid_mask = val > -9000\n\n    mask = torch.ones_like(val, dtype=torch.int32)\n    sign = torch.sign(val.to(torch.float32))\n\n    # Helper to compute neighbor for a single direction\n    def check_neighbor_sign(shift, axis):\n        if shift == 0:\n            return\n\n        pad_dims = [0, 0, 0, 0, 0, 0]\n        if axis == 0:\n            pad_idx = 0 if shift > 0 else 1\n            pad_dims[pad_idx] = abs(shift)\n        elif axis == 1:\n            pad_idx = 2 if shift > 0 else 3\n            pad_dims[pad_idx] = abs(shift)\n        elif axis == 2:\n            pad_idx = 4 if shift > 0 else 5\n            pad_dims[pad_idx] = abs(shift)\n\n        padded = F.pad(val.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')\n\n        slice_dims = [slice(None)] * 3\n        if axis == 0:\n            if shift > 0: slice_dims[0] = slice(shift, None)\n            else: slice_dims[0] = slice(None, shift)\n        elif axis == 1:\n            if shift > 0: slice_dims[1] = slice(shift, None)\n            else: slice_dims[1] = slice(None, shift)\n        elif axis == 2:\n            if shift > 0: slice_dims[2] = slice(shift, None)\n            else: slice_dims[2] = slice(None, shift)\n\n        padded = padded.squeeze(0).squeeze(0)\n        neighbor = padded[slice_dims]\n        neighbor = torch.where(neighbor > -9000, neighbor, val)\n\n        # Check sign consistency\n        neighbor_sign = torch.sign(neighbor.to(torch.float32))\n        return (neighbor_sign == sign)\n\n    # Iteratively check neighbors and update mask\n    # directions: (shift, axis)\n    directions = [(1, 0), (-1, 0), (1, 1), (-1, 1), (1, 2), (-1, 2)]\n\n    for shift, axis in directions:\n        is_same = check_neighbor_sign(shift, axis)\n        mask = mask & is_same.to(torch.int32)\n\n    # Invert mask: we want 1 where ANY neighbor has different sign\n    mask = (~(mask.bool())).to(torch.int32)\n    return mask * valid_mask.to(torch.int32)\n\n\ndef generate_dense_grid_points(\n    bbox_min: np.ndarray,\n    bbox_max: np.ndarray,\n    octree_resolution: int,\n    indexing: str = \"ij\",\n):\n    length = bbox_max - bbox_min\n    num_cells = octree_resolution\n\n    x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)\n    y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)\n    z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)\n    [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)\n    xyz = np.stack((xs, ys, zs), axis=-1)\n    grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]\n\n    return xyz, grid_size, length\n\n\nclass VanillaVolumeDecoder:\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: Callable,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        octree_resolution: int = None,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        device = latents.device\n        dtype = latents.dtype\n        batch_size = latents.shape[0]\n\n        # 1. generate query points\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n\n        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=octree_resolution,\n            indexing=\"ij\"\n        )\n        xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)\n\n        # 2. latents to 3d volume\n        batch_logits = []\n        for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f\"Volume Decoding\",\n                          disable=not enable_pbar):\n            chunk_queries = xyz_samples[start: start + num_chunks, :]\n            chunk_queries = repeat(chunk_queries, \"p c -> b p c\", b=batch_size)\n            logits = geo_decoder(queries=chunk_queries, latents=latents)\n            batch_logits.append(logits)\n\n        grid_logits = torch.cat(batch_logits, dim=1)\n        grid_logits = grid_logits.view((batch_size, *grid_size)).float()\n\n        return grid_logits\n\n\nclass HierarchicalVolumeDecoding:\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: Callable,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        mc_level: float = 0.0,\n        octree_resolution: int = None,\n        min_resolution: int = 63,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        device = latents.device\n        dtype = latents.dtype\n\n        resolutions = []\n        if octree_resolution < min_resolution:\n            resolutions.append(octree_resolution)\n        while octree_resolution >= min_resolution:\n            resolutions.append(octree_resolution)\n            octree_resolution = octree_resolution // 2\n        resolutions.reverse()\n\n        # 1. generate query points\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n        bbox_min = np.array(bounds[0:3])\n        bbox_max = np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=resolutions[0],\n            indexing=\"ij\"\n        )\n\n        dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)\n        dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))\n\n        grid_size = np.array(grid_size)\n        xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)\n\n        # 2. latents to 3d volume\n        batch_logits = []\n        batch_size = latents.shape[0]\n        for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),\n                          desc=f\"Hierarchical Volume Decoding [r{resolutions[0] + 1}]\"):\n            queries = xyz_samples[start: start + num_chunks, :]\n            batch_queries = repeat(queries, \"p c -> b p c\", b=batch_size)\n            logits = geo_decoder(queries=batch_queries, latents=latents)\n            batch_logits.append(logits)\n\n        grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))\n\n        for octree_depth_now in resolutions[1:]:\n            grid_size = np.array([octree_depth_now + 1] * 3)\n            resolution = bbox_size / octree_depth_now\n            next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)\n            next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)\n            curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)\n            curr_points += grid_logits.squeeze(0).abs() < 0.95\n\n            if octree_depth_now == resolutions[-1]:\n                expand_num = 0\n            else:\n                expand_num = 1\n            for i in range(expand_num):\n                curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)\n            (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)\n            next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1\n            for i in range(2 - expand_num):\n                next_index = dilate(next_index.unsqueeze(0)).squeeze(0)\n            nidx = torch.where(next_index > 0)\n\n            # Store shape before deleting\n            next_index_shape = next_index.shape\n            del next_index\n            torch.cuda.empty_cache()\n\n            next_points = torch.stack(nidx, dim=1)\n            next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +\n                           torch.tensor(bbox_min, dtype=next_points.dtype, device=device))\n            batch_logits = []\n            for start in tqdm(range(0, next_points.shape[0], num_chunks),\n                              desc=f\"Hierarchical Volume Decoding [r{octree_depth_now + 1}]\"):\n                queries = next_points[start: start + num_chunks, :]\n                batch_queries = repeat(queries, \"p c -> b p c\", b=batch_size)\n                logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)\n                batch_logits.append(logits)\n\n            # Delayed allocation of next_logits\n            next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)\n            grid_logits = torch.cat(batch_logits, dim=1)\n            next_logits[nidx] = grid_logits[0, ..., 0]\n            grid_logits = next_logits.unsqueeze(0)\n        grid_logits[grid_logits == -10000.] = float('nan')\n\n        return grid_logits\n\n\nclass FlashVDMVolumeDecoding:\n    def __init__(self, topk_mode='mean'):\n        if topk_mode not in ['mean', 'merge']:\n            raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {[\"mean\", \"merge\"]}')\n\n        if topk_mode == 'mean':\n            self.processor = FlashVDMCrossAttentionProcessor()\n        else:\n            self.processor = FlashVDMTopMCrossAttentionProcessor()\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: CrossAttentionDecoder,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        mc_level: float = 0.0,\n        octree_resolution: int = None,\n        min_resolution: int = 63,\n        mini_grid_num: int = 4,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        processor = self.processor\n        geo_decoder.set_cross_attention_processor(processor)\n\n        device = latents.device\n        dtype = latents.dtype\n\n        resolutions = []\n        if octree_resolution < min_resolution:\n            resolutions.append(octree_resolution)\n        while octree_resolution >= min_resolution:\n            resolutions.append(octree_resolution)\n            octree_resolution = octree_resolution // 2\n        resolutions.reverse()\n        resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1\n        for i, resolution in enumerate(resolutions[1:]):\n            resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)\n\n        logger.info(f\"FlashVDMVolumeDecoding Resolution: {resolutions}\")\n\n        # 1. generate query points\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n        bbox_min = np.array(bounds[0:3])\n        bbox_max = np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=resolutions[0],\n            indexing=\"ij\"\n        )\n\n        dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)\n        dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))\n\n        grid_size = np.array(grid_size)\n\n        # 2. latents to 3d volume\n        xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)\n        batch_size = latents.shape[0]\n        mini_grid_size = xyz_samples.shape[0] // mini_grid_num\n        xyz_samples = xyz_samples.view(\n            mini_grid_num, mini_grid_size,\n            mini_grid_num, mini_grid_size,\n            mini_grid_num, mini_grid_size, 3\n        ).permute(\n            0, 2, 4, 1, 3, 5, 6\n        ).reshape(\n            -1, mini_grid_size * mini_grid_size * mini_grid_size, 3\n        )\n        batch_logits = []\n        num_batchs = max(num_chunks // xyz_samples.shape[1], 1)\n        for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),\n                          desc=f\"FlashVDM Volume Decoding\", disable=not enable_pbar):\n            queries = xyz_samples[start: start + num_batchs, :]\n            batch = queries.shape[0]\n            batch_latents = repeat(latents.squeeze(0), \"p c -> b p c\", b=batch)\n            processor.topk = True\n\n            # Chunk queries along dim 1 if too large\n            if queries.shape[1] > num_chunks:\n                batch_logits_sub = []\n                for sub_start in range(0, queries.shape[1], num_chunks):\n                    sub_queries = queries[:, sub_start: sub_start + num_chunks, :]\n                    logits = geo_decoder(queries=sub_queries, latents=batch_latents)\n                    batch_logits_sub.append(logits)\n                logits = torch.cat(batch_logits_sub, dim=1)\n            else:\n                logits = geo_decoder(queries=queries, latents=batch_latents)\n\n            batch_logits.append(logits)\n        grid_logits = torch.cat(batch_logits, dim=0).reshape(\n            mini_grid_num, mini_grid_num, mini_grid_num,\n            mini_grid_size, mini_grid_size,\n            mini_grid_size\n        ).permute(0, 3, 1, 4, 2, 5).contiguous().view(\n            (batch_size, grid_size[0], grid_size[1], grid_size[2])\n        )\n\n        for octree_depth_now in resolutions[1:]:\n            grid_size = np.array([octree_depth_now + 1] * 3)\n            resolution = bbox_size / octree_depth_now\n            next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)\n            curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)\n            curr_points += grid_logits.squeeze(0).abs() < 0.95\n\n            if octree_depth_now == resolutions[-1]:\n                expand_num = 0\n            else:\n                expand_num = 1\n            for i in range(expand_num):\n                curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)\n                curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)\n            (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)\n\n            next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1\n            for i in range(2 - expand_num):\n                next_index = dilate(next_index.unsqueeze(0)).squeeze(0)\n            nidx = torch.where(next_index > 0)\n\n            # Store shape before deleting\n            next_index_shape = next_index.shape\n            del next_index\n            torch.cuda.empty_cache()\n\n            next_points = torch.stack(nidx, dim=1)\n            next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +\n                           torch.tensor(bbox_min, dtype=torch.float32, device=device))\n\n            query_grid_num = 6\n            min_val = next_points.min(axis=0).values\n            max_val = next_points.max(axis=0).values\n            vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)\n            index = torch.floor(vol_queries_index).long()\n            index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]\n            index = index.sort()\n            next_points = next_points[index.indices].unsqueeze(0).contiguous()\n            unique_values = torch.unique(index.values, return_counts=True)\n            grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)\n            input_grid = [[], []]\n            logits_grid_list = []\n            start_num = 0\n            sum_num = 0\n            for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):\n                remaining_count = count\n                while remaining_count > 0:\n                    space_left = num_chunks - sum_num\n                    # If buffer is full, flush it\n                    if space_left <= 0:\n                        processor.topk = input_grid\n                        logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)\n                        start_num = start_num + sum_num\n                        logits_grid_list.append(logits_grid)\n                        input_grid = [[], []]\n                        sum_num = 0\n                        space_left = num_chunks\n\n                    take = min(remaining_count, space_left)\n                    input_grid[0].append(grid_index)\n                    input_grid[1].append(take)\n                    sum_num += take\n                    remaining_count -= take\n            if sum_num > 0:\n                processor.topk = input_grid\n                logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)\n                logits_grid_list.append(logits_grid)\n            logits_grid = torch.cat(logits_grid_list, dim=1)\n            grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)\n\n            # Delayed allocation of next_logits\n            next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)\n            next_logits[nidx] = grid_logits\n            grid_logits = next_logits.unsqueeze(0)\n\n        grid_logits[grid_logits == -10000.] = float('nan')\n\n        return grid_logits\n"
  },
  {
    "path": "ultrashape/models/conditioner_mask.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Open Source Model Licensed under the Apache License Version 2.0\n# and Other Licenses of the Third-Party Components therein:\n# The below Model in this distribution may have been modified by THL A29 Limited\n# (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.\n\n# Copyright (C) 2024 THL A29 Limited, a Tencent company.  All rights reserved.\n# The below software and/or models in this distribution may have been\n# modified by THL A29 Limited (\"Tencent Modifications\").\n# All Tencent Modifications are Copyright (C) THL A29 Limited.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torchvision import transforms\nfrom transformers import (\n    CLIPVisionModelWithProjection,\n    CLIPVisionConfig,\n    Dinov2Model,\n    Dinov2Config,\n)\nfrom transformers import AutoImageProcessor, AutoModel\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000 ** omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    return np.concatenate([emb_sin, emb_cos], axis=1)\n\n\nclass ImageEncoder(nn.Module):\n    def __init__(\n        self,\n        version=None,\n        config=None,\n        use_cls_token=True,\n        image_size=224,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if config is None:\n            self.model = AutoModel.from_pretrained(version)\n        else:\n            self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))\n            \n        self.model.eval()\n        self.model.requires_grad_(False)\n        self.use_cls_token = use_cls_token\n        self.size = image_size // 14\n        self.num_patches = (image_size // 14) ** 2\n        if self.use_cls_token:\n            self.num_patches += 1\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),\n                transforms.CenterCrop(image_size),\n                transforms.Normalize(\n                    mean=self.mean,\n                    std=self.std,\n                ),\n            ]\n        )\n\n        self.mask_transform = transforms.Compose(\n            [\n                transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),\n                transforms.CenterCrop(image_size),\n            ]\n        )\n\n    def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):\n        if value_range is not None:\n            low, high = value_range\n            image = (image - low) / (high - low)\n\n        image = image.to(self.model.device, dtype=self.model.dtype)\n        inputs = self.transform(image)\n        outputs = self.model(inputs)\n\n        last_hidden_state = outputs.last_hidden_state\n        if not self.use_cls_token:\n            last_hidden_state = last_hidden_state[:, 1:, :]\n\n        if mask is not None:\n            pool = nn.MaxPool2d(kernel_size=(14, 14), stride=(14, 14))\n            \n            mask = self.mask_transform(mask)\n            mask = mask.to(image.device, dtype=image.dtype)\n            downsampled_mask = pool(mask)\n            flattened_mask = downsampled_mask.view(downsampled_mask.shape[0], -1)\n            flattened_mask = flattened_mask.unsqueeze(-1)\n\n            if self.use_cls_token:\n                flattened_mask = torch.cat(\n                    [torch.ones(flattened_mask.shape[0], 1, 1, device=flattened_mask.device, dtype=flattened_mask.dtype),\n                     flattened_mask], dim=1)\n\n            valid_mask = (flattened_mask != -1).float()\n            masked_hidden_state = last_hidden_state * valid_mask\n            valid_mask_bool = valid_mask.squeeze(-1) > 0\n            \n            valid_counts = valid_mask_bool.sum(dim=1)\n            max_valid_tokens = valid_counts.max().item()\n            \n            batch_indices = torch.arange(valid_mask_bool.shape[0], device=valid_mask_bool.device)\n            batch_indices = batch_indices.unsqueeze(1).expand(-1, valid_mask_bool.shape[1])\n            \n            flat_batch_indices = batch_indices[valid_mask_bool]\n            flat_token_indices = torch.arange(valid_mask_bool.shape[1], device=valid_mask_bool.device)\n            flat_token_indices = flat_token_indices.unsqueeze(0).expand(valid_mask_bool.shape[0], -1)\n            flat_token_indices = flat_token_indices[valid_mask_bool]\n            \n            valid_tokens = masked_hidden_state[flat_batch_indices, flat_token_indices]\n            # Create output tensor with special padding value (-1) instead of zeros\n            final_output = torch.full(\n                (valid_mask_bool.shape[0], max_valid_tokens, last_hidden_state.shape[-1]),\n                -1.0,  # Use -1 as padding value to clearly distinguish from valid tokens\n                device=last_hidden_state.device, dtype=last_hidden_state.dtype\n            )\n            \n            cum_counts = torch.cumsum(valid_counts, dim=0) - valid_counts\n            for i in range(valid_mask_bool.shape[0]):\n                if valid_counts[i] > 0:\n                    start_idx = cum_counts[i]\n                    end_idx = start_idx + valid_counts[i]\n                    final_output[i, :valid_counts[i]] = valid_tokens[start_idx:end_idx]\n            \n            return final_output\n        \n        return last_hidden_state\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        device = next(self.model.parameters()).device\n        dtype = next(self.model.parameters()).dtype\n\n        num_tokens = kwargs.get('num_tokens', self.num_patches)\n\n        zero = torch.zeros(\n            batch_size,\n            num_tokens,\n            self.model.config.hidden_size,\n            device=device,\n            dtype=dtype,\n        )\n\n        return zero\n\n\nclass CLIPImageEncoder(ImageEncoder):\n    MODEL_CLASS = CLIPVisionModelWithProjection\n    MODEL_CONFIG_CLASS = CLIPVisionConfig\n    mean = [0.48145466, 0.4578275, 0.40821073]\n    std = [0.26862954, 0.26130258, 0.27577711]\n\n\nclass DinoImageEncoder(ImageEncoder):\n    MODEL_CLASS = Dinov2Model\n    MODEL_CONFIG_CLASS = Dinov2Config\n    mean = [0.485, 0.456, 0.406]\n    std = [0.229, 0.224, 0.225]\n\nclass DinoImageEncoderMV(DinoImageEncoder):\n    def __init__(\n        self,\n        version=None,\n        config=None,\n        use_cls_token=True,\n        image_size=224,\n        view_num=4,\n        **kwargs,\n    ):\n        super().__init__(version, config, use_cls_token, image_size, **kwargs)\n        self.view_num = view_num\n        self.num_patches = self.num_patches\n        pos = np.arange(self.view_num, dtype=np.float32)\n        view_embedding = torch.from_numpy(\n            get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()\n\n        view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)\n        self.view_embed = view_embedding.unsqueeze(0)\n\n    def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):\n        if value_range is not None:\n            low, high = value_range\n            image = (image - low) / (high - low)\n\n        image = image.to(self.model.device, dtype=self.model.dtype)\n\n        bs, num_views, c, h, w = image.shape\n        image = image.view(bs * num_views, c, h, w)\n\n        inputs = self.transform(image)\n        outputs = self.model(inputs)\n\n        last_hidden_state = outputs.last_hidden_state\n        last_hidden_state = last_hidden_state.view(\n            bs, num_views, last_hidden_state.shape[-2],\n            last_hidden_state.shape[-1]\n        )\n\n        view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)\n        if view_idxs is not None:\n            assert len(view_idxs) == bs\n            view_embeddings = []\n            for i in range(bs):\n                view_idx = view_idxs[i]\n                assert num_views == len(view_idx)\n                view_embeddings.append(self.view_embed[:, view_idx, ...])\n            view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)\n\n        if num_views != self.view_num:\n            view_embedding = view_embedding[:, :num_views, ...]\n        last_hidden_state = last_hidden_state + view_embedding\n        last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],\n                                                   last_hidden_state.shape[-1])\n        return last_hidden_state\n\n    def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):\n        device = next(self.model.parameters()).device\n        dtype = next(self.model.parameters()).dtype\n        zero = torch.zeros(\n            batch_size,\n            self.num_patches * len(view_idxs[0]),\n            self.model.config.hidden_size,\n            device=device,\n            dtype=dtype,\n        )\n        return zero\n\n\ndef build_image_encoder(config):\n    if config['type'] == 'CLIPImageEncoder':\n        return CLIPImageEncoder(**config['kwargs'])\n    elif config['type'] == 'DinoImageEncoder':\n        return DinoImageEncoder(**config['kwargs'])\n    elif config['type'] == 'DinoImageEncoderMV':\n        return DinoImageEncoderMV(**config['kwargs'])\n    else:\n        raise ValueError(f'Unknown image encoder type: {config[\"type\"]}')\n\n\nclass DualImageEncoder(nn.Module):\n    def __init__(\n        self,\n        main_image_encoder,\n        additional_image_encoder,\n    ):\n        super().__init__()\n        self.main_image_encoder = build_image_encoder(main_image_encoder)\n        self.additional_image_encoder = build_image_encoder(additional_image_encoder)\n\n    def forward(self, image, mask=None, **kwargs):\n        outputs = {\n            'main': self.main_image_encoder(image, mask=mask, **kwargs),\n            'additional': self.additional_image_encoder(image, mask=mask, **kwargs),\n        }\n        return outputs\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        outputs = {\n            'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),\n            'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),\n        }\n        return outputs\n\n\nclass SingleImageEncoder(nn.Module):\n    def __init__(\n        self,\n        main_image_encoder,\n        drop_ratio=0.1,\n    ):\n        super().__init__()\n        self.main_image_encoder = build_image_encoder(main_image_encoder)\n        self.drop_ratio = drop_ratio\n        # self.disable_drop = disable_drop\n\n    def forward(self, image, disable_drop=True, mask=None, **kwargs):\n        outputs = {\n            'main': self.main_image_encoder(image, mask=mask, **kwargs),\n        }\n        \n        if disable_drop:\n            return outputs\n        else:\n            random_p = torch.rand(len(image), device='cuda')\n            remain_bool_tensor = random_p > self.drop_ratio\n            outputs['main'] *= remain_bool_tensor.view(-1,1,1)\n        return outputs\n\n        \n        outputs = {\n            'main': self.main_image_encoder(image, mask=mask, **kwargs),\n        }\n        return outputs\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        outputs = {\n            'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),\n        }\n        return outputs\n"
  },
  {
    "path": "ultrashape/models/denoisers/__init__.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom .dit_mask import RefineDiT\n"
  },
  {
    "path": "ultrashape/models/denoisers/dit_mask.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Open Source Model Licensed under the Apache License Version 2.0\n# and Other Licenses of the Third-Party Components therein:\n# The below Model in this distribution may have been modified by THL A29 Limited\n# (\"Tencent Modifications\"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.\n\n# Copyright (C) 2024 THL A29 Limited, a Tencent company.  All rights reserved.\n# The below software and/or models in this distribution may have been\n# modified by THL A29 Limited (\"Tencent Modifications\").\n# All Tencent Modifications are Copyright (C) THL A29 Limited.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\nimport yaml\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\nfrom .moe_layers import MoEBlock\nfrom ...utils import logger, synchronize_timer, smart_load_model\n\nfrom flash_attn import flash_attn_varlen_func\n\n\ndef modulate(x, shift, scale):\n    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n\n\nclass Timesteps(nn.Module):\n    def __init__(self,\n                 num_channels: int,\n                 downscale_freq_shift: float = 0.0,\n                 scale: int = 1,\n                 max_period: int = 10000\n                 ):\n        super().__init__()\n        self.num_channels = num_channels\n        self.downscale_freq_shift = downscale_freq_shift\n        self.scale = scale\n        self.max_period = max_period\n\n    def forward(self, timesteps):\n        assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n        embedding_dim = self.num_channels\n        half_dim = embedding_dim // 2\n        exponent = -math.log(self.max_period) * torch.arange(\n            start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)\n        exponent = exponent / (half_dim - self.downscale_freq_shift)\n        emb = torch.exp(exponent)\n        emb = timesteps[:, None].float() * emb[None, :]\n        emb = self.scale * emb\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n        if embedding_dim % 2 == 1:\n            emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n        return emb\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n\n    def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):\n        super().__init__()\n        if out_size is None:\n            out_size = hidden_size\n        self.mlp = nn.Sequential(\n            nn.Linear(hidden_size, frequency_embedding_size, bias=True),\n            nn.GELU(),\n            nn.Linear(frequency_embedding_size, out_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n        if cond_proj_dim is not None:\n            self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)\n\n        self.time_embed = Timesteps(hidden_size)\n\n    def forward(self, t, condition):\n\n        t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)\n\n        # t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)\n        if condition is not None:\n            t_freq = t_freq + self.cond_proj(condition)\n\n        t = self.mlp(t_freq)\n        t = t.unsqueeze(dim=1)\n        return t\n\n\nclass MLP(nn.Module):\n    def __init__(self, *, width: int):\n        super().__init__()\n        self.width = width\n        self.fc1 = nn.Linear(width, width * 4)\n        self.fc2 = nn.Linear(width * 4, width)\n        self.gelu = nn.GELU()\n\n    def forward(self, x):\n        return self.fc2(self.gelu(self.fc1(x)))\n\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        qdim,\n        kdim,\n        num_heads,\n        qkv_bias=True,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm,\n        **kwargs,\n    ):\n        super().__init__()\n        self.qdim = qdim\n        self.kdim = kdim\n        self.num_heads = num_heads\n        assert self.qdim % num_heads == 0, \"self.qdim must be divisible by num_heads\"\n        self.head_dim = self.qdim // num_heads\n        assert self.head_dim % 8 == 0 and self.head_dim <= 128, \"Only support head_dim <= 128 and divisible by 8\"\n        self.scale = self.head_dim ** -0.5\n\n        self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)\n        self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)\n        self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)\n\n        # TODO: eps should be 1 / 65530 if using fp16\n        self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.out_proj = nn.Linear(qdim, qdim, bias=True)\n\n\n    def forward(self, x, y):\n        \"\"\"\n        Parameters\n        ----------\n        x: torch.Tensor\n            (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)\n        y: torch.Tensor\n            (batch, seqlen2, hidden_dim2) - may contain padding (marked with -1)\n        freqs_cis_img: torch.Tensor\n            (batch, hidden_dim // 2), RoPE for image\n        \"\"\"\n        b, s1, c = x.shape  # [b, s1, D]\n\n        # Detect padding tokens: check if all values in the feature dimension are -1\n        # y_mask: [b, s2], True for valid tokens, False for padding\n        y_mask = (y != -1).any(dim=-1)  # [b, s2]\n        has_padding = not y_mask.all()\n\n        _, s2, c = y.shape  # [b, s2, 1024]\n        q = self.to_q(x)\n        k = self.to_k(y)\n        v = self.to_v(y)\n\n        kv = torch.cat((k, v), dim=-1)\n        split_size = kv.shape[-1] // self.num_heads // 2\n        kv = kv.view(1, -1, self.num_heads, split_size * 2)\n        k, v = torch.split(kv, split_size, dim=-1)\n\n        q = q.view(b, s1, self.num_heads, self.head_dim)  # [b, s1, h, d]\n        k = k.view(b, s2, self.num_heads, self.head_dim)  # [b, s2, h, d]\n        v = v.view(b, s2, self.num_heads, self.head_dim)  # [b, s2, h, d]\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        if has_padding:\n            seqlens_k = y_mask.sum(dim=1).int()\n            q_flat = q.reshape(-1, self.num_heads, self.head_dim)\n            \n            # For k, v: only keep valid tokens (remove padding)\n            # Create indices for valid tokens\n            valid_indices = []\n            cu_seqlens_k = [0]\n            for i in range(b):\n                valid_len = seqlens_k[i].item()\n                batch_indices = torch.arange(valid_len, device=y.device) + i * s2\n                valid_indices.append(batch_indices)\n                cu_seqlens_k.append(cu_seqlens_k[-1] + valid_len)\n            \n            valid_indices = torch.cat(valid_indices)\n            k_flat = k.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices]  # [total_k, h, d]\n            v_flat = v.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices]  # [total_k, h, d]\n            \n            # Create cumulative sequence lengths\n            cu_seqlens_q = torch.arange(0, (b + 1) * s1, s1, dtype=torch.int32, device=x.device)\n            cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=x.device)\n            \n            # Call flash attention varlen\n            q_flat = q_flat.to(torch.bfloat16)\n            k_flat = k_flat.to(torch.bfloat16)\n            v_flat = v_flat.to(torch.bfloat16)\n\n            context = flash_attn_varlen_func(\n                q_flat, k_flat, v_flat,\n                cu_seqlens_q, cu_seqlens_k,\n                s1, seqlens_k.max().item(),\n                dropout_p=0.0,\n                softmax_scale=None,\n                causal=False\n            )\n            context = context.reshape(b, s1, -1)\n        else:\n            with torch.backends.cuda.sdp_kernel(\n                enable_flash=True,\n                enable_math=False,\n                enable_mem_efficient=True\n            ):\n                q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))\n                \n                attn_mask = None\n                context = F.scaled_dot_product_attention(\n                    q, k, v, attn_mask=attn_mask\n                ).transpose(1, 2).reshape(b, s1, -1)\n\n        out = self.out_proj(context)\n\n        return out\n\n\nclass Attention(nn.Module):\n    \"\"\"\n    We rename some layer names to align with flash attention\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        qkv_bias=True,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.head_dim = self.dim // num_heads\n        # This assertion is aligned with flash attention\n        assert self.head_dim % 8 == 0 and self.head_dim <= 128, \"Only support head_dim <= 128 and divisible by 8\"\n        self.scale = self.head_dim ** -0.5\n\n        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.to_k = nn.Linear(dim, dim, bias=qkv_bias)\n        self.to_v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()\n        self.out_proj = nn.Linear(dim, dim)\n\n    # def forward(self, x):\n    def forward(self, x, rotary_cos=None, rotary_sin=None):\n        B, N, C = x.shape\n\n        q = self.to_q(x)\n        k = self.to_k(x)\n        v = self.to_v(x)\n\n        qkv = torch.cat((q, k, v), dim=-1)\n        split_size = qkv.shape[-1] // self.num_heads // 3\n        qkv = qkv.view(1, -1, self.num_heads, split_size * 3)\n        q, k, v = torch.split(qkv, split_size, dim=-1)\n\n        q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # [b, h, s, d]\n        k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # [b, h, s, d]\n        v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)\n\n        q = self.q_norm(q)  # [b, h, s, d]\n        k = self.k_norm(k)  # [b, h, s, d]\n\n        # ========================= Apply RoPE =========================\n        if rotary_cos is not None:\n            q = apply_rotary_emb(q, rotary_cos, rotary_sin)\n            k = apply_rotary_emb(k, rotary_cos, rotary_sin)\n        # ==============================================================\n\n        with torch.backends.cuda.sdp_kernel(\n            enable_flash=True,\n            enable_math=False,\n            enable_mem_efficient=True\n        ):\n            x = F.scaled_dot_product_attention(q, k, v)\n            x = x.transpose(1, 2).reshape(B, N, -1)\n\n        x = self.out_proj(x)\n        return x\n\n\nclass DiTBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        c_emb_size,\n        num_heads,\n        text_states_dim=1024,\n        use_flash_attn=False,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm,\n        qk_norm_layer=nn.RMSNorm,\n        init_scale=1.0,\n        qkv_bias=True,\n        skip_connection=True,\n        timested_modulate=False,\n        use_moe: bool = False,\n        num_experts: int = 8,\n        moe_top_k: int = 2,\n        **kwargs,\n    ):\n        super().__init__()\n        self.use_flash_attn = use_flash_attn\n        use_ele_affine = True\n\n        # ========================= Self-Attention =========================\n        self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)\n        self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,\n                               norm_layer=qk_norm_layer)\n\n        # ========================= FFN =========================\n        self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)\n\n        # ========================= Add =========================\n        # Simply use add like SDXL.\n        self.timested_modulate = timested_modulate\n        if self.timested_modulate:\n            self.default_modulation = nn.Sequential(\n                nn.SiLU(),\n                nn.Linear(c_emb_size, hidden_size, bias=True)\n            )\n\n        # ========================= Cross-Attention =========================\n        self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,\n                                    qk_norm=qk_norm, norm_layer=qk_norm_layer, init_scale=init_scale)\n        self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)\n\n        if skip_connection:\n            self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)\n            self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)\n        else:\n            self.skip_linear = None\n\n        self.use_moe = use_moe\n        if self.use_moe:\n            self.moe = MoEBlock(\n                hidden_size,\n                num_experts=num_experts,\n                moe_top_k=moe_top_k,\n                dropout=0.0,\n                activation_fn=\"gelu\",\n                final_dropout=False,\n                ff_inner_dim=int(hidden_size * 4.0),\n                ff_bias=True,\n            )\n        else:\n            self.mlp = MLP(width=hidden_size)\n\n    def forward(self, x, c=None, text_states=None, skip_value=None, rotary_cos=None, rotary_sin=None):\n\n        if self.skip_linear is not None:\n            cat = torch.cat([skip_value, x], dim=-1)\n            x = self.skip_linear(cat)\n            x = self.skip_norm(x)\n\n        # Self-Attention\n        if self.timested_modulate:\n            shift_msa = self.default_modulation(c).unsqueeze(dim=1)\n            x = x + shift_msa\n\n        attn_out = self.attn1(self.norm1(x), rotary_cos=rotary_cos, rotary_sin=rotary_sin)\n\n        x = x + attn_out\n\n        # Cross-Attention\n        x = x + self.attn2(self.norm2(x), text_states)\n\n        # FFN Layer\n        mlp_inputs = self.norm3(x)\n\n        if self.use_moe:\n            x = x + self.moe(mlp_inputs)\n        else:\n            x = x + self.mlp(mlp_inputs)\n\n        return x\n\n\nclass AttentionPool(nn.Module):\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x, attention_mask=None):\n        x = x.permute(1, 0, 2)  # NLC -> LNC\n        if attention_mask is not None:\n            attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)\n            global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)\n            x = torch.cat([global_emb[None,], x], dim=0)\n\n        else:\n            x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (L+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (L+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x[:1], key=x, value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False\n        )\n        return x.squeeze(0)\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of DiT.\n    \"\"\"\n\n    def __init__(self, final_hidden_size, out_channels):\n        super().__init__()\n        self.final_hidden_size = final_hidden_size\n        self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)\n        self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)\n\n    def forward(self, x):\n        x = self.norm_final(x)\n        x = x[:, 1:]\n        x = self.linear(x)\n        return x\n\n\nclass RefineDiT(nn.Module):\n\n    @classmethod\n    @synchronize_timer('Refine Model Loading')\n    def from_single_file(\n        cls,\n        ckpt_path,\n        config_path,\n        device='cuda',\n        dtype=torch.float16,\n        use_safetensors=None,\n        **kwargs,\n    ):\n        # load config\n        with open(config_path, 'r') as f:\n            config = yaml.safe_load(f)\n\n        # load ckpt\n        if use_safetensors:\n            ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')\n        if not os.path.exists(ckpt_path):\n            raise FileNotFoundError(f\"Model file {ckpt_path} not found\")\n\n        logger.info(f\"Loading model from {ckpt_path}\")\n        if use_safetensors:\n            import safetensors.torch\n            ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')\n        else:\n            ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)\n\n        if 'model' in ckpt:\n            ckpt = ckpt['model']\n        if 'model' in config:\n            config = config['model']\n\n        model_kwargs = config['params']\n        model_kwargs.update(kwargs)\n\n        model = cls(**model_kwargs)\n        model.load_state_dict(ckpt)\n        model.to(device=device, dtype=dtype)\n        return model\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_path,\n        device='cuda',\n        dtype=torch.float16,\n        use_safetensors=False,\n        variant='fp16',\n        subfolder='hunyuan3d-dit-v2-1',\n        **kwargs,\n    ):\n        config_path, ckpt_path = smart_load_model(\n            model_path,\n            subfolder=subfolder,\n            use_safetensors=use_safetensors,\n            variant=variant\n        )\n\n        return cls.from_single_file(\n            ckpt_path,\n            config_path,\n            device=device,\n            dtype=dtype,\n            use_safetensors=use_safetensors,\n            **kwargs\n        )\n\n    def __init__(\n        self,\n        input_size=1024,\n        in_channels=4,\n        hidden_size=1024,\n        context_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4.0,\n        norm_type='layer',\n        qk_norm_type='rms',\n        qk_norm=False,\n        text_len=257,\n        guidance_cond_proj_dim=None,\n        qkv_bias=True,\n        num_moe_layers: int = 6,\n        num_experts: int = 8,\n        moe_top_k: int = 2,\n        voxel_query_res: int = 128,\n        **kwargs\n    ):\n        super().__init__()\n        self.input_size = input_size\n        self.depth = depth\n        self.in_channels = in_channels\n        self.out_channels = in_channels\n        self.num_heads = num_heads\n\n        self.hidden_size = hidden_size\n        self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm\n        self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm\n        self.context_dim = context_dim\n        self.voxel_query_res = voxel_query_res\n\n        self.guidance_cond_proj_dim = guidance_cond_proj_dim\n\n        self.text_len = text_len\n\n        self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)\n        self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)\n\n        self.blocks = nn.ModuleList([\n            DiTBlock(hidden_size=hidden_size,\n                            c_emb_size=hidden_size,\n                            num_heads=num_heads,\n                            mlp_ratio=mlp_ratio,\n                            text_states_dim=context_dim,\n                            qk_norm=qk_norm,\n                            norm_layer=self.norm,\n                            qk_norm_layer=self.qk_norm,\n                            skip_connection=layer > depth // 2,\n                            qkv_bias=qkv_bias,\n                            use_moe=True if depth - layer <= num_moe_layers else False,\n                            num_experts=num_experts,\n                            moe_top_k=moe_top_k\n                            )\n            for layer in range(depth)\n        ])\n        self.depth = depth\n\n        self.final_layer = FinalLayer(hidden_size, self.out_channels)\n\n    def forward(self, x, t, contexts, **kwargs):\n        cond = contexts['main']\n\n        t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))\n        x = self.x_embedder(x)\n        c = t\n\n        ##########################################\n        head_dim = self.blocks[0].attn1.head_dim\n        num_cond_tokens = c.shape[1] if c.dim() == 3 else 1\n\n        device = x.device\n        cond_cos = torch.ones(x.shape[0], num_cond_tokens, head_dim, device=device)\n        cond_sin = torch.zeros(x.shape[0], num_cond_tokens, head_dim, device=device)\n\n        voxel_cond = kwargs.get('voxel_cond')\n        # rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d(head_dim, voxel_cond)\n        rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d_interpolated(\n            head_dim, voxel_cond, current_res=self.voxel_query_res)\n\n        rotary_cos = torch.cat([cond_cos, rotary_cos_vox], dim=1)\n        rotary_sin = torch.cat([cond_sin, rotary_sin_vox], dim=1)\n        ##########################################\n\n        x = torch.cat([c, x], dim=1)\n\n        skip_value_list = []\n        for layer, block in enumerate(self.blocks):\n            skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()\n            x = block(x, c, cond, rotary_cos=rotary_cos, rotary_sin=rotary_sin, skip_value=skip_value)\n            if layer < self.depth // 2:\n                skip_value_list.append(x)\n\n        x = self.final_layer(x)\n        return x\n\n\ndef apply_rotary_emb(x, cos, sin):\n    \"\"\"\n    x: [B, H, N, D]\n    cos, sin: [B, N, D]\n    \"\"\"\n\n    cos = cos.unsqueeze(1)\n    sin = sin.unsqueeze(1)\n\n    def rotate_half(x):\n        x1, x2 = x.chunk(2, dim=-1)\n        return torch.cat((-x2, x1), dim=-1)\n        \n    return (x * cos) + (rotate_half(x) * sin)\n\n\ndef precompute_freqs_cis_3d(dim: int, grid_indices: torch.Tensor, theta: float = 10000.0):\n    \"\"\"\n    grid_indices: [B, N, 3] voxel idx\n    \"\"\"\n    dim_x = dim // 3\n    dim_y = dim // 3\n    dim_z = dim - dim_x - dim_y \n    \n    device = grid_indices.device\n    freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))\n    freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))\n    freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))\n    \n    x_idx = grid_indices[..., 0].float()\n    y_idx = grid_indices[..., 1].float()\n    z_idx = grid_indices[..., 2].float()\n\n    args_x = x_idx.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)\n    args_y = y_idx.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)\n    args_z = z_idx.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)\n\n    args = torch.cat([args_x, args_y, args_z], dim=-1)\n    args = torch.cat([args, args], dim=-1)\n    \n    return torch.cos(args), torch.sin(args)\n\n\ndef precompute_freqs_cis_3d_interpolated(\n    dim: int, \n    grid_indices: torch.Tensor, \n    theta: float = 10000.0,\n    trained_res: float = 128.0,  # training resolution\n    current_res: float = 256.0,  # inference resolution\n):\n    scale_factor = current_res / trained_res\n    \n    dim_x = dim // 3\n    dim_y = dim // 3\n    dim_z = dim - dim_x - dim_y \n    \n    device = grid_indices.device\n\n    freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))\n    freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))\n    freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))\n\n    num_freqs_x = dim_x // 2 + (dim_x % 2)\n    num_freqs_y = dim_y // 2 + (dim_y % 2)\n    target_len = dim // 2\n    freqs_x = freqs_x[:num_freqs_x]\n    freqs_y = freqs_y[:num_freqs_y]\n    freqs_z = freqs_z[:(target_len - len(freqs_x) - len(freqs_y))]\n\n    input_x = grid_indices[..., 0].float()\n    input_y = grid_indices[..., 1].float()\n    input_z = grid_indices[..., 2].float()\n\n    # Apply Scaling\n    pos_x = input_x / scale_factor\n    pos_y = input_y / scale_factor\n    pos_z = input_z / scale_factor\n\n    # pos * freq\n    args_x = pos_x.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)\n    args_y = pos_y.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)\n    args_z = pos_z.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)\n\n    args = torch.cat([args_x, args_y, args_z], dim=-1)\n    args = torch.cat([args, args], dim=-1) \n    \n    return torch.cos(args), torch.sin(args)\n"
  },
  {
    "path": "ultrashape/models/denoisers/moe_layers.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom timm.models.vision_transformer import PatchEmbed, Attention, Mlp\n\nimport torch.nn.functional as F\nfrom diffusers.models.attention import FeedForward\n\nclass AddAuxiliaryLoss(torch.autograd.Function):\n    \"\"\"\n    The trick function of adding auxiliary (aux) loss, \n    which includes the gradient of the aux loss during backpropagation.\n    \"\"\"\n    @staticmethod\n    def forward(ctx, x, loss):\n        assert loss.numel() == 1\n        ctx.dtype = loss.dtype\n        ctx.required_aux_loss = loss.requires_grad\n        return x\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_loss = None\n        if ctx.required_aux_loss:\n            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)\n        return grad_output, grad_loss\n\nclass MoEGate(nn.Module):\n    def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):\n        super().__init__()\n        self.top_k = num_experts_per_tok\n        self.n_routed_experts = num_experts\n\n        self.scoring_func = 'softmax'\n        self.alpha = aux_loss_alpha\n        self.seq_aux = False\n\n        # topk selection algorithm\n        self.norm_topk_prob = False\n        self.gating_dim = embed_dim\n        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init  as init\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n    \n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape    \n        # print(bsz, seq_len, h)    \n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(hidden_states, self.weight, None)\n        if self.scoring_func == 'softmax':\n            scores = logits.softmax(dim=-1)\n        else:\n            raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')\n        \n        ### select top-k experts\n        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)\n        \n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n\n        ### expert-level computation auxiliary loss\n        if self.training and self.alpha > 0.0:\n            scores_for_aux = scores\n            aux_topk = self.top_k\n            # always compute aux loss based on the naive greedy topk method\n            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)\n            if self.seq_aux:\n                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)\n                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)\n                ce.scatter_add_(\n                    1, \n                    topk_idx_for_aux_loss, \n                    torch.ones(\n                        bsz, seq_len * aux_topk,\n                        device=hidden_states.device\n                    )\n                ).div_(seq_len * aux_topk / self.n_routed_experts)\n                aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()\n                aux_loss = aux_loss * self.alpha\n            else:\n                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),\n                                    num_classes=self.n_routed_experts)\n                ce = mask_ce.float().mean(0)\n                Pi = scores_for_aux.mean(0)\n                fi = ce * self.n_routed_experts\n                aux_loss = (Pi * fi).sum() * self.alpha\n        else:\n            aux_loss = None\n        return topk_idx, topk_weight, aux_loss\n\nclass MoEBlock(nn.Module):\n    def __init__(self, dim, num_experts=8, moe_top_k=2,\n                    activation_fn = \"gelu\", dropout=0.0, final_dropout = False, \n                    ff_inner_dim = None, ff_bias = True):\n        super().__init__()\n        self.moe_top_k = moe_top_k\n        self.experts = nn.ModuleList([\n                FeedForward(dim,dropout=dropout, \n                            activation_fn=activation_fn,  \n                            final_dropout=final_dropout,  \n                            inner_dim=ff_inner_dim,  \n                            bias=ff_bias)\n        for i in range(num_experts)])\n        self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)\n\n        self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,  \n                                          final_dropout=final_dropout,  inner_dim=ff_inner_dim,  \n                                          bias=ff_bias)\n\n    def initialize_weight(self):\n        pass\n    \n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight, aux_loss = self.gate(hidden_states) \n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if self.training:\n            hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)\n            y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)\n            for i, expert in enumerate(self.experts): \n                tmp = expert(hidden_states[flat_topk_idx == i])\n                y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)\n            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)\n            y =  y.view(*orig_shape)\n            y = AddAuxiliaryLoss.apply(y, aux_loss)\n        else:\n            y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)\n        y = y + self.shared_experts(identity)\n        return y\n    \n\n    @torch.no_grad()\n    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):\n        expert_cache = torch.zeros_like(x) \n        idxs = flat_expert_indices.argsort()\n        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)\n        token_idxs = idxs // self.moe_top_k \n        for i, end_idx in enumerate(tokens_per_expert):\n            start_idx = 0 if i == 0 else tokens_per_expert[i-1]\n            if start_idx == end_idx:\n                continue\n            expert = self.experts[i]\n            exp_token_idx = token_idxs[start_idx:end_idx]\n            expert_tokens = x[exp_token_idx]\n            expert_out = expert(expert_tokens)\n            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) \n            \n            # for fp16 and other dtype\n            expert_cache = expert_cache.to(expert_out.dtype)\n            expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),\n                                         expert_out, \n                                         reduce='sum')\n        return expert_cache\n"
  },
  {
    "path": "ultrashape/models/diffusion/flow_matching_dit_trainer.py",
    "content": "\nimport os\nfrom contextlib import contextmanager\nfrom typing import List, Tuple, Optional, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\nimport pytorch_lightning as pl\nfrom pytorch_lightning.utilities import rank_zero_info\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom ultrashape.pipelines import export_to_trimesh\n\nfrom ...utils.ema import LitEma\nfrom ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model, instantiate_vae_model_local\n\n\nclass Diffuser(pl.LightningModule):\n    def __init__(\n        self,\n        *,\n        vae_config,\n        cond_config,\n        dit_cfg,\n        scheduler_cfg,\n        optimizer_cfg,\n        pipeline_cfg=None,\n        image_processor_cfg=None,\n        lora_config=None,\n        ema_config=None,\n        scale_by_std: bool = False,\n        z_scale_factor: float = 1.0,\n        ckpt_path: Optional[str] = None,\n        ignore_keys: Union[Tuple[str], List[str]] = (),\n        torch_compile: bool = False,\n    ):\n        super().__init__()\n\n        # ========= init optimizer config ========= #\n        self.optimizer_cfg = optimizer_cfg\n\n        # ========= init diffusion scheduler ========= #\n        self.scheduler_cfg = scheduler_cfg\n        self.sampler = None\n        if 'transport' in scheduler_cfg:\n            self.transport = instantiate_from_config(scheduler_cfg.transport)\n            self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)\n            self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)\n\n        # ========= init the model ========= #\n        self.dit_cfg = dit_cfg\n        self.model = instantiate_from_config(dit_cfg, device=None, dtype=None)\n        \n        self.cond_stage_model = instantiate_from_config(cond_config)\n\n        self.ckpt_path = ckpt_path\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n\n        # ========= config lora model ========= #\n        if lora_config is not None:\n            from peft import LoraConfig, get_peft_model\n            loraconfig = LoraConfig(\n                r=lora_config.rank,\n                lora_alpha=lora_config.rank,\n                target_modules=lora_config.get('target_modules')\n            )\n            self.model = get_peft_model(self.model, loraconfig)\n\n        # ========= config ema model ========= #\n        self.ema_config = ema_config\n        if self.ema_config is not None:\n            if self.ema_config.ema_model == 'DSEma':\n                # from michelangelo.models.modules.ema_deepspeed import DSEma\n                from ..utils.ema_deepspeed import DSEma\n                self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)\n            else:\n                self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)\n            #do not initilize EMA weight from ckpt path, since I need to change moe layers\n            if ckpt_path is not None:\n                self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        # ========= init vae at last to prevent it is overridden by loaded ckpt ========= #\n        self.first_stage_model = instantiate_vae_model_local(vae_config)\n        self.first_stage_model.enable_flashvdm_decoder()\n\n        self.scale_by_std = scale_by_std\n        if scale_by_std:\n            self.register_buffer(\"z_scale_factor\", torch.tensor(z_scale_factor))\n        else:\n            self.z_scale_factor = z_scale_factor\n\n        # ========= init pipeline for inference ========= #\n        self.image_processor_cfg = image_processor_cfg\n        self.image_processor = None\n        if self.image_processor_cfg is not None:\n            self.image_processor = instantiate_from_config(self.image_processor_cfg)\n        self.pipeline_cfg = pipeline_cfg\n        from ...schedulers import FlowMatchEulerDiscreteScheduler\n        scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)\n        self.pipeline = instantiate_from_config(\n            pipeline_cfg,\n            vae=self.first_stage_model,\n            model=self.model,\n            scheduler=scheduler,\n            conditioner=self.cond_stage_model,\n            image_processor=self.image_processor,\n        )\n\n        # ========= torch compile to accelerate ========= #\n        self.torch_compile = torch_compile\n        if self.torch_compile:\n            torch.nn.Module.compile(self.model)\n            torch.nn.Module.compile(self.first_stage_model)\n            torch.nn.Module.compile(self.cond_stage_model)\n            print(f'*' * 100)\n            print(f'Compile model for acceleration')\n            print(f'*' * 100)\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.ema_config is not None and self.ema_config.get('ema_inference', False):\n            self.model_ema.store(self.model)\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.ema_config is not None and self.ema_config.get('ema_inference', False):\n                self.model_ema.restore(self.model)\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def init_from_ckpt(self, path, ignore_keys=()):\n        ckpt = torch.load(path, map_location=\"cpu\")\n        if 'state_dict' not in ckpt:\n            # deepspeed ckpt\n            state_dict = {}\n            for k in ckpt.keys():\n                new_k = k.replace('_forward_module.', '')\n                state_dict[new_k] = ckpt[k]\n        else:\n            state_dict = ckpt[\"state_dict\"]\n\n        keys = list(state_dict.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if ik in k:\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del state_dict[k]\n\n        missing, unexpected = self.load_state_dict(state_dict, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def on_load_checkpoint(self, checkpoint):\n        \"\"\"\n        The pt_model is trained separately, so we already have access to its\n        checkpoint and load it separately with `self.set_pt_model`.\n\n        However, the PL Trainer is strict about\n        checkpoint loading (not configurable), so it expects the loaded state_dict\n        to match exactly the keys in the model state_dict.\n\n        So, when loading the checkpoint, before matching keys, we add all pt_model keys\n        from self.state_dict() to the checkpoint state dict, so that they match\n        \"\"\"\n        for key in self.state_dict().keys():\n            if key.startswith(\"model_ema\") and key not in checkpoint[\"state_dict\"]:\n                checkpoint[\"state_dict\"][key] = self.state_dict()[key]\n\n    def configure_optimizers(self) -> Tuple[List, List]:\n        lr = self.learning_rate\n\n        params_list = []\n        trainable_parameters = list(self.model.parameters())\n        params_list.append({'params': trainable_parameters, 'lr': lr})\n\n        no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']\n\n\n        if self.optimizer_cfg.get('train_image_encoder', False):\n            image_encoder_parameters = list(self.cond_stage_model.named_parameters())\n            image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if\n                                              not any((no_decay_name in name) for no_decay_name in no_decay)]\n            image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if\n                                                any((no_decay_name in name) for no_decay_name in no_decay)]\n            # filter trainable params\n            image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if\n                                              param.requires_grad]\n            image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if\n                                                param.requires_grad]\n\n            print(f\"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, \")\n            print(f\"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, \")\n\n            image_encoder_lr = self.optimizer_cfg['image_encoder_lr']\n            image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)\n            image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply\n            params_list.append(\n                {'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,\n                 'weight_decay': 0.05})\n            params_list.append(\n                {'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,\n                 'weight_decay': 0.})\n\n        optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)\n        if hasattr(self.optimizer_cfg, 'scheduler'):\n            scheduler_func = instantiate_from_config(\n                self.optimizer_cfg.scheduler,\n                max_decay_steps=self.trainer.max_steps,\n                lr_max=lr\n            )\n            scheduler = {\n                \"scheduler\": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),\n                \"interval\": \"step\",\n                \"frequency\": 1\n            }\n            schedulers = [scheduler]\n        else:\n            schedulers = []\n        optimizers = [optimizer]\n\n        return optimizers, schedulers\n\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.ema_config is not None:\n            self.model_ema(self.model)\n\n    def on_train_epoch_start(self) -> None:\n        pl.seed_everything(self.trainer.global_rank)\n\n    def forward(self, batch, disable_drop):\n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16): #float32 for text\n            contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'), disable_drop=disable_drop)\n\n        with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n            with torch.no_grad():\n                latents, voxel_idx = self.first_stage_model.encode(batch[\"surface\"], sample_posterior=True, need_voxel=True)\n                latents = self.z_scale_factor * latents\n                # print(latents.shape)\n                \n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            loss = self.transport.training_losses(self.model, latents, \n                    dict(contexts=contexts, voxel_cond=voxel_idx))[\"loss\"].mean()\n\n        return loss\n\n    def training_step(self, batch, batch_idx, optimizer_idx=0):\n        loss = self.forward(batch, disable_drop=False)\n        split = 'train'\n        loss_dict = {\n            f\"{split}/total_loss\": loss.detach(),\n            f\"{split}/lr_abs\": self.optimizers().param_groups[0]['lr'],\n        }\n        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)\n\n        return loss\n\n    def validation_step(self, batch, batch_idx, optimizer_idx=0):\n        loss = self.forward(batch, disable_drop=True)\n        split = 'val'\n        loss_dict = {\n            f\"{split}/total_loss\": loss.detach(),\n            f\"{split}/lr_abs\": self.optimizers().param_groups[0]['lr'],\n        }\n        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)\n\n        return loss\n\n    @torch.no_grad()\n    def sample(self, batch, output_type='trimesh', **kwargs):\n        self.cond_stage_model.disable_drop = True\n\n        generator = torch.Generator().manual_seed(0)\n\n        with self.ema_scope(\"Sample\"):\n            with torch.amp.autocast(device_type='cuda'):\n                try:\n                    self.pipeline.device = self.device\n                    self.pipeline.dtype = self.dtype\n                    print(\"### USING PIPELINE ###\")\n                    print(f'device: {self.device} dtype : {self.dtype}')\n                    additional_params = {'output_type':output_type}\n\n                    image = batch.get(\"image\", None)\n                    mask = batch.get('mask', None)\n                    \n                    outputs = self.pipeline(image=image, \n                                            mask=mask,\n                                            generator=generator,\n                                            box_v=1.0,\n                                            mc_level=0.0,\n                                            octree_resolution=1024,\n                                            **additional_params)\n\n                except Exception as e:\n                    import traceback\n                    traceback.print_exc()\n                    print(f\"Unexpected {e=}, {type(e)=}\")\n                    with open(\"error.txt\", \"a\") as f:\n                        f.write(str(e))\n                        f.write(traceback.format_exc())\n                        f.write(\"\\n\")\n                    outputs = [None]\n\n        self.cond_stage_model.disable_drop = False\n        return [outputs]\n"
  },
  {
    "path": "ultrashape/models/diffusion/transport/__init__.py",
    "content": "# This file includes code derived from the SiT project (https://github.com/willisma/SiT),\n# which is licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom .transport import Transport, ModelType, WeightType, PathType, Sampler\n\n\ndef create_transport(\n    path_type='Linear',\n    prediction=\"velocity\",\n    loss_weight=None,\n    train_eps=None,\n    sample_eps=None,\n    train_sample_type=\"uniform\",\n    mean = 0.0,\n    std = 1.0,\n    shift_scale = 1.0,\n):\n    \"\"\"function for creating Transport object\n    **Note**: model prediction defaults to velocity\n    Args:\n    - path_type: type of path to use; default to linear\n    - learn_score: set model prediction to score\n    - learn_noise: set model prediction to noise\n    - velocity_weighted: weight loss by velocity weight\n    - likelihood_weighted: weight loss by likelihood weight\n    - train_eps: small epsilon for avoiding instability during training\n    - sample_eps: small epsilon for avoiding instability during sampling\n    \"\"\"\n\n    if prediction == \"noise\":\n        model_type = ModelType.NOISE\n    elif prediction == \"score\":\n        model_type = ModelType.SCORE\n    else:\n        model_type = ModelType.VELOCITY\n\n    if loss_weight == \"velocity\":\n        loss_type = WeightType.VELOCITY\n    elif loss_weight == \"likelihood\":\n        loss_type = WeightType.LIKELIHOOD\n    else:\n        loss_type = WeightType.NONE\n\n    path_choice = {\n        \"Linear\": PathType.LINEAR,\n        \"GVP\": PathType.GVP,\n        \"VP\": PathType.VP,\n    }\n\n    path_type = path_choice[path_type]\n\n    if (path_type in [PathType.VP]):\n        train_eps = 1e-5 if train_eps is None else train_eps\n        sample_eps = 1e-3 if train_eps is None else sample_eps\n    elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):\n        train_eps = 1e-3 if train_eps is None else train_eps\n        sample_eps = 1e-3 if train_eps is None else sample_eps\n    else:  # velocity & [GVP, LINEAR] is stable everywhere\n        train_eps = 0\n        sample_eps = 0\n\n    # create flow state\n    state = Transport(\n        model_type=model_type,\n        path_type=path_type,\n        loss_type=loss_type,\n        train_eps=train_eps,\n        sample_eps=sample_eps,\n        train_sample_type=train_sample_type,\n        mean=mean,\n        std=std,\n        shift_scale =shift_scale,\n    )\n\n    return state\n"
  },
  {
    "path": "ultrashape/models/diffusion/transport/integrators.py",
    "content": "# This file includes code derived from the SiT project (https://github.com/willisma/SiT),\n# which is licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nfrom torchdiffeq import odeint\nfrom functools import partial\nfrom tqdm import tqdm\n\nclass sde:\n    \"\"\"SDE solver class\"\"\"\n    def __init__(\n        self, \n        drift,\n        diffusion,\n        *,\n        t0,\n        t1,\n        num_steps,\n        sampler_type,\n    ):\n        assert t0 < t1, \"SDE sampler has to be in forward time\"\n\n        self.num_timesteps = num_steps\n        self.t = th.linspace(t0, t1, num_steps)\n        self.dt = self.t[1] - self.t[0]\n        self.drift = drift\n        self.diffusion = diffusion\n        self.sampler_type = sampler_type\n\n    def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):\n        w_cur = th.randn(x.size()).to(x)\n        t = th.ones(x.size(0)).to(x) * t\n        dw = w_cur * th.sqrt(self.dt)\n        drift = self.drift(x, t, model, **model_kwargs)\n        diffusion = self.diffusion(x, t)\n        mean_x = x + drift * self.dt\n        x = mean_x + th.sqrt(2 * diffusion) * dw\n        return x, mean_x\n    \n    def __Heun_step(self, x, _, t, model, **model_kwargs):\n        w_cur = th.randn(x.size()).to(x)\n        dw = w_cur * th.sqrt(self.dt)\n        t_cur = th.ones(x.size(0)).to(x) * t\n        diffusion = self.diffusion(x, t_cur)\n        xhat = x + th.sqrt(2 * diffusion) * dw\n        K1 = self.drift(xhat, t_cur, model, **model_kwargs)\n        xp = xhat + self.dt * K1\n        K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)\n        return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step\n\n    def __forward_fn(self):\n        \"\"\"TODO: generalize here by adding all private functions ending with steps to it\"\"\"\n        sampler_dict = {\n            \"Euler\": self.__Euler_Maruyama_step,\n            \"Heun\": self.__Heun_step,\n        }\n\n        try:\n            sampler = sampler_dict[self.sampler_type]\n        except:\n            raise NotImplementedError(\"Smapler type not implemented.\")\n    \n        return sampler\n\n    def sample(self, init, model, **model_kwargs):\n        \"\"\"forward loop of sde\"\"\"\n        x = init\n        mean_x = init \n        samples = []\n        sampler = self.__forward_fn()\n        for ti in self.t[:-1]:\n            with th.no_grad():\n                x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)\n                samples.append(x)\n\n        return samples\n\nclass ode:\n    \"\"\"ODE solver class\"\"\"\n    def __init__(\n        self,\n        drift,\n        *,\n        t0,\n        t1,\n        sampler_type,\n        num_steps,\n        atol,\n        rtol,\n    ):\n        assert t0 < t1, \"ODE sampler has to be in forward time\"\n\n        self.drift = drift\n        self.t = th.linspace(t0, t1, num_steps)\n        self.atol = atol\n        self.rtol = rtol\n        self.sampler_type = sampler_type\n\n    def sample(self, x, model, **model_kwargs):\n        \n        device = x[0].device if isinstance(x, tuple) else x.device\n        def _fn(t, x):\n            t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t\n            model_output = self.drift(x, t, model, **model_kwargs)\n            return model_output\n\n        t = self.t.to(device)\n        atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]\n        rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]\n        samples = odeint(\n            _fn,\n            x,\n            t,\n            method=self.sampler_type,\n            atol=atol,\n            rtol=rtol\n        )\n        return samples\n"
  },
  {
    "path": "ultrashape/models/diffusion/transport/path.py",
    "content": "# This file includes code derived from the SiT project (https://github.com/willisma/SiT),\n# which is licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport torch as th\nimport numpy as np\nfrom functools import partial\n\ndef expand_t_like_x(t, x):\n    \"\"\"Function to reshape time t to broadcastable dimension of x\n    Args:\n      t: [batch_dim,], time vector\n      x: [batch_dim,...], data point\n    \"\"\"\n    dims = [1] * (len(x.size()) - 1)\n    t = t.view(t.size(0), *dims)\n    return t\n\n\n#################### Coupling Plans ####################\n\nclass ICPlan:\n    \"\"\"Linear Coupling Plan\"\"\"\n    def __init__(self, sigma=0.0):\n        self.sigma = sigma\n\n    def compute_alpha_t(self, t):\n        \"\"\"Compute the data coefficient along the path\"\"\"\n        return t, 1\n    \n    def compute_sigma_t(self, t):\n        \"\"\"Compute the noise coefficient along the path\"\"\"\n        return 1 - t, -1\n    \n    def compute_d_alpha_alpha_ratio_t(self, t):\n        \"\"\"Compute the ratio between d_alpha and alpha\"\"\"\n        return 1 / t\n\n    def compute_drift(self, x, t):\n        \"\"\"We always output sde according to score parametrization; \"\"\"\n        t = expand_t_like_x(t, x)\n        alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)\n        sigma_t, d_sigma_t = self.compute_sigma_t(t)\n        drift = alpha_ratio * x\n        diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t\n\n        return -drift, diffusion\n\n    def compute_diffusion(self, x, t, form=\"constant\", norm=1.0):\n        \"\"\"Compute the diffusion term of the SDE\n        Args:\n          x: [batch_dim, ...], data point\n          t: [batch_dim,], time vector\n          form: str, form of the diffusion term\n          norm: float, norm of the diffusion term\n        \"\"\"\n        t = expand_t_like_x(t, x)\n        choices = {\n            \"constant\": norm,\n            \"SBDM\": norm * self.compute_drift(x, t)[1],\n            \"sigma\": norm * self.compute_sigma_t(t)[0],\n            \"linear\": norm * (1 - t),\n            \"decreasing\": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,\n            \"inccreasing-decreasing\": norm * th.sin(np.pi * t) ** 2,\n        }\n\n        try:\n            diffusion = choices[form]\n        except KeyError:\n            raise NotImplementedError(f\"Diffusion form {form} not implemented\")\n        \n        return diffusion\n\n    def get_score_from_velocity(self, velocity, x, t):\n        \"\"\"Wrapper function: transfrom velocity prediction model to score\n        Args:\n            velocity: [batch_dim, ...] shaped tensor; velocity model output\n            x: [batch_dim, ...] shaped tensor; x_t data point\n            t: [batch_dim,] time tensor\n        \"\"\"\n        t = expand_t_like_x(t, x)\n        alpha_t, d_alpha_t = self.compute_alpha_t(t)\n        sigma_t, d_sigma_t = self.compute_sigma_t(t)\n        mean = x\n        reverse_alpha_ratio = alpha_t / d_alpha_t\n        var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t\n        score = (reverse_alpha_ratio * velocity - mean) / var\n        return score\n    \n    def get_noise_from_velocity(self, velocity, x, t):\n        \"\"\"Wrapper function: transfrom velocity prediction model to denoiser\n        Args:\n            velocity: [batch_dim, ...] shaped tensor; velocity model output\n            x: [batch_dim, ...] shaped tensor; x_t data point\n            t: [batch_dim,] time tensor\n        \"\"\"\n        t = expand_t_like_x(t, x)\n        alpha_t, d_alpha_t = self.compute_alpha_t(t)\n        sigma_t, d_sigma_t = self.compute_sigma_t(t)\n        mean = x\n        reverse_alpha_ratio = alpha_t / d_alpha_t\n        var = reverse_alpha_ratio * d_sigma_t - sigma_t\n        noise = (reverse_alpha_ratio * velocity - mean) / var\n        return noise\n\n    def get_velocity_from_score(self, score, x, t):\n        \"\"\"Wrapper function: transfrom score prediction model to velocity\n        Args:\n            score: [batch_dim, ...] shaped tensor; score model output\n            x: [batch_dim, ...] shaped tensor; x_t data point\n            t: [batch_dim,] time tensor\n        \"\"\"\n        t = expand_t_like_x(t, x)\n        drift, var = self.compute_drift(x, t)\n        velocity = var * score - drift\n        return velocity\n\n    def compute_mu_t(self, t, x0, x1):\n        \"\"\"Compute the mean of time-dependent density p_t\"\"\"\n        t = expand_t_like_x(t, x1)\n        alpha_t, _ = self.compute_alpha_t(t)\n        sigma_t, _ = self.compute_sigma_t(t)\n        # t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1\n        return alpha_t * x1 + sigma_t * x0\n    \n    def compute_xt(self, t, x0, x1):\n        \"\"\"Sample xt from time-dependent density p_t; rng is required\"\"\"\n        xt = self.compute_mu_t(t, x0, x1)\n        return xt\n    \n    def compute_ut(self, t, x0, x1, xt):\n        \"\"\"Compute the vector field corresponding to p_t\"\"\"\n        t = expand_t_like_x(t, x1)\n        _, d_alpha_t = self.compute_alpha_t(t)\n        _, d_sigma_t = self.compute_sigma_t(t)\n        return d_alpha_t * x1 + d_sigma_t * x0\n    \n    def plan(self, t, x0, x1):\n        xt = self.compute_xt(t, x0, x1)\n        ut = self.compute_ut(t, x0, x1, xt)\n        return t, xt, ut\n    \n\nclass VPCPlan(ICPlan):\n    \"\"\"class for VP path flow matching\"\"\"\n\n    def __init__(self, sigma_min=0.1, sigma_max=20.0):\n        self.sigma_min = sigma_min\n        self.sigma_max = sigma_max\n        self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \\\n            (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min \n        self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \\\n            (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min\n\n\n    def compute_alpha_t(self, t):\n        \"\"\"Compute coefficient of x1\"\"\"\n        alpha_t = self.log_mean_coeff(t)\n        alpha_t = th.exp(alpha_t)\n        d_alpha_t = alpha_t * self.d_log_mean_coeff(t)\n        return alpha_t, d_alpha_t\n    \n    def compute_sigma_t(self, t):\n        \"\"\"Compute coefficient of x0\"\"\"\n        p_sigma_t = 2 * self.log_mean_coeff(t)\n        sigma_t = th.sqrt(1 - th.exp(p_sigma_t))\n        d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)\n        return sigma_t, d_sigma_t\n    \n    def compute_d_alpha_alpha_ratio_t(self, t):\n        \"\"\"Special purposed function for computing numerical stabled d_alpha_t / alpha_t\"\"\"\n        return self.d_log_mean_coeff(t)\n\n    def compute_drift(self, x, t):\n        \"\"\"Compute the drift term of the SDE\"\"\"\n        t = expand_t_like_x(t, x)\n        beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)\n        return -0.5 * beta_t * x, beta_t / 2\n    \n\nclass GVPCPlan(ICPlan):\n    def __init__(self, sigma=0.0):\n        super().__init__(sigma)\n    \n    def compute_alpha_t(self, t):\n        \"\"\"Compute coefficient of x1\"\"\"\n        alpha_t = th.sin(t * np.pi / 2)\n        d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)\n        return alpha_t, d_alpha_t\n    \n    def compute_sigma_t(self, t):\n        \"\"\"Compute coefficient of x0\"\"\"\n        sigma_t = th.cos(t * np.pi / 2)\n        d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)\n        return sigma_t, d_sigma_t\n    \n    def compute_d_alpha_alpha_ratio_t(self, t):\n        \"\"\"Special purposed function for computing numerical stabled d_alpha_t / alpha_t\"\"\"\n        return np.pi / (2 * th.tan(t * np.pi / 2))\n"
  },
  {
    "path": "ultrashape/models/diffusion/transport/transport.py",
    "content": "# This file includes code derived from the SiT project (https://github.com/willisma/SiT),\n# which is licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport torch as th\nimport numpy as np\nimport logging\n\nimport enum\n\nfrom . import path\nfrom .utils import EasyDict, log_state, mean_flat\nfrom .integrators import ode, sde\n\n\nclass ModelType(enum.Enum):\n    \"\"\"\n    Which type of output the model predicts.\n    \"\"\"\n\n    NOISE = enum.auto()  # the model predicts epsilon\n    SCORE = enum.auto()  # the model predicts \\nabla \\log p(x)\n    VELOCITY = enum.auto()  # the model predicts v(x)\n\n\nclass PathType(enum.Enum):\n    \"\"\"\n    Which type of path to use.\n    \"\"\"\n\n    LINEAR = enum.auto()\n    GVP = enum.auto()\n    VP = enum.auto()\n\n\nclass WeightType(enum.Enum):\n    \"\"\"\n    Which type of weighting to use.\n    \"\"\"\n\n    NONE = enum.auto()\n    VELOCITY = enum.auto()\n    LIKELIHOOD = enum.auto()\n\n\nclass Transport:\n\n    def __init__(\n        self,\n        *,\n        model_type,\n        path_type,\n        loss_type,\n        train_eps,\n        sample_eps,\n        train_sample_type = \"uniform\",\n        **kwargs,\n    ):\n        path_options = {\n            PathType.LINEAR: path.ICPlan,\n            PathType.GVP: path.GVPCPlan,\n            PathType.VP: path.VPCPlan,\n        }\n\n        self.loss_type = loss_type\n        self.model_type = model_type\n        self.path_sampler = path_options[path_type]()\n        self.train_eps = train_eps\n        self.sample_eps = sample_eps\n        self.train_sample_type = train_sample_type\n        if self.train_sample_type == \"logit_normal\":\n            self.mean = kwargs['mean']\n            self.std = kwargs['std']\n            self.shift_scale = kwargs['shift_scale']\n            print(f\"using logit normal sample, shift scale is {self.shift_scale}\")\n\n    def prior_logp(self, z):\n        '''\n            Standard multivariate normal prior\n            Assume z is batched\n        '''\n        shape = th.tensor(z.size())\n        N = th.prod(shape[1:])\n        _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.\n        return th.vmap(_fn)(z)\n\n    def check_interval(\n        self,\n        train_eps,\n        sample_eps,\n        *,\n        diffusion_form=\"SBDM\",\n        sde=False,\n        reverse=False,\n        eval=False,\n        last_step_size=0.0,\n    ):\n        t0 = 0\n        t1 = 1\n        eps = train_eps if not eval else sample_eps\n        if (type(self.path_sampler) in [path.VPCPlan]):\n\n            t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size\n\n        elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \\\n            and (\n            self.model_type != ModelType.VELOCITY or sde):  # avoid numerical issue by taking a first semi-implicit step\n\n            t0 = eps if (diffusion_form == \"SBDM\" and sde) or self.model_type != ModelType.VELOCITY else 0\n            t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size\n\n        if reverse:\n            t0, t1 = 1 - t0, 1 - t1\n\n        return t0, t1\n\n    def sample(self, x1):\n        \"\"\"Sampling x0 & t based on shape of x1 (if needed)\n          Args:\n            x1 - data point; [batch, *dim]\n        \"\"\"\n\n        x0 = th.randn_like(x1)\n        if self.train_sample_type==\"uniform\":\n            t0, t1 = self.check_interval(self.train_eps, self.sample_eps)\n            t = th.rand((x1.shape[0],)) * (t1 - t0) + t0\n            t = t.to(x1)\n        elif self.train_sample_type==\"logit_normal\":\n            t = th.randn((x1.shape[0],)) * self.std + self.mean\n            t = t.to(x1)\n            t = 1/(1+th.exp(-t))\n\n            t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)\n\n        return t, x0, x1\n\n    def training_losses(\n        self,\n        model,\n        x1,\n        model_kwargs=None\n    ):\n        \"\"\"Loss for training the score model\n        Args:\n        - model: backbone model; could be score, noise, or velocity\n        - x1: datapoint\n        - model_kwargs: additional arguments for the model\n        \"\"\"\n        if model_kwargs == None:\n            model_kwargs = {}\n\n        t, x0, x1 = self.sample(x1)\n        t, xt, ut = self.path_sampler.plan(t, x0, x1)\n        model_output = model(xt, t, **model_kwargs)\n        B, *_, C = xt.shape\n        assert model_output.size() == (B, *xt.size()[1:-1], C)\n\n        terms = {}\n        terms['pred'] = model_output\n        if self.model_type == ModelType.VELOCITY:\n            terms['loss'] = mean_flat(((model_output - ut) ** 2))\n        else:\n            _, drift_var = self.path_sampler.compute_drift(xt, t)\n            sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))\n            if self.loss_type in [WeightType.VELOCITY]:\n                weight = (drift_var / sigma_t) ** 2\n            elif self.loss_type in [WeightType.LIKELIHOOD]:\n                weight = drift_var / (sigma_t ** 2)\n            elif self.loss_type in [WeightType.NONE]:\n                weight = 1\n            else:\n                raise NotImplementedError()\n\n            if self.model_type == ModelType.NOISE:\n                terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))\n            else:\n                terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))\n\n        return terms\n\n    def get_drift(\n        self\n    ):\n        \"\"\"member function for obtaining the drift of the probability flow ODE\"\"\"\n\n        def score_ode(x, t, model, **model_kwargs):\n            drift_mean, drift_var = self.path_sampler.compute_drift(x, t)\n            model_output = model(x, t, **model_kwargs)\n            return (-drift_mean + drift_var * model_output)  # by change of variable\n\n        def noise_ode(x, t, model, **model_kwargs):\n            drift_mean, drift_var = self.path_sampler.compute_drift(x, t)\n            sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))\n            model_output = model(x, t, **model_kwargs)\n            score = model_output / -sigma_t\n            return (-drift_mean + drift_var * score)\n\n        def velocity_ode(x, t, model, **model_kwargs):\n            model_output = model(x, t, **model_kwargs)\n            return model_output\n\n        if self.model_type == ModelType.NOISE:\n            drift_fn = noise_ode\n        elif self.model_type == ModelType.SCORE:\n            drift_fn = score_ode\n        else:\n            drift_fn = velocity_ode\n\n        def body_fn(x, t, model, **model_kwargs):\n            model_output = drift_fn(x, t, model, **model_kwargs)\n            assert model_output.shape == x.shape, \"Output shape from ODE solver must match input shape\"\n            return model_output\n\n        return body_fn\n\n    def get_score(\n        self,\n    ):\n        \"\"\"member function for obtaining score of \n            x_t = alpha_t * x + sigma_t * eps\"\"\"\n        if self.model_type == ModelType.NOISE:\n            score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \\\n                self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]\n        elif self.model_type == ModelType.SCORE:\n            score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)\n        elif self.model_type == ModelType.VELOCITY:\n            score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,\n                                                                                               t)\n        else:\n            raise NotImplementedError()\n\n        return score_fn\n\n\nclass Sampler:\n    \"\"\"Sampler class for the transport model\"\"\"\n\n    def __init__(\n        self,\n        transport,\n    ):\n        \"\"\"Constructor for a general sampler; supporting different sampling methods\n        Args:\n        - transport: an tranport object specify model prediction & interpolant type\n        \"\"\"\n\n        self.transport = transport\n        self.drift = self.transport.get_drift()\n        self.score = self.transport.get_score()\n\n    def __get_sde_diffusion_and_drift(\n        self,\n        *,\n        diffusion_form=\"SBDM\",\n        diffusion_norm=1.0,\n    ):\n\n        def diffusion_fn(x, t):\n            diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)\n            return diffusion\n\n        sde_drift = \\\n            lambda x, t, model, **kwargs: \\\n                self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)\n\n        sde_diffusion = diffusion_fn\n\n        return sde_drift, sde_diffusion\n\n    def __get_last_step(\n        self,\n        sde_drift,\n        *,\n        last_step,\n        last_step_size,\n    ):\n        \"\"\"Get the last step function of the SDE solver\"\"\"\n\n        if last_step is None:\n            last_step_fn = \\\n                lambda x, t, model, **model_kwargs: \\\n                    x\n        elif last_step == \"Mean\":\n            last_step_fn = \\\n                lambda x, t, model, **model_kwargs: \\\n                    x + sde_drift(x, t, model, **model_kwargs) * last_step_size\n        elif last_step == \"Tweedie\":\n            alpha = self.transport.path_sampler.compute_alpha_t  # simple aliasing; the original name was too long\n            sigma = self.transport.path_sampler.compute_sigma_t\n            last_step_fn = \\\n                lambda x, t, model, **model_kwargs: \\\n                    x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,\n                                                                                             **model_kwargs)\n        elif last_step == \"Euler\":\n            last_step_fn = \\\n                lambda x, t, model, **model_kwargs: \\\n                    x + self.drift(x, t, model, **model_kwargs) * last_step_size\n        else:\n            raise NotImplementedError()\n\n        return last_step_fn\n\n    def sample_sde(\n        self,\n        *,\n        sampling_method=\"Euler\",\n        diffusion_form=\"SBDM\",\n        diffusion_norm=1.0,\n        last_step=\"Mean\",\n        last_step_size=0.04,\n        num_steps=250,\n    ):\n        \"\"\"returns a sampling function with given SDE settings\n        Args:\n        - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama\n        - diffusion_form: function form of diffusion coefficient; default to be matching SBDM\n        - diffusion_norm: function magnitude of diffusion coefficient; default to 1\n        - last_step: type of the last step; default to identity\n        - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]\n        - num_steps: total integration step of SDE\n        \"\"\"\n\n        if last_step is None:\n            last_step_size = 0.0\n\n        sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(\n            diffusion_form=diffusion_form,\n            diffusion_norm=diffusion_norm,\n        )\n\n        t0, t1 = self.transport.check_interval(\n            self.transport.train_eps,\n            self.transport.sample_eps,\n            diffusion_form=diffusion_form,\n            sde=True,\n            eval=True,\n            reverse=False,\n            last_step_size=last_step_size,\n        )\n\n        _sde = sde(\n            sde_drift,\n            sde_diffusion,\n            t0=t0,\n            t1=t1,\n            num_steps=num_steps,\n            sampler_type=sampling_method\n        )\n\n        last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)\n\n        def _sample(init, model, **model_kwargs):\n            xs = _sde.sample(init, model, **model_kwargs)\n            ts = th.ones(init.size(0), device=init.device) * t1\n            x = last_step_fn(xs[-1], ts, model, **model_kwargs)\n            xs.append(x)\n\n            assert len(xs) == num_steps, \"Samples does not match the number of steps\"\n\n            return xs\n\n        return _sample\n\n    def sample_ode(\n        self,\n        *,\n        sampling_method=\"dopri5\",\n        num_steps=50,\n        atol=1e-6,\n        rtol=1e-3,\n        reverse=False,\n    ):\n        \"\"\"returns a sampling function with given ODE settings\n        Args:\n        - sampling_method: type of sampler used in solving the ODE; default to be Dopri5\n        - num_steps: \n            - fixed solver (Euler, Heun): the actual number of integration steps performed\n            - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation\n        - atol: absolute error tolerance for the solver\n        - rtol: relative error tolerance for the solver\n        - reverse: whether solving the ODE in reverse (data to noise); default to False\n        \"\"\"\n        if reverse:\n            drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)\n        else:\n            drift = self.drift\n\n        t0, t1 = self.transport.check_interval(\n            self.transport.train_eps,\n            self.transport.sample_eps,\n            sde=False,\n            eval=True,\n            reverse=reverse,\n            last_step_size=0.0,\n        )\n\n        _ode = ode(\n            drift=drift,\n            t0=t0,\n            t1=t1,\n            sampler_type=sampling_method,\n            num_steps=num_steps,\n            atol=atol,\n            rtol=rtol,\n        )\n\n        return _ode.sample\n\n    def sample_ode_intermediate(\n        self,\n        *,\n        sampling_method=\"dopri5\",\n        num_steps=50,\n        atol=1e-6,\n        rtol=1e-3,\n        t=0.5,\n        reverse=False,\n    ):\n        \"\"\"returns a sampling function with given ODE settings\n        Args:\n        - sampling_method: type of sampler used in solving the ODE; default to be Dopri5\n        - num_steps: \n            - fixed solver (Euler, Heun): the actual number of integration steps performed\n            - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation\n        - atol: absolute error tolerance for the solver\n        - rtol: relative error tolerance for the solver\n        - reverse: whether solving the ODE in reverse (data to noise); default to False\n        \"\"\"\n        if reverse:\n            drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)\n        else:\n            drift = self.drift\n\n        t0, t1 = self.transport.check_interval(\n            self.transport.train_eps,\n            self.transport.sample_eps,\n            sde=False,\n            eval=True,\n            reverse=reverse,\n            last_step_size=0.0,\n        )\n\n        _ode = ode(\n            drift=drift,\n            t0=t,\n            t1=t1,\n            sampler_type=sampling_method,\n            num_steps=num_steps,\n            atol=atol,\n            rtol=rtol,\n        )\n\n        return _ode.sample\n\n    def sample_ode_likelihood(\n        self,\n        *,\n        sampling_method=\"dopri5\",\n        num_steps=50,\n        atol=1e-6,\n        rtol=1e-3,\n    ):\n\n        \"\"\"returns a sampling function for calculating likelihood with given ODE settings\n        Args:\n        - sampling_method: type of sampler used in solving the ODE; default to be Dopri5\n        - num_steps: \n            - fixed solver (Euler, Heun): the actual number of integration steps performed\n            - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation\n        - atol: absolute error tolerance for the solver\n        - rtol: relative error tolerance for the solver\n        \"\"\"\n\n        def _likelihood_drift(x, t, model, **model_kwargs):\n            x, _ = x\n            eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1\n            t = th.ones_like(t) * (1 - t)\n            with th.enable_grad():\n                x.requires_grad = True\n                grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]\n                logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))\n                drift = self.drift(x, t, model, **model_kwargs)\n            return (-drift, logp_grad)\n\n        t0, t1 = self.transport.check_interval(\n            self.transport.train_eps,\n            self.transport.sample_eps,\n            sde=False,\n            eval=True,\n            reverse=False,\n            last_step_size=0.0,\n        )\n\n        _ode = ode(\n            drift=_likelihood_drift,\n            t0=t0,\n            t1=t1,\n            sampler_type=sampling_method,\n            num_steps=num_steps,\n            atol=atol,\n            rtol=rtol,\n        )\n\n        def _sample_fn(x, model, **model_kwargs):\n            init_logp = th.zeros(x.size(0)).to(x)\n            input = (x, init_logp)\n            drift, delta_logp = _ode.sample(input, model, **model_kwargs)\n            drift, delta_logp = drift[-1], delta_logp[-1]\n            prior_logp = self.transport.prior_logp(drift)\n            logp = prior_logp - delta_logp\n            return logp, drift\n\n        return _sample_fn\n"
  },
  {
    "path": "ultrashape/models/diffusion/transport/utils.py",
    "content": "# This file includes code derived from the SiT project (https://github.com/willisma/SiT),\n# which is licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport torch as th\n\nclass EasyDict:\n\n    def __init__(self, sub_dict):\n        for k, v in sub_dict.items():\n            setattr(self, k, v)\n\n    def __getitem__(self, key):\n        return getattr(self, key)\n\ndef mean_flat(x):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return th.mean(x, dim=list(range(1, len(x.size()))))\n\ndef log_state(state):\n    result = []\n    \n    sorted_state = dict(sorted(state.items()))\n    for key, value in sorted_state.items():\n        # Check if the value is an instance of a class\n        if \"<object\" in str(value) or \"object at\" in str(value):\n            result.append(f\"{key}: [{value.__class__.__name__}]\")\n        else:\n            result.append(f\"{key}: {value}\")\n    \n    return '\\n'.join(result)\n"
  },
  {
    "path": "ultrashape/pipelines.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport copy\nimport importlib\nimport inspect\nimport os\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nimport trimesh\nimport yaml\nfrom PIL import Image\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available\nfrom tqdm import tqdm\n\nfrom .models.autoencoders import ShapeVAE\nfrom .models.autoencoders import SurfaceExtractors\nfrom .utils import logger, synchronize_timer, smart_load_model\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n@synchronize_timer('Export to trimesh')\ndef export_to_trimesh(mesh_output):\n    if isinstance(mesh_output, list):\n        outputs = []\n        for mesh in mesh_output:\n            if mesh is None:\n                outputs.append(None)\n            else:\n                mesh.mesh_f = mesh.mesh_f[:, ::-1]\n                mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)\n                outputs.append(mesh_output)\n        return outputs\n    else:\n        mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]\n        mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)\n        return mesh_output\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef instantiate_from_config(config, **kwargs):\n    if \"target\" not in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    cls = get_obj_from_str(config[\"target\"])\n    params = config.get(\"params\", dict())\n    kwargs.update(params)\n    instance = cls(**kwargs)\n    return instance\n\n\nclass DiTPipeline:\n    model_cpu_offload_seq = \"conditioner->model->vae\"\n    _exclude_from_cpu_offload = []\n\n    @classmethod\n    @synchronize_timer('DiTPipeline Model Loading')\n    def from_single_file(\n        cls,\n        ckpt_path,\n        config_path,\n        device='cuda',\n        dtype=torch.float16,\n        use_safetensors=None,\n        **kwargs,\n    ):\n        # load config\n        with open(config_path, 'r') as f:\n            config = yaml.safe_load(f)\n\n        # load ckpt\n        if use_safetensors:\n            ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')\n        if not os.path.exists(ckpt_path):\n            raise FileNotFoundError(f\"Model file {ckpt_path} not found\")\n        logger.info(f\"Loading model from {ckpt_path}\")\n\n        if use_safetensors:\n            # parse safetensors\n            import safetensors.torch\n            safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')\n            ckpt = {}\n            for key, value in safetensors_ckpt.items():\n                model_name = key.split('.')[0]\n                new_key = key[len(model_name) + 1:]\n                if model_name not in ckpt:\n                    ckpt[model_name] = {}\n                ckpt[model_name][new_key] = value\n        else:\n            ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)\n        # load model\n        model = instantiate_from_config(config['model'])\n        model.load_state_dict(ckpt['model'])\n        vae = instantiate_from_config(config['vae'])\n        vae.load_state_dict(ckpt['vae'], strict=False)\n        conditioner = instantiate_from_config(config['conditioner'])\n        if 'conditioner' in ckpt:\n            conditioner.load_state_dict(ckpt['conditioner'])\n        image_processor = instantiate_from_config(config['image_processor'])\n        scheduler = instantiate_from_config(config['scheduler'])\n\n        model_kwargs = dict(\n            vae=vae,\n            model=model,\n            scheduler=scheduler,\n            conditioner=conditioner,\n            image_processor=image_processor,\n            device=device,\n            dtype=dtype,\n        )\n        model_kwargs.update(kwargs)\n\n        return cls(\n            **model_kwargs\n        )\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_path,\n        device='cuda',\n        dtype=torch.float16,\n        use_safetensors=False,\n        variant='fp16',\n        subfolder='hunyuan3d-dit-v2-1',\n        **kwargs,\n    ):\n        kwargs['from_pretrained_kwargs'] = dict(\n            model_path=model_path,\n            subfolder=subfolder,\n            use_safetensors=use_safetensors,\n            variant=variant,\n            dtype=dtype,\n            device=device,\n        )\n        config_path, ckpt_path = smart_load_model(\n            model_path,\n            subfolder=subfolder,\n            use_safetensors=use_safetensors,\n            variant=variant\n        )\n        return cls.from_single_file(\n            ckpt_path,\n            config_path,\n            device=device,\n            dtype=dtype,\n            use_safetensors=use_safetensors,\n            **kwargs\n        )\n\n    def __init__(\n        self,\n        vae,\n        model,\n        scheduler,\n        conditioner,\n        image_processor,\n        device='cuda',\n        dtype=torch.float16,\n        ref_model=None,\n        **kwargs\n    ):\n        self.vae = vae\n        self.model = model\n        self.ref_model = ref_model\n        self.scheduler = scheduler\n        self.conditioner = conditioner\n        self.image_processor = image_processor\n        self.kwargs = kwargs\n\n        self.components = {\n            \"vae\": vae,\n            \"model\": model,\n            \"scheduler\": scheduler,\n            \"conditioner\": conditioner,\n            \"image_processor\": image_processor,\n        }\n        if ref_model is not None:\n             self.components[\"ref_model\"] = ref_model\n\n        self.to(device, dtype)\n\n    def compile(self):\n        self.vae = torch.compile(self.vae)\n        self.model = torch.compile(self.model)\n        self.conditioner = torch.compile(self.conditioner)\n\n    def enable_flashvdm(\n        self,\n        enabled: bool = True,\n        adaptive_kv_selection=True,\n        topk_mode='mean',\n        mc_algo='mc',\n        replace_vae=True,\n    ):\n        if enabled:\n            self.vae.enable_flashvdm_decoder(\n                enabled=enabled,\n                adaptive_kv_selection=adaptive_kv_selection,\n                topk_mode=topk_mode,\n                mc_algo=mc_algo\n            )\n        else:\n            self.vae.enable_flashvdm_decoder(enabled=False)\n\n    def to(self, device=None, dtype=None):\n        if dtype is not None:\n            self.dtype = dtype\n            self.vae.to(dtype=dtype)\n            self.model.to(dtype=dtype)\n            self.conditioner.to(dtype=dtype)\n        if device is not None:\n            self.device = torch.device(device)\n            self.vae.to(device)\n            self.model.to(device)\n            self.conditioner.to(device)\n\n    @property\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from\n        Accelerate's module hooks.\n        \"\"\"\n        for name, model in self.components.items():\n            if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:\n                continue\n\n            if not hasattr(model, \"_hf_hook\"):\n                return self.device\n            for module in model.modules():\n                if (\n                    hasattr(module, \"_hf_hook\")\n                    and hasattr(module._hf_hook, \"execution_device\")\n                    and module._hf_hook.execution_device is not None\n                ):\n                    return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = \"cuda\"):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared\n        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`\n        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with\n        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.\n\n        Arguments:\n            gpu_id (`int`, *optional*):\n                The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.\n            device (`torch.Device` or `str`, *optional*, defaults to \"cuda\"):\n                The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will\n                default to \"cuda\".\n        \"\"\"\n        if self.model_cpu_offload_seq is None:\n            raise ValueError(\n                \"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set.\"\n            )\n\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\"`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.\")\n\n        torch_device = torch.device(device)\n        device_index = torch_device.index\n\n        if gpu_id is not None and device_index is not None:\n            raise ValueError(\n                f\"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}\"\n                f\"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of \"\n                f\"the device: `device`={torch_device.type}\"\n            )\n\n        # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`)\n        # or default to previously set id or default to 0\n        self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, \"_offload_gpu_id\", 0)\n\n        device_type = torch_device.type\n        device = torch.device(f\"{device_type}:{self._offload_gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\")\n            device_mod = getattr(torch, self.device.type, None)\n            if hasattr(device_mod, \"empty_cache\") and device_mod.is_available():\n                device_mod.empty_cache()  \n                # otherwise we don't see the memory savings (but they probably exist)\n\n        all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}\n\n        self._all_hooks = []\n        hook = None\n        for model_str in self.model_cpu_offload_seq.split(\"->\"):\n            model = all_model_components.pop(model_str, None)\n            if not isinstance(model, torch.nn.Module):\n                continue\n\n            _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)\n            self._all_hooks.append(hook)\n\n        # CPU offload models that are not in the seq chain unless they are explicitly excluded\n        # these models will stay on CPU until maybe_free_model_hooks is called\n        # some models cannot be in the seq chain because they are iteratively called, \n        # such as controlnet\n        for name, model in all_model_components.items():\n            if not isinstance(model, torch.nn.Module):\n                continue\n\n            if name in self._exclude_from_cpu_offload:\n                model.to(device)\n            else:\n                _, hook = cpu_offload_with_hook(model, device)\n                self._all_hooks.append(hook)\n\n    def maybe_free_model_hooks(self):\n        r\"\"\"\n        Function that offloads all components, removes all model hooks that were added when using\n        `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function\n        is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it\n        functions correctly when applying enable_model_cpu_offload.\n        \"\"\"\n        if not hasattr(self, \"_all_hooks\") or len(self._all_hooks) == 0:\n            # `enable_model_cpu_offload` has not be called, so silently do nothing\n            return\n\n        for hook in self._all_hooks:\n            # offload model and remove hook from model\n            hook.offload()\n            hook.remove()\n\n        # make sure the model is in the same state as before calling it\n        self.enable_model_cpu_offload()\n\n    @synchronize_timer('Encode cond')\n    def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):\n        bsz = image.shape[0]\n        cond = self.conditioner(image=image, **additional_cond_inputs)  # cond['main'].shape\n\n        if do_classifier_free_guidance:\n            cond_token_num = cond[\"main\"].shape[1]\n            additional_cond_inputs[\"num_tokens\"] = cond_token_num\n            un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)\n\n            if dual_guidance:\n                un_cond_drop_main = copy.deepcopy(un_cond)\n                un_cond_drop_main['additional'] = cond['additional']\n\n                def cat_recursive(a, b, c):\n                    if isinstance(a, torch.Tensor):\n                        return torch.cat([a, b, c], dim=0).to(self.dtype)\n                    out = {}\n                    for k in a.keys():\n                        out[k] = cat_recursive(a[k], b[k], c[k])\n                    return out\n\n                cond = cat_recursive(cond, un_cond_drop_main, un_cond)\n            else:\n                def cat_recursive(a, b):\n                    if isinstance(a, torch.Tensor):\n                        return torch.cat([a, b], dim=0).to(self.dtype)\n                    out = {}\n                    for k in a.keys():\n                        out[k] = cat_recursive(a[k], b[k])\n                    return out\n\n                cond = cat_recursive(cond, un_cond)\n        return cond\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def prepare_latents(self, batch_size, dtype, device, generator, latents=None, shape=None):\n        if shape is None:\n            shape = (batch_size, *self.vae.latent_shape)\n        else:\n            shape = (batch_size, *shape)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)\n        return latents\n\n    def prepare_image(self, image, mask=None) -> dict:\n        if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):\n            outputs = {\n                'image': image,\n                'mask': mask\n            }\n            return outputs\n            \n        if isinstance(image, str) and not os.path.exists(image):\n            raise FileNotFoundError(f\"Couldn't find image at path {image}\")\n\n        if not isinstance(image, list):\n            image = [image]\n\n        outputs = []\n        for img in image:\n            output = self.image_processor(img)  # output['image'].shape\n            outputs.append(output)\n\n        cond_input = {k: [] for k in outputs[0].keys()}\n        for output in outputs:\n            for key, value in output.items():\n                cond_input[key].append(value)\n        for key, value in cond_input.items():\n            if isinstance(value[0], torch.Tensor):\n                cond_input[key] = torch.cat(value, dim=0)\n\n        return cond_input\n\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    def set_surface_extractor(self, mc_algo):\n        if mc_algo is None:\n            return\n        logger.info('The parameters `mc_algo` is deprecated, and will be removed in future versions.\\n'\n                    'Please use: \\n'\n                    'from hy3dshape.models.autoencoders import SurfaceExtractors\\n'\n                    'pipeline.vae.surface_extractor = SurfaceExtractors[mc_algo]() instead\\n')\n        if mc_algo not in SurfaceExtractors.keys():\n            raise ValueError(f\"Unknown mc_algo {mc_algo}\")\n        self.vae.surface_extractor = SurfaceExtractors[mc_algo]()\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Union[str, List[str], Image.Image] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        eta: float = 0.0,\n        guidance_scale: float = 7.5,\n        dual_guidance_scale: float = 10.5,\n        dual_guidance: bool = True,\n        generator=None,\n        box_v=1.01,\n        octree_resolution=384,\n        mc_level=-1 / 512,\n        num_chunks=8000,\n        mc_algo=None,\n        output_type: Optional[str] = \"trimesh\",\n        enable_pbar=True,\n        **kwargs,\n    ) -> List[List[trimesh.Trimesh]]:\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        self.set_surface_extractor(mc_algo)\n\n        device = self.device\n        dtype = self.dtype\n        do_classifier_free_guidance = guidance_scale >= 0 and \\\n                                      getattr(self.model, 'guidance_cond_proj_dim', None) is None\n        dual_guidance = dual_guidance_scale >= 0 and dual_guidance\n\n        if isinstance(image, torch.Tensor):\n            pass\n        else:\n            cond_inputs = self.prepare_image(image)\n            image = cond_inputs.pop('image')\n        \n        cond = self.encode_cond(\n            image=image,\n            additional_cond_inputs=cond_inputs,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            dual_guidance=False,\n        )\n        batch_size = image.shape[0]\n\n        t_dtype = torch.long\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas)\n\n        latents = self.prepare_latents(batch_size, dtype, device, generator)\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        guidance_cond = None\n        if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:\n            logger.info('Using lcm guidance scale')\n            guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)\n            guidance_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n        with synchronize_timer('Diffusion Sampling'):\n            for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc=\"Diffusion Sampling:\", leave=False)):\n                # expand the latents if we are doing classifier free guidance\n                if do_classifier_free_guidance:\n                    latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))\n                else:\n                    latent_model_input = latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)\n                timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])\n                noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)\n\n                # no drop, drop clip, all drop\n                if do_classifier_free_guidance:\n                    if dual_guidance:\n                        noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)\n                        noise_pred = (\n                            noise_pred_uncond\n                            + guidance_scale * (noise_pred_clip - noise_pred_dino)\n                            + dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)\n                        )\n                    else:\n                        noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)\n                latents = outputs.prev_sample\n\n                if callback is not None and i % callback_steps == 0:\n                    step_idx = i // getattr(self.scheduler, \"order\", 1)\n                    callback(step_idx, t, outputs)\n\n        return self._export(\n            latents,\n            output_type,\n            box_v, mc_level, num_chunks, octree_resolution, mc_algo,\n        )\n\n    def _export(\n        self,\n        latents,\n        output_type='trimesh',\n        box_v=1.01,\n        mc_level=0.0,\n        num_chunks=20000,\n        octree_resolution=256,\n        mc_algo='mc',\n        enable_pbar=True\n    ):\n        if not output_type == \"latent\":\n            latents = 1. / self.vae.scale_factor * latents\n            latents = self.vae(latents)\n            outputs, _ = self.vae.latents2mesh(\n                latents,\n                bounds=box_v,\n                mc_level=mc_level,\n                num_chunks=num_chunks,\n                octree_resolution=octree_resolution,\n                mc_algo=mc_algo,\n                enable_pbar=enable_pbar,\n            )\n        else:\n            outputs = latents\n\n        if output_type == 'trimesh':\n            outputs = export_to_trimesh(outputs)\n\n        return outputs\n\n\nclass UltraShapePipeline(DiTPipeline):\n\n    @torch.inference_mode()\n    def __call__(\n        self,\n        image: Union[str, List[str], Image.Image, dict, List[dict], torch.Tensor] = None,\n        voxel_cond: torch.Tensor = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        eta: float = 0.0,\n        guidance_scale: float = 5.0,\n        generator=None,\n        box_v=1.01,\n        octree_resolution=384,\n        mc_level=0.0,\n        mc_algo=None,\n        num_chunks=8000,\n        output_type: Optional[str] = \"trimesh\",\n        enable_pbar=True,\n        mask = None,\n        **kwargs,\n    ) -> List[List[trimesh.Trimesh]]:\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        self.set_surface_extractor(mc_algo)\n\n        device = self.device\n        dtype = self.dtype\n        do_classifier_free_guidance = guidance_scale >= 0 and not (\n            hasattr(self.model, 'guidance_embed') and\n            self.model.guidance_embed is True\n        )\n\n        # print('image', type(image), 'mask', type(mask))\n        cond_inputs = self.prepare_image(image, mask)\n        image = cond_inputs.pop('image')\n        cond = self.encode_cond(\n            image=image,\n            additional_cond_inputs=cond_inputs,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            dual_guidance=False,\n        )\n\n        batch_size = image.shape[0]\n\n        # 5. Prepare timesteps\n        # NOTE: this is slightly different from common usage, we start from 0.\n        sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n        )\n        latents_shape = None\n        if voxel_cond is not None:\n             # voxel_cond: [B, N, 3] -> [N, 3] if batched? No, it's [B, N, 3] usually\n             # The encoder expects [B, N, 3]\n             num_tokens = voxel_cond.shape[1]\n             latents_shape = (num_tokens, self.vae.latent_shape[-1])\n\n        latents = self.prepare_latents(batch_size, dtype, device, generator, shape=latents_shape)\n\n        guidance = None\n        if hasattr(self.model, 'guidance_embed') and \\\n            self.model.guidance_embed is True:\n            guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)\n            # logger.info(f'Using guidance embed with scale {guidance_scale}')\n        if do_classifier_free_guidance and voxel_cond is not None:\n            voxel_cond = torch.cat([voxel_cond] * 2)\n        with synchronize_timer('Diffusion Sampling'):\n            for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc=\"Diffusion Sampling:\")):\n                # expand the latents if we are doing classifier free guidance\n                if do_classifier_free_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                else:\n                    latent_model_input = latents\n\n                # NOTE: we assume model get timesteps ranged from 0 to 1\n                timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype).to(latent_model_input.device)\n                timestep = timestep / self.scheduler.config.num_train_timesteps\n                if voxel_cond is None:\n                    noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)\n                else:\n                    noise_pred = self.model(latent_model_input, timestep, cond, \n                            guidance=guidance, voxel_cond=voxel_cond)\n\n                if do_classifier_free_guidance:\n                    noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                outputs = self.scheduler.step(noise_pred, t, latents)\n                latents = outputs.prev_sample\n\n                if callback is not None and i % callback_steps == 0:\n                    step_idx = i // getattr(self.scheduler, \"order\", 1)\n                    callback(step_idx, t, outputs)\n\n        return self._export(\n            latents,\n            output_type,\n            box_v, mc_level, num_chunks, octree_resolution, mc_algo,\n            enable_pbar=enable_pbar,\n        ), latents\n"
  },
  {
    "path": "ultrashape/postprocessors.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\nimport tempfile\nfrom typing import Union\n\nimport numpy as np\nimport pymeshlab\nimport torch\nimport trimesh\n\nfrom .models.autoencoders import Latent2MeshOutput\nfrom .utils import synchronize_timer\n\n\ndef load_mesh(path):\n    if path.endswith(\".glb\"):\n        mesh = trimesh.load(path)\n    else:\n        mesh = pymeshlab.MeshSet()\n        mesh.load_new_mesh(path)\n    return mesh\n\n\ndef reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):\n    if max_facenum > mesh.current_mesh().face_number():\n        return mesh\n\n    mesh.apply_filter(\n        \"meshing_decimation_quadric_edge_collapse\",\n        targetfacenum=max_facenum,\n        qualitythr=1.0,\n        preserveboundary=True,\n        boundaryweight=3,\n        preservenormal=True,\n        preservetopology=True,\n        autoclean=True\n    )\n    return mesh\n\n\ndef remove_floater(mesh: pymeshlab.MeshSet):\n    mesh.apply_filter(\"compute_selection_by_small_disconnected_components_per_face\",\n                      nbfaceratio=0.005)\n    mesh.apply_filter(\"compute_selection_transfer_face_to_vertex\", inclusive=False)\n    mesh.apply_filter(\"meshing_remove_selected_vertices_and_faces\")\n    return mesh\n\n\ndef pymeshlab2trimesh(mesh: pymeshlab.MeshSet):\n    with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:\n        mesh.save_current_mesh(temp_file.name)\n        mesh = trimesh.load(temp_file.name)\n    # 检查加载的对象类型\n    if isinstance(mesh, trimesh.Scene):\n        combined_mesh = trimesh.Trimesh()\n        # 如果是Scene，遍历所有的geometry并合并\n        for geom in mesh.geometry.values():\n            combined_mesh = trimesh.util.concatenate([combined_mesh, geom])\n        mesh = combined_mesh\n    return mesh\n\n\ndef trimesh2pymeshlab(mesh: trimesh.Trimesh):\n    with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:\n        if isinstance(mesh, trimesh.scene.Scene):\n            for idx, obj in enumerate(mesh.geometry.values()):\n                if idx == 0:\n                    temp_mesh = obj\n                else:\n                    temp_mesh = temp_mesh + obj\n            mesh = temp_mesh\n        mesh.export(temp_file.name)\n        mesh = pymeshlab.MeshSet()\n        mesh.load_new_mesh(temp_file.name)\n    return mesh\n\n\ndef export_mesh(input, output):\n    if isinstance(input, pymeshlab.MeshSet):\n        mesh = output\n    elif isinstance(input, Latent2MeshOutput):\n        output = Latent2MeshOutput()\n        output.mesh_v = output.current_mesh().vertex_matrix()\n        output.mesh_f = output.current_mesh().face_matrix()\n        mesh = output\n    else:\n        mesh = pymeshlab2trimesh(output)\n    return mesh\n\n\ndef import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:\n    if isinstance(mesh, str):\n        mesh = load_mesh(mesh)\n    elif isinstance(mesh, Latent2MeshOutput):\n        mesh = pymeshlab.MeshSet()\n        mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)\n        mesh.add_mesh(mesh_pymeshlab, \"converted_mesh\")\n\n    if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):\n        mesh = trimesh2pymeshlab(mesh)\n\n    return mesh\n\n\nclass FaceReducer:\n    @synchronize_timer('FaceReducer')\n    def __call__(\n        self,\n        mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],\n        max_facenum: int = 40000\n    ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:\n        ms = import_mesh(mesh)\n        ms = reduce_face(ms, max_facenum=max_facenum)\n        mesh = export_mesh(mesh, ms)\n        return mesh\n\n\nclass FloaterRemover:\n    @synchronize_timer('FloaterRemover')\n    def __call__(\n        self,\n        mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],\n    ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:\n        ms = import_mesh(mesh)\n        ms = remove_floater(ms)\n        mesh = export_mesh(mesh, ms)\n        return mesh\n\n\nclass DegenerateFaceRemover:\n    @synchronize_timer('DegenerateFaceRemover')\n    def __call__(\n        self,\n        mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],\n    ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:\n        ms = import_mesh(mesh)\n\n        with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:\n            ms.save_current_mesh(temp_file.name)\n            ms = pymeshlab.MeshSet()\n            ms.load_new_mesh(temp_file.name)\n\n        mesh = export_mesh(mesh, ms)\n        return mesh\n\n\ndef mesh_normalize(mesh):\n    \"\"\"\n    Normalize mesh vertices to sphere\n    \"\"\"\n    scale_factor = 1.2\n    vtx_pos = np.asarray(mesh.vertices)\n    max_bb = (vtx_pos - 0).max(0)[0]\n    min_bb = (vtx_pos - 0).min(0)[0]\n\n    center = (max_bb + min_bb) / 2\n\n    scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0\n\n    vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))\n    mesh.vertices = vtx_pos\n\n    return mesh\n\n\nclass MeshSimplifier:\n    def __init__(self, executable: str = None):\n        if executable is None:\n            CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))\n            executable = os.path.join(CURRENT_DIR, \"mesh_simplifier.bin\")\n        self.executable = executable\n\n    @synchronize_timer('MeshSimplifier')\n    def __call__(\n        self,\n        mesh: Union[trimesh.Trimesh],\n    ) -> Union[trimesh.Trimesh]:\n        with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:\n            with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:\n                mesh.export(temp_input.name)\n                os.system(f'{self.executable} {temp_input.name} {temp_output.name}')\n                ms = trimesh.load(temp_output.name, process=False)\n                if isinstance(ms, trimesh.Scene):\n                    combined_mesh = trimesh.Trimesh()\n                    for geom in ms.geometry.values():\n                        combined_mesh = trimesh.util.concatenate([combined_mesh, geom])\n                    ms = combined_mesh\n                ms = mesh_normalize(ms)\n                return ms\n"
  },
  {
    "path": "ultrashape/preprocessors.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport cv2\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom einops import repeat, rearrange\n\n\ndef array_to_tensor(np_array):\n    image_pt = torch.tensor(np_array).float()\n    image_pt = image_pt / 255 * 2 - 1\n    image_pt = rearrange(image_pt, \"h w c -> c h w\")\n    image_pts = repeat(image_pt, \"c h w -> b c h w\", b=1)\n    return image_pts\n\n\nclass ImageProcessorV2:\n    def __init__(self, size=512, border_ratio=None):\n        self.size = size\n        self.border_ratio = border_ratio\n\n    @staticmethod\n    def recenter(image, border_ratio: float = 0.2):\n        \"\"\" recenter an image to leave some empty space at the image border.\n\n        Args:\n            image (ndarray): input image, float/uint8 [H, W, 3/4]\n            mask (ndarray): alpha mask, bool [H, W]\n            border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.\n\n        Returns:\n            ndarray: output image, float/uint8 [H, W, 3/4]\n        \"\"\"\n\n        if image.shape[-1] == 4:\n            mask = image[..., 3]\n        else:\n            mask = np.ones_like(image[..., 0:1]) * 255\n            image = np.concatenate([image, mask], axis=-1)\n            mask = mask[..., 0]\n\n        H, W, C = image.shape\n\n        size = max(H, W)\n        result = np.zeros((size, size, C), dtype=np.uint8)\n\n        coords = np.nonzero(mask)\n        x_min, x_max = coords[0].min(), coords[0].max()\n        y_min, y_max = coords[1].min(), coords[1].max()\n        h = x_max - x_min\n        w = y_max - y_min\n        if h == 0 or w == 0:\n            raise ValueError('input image is empty')\n        desired_size = int(size * (1 - border_ratio))\n        scale = desired_size / max(h, w)\n        h2 = int(h * scale)\n        w2 = int(w * scale)\n        x2_min = (size - h2) // 2\n        x2_max = x2_min + h2\n\n        y2_min = (size - w2) // 2\n        y2_max = y2_min + w2\n\n        result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2),\n                                                          interpolation=cv2.INTER_AREA)\n\n        bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255\n\n        mask = result[..., 3:].astype(np.float32) / 255\n        result = result[..., :3] * mask + bg * (1 - mask)\n\n        mask = mask * 255\n        result = result.clip(0, 255).astype(np.uint8)\n        mask = mask.clip(0, 255).astype(np.uint8)\n        return result, mask\n\n    def load_image(self, image, border_ratio=0.15, to_tensor=True):\n        if isinstance(image, str):\n            image = cv2.imread(image, cv2.IMREAD_UNCHANGED)\n            image, mask = self.recenter(image, border_ratio=border_ratio)\n            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        elif isinstance(image, Image.Image):\n            image = image.convert(\"RGBA\")\n            image = np.asarray(image)\n            image, mask = self.recenter(image, border_ratio=border_ratio)\n\n        image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)\n        mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)\n        mask = mask[..., np.newaxis]\n\n        if to_tensor:\n            image = array_to_tensor(image)\n            mask = array_to_tensor(mask)\n        return image, mask\n\n    def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):\n        if self.border_ratio is not None:\n            border_ratio = self.border_ratio\n        image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)\n        outputs = {\n            'image': image,\n            'mask': mask\n        }\n        return outputs\n\n\nclass MVImageProcessorV2(ImageProcessorV2):\n    \"\"\"\n    view order: front, front clockwise 90, back, front clockwise 270\n    \"\"\"\n    return_view_idx = True\n\n    def __init__(self, size=512, border_ratio=None):\n        super().__init__(size, border_ratio)\n        self.view2idx = {\n            'front': 0,\n            'left': 1,\n            'back': 2,\n            'right': 3\n        }\n\n    def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):\n        if self.border_ratio is not None:\n            border_ratio = self.border_ratio\n\n        images = []\n        masks = []\n        view_idxs = []\n        for idx, (view_tag, image) in enumerate(image_dict.items()):\n            view_idxs.append(self.view2idx[view_tag])\n            image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)\n            images.append(image)\n            masks.append(mask)\n\n        zipped_lists = zip(view_idxs, images, masks)\n        sorted_zipped_lists = sorted(zipped_lists)\n        view_idxs, images, masks = zip(*sorted_zipped_lists)\n\n        image = torch.cat(images, 0).unsqueeze(0)\n        mask = torch.cat(masks, 0).unsqueeze(0)\n        outputs = {\n            'image': image,\n            'mask': mask,\n            'view_idxs': view_idxs\n        }\n        return outputs\n\n\nIMAGE_PROCESSORS = {\n    \"v2\": ImageProcessorV2,\n    'mv_v2': MVImageProcessorV2,\n}\n\nDEFAULT_IMAGEPROCESSOR = 'v2'\n"
  },
  {
    "path": "ultrashape/rembg.py",
    "content": "# ==============================================================================\n# Original work Copyright (c) 2025 Tencent.\n# Modified work Copyright (c) 2025 UltraShape Team.\n# \n# Modified by UltraShape on 2025.12.25\n# ==============================================================================\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nfrom PIL import Image\nfrom rembg import remove, new_session\n\n\nclass BackgroundRemover():\n    def __init__(self):\n        self.session = new_session()\n\n    def __call__(self, image: Image.Image):\n        output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])\n        return output\n"
  },
  {
    "path": "ultrashape/schedulers.py",
    "content": "# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput, logging\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n    \"\"\"\n\n    prev_sample: torch.FloatTensor\n\n\nclass FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed\n\n    Euler scheduler.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        timestep_spacing (`str`, defaults to `\"linspace\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        shift (`float`, defaults to 1.0):\n            The shift value for the timestep schedule.\n    \"\"\"\n\n    _compatibles = []\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        shift: float = 1.0,\n        use_dynamic_shifting=False,\n    ):\n        timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()\n        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)\n\n        sigmas = timesteps / num_train_timesteps\n        if not use_dynamic_shifting:\n            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n\n        self.timesteps = sigmas * num_train_timesteps\n\n        self._step_index = None\n        self._begin_index = None\n\n        self.sigmas = sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    def scale_noise(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        noise: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Forward process in flow-matching\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The input sample.\n            timestep (`int`, *optional*):\n                The current timestep in the diffusion chain.\n\n        Returns:\n            `torch.FloatTensor`:\n                A scaled input sample.\n        \"\"\"\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)\n\n        if sample.device.type == \"mps\" and torch.is_floating_point(timestep):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)\n            timestep = timestep.to(sample.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(sample.device)\n            timestep = timestep.to(sample.device)\n\n        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timestep.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timestep.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(sample.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        sample = sigma * noise + (1.0 - sigma) * sample\n\n        return sample\n\n    def _sigma_to_t(self, sigma):\n        return sigma * self.config.num_train_timesteps\n\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n        mu: Optional[float] = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`\")\n\n        if sigmas is None:\n            self.num_inference_steps = num_inference_steps\n            timesteps = np.linspace(\n                self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps\n            )\n\n            sigmas = timesteps / self.config.num_train_timesteps\n\n        if self.config.use_dynamic_shifting:\n            sigmas = self.time_shift(mu, 1.0, sigmas)\n        else:\n            sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)\n\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)\n        timesteps = sigmas * self.config.num_train_timesteps\n\n        self.timesteps = timesteps.to(device=device)\n        self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])\n\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep):\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        sample: torch.FloatTensor,\n        s_churn: float = 0.0,\n        s_tmin: float = 0.0,\n        s_tmax: float = float(\"inf\"),\n        s_noise: float = 1.0,\n        generator: Optional[torch.Generator] = None,\n        return_dict: bool = True,\n    ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output (`torch.FloatTensor`):\n                The direct output from learned diffusion model.\n            timestep (`float`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.FloatTensor`):\n                A current instance of a sample created by the diffusion process.\n            s_churn (`float`):\n            s_tmin  (`float`):\n            s_tmax  (`float`):\n            s_noise (`float`, defaults to 1.0):\n                Scaling factor for noise added to the sample.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            return_dict (`bool`):\n                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or\n                tuple.\n\n        Returns:\n            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is\n                returned, otherwise a tuple is returned where the first element is the sample tensor.\n        \"\"\"\n\n        if (\n            isinstance(timestep, int)\n            or isinstance(timestep, torch.IntTensor)\n            or isinstance(timestep, torch.LongTensor)\n        ):\n            raise ValueError(\n                (\n                    \"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to\"\n                    \" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass\"\n                    \" one of the `scheduler.timesteps` as a timestep.\"\n                ),\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        # Upcast to avoid precision issues when computing prev_sample\n        sample = sample.to(torch.float32).to(model_output.device)\n\n        sigma = self.sigmas[self.step_index].to(model_output.device)\n        sigma_next = self.sigmas[self.step_index + 1].to(model_output.device)\n\n        prev_sample = sample + (sigma_next - sigma) * model_output\n\n        # Cast sample back to model compatible dtype\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        # upon completion increase step index by one\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n\n@dataclass\nclass ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):\n    prev_sample: torch.FloatTensor\n    pred_original_sample: torch.FloatTensor\n\n\nclass ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):\n    _compatibles = []\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        pcm_timesteps: int = 50,\n    ):\n        sigmas = np.linspace(0, 1, num_train_timesteps)\n        step_ratio = num_train_timesteps // pcm_timesteps\n\n        euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1\n        euler_timesteps = np.asarray([0] + euler_timesteps.tolist())\n\n        self.euler_timesteps = euler_timesteps\n        self.sigmas = sigmas[self.euler_timesteps]\n        self.sigmas = torch.from_numpy((self.sigmas.copy())).to(dtype=torch.float32)\n        self.timesteps = self.sigmas * num_train_timesteps\n        self._step_index = None\n        self._begin_index = None\n        self.sigmas = self.sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    def _sigma_to_t(self, sigma):\n        return sigma * self.config.num_train_timesteps\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n        self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)\n        inference_indices = np.linspace(\n            0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False\n        )\n        inference_indices = np.floor(inference_indices).astype(np.int64)\n        inference_indices = torch.from_numpy(inference_indices).long()\n\n        self.sigmas_ = self.sigmas[inference_indices]\n        timesteps = self.sigmas_ * self.config.num_train_timesteps\n        self.timesteps = timesteps.to(device=device)\n        self.sigmas_ = torch.cat(\n            [self.sigmas_, torch.ones(1, device=self.sigmas_.device)]\n        )\n\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep):\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        sample: torch.FloatTensor,\n        generator: Optional[torch.Generator] = None,\n        return_dict: bool = True,\n    ) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:\n        if (\n            isinstance(timestep, int)\n            or isinstance(timestep, torch.IntTensor)\n            or isinstance(timestep, torch.LongTensor)\n        ):\n            raise ValueError(\n                (\n                    \"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to\"\n                    \" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass\"\n                    \" one of the `scheduler.timesteps` as a timestep.\"\n                ),\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        sample = sample.to(torch.float32).to(model_output.device)\n\n        sigma = self.sigmas_[self.step_index].to(model_output.device)\n        sigma_next = self.sigmas_[self.step_index + 1].to(model_output.device)\n\n        prev_sample = sample + (sigma_next - sigma) * model_output\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        pred_original_sample = sample + (1.0 - sigma) * model_output\n        pred_original_sample = pred_original_sample.to(model_output.dtype)\n\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,\n                                                                pred_original_sample=pred_original_sample)\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n"
  },
  {
    "path": "ultrashape/surface_loaders.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport numpy as np\nimport torch\nimport trimesh\n\n\ndef normalize_mesh(mesh, scale=0.9999):\n    \"\"\"\n    Normalize the mesh to fit inside a centered cube with a specified scale.\n\n    The mesh is translated so that its bounding box center is at the origin,\n    then uniformly scaled so that the longest side of the bounding box fits within [-scale, scale].\n\n    Args:\n        mesh (trimesh.Trimesh): Input mesh to normalize.\n        scale (float, optional): Scaling factor to slightly shrink the mesh inside the unit cube. Default is 0.9999.\n\n    Returns:\n        trimesh.Trimesh: The normalized mesh with applied translation and scaling.\n    \"\"\"\n    bbox = mesh.bounds\n    center = (bbox[1] + bbox[0]) / 2\n    scale_ = (bbox[1] - bbox[0]).max()\n\n    mesh.apply_translation(-center)\n    mesh.apply_scale(1 / scale_ * 2 * scale)\n\n    return mesh\n\n\ndef sample_pointcloud(mesh, num=200000):\n    \"\"\"\n    Sample points uniformly from the surface of the mesh along with their corresponding face normals.\n\n    Args:\n        mesh (trimesh.Trimesh): Input mesh to sample from.\n        num (int, optional): Number of points to sample. Default is 200000.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]:\n            - points: Sampled points as a float tensor of shape (num, 3).\n            - normals: Corresponding normals as a float tensor of shape (num, 3).\n    \"\"\"\n    points, face_idx = mesh.sample(num, return_index=True)\n    normals = mesh.face_normals[face_idx]\n    points = torch.from_numpy(points.astype(np.float32))\n    normals = torch.from_numpy(normals.astype(np.float32))\n    return points, normals\n\n\ndef load_surface(mesh, num_points=8192):\n    \"\"\"\n    Normalize the mesh, sample points and normals from its surface, and randomly select a subset.\n\n    Args:\n        mesh (trimesh.Trimesh): Input mesh to process.\n        num_points (int, optional): Number of points to randomly select \n                from the sampled surface points. Default is 8192.\n\n    Returns:\n        Tuple[torch.Tensor, trimesh.Trimesh]:\n            - surface: Tensor of shape (1, num_points, 6), concatenating points and normals.\n            - mesh: The normalized mesh.\n    \"\"\"\n\n    mesh = normalize_mesh(mesh, scale=0.98)\n    surface, normal = sample_pointcloud(mesh)\n\n    rng = np.random.default_rng(seed=0)\n    ind = rng.choice(surface.shape[0], num_points, replace=False)\n    surface = torch.FloatTensor(surface[ind])\n    normal = torch.FloatTensor(normal[ind])\n\n    surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)\n\n    return surface, mesh\n\n\ndef sharp_sample_pointcloud(mesh, num=16384):\n    \"\"\"\n    Sample points and normals preferentially from sharp edges of the mesh.\n\n    Sharp edges are detected based on the angle between vertex normals and face normals.\n    Points are sampled along these edges proportionally to edge length.\n\n    Args:\n        mesh (trimesh.Trimesh): Input mesh to sample from.\n        num (int, optional): Number of points to sample from sharp edges. Default is 16384.\n\n    Returns:\n        Tuple[np.ndarray, np.ndarray]:\n            - samples: Sampled points along sharp edges, shape (num, 3).\n            - normals: Corresponding interpolated normals, shape (num, 3).\n    \"\"\"\n    V = mesh.vertices\n    N = mesh.face_normals\n    VN = mesh.vertex_normals\n    F = mesh.faces\n    VN2 = np.ones(V.shape[0])\n    for i in range(3):\n        dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1)\n        VN2[F[:, i]] = np.min(dot, axis=-1)\n\n    sharp_mask = VN2 < 0.985\n    # collect edge\n    edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2]))\n    edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0]))\n    sharp_edge = ((sharp_mask[edge_a] * sharp_mask[edge_b]))\n    edge_a = edge_a[sharp_edge > 0]\n    edge_b = edge_b[sharp_edge > 0]\n\n    sharp_verts_a = V[edge_a]\n    sharp_verts_b = V[edge_b]\n    sharp_verts_an = VN[edge_a]\n    sharp_verts_bn = VN[edge_b]\n\n    weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)\n    weights /= np.sum(weights)\n\n    random_number = np.random.rand(num)\n    w = np.random.rand(num, 1)\n    index = np.searchsorted(weights.cumsum(), random_number)\n    samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]\n    normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index]\n    return samples, normals\n\n\ndef load_surface_sharpegde(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, normalize_scale=0.9999):\n    try:\n        mesh_full = trimesh.util.concatenate(mesh.dump())\n    except Exception as err:\n        mesh_full = trimesh.util.concatenate(mesh)\n    mesh_full = normalize_mesh(mesh_full, scale=normalize_scale)\n\n    origin_num = mesh_full.faces.shape[0]\n    original_vertices = mesh_full.vertices\n    original_faces = mesh_full.faces\n\n    mesh = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[:origin_num])\n    mesh_fill = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[origin_num:])\n    area = mesh.area\n    area_fill = mesh_fill.area\n    sample_num = 819200 // 2 # 499712 // 2\n    num_fill = int(sample_num * (area_fill / (area + area_fill)))\n    num = sample_num - num_fill\n\n    random_surface, random_normal = sample_pointcloud(mesh, num=num)\n    if num_fill == 0:\n        random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3))\n    else:\n        random_surface_fill, random_normal_fill = sample_pointcloud(mesh_fill, num=num_fill)\n    random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num)\n\n    # save_surface\n    surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16)\n    surface_fill = np.concatenate((random_surface_fill, random_normal_fill), axis=1).astype(np.float16)\n    sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype(np.float16)\n    surface = np.concatenate((surface, surface_fill), axis=0)\n    if sharpedge_flag:\n        sharpedge_label = np.zeros((surface.shape[0], 1))\n        surface = np.concatenate((surface, sharpedge_label), axis=1)\n        sharpedge_label = np.ones((sharp_surface.shape[0], 1))\n        sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)\n    rng = np.random.default_rng()\n    ind = rng.choice(surface.shape[0], num_points, replace=False)\n    surface = torch.FloatTensor(surface[ind])\n    ind = rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)\n    sharp_surface = torch.FloatTensor(sharp_surface[ind])\n\n    return torch.cat([surface, sharp_surface], dim=0).unsqueeze(0), mesh_full\n\n\nclass SurfaceLoader:\n    def __init__(self, num_points=8192):\n        self.num_points = num_points\n\n    def __call__(self, mesh_or_mesh_path, num_points=None):\n        if num_points is None:\n            num_points = self.num_points\n\n        mesh = mesh_or_mesh_path\n        if isinstance(mesh, str):\n            mesh = trimesh.load(mesh, force=\"mesh\", merge_primitives=True)\n        if isinstance(mesh, trimesh.scene.Scene):\n            for idx, obj in enumerate(mesh.geometry.values()):\n                if idx == 0:\n                    temp_mesh = obj\n                else:\n                    temp_mesh = temp_mesh + obj\n            mesh = temp_mesh\n        surface, mesh = load_surface(mesh, num_points=num_points)\n        return surface\n\n\nclass SharpEdgeSurfaceLoader:\n    def __init__(self, num_uniform_points=8192, num_sharp_points=8192, **kwargs):\n        self.num_uniform_points = num_uniform_points\n        self.num_sharp_points = num_sharp_points\n        self.num_points = num_uniform_points + num_sharp_points\n\n    def __call__(self, mesh_or_mesh_path, num_uniform_points=None, \n        num_sharp_points=None, normalize_scale=0.9999):\n        if num_uniform_points is None:\n            num_uniform_points = self.num_uniform_points\n        if num_sharp_points is None:\n            num_sharp_points = self.num_sharp_points\n\n        mesh = mesh_or_mesh_path\n        if isinstance(mesh, str):\n            mesh = trimesh.load(mesh, force=\"mesh\", merge_primitives=True)\n        if isinstance(mesh, trimesh.scene.Scene):\n            for idx, obj in enumerate(mesh.geometry.values()):\n                if idx == 0:\n                    temp_mesh = obj\n                else:\n                    temp_mesh = temp_mesh + obj\n            mesh = temp_mesh\n        surface, mesh = load_surface_sharpegde(mesh, num_points=num_uniform_points, \n            num_sharp_points=num_sharp_points, normalize_scale=normalize_scale)\n        return surface\n"
  },
  {
    "path": "ultrashape/utils/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .misc import get_config_from_file\nfrom .misc import instantiate_from_config\nfrom .utils import get_logger, logger, synchronize_timer, smart_load_model\nfrom .voxelize import voxelize_from_point\n"
  },
  {
    "path": "ultrashape/utils/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_updates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n\n        self.m_name2s_name = {}\n        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates\n        else torch.tensor(-1, dtype=torch.int))\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                # remove as '.'-character is not allowed in buffers\n                s_name = name.replace('.', '_____')\n                self.m_name2s_name.update({name: s_name})\n                self.register_buffer(s_name, p.clone().detach().data)\n\n        self.collected_params = []\n\n    def forward(self, model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, model):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in model.parameters()]\n\n    def restore(self, model):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, model.parameters()):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "ultrashape/utils/misc.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport importlib\nfrom omegaconf import OmegaConf, DictConfig, ListConfig\n\nimport torch\nimport torch.distributed as dist\nfrom typing import Union\nfrom .utils import logger\nimport os\n\n\ndef get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:\n    config_file = OmegaConf.load(config_file)\n\n    if 'base_config' in config_file.keys():\n        if config_file['base_config'] == \"default_base\":\n            base_config = OmegaConf.create()\n            # base_config = get_default_config()\n        elif config_file['base_config'].endswith(\".yaml\"):\n            base_config = get_config_from_file(config_file['base_config'])\n        else:\n            raise ValueError(f\"{config_file} must be `.yaml` file or it contains `base_config` key.\")\n\n        config_file = {key: value for key, value in config_file if key != \"base_config\"}\n\n        return OmegaConf.merge(base_config, config_file)\n\n    return config_file\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef get_obj_from_config(config):\n    if \"target\" not in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n\n    return get_obj_from_str(config[\"target\"])\n\n\ndef instantiate_from_config(config, **kwargs):\n    if \"target\" not in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n\n    cls = get_obj_from_str(config[\"target\"])\n\n    if config.get(\"from_pretrained\", None):\n        return cls.from_pretrained(\n                    config[\"from_pretrained\"], \n                    use_safetensors=config.get('use_safetensors', False),\n                    variant=config.get('variant', 'fp16'))\n\n    params = config.get(\"params\", dict())\n    # params.update(kwargs)\n    # instance = cls(**params)\n    kwargs.update(params)\n    instance = cls(**kwargs)\n\n    return instance\n\n\ndef instantiate_vae_from_config(config, **kwargs):\n    if \"target\" not in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n\n    cls = get_obj_from_str(config[\"target\"])\n\n    if config.get(\"from_pretrained\", None):\n        return cls.from_pretrained(\n                    config[\"from_pretrained\"], \n                    params=config.get(\"params\", dict()),\n                    use_safetensors=config.get('use_safetensors', False),\n                    variant=config.get('variant', 'fp16'))\n\n    params = config.get(\"params\", dict())\n    kwargs.update(params)\n    instance = cls(**kwargs)\n\n    return instance\n\ndef instantiate_vae_from_config_local(config, **kwargs):\n    if \"target\" not in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n\n    cls = get_obj_from_str(config[\"target\"])\n\n    if not config.get(\"from_pretrained\", None):\n        raise FileNotFoundError(f\"Need from_pretrained!\")\n    \n    ckpt_path = config[\"from_pretrained\"]\n            \n    logger.info(f\"Loading model from {ckpt_path}\")\n    if not os.path.exists(ckpt_path):\n        raise FileNotFoundError(f\"Model file {ckpt_path} not found\")\n    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)\n\n    if 'state_dict' not in ckpt:\n        # deepspeed ckpt\n        state_dict = {}\n        for k in ckpt.keys():\n            new_k = k.replace('vae_model.', '')\n            state_dict[new_k] = ckpt[k]\n    else:\n        state_dict = ckpt[\"state_dict\"]\n\n    params = config.get(\"params\", dict())\n    kwargs.update(params)\n    instance = cls(**kwargs)\n\n\n    missing, unexpected = instance.load_state_dict(state_dict)\n    print(f\"VAE Missing Keys: {missing}\")\n    print(f\"VAE Unexpected Keys: {unexpected}\")\n\n    return instance\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef instantiate_non_trainable_model(config):\n    model = instantiate_from_config(config)\n    model = model.eval()\n    model.train = disabled_train\n    for param in model.parameters():\n        param.requires_grad = False\n\n    return model\n\n\ndef instantiate_vae_model(config, requires_grad=False):\n    model = instantiate_vae_from_config(config)\n    model = model.eval()\n    model.train = disabled_train\n    for param in model.parameters():\n        param.requires_grad = requires_grad\n\n    return model\n\ndef instantiate_vae_model_local(config, requires_grad=False):\n    model = instantiate_vae_from_config_local(config)\n    model = model.eval()\n    model.train = disabled_train\n    for param in model.parameters():\n        param.requires_grad = requires_grad\n\n    return model\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef all_gather_batch(tensors):\n    \"\"\"\n    Performs all_gather operation on the provided tensors.\n    \"\"\"\n    # Queue the gathered tensors\n    world_size = get_world_size()\n    # There is no need for reduction in the single-proc case\n    if world_size == 1:\n        return tensors\n    tensor_list = []\n    output_tensor = []\n    for tensor in tensors:\n        tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]\n        dist.all_gather(\n            tensor_all,\n            tensor,\n            async_op=False  # performance opt\n        )\n\n        tensor_list.append(tensor_all)\n\n    for tensor_all in tensor_list:\n        output_tensor.append(torch.cat(tensor_all, dim=0))\n    return output_tensor\n"
  },
  {
    "path": "ultrashape/utils/trainings/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "ultrashape/utils/trainings/callback.py",
    "content": "# ------------------------------------------------------------------------------------\n# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)\n# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.\n# ------------------------------------------------------------------------------------\n\nimport os\nimport time\nimport wandb\nimport numpy as np\nfrom PIL import Image\nfrom pathlib import Path\nfrom omegaconf import OmegaConf, DictConfig\nfrom typing import Tuple, Generic, Dict, Callable, Optional, Any\nfrom pprint import pprint\n\nimport torch\nimport torchvision\nimport pytorch_lightning as pl\nimport pytorch_lightning.loggers\nfrom pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning.loggers.logger import DummyLogger\nfrom pytorch_lightning.utilities import rank_zero_only, rank_zero_info\nfrom pytorch_lightning.callbacks import Callback\n\nfrom functools import wraps\n\ndef node_zero_only(fn: Callable) -> Callable:\n    @wraps(fn)\n    def wrapped_fn(*args, **kwargs) -> Optional[Any]:\n        if node_zero_only.node == 0:\n            return fn(*args, **kwargs)\n        return None\n    return wrapped_fn\n\nnode_zero_only.node = getattr(node_zero_only, 'node', int(os.environ.get('NODE_RANK', 0)))\n\ndef node_zero_experiment(fn: Callable) -> Callable:\n    \"\"\"Returns the real experiment on rank 0 and otherwise the DummyExperiment.\"\"\"\n    @wraps(fn)\n    def experiment(self):\n        @node_zero_only\n        def get_experiment():\n            return fn(self)\n        return get_experiment() or DummyLogger.experiment\n    return experiment\n\n# customize wandb for node 0 only\nclass MyWandbLogger(WandbLogger):\n    @WandbLogger.experiment.getter\n    @node_zero_experiment\n    def experiment(self):\n        return super().experiment\n\nclass SetupCallback(Callback):\n    def __init__(self, config: DictConfig, exp_config: DictConfig,\n                 basedir: Path, logdir: str = \"log\", ckptdir: str = \"ckpt\") -> None:\n        super().__init__()\n        self.logdir = basedir / logdir\n        self.ckptdir = basedir / ckptdir\n        self.config = config\n        self.exp_config = exp_config\n\n    # def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:\n    #     if trainer.global_rank == 0:\n    #         # Create logdirs and save configs\n    #         os.makedirs(self.logdir, exist_ok=True)\n    #         os.makedirs(self.ckptdir, exist_ok=True)\n    #\n    #         print(\"Experiment config\")\n    #         print(self.exp_config.pretty())\n    #\n    #         print(\"Model config\")\n    #         print(self.config.pretty())\n\n    def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:\n        if trainer.global_rank == 0:\n            # Create logdirs and save configs\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n\n            # print(\"Experiment config\")\n            # pprint(self.exp_config)\n            #\n            # print(\"Model config\")\n            # pprint(self.config)\n\n\nclass ImageLogger(Callback):\n    def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True,\n                 increase_log_steps: bool = True) -> None:\n\n        super().__init__()\n        self.batch_freq = batch_frequency\n        self.max_images = max_images\n        self.logger_log_images = {\n            pl.loggers.WandbLogger: self._wandb,\n            pl.loggers.TestTubeLogger: self._testtube,\n        }\n        self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]\n        if not increase_log_steps:\n            self.log_steps = [self.batch_freq]\n        self.clamp = clamp\n\n    @rank_zero_only\n    def _wandb(self, pl_module, images, batch_idx, split):\n        # raise ValueError(\"No way wandb\")\n        grids = dict()\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grids[f\"{split}/{k}\"] = wandb.Image(grid)\n        pl_module.logger.experiment.log(grids)\n\n    @rank_zero_only\n    def _testtube(self, pl_module, images, batch_idx, split):\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n\n            tag = f\"{split}/{k}\"\n            pl_module.logger.experiment.add_image(\n                tag, grid,\n                global_step=pl_module.global_step)\n\n    @rank_zero_only\n    def log_local(self, save_dir: str, split: str, images: Dict,\n                  global_step: int, current_epoch: int, batch_idx: int) -> None:\n        root = os.path.join(save_dir, \"results\", split)\n        os.makedirs(root, exist_ok=True)\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k], nrow=4)\n\n            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)\n            grid = grid.numpy()\n            grid = (grid * 255).astype(np.uint8)\n            filename = \"{}_gs-{:06}_e-{:06}_b-{:06}.png\".format(\n                k,\n                global_step,\n                current_epoch,\n                batch_idx)\n            path = os.path.join(root, filename)\n            os.makedirs(os.path.split(path)[0], exist_ok=True)\n            Image.fromarray(grid).save(path)\n\n    def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int,\n                split: str = \"train\") -> None:\n        if (self.check_frequency(batch_idx) and  # batch_idx % self.batch_freq == 0\n                hasattr(pl_module, \"log_images\") and\n                callable(pl_module.log_images) and\n                self.max_images > 0):\n            logger = type(pl_module.logger)\n\n            is_train = pl_module.training\n            if is_train:\n                pl_module.eval()\n\n            with torch.no_grad():\n                images = pl_module.log_images(batch, split=split, pl_module=pl_module)\n\n            for k in images:\n                N = min(images[k].shape[0], self.max_images)\n                images[k] = images[k][:N].detach().cpu()\n                if self.clamp:\n                    images[k] = images[k].clamp(0, 1)\n\n            self.log_local(pl_module.logger.save_dir, split, images,\n                           pl_module.global_step, pl_module.current_epoch, batch_idx)\n\n            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)\n            logger_log_images(pl_module, images, pl_module.global_step, split)\n\n            if is_train:\n                pl_module.train()\n\n    def check_frequency(self, batch_idx: int) -> bool:\n        if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):\n            try:\n                self.log_steps.pop(0)\n            except IndexError:\n                pass\n            return True\n        return False\n\n    def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,\n                           outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:\n        self.log_img(pl_module, batch, batch_idx, split=\"train\")\n\n    def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,\n                                outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],\n                                dataloader_idx: int, batch_idx: int) -> None:\n        self.log_img(pl_module, batch, batch_idx, split=\"val\")\n\n\nclass CUDACallback(Callback):\n    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py\n    def on_train_epoch_start(self, trainer, pl_module):\n        # Reset the memory use counter\n        torch.cuda.reset_peak_memory_stats(trainer.root_gpu)\n        torch.cuda.synchronize(trainer.root_gpu)\n        self.start_time = time.time()\n\n    def on_train_epoch_end(self, trainer, pl_module, outputs):\n        torch.cuda.synchronize(trainer.root_gpu)\n        max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20\n        epoch_time = time.time() - self.start_time\n\n        try:\n            max_memory = trainer.training_type_plugin.reduce(max_memory)\n            epoch_time = trainer.training_type_plugin.reduce(epoch_time)\n\n            rank_zero_info(f\"Average Epoch time: {epoch_time:.2f} seconds\")\n            rank_zero_info(f\"Average Peak memory {max_memory:.2f}MiB\")\n        except AttributeError:\n            pass\n"
  },
  {
    "path": "ultrashape/utils/trainings/lr_scheduler.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport numpy as np\n\n\nclass BaseScheduler(object):\n\n    def schedule(self, n, **kwargs):\n        raise NotImplementedError\n\n\nclass LambdaWarmUpCosineFactorScheduler(BaseScheduler):\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n    def __init__(self, warm_up_steps, f_min, f_max, f_start, max_decay_steps, verbosity_interval=0, **ignore_kwargs):\n        self.lr_warm_up_steps = warm_up_steps\n        self.f_start = f_start\n        self.f_min = f_min\n        self.f_max = f_max\n        self.lr_max_decay_steps = max_decay_steps\n        self.last_f = 0.\n        self.verbosity_interval = verbosity_interval\n\n    def schedule(self, n, **kwargs):\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(f\"current step: {n}, recent lr-multiplier: {self.f_start}\")\n        if n < self.lr_warm_up_steps:\n            f = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start\n            self.last_f = f\n            return f\n        else:\n            t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)\n            t = min(t, 1.0)\n            f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi))\n            self.last_f = f\n            return f\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n"
  },
  {
    "path": "ultrashape/utils/trainings/mesh.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\nimport cv2\nimport numpy as np\nimport PIL.Image\nfrom typing import Optional\n\nimport trimesh\n\n\ndef save_obj(pointnp_px3, facenp_fx3, fname):\n    fid = open(fname, \"w\")\n    write_str = \"\"\n    for pidx, p in enumerate(pointnp_px3):\n        pp = p\n        write_str += \"v %f %f %f\\n\" % (pp[0], pp[1], pp[2])\n\n    for i, f in enumerate(facenp_fx3):\n        f1 = f + 1\n        write_str += \"f %d %d %d\\n\" % (f1[0], f1[1], f1[2])\n    fid.write(write_str)\n    fid.close()\n    return\n\n\ndef savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):\n    fol, na = os.path.split(fname)\n    na, _ = os.path.splitext(na)\n\n    matname = \"%s/%s.mtl\" % (fol, na)\n    fid = open(matname, \"w\")\n    fid.write(\"newmtl material_0\\n\")\n    fid.write(\"Kd 1 1 1\\n\")\n    fid.write(\"Ka 0 0 0\\n\")\n    fid.write(\"Ks 0.4 0.4 0.4\\n\")\n    fid.write(\"Ns 10\\n\")\n    fid.write(\"illum 2\\n\")\n    fid.write(\"map_Kd %s.png\\n\" % na)\n    fid.close()\n    ####\n\n    fid = open(fname, \"w\")\n    fid.write(\"mtllib %s.mtl\\n\" % na)\n\n    for pidx, p3 in enumerate(pointnp_px3):\n        pp = p3\n        fid.write(\"v %f %f %f\\n\" % (pp[0], pp[1], pp[2]))\n\n    for pidx, p2 in enumerate(tcoords_px2):\n        pp = p2\n        fid.write(\"vt %f %f\\n\" % (pp[0], pp[1]))\n\n    fid.write(\"usemtl material_0\\n\")\n    for i, f in enumerate(facenp_fx3):\n        f1 = f + 1\n        f2 = facetex_fx3[i] + 1\n        fid.write(\"f %d/%d %d/%d %d/%d\\n\" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))\n    fid.close()\n\n    PIL.Image.fromarray(np.ascontiguousarray(tex_map), \"RGB\").save(\n        os.path.join(fol, \"%s.png\" % na))\n\n    return\n\n\nclass MeshOutput(object):\n\n    def __init__(self,\n                 mesh_v: np.ndarray,\n                 mesh_f: np.ndarray,\n                 vertex_colors: Optional[np.ndarray] = None,\n                 uvs: Optional[np.ndarray] = None,\n                 mesh_tex_idx: Optional[np.ndarray] = None,\n                 tex_map: Optional[np.ndarray] = None):\n\n        self.mesh_v = mesh_v\n        self.mesh_f = mesh_f\n        self.vertex_colors = vertex_colors\n        self.uvs = uvs\n        self.mesh_tex_idx = mesh_tex_idx\n        self.tex_map = tex_map\n\n    def contain_uv_texture(self):\n        return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)\n\n    def contain_vertex_colors(self):\n        return self.vertex_colors is not None\n\n    def export(self, fname):\n\n        if self.contain_uv_texture():\n            savemeshtes2(\n                self.mesh_v,\n                self.uvs,\n                self.mesh_f,\n                self.mesh_tex_idx,\n                self.tex_map,\n                fname\n            )\n\n        elif self.contain_vertex_colors():\n            mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)\n            mesh_obj.export(fname)\n\n        else:\n            save_obj(\n                self.mesh_v,\n                self.mesh_f,\n                fname\n            )\n\n\n\n"
  },
  {
    "path": "ultrashape/utils/trainings/mesh_log_callback.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport json\nimport math\nimport os\nfrom typing import Tuple, Generic, Dict, List, Union, Optional\n\nimport trimesh\nimport numpy as np\nimport pytorch_lightning as pl\nimport pytorch_lightning.loggers\nimport torch\nimport torchvision\nfrom pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.utilities import rank_zero_only\n\nfrom hy3dshape.pipelines import export_to_trimesh\nfrom hy3dshape.utils.trainings.mesh import MeshOutput\nfrom hy3dshape.utils.visualizers import html_util\nfrom hy3dshape.utils.visualizers.pythreejs_viewer import PyThreeJSViewer\n\n\nclass ImageConditionalASLDiffuserLogger(Callback): \n    def __init__(self,\n                 step_frequency: int,\n                 num_samples: int = 1,\n                 mean: Optional[Union[List[float], Tuple[float]]] = None,\n                 std: Optional[Union[List[float], Tuple[float]]] = None,\n                 bounds: Union[List[float], Tuple[float]] = (-1.1, -1.1, -1.1, 1.1, 1.1, 1.1),\n                 **kwargs) -> None:\n\n        super().__init__()\n        self.bbox_size = np.array(bounds[3:6]) - np.array(bounds[0:3])\n\n        if mean is not None:\n            mean = np.asarray(mean)\n\n        if std is not None:\n            std = np.asarray(std)\n\n        self.mean = mean\n        self.std = std\n\n        self.step_freq = step_frequency\n        self.num_samples = num_samples\n        self.has_train_logged = False\n        self.logger_log_images = {\n            pl.loggers.WandbLogger: self._wandb,\n        }\n\n        self.viewer = PyThreeJSViewer(settings={}, render_mode=\"WEBSITE\")\n\n    @rank_zero_only\n    def _wandb(self, pl_module, images, batch_idx, split):\n        # raise ValueError(\"No way wandb\")\n        grids = dict()\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grids[f\"{split}/{k}\"] = wandb.Image(grid)\n        pl_module.logger.experiment.log(grids)\n\n    def log_local(self,\n                  outputs: List[List['Latent2MeshOutput']],\n                  images: Union[np.ndarray, List[np.ndarray]],\n                  description: List[str],\n                  keys: List[str],\n                  save_dir: str, split: str,\n                  global_step: int, current_epoch: int, batch_idx: int,\n                  prog_bar: bool = False,\n                  multi_views=None,  # yf ...\n                  ) -> None:\n\n        folder = \"gs-{:010}_e-{:06}_b-{:06}\".format(global_step, current_epoch, batch_idx)\n        visual_dir = os.path.join(save_dir, \"visuals\", split, folder)\n        os.makedirs(visual_dir, exist_ok=True)\n\n        num_samples = len(images)\n        \n        for i in range(num_samples):\n            key_i = keys[i]\n            image_i = self.denormalize_image(images[i])\n            shape_tag_i = description[i]\n\n            for j in range(1):\n                mesh = outputs[j][i]\n                if mesh is None:\n                    continue\n\n                mesh_v = mesh.mesh_v.copy()\n                mesh_v[:, 0] += j * np.max(self.bbox_size)\n                self.viewer.add_mesh(mesh_v, mesh.mesh_f)\n\n            image_tag = html_util.to_image_embed_tag(image_i)\n            mesh_tag = self.viewer.to_html(html_frame=False)\n\n            table_tag = f\"\"\"\n            <table border = \"1\">\n                <caption> {shape_tag_i} - {key_i} </caption>\n                <caption> Input Image | Generated Mesh </caption>\n                <tr>\n                    <td>{image_tag}</td>\n                    <td>{mesh_tag}</td>\n                </tr>\n            </table>\n            \"\"\"\n\n            if multi_views is not None:\n                multi_views_i = self.make_grid(multi_views[i])\n                views_tag = html_util.to_image_embed_tag(self.denormalize_image(multi_views_i))\n                table_tag = f\"\"\"\n                <table border = \"1\">\n                    <caption> {shape_tag_i} - {key_i} </caption>\n                    <caption> Input Image | Generated Mesh </caption>\n                    <tr>\n                        <td>{image_tag}</td>\n                        <td>{views_tag}</td>\n                        <td>{mesh_tag}</td>\n                    </tr>\n                </table>\n                \"\"\"\n\n            html_frame = html_util.to_html_frame(table_tag)\n            if len(key_i) > 100:\n                key_i = key_i[:100]\n            with open(os.path.join(visual_dir, f\"{key_i}.html\"), \"w\") as writer:\n                writer.write(html_frame)\n\n            self.viewer.reset()\n\n    def log_sample(self,\n                   pl_module: pl.LightningModule,\n                   batch: Dict[str, torch.FloatTensor],\n                   batch_idx: int,\n                   split: str = \"train\") -> None:\n        \"\"\"\n\n        Args:\n            pl_module:\n            batch (dict): the batch sample information, and it contains:\n                 - surface (torch.FloatTensor):\n                 - image (torch.FloatTensor):\n            batch_idx (int):\n            split (str):\n\n        Returns:\n\n        \"\"\"\n\n        is_train = pl_module.training\n        if is_train:\n            pl_module.eval()\n\n        batch_size = len(batch[\"surface\"])\n        replace = batch_size < self.num_samples\n        ids = np.random.choice(batch_size, self.num_samples, replace=replace)\n\n        with torch.no_grad():\n            # run text to mesh\n            # keys = [batch[\"__key__\"][i] for i in ids]\n            keys = [f'key_{i}' for i in ids]\n            # texts = [batch[\"text\"][i] for i in ids]\n            texts = [f'text_{i}'for i in ids]\n            # description = [batch[\"description\"][i] for i in ids]\n            description = [f'desc_{i}_{os.path.splitext(os.path.basename(batch[\"uid\"][i]))[0]}' for i in ids]\n            images = batch[\"image\"][ids]\n            mask_input = batch[\"mask\"][ids] if 'mask' in batch else None\n            # uids = batch[\"uid\"][ids]\n            sample_batch = {\n                \"__key__\": keys,\n                \"image\": images,\n                'text': texts,\n                'mask': mask_input,\n            }\n\n            # if 'cam_parm' in batch:\n            #     sample_batch['cam_parm'] = batch['cam_parm'][ids]\n\n            # if 'multi_views' in batch:  # yf ...\n            #     sample_batch['multi_views'] = batch['multi_views'][ids]\n\n            outputs = pl_module.sample(\n                batch=sample_batch,\n                output_type='latents2mesh'\n            )\n\n            images = images.cpu().float().numpy()\n            # images = self.denormalize_image(images)\n            # images = np.transpose(images, (0, 2, 3, 1))\n            # images = ((images + 1) / 2 * 255).astype(np.uint8)\n\n        self.log_local(outputs, images, description, keys, pl_module.logger.save_dir, split,\n                       pl_module.global_step, pl_module.current_epoch, batch_idx, prog_bar=False,\n                       multi_views=sample_batch.get('multi_views'))\n\n        if is_train: pl_module.train()\n\n    def make_grid(self, images):  # return (3,h,w) in (0,1) ...\n        images_resized = []\n        for img in images:\n            img_resized = torchvision.transforms.functional.resize(img, (320, 320))\n            images_resized.append(img_resized)\n        image = torchvision.utils.make_grid(images_resized, nrow=2, padding=5, pad_value=255)\n\n        image = image.cpu().numpy()\n        #       image = np.transpose(image, (1, 2, 0))\n        #       image = (image * 255).astype(np.uint8)\n\n        return image\n\n    def check_frequency(self, step: int) -> bool:\n        if step % self.step_freq == 0:\n            return True\n        return False\n\n    def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,\n                           outputs: Generic, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> None:\n\n        if (self.check_frequency(pl_module.global_step) and  # batch_idx % self.batch_freq == 0\n            hasattr(pl_module, \"sample\") and\n            callable(pl_module.sample) and\n            self.num_samples > 0):\n            self.log_sample(pl_module, batch, batch_idx, split=\"train\")\n            self.has_train_logged = True\n\n    def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,\n                                outputs: Generic, batch: Dict[str, torch.FloatTensor],\n                                dataloader_idx: int, batch_idx: int) -> None:\n\n        if self.has_train_logged:\n            self.log_sample(pl_module, batch, batch_idx, split=\"val\")\n            self.has_train_logged = False\n\n    def denormalize_image(self, image):\n        \"\"\"\n\n        Args:\n            image (np.ndarray): [3, h, w]\n\n        Returns:\n            image (np.ndarray): [h, w, 3], np.uint8, [0, 255].\n        \"\"\"\n        # image = np.transpose(image, (0, 2, 3, 1))\n        image = np.transpose(image, (1, 2, 0))\n\n        if self.std is not None:\n            image = image * self.std\n\n        if self.mean is not None:\n            image = image + self.mean\n\n        image = (image * 255).astype(np.uint8)\n\n        return image\n\n\nclass ImageConditionalFixASLDiffuserLogger(Callback):\n    def __init__(\n        self,\n        step_frequency: int,\n        test_data_path: str,\n        max_size: int = None,\n        save_dir: str = 'infer',\n        **kwargs,\n    ) -> None:\n        super().__init__()\n        self.step_freq = step_frequency\n        self.viewer = PyThreeJSViewer(settings={}, render_mode=\"WEBSITE\")\n\n        self.test_data_path = test_data_path\n        with open(self.test_data_path, 'r') as f:\n            data = json.load(f)\n            self.file_list = data['file_list']\n            # self.file_folder = data['file_folder']\n            if max_size is not None:\n                self.file_list = self.file_list[:max_size]\n        self.kwargs = kwargs\n        self.save_dir = save_dir\n\n    def on_train_batch_end(\n        self,\n        trainer: pl.trainer.Trainer,\n        pl_module: pl.LightningModule,\n        outputs: Generic,\n        batch: Dict[str, torch.FloatTensor],\n        batch_idx: int,\n    ):\n        if pl_module.global_step % self.step_freq == 0:\n            with open(self.test_data_path, 'r') as f:\n                data = json.load(f)\n                self.file_list = data['file_list']\n            is_train = pl_module.training\n            if is_train:\n                pl_module.eval()\n\n            # folder_path = self.file_folder\n            # folder_name = os.path.basename(folder_path)\n            folder = \"gs-{:010}_e-{:06}_b-{:06}\".format(pl_module.global_step, pl_module.current_epoch, batch_idx)\n            visual_dir = os.path.join(pl_module.logger.save_dir, self.save_dir, folder)\n            os.makedirs(visual_dir, exist_ok=True)\n\n            image_paths = self.file_list\n            chunk_size = math.ceil(len(image_paths) / trainer.world_size)\n            if pl_module.global_rank == trainer.world_size - 1:\n                image_paths = image_paths[pl_module.global_rank * chunk_size:]\n            else:\n                image_paths = image_paths[pl_module.global_rank * chunk_size:(pl_module.global_rank + 1) * chunk_size]\n\n            print(f'Rank{pl_module.global_rank}: processing {len(image_paths)}|{len(self.file_list)} images')\n            for image_path in image_paths:\n                # if folder_path in image_path:\n                #     save_path = image_path.replace(folder_path, visual_dir)\n                # else:\n                save_path = os.path.join(visual_dir, os.path.basename(image_path))\n                save_path = os.path.splitext(save_path)[0] + '.glb'\n\n                if isinstance(image_path, str):\n                    print(image_path)\n                    \n                with torch.no_grad():\n                    mesh = pl_module.sample(batch={\"image\": image_path}, **self.kwargs)[0][0]\n                    if isinstance(mesh, tuple) and len(mesh)==2:\n                        mesh = export_to_trimesh(mesh)\n                    elif isinstance(mesh, trimesh.Trimesh):\n                        os.makedirs(os.path.dirname(save_path), exist_ok=True)\n                        mesh.export(save_path)\n\n            if is_train:\n                pl_module.train()\n"
  },
  {
    "path": "ultrashape/utils/trainings/peft.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport os\nfrom pytorch_lightning.callbacks import Callback\nfrom omegaconf import OmegaConf, ListConfig\n\nclass PeftSaveCallback(Callback):\n    def __init__(self, peft_model, save_dir: str, save_every_n_steps: int = None):\n        super().__init__()\n        self.peft_model = peft_model\n        self.save_dir = save_dir\n        self.save_every_n_steps = save_every_n_steps\n        os.makedirs(self.save_dir, exist_ok=True)\n\n    def recursive_convert(self, obj):\n        from omegaconf import OmegaConf, ListConfig\n        if isinstance(obj, (OmegaConf, ListConfig)):\n            return OmegaConf.to_container(obj, resolve=True)\n        elif isinstance(obj, dict):\n            return {k: self.recursive_convert(v) for k, v in obj.items()}\n        elif isinstance(obj, list):\n            return [self.recursive_convert(i) for i in obj]\n        elif isinstance(obj, type):\n            # 避免修改类对象\n            return obj\n        elif hasattr(obj, '__dict__'):\n            for attr_name, attr_value in vars(obj).items():\n                setattr(obj, attr_name, self.recursive_convert(attr_value))\n            return obj\n        else:\n            return obj\n\n    # def recursive_convert(self, obj):\n    #     if isinstance(obj, (OmegaConf, ListConfig)):\n    #         return OmegaConf.to_container(obj, resolve=True)\n    #     elif isinstance(obj, dict):\n    #         return {k: self.recursive_convert(v) for k, v in obj.items()}\n    #     elif isinstance(obj, list):\n    #         return [self.recursive_convert(i) for i in obj]\n    #     elif hasattr(obj, '__dict__'):\n    #         for attr_name, attr_value in vars(obj).items():\n    #             setattr(obj, attr_name, self.recursive_convert(attr_value))\n    #         return obj\n    #     else:\n    #         return obj\n\n    def _convert_peft_config(self):\n        pc = self.peft_model.peft_config\n        self.peft_model.peft_config = self.recursive_convert(pc)\n\n    def on_train_epoch_end(self, trainer, pl_module):\n        self._convert_peft_config()\n        save_path = os.path.join(self.save_dir, f\"epoch_{trainer.current_epoch}\")\n        self.peft_model.save_pretrained(save_path)\n        print(f\"[PeftSaveCallback] Saved LoRA weights to {save_path}\")\n\n    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n        if self.save_every_n_steps is not None:\n            global_step = trainer.global_step\n            if global_step % self.save_every_n_steps == 0 and global_step > 0:\n                self._convert_peft_config()\n                save_path = os.path.join(self.save_dir, f\"step_{global_step}\")\n                self.peft_model.save_pretrained(save_path)\n                print(f\"[PeftSaveCallback] Saved LoRA weights to {save_path}\")\n"
  },
  {
    "path": "ultrashape/utils/typing.py",
    "content": "\"\"\"\nThis module contains type annotations for the project, using\n1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects\n2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors\n\nTwo types of typing checking can be used:\n1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)\n2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)\n\"\"\"\n\n# Basic types\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Literal,\n    NamedTuple,\n    NewType,\n    Optional,\n    Sized,\n    Tuple,\n    Type,\n    TypeVar,\n    Union,\n    Sequence,\n)\n\n# Tensor dtype\n# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md\nfrom jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt\n\n# Config type\nfrom omegaconf import DictConfig\n\n# PyTorch Tensor type\nfrom torch import Tensor\n\n# Runtime type checking decorator\nfrom typeguard import typechecked as typechecker\n"
  },
  {
    "path": "ultrashape/utils/utils.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport logging\nimport os\nfrom functools import wraps\n\nimport torch\n\n\ndef get_logger(name):\n    logger = logging.getLogger(name)\n    logger.setLevel(logging.INFO)\n\n    console_handler = logging.StreamHandler()\n    console_handler.setLevel(logging.INFO)\n\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    console_handler.setFormatter(formatter)\n    logger.addHandler(console_handler)\n    return logger\n\n\nlogger = get_logger('hy3dgen.shapgen')\n\n\nclass synchronize_timer:\n    \"\"\" Synchronized timer to count the inference time of `nn.Module.forward`.\n\n        Supports both context manager and decorator usage.\n\n        Example as context manager:\n        ```python\n        with synchronize_timer('name') as t:\n            run()\n        ```\n\n        Example as decorator:\n        ```python\n        @synchronize_timer('Export to trimesh')\n        def export_to_trimesh(mesh_output):\n            pass\n        ```\n    \"\"\"\n\n    def __init__(self, name=None):\n        self.name = name\n\n    def __enter__(self):\n        \"\"\"Context manager entry: start timing.\"\"\"\n        if os.environ.get('HY3DGEN_DEBUG', '0') == '1':\n            self.start = torch.cuda.Event(enable_timing=True)\n            self.end = torch.cuda.Event(enable_timing=True)\n            self.start.record()\n            return lambda: self.time\n\n    def __exit__(self, exc_type, exc_value, exc_tb):\n        \"\"\"Context manager exit: stop timing and log results.\"\"\"\n        if os.environ.get('HY3DGEN_DEBUG', '0') == '1':\n            self.end.record()\n            torch.cuda.synchronize()\n            self.time = self.start.elapsed_time(self.end)\n            if self.name is not None:\n                logger.info(f'{self.name} takes {self.time} ms')\n\n    def __call__(self, func):\n        \"\"\"Decorator: wrap the function to time its execution.\"\"\"\n\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            with self:\n                result = func(*args, **kwargs)\n            return result\n\n        return wrapper\n\n\ndef smart_load_model(\n    model_path,\n    subfolder,\n    use_safetensors,\n    variant,\n):\n    original_model_path = model_path\n    # try local path\n    base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')\n    model_fld = os.path.expanduser(os.path.join(base_dir, model_path))\n    model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))\n    logger.info(f'Try to load model from local path: {model_path}')\n    if not os.path.exists(model_path):\n        logger.info('Model path not exists, try to download from huggingface')\n        try:\n            from huggingface_hub import snapshot_download\n            # 只下载指定子目录\n            path = snapshot_download(\n                repo_id=original_model_path,\n                allow_patterns=[f\"{subfolder}/*\"],  # 关键修改：模式匹配子文件夹\n                local_dir=model_fld \n            )\n            model_path = os.path.join(path, subfolder)  # 保持路径拼接逻辑不变\n        except ImportError:\n            logger.warning(\n                \"You need to install HuggingFace Hub to load models from the hub.\"\n            )\n            raise RuntimeError(f\"Model path {model_path} not found\")\n        except Exception as e:\n            raise e\n\n    if not os.path.exists(model_path):\n        raise FileNotFoundError(f\"Model path {original_model_path} not found\")\n\n    extension = 'ckpt' if not use_safetensors else 'safetensors'\n    variant = '' if variant is None else f'.{variant}'\n    ckpt_name = f'model{variant}.{extension}'\n    config_path = os.path.join(model_path, 'config.yaml')\n    ckpt_path = os.path.join(model_path, ckpt_name)\n    return config_path, ckpt_path\n"
  },
  {
    "path": "ultrashape/utils/visualizers/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "ultrashape/utils/visualizers/color_util.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\n# Helper functions\ndef get_colors(inp, colormap=\"viridis\", normalize=True, vmin=None, vmax=None):\n    colormap = plt.cm.get_cmap(colormap)\n    if normalize:\n        vmin = np.min(inp)\n        vmax = np.max(inp)\n\n    norm = plt.Normalize(vmin, vmax)\n    return colormap(norm(inp))[:, :3]\n\n\ndef gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):\n    # tex dims need to be power of two.\n    array = np.ones((width, height, 3), dtype='float32')\n\n    # width in texels of each checker\n    checker_w = width / n_checkers_x\n    checker_h = height / n_checkers_y\n\n    for y in range(height):\n        for x in range(width):\n            color_key = int(x / checker_w) + int(y / checker_h)\n            if color_key % 2 == 0:\n                array[x, y, :] = [1., 0.874, 0.0]\n            else:\n                array[x, y, :] = [0., 0., 0.]\n    return array\n\n\ndef gen_circle(width=256, height=256):\n    xx, yy = np.mgrid[:width, :height]\n    circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2\n    array = np.ones((width, height, 4), dtype='float32')\n    array[:, :, 0] = (circle <= width)\n    array[:, :, 1] = (circle <= width)\n    array[:, :, 2] = (circle <= width)\n    array[:, :, 3] = circle <= width\n    return array\n\n"
  },
  {
    "path": "ultrashape/utils/visualizers/html_util.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\nimport io\nimport base64\nimport numpy as np\nfrom PIL import Image\n\n\ndef to_html_frame(content):\n\n    html_frame = f\"\"\"\n    <html>\n      <body>\n        {content}\n      </body>\n    </html>\n    \"\"\"\n\n    return html_frame\n\n\ndef to_single_row_table(caption: str, content: str):\n\n    table_html = f\"\"\"\n    <table border = \"1\">\n        <caption>{caption}</caption>\n        <tr>\n            <td>{content}</td>\n        </tr>\n    </table>\n    \"\"\"\n\n    return table_html\n\n\ndef to_image_embed_tag(image: np.ndarray):\n\n    # Convert np.ndarray to bytes\n    img = Image.fromarray(image)\n    raw_bytes = io.BytesIO()\n    img.save(raw_bytes, \"PNG\")\n\n    # Encode bytes to base64\n    image_base64 = base64.b64encode(raw_bytes.getvalue()).decode(\"utf-8\")\n\n    image_tag = f\"\"\"\n    <img src=\"data:image/png;base64,{image_base64}\" alt=\"Embedded Image\">\n    \"\"\"\n\n    return image_tag\n"
  },
  {
    "path": "ultrashape/utils/visualizers/pythreejs_viewer.py",
    "content": "# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n# except for the third-party components listed below.\n# Hunyuan 3D does not impose any additional limitations beyond what is outlined\n# in the repsective licenses of these third-party components.\n# Users must comply with all terms and conditions of original licenses of these third-party\n# components and must ensure that the usage of the third party components adheres to\n# all relevant laws and regulations.\n\n# For avoidance of doubts, Hunyuan 3D means the large language models and\n# their software and algorithms, including trained model weights, parameters (including\n# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,\n# fine-tuning enabling code and other elements of the foregoing made publicly available\n# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.\n\n\nimport numpy as np\nfrom ipywidgets import embed\nimport pythreejs as p3s\nimport uuid\n\nfrom .color_util import get_colors, gen_circle, gen_checkers\n\n\nEMBED_URL = \"https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js\"\n\n\nclass PyThreeJSViewer(object):\n\n    def __init__(self, settings, render_mode=\"WEBSITE\"):\n        self.render_mode = render_mode\n        self.__update_settings(settings)\n        self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)\n        self._light2 = p3s.AmbientLight(intensity=0.5)\n        self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s[\"fov\"],\n                                          aspect=self.__s[\"width\"] / self.__s[\"height\"], children=[self._light])\n        self._orbit = p3s.OrbitControls(controlling=self._cam)\n        self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s[\"background\"])  # \"#4c4c80\"\n        self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],\n                                      width=self.__s[\"width\"], height=self.__s[\"height\"],\n                                      antialias=self.__s[\"antialias\"])\n\n        self.__objects = {}\n        self.__cnt = 0\n\n    def jupyter_mode(self):\n        self.render_mode = \"JUPYTER\"\n\n    def offline(self):\n        self.render_mode = \"OFFLINE\"\n\n    def website(self):\n        self.render_mode = \"WEBSITE\"\n\n    def __get_shading(self, shading):\n        shad = {\"flat\": True, \"wireframe\": False, \"wire_width\": 0.03, \"wire_color\": \"black\",\n                \"side\": 'DoubleSide', \"colormap\": \"viridis\", \"normalize\": [None, None],\n                \"bbox\": False, \"roughness\": 0.5, \"metalness\": 0.25, \"reflectivity\": 1.0,\n                \"line_width\": 1.0, \"line_color\": \"black\",\n                \"point_color\": \"red\", \"point_size\": 0.01, \"point_shape\": \"circle\",\n                \"text_color\": \"red\"\n                }\n        for k in shading:\n            shad[k] = shading[k]\n        return shad\n\n    def __update_settings(self, settings={}):\n        sett = {\"width\": 1600, \"height\": 800, \"antialias\": True, \"scale\": 1.5, \"background\": \"#ffffff\",\n                \"fov\": 30}\n        for k in settings:\n            sett[k] = settings[k]\n        self.__s = sett\n\n    def __add_object(self, obj, parent=None):\n        if not parent:  # Object is added to global scene and objects dict\n            self.__objects[self.__cnt] = obj\n            self.__cnt += 1\n            self._scene.add(obj[\"mesh\"])\n        else:  # Object is added to parent object and NOT to objects dict\n            parent.add(obj[\"mesh\"])\n\n        self.__update_view()\n\n        if self.render_mode == \"JUPYTER\":\n            return self.__cnt - 1\n        elif self.render_mode == \"WEBSITE\":\n            return self\n\n    def __add_line_geometry(self, lines, shading, obj=None):\n        lines = lines.astype(\"float32\", copy=False)\n        mi = np.min(lines, axis=0)\n        ma = np.max(lines, axis=0)\n\n        geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))\n        material = p3s.LineMaterial(linewidth=shading[\"line_width\"], color=shading[\"line_color\"])\n        # , vertexColors='VertexColors'),\n        lines = p3s.LineSegments2(geometry=geometry, material=material)  # type='LinePieces')\n        line_obj = {\"geometry\": geometry, \"mesh\": lines, \"material\": material,\n                    \"max\": ma, \"min\": mi, \"type\": \"Lines\", \"wireframe\": None}\n\n        if obj:\n            return self.__add_object(line_obj, obj), line_obj\n        else:\n            return self.__add_object(line_obj)\n\n    def __update_view(self):\n        if len(self.__objects) == 0:\n            return\n        ma = np.zeros((len(self.__objects), 3))\n        mi = np.zeros((len(self.__objects), 3))\n        for r, obj in enumerate(self.__objects):\n            ma[r] = self.__objects[obj][\"max\"]\n            mi[r] = self.__objects[obj][\"min\"]\n        ma = np.max(ma, axis=0)\n        mi = np.min(mi, axis=0)\n        diag = np.linalg.norm(ma - mi)\n        mean = ((ma - mi) / 2 + mi).tolist()\n        scale = self.__s[\"scale\"] * (diag)\n        self._orbit.target = mean\n        self._cam.lookAt(mean)\n        self._cam.position = [mean[0], mean[1], mean[2] + scale]\n        self._light.position = [mean[0], mean[1], mean[2] + scale]\n\n        self._orbit.exec_three_obj_method('update')\n        self._cam.exec_three_obj_method('updateProjectionMatrix')\n\n    def __get_bbox(self, v):\n        m = np.min(v, axis=0)\n        M = np.max(v, axis=0)\n\n        # Corners of the bounding box\n        v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],\n                          [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])\n\n        f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],\n                          [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)\n        return v_box, f_box\n\n    def __get_colors(self, v, f, c, sh):\n        coloring = \"VertexColors\"\n        if type(c) == np.ndarray and c.size == 3:  # Single color\n            colors = np.ones_like(v)\n            colors[:, 0] = c[0]\n            colors[:, 1] = c[1]\n            colors[:, 2] = c[2]\n            # print(\"Single colors\")\n        elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3:  # Color values for\n            if c.shape[0] == f.shape[0]:  # faces\n                colors = np.hstack([c, c, c]).reshape((-1, 3))\n                coloring = \"FaceColors\"\n                # print(\"Face color values\")\n            elif c.shape[0] == v.shape[0]:  # vertices\n                colors = c\n                # print(\"Vertex color values\")\n            else:  # Wrong size, fallback\n                print(\"Invalid color array given! Supported are numpy arrays.\", type(c))\n                colors = np.ones_like(v)\n                colors[:, 0] = 1.0\n                colors[:, 1] = 0.874\n                colors[:, 2] = 0.0\n        elif type(c) == np.ndarray and c.size == f.shape[0]:  # Function values for faces\n            normalize = sh[\"normalize\"][0] != None and sh[\"normalize\"][1] != None\n            cc = get_colors(c, sh[\"colormap\"], normalize=normalize,\n                            vmin=sh[\"normalize\"][0], vmax=sh[\"normalize\"][1])\n            # print(cc.shape)\n            colors = np.hstack([cc, cc, cc]).reshape((-1, 3))\n            coloring = \"FaceColors\"\n            # print(\"Face function values\")\n        elif type(c) == np.ndarray and c.size == v.shape[0]:  # Function values for vertices\n            normalize = sh[\"normalize\"][0] != None and sh[\"normalize\"][1] != None\n            colors = get_colors(c, sh[\"colormap\"], normalize=normalize,\n                                vmin=sh[\"normalize\"][0], vmax=sh[\"normalize\"][1])\n            # print(\"Vertex function values\")\n\n        else:\n            colors = np.ones_like(v)\n            colors[:, 0] = 1.0\n            colors[:, 1] = 0.874\n            colors[:, 2] = 0.0\n\n            # No color\n            if c is not None:\n                print(\"Invalid color array given! Supported are numpy arrays.\", type(c))\n\n        return colors, coloring\n\n    def __get_point_colors(self, v, c, sh):\n        v_color = True\n        if c is None:  # No color given, use global color\n            # conv = mpl.colors.ColorConverter()\n            colors = sh[\"point_color\"]  # np.array(conv.to_rgb(sh[\"point_color\"]))\n            v_color = False\n        elif isinstance(c, str):  # No color given, use global color\n            # conv = mpl.colors.ColorConverter()\n            colors = c  # np.array(conv.to_rgb(c))\n            v_color = False\n        elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:\n            # Point color\n            colors = c.astype(\"float32\", copy=False)\n\n        elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:\n            # Function values for vertices, but the colors are features\n            c_norm = np.linalg.norm(c, ord=2, axis=-1)\n            normalize = sh[\"normalize\"][0] != None and sh[\"normalize\"][1] != None\n            colors = get_colors(c_norm, sh[\"colormap\"], normalize=normalize,\n                                vmin=sh[\"normalize\"][0], vmax=sh[\"normalize\"][1])\n            colors = colors.astype(\"float32\", copy=False)\n\n        elif type(c) == np.ndarray and c.size == v.shape[0]:  # Function color\n            normalize = sh[\"normalize\"][0] != None and sh[\"normalize\"][1] != None\n            colors = get_colors(c, sh[\"colormap\"], normalize=normalize,\n                                vmin=sh[\"normalize\"][0], vmax=sh[\"normalize\"][1])\n            colors = colors.astype(\"float32\", copy=False)\n            # print(\"Vertex function values\")\n\n        else:\n            print(\"Invalid color array given! Supported are numpy arrays.\", type(c))\n            colors = sh[\"point_color\"]\n            v_color = False\n\n        return colors, v_color\n\n    def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):\n        shading.update(kwargs)\n        sh = self.__get_shading(shading)\n        mesh_obj = {}\n\n        # it is a tet\n        if v.shape[1] == 3 and f.shape[1] == 4:\n            f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)\n            for i in range(f.shape[0]):\n                f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])\n                f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])\n                f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])\n                f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])\n            f = f_tmp\n\n        if v.shape[1] == 2:\n            v = np.append(v, np.zeros([v.shape[0], 1]), 1)\n\n        # Type adjustment vertices\n        v = v.astype(\"float32\", copy=False)\n\n        # Color setup\n        colors, coloring = self.__get_colors(v, f, c, sh)\n\n        # Type adjustment faces and colors\n        c = colors.astype(\"float32\", copy=False)\n\n        # Material and geometry setup\n        ba_dict = {\"color\": p3s.BufferAttribute(c)}\n        if coloring == \"FaceColors\":\n            verts = np.zeros((f.shape[0] * 3, 3), dtype=\"float32\")\n            for ii in range(f.shape[0]):\n                # print(ii*3, f[ii])\n                verts[ii * 3] = v[f[ii, 0]]\n                verts[ii * 3 + 1] = v[f[ii, 1]]\n                verts[ii * 3 + 2] = v[f[ii, 2]]\n            v = verts\n        else:\n            f = f.astype(\"uint32\", copy=False).ravel()\n            ba_dict[\"index\"] = p3s.BufferAttribute(f, normalized=False)\n\n        ba_dict[\"position\"] = p3s.BufferAttribute(v, normalized=False)\n\n        if uv is not None:\n            uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))\n            if texture_data is None:\n                texture_data = gen_checkers(20, 20)\n            tex = p3s.DataTexture(data=texture_data, format=\"RGBFormat\", type=\"FloatType\")\n            material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh[\"reflectivity\"], side=sh[\"side\"],\n                                                roughness=sh[\"roughness\"], metalness=sh[\"metalness\"],\n                                                flatShading=sh[\"flat\"],\n                                                polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)\n            ba_dict[\"uv\"] = p3s.BufferAttribute(uv.astype(\"float32\", copy=False))\n        else:\n            material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh[\"reflectivity\"],\n                                                side=sh[\"side\"], roughness=sh[\"roughness\"], metalness=sh[\"metalness\"],\n                                                flatShading=sh[\"flat\"],\n                                                polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)\n\n        if type(n) != type(None) and coloring == \"VertexColors\":  # TODO: properly handle normals for FaceColors as well\n            ba_dict[\"normal\"] = p3s.BufferAttribute(n.astype(\"float32\", copy=False), normalized=True)\n\n        geometry = p3s.BufferGeometry(attributes=ba_dict)\n\n        if coloring == \"VertexColors\" and type(n) == type(None):\n            geometry.exec_three_obj_method('computeVertexNormals')\n        elif coloring == \"FaceColors\" and type(n) == type(None):\n            geometry.exec_three_obj_method('computeFaceNormals')\n\n        # Mesh setup\n        mesh = p3s.Mesh(geometry=geometry, material=material)\n\n        # Wireframe setup\n        mesh_obj[\"wireframe\"] = None\n        if sh[\"wireframe\"]:\n            wf_geometry = p3s.WireframeGeometry(mesh.geometry)  # WireframeGeometry\n            wf_material = p3s.LineBasicMaterial(color=sh[\"wire_color\"], linewidth=sh[\"wire_width\"])\n            wireframe = p3s.LineSegments(wf_geometry, wf_material)\n            mesh.add(wireframe)\n            mesh_obj[\"wireframe\"] = wireframe\n\n        # Bounding box setup\n        if sh[\"bbox\"]:\n            v_box, f_box = self.__get_bbox(v)\n            _, bbox = self.add_edges(v_box, f_box, sh, mesh)\n            mesh_obj[\"bbox\"] = [bbox, v_box, f_box]\n\n        # Object setup\n        mesh_obj[\"max\"] = np.max(v, axis=0)\n        mesh_obj[\"min\"] = np.min(v, axis=0)\n        mesh_obj[\"geometry\"] = geometry\n        mesh_obj[\"mesh\"] = mesh\n        mesh_obj[\"material\"] = material\n        mesh_obj[\"type\"] = \"Mesh\"\n        mesh_obj[\"shading\"] = sh\n        mesh_obj[\"coloring\"] = coloring\n        mesh_obj[\"arrays\"] = [v, f, c]  # TODO replays with proper storage or remove if not needed\n\n        return self.__add_object(mesh_obj)\n\n    def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):\n        shading.update(kwargs)\n        if len(beginning.shape) == 1:\n            if len(beginning) == 2:\n                beginning = np.array([[beginning[0], beginning[1], 0]])\n        else:\n            if beginning.shape[1] == 2:\n                beginning = np.append(\n                    beginning, np.zeros([beginning.shape[0], 1]), 1)\n        if len(ending.shape) == 1:\n            if len(ending) == 2:\n                ending = np.array([[ending[0], ending[1], 0]])\n        else:\n            if ending.shape[1] == 2:\n                ending = np.append(\n                    ending, np.zeros([ending.shape[0], 1]), 1)\n\n        sh = self.__get_shading(shading)\n        lines = np.hstack([beginning, ending])\n        lines = lines.reshape((-1, 3))\n        return self.__add_line_geometry(lines, sh, obj)\n\n    def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):\n        shading.update(kwargs)\n        if vertices.shape[1] == 2:\n            vertices = np.append(\n                vertices, np.zeros([vertices.shape[0], 1]), 1)\n        sh = self.__get_shading(shading)\n        lines = np.zeros((edges.size, 3))\n        cnt = 0\n        for e in edges:\n            lines[cnt, :] = vertices[e[0]]\n            lines[cnt + 1, :] = vertices[e[1]]\n            cnt += 2\n        return self.__add_line_geometry(lines, sh, obj)\n\n    def add_points(self, points, c=None, shading={}, obj=None, **kwargs):\n        shading.update(kwargs)\n        if len(points.shape) == 1:\n            if len(points) == 2:\n                points = np.array([[points[0], points[1], 0]])\n        else:\n            if points.shape[1] == 2:\n                points = np.append(\n                    points, np.zeros([points.shape[0], 1]), 1)\n        sh = self.__get_shading(shading)\n        points = points.astype(\"float32\", copy=False)\n        mi = np.min(points, axis=0)\n        ma = np.max(points, axis=0)\n\n        g_attributes = {\"position\": p3s.BufferAttribute(points, normalized=False)}\n        m_attributes = {\"size\": sh[\"point_size\"]}\n\n        if sh[\"point_shape\"] == \"circle\":  # Plot circles\n            tex = p3s.DataTexture(data=gen_circle(16, 16), format=\"RGBAFormat\", type=\"FloatType\")\n            m_attributes[\"map\"] = tex\n            m_attributes[\"alphaTest\"] = 0.5\n            m_attributes[\"transparency\"] = True\n        else:  # Plot squares\n            pass\n\n        colors, v_colors = self.__get_point_colors(points, c, sh)\n        if v_colors:  # Colors per point\n            m_attributes[\"vertexColors\"] = 'VertexColors'\n            g_attributes[\"color\"] = p3s.BufferAttribute(colors, normalized=False)\n\n        else:  # Colors for all points\n            m_attributes[\"color\"] = colors\n\n        material = p3s.PointsMaterial(**m_attributes)\n        geometry = p3s.BufferGeometry(attributes=g_attributes)\n        points = p3s.Points(geometry=geometry, material=material)\n        point_obj = {\"geometry\": geometry, \"mesh\": points, \"material\": material,\n                     \"max\": ma, \"min\": mi, \"type\": \"Points\", \"wireframe\": None}\n\n        if obj:\n            return self.__add_object(point_obj, obj), point_obj\n        else:\n            return self.__add_object(point_obj)\n\n    def remove_object(self, obj_id):\n        if obj_id not in self.__objects:\n            print(\"Invalid object id. Valid ids are: \", list(self.__objects.keys()))\n            return\n        self._scene.remove(self.__objects[obj_id][\"mesh\"])\n        del self.__objects[obj_id]\n        self.__update_view()\n\n    def reset(self):\n        for obj_id in list(self.__objects.keys()).copy():\n            self._scene.remove(self.__objects[obj_id][\"mesh\"])\n            del self.__objects[obj_id]\n        self.__update_view()\n\n    def update_object(self, oid=0, vertices=None, colors=None, faces=None):\n        obj = self.__objects[oid]\n        if type(vertices) != type(None):\n            if obj[\"coloring\"] == \"FaceColors\":\n                f = obj[\"arrays\"][1]\n                verts = np.zeros((f.shape[0] * 3, 3), dtype=\"float32\")\n                for ii in range(f.shape[0]):\n                    # print(ii*3, f[ii])\n                    verts[ii * 3] = vertices[f[ii, 0]]\n                    verts[ii * 3 + 1] = vertices[f[ii, 1]]\n                    verts[ii * 3 + 2] = vertices[f[ii, 2]]\n                v = verts\n\n            else:\n                v = vertices.astype(\"float32\", copy=False)\n            obj[\"geometry\"].attributes[\"position\"].array = v\n            # self.wireframe.attributes[\"position\"].array = v # Wireframe updates?\n            obj[\"geometry\"].attributes[\"position\"].needsUpdate = True\n        #           obj[\"geometry\"].exec_three_obj_method('computeVertexNormals')\n        if type(colors) != type(None):\n            colors, coloring = self.__get_colors(obj[\"arrays\"][0], obj[\"arrays\"][1], colors, obj[\"shading\"])\n            colors = colors.astype(\"float32\", copy=False)\n            obj[\"geometry\"].attributes[\"color\"].array = colors\n            obj[\"geometry\"].attributes[\"color\"].needsUpdate = True\n        if type(faces) != type(None):\n            if obj[\"coloring\"] == \"FaceColors\":\n                print(\"Face updates are currently only possible in vertex color mode.\")\n                return\n            f = faces.astype(\"uint32\", copy=False).ravel()\n            print(obj[\"geometry\"].attributes)\n            obj[\"geometry\"].attributes[\"index\"].array = f\n            # self.wireframe.attributes[\"position\"].array = v # Wireframe updates?\n            obj[\"geometry\"].attributes[\"index\"].needsUpdate = True\n        #            obj[\"geometry\"].exec_three_obj_method('computeVertexNormals')\n        # self.mesh.geometry.verticesNeedUpdate = True\n        # self.mesh.geometry.elementsNeedUpdate = True\n        # self.update()\n        if self.render_mode == \"WEBSITE\":\n            return self\n\n    #    def update(self):\n    #        self.mesh.exec_three_obj_method('update')\n    #        self.orbit.exec_three_obj_method('update')\n    #        self.cam.exec_three_obj_method('updateProjectionMatrix')\n    #        self.scene.exec_three_obj_method('update')\n\n    def add_text(self, text, shading={}, **kwargs):\n        shading.update(kwargs)\n        sh = self.__get_shading(shading)\n        tt = p3s.TextTexture(string=text, color=sh[\"text_color\"])\n        sm = p3s.SpriteMaterial(map=tt)\n        text = p3s.Sprite(material=sm, scaleToTexture=True)\n        self._scene.add(text)\n\n    # def add_widget(self, widget, callback):\n    #    self.widgets.append(widget)\n    #    widget.observe(callback, names='value')\n\n    #    def add_dropdown(self, options, default, desc, cb):\n    #        widget = widgets.Dropdown(options=options, value=default, description=desc)\n    #        self.__widgets.append(widget)\n    #        widget.observe(cb, names=\"value\")\n    #        display(widget)\n\n    #    def add_button(self, text, cb):\n    #        button = widgets.Button(description=text)\n    #        self.__widgets.append(button)\n    #        button.on_click(cb)\n    #        display(button)\n\n    def to_html(self, imports=True, html_frame=True):\n        # Bake positions (fixes centering bug in offline rendering)\n        if len(self.__objects) == 0:\n            return\n        ma = np.zeros((len(self.__objects), 3))\n        mi = np.zeros((len(self.__objects), 3))\n        for r, obj in enumerate(self.__objects):\n            ma[r] = self.__objects[obj][\"max\"]\n            mi[r] = self.__objects[obj][\"min\"]\n        ma = np.max(ma, axis=0)\n        mi = np.min(mi, axis=0)\n        diag = np.linalg.norm(ma - mi)\n        mean = (ma - mi) / 2 + mi\n        for r, obj in enumerate(self.__objects):\n            v = self.__objects[obj][\"geometry\"].attributes[\"position\"].array\n            v -= mean\n            # v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window\n\n        scale = self.__s[\"scale\"] * (diag)\n        self._orbit.target = [0.0, 0.0, 0.0]\n        self._cam.lookAt([0.0, 0.0, 0.0])\n        # self._cam.position = [0.0, 0.0, scale]\n        self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window\n        self._light.position = [0.0, 0.0, scale]\n\n        state = embed.dependency_state(self._renderer)\n\n        # Somehow these entries are missing when the state is exported in python.\n        # Exporting from the GUI works, so we are inserting the missing entries.\n        for k in state:\n            if state[k][\"model_name\"] == \"OrbitControlsModel\":\n                state[k][\"state\"][\"maxAzimuthAngle\"] = \"inf\"\n                state[k][\"state\"][\"maxDistance\"] = \"inf\"\n                state[k][\"state\"][\"maxZoom\"] = \"inf\"\n                state[k][\"state\"][\"minAzimuthAngle\"] = \"-inf\"\n\n        tpl = embed.load_requirejs_template\n        if not imports:\n            embed.load_requirejs_template = \"\"\n\n        s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)\n        # s = embed.embed_snippet(self.__w, state=state)\n        embed.load_requirejs_template = tpl\n\n        if html_frame:\n            s = \"<html>\\n<body>\\n\" + s + \"\\n</body>\\n</html>\"\n\n        # Revert changes\n        for r, obj in enumerate(self.__objects):\n            v = self.__objects[obj][\"geometry\"].attributes[\"position\"].array\n            v += mean\n        self.__update_view()\n\n        return s\n\n    def save(self, filename=\"\"):\n        if filename == \"\":\n            uid = str(uuid.uuid4()) + \".html\"\n        else:\n            filename = filename.replace(\".html\", \"\")\n            uid = filename + '.html'\n        with open(uid, \"w\") as f:\n            f.write(self.to_html())\n        print(\"Plot saved to file %s.\" % uid)\n"
  },
  {
    "path": "ultrashape/utils/voxelize.py",
    "content": "import torch\n\ndef voxelize_from_point(pc, num_latents, resolution=128):\n\n    B, N, D = pc.shape\n    device = pc.device\n    \n    norm_pc = (pc + 1.0) / 2.0\n    voxel_indices = torch.floor(norm_pc * resolution).long()\n    voxel_indices = torch.clamp(voxel_indices, 0, resolution - 1) # (B, N, 3)\n    \n    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)\n    flat_indices = torch.cat([batch_idx.unsqueeze(-1), voxel_indices], dim=-1).view(-1, 4)\n    unique_voxels = torch.unique(flat_indices, dim=0)\n    u_batch_ids = unique_voxels[:, 0]\n\n    noise = torch.rand_like(u_batch_ids, dtype=torch.float)\n    sort_keys = u_batch_ids.float() + noise\n    perm = torch.argsort(sort_keys)\n    shuffled_voxels = unique_voxels[perm]\n    shuffled_batch_ids = shuffled_voxels[:, 0].contiguous()\n    \n    counts = torch.bincount(shuffled_batch_ids, minlength=B)\n    min_count = counts.min().item()\n    \n    # Always aim for num_latents\n    actual_k = num_latents\n    \n    if min_count < num_latents:\n        print(f\"[Info] Voxel count ({min_count}) < Target ({num_latents}). Sampling with replacement.\")\n        # If we don't have enough unique voxels, we need to sample with replacement/repetition\n        # We can just repeat the indices to fill the gap\n\n        batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))\n\n        # Create gathering indices that wrap around for each batch\n        # For each batch element i, we want actual_k indices\n        # They start at batch_starts[i] and go up to batch_starts[i] + counts[i]\n        # We use modulo to wrap around: (j % counts[i]) + batch_starts[i]\n\n        # Expand for broadcasting\n        batch_starts_exp = batch_starts.unsqueeze(1) # [B, 1]\n        counts_exp = counts.unsqueeze(1) # [B, 1]\n\n        offsets = torch.arange(actual_k, device=device).unsqueeze(0) # [1, K]\n\n        # Calculate offsets modulo the available count for each batch\n        # This effectively repeats the available voxels to fill the desired size\n        # We need to be careful about division by zero if a batch has 0 voxels (shouldn't happen with valid PC)\n        counts_exp = torch.maximum(counts_exp, torch.tensor(1, device=device))\n\n        wrapped_offsets = offsets % counts_exp\n        gather_indices = batch_starts_exp + wrapped_offsets\n        gather_indices = gather_indices.view(-1)\n\n    else:\n        # Standard case: enough points, just take the first k\n        batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))\n        offsets = torch.arange(actual_k, device=device).unsqueeze(0)\n        gather_indices = batch_starts.unsqueeze(1) + offsets\n        gather_indices = gather_indices.view(-1)\n\n    selected_indices = shuffled_voxels[gather_indices]\n    \n    final_grid_coords = selected_indices[:, 1:]\n    \n    # Grid Index -> Voxel Center\n    voxel_size = 2.0 / resolution\n    final_centers = (final_grid_coords.float() + 0.5) * voxel_size - 1.0\n    \n    sampled_pc = final_centers.view(B, actual_k, 3)\n    sampled_indices = final_grid_coords.view(B, actual_k, 3)\n    \n    return sampled_pc, sampled_indices\n"
  }
]