Repository: apple/ml-fastvlm Branch: main Commit: 592b4add3c1c Files: 74 Total size: 506.0 KB Directory structure: gitextract_4fmnuh7_/ ├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_MODEL ├── README.md ├── app/ │ ├── Configuration/ │ │ └── Build.xcconfig │ ├── FastVLM/ │ │ ├── FastVLM.h │ │ ├── FastVLM.swift │ │ └── MediaProcessingExtensions.swift │ ├── FastVLM App/ │ │ ├── Assets.xcassets/ │ │ │ ├── AccentColor.colorset/ │ │ │ │ └── Contents.json │ │ │ ├── AppIcon.appiconset/ │ │ │ │ └── Contents.json │ │ │ └── Contents.json │ │ ├── ContentView.swift │ │ ├── FastVLM.entitlements │ │ ├── FastVLMApp.swift │ │ ├── FastVLMModel.swift │ │ ├── Info.plist │ │ ├── InfoView.swift │ │ └── Preview Content/ │ │ └── Preview Assets.xcassets/ │ │ └── Contents.json │ ├── FastVLM.xcodeproj/ │ │ ├── project.pbxproj │ │ └── xcshareddata/ │ │ └── xcschemes/ │ │ └── FastVLM App.xcscheme │ ├── README.md │ ├── Video/ │ │ ├── CameraController.swift │ │ ├── CameraControlsView.swift │ │ ├── CameraType.swift │ │ ├── Video.h │ │ └── VideoFrameView.swift │ └── get_pretrained_mlx_model.sh ├── get_models.sh ├── llava/ │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model/ │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model/ │ │ │ ├── llava_llama.py │ │ │ ├── llava_mistral.py │ │ │ ├── llava_mpt.py │ │ │ └── llava_qwen.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder/ │ │ │ ├── builder.py │ │ │ ├── clip_encoder.py │ │ │ ├── mobileclip/ │ │ │ │ ├── __init__.py │ │ │ │ ├── configs/ │ │ │ │ │ └── mobileclip_l.json │ │ │ │ └── mci.py │ │ │ └── mobileclip_encoder.py │ │ ├── multimodal_projector/ │ │ │ └── builder.py │ │ └── utils.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── gradio_web_server.py │ │ ├── model_worker.py │ │ ├── register_worker.py │ │ ├── sglang_worker.py │ │ └── test_message.py │ ├── train/ │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llama_xformers_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ ├── train_mem.py │ │ ├── train_qwen.py │ │ └── train_xformers.py │ └── utils.py ├── model_export/ │ ├── README.md │ ├── export_vision_encoder.py │ └── fastvlm_mlx-vlm.patch ├── predict.py └── pyproject.toml ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # macOS **/.DS_Store # PyCharm project settings .idea/ # Xcode *.xcworkspace # FastVLM models app/FastVLM/model ================================================ FILE: ACKNOWLEDGEMENTS ================================================ Acknowledgements Portions of this Software may utilize the following copyrighted material, the use of which is hereby acknowledged. --------------------------------------------------------------------------------- LLaVA: Large Language and Vision Assistant (LLaVA) Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --------------------------------------------------------------------------------- FastViT (ml-fastvit) Copyright (C) 2023 Apple Inc. All Rights Reserved. IMPORTANT: This Apple software is supplied to you by Apple Inc. ("Apple") in consideration of your agreement to the following terms, and your use, installation, modification or redistribution of this Apple software constitutes acceptance of these terms. If you do not agree with these terms, please do not use, install, modify or redistribute this Apple software. In consideration of your agreement to abide by the following terms, and subject to these terms, Apple grants you a personal, non-exclusive license, under Apple's copyrights in this original Apple software (the "Apple Software"), to use, reproduce, modify and redistribute the Apple Software, with or without modifications, in source and/or binary forms; provided that if you redistribute the Apple Software in its entirety and without modifications, you must retain this notice and the following text and disclaimers in all such redistributions of the Apple Software. Neither the name, trademarks, service marks or logos of Apple Inc. may be used to endorse or promote products derived from the Apple Software without specific prior written permission from Apple. Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent rights that may be infringed by your derivative works or by other works in which the Apple Software may be incorporated. The Apple Software is provided by Apple on an "AS IS" basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------------------------------------------------------------- mlx-vlm MIT License Copyright © 2023 Apple Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --------------------------------------------------------------------------------- MobileCLIP (ml-mobileclip) MIT License Copyright © 2024 Apple Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ------------------------------------------------------------------------------- SOFTWARE DISTRIBUTED WITH ML-MobileCLIP: The ML-MobileCLIP model weights and data copyright and license terms can be found in LICENSE_weights_data. The ML-MobileCLIP software includes a number of subcomponents with separate copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. --------------------------------------------------------------------------------- ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) ================================================ FILE: CONTRIBUTING.md ================================================ # Contribution Guide Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. ## Before you get started By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). ================================================ FILE: LICENSE ================================================ Copyright (C) 2025 Apple Inc. All Rights Reserved. IMPORTANT: This Apple software is supplied to you by Apple Inc. ("Apple") in consideration of your agreement to the following terms, and your use, installation, modification or redistribution of this Apple software constitutes acceptance of these terms. If you do not agree with these terms, please do not use, install, modify or redistribute this Apple software. In consideration of your agreement to abide by the following terms, and subject to these terms, Apple grants you a personal, non-exclusive license, under Apple's copyrights in this original Apple software (the "Apple Software"), to use, reproduce, modify and redistribute the Apple Software, with or without modifications, in source and/or binary forms; provided that if you redistribute the Apple Software in its entirety and without modifications, you must retain this notice and the following text and disclaimers in all such redistributions of the Apple Software. Neither the name, trademarks, service marks or logos of Apple Inc. may be used to endorse or promote products derived from the Apple Software without specific prior written permission from Apple. Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent rights that may be infringed by your derivative works or by other works in which the Apple Software may be incorporated. The Apple Software is provided by Apple on an "AS IS" basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------------------------------------------------------------------------------- SOFTWARE DISTRIBUTED WITH ML-FASTVLM: The ml-fastvlm software includes a number of subcomponents with separate copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. The ml-fastvlm model weights copyright and license terms can be found in LICENSE_MODEL file. ------------------------------------------------------------------------------- ================================================ FILE: LICENSE_MODEL ================================================ Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is specifically developed and released by Apple Inc. ("Apple") for the sole purpose of scientific research of artificial intelligence and machine-learning technology. “Apple Machine Learning Research Model” means the model, including but not limited to algorithms, formulas, trained model weights, parameters, configurations, checkpoints, and any related materials (including documentation). This Apple Machine Learning Research Model is provided to You by Apple in consideration of your agreement to the following terms, and your use, modification, creation of Model Derivatives, and or redistribution of the Apple Machine Learning Research Model constitutes acceptance of this Agreement. If You do not agree with these terms, please do not use, modify, create Model Derivatives of, or distribute this Apple Machine Learning Research Model or Model Derivatives. * License Scope: In consideration of your agreement to abide by the following terms, and subject to these terms, Apple hereby grants you a personal, non-exclusive, worldwide, non-transferable, royalty-free, revocable, and limited license, to use, copy, modify, distribute, and create Model Derivatives (defined below) of the Apple Machine Learning Research Model exclusively for Research Purposes. You agree that any Model Derivatives You may create or that may be created for You will be limited to Research Purposes as well. “Research Purposes” means non-commercial scientific research and academic development activities, such as experimentation, analysis, testing conducted by You with the sole intent to advance scientific knowledge and research. “Research Purposes” does not include any commercial exploitation, product development or use in any commercial product or service. * Distribution of Apple Machine Learning Research Model and Model Derivatives: If you choose to redistribute Apple Machine Learning Research Model or its Model Derivatives, you must provide a copy of this Agreement to such third party, and ensure that the following attribution notice be provided: “Apple Machine Learning Research Model is licensed under the Apple Machine Learning Research Model License Agreement.” Additionally, all Model Derivatives must clearly be identified as such, including disclosure of modifications and changes made to the Apple Machine Learning Research Model. The name, trademarks, service marks or logos of Apple may not be used to endorse or promote Model Derivatives or the relationship between You and Apple. “Model Derivatives” means any models or any other artifacts created by modifications, improvements, adaptations, alterations to the architecture, algorithm or training processes of the Apple Machine Learning Research Model, or by any retraining, fine-tuning of the Apple Machine Learning Research Model. * No Other License: Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent, trademark, and similar intellectual property rights worldwide that may be infringed by the Apple Machine Learning Research Model, the Model Derivatives or by other works in which the Apple Machine Learning Research Model may be incorporated. * Compliance with Laws: Your use of Apple Machine Learning Research Model must be in compliance with all applicable laws and regulations. * Term and Termination: The term of this Agreement will begin upon your acceptance of this Agreement or use of the Apple Machine Learning Research Model and will continue until terminated in accordance with the following terms. Apple may terminate this Agreement at any time if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You must cease to use all Apple Machine Learning Research Models and Model Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will survive termination. * Disclaimer and Limitation of Liability: This Apple Machine Learning Research Model and any outputs generated by the Apple Machine Learning Research Model are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for determining the appropriateness of using or redistributing the Apple Machine Learning Research Model and any outputs of the Apple Machine Learning Research Model and assume any risks associated with Your use of the Apple Machine Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * Governing Law: This Agreement will be governed by and construed under the laws of the State of California without regard to its choice of law principles. The Convention on Contracts for the International Sale of Goods shall not apply to the Agreement except that the arbitration clause and any arbitration hereunder shall be governed by the Federal Arbitration Act, Chapters 1 and 2. Copyright (C) 2025 Apple Inc. All Rights Reserved. ================================================ FILE: README.md ================================================ # FastVLM: Efficient Vision Encoding for Vision Language Models This is the official repository of **[FastVLM: Efficient Vision Encoding for Vision Language Models](https://www.arxiv.org/abs/2412.13303). (CVPR 2025)** [//]: # (![FastViTHD Performance](docs/acc_vs_latency_qwen-2.png))

Accuracy vs latency figure.

### Highlights * We introduce FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images. * Our smallest variant outperforms LLaVA-OneVision-0.5B with 85x faster Time-to-First-Token (TTFT) and 3.4x smaller vision encoder. * Our larger variants using Qwen2-7B LLM outperform recent works like Cambrian-1-8B while using a single image encoder with a 7.9x faster TTFT. * Demo iOS app to demonstrate the performance of our model on a mobile device.
FastVLM - Counting FastVLM - Handwriting FastVLM - Emoji
## Getting Started We use LLaVA codebase to train FastVLM variants. In order to train or finetune your own variants, please follow instructions provided in [LLaVA](https://github.com/haotian-liu/LLaVA) codebase. We provide instructions for running inference with our models. ### Setup ```bash conda create -n fastvlm python=3.10 conda activate fastvlm pip install -e . ``` ### Model Zoo For detailed information on various evaluations, please refer to our [paper](https://www.arxiv.org/abs/2412.13303). | Model | Stage | Pytorch Checkpoint (url) | |:-------------|:-----:|:---------------------------------------------------------------------------------------------------------------:| | FastVLM-0.5B | 2 | [fastvlm_0.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip) | | | 3 | [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip) | | FastVLM-1.5B | 2 | [fastvlm_1.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip) | | | 3 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip) | | FastVLM-7B | 2 | [fastvlm_7b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip) | | | 3 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip) | To download all the pretrained checkpoints run the command below (note that this might take some time depending on your connection so might be good to grab ☕️ while you wait). ```bash bash get_models.sh # Files will be downloaded to `checkpoints` directory. ``` ### Usage Example To run inference of PyTorch checkpoint, follow the instruction below ```bash python predict.py --model-path /path/to/checkpoint-dir \ --image-file /path/to/image.png \ --prompt "Describe the image." ``` ### Inference on Apple Silicon To run inference on Apple Silicon, pytorch checkpoints have to be exported to format suitable for running on Apple Silicon, detailed instructions and code can be found [`model_export`](model_export/) subfolder. Please see the README there for more details. For convenience, we provide 3 models that are in Apple Silicon compatible format: [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3_llm.fp16.zip), [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3_llm.int8.zip), [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3_llm.int4.zip). We encourage developers to export the model of their choice with the appropriate quantization levels following the instructions in [`model_export`](model_export/). ### Inference on Apple Devices To run inference on Apple devices like iPhone, iPad or Mac, see [`app`](app/) subfolder for more details. ## Citation If you found this code useful, please cite the following paper: ``` @InProceedings{fastvlm2025, author = {Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari}, title = {FastVLM: Efficient Vision Encoding for Vision Language Models}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2025}, } ``` ## Acknowledgements Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details. ## License Please check out the repository [LICENSE](LICENSE) before using the provided code and [LICENSE_MODEL](LICENSE_MODEL) for the released models. ================================================ FILE: app/Configuration/Build.xcconfig ================================================ // The `DISAMBIGUATOR` configuration is to make it easier to build // and run a sample code project. Once you set your project's development team, // you'll have a unique bundle identifier. This is because the bundle identifier // is derived based on the 'DISAMBIGUATOR' value. Do not use this // approach in your own projects—it's only useful for example projects because // they are frequently downloaded and don't have a development team set. DISAMBIGUATOR=${DEVELOPMENT_TEAM} ================================================ FILE: app/FastVLM/FastVLM.h ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // #ifndef FastVLM_h #define FastVLM_h #endif /* FastVLM_h */ ================================================ FILE: app/FastVLM/FastVLM.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import CoreImage import CoreML import Foundation import MLX import MLXFast import MLXLMCommon import MLXNN import MLXVLM import Tokenizers // FastVLM is Qwen2VL with a custom vision tower. // MARK: - Common /// Rotates half the hidden dims of the input private func rotateHalf(_ x: MLXArray) -> MLXArray { let index = x.dim(-1) / 2 let x1 = x[.ellipsis, 0 ..< index] let x2 = x[.ellipsis, index...] return concatenated([-x2, x1], axis: -1) } // MARK: - Language private enum Language { /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors static private func applyMultimodalRotaryPositionEmbedding( q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, positionIds: MLXArray, mropeSection: [Int] ) -> (MLXArray, MLXArray) { var cos = cos[positionIds] var sin = sin[positionIds] cos = concatenated( // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, axis: -1 )[0..., .newAxis, 0..., 0...] sin = concatenated( split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, axis: -1 )[0..., .newAxis, 0..., 0...] // Apply rotary embedding let qEmbed = (q * cos) + (rotateHalf(q) * sin) let kEmbed = (k * cos) + (rotateHalf(k) * sin) return (qEmbed, kEmbed) } fileprivate class Attention: Module { let heads: Int let kvHeads: Int let headDim: Int let scale: Float let mropeSection: [Int] @ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "k_proj") var wk: Linear @ModuleInfo(key: "v_proj") var wv: Linear @ModuleInfo(key: "o_proj") var wo: Linear @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE public init(_ args: FastVLMConfiguration.TextConfiguration) { let dim = args.hiddenSize self.heads = args.attentionHeads self.kvHeads = args.kvHeads self.headDim = dim / heads self.scale = pow(Float(headDim), -0.5) self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() self.mropeSection = sequence(state: (0, array.makeIterator())) { state in if let v = state.1.next() { // note the *2 state.0 += v * 2 return state.0 } else { return nil } }.dropLast() } else { fatalError("rope_scaling['mrope_section'] must be an array of integers") } self._rotaryEmbedding.wrappedValue = RoPE( dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) } public func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? ) -> MLXArray { let (B, L) = (x.dim(0), x.dim(1)) var queries = wq(x) var keys = wk(x) var values = wv(x) // prepare the queries, keys and values for the attention computation queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) let offset = cache?.offset ?? 0 let mask = mask?[0..., 0 ..< keys.dim(-2)] queries = rotaryEmbedding(queries, offset: offset) keys = rotaryEmbedding(keys, offset: offset) if let cache { (keys, values) = cache.update(keys: keys, values: values) } let output = MLXFast.scaledDotProductAttention( queries: queries, keys: keys, values: values, scale: scale, mask: mask ) .transposed(0, 2, 1, 3) .reshaped(B, L, -1) return wo(output) } } fileprivate class MLP: Module, UnaryLayer { @ModuleInfo(key: "gate_proj") var gate: Linear @ModuleInfo(key: "down_proj") var down: Linear @ModuleInfo(key: "up_proj") var up: Linear public init(dimensions: Int, hiddenDimensions: Int) { self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) } public func callAsFunction(_ x: MLXArray) -> MLXArray { down(silu(gate(x)) * up(x)) } } fileprivate class FastVLMDecoderLayer: Module { @ModuleInfo(key: "self_attn") var attention: Attention let mlp: MLP @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm public init(_ args: FastVLMConfiguration.TextConfiguration) { self._attention.wrappedValue = Attention(args) self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) self._inputLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) self._postAttentionLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) } public func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? ) -> MLXArray { var r = attention(inputLayerNorm(x), mask: mask, cache: cache) let h = x + r r = mlp(postAttentionLayerNorm(h)) let out = h + r return out } } fileprivate class Qwen2Model: Module { @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding fileprivate let layers: [FastVLMDecoderLayer] fileprivate let norm: RMSNorm public init(_ args: FastVLMConfiguration.TextConfiguration) { precondition(args.vocabularySize > 0) self._embedTokens.wrappedValue = Embedding( embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) self.layers = (0 ..< args.hiddenLayers) .map { _ in FastVLMDecoderLayer(args) } self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) } public func callAsFunction( _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil ) -> MLXArray { var h: MLXArray if let inputEmbedding { h = inputEmbedding } else if let inputs { h = embedTokens(inputs) } else { fatalError("one of inputs or inputEmbedding must be non-nil") } let mask = createAttentionMask(h: h, cache: cache) for (i, layer) in layers.enumerated() { h = layer(h, mask: mask, cache: cache?[i]) } return norm(h) } } fileprivate class LanguageModel: Module, KVCacheDimensionProvider { @ModuleInfo var model: Qwen2Model @ModuleInfo(key: "lm_head") var lmHead: Linear? var kvHeads: [Int] public init(_ args: FastVLMConfiguration.TextConfiguration) { self.model = Qwen2Model(args) if !args.tieWordEmbeddings { _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) } self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } } public func callAsFunction( _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil ) -> LMOutput { var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) if let lmHead { out = lmHead(out) } else { out = model.embedTokens.asLinear(out) } return LMOutput(logits: out) } } } // MARK: - Vision private enum Vision { fileprivate class VisionModelCoreML { let lock = NSLock() var _model: fastvithd? init() { } func load() throws -> fastvithd { try lock.withLock { if let model = _model { return model } let model = try fastvithd() _model = model return model } } public func model() -> fastvithd { try! load() } public func encode(_ image: MLXArray) -> MLXArray { // MLMultiArray requires mutable input data var (data, strides) = { let arrayData = image.asType(.float32).asData(access: .noCopyIfContiguous) return (arrayData.data, arrayData.strides) }() precondition(image.ndim == 4) precondition(image.dim(0) == 1) precondition(image.dim(1) == 3) let h = NSNumber(value: image.dim(2)) let w = NSNumber(value: image.dim(3)) return data.withUnsafeMutableBytes { (ptr: UnsafeMutableRawBufferPointer) in // wrap the backing of the MLXArray let array = try! MLMultiArray( dataPointer: ptr.baseAddress!, shape: [1, 3, h, w], dataType: .float32, strides: strides.map { .init(value: $0) }) // inference let output = try! model().prediction(images: array) precondition(output.image_features.shape == [1, 256, 3072]) precondition(output.image_features.dataType == .float32) return output.image_features.withUnsafeBytes { ptr in MLXArray(ptr, [1, 256, 3072], type: Float32.self) } } } } fileprivate class VisionModel: Module { let model = VisionModelCoreML() public override init() {} public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray { model.encode(hiddenStates) } } } // MARK: - Processor /// FastVLM `UserInputProcessor`. /// /// This is meant to be used with ``FastVLM`` and is typically created by ``VLMModelFactory``. public class FastVLMProcessor: UserInputProcessor { private let config: FastVLMProcessorConfiguration private let imageProcessingConfig: FastVLMPreProcessorConfiguration private let tokenizer: any Tokenizer public init(_ config: FastVLMPreProcessorConfiguration, tokenizer: any Tokenizer) { self.config = FastVLMProcessorConfiguration() self.imageProcessingConfig = config self.tokenizer = tokenizer } public func preprocess(image: CIImage, processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { // first apply the user requested resizing, etc. if any var image = MediaProcessingExtensions.apply(image, processing: processing) // image_processing_clip.py let size = MediaProcessingExtensions.fitIn( image.extent.size, shortestEdge: imageProcessingConfig.size.shortestEdge) image = MediaProcessingExtensions.resampleBicubic(image, to: size) image = MediaProcessingExtensions.centerCrop( image, size: imageProcessingConfig.cropSize.size) image = MediaProcessing.normalize( image, mean: imageProcessingConfig.imageMeanTuple, std: imageProcessingConfig.imageStdTuple) let array = MediaProcessingExtensions.asPlanarMLXArray(image) return (array, .init(0, array.dim(2), array.dim(3))) } public func prepare(prompt: UserInput.Prompt, imageTHW: THW?) -> String { var messages = prompt.asMessages() if messages[0]["role"] != "system" { messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) } let lastIndex = messages.count - 1 var lastMessage = messages[lastIndex]["content"] ?? "" // processing_llava.py if let imageTHW { let height = imageTHW.h let width = imageTHW.w let patchSize = config.patchSize var numImageTokens = (height / patchSize) * (width / patchSize) + config.numAdditionalImageTokens if config.visionFeatureSelectStrategy == .default { numImageTokens -= 1 } lastMessage += Array(repeating: config.imageToken, count: numImageTokens) .joined() } messages[lastIndex]["content"] = lastMessage return messages .map { "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" } .joined(separator: "\n") + "\n<|im_start|>assistant\n" } public func prepare(input: UserInput) throws -> LMInput { if input.images.isEmpty { // just a straight text prompt let prompt = prepare(prompt: input.prompt, imageTHW: nil) let promptTokens = tokenizer.encode(text: prompt) return LMInput(tokens: MLXArray(promptTokens)) } if input.images.count > 1 { throw VLMError.singleImageAllowed } let (pixels, thw) = try preprocess( image: input.images[0].asCIImage(), processing: input.processing) let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: [thw]) let prompt = prepare(prompt: input.prompt, imageTHW: thw) let promptTokens = tokenizer.encode(text: prompt) let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) let mask = ones(like: promptArray).asType(.int8) return LMInput(text: .init(tokens: promptArray, mask: mask), image: image) } } // MARK: - Model private class FastVLMMultiModalProjector: Module, UnaryLayer { @ModuleInfo(key: "linear_0") var linear0: Linear @ModuleInfo(key: "gelu") var gelu: GELU @ModuleInfo(key: "linear_2") var linear2: Linear public init(_ config: FastVLMConfiguration) { self._linear0.wrappedValue = Linear( config.visionConfiguration.hiddenSize, config.textConfiguration.hiddenSize, bias: true) self._gelu.wrappedValue = GELU() self._linear2.wrappedValue = Linear( config.textConfiguration.hiddenSize, config.textConfiguration.hiddenSize, bias: true) } public func callAsFunction(_ x: MLXArray) -> MLXArray { var x = linear0(x) x = gelu(x) x = linear2(x) return x } } /// FastVLM /// /// This is typically created by ``VLMModelFactory``. public class FastVLM: Module, VLMModel, KVCacheDimensionProvider { static public var modelConfiguration: ModelConfiguration { let bundle = Bundle(for: FastVLM.self) let url = bundle.url(forResource: "config", withExtension: "json")! .resolvingSymlinksInPath() .deletingLastPathComponent() return ModelConfiguration(directory: url) } static public func register(modelFactory: VLMModelFactory) { modelFactory.typeRegistry.registerModelType("llava_qwen2") { url in let configuration = try JSONDecoder().decode( FastVLMConfiguration.self, from: Data(contentsOf: url)) return FastVLM(configuration) } modelFactory.processorRegistry.registerProcessorType("LlavaProcessor") { url, tokenizer in let configuration = try JSONDecoder().decode( FastVLMPreProcessorConfiguration.self, from: Data(contentsOf: url)) return FastVLMProcessor(configuration, tokenizer: tokenizer) } } @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel @ModuleInfo(key: "multi_modal_projector") private var multiModalProjector: FastVLMMultiModalProjector public let config: FastVLMConfiguration public var vocabularySize: Int { config.baseConfiguration.vocabularySize } public var kvHeads: [Int] { languageModel.kvHeads } public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } public init(_ config: FastVLMConfiguration) { self.config = config self._visionModel.wrappedValue = Vision.VisionModel() self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) self._multiModalProjector.wrappedValue = FastVLMMultiModalProjector(config) } private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?) -> MLXArray { guard let pixelValues, let gridThw else { return languageModel(inputIds).logits } // Get the input embeddings from the language model let inputEmbeds = languageModel.model.embedTokens(inputIds) // Get the ouptut hidden states from the vision model let imageFeaturesCoreML = self.visionModel(pixelValues, gridThw: gridThw) let imageFeatures = multiModalProjector(imageFeaturesCoreML) // Insert special image tokens in the input_ids return mergeInputIdsWithImageFeatures( inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: imageFeatures) } private func mergeInputIdsWithImageFeatures( inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray ) -> MLXArray { let imageTokenIndex = config.baseConfiguration.imageTokenId var imageIndices = [Int]() for (i, v) in inputIds.asArray(Int.self).enumerated() { if v == imageTokenIndex { imageIndices.append(i) } } inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures return inputEmbeds } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws -> PrepareResult { let gridThw = input.image?.imageGridThw let dtype = DType.float32 let pixels = input.image?.pixels.asType(dtype) let inputEmbeddings = self.inputEmbeddings( inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw) let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) return .logits(result) } public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { languageModel(inputs, cache: cache).logits } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { _ = try? visionModel.model.load() return weights } } // MARK: - Configuration /// Configuration for ``FastVLM`` public struct FastVLMConfiguration: Codable, Sendable { public struct VisionConfiguration: Codable, Sendable { public let hiddenSize: Int enum CodingKeys: String, CodingKey { case hiddenSize = "mm_hidden_size" } } public struct TextConfiguration: Codable, Sendable { public let modelType: String public let hiddenSize: Int public let hiddenLayers: Int public let intermediateSize: Int public let attentionHeads: Int private let _rmsNormEps: Float? public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } public let vocabularySize: Int public let kvHeads: Int private let _maxPositionEmbeddings: Int? public var maxpPositionEmbeddings: Int { _maxPositionEmbeddings ?? 32768 } private let _ropeTheta: Float? public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } private let _ropeTraditional: Bool? public var ropeTraditional: Bool { _ropeTraditional ?? false } public let _ropeScaling: [String: StringOrNumber]? public var ropeScaling: [String: StringOrNumber]? { _ropeScaling ?? ["mrope_section": .ints([2, 1, 1])] } private let _tieWordEmbeddings: Bool? public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } enum CodingKeys: String, CodingKey { case modelType = "model_type" case hiddenSize = "hidden_size" case hiddenLayers = "num_hidden_layers" case intermediateSize = "intermediate_size" case attentionHeads = "num_attention_heads" case _rmsNormEps = "rms_norm_eps" case vocabularySize = "vocab_size" case kvHeads = "num_key_value_heads" case _maxPositionEmbeddings = "max_position_embeddings" case _ropeTheta = "rope_theta" case _ropeTraditional = "rope_traditional" case _ropeScaling = "rope_scaling" case _tieWordEmbeddings = "tie_word_embeddings" } } public struct BaseConfiguration: Codable, Sendable { public let modelType: String public let vocabularySize: Int public let imageTokenId: Int public let hiddenSize: Int enum CodingKeys: String, CodingKey { case modelType = "model_type" case vocabularySize = "vocab_size" case imageTokenId = "image_token_index" case hiddenSize = "hidden_size" } } public let visionConfiguration: VisionConfiguration public let textConfiguration: TextConfiguration public let baseConfiguration: BaseConfiguration public init(from decoder: any Swift.Decoder) throws { // these are overlaid in the top level self.visionConfiguration = try VisionConfiguration(from: decoder) self.textConfiguration = try TextConfiguration(from: decoder) self.baseConfiguration = try BaseConfiguration(from: decoder) } } /// Configuration for ``FastVLMProcessor`` public struct FastVLMPreProcessorConfiguration: Codable, Sendable { public struct CropSize: Codable, Sendable { let width: Int let height: Int var size: CGSize { .init(width: CGFloat(width), height: CGFloat(height)) } } public struct Size: Codable, Sendable { let shortestEdge: Int enum CodingKeys: String, CodingKey { case shortestEdge = "shortest_edge" } } public var imageMean: [CGFloat] public var imageStd: [CGFloat] public var size: Size public var cropSize: CropSize public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { (imageMean[0], imageMean[1], imageMean[2]) } public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } enum CodingKeys: String, CodingKey { case imageMean = "image_mean" case imageStd = "image_std" case size case cropSize = "crop_size" } } public struct FastVLMProcessorConfiguration: Codable, Sendable { public enum Strategy: Codable, Sendable { case `default` } public var imageToken = "" public var numAdditionalImageTokens = 0 public var patchSize = 64 public var visionFeatureSelectStrategy: Strategy? } ================================================ FILE: app/FastVLM/MediaProcessingExtensions.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import Accelerate import CoreImage import MLX import MLXLMCommon import MLXVLM /// Additions to MediaProcessing -- not currently present in mlx-libraries enum MediaProcessingExtensions { // this function is not exported in current mlx-swift-examples -- local copy until it is exposed // properly public static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage { var image = image if let resize = processing?.resize { let scale = MediaProcessing.bestFitScale(image.extent.size, in: resize) image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale)) } return image } public static func rectSmallerOrEqual(_ extent: CGRect, size: CGSize) -> Bool { return extent.width <= size.width && extent.height <= size.height } public static func centerCrop(_ extent: CGRect, size: CGSize) -> CGRect { let targetWidth = min(extent.width, size.width) let targetHeight = min(extent.height, size.height) return CGRect( x: (extent.maxX - targetWidth) / 2, y: (extent.maxY - targetHeight) / 2, width: targetWidth, height: targetHeight ) } public static func centerCrop(_ image: CIImage, size: CGSize) -> CIImage { let extent = image.extent if rectSmallerOrEqual(extent, size: size) { return image } let crop = centerCrop(extent, size: size) return image .cropped(to: crop) .transformed(by: CGAffineTransform(translationX: -crop.minX, y: -crop.minY)) } public static func fitIn(_ size: CGSize, shortestEdge: Int) -> CGSize { let floatShortestEdge = CGFloat(shortestEdge) let (short, long) = size.width <= size.height ? (size.width, size.height) : (size.height, size.width) let newShort = floatShortestEdge let newLong = floatShortestEdge * long / short return size.width <= size.height ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort) } public static func fitIn(_ size: CGSize, longestEdge: Int) -> CGSize { let floatLongestEdge = CGFloat(longestEdge) var (newShort, newLong) = size.width <= size.height ? (size.width, size.height) : (size.height, size.width) if newLong > floatLongestEdge { newLong = floatLongestEdge newShort = floatLongestEdge * newShort / newLong } return size.width <= size.height ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort) } // version of function from https://github.com/ml-explore/mlx-swift-examples/pull/222 public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { // Create a bicubic scale filter let yScale = size.height / image.extent.height let xScale = size.width / image.extent.width let filter = CIFilter.bicubicScaleTransform() filter.inputImage = image filter.scale = Float(yScale) filter.aspectRatio = Float(xScale / yScale) let scaledImage = filter.outputImage! // Create a rect with the exact dimensions we want let exactRect = CGRect( x: 0, y: 0, width: size.width, height: size.height ) // Crop to ensure exact dimensions return scaledImage.cropped(to: exactRect) } static let context = CIContext() /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`. /// /// This physically moves the channels into a planar configuration -- this is /// required for feeding into the CoreML model and is faster to use /// dedicated functions than transforming into contiguous memory /// on readout. static public func asPlanarMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray { let size = image.extent.size let w = Int(size.width.rounded()) let h = Int(size.height.rounded()) // probably not strictly necessary, but this is what happens in // e.g. image_processing_siglip in transformers (float32) let format = CIFormat.RGBAf let componentsPerPixel = 4 let bytesPerComponent: Int = MemoryLayout.size let bytesPerPixel = componentsPerPixel * bytesPerComponent let bytesPerRow = w * bytesPerPixel var data = Data(count: w * h * bytesPerPixel) var planarData = Data(count: 3 * w * h * bytesPerComponent) data.withUnsafeMutableBytes { ptr in context.render( image, toBitmap: ptr.baseAddress!, rowBytes: bytesPerRow, bounds: image.extent, format: format, colorSpace: colorSpace) context.clearCaches() let vh = vImagePixelCount(h) let vw = vImagePixelCount(w) // convert from RGBAf -> RGBf in place let rgbBytesPerRow = w * 3 * bytesPerComponent var rgbaSrc = vImage_Buffer( data: ptr.baseAddress!, height: vh, width: vw, rowBytes: bytesPerRow) var rgbDest = vImage_Buffer( data: ptr.baseAddress!, height: vh, width: vw, rowBytes: rgbBytesPerRow) vImageConvert_RGBAFFFFtoRGBFFF(&rgbaSrc, &rgbDest, vImage_Flags(kvImageNoFlags)) // and convert to planar data in a second buffer planarData.withUnsafeMutableBytes { planarPtr in let planeBytesPerRow = w * bytesPerComponent var rDest = vImage_Buffer( data: planarPtr.baseAddress!.advanced(by: 0 * planeBytesPerRow * h), height: vh, width: vw, rowBytes: planeBytesPerRow) var gDest = vImage_Buffer( data: planarPtr.baseAddress!.advanced(by: 1 * planeBytesPerRow * h), height: vh, width: vw, rowBytes: planeBytesPerRow) var bDest = vImage_Buffer( data: planarPtr.baseAddress!.advanced(by: 2 * planeBytesPerRow * h), height: vh, width: vw, rowBytes: planeBytesPerRow) vImageConvert_RGBFFFtoPlanarF( &rgbDest, &rDest, &gDest, &bDest, vImage_Flags(kvImageNoFlags)) } } return MLXArray(planarData, [1, 3, h, w], type: Float32.self) } } ================================================ FILE: app/FastVLM App/Assets.xcassets/AccentColor.colorset/Contents.json ================================================ { "colors" : [ { "idiom" : "universal" } ], "info" : { "author" : "xcode", "version" : 1 } } ================================================ FILE: app/FastVLM App/Assets.xcassets/AppIcon.appiconset/Contents.json ================================================ { "images" : [ { "filename" : "FastVLM - 150 Blue - Light@2x.png", "idiom" : "universal", "platform" : "ios", "size" : "1024x1024" }, { "appearances" : [ { "appearance" : "luminosity", "value" : "dark" } ], "filename" : "FastVLM - 150 Blue - Dark@2x.png", "idiom" : "universal", "platform" : "ios", "size" : "1024x1024" }, { "appearances" : [ { "appearance" : "luminosity", "value" : "tinted" } ], "filename" : "FastVLM - 150 White - Tinted@2x.png", "idiom" : "universal", "platform" : "ios", "size" : "1024x1024" }, { "idiom" : "mac", "scale" : "1x", "size" : "16x16" }, { "idiom" : "mac", "scale" : "2x", "size" : "16x16" }, { "idiom" : "mac", "scale" : "1x", "size" : "32x32" }, { "idiom" : "mac", "scale" : "2x", "size" : "32x32" }, { "idiom" : "mac", "scale" : "1x", "size" : "128x128" }, { "idiom" : "mac", "scale" : "2x", "size" : "128x128" }, { "idiom" : "mac", "scale" : "1x", "size" : "256x256" }, { "idiom" : "mac", "scale" : "2x", "size" : "256x256" }, { "filename" : "FastVLM - MacOS - Dark@1x.png", "idiom" : "mac", "scale" : "1x", "size" : "512x512" }, { "filename" : "FastVLM - MacOS - Dark@2x.png", "idiom" : "mac", "scale" : "2x", "size" : "512x512" } ], "info" : { "author" : "xcode", "version" : 1 } } ================================================ FILE: app/FastVLM App/Assets.xcassets/Contents.json ================================================ { "info" : { "author" : "xcode", "version" : 1 } } ================================================ FILE: app/FastVLM App/ContentView.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import AVFoundation import MLXLMCommon import SwiftUI import Video // support swift 6 extension CVImageBuffer: @unchecked @retroactive Sendable {} extension CMSampleBuffer: @unchecked @retroactive Sendable {} // delay between frames -- controls the frame rate of the updates let FRAME_DELAY = Duration.milliseconds(1) struct ContentView: View { @State private var camera = CameraController() @State private var model = FastVLMModel() /// stream of frames -> VideoFrameView, see distributeVideoFrames @State private var framesToDisplay: AsyncStream? @State private var prompt = "Describe the image in English." @State private var promptSuffix = "Output should be brief, about 15 words or less." @State private var isShowingInfo: Bool = false @State private var selectedCameraType: CameraType = .continuous @State private var isEditingPrompt: Bool = false var toolbarItemPlacement: ToolbarItemPlacement { var placement: ToolbarItemPlacement = .navigation #if os(iOS) placement = .topBarLeading #endif return placement } var statusTextColor : Color { return model.evaluationState == .processingPrompt ? .black : .white } var statusBackgroundColor : Color { switch model.evaluationState { case .idle: return .gray case .generatingResponse: return .green case .processingPrompt: return .yellow } } var body: some View { NavigationStack { Form { Section { VStack(alignment: .leading, spacing: 10.0) { Picker("Camera Type", selection: $selectedCameraType) { ForEach(CameraType.allCases, id: \.self) { cameraType in Text(cameraType.rawValue.capitalized).tag(cameraType) } } // Prevent macOS from adding a text label for the picker .labelsHidden() .pickerStyle(.segmented) .onChange(of: selectedCameraType) { _, _ in // Cancel any in-flight requests when switching modes model.cancel() } if let framesToDisplay { VideoFrameView( frames: framesToDisplay, cameraType: selectedCameraType, action: { frame in processSingleFrame(frame) }) // Because we're using the AVCaptureSession preset // `.vga640x480`, we can assume this aspect ratio .aspectRatio(4/3, contentMode: .fit) #if os(macOS) .frame(maxWidth: 750) #endif .overlay(alignment: .top) { if !model.promptTime.isEmpty { Text("TTFT \(model.promptTime)") .font(.caption) .foregroundStyle(.white) .monospaced() .padding(.vertical, 4.0) .padding(.horizontal, 6.0) .background(alignment: .center) { RoundedRectangle(cornerRadius: 8) .fill(Color.black.opacity(0.6)) } .padding(.top) } } #if !os(macOS) .overlay(alignment: .topTrailing) { CameraControlsView( backCamera: $camera.backCamera, device: $camera.device, devices: $camera.devices) .padding() } #endif .overlay(alignment: .bottom) { if selectedCameraType == .continuous { Group { if model.evaluationState == .processingPrompt { HStack { ProgressView() .tint(self.statusTextColor) .controlSize(.small) Text(model.evaluationState.rawValue) } } else if model.evaluationState == .idle { HStack(spacing: 6.0) { Image(systemName: "clock.fill") .font(.caption) Text(model.evaluationState.rawValue) } } else { // I'm manually tweaking the spacing to // better match the spacing with ProgressView HStack(spacing: 6.0) { Image(systemName: "lightbulb.fill") .font(.caption) Text(model.evaluationState.rawValue) } } } .foregroundStyle(self.statusTextColor) .font(.caption) .bold() .padding(.vertical, 6.0) .padding(.horizontal, 8.0) .background(self.statusBackgroundColor) .clipShape(.capsule) .padding(.bottom) } } #if os(macOS) .frame(maxWidth: .infinity) .frame(minWidth: 500) .frame(minHeight: 375) #endif } } } .listRowInsets(EdgeInsets()) .listRowBackground(Color.clear) .listRowSeparator(.hidden) promptSections Section { if model.output.isEmpty && model.running { ProgressView() .controlSize(.large) .frame(maxWidth: .infinity) } else { ScrollView { Text(model.output) .foregroundStyle(isEditingPrompt ? .secondary : .primary) .textSelection(.enabled) #if os(macOS) .font(.headline) .fontWeight(.regular) #endif } .frame(minHeight: 50.0, maxHeight: 200.0) } } header: { Text("Response") #if os(macOS) .font(.headline) .padding(.bottom, 2.0) #endif } #if os(macOS) Spacer() #endif } #if os(iOS) .listSectionSpacing(0) #elseif os(macOS) .padding() #endif .task { camera.start() } .task { await model.load() } #if !os(macOS) .onAppear { // Prevent the screen from dimming or sleeping due to inactivity UIApplication.shared.isIdleTimerDisabled = true } .onDisappear { // Resumes normal idle timer behavior UIApplication.shared.isIdleTimerDisabled = false } #endif // task to distribute video frames -- this will cancel // and restart when the view is on/off screen. note: it is // important that this is here (attached to the VideoFrameView) // rather than the outer view because this has the correct lifecycle .task { if Task.isCancelled { return } await distributeVideoFrames() } .navigationTitle("FastVLM") #if os(iOS) .navigationBarTitleDisplayMode(.inline) #endif .toolbar { ToolbarItem(placement: toolbarItemPlacement) { Button { isShowingInfo.toggle() } label: { Image(systemName: "info.circle") } } ToolbarItem(placement: .primaryAction) { if isEditingPrompt { Button { isEditingPrompt.toggle() } label: { Text("Done") .fontWeight(.bold) } } else { Menu { Button("Describe image") { prompt = "Describe the image in English." promptSuffix = "Output should be brief, about 15 words or less." } Button("Facial expression") { prompt = "What is this person's facial expression?" promptSuffix = "Output only one or two words." } Button("Read text") { prompt = "What is written in this image?" promptSuffix = "Output only the text in the image." } #if !os(macOS) Button("Customize...") { isEditingPrompt.toggle() } #endif } label: { Text("Prompts") } } } } .sheet(isPresented: $isShowingInfo) { InfoView() } } } var promptSummary: some View { Section("Prompt") { VStack(alignment: .leading, spacing: 4.0) { let trimmedPrompt = prompt.trimmingCharacters(in: .whitespacesAndNewlines) if !trimmedPrompt.isEmpty { Text(trimmedPrompt) .foregroundStyle(.secondary) } let trimmedSuffix = promptSuffix.trimmingCharacters(in: .whitespacesAndNewlines) if !trimmedSuffix.isEmpty { Text(trimmedSuffix) .font(.caption) .foregroundStyle(.tertiary) } } } } var promptForm: some View { Group { #if os(iOS) Section("Prompt") { TextEditor(text: $prompt) .frame(minHeight: 38) } Section("Prompt Suffix") { TextEditor(text: $promptSuffix) .frame(minHeight: 38) } #elseif os(macOS) Section { HStack(alignment: .top) { VStack(alignment: .leading) { Text("Prompt") .font(.headline) TextEditor(text: $prompt) .frame(height: 38) .padding(.horizontal, 8.0) .padding(.vertical, 10.0) .background(Color(.textBackgroundColor)) .cornerRadius(10.0) } VStack(alignment: .leading) { Text("Prompt Suffix") .font(.headline) TextEditor(text: $promptSuffix) .frame(height: 38) .padding(.horizontal, 8.0) .padding(.vertical, 10.0) .background(Color(.textBackgroundColor)) .cornerRadius(10.0) } } } .padding(.vertical) #endif } } var promptSections: some View { Group { #if os(iOS) if isEditingPrompt { promptForm } else { promptSummary } #elseif os(macOS) promptForm #endif } } func analyzeVideoFrames(_ frames: AsyncStream) async { for await frame in frames { let userInput = UserInput( prompt: .text("\(prompt) \(promptSuffix)"), images: [.ciImage(CIImage(cvPixelBuffer: frame))] ) // generate output for a frame and wait for generation to complete let t = await model.generate(userInput) _ = await t.result do { try await Task.sleep(for: FRAME_DELAY) } catch { return } } } func distributeVideoFrames() async { // attach a stream to the camera -- this code will read this let frames = AsyncStream(bufferingPolicy: .bufferingNewest(1)) { camera.attach(continuation: $0) } let (framesToDisplay, framesToDisplayContinuation) = AsyncStream.makeStream( of: CVImageBuffer.self, bufferingPolicy: .bufferingNewest(1) ) self.framesToDisplay = framesToDisplay // Only create analysis stream if in continuous mode let (framesToAnalyze, framesToAnalyzeContinuation) = AsyncStream.makeStream( of: CVImageBuffer.self, bufferingPolicy: .bufferingNewest(1) ) // set up structured tasks (important -- this means the child tasks // are cancelled when the parent is cancelled) async let distributeFrames: () = { for await sampleBuffer in frames { if let frame = sampleBuffer.imageBuffer { framesToDisplayContinuation.yield(frame) // Only send frames for analysis in continuous mode if await selectedCameraType == .continuous { framesToAnalyzeContinuation.yield(frame) } } } // detach from the camera controller and feed to the video view await MainActor.run { self.framesToDisplay = nil self.camera.detatch() } framesToDisplayContinuation.finish() framesToAnalyzeContinuation.finish() }() // Only analyze frames if in continuous mode if selectedCameraType == .continuous { async let analyze: () = analyzeVideoFrames(framesToAnalyze) await distributeFrames await analyze } else { await distributeFrames } } /// Perform FastVLM inference on a single frame. /// - Parameter frame: The frame to analyze. func processSingleFrame(_ frame: CVImageBuffer) { // Reset Response UI (show spinner) Task { @MainActor in model.output = "" } // Construct request to model let userInput = UserInput( prompt: .text("\(prompt) \(promptSuffix)"), images: [.ciImage(CIImage(cvPixelBuffer: frame))] ) // Post request to FastVLM Task { await model.generate(userInput) } } } #Preview { ContentView() } ================================================ FILE: app/FastVLM App/FastVLM.entitlements ================================================ com.apple.developer.kernel.increased-memory-limit com.apple.security.app-sandbox com.apple.security.device.camera com.apple.security.files.user-selected.read-only com.apple.security.network.client ================================================ FILE: app/FastVLM App/FastVLMApp.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import SwiftUI @main struct FastVLMApp: App { var body: some Scene { WindowGroup { ContentView() } } } ================================================ FILE: app/FastVLM App/FastVLMModel.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import CoreImage import FastVLM import Foundation import MLX import MLXLMCommon import MLXRandom import MLXVLM @Observable @MainActor class FastVLMModel { public var running = false public var modelInfo = "" public var output = "" public var promptTime: String = "" enum LoadState { case idle case loaded(ModelContainer) } private let modelConfiguration = FastVLM.modelConfiguration /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.0) let maxTokens = 240 /// update the display every N tokens -- 4 looks like it updates continuously /// and is low overhead. observed ~15% reduction in tokens/s when updating /// on every token let displayEveryNTokens = 4 private var loadState = LoadState.idle private var currentTask: Task? enum EvaluationState: String, CaseIterable { case idle = "Idle" case processingPrompt = "Processing Prompt" case generatingResponse = "Generating Response" } public var evaluationState = EvaluationState.idle public init() { FastVLM.register(modelFactory: VLMModelFactory.shared) } private func _load() async throws -> ModelContainer { switch loadState { case .idle: // limit the buffer cache MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) let modelContainer = try await VLMModelFactory.shared.loadContainer( configuration: modelConfiguration ) { [modelConfiguration] progress in Task { @MainActor in self.modelInfo = "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" } } self.modelInfo = "Loaded" loadState = .loaded(modelContainer) return modelContainer case .loaded(let modelContainer): return modelContainer } } public func load() async { do { _ = try await _load() } catch { self.modelInfo = "Error loading model: \(error)" } } public func generate(_ userInput: UserInput) async -> Task { if let currentTask, running { return currentTask } running = true // Cancel any existing task currentTask?.cancel() // Create new task and store reference let task = Task { do { let modelContainer = try await _load() // each time you generate you will get something new MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) // Check if task was cancelled if Task.isCancelled { return } let result = try await modelContainer.perform { context in // Measure the time it takes to prepare the input Task { @MainActor in evaluationState = .processingPrompt } let llmStart = Date() let input = try await context.processor.prepare(input: userInput) var seenFirstToken = false // FastVLM generates the output let result = try MLXLMCommon.generate( input: input, parameters: generateParameters, context: context ) { tokens in // Check if task was cancelled if Task.isCancelled { return .stop } if !seenFirstToken { seenFirstToken = true // produced first token, update the time to first token, // the processing state and start displaying the text let llmDuration = Date().timeIntervalSince(llmStart) let text = context.tokenizer.decode(tokens: tokens) Task { @MainActor in evaluationState = .generatingResponse self.output = text self.promptTime = "\(Int(llmDuration * 1000)) ms" } } // Show the text in the view as it generates if tokens.count % displayEveryNTokens == 0 { let text = context.tokenizer.decode(tokens: tokens) Task { @MainActor in self.output = text } } if tokens.count >= maxTokens { return .stop } else { return .more } } // Return the duration of the LLM and the result return result } // Check if task was cancelled before updating UI if !Task.isCancelled { self.output = result.output } } catch { if !Task.isCancelled { output = "Failed: \(error)" } } if evaluationState == .generatingResponse { evaluationState = .idle } running = false } currentTask = task return task } public func cancel() { currentTask?.cancel() currentTask = nil running = false output = "" promptTime = "" } } ================================================ FILE: app/FastVLM App/Info.plist ================================================ ================================================ FILE: app/FastVLM App/InfoView.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import Foundation import SwiftUI struct InfoView: View { @Environment(\.dismiss) var dismiss let paragraph1 = "**FastVLM¹** is a new family of Vision-Language models that makes use of **FastViTHD**, a hierarchical hybrid vision encoder that produces small number of high quality tokens at low latencies, resulting in significantly faster time-to-first-token (TTFT)." let paragraph2 = "This app showcases the **FastVLM** model in action, allowing users to freely customize the prompt. FastVLM utilizes Qwen2-Instruct LLMs without additional safety tuning, so please exercise caution when modifying the prompt." let footer = "1. **FastVLM: Efficient Vision Encoding for Vision Language Models.** (CVPR 2025) Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari" var body: some View { NavigationStack { VStack(alignment: .leading, spacing: 20.0) { // I'm not going to lie, this doesn't make sense... // Wrapping `String`s with `.init()` turns them into `LocalizedStringKey`s // which gives us all of the fun Markdown formatting while retaining the // ability to use `String` variables. ¯\_(ツ)_/¯ Text("\(.init(paragraph1))\n\n\(.init(paragraph2))\n\n") .font(.body) Spacer() Text(.init(footer)) .font(.caption) .foregroundStyle(.secondary) } .padding() .frame(maxWidth: .infinity, maxHeight: .infinity, alignment: .top) .textSelection(.enabled) .navigationTitle("Information") #if os(iOS) .navigationBarTitleDisplayMode(.inline) #endif .toolbar { #if os(iOS) ToolbarItem(placement: .navigationBarLeading) { Button { dismiss() } label: { Image(systemName: "xmark.circle") .resizable() .frame(width: 25, height: 25) .foregroundStyle(.secondary) } .buttonStyle(.plain) } #elseif os(macOS) ToolbarItem(placement: .cancellationAction) { Button("Done") { dismiss() } .buttonStyle(.bordered) } #endif } } } } #Preview { InfoView() } ================================================ FILE: app/FastVLM App/Preview Content/Preview Assets.xcassets/Contents.json ================================================ { "info" : { "author" : "xcode", "version" : 1 } } ================================================ FILE: app/FastVLM.xcodeproj/project.pbxproj ================================================ // !$*UTF8*$! { archiveVersion = 1; classes = { }; objectVersion = 77; objects = { /* Begin PBXBuildFile section */ 019A3E1A2D78E7370055F93B /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E192D78E7370055F93B /* MLX */; }; 019A3E1C2D78E73E0055F93B /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1B2D78E73E0055F93B /* MLXLMCommon */; }; 019A3E1E2D78E7470055F93B /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1D2D78E7470055F93B /* MLXRandom */; }; 019A3E202D78E74C0055F93B /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1F2D78E74C0055F93B /* MLXVLM */; }; 019A3E212D78E7530055F93B /* Video.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C35372AE2D08C32D00474D34 /* Video.framework */; }; 019A3E222D78E7530055F93B /* Video.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C35372AE2D08C32D00474D34 /* Video.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C3ED544E2D790860005E20B3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED544D2D790860005E20B3 /* MLXLMCommon */; }; C3ED54502D790860005E20B3 /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED544F2D790860005E20B3 /* MLXVLM */; }; C3ED54522D790860005E20B3 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54512D790860005E20B3 /* MLX */; }; C3ED54542D790860005E20B3 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54532D790860005E20B3 /* MLXNN */; }; C3ED54582D790A68005E20B3 /* MLXFast in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54572D790A68005E20B3 /* MLXFast */; }; C3ED545B2D790AD6005E20B3 /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED545A2D790AD6005E20B3 /* Transformers */; }; C3ED55012D7A0A7A005E20B3 /* FastVLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C39BB3E62D79082A005DB8FB /* FastVLM.framework */; }; C3ED55022D7A0A7A005E20B3 /* FastVLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C39BB3E62D79082A005DB8FB /* FastVLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ 019A3E232D78E7530055F93B /* PBXContainerItemProxy */ = { isa = PBXContainerItemProxy; containerPortal = C35EDB642D07699400757E80 /* Project object */; proxyType = 1; remoteGlobalIDString = C35372AD2D08C32D00474D34; remoteInfo = Video; }; C3ED55032D7A0A7A005E20B3 /* PBXContainerItemProxy */ = { isa = PBXContainerItemProxy; containerPortal = C35EDB642D07699400757E80 /* Project object */; proxyType = 1; remoteGlobalIDString = C39BB3E52D79082A005DB8FB; remoteInfo = FastVLM; }; /* End PBXContainerItemProxy section */ /* Begin PBXCopyFilesBuildPhase section */ 019A3E252D78E7530055F93B /* Embed Frameworks */ = { isa = PBXCopyFilesBuildPhase; buildActionMask = 2147483647; dstPath = ""; dstSubfolderSpec = 10; files = ( C3ED55022D7A0A7A005E20B3 /* FastVLM.framework in Embed Frameworks */, 019A3E222D78E7530055F93B /* Video.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; }; /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ 019A3E0A2D78E6A00055F93B /* FastVLM App.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "FastVLM App.app"; sourceTree = BUILT_PRODUCTS_DIR; }; 12FFAF3D2DC93583009C4EFA /* get_pretrained_mlx_model.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = get_pretrained_mlx_model.sh; sourceTree = ""; }; 12FFAF3E2DC93583009C4EFA /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C35372AE2D08C32D00474D34 /* Video.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Video.framework; sourceTree = BUILT_PRODUCTS_DIR; }; C39BB3E62D79082A005DB8FB /* FastVLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = FastVLM.framework; sourceTree = BUILT_PRODUCTS_DIR; }; /* End PBXFileReference section */ /* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */ 120A44852D9B05A900E244A3 /* Exceptions for "FastVLM App" folder in "FastVLM App" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; membershipExceptions = ( Info.plist, ); target = 019A3E092D78E6A00055F93B /* FastVLM App */; }; C35372B92D08C32D00474D34 /* Exceptions for "Video" folder in "Video" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; publicHeaders = ( Video.h, ); target = C35372AD2D08C32D00474D34 /* Video */; }; C3ED54BB2D791BEA005E20B3 /* Exceptions for "FastVLM" folder in "FastVLM" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; publicHeaders = ( FastVLM.h, ); target = C39BB3E52D79082A005DB8FB /* FastVLM */; }; /* End PBXFileSystemSynchronizedBuildFileExceptionSet section */ /* Begin PBXFileSystemSynchronizedRootGroup section */ 019A3E0B2D78E6A00055F93B /* FastVLM App */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( 120A44852D9B05A900E244A3 /* Exceptions for "FastVLM App" folder in "FastVLM App" target */, ); path = "FastVLM App"; sourceTree = ""; }; C32B4A802DA4805400EF663D /* Configuration */ = { isa = PBXFileSystemSynchronizedRootGroup; path = Configuration; sourceTree = ""; }; C35372AF2D08C32D00474D34 /* Video */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( C35372B92D08C32D00474D34 /* Exceptions for "Video" folder in "Video" target */, ); path = Video; sourceTree = ""; }; C39BB3E72D79082A005DB8FB /* FastVLM */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( C3ED54BB2D791BEA005E20B3 /* Exceptions for "FastVLM" folder in "FastVLM" target */, ); path = FastVLM; sourceTree = ""; }; /* End PBXFileSystemSynchronizedRootGroup section */ /* Begin PBXFrameworksBuildPhase section */ 019A3E072D78E6A00055F93B /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( 019A3E1C2D78E73E0055F93B /* MLXLMCommon in Frameworks */, 019A3E212D78E7530055F93B /* Video.framework in Frameworks */, C3ED55012D7A0A7A005E20B3 /* FastVLM.framework in Frameworks */, 019A3E1E2D78E7470055F93B /* MLXRandom in Frameworks */, 019A3E202D78E74C0055F93B /* MLXVLM in Frameworks */, 019A3E1A2D78E7370055F93B /* MLX in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; C35372AB2D08C32D00474D34 /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C39BB3E32D79082A005DB8FB /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( C3ED545B2D790AD6005E20B3 /* Transformers in Frameworks */, C3ED54522D790860005E20B3 /* MLX in Frameworks */, C3ED54502D790860005E20B3 /* MLXVLM in Frameworks */, C3ED54542D790860005E20B3 /* MLXNN in Frameworks */, C3ED54582D790A68005E20B3 /* MLXFast in Frameworks */, C3ED544E2D790860005E20B3 /* MLXLMCommon in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ C35EDB632D07699400757E80 = { isa = PBXGroup; children = ( 12FFAF3E2DC93583009C4EFA /* README.md */, 12FFAF3D2DC93583009C4EFA /* get_pretrained_mlx_model.sh */, C32B4A802DA4805400EF663D /* Configuration */, C35372AF2D08C32D00474D34 /* Video */, 019A3E0B2D78E6A00055F93B /* FastVLM App */, C39BB3E72D79082A005DB8FB /* FastVLM */, C35EDB7F2D076C3C00757E80 /* Frameworks */, C35EDB702D076A5A00757E80 /* Products */, ); sourceTree = ""; }; C35EDB702D076A5A00757E80 /* Products */ = { isa = PBXGroup; children = ( C35372AE2D08C32D00474D34 /* Video.framework */, 019A3E0A2D78E6A00055F93B /* FastVLM App.app */, C39BB3E62D79082A005DB8FB /* FastVLM.framework */, ); name = Products; sourceTree = ""; }; C35EDB7F2D076C3C00757E80 /* Frameworks */ = { isa = PBXGroup; children = ( ); name = Frameworks; sourceTree = ""; }; /* End PBXGroup section */ /* Begin PBXHeadersBuildPhase section */ C35372A92D08C32D00474D34 /* Headers */ = { isa = PBXHeadersBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C39BB3E12D79082A005DB8FB /* Headers */ = { isa = PBXHeadersBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXHeadersBuildPhase section */ /* Begin PBXNativeTarget section */ 019A3E092D78E6A00055F93B /* FastVLM App */ = { isa = PBXNativeTarget; buildConfigurationList = 019A3E182D78E6A20055F93B /* Build configuration list for PBXNativeTarget "FastVLM App" */; buildPhases = ( 019A3E062D78E6A00055F93B /* Sources */, 019A3E072D78E6A00055F93B /* Frameworks */, 019A3E082D78E6A00055F93B /* Resources */, 019A3E252D78E7530055F93B /* Embed Frameworks */, ); buildRules = ( ); dependencies = ( 019A3E242D78E7530055F93B /* PBXTargetDependency */, C3ED55042D7A0A7A005E20B3 /* PBXTargetDependency */, ); fileSystemSynchronizedGroups = ( 019A3E0B2D78E6A00055F93B /* FastVLM App */, ); name = "FastVLM App"; packageProductDependencies = ( 019A3E192D78E7370055F93B /* MLX */, 019A3E1B2D78E73E0055F93B /* MLXLMCommon */, 019A3E1D2D78E7470055F93B /* MLXRandom */, 019A3E1F2D78E74C0055F93B /* MLXVLM */, ); productName = FastVLMCameraExample; productReference = 019A3E0A2D78E6A00055F93B /* FastVLM App.app */; productType = "com.apple.product-type.application"; }; C35372AD2D08C32D00474D34 /* Video */ = { isa = PBXNativeTarget; buildConfigurationList = C35372BA2D08C32D00474D34 /* Build configuration list for PBXNativeTarget "Video" */; buildPhases = ( C35372A92D08C32D00474D34 /* Headers */, C35372AA2D08C32D00474D34 /* Sources */, C35372AB2D08C32D00474D34 /* Frameworks */, C35372AC2D08C32D00474D34 /* Resources */, ); buildRules = ( ); dependencies = ( ); fileSystemSynchronizedGroups = ( C35372AF2D08C32D00474D34 /* Video */, ); name = Video; packageProductDependencies = ( ); productName = Video; productReference = C35372AE2D08C32D00474D34 /* Video.framework */; productType = "com.apple.product-type.framework"; }; C39BB3E52D79082A005DB8FB /* FastVLM */ = { isa = PBXNativeTarget; buildConfigurationList = C39BB3FF2D79082A005DB8FB /* Build configuration list for PBXNativeTarget "FastVLM" */; buildPhases = ( C39BB3E12D79082A005DB8FB /* Headers */, C39BB3E22D79082A005DB8FB /* Sources */, C39BB3E32D79082A005DB8FB /* Frameworks */, C39BB3E42D79082A005DB8FB /* Resources */, ); buildRules = ( ); dependencies = ( ); fileSystemSynchronizedGroups = ( C39BB3E72D79082A005DB8FB /* FastVLM */, ); name = FastVLM; packageProductDependencies = ( C3ED544D2D790860005E20B3 /* MLXLMCommon */, C3ED544F2D790860005E20B3 /* MLXVLM */, C3ED54512D790860005E20B3 /* MLX */, C3ED54532D790860005E20B3 /* MLXNN */, C3ED54572D790A68005E20B3 /* MLXFast */, C3ED545A2D790AD6005E20B3 /* Transformers */, ); productName = FastVLM; productReference = C39BB3E62D79082A005DB8FB /* FastVLM.framework */; productType = "com.apple.product-type.framework"; }; /* End PBXNativeTarget section */ /* Begin PBXProject section */ C35EDB642D07699400757E80 /* Project object */ = { isa = PBXProject; attributes = { BuildIndependentTargetsInParallel = 1; LastSwiftUpdateCheck = 1620; LastUpgradeCheck = 1630; TargetAttributes = { 019A3E092D78E6A00055F93B = { CreatedOnToolsVersion = 16.2; }; C35372AD2D08C32D00474D34 = { CreatedOnToolsVersion = 16.0; }; C39BB3E52D79082A005DB8FB = { CreatedOnToolsVersion = 16.0; LastSwiftMigration = 1600; }; }; }; buildConfigurationList = C35EDB672D07699400757E80 /* Build configuration list for PBXProject "FastVLM" */; developmentRegion = en; hasScannedForEncodings = 0; knownRegions = ( en, Base, ); mainGroup = C35EDB632D07699400757E80; minimizedProjectReferenceProxies = 1; packageReferences = ( C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */, C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */, C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */, ); preferredProjectObjectVersion = 77; productRefGroup = C35EDB702D076A5A00757E80 /* Products */; projectDirPath = ""; projectRoot = ""; targets = ( C35372AD2D08C32D00474D34 /* Video */, 019A3E092D78E6A00055F93B /* FastVLM App */, C39BB3E52D79082A005DB8FB /* FastVLM */, ); }; /* End PBXProject section */ /* Begin PBXResourcesBuildPhase section */ 019A3E082D78E6A00055F93B /* Resources */ = { isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C35372AC2D08C32D00474D34 /* Resources */ = { isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C39BB3E42D79082A005DB8FB /* Resources */ = { isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXResourcesBuildPhase section */ /* Begin PBXSourcesBuildPhase section */ 019A3E062D78E6A00055F93B /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C35372AA2D08C32D00474D34 /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; C39BB3E22D79082A005DB8FB /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXSourcesBuildPhase section */ /* Begin PBXTargetDependency section */ 019A3E242D78E7530055F93B /* PBXTargetDependency */ = { isa = PBXTargetDependency; target = C35372AD2D08C32D00474D34 /* Video */; targetProxy = 019A3E232D78E7530055F93B /* PBXContainerItemProxy */; }; C3ED55042D7A0A7A005E20B3 /* PBXTargetDependency */ = { isa = PBXTargetDependency; target = C39BB3E52D79082A005DB8FB /* FastVLM */; targetProxy = C3ED55032D7A0A7A005E20B3 /* PBXContainerItemProxy */; }; /* End PBXTargetDependency section */ /* Begin XCBuildConfiguration section */ 019A3E162D78E6A20055F93B /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_ENTITLEMENTS = "FastVLM App/FastVLM.entitlements"; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 0.1.0; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; DEVELOPMENT_ASSET_PATHS = "\"FastVLM App/Preview Content\""; DEVELOPMENT_TEAM = ""; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; GCC_OPTIMIZATION_LEVEL = 0; GCC_PREPROCESSOR_DEFINITIONS = ( "DEBUG=1", "$(inherited)", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = "FastVLM App/Info.plist"; INFOPLIST_KEY_CFBundleDisplayName = FastVLM; INFOPLIST_KEY_NSCameraUsageDescription = "Use camera to get live feed of images"; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; INFOPLIST_KEY_UIRequiresFullScreen = YES; "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait; IPHONEOS_DEPLOYMENT_TARGET = 18.2; LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.2; MARKETING_VERSION = 1.0; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; PRODUCT_BUNDLE_IDENTIFIER = com.apple.ml.FastVLM; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; XROS_DEPLOYMENT_TARGET = 2.2; }; name = Debug; }; 019A3E172D78E6A20055F93B /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_ENTITLEMENTS = "FastVLM App/FastVLM.entitlements"; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 0.1.0; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEVELOPMENT_ASSET_PATHS = "\"FastVLM App/Preview Content\""; DEVELOPMENT_TEAM = ""; ENABLE_HARDENED_RUNTIME = YES; ENABLE_NS_ASSERTIONS = NO; ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = "FastVLM App/Info.plist"; INFOPLIST_KEY_CFBundleDisplayName = FastVLM; INFOPLIST_KEY_NSCameraUsageDescription = "Use camera to get live feed of images"; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; INFOPLIST_KEY_UIRequiresFullScreen = YES; "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait; IPHONEOS_DEPLOYMENT_TARGET = 18.2; LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.2; MARKETING_VERSION = 1.0; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; PRODUCT_BUNDLE_IDENTIFIER = com.apple.ml.FastVLM; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; XROS_DEPLOYMENT_TARGET = 2.2; }; name = Release; }; C35372B72D08C32D00474D34 /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; BUILD_LIBRARY_FOR_DISTRIBUTION = YES; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_IDENTITY = ""; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; DEFINES_MODULE = YES; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; ENABLE_MODULE_VERIFIER = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; GCC_OPTIMIZATION_LEVEL = 0; GCC_PREPROCESSOR_DEFINITIONS = ( "DEBUG=1", "$(inherited)", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSHumanReadableCopyright = ""; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; IPHONEOS_DEPLOYMENT_TARGET = 18.0; LD_RUNPATH_SEARCH_PATHS = ( "@executable_path/Frameworks", "@loader_path/Frameworks", ); "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( "@executable_path/../Frameworks", "@loader_path/Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.0; MARKETING_VERSION = 1.0; MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; PRODUCT_BUNDLE_IDENTIFIER = mlx.Video; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SDKROOT = auto; SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_INSTALL_OBJC_HEADER = NO; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; VERSIONING_SYSTEM = "apple-generic"; VERSION_INFO_PREFIX = ""; XROS_DEPLOYMENT_TARGET = 2.0; }; name = Debug; }; C35372B82D08C32D00474D34 /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; BUILD_LIBRARY_FOR_DISTRIBUTION = YES; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_IDENTITY = ""; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEFINES_MODULE = YES; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; ENABLE_MODULE_VERIFIER = YES; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSHumanReadableCopyright = ""; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; IPHONEOS_DEPLOYMENT_TARGET = 18.0; LD_RUNPATH_SEARCH_PATHS = ( "@executable_path/Frameworks", "@loader_path/Frameworks", ); "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( "@executable_path/../Frameworks", "@loader_path/Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.0; MARKETING_VERSION = 1.0; MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; PRODUCT_BUNDLE_IDENTIFIER = mlx.Video; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SDKROOT = auto; SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_INSTALL_OBJC_HEADER = NO; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; VERSIONING_SYSTEM = "apple-generic"; VERSION_INFO_PREFIX = ""; XROS_DEPLOYMENT_TARGET = 2.0; }; name = Release; }; C35EDB682D07699400757E80 /* Debug */ = { isa = XCBuildConfiguration; baseConfigurationReferenceAnchor = C32B4A802DA4805400EF663D /* Configuration */; baseConfigurationReferenceRelativePath = Build.xcconfig; buildSettings = { ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_TEAM = 565ARCVNXV; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; ONLY_ACTIVE_ARCH = YES; }; name = Debug; }; C35EDB692D07699400757E80 /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_TEAM = 565ARCVNXV; ENABLE_STRICT_OBJC_MSGSEND = YES; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; }; name = Release; }; C39BB3FB2D79082A005DB8FB /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; BUILD_LIBRARY_FOR_DISTRIBUTION = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_IDENTITY = ""; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; DEFINES_MODULE = NO; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; ENABLE_MODULE_VERIFIER = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; GCC_OPTIMIZATION_LEVEL = 0; GCC_PREPROCESSOR_DEFINITIONS = ( "DEBUG=1", "$(inherited)", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSHumanReadableCopyright = ""; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; IPHONEOS_DEPLOYMENT_TARGET = 18.0; LD_RUNPATH_SEARCH_PATHS = ( "@executable_path/Frameworks", "@loader_path/Frameworks", ); "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( "@executable_path/../Frameworks", "@loader_path/Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.0; MARKETING_VERSION = 1.0; MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; PRODUCT_BUNDLE_IDENTIFIER = mlx.FastVLM; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SDKROOT = auto; SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_INSTALL_OBJC_HEADER = NO; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; VERSIONING_SYSTEM = "apple-generic"; VERSION_INFO_PREFIX = ""; XROS_DEPLOYMENT_TARGET = 2.0; }; name = Debug; }; C39BB3FC2D79082A005DB8FB /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; BUILD_LIBRARY_FOR_DISTRIBUTION = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_IDENTITY = ""; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEFINES_MODULE = NO; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; ENABLE_MODULE_VERIFIER = YES; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSHumanReadableCopyright = ""; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; IPHONEOS_DEPLOYMENT_TARGET = 18.0; LD_RUNPATH_SEARCH_PATHS = ( "@executable_path/Frameworks", "@loader_path/Frameworks", ); "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( "@executable_path/../Frameworks", "@loader_path/Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 15.0; MARKETING_VERSION = 1.0; MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; PRODUCT_BUNDLE_IDENTIFIER = mlx.FastVLM; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SDKROOT = auto; SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_INSTALL_OBJC_HEADER = NO; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; VERSIONING_SYSTEM = "apple-generic"; VERSION_INFO_PREFIX = ""; XROS_DEPLOYMENT_TARGET = 2.0; }; name = Release; }; /* End XCBuildConfiguration section */ /* Begin XCConfigurationList section */ 019A3E182D78E6A20055F93B /* Build configuration list for PBXNativeTarget "FastVLM App" */ = { isa = XCConfigurationList; buildConfigurations = ( 019A3E162D78E6A20055F93B /* Debug */, 019A3E172D78E6A20055F93B /* Release */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; C35372BA2D08C32D00474D34 /* Build configuration list for PBXNativeTarget "Video" */ = { isa = XCConfigurationList; buildConfigurations = ( C35372B72D08C32D00474D34 /* Debug */, C35372B82D08C32D00474D34 /* Release */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; C35EDB672D07699400757E80 /* Build configuration list for PBXProject "FastVLM" */ = { isa = XCConfigurationList; buildConfigurations = ( C35EDB682D07699400757E80 /* Debug */, C35EDB692D07699400757E80 /* Release */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; C39BB3FF2D79082A005DB8FB /* Build configuration list for PBXNativeTarget "FastVLM" */ = { isa = XCConfigurationList; buildConfigurations = ( C39BB3FB2D79082A005DB8FB /* Debug */, C39BB3FC2D79082A005DB8FB /* Release */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; /* End XCConfigurationList section */ /* Begin XCRemoteSwiftPackageReference section */ C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/ml-explore/mlx-swift-examples"; requirement = { kind = upToNextMajorVersion; minimumVersion = 2.21.2; }; }; C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/ml-explore/mlx-swift"; requirement = { kind = upToNextMajorVersion; minimumVersion = 0.21.2; }; }; C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/huggingface/swift-transformers"; requirement = { kind = upToNextMajorVersion; minimumVersion = 0.1.18; }; }; /* End XCRemoteSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ 019A3E192D78E7370055F93B /* MLX */ = { isa = XCSwiftPackageProductDependency; package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLX; }; 019A3E1B2D78E73E0055F93B /* MLXLMCommon */ = { isa = XCSwiftPackageProductDependency; package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */; productName = MLXLMCommon; }; 019A3E1D2D78E7470055F93B /* MLXRandom */ = { isa = XCSwiftPackageProductDependency; package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLXRandom; }; 019A3E1F2D78E74C0055F93B /* MLXVLM */ = { isa = XCSwiftPackageProductDependency; package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */; productName = MLXVLM; }; C3ED544D2D790860005E20B3 /* MLXLMCommon */ = { isa = XCSwiftPackageProductDependency; package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */; productName = MLXLMCommon; }; C3ED544F2D790860005E20B3 /* MLXVLM */ = { isa = XCSwiftPackageProductDependency; package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */; productName = MLXVLM; }; C3ED54512D790860005E20B3 /* MLX */ = { isa = XCSwiftPackageProductDependency; package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLX; }; C3ED54532D790860005E20B3 /* MLXNN */ = { isa = XCSwiftPackageProductDependency; package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLXNN; }; C3ED54572D790A68005E20B3 /* MLXFast */ = { isa = XCSwiftPackageProductDependency; package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLXFast; }; C3ED545A2D790AD6005E20B3 /* Transformers */ = { isa = XCSwiftPackageProductDependency; package = C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */; productName = Transformers; }; /* End XCSwiftPackageProductDependency section */ }; rootObject = C35EDB642D07699400757E80 /* Project object */; } ================================================ FILE: app/FastVLM.xcodeproj/xcshareddata/xcschemes/FastVLM App.xcscheme ================================================ ================================================ FILE: app/README.md ================================================ # FastVLM Demonstrates the performance of **FastVLM** models for on-device, visual question answering.
FastVLM - Counting FastVLM - Handwriting FastVLM - Emoji
## Features - FastVLM runs on iOS (18.2+) and macOS (15.2+). - View Time-To-First-Token (TTFT) with every inference. - All predictions are processed privately and securely using on-device models. ### Flexible Prompting Flexible prompting The app includes a set of built-in prompts to help you get started quickly. Tap the **Prompts** button in the top-right corner to explore them. Selecting a prompt will immediately update the active input. To create new prompts or edit existing ones, choose **Customize…** from the **Prompts** menu. ## Pretrained Model Options There are 3 pretrained sizes of FastVLM to choose from: - **FastVLM 0.5B**: Small and fast - great for mobile devices where speed matters. - **FastVLM 1.5B**: Well balanced - great for larger devices where speed and accuracy matters. - **FastVLM 7B**: Fast and accurate - ideal for situations where accuracy matters over speed. To download any FastVLM listed above, use the [get_pretrained_mlx_model.sh](get_pretrained_mlx_model.sh) script. The script downloads the model from the web and places it in the appropriate location. Once a model has been downloaded using the steps below, no additional steps are needed to build the app in Xcode. To explore how the other models work for your use-case, simply re-run the `get_pretrained_mlx_model.sh` with the new model selected, follow the prompts, and rebuild your app in Xcode. ### Download Instructions 1. Make the script executable ```shell chmod +x app/get_pretrained_mlx_model.sh ``` 2. Download FastVLM ```shell app/get_pretrained_mlx_model.sh --model 0.5b --dest app/FastVLM/model ``` 3. Open the app in Xcode, Build, and Run. ### Custom Model In addition to pretrained sizes of FastVLM, you can further quantize or fine-tune FastVLM to best fit their needs. To learn more, check out our documentation on how to [`export the model`](../model_export#export-vlm). Please clear existing model in `app/FastVLM/model` before downloading or copying a new model. ================================================ FILE: app/Video/CameraController.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import AVFoundation import CoreImage #if os(iOS) import UIKit #endif @Observable public class CameraController: NSObject { private var framesContinuation: AsyncStream.Continuation? public var backCamera = true { didSet { stop() start() } } public var devices = [AVCaptureDevice]() public var device: AVCaptureDevice = AVCaptureDevice.default(for: .video)! { didSet { stop() start() } } private var permissionGranted = true private var captureSession: AVCaptureSession? private let sessionQueue = DispatchQueue(label: "sessionQueue") @objc dynamic private var rotationCoordinator : AVCaptureDevice.RotationCoordinator? private var rotationObservation: NSKeyValueObservation? public func attach(continuation: AsyncStream.Continuation) { sessionQueue.async { self.framesContinuation = continuation } } public func detatch() { sessionQueue.async { self.framesContinuation = nil } } public func stop() { sessionQueue.sync { [self] in captureSession?.stopRunning() captureSession = nil } } public func start() { sessionQueue.async { [self] in let captureSession = AVCaptureSession() self.captureSession = captureSession self.checkPermission() self.setupCaptureSession(position: backCamera ? .back : .front) captureSession.startRunning() } } #if os(iOS) private func setOrientation(_ orientation: UIDeviceOrientation) { guard let captureSession else { return } let angle: Double? switch orientation { case .unknown, .faceDown: angle = nil case .portrait, .faceUp: angle = 90 case .portraitUpsideDown: angle = 270 case .landscapeLeft: angle = 0 case .landscapeRight: angle = 180 @unknown default: angle = nil } if let angle { for output in captureSession.outputs { output.connection(with: .video)?.videoRotationAngle = angle } } } private func updateRotation(rotation : CGFloat) { guard let captureSession else { return } for output in captureSession.outputs { output.connection(with: .video)?.videoRotationAngle = rotation } } #endif func checkPermission() { switch AVCaptureDevice.authorizationStatus(for: .video) { case .authorized: // The user has previously granted access to the camera. self.permissionGranted = true case .notDetermined: // The user has not yet been asked for camera access. self.requestPermission() // Combine the two other cases into the default case default: self.permissionGranted = false } } func requestPermission() { // Strong reference not a problem here but might become one in the future. AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in self.permissionGranted = granted } } func setupCaptureSession(position: AVCaptureDevice.Position) { guard let captureSession else { return } let videoOutput = AVCaptureVideoDataOutput() guard permissionGranted else { print("No permission for camera") return } let deviceTypes: [AVCaptureDevice.DeviceType] #if os(iOS) deviceTypes = [.builtInDualCamera, .builtInWideAngleCamera] #else deviceTypes = [.external, .continuityCamera, .builtInWideAngleCamera] #endif let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession( deviceTypes: deviceTypes, mediaType: .video, position: position) let videoDevice: AVCaptureDevice? if videoDeviceDiscoverySession.devices.contains(self.device) { videoDevice = self.device } else { videoDevice = videoDeviceDiscoverySession.devices.first } if devices.isEmpty { self.devices = videoDeviceDiscoverySession.devices } guard let videoDevice else { print("Unable to find video device") return } guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else { print("Unable to create AVCaptureDeviceInput") return } guard captureSession.canAddInput(videoDeviceInput) else { print("Unable to add input") return } captureSession.addInput(videoDeviceInput) videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue")) captureSession.addOutput(videoOutput) captureSession.sessionPreset = AVCaptureSession.Preset.hd1920x1080 #if os(iOS) rotationCoordinator = AVCaptureDevice.RotationCoordinator(device: videoDevice, previewLayer: nil) rotationObservation = observe(\.rotationCoordinator!.videoRotationAngleForHorizonLevelCapture, options: [.initial, .new]) { [weak self] _, change in if let nv = change.newValue { self?.updateRotation(rotation: nv) } } #endif } } extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate { public func captureOutput( _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection ) { if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil { framesContinuation?.yield(sampleBuffer) } } } ================================================ FILE: app/Video/CameraControlsView.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import AVFoundation import SwiftUI public struct CameraControlsView: View { @Binding public var backCamera: Bool @Binding public var device: AVCaptureDevice @Binding public var devices: [AVCaptureDevice] public init( backCamera: Binding, device: Binding, devices: Binding<[AVCaptureDevice]> ) { self._backCamera = backCamera self._device = device self._devices = devices } public var body: some View { Button { backCamera.toggle() } label: { RoundedRectangle(cornerRadius: 8.0) .fill(.regularMaterial) .frame(width: 32.0, height: 32.0) .overlay(alignment: .center) { // Switch cameras image Image(systemName: "arrow.triangle.2.circlepath.camera.fill") .foregroundStyle(.primary) .padding(6.0) } } } } ================================================ FILE: app/Video/CameraType.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import Foundation public enum CameraType: String, CaseIterable { case continuous case single } ================================================ FILE: app/Video/Video.h ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // #import //! Project version number for Video. FOUNDATION_EXPORT double VideoVersionNumber; //! Project version string for Video. FOUNDATION_EXPORT const unsigned char VideoVersionString[]; ================================================ FILE: app/Video/VideoFrameView.swift ================================================ // // For licensing see accompanying LICENSE file. // Copyright (C) 2025 Apple Inc. All Rights Reserved. // import AVFoundation import CoreImage import Foundation import SwiftUI /// Displays a stream of video frames public struct VideoFrameView: View { @Environment(\.colorScheme) private var colorScheme public let frames: AsyncStream public let cameraType: CameraType public let action: ((CVImageBuffer) -> Void)? @State private var hold: Bool = false @State private var videoFrame: CVImageBuffer? private var backgroundColor: Color { #if os(iOS) return Color(.secondarySystemBackground) #elseif os(macOS) return Color(.secondarySystemFill) #else // When in doubt, use these values that I captured to match iOS' secondarySystemBackground if colorScheme == .dark { return Color(red: 0.11, green: 0.11, blue: 0.12) } else { return Color(red: 0.95, green: 0.95, blue: 0.97) } #endif } public init( frames: AsyncStream, cameraType: CameraType, action: ((CVImageBuffer) -> Void)? ) { self.frames = frames self.cameraType = cameraType self.action = action } public var body: some View { Group { if let videoFrame { _ImageView(image: videoFrame) .overlay(alignment: .bottom) { if cameraType == .single { Button { tap() } label: { if hold { Label("Resume", systemImage: "play.fill") } else { Label("Capture Photo", systemImage: "camera.fill") } } .clipShape(.capsule) .buttonStyle(.borderedProminent) .tint(hold ? .gray : .accentColor) .foregroundColor(.white) .padding() } } } else { // spinner before the camera comes up ProgressView() .controlSize(.large) } } // This ensures that we take up the full 4/3 aspect ratio // even if we don't have an image to display .frame(maxWidth: .infinity, maxHeight: .infinity) .background(backgroundColor) .clipShape(RoundedRectangle(cornerRadius: 10.0)) .task { // feed frames to the _ImageView if Task.isCancelled { return } for await frame in frames { if !hold { videoFrame = frame } } } .onChange(of: cameraType) { _, newType in // No matter what, when the user switches to .continuous, // we need to continue showing updated frames if newType == .continuous { hold = false } } } private func tap() { if hold { // resume hold = false } else if let videoFrame { hold = true if let action { action(videoFrame) } } } } #if os(iOS) /// Internal view to display a CVImageBuffer private struct _ImageView: UIViewRepresentable { let image: Any var gravity = CALayerContentsGravity.resizeAspectFill func makeUIView(context: Context) -> UIView { let view = UIView() view.layer.contentsGravity = gravity return view } func updateUIView(_ uiView: UIView, context: Context) { uiView.layer.contents = image } } #else private struct _ImageView: NSViewRepresentable { let image: Any var gravity = CALayerContentsGravity.resizeAspectFill func makeNSView(context: Context) -> NSView { let view = NSView() view.wantsLayer = true view.layer?.contentsGravity = gravity return view } func updateNSView(_ uiView: NSView, context: Context) { uiView.layer?.contents = image } } #endif ================================================ FILE: app/get_pretrained_mlx_model.sh ================================================ #!/usr/bin/env bash # # For licensing see accompanying LICENSE_MODEL file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # set -e # Help function show_help() { local is_error=${1:-true} # Default to error mode if no argument provided echo "Usage: $0 --model --dest " echo echo "Required arguments:" echo " --model Size of the model to download" echo " --dest Directory where the model will be downloaded" echo echo "Available model sizes:" echo " 0.5b - 0.5B parameter model (FP16)" echo " 1.5b - 1.5B parameter model (INT8)" echo " 7b - 7B parameter model (INT4)" echo echo "Options:" echo " --help Show help message" # Exit with success (0) for help flag, error (1) for usage errors if [ "$is_error" = "false" ]; then exit 0 else exit 1 fi } # Parse command line arguments while [[ "$#" -gt 0 ]]; do case $1 in --model) model_size="$2"; shift ;; --dest) dest_dir="$2"; shift ;; --help) show_help false ;; # Explicit help request *) echo -e "Unknown parameter: $1\n"; show_help true ;; # Error case esac shift done # Validate required parameters if [ -z "$model_size" ]; then echo -e "Error: --model parameter is required\n" show_help true fi if [ -z "$dest_dir" ]; then echo -e "Error: --dest parameter is required\n" show_help true fi # Map model size to full model name case "$model_size" in "0.5b") model="llava-fastvithd_0.5b_stage3_llm.fp16" ;; "1.5b") model="llava-fastvithd_1.5b_stage3_llm.int8" ;; "7b") model="llava-fastvithd_7b_stage3_llm.int4" ;; *) echo -e "Error: Invalid model size '$model_size'\n" show_help true ;; esac cleanup() { rm -rf "$tmp_dir" } download_model() { # Download directory tmp_dir=$(mktemp -d) # Model paths base_url="https://ml-site.cdn-apple.com/datasets/fastvlm" # Create destination directory if it doesn't exist if [ ! -d "$dest_dir" ]; then echo "Creating destination directory: $dest_dir" mkdir -p "$dest_dir" elif [ "$(ls -A "$dest_dir")" ]; then echo -e "Destination directory '$dest_dir' exists and is not empty.\n" read -p "Do you want to clear it and continue? [y/N]: " confirm if [[ ! "$confirm" =~ ^[Yy]$ ]]; then echo -e "\nStopping." exit 1 fi echo -e "\nClearing existing contents in '$dest_dir'" rm -rf "${dest_dir:?}"/* fi # Create temp variables tmp_zip_file="${tmp_dir}/${model}.zip" tmp_extract_dir="${tmp_dir}/${model}" # Create temp extract directory mkdir -p "$tmp_extract_dir" # Download model echo -e "\nDownloading '${model}' model ...\n" wget -q --progress=bar:noscroll --show-progress -O "$tmp_zip_file" "$base_url/$model.zip" # Unzip model echo -e "\nUnzipping model..." unzip -q "$tmp_zip_file" -d "$tmp_extract_dir" # Copy model files to destination directory echo -e "\nCopying model files to destination directory..." cp -r "$tmp_extract_dir/$model"/* "$dest_dir" # Verify destination directory exists and is not empty if [ ! -d "$dest_dir" ] || [ -z "$(ls -A "$dest_dir")" ]; then echo -e "\nModel extraction failed. Destination directory '$dest_dir' is missing or empty." exit 1 fi echo -e "\nModel downloaded and extracted to '$dest_dir'" } # Cleanup download directory on exit trap cleanup EXIT INT TERM # Download models download_model ================================================ FILE: get_models.sh ================================================ #!/usr/bin/env bash # # For licensing see accompanying LICENSE_MODEL file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # mkdir -p checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip -P checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip -P checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip -P checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip -P checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip -P checkpoints wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip -P checkpoints # Extract models cd checkpoints unzip -qq llava-fastvithd_0.5b_stage2.zip unzip -qq llava-fastvithd_0.5b_stage3.zip unzip -qq llava-fastvithd_1.5b_stage2.zip unzip -qq llava-fastvithd_1.5b_stage3.zip unzip -qq llava-fastvithd_7b_stage2.zip unzip -qq llava-fastvithd_7b_stage3.zip # Clean up rm llava-fastvithd_0.5b_stage2.zip rm llava-fastvithd_0.5b_stage3.zip rm llava-fastvithd_1.5b_stage2.zip rm llava-fastvithd_1.5b_stage3.zip rm llava-fastvithd_7b_stage2.zip rm llava-fastvithd_7b_stage3.zip cd - ================================================ FILE: llava/__init__.py ================================================ from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM ================================================ FILE: llava/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" IMAGE_PLACEHOLDER = "" ================================================ FILE: llava/conversation.py ================================================ import dataclasses from enum import auto, Enum from typing import List, Tuple import base64 from io import BytesIO from PIL import Image class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() QWEN_2 = auto() # fix: add qwen2 CHATML = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if 'mmtag' in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" # elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2 # seps = [self.sep, self.sep2] # ret = self.system + seps[0] # ret = "" # for i, (role, message) in enumerate(messages): # if message: # if type(message) is tuple: # message, _, _ = message # ret += role + ": " + message + seps[i % 2] # else: # ret += role + ":" elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2 ret = self.system + self.sep for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.CHATML: ret = "" if self.system == "" else self.system + self.sep + "\n" for role, message in messages: if message: if type(message) is tuple: message, images = message message = "" * len(images) + message ret += role + "\n" + message + self.sep + "\n" else: ret += role + "\n" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: def wrap_sys(msg): return f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg def wrap_inst(msg): return f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode in ["Default", "Crop"]: pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") if max(image.size) > max_len: max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: return image else: buffered = BytesIO() image.save(buffered, format=image_format) img_b64_str = base64.b64encode(buffered.getvalue()).decode() return img_b64_str def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: msg, image, image_process_mode = msg image = self.process_image(image, image_process_mode, return_pil=return_pil) images.append(image) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: msg, image, image_process_mode = msg img_b64_str = self.process_image( image, "Default", return_pil=False, image_format='JPEG') img_str = f'user upload image' msg = img_str + msg.replace('', '').strip() ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llava_llama_2 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mpt = Conversation( system="""<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_llava_plain = Conversation( system="", roles=("", ""), messages=( ), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", ) conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v0_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), messages=( ), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", version="v0_mmtag", ) conv_llava_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llava_v1_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", version="v1_mmtag", ) conv_mistral_instruct = Conversation( system="", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_chatml_direct = Conversation( system="""<|im_start|>system Answer the questions.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_qwen_2 = Conversation( system="<|im_start|>system\nYou are a helpful assistant.", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="qwen_v2", messages=(), offset=0, sep_style=SeparatorStyle.QWEN_2, sep="<|im_end|>\n", ) # conv_qwen_2 = Conversation( # system="", # roles=("user", "assistant"), # version="qwen_v2", # messages=(), # offset=0, # sep_style=SeparatorStyle.QWEN_2, # sep=" ", # sep2="<|im_end|>", # ) # fix: add qwen2 # conv_qwen_2 = Conversation( # system="A chat between a curious user and an artificial intelligence assistant. " # "The assistant gives helpful, detailed, and polite answers to the user's questions.", # roles=("USER", "ASSISTANT"), # version="qwen_v2", # messages=(), # offset=0, # sep_style=SeparatorStyle.QWEN_2, # sep=" ", # sep2="<|endoftext|>", # ) # conv_qwen_2 = Conversation( # system="""<|im_start|>system # You are a helpful assistant.""", # roles=("<|im_start|>user", "<|im_start|>assistant"), # version="qwen_v2", # messages=[], # offset=0, # sep_style=SeparatorStyle.QWEN_2, # sep="<|im_end|>", # sep2="<|im_end|>", # ) default_conversation = conv_qwen_2 conv_templates = { "default": conv_qwen_2, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "qwen_2": conv_qwen_2, "llama_2": conv_llama_2, "mistral_instruct": conv_mistral_instruct, "chatml_direct": conv_chatml_direct, "mistral_direct": conv_chatml_direct, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, "v0_mmtag": conv_llava_v0_mmtag, "llava_v1": conv_llava_v1, "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "mpt": conv_mpt, } if __name__ == "__main__": print("conversation:", default_conversation.get_prompt()) ================================================ FILE: llava/mm_utils.py ================================================ import PIL from PIL import Image PIL.Image.MAX_IMAGE_PIXELS=500000000 from io import BytesIO import base64 import torch import math import ast from transformers import StoppingCriteria from llava.constants import IMAGE_TOKEN_INDEX def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float('inf') for width, height in possible_resolutions: scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def resize_and_pad_image(image, target_resolution): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) # Resize the image resized_image = image.resize((new_width, new_height)) new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image def divide_to_patches(image, patch_size): """ Divides an image into patches of a specified size. Args: image (PIL.Image.Image): The input image. patch_size (int): The size of each patch. Returns: list: A list of PIL.Image.Image objects representing the patches. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patch = image.crop(box) patches.append(patch) return patches def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (tuple): The size of the input image in the format (width, height). grid_pinpoints (str): A string representation of a list of possible resolutions. patch_size (int): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) width, height = select_best_resolution(image_size, possible_resolutions) return width // patch_size, height // patch_size def process_anyres_image(image, processor, grid_pinpoints): """ Process an image with variable resolutions. Args: image (PIL.Image.Image): The input image to be processed. processor: The image processor object. grid_pinpoints (str): A string representation of a list of possible resolutions. Returns: torch.Tensor: A tensor containing the processed image patches. """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, processor.crop_size['height']) image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] for image_patch in image_patches] return torch.stack(image_patches, dim=0) def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) elif image_aspect_ratio == "anyres": for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) else: return image_processor(images, return_tensors='pt')['pixel_values'] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] if torch.equal(truncated_output_ids, keyword_id): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: outputs = [] for i in range(output_ids.shape[0]): outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) return all(outputs) ================================================ FILE: llava/model/__init__.py ================================================ # try: from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig # except: # pass ================================================ FILE: llava/model/apply_delta.py ================================================ """ Usage: python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava import LlavaLlamaForCausalLM def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading delta") delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] += bparam print("Saving target model") delta.save_pretrained(target_model_path) delta_tokenizer.save_pretrained(target_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) args = parser.parse_args() apply_delta(args.base_model_path, args.target_model_path, args.delta_path) ================================================ FILE: llava/model/builder.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings import shutil from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch from llava.model import * from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): kwargs = {"device_map": device_map, **kwargs} if device != "cuda": kwargs['device_map'] = {"": device} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if use_flash_attn: kwargs['attn_implementation'] = 'flash_attention_2' if 'llava' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is None: warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') if 'lora' in model_name.lower() and model_base is not None: from llava.model.language_model.llava_llama import LlavaConfig lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA from base model...') if 'mpt' in model_name.lower(): if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) model = LlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) elif 'mistral' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaMistralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) elif 'dclm' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaOpenlmForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) # model = LlavaLlamaForCausalLM.from_pretrained( # model_path, # low_cpu_mem_usage=True, # **kwargs # ) model = LlavaQwen2ForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False if 'mpt' in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) image_processor = None if 'llava' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model(device_map=device_map) if device_map != 'auto': vision_tower.to(device=device_map, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len ================================================ FILE: llava/model/consolidate.py ================================================ """ Usage: python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate """ import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model import * from llava.model.utils import auto_upgrade def consolidate_ckpt(src_path, dst_path): print("Loading model") auto_upgrade(src_path) src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) src_model.save_pretrained(dst_path) src_tokenizer.save_pretrained(dst_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True) parser.add_argument("--dst", type=str, required=True) args = parser.parse_args() consolidate_ckpt(args.src, args.dst) ================================================ FILE: llava/model/language_model/llava_llama.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaConfig(LlamaConfig): model_type = "llava_llama" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes return inputs AutoConfig.register("llava_llama", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) ================================================ FILE: llava/model/language_model/llava_mistral.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ MistralConfig, MistralModel, MistralForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaMistralConfig(MistralConfig): model_type = "llava_mistral" class LlavaMistralModel(LlavaMetaModel, MistralModel): config_class = LlavaMistralConfig def __init__(self, config: MistralConfig): super(LlavaMistralModel, self).__init__(config) class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMistralConfig def __init__(self, config): super(MistralForCausalLM, self).__init__(config) self.model = LlavaMistralModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes return inputs AutoConfig.register("llava_mistral", LlavaMistralConfig) AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) ================================================ FILE: llava/model/language_model/llava_mpt.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch from transformers import AutoConfig, AutoModelForCausalLM, \ MptConfig, MptForCausalLM, MptModel from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaMptConfig(MptConfig): model_type = "llava_mpt" class LlavaMptModel(LlavaMetaModel, MptModel): config_class = LlavaMptConfig def __init__(self, config: MptConfig): config.hidden_size = config.d_model super(LlavaMptModel, self).__init__(config) def embed_tokens(self, x): return self.wte(x) class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMptConfig supports_gradient_checkpointing = True def __init__(self, config): super(MptForCausalLM, self).__init__(config) self.transformer = LlavaMptModel(config) self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.transformer def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlavaMptModel): module.gradient_checkpointing = value def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, images=None): input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) return super().forward( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) _inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) _inputs['images'] = images return _inputs AutoConfig.register("llava_mpt", LlavaMptConfig) AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) ================================================ FILE: llava/model/language_model/llava_qwen.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM class LlavaConfig(Qwen2Config): model_type = "llava_qwen2" class LlavaQwen2Model(LlavaMetaModel, Qwen2Model): config_class = LlavaConfig def __init__(self, config: Qwen2Config): super(LlavaQwen2Model, self).__init__(config) class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(Qwen2ForCausalLM, self).__init__(config) self.model = LlavaQwen2Model(config) # self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes return inputs AutoConfig.register("llava_qwen2", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM) ================================================ FILE: llava/model/llava_arch.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn as nn from .multimodal_encoder.builder import build_vision_tower from .multimodal_projector.builder import build_vision_projector from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.mm_utils import get_anyres_image_grid_shape class LlavaMetaModel: def __init__(self, config): super(LlavaMetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = build_vision_projector(config) if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): self.image_newline = nn.Parameter( torch.empty(config.hidden_size, dtype=self.dtype) ) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter mm_patch_merge_type = model_args.mm_patch_merge_type self.config.mm_vision_tower = vision_tower if self.get_vision_tower() is None: vision_tower = build_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower else: if fsdp is not None and len(fsdp) > 0: vision_tower = self.vision_tower[0] else: vision_tower = self.vision_tower vision_tower.load_model() self.config.use_mm_proj = True self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') self.config.mm_hidden_size = vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature self.config.mm_patch_merge_type = mm_patch_merge_type if getattr(self, 'mm_projector', None) is None: self.mm_projector = build_vision_projector(self.config) if 'unpad' in mm_patch_merge_type: embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) self.image_newline = nn.Parameter( torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std ) else: # In case it is frozen by LoRA for p in self.mm_projector.parameters(): p.requires_grad = True if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. original_size (tuple): The original size of PIL image (width, height). Returns: torch.Tensor: The unpadded image tensor. """ original_width, original_height = original_size current_height, current_width = tensor.shape[1:] original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding:current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding:current_width - padding] return unpadded_tensor class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') if mm_patch_merge_type == 'flat': image_features = [x.flatten(0, 1) for x in image_features] elif mm_patch_merge_type.startswith('spatial'): new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.get_vision_tower().num_patches_per_side assert height * width == base_image_feature.shape[0] if image_aspect_ratio == 'anyres': if hasattr(self.get_vision_tower(), 's2_image_size'): img_size = self.get_vision_tower().s2_image_size elif isinstance(self.get_vision_tower().config, dict): img_size = self.get_vision_tower().config["image_cfg"]["image_size"] else: img_size = self.get_vision_tower().config.image_size num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) else: raise NotImplementedError if 'unpad' in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat(( image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) ), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) else: image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() image_feature = image_feature.flatten(0, 3) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if 'unpad' in mm_patch_merge_type: image_feature = torch.cat(( image_feature, self.model.image_newline[None].to(image_feature.device) ), dim=0) new_image_features.append(image_feature) image_features = new_image_features else: raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") else: image_features = self.encode_images(images) # TODO: image start / end is not implemented here to support pretraining. if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): raise NotImplementedError # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. # But it is not ideal, and if you have a better idea, # please open an issue / submit a PR, thanks. _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask -- FIXME _input_ids = input_ids input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_image_features = image_features[cur_image_idx] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) if tokenizer_model_max_length is not None: new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": new_input_embeds_padded.append(torch.cat(( torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed ), dim=0)) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) else: new_input_embeds_padded.append(torch.cat(( cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) ), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False ================================================ FILE: llava/model/make_delta.py ================================================ """ Usage: python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta """ import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model.utils import auto_upgrade def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Loading target model") auto_upgrade(target_model_path) target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) print("Calculating delta") for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): if name not in base.state_dict(): assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data -= base.state_dict()[name] else: assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' bparam = base.state_dict()[name] param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam print("Saving delta") if hub_repo_id: kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} else: kwargs = {} target.save_pretrained(delta_path, **kwargs) target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) target_tokenizer.save_pretrained(delta_path, **kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) parser.add_argument("--hub-repo-id", type=str, default=None) args = parser.parse_args() make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) ================================================ FILE: llava/model/multimodal_encoder/builder.py ================================================ import os from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 from .mobileclip_encoder import MobileCLIPVisionTower def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) is_absolute_path_exists = os.path.exists(vision_tower) use_s2 = getattr(vision_tower_cfg, 's2', False) if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: if use_s2: return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) else: return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) elif "mobileclip" in vision_tower.lower(): return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f'Unknown vision tower: {vision_tower}') ================================================ FILE: llava/model/multimodal_encoder/clip_encoder.py ================================================ import torch import torch.nn as nn from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) self.input_image_size = getattr(args, 'input_image_size', None) if self.tune_vision_tower: print("CLIP Vision tower is set to tunable") if not delay_load: self.load_model() elif getattr(args, 'unfreeze_mm_vision_tower', False): self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) if self.input_image_size is not None: self.cfg_only.image_size = self.input_image_size def load_model(self, device_map=None): if self.is_loaded: print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) return self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) if not self.tune_vision_tower: self.vision_tower.requires_grad_(False) if self.input_image_size is not None: print("Using input image size: {}".format(self.input_image_size)) self.image_processor.size['shortest_edge'] = self.input_image_size self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.input_image_size self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features def forward(self, images): if self.tune_vision_tower: return self.forward_images(images) else: with torch.no_grad(): return self.forward_images(images) def forward_images(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches_per_side(self): return self.config.image_size // self.config.patch_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 class CLIPVisionTowerS2(CLIPVisionTower): def __init__(self, vision_tower, args, delay_load=False): self.s2_scales = getattr(args, 's2_scales', '336,672,1008') self.s2_scales = list(map(int, self.s2_scales.split(','))) self.s2_scales.sort() self.s2_split_size = self.s2_scales[0] self.s2_image_size = self.s2_scales[-1] super().__init__(vision_tower, args, delay_load) try: from s2wrapper import forward as multiscale_forward except ImportError: raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') self.multiscale_forward = multiscale_forward # change resize/crop size in preprocessing to the largest image size in s2_scale if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): self.image_processor.size['shortest_edge'] = self.s2_image_size self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size def load_model(self, device_map=None): if self.is_loaded: print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) return self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) self.vision_tower.requires_grad_(False) self.image_processor.size['shortest_edge'] = self.s2_image_size self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size self.is_loaded = True @torch.no_grad() def forward_feature(self, images): image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) image_features.append(image_feature) else: image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) return image_features @property def hidden_size(self): return self.config.hidden_size * len(self.s2_scales) ================================================ FILE: llava/model/multimodal_encoder/mobileclip/__init__.py ================================================ # # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import os import json from typing import Any import torch.nn as nn from timm.models import create_model from .mci import GlobalPool2D def load_model_config( model_name: str, ) -> Any: # Strip suffixes to model name model_name = "_".join(model_name.split("_")[0:2]) # Config files root_dir = os.path.dirname(os.path.abspath(__file__)) configs_dir = os.path.join(root_dir, "configs") model_cfg_file = os.path.join(configs_dir, model_name + ".json") # Get config from yaml file if not os.path.exists(model_cfg_file): raise ValueError(f"Unsupported model name: {model_name}") model_cfg = json.load(open(model_cfg_file, "r")) return model_cfg class MCi(nn.Module): """ This class implements `MCi Models `_ """ def __init__(self, model_name: str, *args, **kwargs) -> None: super().__init__() self.projection_dim = None if "projection_dim" in kwargs: self.projection_dim = kwargs.get("projection_dim") # Create model self.model = create_model(model_name, projection_dim=self.projection_dim) # Build out projection head. if self.projection_dim is not None: if hasattr(self.model, "head"): self.model.head = MCi._update_image_classifier( image_classifier=self.model.head, projection_dim=self.projection_dim ) def forward(self, x: Any, *args, **kwargs) -> Any: """A forward function of the model.""" x = self.model(x, *args, **kwargs) return x @staticmethod def _get_in_feature_dimension(image_classifier: nn.Module) -> int: """Return the input feature dimension to the image classification head.""" in_features = None if isinstance(image_classifier, nn.Sequential): # Classifier that uses nn.Sequential usually has global pooling and # multiple linear layers. Find the first linear layer and get its # in_features for layer in image_classifier: if isinstance(layer, nn.Linear): in_features = layer.in_features break elif isinstance(image_classifier, nn.Linear): in_features = image_classifier.in_features if in_features is None: raise NotImplementedError( f"Cannot get input feature dimension of {image_classifier}." ) return in_features @staticmethod def _update_image_classifier( image_classifier: nn.Module, projection_dim: int, *args, **kwargs ) -> nn.Module: in_features = MCi._get_in_feature_dimension(image_classifier) new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim) return new_img_classifier ================================================ FILE: llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json ================================================ { "embed_dim": 768, "image_cfg": { "image_size": 1024, "model_name": "fastvithd", "embed_dim": 3072, "patch_size": 64 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "dim": 768, "ffn_multiplier_per_layer": 4.0, "n_heads_per_layer": 12, "n_transformer_layers": 12, "norm_layer": "layer_norm_fp32", "causal_masking": false, "model_name": "base" } } ================================================ FILE: llava/model/multimodal_encoder/mobileclip/mci.py ================================================ # # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import copy from functools import partial from typing import List, Tuple, Optional, Union, Dict import torch import torch.nn as nn from torch import Tensor import torch.nn.functional as F from torch.nn.init import normal_ from timm.models import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, SqueezeExcite def _cfg(url="", **kwargs): return { "url": url, "num_classes": 1000, "input_size": (3, 256, 256), "pool_size": None, "crop_pct": 0.95, "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "classifier": "head", **kwargs, } default_cfgs = { "fastvit_t": _cfg(crop_pct=0.9), "fastvit_s": _cfg(crop_pct=0.9), "fastvit_m": _cfg(crop_pct=0.95), } class SEBlock(nn.Module): """Squeeze and Excite module. Pytorch implementation of `Squeeze-and-Excitation Networks` - https://arxiv.org/pdf/1709.01507.pdf """ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: """Construct a Squeeze and Excite Module. Args: in_channels: Number of input channels. rd_ratio: Input channel reduction ratio. """ super(SEBlock, self).__init__() self.reduce = nn.Conv2d( in_channels=in_channels, out_channels=int(in_channels * rd_ratio), kernel_size=1, stride=1, bias=True, ) self.expand = nn.Conv2d( in_channels=int(in_channels * rd_ratio), out_channels=in_channels, kernel_size=1, stride=1, bias=True, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" b, c, h, w = inputs.size() x = F.avg_pool2d(inputs, kernel_size=[h, w]) x = self.reduce(x) x = F.relu(x) x = self.expand(x) x = torch.sigmoid(x) x = x.view(-1, c, 1, 1) return inputs * x class MobileOneBlock(nn.Module): """MobileOne building block. This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` - https://arxiv.org/pdf/2206.04040.pdf """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, inference_mode: bool = False, use_se: bool = False, use_act: bool = True, use_scale_branch: bool = True, num_conv_branches: int = 1, activation: nn.Module = nn.GELU(), ) -> None: """Construct a MobileOneBlock module. Args: in_channels: Number of channels in the input. out_channels: Number of channels produced by the block. kernel_size: Size of the convolution kernel. stride: Stride size. padding: Zero-padding size. dilation: Kernel dilation factor. groups: Group number. inference_mode: If True, instantiates model in inference mode. use_se: Whether to use SE-ReLU activations. use_act: Whether to use activation. Default: ``True`` use_scale_branch: Whether to use scale branch. Default: ``True`` num_conv_branches: Number of linear conv branches. """ super(MobileOneBlock, self).__init__() self.inference_mode = inference_mode self.groups = groups self.stride = stride self.padding = padding self.dilation = dilation self.kernel_size = kernel_size self.in_channels = in_channels self.out_channels = out_channels self.num_conv_branches = num_conv_branches # Check if SE-ReLU is requested if use_se: self.se = SEBlock(out_channels) else: self.se = nn.Identity() if use_act: self.activation = activation else: self.activation = nn.Identity() if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True, ) else: # Re-parameterizable skip connection # Fallback, sometimes batchnorm tensors # do not get instantiated correctly on some processes # when using deepspeed + accelerate norm_layer = nn.BatchNorm2d(num_features=in_channels) if norm_layer.weight.shape[0] == 0: norm_layer.weight = nn.Parameter(torch.zeros(in_channels)) if norm_layer.bias.shape[0] == 0: norm_layer.bias = nn.Parameter(torch.zeros(in_channels)) self.rbr_skip = ( norm_layer if out_channels == in_channels and stride == 1 else None ) # Re-parameterizable conv branches if num_conv_branches > 0: rbr_conv = list() for _ in range(self.num_conv_branches): rbr_conv.append( self._conv_bn(kernel_size=kernel_size, padding=padding) ) self.rbr_conv = nn.ModuleList(rbr_conv) else: self.rbr_conv = None # Re-parameterizable scale branch self.rbr_scale = None if not isinstance(kernel_size, int): kernel_size = kernel_size[0] if (kernel_size > 1) and use_scale_branch: self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" # Inference mode forward pass. if self.inference_mode: return self.activation(self.se(self.reparam_conv(x))) # Multi-branched train-time forward pass. # Skip branch output identity_out = 0 if self.rbr_skip is not None: identity_out = self.rbr_skip(x) # Scale branch output scale_out = 0 if self.rbr_scale is not None: scale_out = self.rbr_scale(x) # Other branches out = scale_out + identity_out if self.rbr_conv is not None: for ix in range(self.num_conv_branches): out += self.rbr_conv[ix](x) return self.activation(self.se(out)) def reparameterize(self): """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ if self.inference_mode: return kernel, bias = self._get_kernel_bias() self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True, ) self.reparam_conv.weight.data = kernel self.reparam_conv.bias.data = bias # Delete un-used branches self.__delattr__("rbr_conv") self.__delattr__("rbr_scale") if hasattr(self, "rbr_skip"): self.__delattr__("rbr_skip") self.inference_mode = True def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 Returns: Tuple of (kernel, bias) after fusing branches. """ # get weights and bias of scale branch kernel_scale = 0 bias_scale = 0 if self.rbr_scale is not None: kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) # Pad scale branch kernel to match conv branch kernel size. pad = self.kernel_size // 2 kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) # get weights and bias of skip branch kernel_identity = 0 bias_identity = 0 if self.rbr_skip is not None: kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) # get weights and bias of conv branches kernel_conv = 0 bias_conv = 0 if self.rbr_conv is not None: for ix in range(self.num_conv_branches): _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) kernel_conv += _kernel bias_conv += _bias kernel_final = kernel_conv + kernel_scale + kernel_identity bias_final = bias_conv + bias_scale + bias_identity return kernel_final, bias_final def _fuse_bn_tensor( self, branch: Union[nn.Sequential, nn.BatchNorm2d] ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with preceeding conv layer. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 Args: branch: Sequence of ops to be fused. Returns: Tuple of (kernel, bias) after fusing batchnorm. """ if isinstance(branch, nn.Sequential): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, "id_tensor"): input_dim = self.in_channels // self.groups kernel_size = self.kernel_size if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size, self.kernel_size) kernel_value = torch.zeros( (self.in_channels, input_dim, kernel_size[0], kernel_size[1]), dtype=branch.weight.dtype, device=branch.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2 ] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: """Helper method to construct conv-batchnorm layers. Args: kernel_size: Size of the convolution kernel. padding: Zero-padding size. Returns: Conv-BN module. """ # Fallback, sometimes batchnorm tensors # do not get instantiated correctly on some processes # when using deepspeed + accelerate norm_layer = nn.BatchNorm2d(num_features=self.out_channels) if norm_layer.weight.shape[0] == 0: norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels)) if norm_layer.bias.shape[0] == 0: norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels)) mod_list = nn.Sequential() mod_list.add_module( "conv", nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=kernel_size, stride=self.stride, padding=padding, groups=self.groups, bias=False, ), ) mod_list.add_module("bn", norm_layer) return mod_list class ReparamLargeKernelConv(nn.Module): """Building Block of RepLKNet This class defines overparameterized large kernel conv block introduced in `RepLKNet `_ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, small_kernel: int, inference_mode: bool = False, use_se: bool = False, activation: nn.Module = nn.GELU(), ) -> None: """Construct a ReparamLargeKernelConv module. Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Kernel size of the large kernel conv branch. stride: Stride size. Default: 1 groups: Group number. Default: 1 small_kernel: Kernel size of small kernel conv branch. inference_mode: If True, instantiates model in inference mode. Default: ``False`` activation: Activation module. Default: ``nn.GELU`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride self.groups = groups self.in_channels = in_channels self.out_channels = out_channels self.activation = activation self.kernel_size = kernel_size self.small_kernel = small_kernel self.padding = kernel_size // 2 # Check if SE is requested if use_se: self.se = SqueezeExcite(out_channels, rd_ratio=0.25) else: self.se = nn.Identity() if inference_mode: self.lkb_reparam = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=self.padding, dilation=1, groups=groups, bias=True, ) else: self.lkb_origin = self._conv_bn( kernel_size=kernel_size, padding=self.padding ) if small_kernel is not None: assert ( small_kernel <= kernel_size ), "The kernel size for re-param cannot be larger than the large kernel!" self.small_conv = self._conv_bn( kernel_size=small_kernel, padding=small_kernel // 2 ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" if hasattr(self, "lkb_reparam"): out = self.lkb_reparam(x) else: out = self.lkb_origin(x) if hasattr(self, "small_conv"): out += self.small_conv(x) return self.activation(self.se(out)) def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch Returns: Tuple of (kernel, bias) after fusing branches. """ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) if hasattr(self, "small_conv"): small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) eq_b += small_b eq_k += nn.functional.pad( small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 ) return eq_k, eq_b def reparameterize(self) -> None: """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ eq_k, eq_b = self.get_kernel_bias() self.lkb_reparam = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.lkb_origin.conv.dilation, groups=self.groups, bias=True, ) self.lkb_reparam.weight.data = eq_k self.lkb_reparam.bias.data = eq_b self.__delattr__("lkb_origin") if hasattr(self, "small_conv"): self.__delattr__("small_conv") @staticmethod def _fuse_bn( conv: torch.Tensor, bn: nn.BatchNorm2d ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with conv layer. Args: conv: Convolutional kernel weights. bn: Batchnorm 2d layer. Returns: Tuple of (kernel, bias) after fusing batchnorm. """ kernel = conv.weight running_mean = bn.running_mean running_var = bn.running_var gamma = bn.weight beta = bn.bias eps = bn.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: """Helper method to construct conv-batchnorm layers. Args: kernel_size: Size of the convolution kernel. padding: Zero-padding size. Returns: A nn.Sequential Conv-BN module. """ # Fallback, sometimes batchnorm tensors # do not get instantiated correctly on some processes # when using deepspeed + accelerate norm_layer = nn.BatchNorm2d(num_features=self.out_channels) if norm_layer.weight.shape[0] == 0: norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels)) if norm_layer.bias.shape[0] == 0: norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels)) mod_list = nn.Sequential() mod_list.add_module( "conv", nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=kernel_size, stride=self.stride, padding=padding, groups=self.groups, bias=False, ), ) mod_list.add_module("bn", norm_layer) return mod_list def convolutional_stem( in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True, ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. Args: in_channels: Number of input channels. out_channels: Number of output channels. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` Returns: nn.Sequential object with stem elements. """ return nn.Sequential( MobileOneBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, use_scale_branch=use_scale_branch ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels, inference_mode=inference_mode, use_se=False, num_conv_branches=1, use_scale_branch=use_scale_branch ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, use_scale_branch=use_scale_branch ), ) class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_features, eps=1e-05) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.eps = eps def forward(self, x) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + self.bias.unsqueeze(-1).unsqueeze(-1) return x class MHSA(nn.Module): """Multi-headed Self Attention module. Source modified from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ def __init__( self, dim: int, head_dim: int = 32, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. Args: dim: Number of embedding dimensions. head_dim: Number of hidden dimensions per head. Default: ``32`` qkv_bias: Use bias or not. Default: ``False`` attn_drop: Dropout rate for attention tensor. proj_drop: Dropout rate for projection tensor. """ super().__init__() assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim self.num_heads = dim // head_dim self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = x.shape B, C, H, W = shape N = H * W if len(shape) == 4: x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # trick here to make q@k.t more stable attn = (q * self.scale) @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) if len(shape) == 4: x = x.transpose(-2, -1).reshape(B, C, H, W) return x class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( self, patch_size: int, stride: int, in_channels: int, embed_dim: int, inference_mode: bool = False, use_se: bool = False, ) -> None: """Build patch embedding layer. Args: patch_size: Patch size for embedding computation. stride: Stride for convolutional embedding layer. in_channels: Number of channels of input tensor. embed_dim: Number of embedding dimensions. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` use_se: If ``True`` SE block will be used. """ super().__init__() block = list() block.append( ReparamLargeKernelConv( in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=stride, groups=in_channels, small_kernel=3, inference_mode=inference_mode, use_se=use_se, ) ) block.append( MobileOneBlock( in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ) ) self.proj = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x class RepMixer(nn.Module): """Reparameterizable token mixer. For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ """ def __init__( self, dim, kernel_size=3, use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): """Build RepMixer Module. Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super().__init__() self.dim = dim self.kernel_size = kernel_size self.inference_mode = inference_mode if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) else: self.norm = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ if self.inference_mode: return self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b self.__delattr__("mixer") self.__delattr__("norm") if self.use_layer_scale: self.__delattr__("layer_scale") class ConvFFN(nn.Module): """Convolutional FFN Module.""" def __init__( self, in_channels: int, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, act_layer: nn.Module = nn.GELU, drop: float = 0.0, ) -> None: """Build convolutional FFN module. Args: in_channels: Number of input channels. hidden_channels: Number of channels after expansion. Default: None out_channels: Number of output channels. Default: None act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.conv = nn.Sequential() self.conv.add_module( "conv", nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, padding=3, groups=in_channels, bias=False, ), ) # Fallback, sometimes batchnorm tensors # do not get instantiated correctly on some processes # when using deepspeed + accelerate norm_layer = nn.BatchNorm2d(num_features=out_channels) if norm_layer.weight.shape[0] == 0: norm_layer.weight = nn.Parameter(torch.zeros(out_channels)) if norm_layer.bias.shape[0] == 0: norm_layer.bias = nn.Parameter(torch.zeros(out_channels)) self.conv.add_module("bn", norm_layer) self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class RepCPE(nn.Module): """Implementation of conditional positional encoding. For more details refer to paper: `Conditional Positional Encodings for Vision Transformers `_ In our implementation, we can reparameterize this module to eliminate a skip connection. """ def __init__( self, in_channels: int, embed_dim: int = 768, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding Args: in_channels: Number of input channels. embed_dim: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super(RepCPE, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( f'"spatial_shape" must by a sequence or int, ' f"get {type(spatial_shape)} instead." ) assert len(spatial_shape) == 2, ( f'Length of "spatial_shape" should be 2, ' f"got {len(spatial_shape)} instead." ) self.spatial_shape = spatial_shape self.embed_dim = embed_dim self.in_channels = in_channels self.groups = embed_dim if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) else: self.pe = nn.Conv2d( in_channels, embed_dim, spatial_shape, 1, int(spatial_shape[0] // 2), bias=True, groups=embed_dim, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: x = self.pe(x) + x return x def reparameterize(self) -> None: # Build equivalent Id tensor input_dim = self.in_channels // self.groups kernel_value = torch.zeros( ( self.in_channels, input_dim, self.spatial_shape[0], self.spatial_shape[1], ), dtype=self.pe.weight.dtype, device=self.pe.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, self.spatial_shape[0] // 2, self.spatial_shape[1] // 2, ] = 1 id_tensor = kernel_value # Reparameterize Id tensor and conv w_final = id_tensor + self.pe.weight b_final = self.pe.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final self.__delattr__("pe") class RepMixerBlock(nn.Module): """Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): """Build RepMixer Block. Args: dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super().__init__() self.token_mixer = RepMixer( dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop Path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) return x class AttentionBlock(nn.Module): """Implementation of metaformer block with MHSA as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): """Build Attention Block. Args: dim: Number of embedding dimensions. mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ super().__init__() # Fallback, sometimes batchnorm tensors # do not get instantiated correctly on some processes # when using deepspeed + accelerate norm_layer_ = norm_layer(num_features=dim) if norm_layer_.weight.shape[0] == 0: norm_layer_.weight = nn.Parameter(torch.zeros(dim)) if norm_layer_.bias.shape[0] == 0: norm_layer_.bias = nn.Parameter(torch.zeros(dim)) self.norm = norm_layer_ self.token_mixer = MHSA(dim=dim) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) else: x = x + self.drop_path(self.token_mixer(self.norm(x))) x = x + self.drop_path(self.convffn(x)) return x def basic_blocks( dim: int, block_index: int, num_blocks: List[int], token_mixer_type: str, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode=False, ) -> nn.Sequential: """Build FastViT blocks within a stage. Args: dim: Number of embedding dimensions. block_index: block index. num_blocks: List containing number of blocks per stage. token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio. act_layer: Activation layer. norm_layer: Normalization layer. drop_rate: Dropout rate. drop_path_rate: Drop path rate. use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode. Returns: nn.Sequential object of all the blocks within the stage. """ blocks = [] for block_idx in range(num_blocks[block_index]): block_dpr = ( drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1) ) if token_mixer_type == "repmixer": blocks.append( RepMixerBlock( dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio, act_layer=act_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) ) elif token_mixer_type == "attention": blocks.append( AttentionBlock( dim, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) ) else: raise ValueError( "Token mixer type: {} not supported".format(token_mixer_type) ) blocks = nn.Sequential(*blocks) return blocks class GlobalPool2D(nn.Module): """This class implements global pooling with linear projection.""" def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None: super().__init__() scale = in_dim**-0.5 self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) self.in_dim = in_dim self.out_dim = out_dim def pool(self, x) -> Tensor: if x.dim() == 4: dims = [-2, -1] elif x.dim() == 5: dims = [-3, -2, -1] x = torch.mean(x, dim=dims, keepdim=False) return x def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # x is of shape [batch, in_dim] assert ( x.dim() == 4 ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format( x.shape ) # [batch, in_dim, in_height, in_width] --> [batch, in_dim] x = self.pool(x) # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] x = x @ self.proj return x class FastViT(nn.Module): """ This class implements `FastViT architecture `_ """ def __init__( self, layers, token_mixers: Tuple[str, ...], embed_dims=None, mlp_ratios=None, downsamples=None, se_downsamples=None, repmixer_kernel_size=3, norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, num_classes=1000, pos_embs=None, down_patch_size=7, down_stride=2, drop_rate=0.0, drop_path_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, init_cfg=None, pretrained=None, cls_ratio=2.0, inference_mode=False, stem_scale_branch=True, **kwargs, ) -> None: super().__init__() self.num_classes = num_classes if len(layers) == 4: self.out_indices = [0, 2, 4, 7] elif len(layers) == 5: self.out_indices = [0, 2, 4, 7, 10] else: raise NotImplementedError("FPN is not implemented for more than 5 stages.") if pos_embs is None: pos_embs = [None] * len(layers) if se_downsamples is None: se_downsamples = [False] * len(layers) # Convolutional stem self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode, use_scale_branch=stem_scale_branch) # Build the main stages of the network architecture network = [] for i in range(len(layers)): # Add position embeddings if requested if pos_embs[i] is not None: network.append( pos_embs[i]( embed_dims[i], embed_dims[i], inference_mode=inference_mode ) ) stage = basic_blocks( embed_dims[i], i, layers, token_mixer_type=token_mixers[i], kernel_size=repmixer_kernel_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) network.append(stage) if i >= len(layers) - 1: break # Patch merging/downsampling between stages. if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, in_channels=embed_dims[i], embed_dim=embed_dims[i + 1], inference_mode=inference_mode, use_se=se_downsamples[i + 1], ) ) self.network = nn.ModuleList(network) # Classifier head self.conv_exp = MobileOneBlock( in_channels=embed_dims[-1], out_channels=int(embed_dims[-1] * cls_ratio), kernel_size=3, stride=1, padding=1, groups=embed_dims[-1], inference_mode=inference_mode, use_se=True, num_conv_branches=1, ) self.head = ( nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self.cls_init_weights) self.init_cfg = copy.deepcopy(init_cfg) def cls_init_weights(self, m: nn.Module) -> None: """Init. for classification""" if isinstance(m, nn.Linear): normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) return x def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: for idx, block in enumerate(self.network): x = block(x) return x def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: # input embedding x = self.forward_embeddings(x) # through backbone x = self.forward_tokens(x) # for image classification/embedding x = self.conv_exp(x) cls_out = self.head(x) out_dict = dict() if kwargs.get("return_image_embeddings", False): out_dict.update({"logits": cls_out}) out_dict.update({"image_embeddings": x}) return out_dict else: return cls_out @register_model def fastvithd(pretrained=False, **kwargs): """Instantiate FastViTHD model variant.""" layers = [2, 12, 24, 4, 2] embed_dims = [96, 192, 384, 768, 1536] mlp_ratios = [4, 4, 4, 4, 4] downsamples = [True, True, True, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, norm_layer=LayerNormChannel, stem_scale_branch=False, inference_mode=True, **kwargs, ) model.default_cfg = default_cfgs["fastvit_m"] if pretrained: raise ValueError("Functionality not implemented.") return model ================================================ FILE: llava/model/multimodal_encoder/mobileclip_encoder.py ================================================ # # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import torch import torch.nn as nn import torch.nn.functional as F from transformers import CLIPImageProcessor import llava.model.multimodal_encoder.mobileclip as mobileclip class MobileCLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) self.input_image_size = int(vision_tower.split("_")[-1]) # Delay load is disabled for now if not delay_load: self.load_model() elif getattr(args, 'unfreeze_mm_vision_tower', False): self.load_model() else: model_cfg = mobileclip.load_model_config(self.vision_tower_name) self.cfg_only = model_cfg def load_model(self, device_map=None): if self.is_loaded: print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) return # Load model config model_cfg = mobileclip.load_model_config(self.vision_tower_name) # Override default image resolution model_cfg["image_cfg"]["image_size"] = self.input_image_size self.cfg_only = model_cfg # Build HF CLIPImageProcessor with MobileCLIP parameters self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"], "width": model_cfg["image_cfg"]["image_size"]}, image_mean=[0.0, 0.0, 0.0], image_std=[1.0, 1.0, 1.0], size={"shortest_edge": model_cfg["image_cfg"]["image_size"]}) # Instantiate the image encoder self.vision_tower = mobileclip.MCi(model_name=model_cfg["image_cfg"]["model_name"], projection_dim=model_cfg["embed_dim"]) if not self.tune_vision_tower: self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): # Features from penultimate layer image_features = image_forward_outs["image_embeddings"] # Reshape 4D tensor to 3D B, C, H, W = image_features.shape image_features = image_features.reshape(B, C, H*W) image_features = image_features.transpose(1, 2) return image_features def forward(self, images): if self.tune_vision_tower: return self.forward_images(images) else: with torch.no_grad(): return self.forward_images(images) def forward_images(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return next(self.vision_tower.parameters()).dtype @property def device(self): return next(self.vision_tower.parameters()).device @property def config(self): return self.cfg_only @property def hidden_size(self): return self.config["image_cfg"]["embed_dim"] @property def num_patches_per_side(self): return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"] @property def num_patches(self): return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2 ================================================ FILE: llava/model/multimodal_projector/builder.py ================================================ import torch.nn as nn import re class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == 'identity': return IdentityMap() raise ValueError(f'Unknown projector type: {projector_type}') ================================================ FILE: llava/model/utils.py ================================================ from transformers import AutoConfig def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if 'llava' in config and 'llava' not in cfg.model_type: assert cfg.model_type == 'llama' print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") print("You must upgrade the checkpoint to the new code base (this can be done automatically).") confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") cfg.architectures[0] = 'LlavaLlamaForCausalLM' cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) ================================================ FILE: llava/serve/__init__.py ================================================ ================================================ FILE: llava/serve/cli.py ================================================ import argparse import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from PIL import Image import requests from PIL import Image from io import BytesIO from transformers import TextStreamer def load_image(image_file): if image_file.startswith('http://') or image_file.startswith('https://'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def main(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "mistral" in model_name.lower(): conv_mode = "mistral_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode = "chatml_direct" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles image = load_image(args.image_file) image_size = image.size # Similar operation in model_worker.py image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) while True: try: inp = input(f"{roles[0]}: ") except EOFError: inp = "" if not inp: print("exit...") break print(f"{roles[1]}: ", end="") if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp image = None conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, image_sizes=[image_size], do_sample=True if args.temperature > 0 else False, temperature=args.temperature, max_new_tokens=args.max_new_tokens, streamer=streamer, use_cache=True) outputs = tokenizer.decode(output_ids[0]).strip() conv.messages[-1][-1] = outputs if args.debug: print("\n", {"prompt": prompt, "outputs": outputs}, "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() main(args) ================================================ FILE: llava/serve/controller.py ================================================ """ A controller manages distributed workers. It sends worker addresses to clients. """ import argparse import asyncio import dataclasses from enum import Enum, auto import json import logging import time from typing import List, Union import threading from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import numpy as np import requests import uvicorn from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION from llava.utils import build_logger, server_error_msg logger = build_logger("controller", "controller.log") class DispatchMethod(Enum): LOTTERY = auto() SHORTEST_QUEUE = auto() @classmethod def from_str(cls, name): if name == "lottery": return cls.LOTTERY elif name == "shortest_queue": return cls.SHORTEST_QUEUE else: raise ValueError(f"Invalid dispatch method") @dataclasses.dataclass class WorkerInfo: model_names: List[str] speed: int queue_length: int check_heart_beat: bool last_heart_beat: str def heart_beat_controller(controller): while True: time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) controller.remove_stable_workers_by_expiration() class Controller: def __init__(self, dispatch_method: str): # Dict[str -> WorkerInfo] self.worker_info = {} self.dispatch_method = DispatchMethod.from_str(dispatch_method) self.heart_beat_thread = threading.Thread( target=heart_beat_controller, args=(self,), daemon=True) self.heart_beat_thread.start() logger.info("Init controller") def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict): if worker_name not in self.worker_info: logger.info(f"Register a new worker: {worker_name}") else: logger.info(f"Register an existing worker: {worker_name}") if not worker_status: worker_status = self.get_worker_status(worker_name) if not worker_status: return False self.worker_info[worker_name] = WorkerInfo( worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time()) logger.info(f"Register done: {worker_name}, {worker_status}") return True def get_worker_status(self, worker_name: str): try: r = requests.post(worker_name + "/worker_get_status", timeout=5) except requests.exceptions.RequestException as e: logger.error(f"Get status fails: {worker_name}, {e}") return None if r.status_code != 200: logger.error(f"Get status fails: {worker_name}, {r}") return None return r.json() def remove_worker(self, worker_name: str): del self.worker_info[worker_name] def refresh_all_workers(self): old_info = dict(self.worker_info) self.worker_info = {} for w_name, w_info in old_info.items(): if not self.register_worker(w_name, w_info.check_heart_beat, None): logger.info(f"Remove stale worker: {w_name}") def list_models(self): model_names = set() for w_name, w_info in self.worker_info.items(): model_names.update(w_info.model_names) return list(model_names) def get_worker_address(self, model_name: str): if self.dispatch_method == DispatchMethod.LOTTERY: worker_names = [] worker_speeds = [] for w_name, w_info in self.worker_info.items(): if model_name in w_info.model_names: worker_names.append(w_name) worker_speeds.append(w_info.speed) worker_speeds = np.array(worker_speeds, dtype=np.float32) norm = np.sum(worker_speeds) if norm < 1e-4: return "" worker_speeds = worker_speeds / norm if True: # Directly return address pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] return worker_name # Check status before returning while True: pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] if self.get_worker_status(worker_name): break else: self.remove_worker(worker_name) worker_speeds[pt] = 0 norm = np.sum(worker_speeds) if norm < 1e-4: return "" worker_speeds = worker_speeds / norm continue return worker_name elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: worker_names = [] worker_qlen = [] for w_name, w_info in self.worker_info.items(): if model_name in w_info.model_names: worker_names.append(w_name) worker_qlen.append(w_info.queue_length / w_info.speed) if len(worker_names) == 0: return "" min_index = np.argmin(worker_qlen) w_name = worker_names[min_index] self.worker_info[w_name].queue_length += 1 logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") return w_name else: raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") def receive_heart_beat(self, worker_name: str, queue_length: int): if worker_name not in self.worker_info: logger.info(f"Receive unknown heart beat. {worker_name}") return False self.worker_info[worker_name].queue_length = queue_length self.worker_info[worker_name].last_heart_beat = time.time() logger.info(f"Receive heart beat. {worker_name}") return True def remove_stable_workers_by_expiration(self): expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION to_delete = [] for worker_name, w_info in self.worker_info.items(): if w_info.check_heart_beat and w_info.last_heart_beat < expire: to_delete.append(worker_name) for worker_name in to_delete: self.remove_worker(worker_name) def worker_api_generate_stream(self, params): worker_addr = self.get_worker_address(params["model"]) if not worker_addr: logger.info(f"no worker: {params['model']}") ret = { "text": server_error_msg, "error_code": 2, } yield json.dumps(ret).encode() + b"\0" try: response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: yield chunk + b"\0" except requests.exceptions.RequestException as e: logger.info(f"worker timeout: {worker_addr}") ret = { "text": server_error_msg, "error_code": 3, } yield json.dumps(ret).encode() + b"\0" # Let the controller act as a worker to achieve hierarchical # management. This can be used to connect isolated sub networks. def worker_api_get_status(self): model_names = set() speed = 0 queue_length = 0 for w_name in self.worker_info: worker_status = self.get_worker_status(w_name) if worker_status is not None: model_names.update(worker_status["model_names"]) speed += worker_status["speed"] queue_length += worker_status["queue_length"] return { "model_names": list(model_names), "speed": speed, "queue_length": queue_length, } app = FastAPI() @app.post("/register_worker") async def register_worker(request: Request): data = await request.json() controller.register_worker( data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)) @app.post("/refresh_all_workers") async def refresh_all_workers(): models = controller.refresh_all_workers() @app.post("/list_models") async def list_models(): models = controller.list_models() return {"models": models} @app.post("/get_worker_address") async def get_worker_address(request: Request): data = await request.json() addr = controller.get_worker_address(data["model"]) return {"address": addr} @app.post("/receive_heart_beat") async def receive_heart_beat(request: Request): data = await request.json() exist = controller.receive_heart_beat( data["worker_name"], data["queue_length"]) return {"exist": exist} @app.post("/worker_generate_stream") async def worker_api_generate_stream(request: Request): params = await request.json() generator = controller.worker_api_generate_stream(params) return StreamingResponse(generator) @app.post("/worker_get_status") async def worker_api_get_status(request: Request): return controller.worker_api_get_status() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21001) parser.add_argument("--dispatch-method", type=str, choices=[ "lottery", "shortest_queue"], default="shortest_queue") args = parser.parse_args() logger.info(f"args: {args}") controller = Controller(args.dispatch_method) uvicorn.run(app, host=args.host, port=args.port, log_level="info") ================================================ FILE: llava/serve/gradio_web_server.py ================================================ import argparse import datetime import json import os import time import gradio as gr import requests from llava.conversation import (default_conversation, conv_templates, SeparatorStyle) from llava.constants import LOGDIR from llava.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) import hashlib logger = build_logger("gradio_web_server", "gradio_web_server.log") headers = {"User-Agent": "LLaVA Client"} no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) priority = { "vicuna-13b": "aaaaaaa", "koala-13b": "aaaaaab", } def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def get_model_list(): ret = requests.post(args.controller_url + "/refresh_all_workers") assert ret.status_code == 200 ret = requests.post(args.controller_url + "/list_models") models = ret.json()["models"] models.sort(key=lambda x: priority.get(x, x)) logger.info(f"Models: {models}") return models get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") dropdown_update = gr.Dropdown(visible=True) if "model" in url_params: model = url_params["model"] if model in models: dropdown_update = gr.Dropdown(value=model, visible=True) state = default_conversation.copy() return state, dropdown_update def load_demo_refresh_model_list(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") models = get_model_list() state = default_conversation.copy() dropdown_update = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "" ) return state, dropdown_update def vote_last_response(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, model_selector, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, model_selector, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, model_selector, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def add_text(state, text, image, image_process_mode, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0 and image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 if args.moderate: flagged = violates_moderation(text) if flagged: state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( no_change_btn,) * 5 text = text[:1536] # Hard cut-off if image is not None: text = text[:1200] # Hard cut-off for images if '' not in text: # text = '' + text text = text + '\n' text = (text, image, image_process_mode) state = default_conversation.copy() state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: # First round of conversation if "llava" in model_name.lower(): if 'llama-2' in model_name.lower(): template_name = "llava_llama_2" elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): if 'orca' in model_name.lower(): template_name = "mistral_orca" elif 'hermes' in model_name.lower(): template_name = "chatml_direct" else: template_name = "mistral_instruct" elif 'llava-v1.6-34b' in model_name.lower(): template_name = "chatml_direct" elif "v1" in model_name.lower(): if 'mmtag' in model_name.lower(): template_name = "v1_mmtag" elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): template_name = "v1_mmtag" else: template_name = "llava_v1" elif "mpt" in model_name.lower(): template_name = "mpt" else: if 'mmtag' in model_name.lower(): template_name = "v0_mmtag" elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): template_name = "v0_mmtag" else: template_name = "llava_v0" elif "mpt" in model_name: template_name = "mpt_text" elif "llama-2" in model_name: template_name = "llama_2" else: template_name = "vicuna_v1" new_state = conv_templates[template_name].copy() new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # Query worker address controller_url = args.controller_url ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name}) worker_addr = ret.json()["address"] logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") # No available worker if worker_addr == "": state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return # Construct prompt prompt = state.get_prompt() all_images = state.get_images(return_pil=True) all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] for image, hash in zip(all_images, all_image_hash): t = datetime.datetime.now() filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) # Make requests pload = { "model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p), "max_new_tokens": min(int(max_new_tokens), 1536), "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, "images": f'List of {len(state.get_images())} images: {all_image_hash}', } logger.info(f"==== request ====\n{pload}") pload['images'] = state.get_images() state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][len(prompt):].strip() state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return time.sleep(0.03) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 finish_tstamp = time.time() logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "images": all_image_hash, "ip": request.client.host, } fout.write(json.dumps(data) + "\n") title_markdown = (""" # 🌋 LLaVA: Large Language and Vision Assistant [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)] """) tos_markdown = (""" ### Terms of use By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. """) learn_more_markdown = (""" ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) block_css = """ #buttons button { min-width: min(120px,100%); } """ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() if not embed_mode: gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=3): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) if cur_dir is None: cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples(examples=[ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], ], inputs=[imagebox, textbox]) with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="LLaVA Chatbot", height=650, layout="panel", ) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button(value="Send", variant="primary") with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠️ Flag", interactive=False) # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear", interactive=False) if not embed_mode: gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) url_params = gr.JSON(visible=False) # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click( upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) downvote_btn.click( downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) flag_btn.click( flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) regenerate_btn.click( regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) clear_btn.click( clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False ) textbox.submit( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) submit_btn.click( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) if args.model_list_mode == "once": demo.load( load_demo, [url_params], [state, model_selector], js=get_window_url_params ) elif args.model_list_mode == "reload": demo.load( load_demo_refresh_model_list, None, [state, model_selector], queue=False ) else: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") return demo if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--controller-url", type=str, default="http://localhost:21001") parser.add_argument("--concurrency-count", type=int, default=16) parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") args = parser.parse_args() logger.info(f"args: {args}") models = get_model_list() logger.info(args) demo = build_demo(args.embed, concurrency_count=args.concurrency_count) demo.queue( api_open=False ).launch( server_name=args.host, server_port=args.port, share=args.share ) ================================================ FILE: llava/serve/model_worker.py ================================================ """ A model worker executes the model. """ import argparse import asyncio import json import time import threading import uuid from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse import requests import torch import uvicorn from functools import partial from llava.constants import WORKER_HEART_BEAT_INTERVAL from llava.utils import (build_logger, server_error_msg, pretty_print_semaphore) from llava.model.builder import load_pretrained_model from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from transformers import TextIteratorStreamer from threading import Thread GB = 1 << 30 worker_id = str(uuid.uuid4())[:6] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") global_counter = 0 model_semaphore = None def heart_beat_worker(controller): while True: time.sleep(WORKER_HEART_BEAT_INTERVAL) controller.send_heart_beat() class ModelWorker: def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit, device, use_flash_attn=False): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id if model_path.endswith("/"): model_path = model_path[:-1] if model_name is None: model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] else: self.model_name = model_name self.device = device logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn) self.is_multimodal = 'llava' in self.model_name.lower() if not no_register: self.register_to_controller() self.heart_beat_thread = threading.Thread( target=heart_beat_worker, args=(self,), daemon=True) self.heart_beat_thread.start() def register_to_controller(self): logger.info("Register to controller") url = self.controller_addr + "/register_worker" data = { "worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status() } r = requests.post(url, json=data) assert r.status_code == 200 def send_heart_beat(self): logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") url = self.controller_addr + "/receive_heart_beat" while True: try: ret = requests.post(url, json={ "worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) exist = ret.json()["exist"] break except requests.exceptions.RequestException as e: logger.error(f"heart beat error: {e}") time.sleep(5) if not exist: self.register_to_controller() def get_queue_length(self): if model_semaphore is None: return 0 else: return args.limit_model_concurrency - model_semaphore._value + (len( model_semaphore._waiters) if model_semaphore._waiters is not None else 0) def get_status(self): return { "model_names": [self.model_name], "speed": 1, "queue_length": self.get_queue_length(), } @torch.inference_mode() def generate_stream(self, params): tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor prompt = params["prompt"] ori_prompt = prompt images = params.get("images", None) num_image_tokens = 0 if images is not None and len(images) > 0 and self.is_multimodal: if len(images) > 0: if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): raise ValueError("Number of images does not match number of tokens in prompt") images = [load_image_from_base64(image) for image in images] image_sizes = [image.size for image in images] images = process_images(images, image_processor, model.config) if type(images) is list: images = [image.to(self.model.device, dtype=torch.float16) for image in images] else: images = images.to(self.model.device, dtype=torch.float16) replace_token = DEFAULT_IMAGE_TOKEN if getattr(self.model.config, 'mm_use_im_start_end', False): replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches else: images = None image_sizes = None image_args = {"images": images, "image_sizes": image_sizes} else: images = None image_args = {} temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) max_context_length = getattr(model.config, 'max_position_embeddings', 2048) max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) stop_str = params.get("stop", None) do_sample = True if temperature > 0.001 else False input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) keywords = [stop_str] # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) if max_new_tokens < 1: yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" return thread = Thread(target=model.generate, kwargs=dict( inputs=input_ids, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, streamer=streamer, use_cache=True, **image_args )) thread.start() generated_text = ori_prompt for new_text in streamer: generated_text += new_text if generated_text.endswith(stop_str): generated_text = generated_text[:-len(stop_str)] yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" def generate_stream_gate(self, params): try: for x in self.generate_stream(params): yield x except ValueError as e: print("Caught ValueError:", e) ret = { "text": server_error_msg, "error_code": 1, } yield json.dumps(ret).encode() + b"\0" except torch.cuda.CudaError as e: print("Caught torch.cuda.CudaError:", e) ret = { "text": server_error_msg, "error_code": 1, } yield json.dumps(ret).encode() + b"\0" except Exception as e: print("Caught Unknown Error", e) ret = { "text": server_error_msg, "error_code": 1, } yield json.dumps(ret).encode() + b"\0" app = FastAPI() def release_model_semaphore(fn=None): model_semaphore.release() if fn is not None: fn() @app.post("/worker_generate_stream") async def generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 params = await request.json() if model_semaphore is None: model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) await model_semaphore.acquire() worker.send_heart_beat() generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) return StreamingResponse(generator, background=background_tasks) @app.post("/worker_get_status") async def get_status(request: Request): return worker.get_status() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument("--controller-address", type=str, default="http://localhost:21001") parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--model-name", type=str) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") parser.add_argument("--limit-model-concurrency", type=int, default=5) parser.add_argument("--stream-interval", type=int, default=1) parser.add_argument("--no-register", action="store_true") parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--use-flash-attn", action="store_true") args = parser.parse_args() logger.info(f"args: {args}") if args.multi_modal: logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") worker = ModelWorker(args.controller_address, args.worker_address, worker_id, args.no_register, args.model_path, args.model_base, args.model_name, args.load_8bit, args.load_4bit, args.device, use_flash_attn=args.use_flash_attn) uvicorn.run(app, host=args.host, port=args.port, log_level="info") ================================================ FILE: llava/serve/register_worker.py ================================================ """ Manually register workers. Usage: python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 """ import argparse import requests if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--controller-address", type=str) parser.add_argument("--worker-name", type=str) parser.add_argument("--check-heart-beat", action="store_true") args = parser.parse_args() url = args.controller_address + "/register_worker" data = { "worker_name": args.worker_name, "check_heart_beat": args.check_heart_beat, "worker_status": None, } r = requests.post(url, json=data) assert r.status_code == 200 ================================================ FILE: llava/serve/sglang_worker.py ================================================ """ A model worker executes the model. """ import argparse import asyncio from concurrent.futures import ThreadPoolExecutor import json import time import threading import uuid from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse import requests import re import uvicorn from functools import partial from llava.constants import WORKER_HEART_BEAT_INTERVAL from llava.utils import (build_logger, server_error_msg, pretty_print_semaphore) from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square from llava.constants import DEFAULT_IMAGE_TOKEN import sglang as sgl from sglang.backend.runtime_endpoint import RuntimeEndpoint GB = 1 << 30 worker_id = str(uuid.uuid4())[:6] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") global_counter = 0 model_semaphore = None def heart_beat_worker(controller): while True: time.sleep(WORKER_HEART_BEAT_INTERVAL) controller.send_heart_beat() @sgl.function def pipeline(s, prompt, max_tokens): for p in prompt: if type(p) is str: s += p else: s += sgl.image(p) s += sgl.gen("response", max_tokens=max_tokens) class ModelWorker: def __init__(self, controller_addr, worker_addr, sgl_endpoint, worker_id, no_register, model_name): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id # Select backend backend = RuntimeEndpoint(sgl_endpoint) sgl.set_default_backend(backend) model_path = backend.model_info["model_path"] if model_path.endswith("/"): model_path = model_path[:-1] if model_name is None: model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] else: self.model_name = model_name logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") if not no_register: self.register_to_controller() self.heart_beat_thread = threading.Thread( target=heart_beat_worker, args=(self,), daemon=True) self.heart_beat_thread.start() def register_to_controller(self): logger.info("Register to controller") url = self.controller_addr + "/register_worker" data = { "worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status() } r = requests.post(url, json=data) assert r.status_code == 200 def send_heart_beat(self): logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") url = self.controller_addr + "/receive_heart_beat" while True: try: ret = requests.post(url, json={ "worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) exist = ret.json()["exist"] break except requests.exceptions.RequestException as e: logger.error(f"heart beat error: {e}") time.sleep(5) if not exist: self.register_to_controller() def get_queue_length(self): if model_semaphore is None: return 0 else: return args.limit_model_concurrency - model_semaphore._value + (len( model_semaphore._waiters) if model_semaphore._waiters is not None else 0) def get_status(self): return { "model_names": [self.model_name], "speed": 1, "queue_length": self.get_queue_length(), } async def generate_stream(self, params): ori_prompt = prompt = params["prompt"] images = params.get("images", None) if images is not None and len(images) > 0: if len(images) > 0: if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): raise ValueError("Number of images does not match number of tokens in prompt") images = [load_image_from_base64(image) for image in images] # FIXME: for image-start/end token # replace_token = DEFAULT_IMAGE_TOKEN # if getattr(self.model.config, 'mm_use_im_start_end', False): # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN) prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) prompt = [] for i in range(len(prompt_split)): prompt.append(prompt_split[i]) if i < len(images): prompt.append(images[i]) else: prompt = [prompt] temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) stop_str = params.get("stop", None) stop_str = [stop_str] if stop_str is not None else None print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}) state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) generated_text = ori_prompt async for text_outputs in state.text_async_iter(var_name="response"): generated_text += text_outputs yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" async def generate_stream_gate(self, params): try: async for x in self.generate_stream(params): yield x except ValueError as e: print("Caught ValueError:", e) ret = { "text": server_error_msg, "error_code": 1, } yield json.dumps(ret).encode() + b"\0" except Exception as e: print("Caught Unknown Error", e) ret = { "text": server_error_msg, "error_code": 1, } yield json.dumps(ret).encode() + b"\0" app = FastAPI() def release_model_semaphore(fn=None): model_semaphore.release() if fn is not None: fn() @app.post("/worker_generate_stream") async def generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 params = await request.json() if model_semaphore is None: model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) await model_semaphore.acquire() worker.send_heart_beat() generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) return StreamingResponse(generator, background=background_tasks) @app.post("/worker_get_status") async def get_status(request: Request): return worker.get_status() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument("--controller-address", type=str, default="http://localhost:21001") parser.add_argument("--model-name", type=str) parser.add_argument("--sgl-endpoint", type=str) parser.add_argument("--limit-model-concurrency", type=int, default=5) parser.add_argument("--stream-interval", type=int, default=1) parser.add_argument("--no-register", action="store_true") args = parser.parse_args() logger.info(f"args: {args}") worker = ModelWorker(args.controller_address, args.worker_address, args.sgl_endpoint, worker_id, args.no_register, args.model_name) uvicorn.run(app, host=args.host, port=args.port, log_level="info") ================================================ FILE: llava/serve/test_message.py ================================================ import argparse import json import requests from llava.conversation import default_conversation def main(): if args.worker_address: worker_addr = args.worker_address else: controller_addr = args.controller_address ret = requests.post(controller_addr + "/refresh_all_workers") ret = requests.post(controller_addr + "/list_models") models = ret.json()["models"] models.sort() print(f"Models: {models}") ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) worker_addr = ret.json()["address"] print(f"worker_addr: {worker_addr}") if worker_addr == "": return conv = default_conversation.copy() conv.append_message(conv.roles[0], args.message) prompt = conv.get_prompt() headers = {"User-Agent": "LLaVA Client"} pload = { "model": args.model_name, "prompt": prompt, "max_new_tokens": args.max_new_tokens, "temperature": 0.7, "stop": conv.sep, } response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) print(prompt.replace(conv.sep, "\n"), end="") for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"].split(conv.sep)[-1] print(output, end="\r") print("") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--controller-address", type=str, default="http://localhost:21001") parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--max-new-tokens", type=int, default=32) parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") args = parser.parse_args() main() ================================================ FILE: llava/train/llama_flash_attn_monkey_patch.py ================================================ from typing import Optional, Tuple import warnings import torch import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv try: from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func except ImportError: from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: warnings.warn( "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." ) bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) # shape: (b, num_heads, s, head_dim) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) if past_key_value is not None: # reuse k, v key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Transform the data into the format required by flash attention qkv = torch.stack([query_states, key_states, value_states], dim=2) qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] key_padding_mask = attention_mask if key_padding_mask is None: qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) cu_q_lens = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device ) max_s = q_len output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = output.view(bsz, q_len, -1) else: qkv = qkv.reshape(bsz, q_len, -1) qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) output_unpad = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) output = pad_input(output_unpad, indices, bsz, q_len) return self.o_proj(output), None, past_key_value # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): cuda_major, cuda_minor = torch.cuda.get_device_capability() if cuda_major < 8: warnings.warn( "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward ================================================ FILE: llava/train/llama_xformers_attn_monkey_patch.py ================================================ """ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments """ import logging import math from typing import Optional, Tuple import torch import transformers.models.llama.modeling_llama from torch import nn try: import xformers.ops except ImportError: logging.error("xformers not found! Please install it before trying to use it.") def replace_llama_attn_with_xformers_attn(): transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward def xformers_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) ( query_states, key_states, ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # We only apply xformers optimizations if we don't need to output the whole attention matrix if not output_attentions: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: # input and output should be of form (bsz, q_len, num_heads, head_dim) attn_output = xformers.ops.memory_efficient_attention( query_states, key_states, value_states, attn_bias=None ) else: # input and output should be of form (bsz, q_len, num_heads, head_dim) attn_output = xformers.ops.memory_efficient_attention( query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask(), ) attn_weights = None else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value ================================================ FILE: llava/train/llava_trainer.py ================================================ import os import torch import torch.nn as nn from torch.utils.data import Sampler import transformers from transformers import Trainer from transformers.trainer import ( is_sagemaker_mp_enabled, get_parameter_names, has_length, # ALL_LAYERNORM_LAYERS, logger, ) from typing import List, Optional ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.BatchNorm2d] def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return def split_to_even_chunks(indices, lengths, num_chunks): """ Split a list of indices into `chunks` chunks of roughly equal lengths. """ if len(indices) % num_chunks != 0: return [indices[i::num_chunks] for i in range(num_chunks)] num_indices_per_chunk = len(indices) // num_chunks chunks = [[] for _ in range(num_chunks)] chunks_lengths = [0 for _ in range(num_chunks)] for index in indices: shortest_chunk = chunks_lengths.index(min(chunks_lengths)) chunks[shortest_chunk].append(index) chunks_lengths[shortest_chunk] += lengths[index] if len(chunks[shortest_chunk]) == num_indices_per_chunk: chunks_lengths[shortest_chunk] = float("inf") return chunks def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): # all samples are in the same modality return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] megabatch_size = world_size * batch_size mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] last_mm = mm_megabatches[-1] last_lang = lang_megabatches[-1] additional_batch = last_mm + last_lang megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] if len(additional_batch) > 0: megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) megabatch_size = world_size * batch_size megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] return [i for megabatch in megabatches for batch in megabatch for i in batch] class LengthGroupedSampler(Sampler): r""" Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ def __init__( self, batch_size: int, world_size: int, lengths: Optional[List[int]] = None, generator=None, group_by_modality: bool = False, ): if lengths is None: raise ValueError("Lengths must be provided.") self.batch_size = batch_size self.world_size = world_size self.lengths = lengths self.generator = generator self.group_by_modality = group_by_modality def __len__(self): return len(self.lengths) def __iter__(self): if self.group_by_modality: indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) else: indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) return iter(indices) class LLaVATrainer(Trainer): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None if self.args.group_by_modality_length: lengths = self.train_dataset.modality_lengths return LengthGroupedSampler( self.args.train_batch_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, lengths=lengths, group_by_modality=True, ) else: return super()._get_train_sampler() def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] lr_mapper = {} if self.args.mm_projector_lr is not None: lr_mapper["mm_projector"] = self.args.mm_projector_lr if self.args.mm_vision_tower_lr is not None: lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr if len(lr_mapper) > 0: special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)] optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": 0.0, }, ] for module_keyword, lr in lr_mapper.items(): module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] optimizer_grouped_parameters.extend( [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, "lr": lr, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": lr, }, ] ) else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ['mm_projector', 'vision_resampler'] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in']) weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) else: # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144 model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None) super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): pass else: # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144 self.model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None) super(LLaVATrainer, self)._save(output_dir, state_dict) ================================================ FILE: llava/train/train.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from packaging import version import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers import tokenizers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token, process_anyres_image from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) tune_mm_mlp_and_vision_tower: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_projector_type: Optional[str] = field(default='linear') mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_patch_merge_type: Optional[str] = field(default='flat') mm_vision_select_feature: Optional[str] = field(default="patch") unfreeze_mm_vision_tower: bool = field(default=False) s2: Optional[bool] = field(default=False) hd: Optional[bool] = field(default=False) @dataclass class DataArguments: data_path: Optional[List[str]] = field(default=None, metadata={"help": "Optional list of paths to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[List[str]] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) image_crop_resolution: Optional[int] = field(default=None) image_split_resolution: Optional[int] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" mm_projector_lr: Optional[float] = None mm_vision_tower_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if getattr(trainer.args, "tune_mm_mlp_adapter", False): # Only save Adapter keys_to_match = ['mm_projector'] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in', 'tok_embeddings']) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) return if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) # fix: add qwen2 def preprocess_qwen_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2 # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) rounds_len = len(rounds) cur_len = 0 # target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_ids = tokenizer_image_token(rou, tokenizer) instruction_ids = tokenizer_image_token(parts[0], tokenizer) equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) round_len = len(round_ids) else: round_ids = tokenizer(rou).input_ids instruction_ids = tokenizer(parts[0]).input_ids equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) round_len = len(round_ids) if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: round_len += 1 instruction_len += 1 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len + rounds_len - 2: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: round_len -= 1 instruction_len -= 1 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 1 if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: round_len += 1 instruction_len += 1 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer, has_image=has_image) # fix: add qwen2 if conversation_lib.default_conversation.version.startswith("qwen_v2"): return preprocess_qwen_2(sources, tokenizer, has_image=has_image) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: List[str], tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() #list_data_dict = json.load(open(data_path, "r")) list_data_dict = [] for i, _data_path in enumerate(data_path): data = json.load(open(_data_path, "r")) data = [{**entry, 'img_path_idx': i} for entry in data] list_data_dict += data rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if 'image' in sample else 0 length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) cur_len = cur_len if 'image' in sample else -cur_len length_list.append(cur_len) return length_list def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] img_path_idx = self.list_data_dict[i]['img_path_idx'] image_folder = self.data_args.image_folder[img_path_idx] processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') image_size = image.size if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] elif self.data_args.image_aspect_ratio == "anyres" or "anyres_max" in self.data_args.image_aspect_ratio: image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image data_dict['image_size'] = image_size elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) data_dict['image_size'] = (crop_size['height'], crop_size['width']) return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] batch['image_sizes'] = [instance['image_size'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(attn_implementation=None): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_skip_modules=["mm_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMptForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) elif 'dclm' in model_args.model_name_or_path.lower(): config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.attn_config['attn_impl'] = "eager" model = LlavaOpenlmForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True if model_args.tune_mm_mlp_and_vision_tower: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True for p in model.get_vision_tower().parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end model.config.mm_projector_lr = training_args.mm_projector_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_mem.py ================================================ from llava.train.train_qwen import train if __name__ == "__main__": train(attn_implementation="flash_attention_2") ================================================ FILE: llava/train/train_qwen.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from packaging import version import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers import tokenizers from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from torch.utils.data import Dataset from llava.train.llava_trainer import LLaVATrainer from llava import conversation as conversation_lib from llava.model import * from llava.mm_utils import tokenizer_image_token, process_anyres_image from PIL import Image local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) tune_mm_mlp_and_vision_tower: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_projector_type: Optional[str] = field(default='linear') mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_patch_merge_type: Optional[str] = field(default='flat') mm_vision_select_feature: Optional[str] = field(default="patch") unfreeze_mm_vision_tower: bool = field(default=False) s2: Optional[bool] = field(default=False) hd: Optional[bool] = field(default=False) @dataclass class DataArguments: data_path: Optional[List[str]] = field(default=None, metadata={"help": "Optional list of paths to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[List[str]] = field(default=None) image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) image_crop_resolution: Optional[int] = field(default=None) image_split_resolution: Optional[int] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" mm_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) mm_vision_tower_lr: Optional[float] = None def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if getattr(trainer.args, "tune_mm_mlp_adapter", False): # Only save Adapter keys_to_match = ['mm_projector'] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in']) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) return if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = 'unknown' sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) # fix: add qwen2 # def preprocess_qwen_2( # sources, # tokenizer: transformers.PreTrainedTokenizer, # has_image: bool = False # ) -> Dict: # # print('-----preprocess_qwen_2-------') # conv = conversation_lib.default_conversation.copy() # roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # # # Apply prompt templates # conversations = [] # for i, source in enumerate(sources): # if roles[source[0]["from"]] != conv.roles[0]: # # Skip the first one if it is not from human # source = source[1:] # # conv.messages = [] # for j, sentence in enumerate(source): # role = roles[sentence["from"]] # assert role == conv.roles[j % 2], f"{i}" # conv.append_message(role, sentence["value"]) # conversations.append(conv.get_prompt()) # # # Tokenize conversations # # if has_image: # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) # else: # input_ids = tokenizer( # conversations, # return_tensors="pt", # padding="longest", # max_length=tokenizer.model_max_length, # truncation=True, # ).input_ids # # targets = input_ids.clone() # # assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2 # # rank0_print(50*'S') # # Mask targets # sep = conv.sep + conv.roles[1] + ": " # for conversation, target in zip(conversations, targets): # total_len = int(target.ne(tokenizer.pad_token_id).sum()) # rank0_print(f"target.shape={target.shape}", f"total_len={total_len}") # # rounds = conversation.split(conv.sep2) # rounds_len = len(rounds) # cur_len = 0 # # target[:cur_len] = IGNORE_INDEX # for i, rou in enumerate(rounds): # if rou == "": # break # # parts = rou.split(sep) # if len(parts) != 2: # break # parts[0] += sep # # if has_image: # round_ids = tokenizer_image_token(rou, tokenizer) # instruction_ids = tokenizer_image_token(parts[0], tokenizer) # equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] # # instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) # round_len = len(round_ids) # # else: # round_ids = tokenizer(rou).input_ids # instruction_ids = tokenizer(parts[0]).input_ids # equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] # # instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) # round_len = len(round_ids) # # if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: # round_len += 1 # instruction_len += 1 # # rank0_print(i, rou, f"cur_len={cur_len}", f"round_len={round_len}", f"instruction_len={instruction_len}", f"cur_len + instruction_len={cur_len + instruction_len}") # target[cur_len: cur_len + instruction_len] = IGNORE_INDEX # # cur_len += round_len # rank0_print("Outside Loop") # rank0_print(cur_len, len(target)) # rank0_print(target) # rank0_print(50 * 'E') # exit(0) # target[cur_len:] = IGNORE_INDEX # # if cur_len < tokenizer.model_max_length: # if cur_len != total_len + rounds_len - 2: # target[:] = IGNORE_INDEX # print( # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." # f" (ignored)" # ) # # return dict( # input_ids=input_ids, # labels=targets, # ) def preprocess_qwen_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): try: role = roles[sentence["from"]] except KeyError as e: print("e") print("skipping sentence due to unrecognized role: {}".format(sentence["from"])) continue assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2 split_sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds_before = conversation.split( conv.sep) # system->user->assistant->user->assistant->user->assistant->user->assistant->Empty if rounds_before[0] == conv.system: rounds_before[1] = conv.sep.join([rounds_before[0], rounds_before[1]]) rounds_before = rounds_before[1:] # connect every pair: rounds = [] for i in range(0, len(rounds_before), 2): if i < len(rounds_before)-1: rounds.append(conv.sep.join([rounds_before[i], rounds_before[i + 1]])) else: rounds.append(rounds_before[i]) cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(split_sep) if len(parts) != 2: break assert parts[0].startswith(conv.roles[0]) or parts[0].startswith(conv.system) parts[0] += split_sep if has_image: round_ids = tokenizer_image_token(rou, tokenizer) instruction_ids = tokenizer_image_token(parts[0], tokenizer) equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) round_len = len(round_ids) else: round_ids = tokenizer(rou).input_ids instruction_ids = tokenizer(parts[0]).input_ids equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) round_len = len(round_ids) round_len += 2 # this is tom compensate for the sep2 assert rou == parts[0]+parts[1] target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: round_len -= 1 instruction_len -= 1 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 1 if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: round_len += 1 instruction_len += 1 target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ # print("conversation:",conversation_lib.default_conversation.version) # conversation_lib.default_conversation.version == "qwen_v2" if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): # print('--v1--') return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": # print('--mpt--') return preprocess_mpt(sources, tokenizer, has_image=has_image) # fix: add qwen2 if conversation_lib.default_conversation.version.startswith("qwen_v2"): # print('--qwen_v2--') return preprocess_qwen_2(sources, tokenizer, has_image=has_image) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: List[str], tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() #list_data_dict = json.load(open(data_path, "r")) list_data_dict = [] for i, _data_path in enumerate(data_path): data = json.load(open(_data_path, "r")) data = [{**entry, 'img_path_idx': i} for entry in data] list_data_dict += data self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if 'image' in sample else 0 length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) cur_len = cur_len if 'image' in sample else -cur_len length_list.append(cur_len) return length_list def get_sample(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'image' in sources[0]: image_file = self.list_data_dict[i]['image'] img_path_idx = self.list_data_dict[i]['img_path_idx'] image_folder = self.data_args.image_folder[img_path_idx] processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') image_size = image.size if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] elif self.data_args.image_aspect_ratio == "anyres" or "anyres_max" in self.data_args.image_aspect_ratio: image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('image' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image data_dict['image_size'] = image_size elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) data_dict['image_size'] = (crop_size['height'], crop_size['width']) return data_dict def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: return self.get_sample(i) except Exception as e: print("Error loading sample") print() return self.get_sample(0) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] batch['image_sizes'] = [instance['image_size'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(attn_implementation=None): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_skip_modules=["mm_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if model_args.vision_tower is not None: if 'mpt' in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.attn_config['attn_impl'] = training_args.mpt_attn_impl model = LlavaMptForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) else: model = LlavaQwen2ForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) else: model = transformers.Qwen2ForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if 'mpt' in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right" ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: if tokenizer.unk_token: tokenizer.pad_token = tokenizer.unk_token else: # use qwen tokenizer.legacy = False if model_args.version in conversation_lib.conv_templates: # print('version:', model_args.version) conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end model.config.mm_projector_lr = training_args.mm_projector_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: llava/train/train_xformers.py ================================================ # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. # Need to call this before importing transformers. from llava.train.train import train from llava.train.llama_xformers_attn_monkey_patch import ( replace_llama_attn_with_xformers_attn, ) replace_llama_attn_with_xformers_attn() if __name__ == "__main__": train() ================================================ FILE: llava/utils.py ================================================ import datetime import logging import logging.handlers import os import sys import requests from llava.constants import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True, encoding='UTF-8') handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" ================================================ FILE: model_export/README.md ================================================ # Model Export for inference on Apple Silicon Disclaimer: this is not an official recommendation, just research and exploration. ## Export Vision Encoder We found that LLaVA trainer does not save all the states needed for auto inference, predominantly used in third party libraries like `mlx-vlm`. We save additional metadata to model checkpoint directory and export the vision model using coremltools. Export vision encoder and patch the checkpoint using the instruction below. ```bash python export_vision_encoder.py --model-path /path/to/fastvlm-checkpoint ``` ## Export VLM ### Install mlx-vlm We provide a patch to `mlx-vlm` to support inference of FastVLM. ```bash git clone https://github.com/Blaizzy/mlx-vlm.git cd mlx-vlm git checkout 1884b551bc741f26b2d54d68fa89d4e934b9a3de git apply ../fastvlm_mlx-vlm.patch pip install -e . ``` Export model using the following instruction. ```bash python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \ --mlx-path /path/to/exported-fastvlm \ --only-llm ``` To quantize the LLM, additional options can be provided as shown below. `--q-bits` specifies bits per weight, the command below exports the LLM with 8-bit quantization. ```bash python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \ --mlx-path /path/to/exported-fastvlm \ --only-llm \ -q \ --q-bits 8 # For 4-bit quantization, specify 4 ``` ### Generate The exported model can be used for inference in a python environment following the instruction below. ```bash python -m mlx_vlm.generate --model /path/to/exported-fastvlm \ --image /path/to/image.png \ --prompt "Describe the image." \ --max-tokens 256 \ --temp 0.0 ``` ## Troubleshooting We noticed that sometimes `config.json` for the LLaVA model incorrectly sets the value for `tie_word_embeddings`. This causes the following error during conversion, `ValueError: Received parameters not in model: language_model.lm_head.weight.` If you encounter this error, set the value of `tie_word_embeddings` accordingly. ================================================ FILE: model_export/export_vision_encoder.py ================================================ # # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import os import json import copy import argparse import torch import numpy as np import coremltools from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import get_model_name_from_path def export(args): # Load model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps") # Save extra metadata that is not saved during LLaVA training # required by HF for auto-loading model and for mlx-vlm preprocessing # Save image processing config setattr(image_processor, "processor_class", "LlavaProcessor") output_path = os.path.join(model_path, "preprocessor_config.json") image_processor.to_json_file(output_path) # Create processor config processor_config = dict() processor_config["image_token"] = "" processor_config["num_additional_image_tokens"] = 0 processor_config["processor_class"] = "LlavaProcessor" processor_config["patch_size"] = 64 output_path = os.path.join(model_path, "processor_config.json") json.dump(processor_config, open(output_path, "w"), indent=2) # Modify tokenizer to include special token. tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json") tokenizer_config = json.load(open(tokenizer_config_path, 'r')) token_ids = list() image_token_is_present = False for k, v in tokenizer_config['added_tokens_decoder'].items(): token_ids.append(int(k)) if v["content"] == "": image_token_is_present = True token_ids.pop() # Append only if token is not present if not image_token_is_present: tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy( tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}']) tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "" json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2) # Modify config to contain token id for config_path = os.path.join(model_path, "config.json") model_config = json.load(open(config_path, 'r')) model_config["image_token_index"] = max(token_ids) + 1 json.dump(model_config, open(config_path, 'w'), indent=2) # Export the vision encoder to CoreML image_res = image_processor.to_dict()['size']['shortest_edge'] inputs = torch.rand(1, 3, image_res, image_res) inputs_tensor = [ coremltools.TensorType( name="images", shape=inputs.shape, ) ] vision_model = model.get_vision_tower() vision_model = vision_model.float() traced_model = torch.jit.trace(vision_model, torch.Tensor(inputs)) pt_name = "fastvithd.pt" traced_model.save(pt_name) # Export ml_model = coremltools.convert( model=pt_name, outputs=[coremltools.TensorType(name="image_features", dtype=np.float32)], inputs=inputs_tensor, convert_to="mlprogram", debug=False, compute_units=coremltools.ComputeUnit.CPU_AND_GPU, minimum_deployment_target=coremltools.target.iOS16, compute_precision=coremltools.precision.FLOAT32 ) ml_model_path = os.path.join(model_path, "fastvithd.mlpackage") ml_model.save(ml_model_path) # Remove traced model os.remove(pt_name) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--conv-mode", type=str, default="qwen_2") args = parser.parse_args() export(args) ================================================ FILE: model_export/fastvlm_mlx-vlm.patch ================================================ diff --git a/mlx_vlm/convert.py b/mlx_vlm/convert.py index 5952a88..335e9db 100644 --- a/mlx_vlm/convert.py +++ b/mlx_vlm/convert.py @@ -55,6 +55,12 @@ def configure_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) + parser.add_argument( + "--only-llm", + help="Convert only LLM.", + action="store_true", + default=False, + ) return parser diff --git a/mlx_vlm/models/fastvlm/__init__.py b/mlx_vlm/models/fastvlm/__init__.py new file mode 100644 index 0000000..691192e --- /dev/null +++ b/mlx_vlm/models/fastvlm/__init__.py @@ -0,0 +1,7 @@ +from .fastvlm import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, +) diff --git a/mlx_vlm/models/fastvlm/fastvlm.py b/mlx_vlm/models/fastvlm/fastvlm.py new file mode 100644 index 0000000..7db6497 --- /dev/null +++ b/mlx_vlm/models/fastvlm/fastvlm.py @@ -0,0 +1,187 @@ +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import coremltools +from huggingface_hub import snapshot_download + +from .language import LanguageModel, TextConfig + + +@dataclass +class VisionConfig: + mm_hidden_size: int + mm_vision_tower: str + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 151936 + + @classmethod + def from_dict(cls, params): + # Copy text config parameters from root level + params["text_config"] = dict( + filter(lambda x: 'mm' not in x[0], params.items()) + ) + # Copy vision config parameters from root level + params["vision_config"] = dict( + filter(lambda x: 'mm' in x[0], params.items()) + ) + + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class FastVLMMultiModalProjector(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.linear_0 = nn.Linear( + config.vision_config.mm_hidden_size, config.text_config.hidden_size, bias=True + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_0(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.vision_tower = None + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = FastVLMMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model.model.embed_tokens(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get image features from CoreML model + coreml_out_dict = self.vision_tower.predict({"images": np.array(pixel_values, copy=False)}) + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(mx.array(coreml_out_dict["image_features"])) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + return final_inputs_embeds + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids + ): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + num_images, _, vision_hidden_size = image_features.shape + + reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) + + # cast to the dtype of the input_embeds to support quantized models + reshaped_image_hidden_states = reshaped_image_hidden_states.astype( + inputs_embeds.dtype + ) + inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states + return inputs_embeds + + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, + ): + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits = self.language_model( + input_ids, cache=cache, inputs_embeds=input_embddings + ) + return logits + + @staticmethod + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + + with open(path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = ModelConfig.from_dict(model_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) + + model = Model(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = LanguageModel.sanitize(weights) + + # Load CoreML vision tower + coreml_file = glob.glob(str(path / "*.mlpackage")) + assert len(coreml_file) == 1, "Found multiple vision model files" + model.vision_tower = coremltools.models.MLModel(coreml_file[0]) + + model.load_weights(list(weights.items())) + return model diff --git a/mlx_vlm/models/fastvlm/language.py b/mlx_vlm/models/fastvlm/language.py new file mode 100644 index 0000000..b791df4 --- /dev/null +++ b/mlx_vlm/models/fastvlm/language.py @@ -0,0 +1,220 @@ +import inspect +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from ..base import KVCache, LanguageModelOutput, create_attention_mask + + +@dataclass +class TextConfig: + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: Optional[int] = None + max_position_embeddings: Optional[int] = 32768 + rope_theta: float = 1000000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"mrope_section", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if not self.rope_scaling["type"] in ["mrope", "default"]: + raise ValueError(f"rope_scaling type must be 'mrope' or 'default'") + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Attention(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.rotary_emb = nn.RoPE( + head_dim, + base=args.rope_theta, + traditional=args.rope_traditional, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + + offset = cache.offset if cache else 0 + + if mask is not None: + mask = mask[..., : keys.shape[-2]] + + queries = self.rotary_emb(queries, offset=offset) + keys = self.rotary_emb(keys, offset=offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class Qwen2Model(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen2DecoderLayer(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds: Optional[mx.array] = None, + ): + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class LanguageModel(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen2Model(args) + + if "qwen2" not in args.model_type: + raise ValueError(f"Unsupported model type: {args.model_type}") + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return LanguageModelOutput(logits=out) + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 725e811..ba48296 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -93,6 +93,7 @@ def get_message_json( "idefics2": "message_list_with_image", "idefics3": "message_list_with_image", "llava": "message_list_with_image", + "llava_qwen2": "message_with_image_token_new_line", "llava_next": "message_list_with_image", "mllama": "message_list_with_image", # Models that can handle both image and video formats @@ -143,7 +144,7 @@ def get_message_json( def get_chat_template(processor, messages, add_generation_prompt, tokenize=False): - if "chat_template" in processor.__dict__.keys(): + if ("chat_template" in processor.__dict__.keys()) and (processor.chat_template is not None): return processor.apply_chat_template( messages, tokenize=tokenize, diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 4acff3e..00f366f 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1,3 +1,4 @@ +import os import copy import glob import importlib @@ -15,6 +16,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np import requests +import coremltools from huggingface_hub import snapshot_download from mlx.utils import tree_flatten, tree_unflatten from PIL import Image, ImageOps @@ -31,7 +33,7 @@ from .tokenizer_utils import load_tokenizer from .trainer import apply_lora_layers # Constants -MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"} +MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny", "llava_qwen2": "fastvlm"} MAX_FILE_SIZE_GB = 5 @@ -168,9 +170,19 @@ python -m mlx_vlm.convert --hf-path --mlx-path # Sanitize weights weights = sanitize_weights(model, weights) - weights = sanitize_weights( - model_class.VisionModel, weights, model_config.vision_config - ) + if hasattr(model_class, 'VisionModel'): + weights = sanitize_weights( + model_class.VisionModel, weights, model_config.vision_config + ) + else: + # Load CoreML vision tower + print("Looking for CoreML vision tower") + coreml_file = glob.glob(str(model_path / "*.mlpackage")) + if len(coreml_file) > 0: + assert len(coreml_file) == 1, "Found multiple vision model files." + print(f"Loading {coreml_file[0]} vision tower") + model.vision_tower = coremltools.models.MLModel(coreml_file[0]) + weights = sanitize_weights( model_class.LanguageModel, weights, model_config.text_config ) @@ -185,7 +197,21 @@ python -m mlx_vlm.convert --hf-path --mlx-path class_predicate=class_predicate, ) - model.load_weights(list(weights.items())) + if kwargs.get("only_llm", False): + # Ignore vision tower weights + new_weights = dict() + for k, v in weights.items(): + if 'vision_tower' in k: + continue + if 'mm_projector' in k: + new_k = k.replace('model.mm_projector.', 'multi_modal_projector.linear_') + new_weights[new_k] = v + else: + new_weights['language_model.'+k] = v + + model.load_weights(list(new_weights.items())) + else: + model.load_weights(list(weights.items())) if not lazy: mx.eval(model.parameters()) @@ -669,11 +695,12 @@ def convert( dequantize: bool = False, skip_vision: bool = False, trust_remote_code: bool = True, + only_llm: bool = False ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) model, config, processor = fetch_from_hub( - model_path, lazy=True, trust_remote_code=trust_remote_code + model_path, lazy=True, trust_remote_code=trust_remote_code, only_llm=only_llm ) weights = dict(tree_flatten(model.parameters())) @@ -709,6 +736,12 @@ def convert( save_config(config, config_path=mlx_path / "config.json") + # Copy over any coreml files if found + coreml_files = glob.glob(str(model_path / "*.mlpackage")) + for file in coreml_files: + des_path = os.path.join(mlx_path, file.split(os.path.sep)[-1]) + shutil.copytree(file, des_path) + if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path) ================================================ FILE: predict.py ================================================ # # Modified from LLaVA/predict.py # Please see ACKNOWLEDGEMENTS for details about LICENSE # import os import argparse import torch from PIL import Image from llava.utils import disable_torch_init from llava.conversation import conv_templates from llava.model.builder import load_pretrained_model from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN def predict(args): # Remove generation config from model folder # to read generation parameters from args model_path = os.path.expanduser(args.model_path) generation_config = None if os.path.exists(os.path.join(model_path, 'generation_config.json')): generation_config = os.path.join(model_path, '.generation_config.json') os.rename(os.path.join(model_path, 'generation_config.json'), generation_config) # Load model disable_torch_init() model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps") # Construct prompt qs = args.prompt if model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() # Set the pad token id for generation model.generation_config.pad_token_id = tokenizer.pad_token_id # Tokenize prompt input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch.device("mps")) # Load and preprocess image image = Image.open(args.image_file).convert('RGB') image_tensor = process_images([image], image_processor, model.config)[0] # Run inference with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor.unsqueeze(0).half(), image_sizes=[image.size], do_sample=True if args.temperature > 0 else False, temperature=args.temperature, top_p=args.top_p, num_beams=args.num_beams, max_new_tokens=256, use_cache=True) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() print(outputs) # Restore generation config if generation_config is not None: os.rename(generation_config, os.path.join(model_path, 'generation_config.json')) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="./llava-v1.5-13b") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image-file", type=str, default=None, help="location of image file") parser.add_argument("--prompt", type=str, default="Describe the image.", help="Prompt for VLM.") parser.add_argument("--conv-mode", type=str, default="qwen_2") parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--num_beams", type=int, default=1) args = parser.parse_args() predict(args) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "llava" version = "1.2.2.post1" description = "Towards GPT-4 like large language and visual assistant." readme = "README.md" requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "torch==2.6.0", "torchvision==0.21.0", "transformers==4.48.3", "tokenizers==0.21.0", "sentencepiece==0.1.99", "shortuuid", "accelerate==1.6.0", "peft>=0.10.0,<0.14.0", "bitsandbytes", "pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", "gradio==5.11.0", "requests", "uvicorn", "fastapi", "einops==0.6.1", "einops-exts==0.0.4", "timm==1.0.15", "coremltools==8.2" ] [project.optional-dependencies] train = ["deepspeed==0.13.1", "ninja", "wandb"] build = ["build", "twine"] [tool.setuptools.packages.find] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] [tool.wheel] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]