Repository: apple/ml-fastvlm
Branch: main
Commit: 592b4add3c1c
Files: 74
Total size: 506.0 KB
Directory structure:
gitextract_4fmnuh7_/
├── .gitignore
├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_MODEL
├── README.md
├── app/
│ ├── Configuration/
│ │ └── Build.xcconfig
│ ├── FastVLM/
│ │ ├── FastVLM.h
│ │ ├── FastVLM.swift
│ │ └── MediaProcessingExtensions.swift
│ ├── FastVLM App/
│ │ ├── Assets.xcassets/
│ │ │ ├── AccentColor.colorset/
│ │ │ │ └── Contents.json
│ │ │ ├── AppIcon.appiconset/
│ │ │ │ └── Contents.json
│ │ │ └── Contents.json
│ │ ├── ContentView.swift
│ │ ├── FastVLM.entitlements
│ │ ├── FastVLMApp.swift
│ │ ├── FastVLMModel.swift
│ │ ├── Info.plist
│ │ ├── InfoView.swift
│ │ └── Preview Content/
│ │ └── Preview Assets.xcassets/
│ │ └── Contents.json
│ ├── FastVLM.xcodeproj/
│ │ ├── project.pbxproj
│ │ └── xcshareddata/
│ │ └── xcschemes/
│ │ └── FastVLM App.xcscheme
│ ├── README.md
│ ├── Video/
│ │ ├── CameraController.swift
│ │ ├── CameraControlsView.swift
│ │ ├── CameraType.swift
│ │ ├── Video.h
│ │ └── VideoFrameView.swift
│ └── get_pretrained_mlx_model.sh
├── get_models.sh
├── llava/
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── builder.py
│ │ ├── consolidate.py
│ │ ├── language_model/
│ │ │ ├── llava_llama.py
│ │ │ ├── llava_mistral.py
│ │ │ ├── llava_mpt.py
│ │ │ └── llava_qwen.py
│ │ ├── llava_arch.py
│ │ ├── make_delta.py
│ │ ├── multimodal_encoder/
│ │ │ ├── builder.py
│ │ │ ├── clip_encoder.py
│ │ │ ├── mobileclip/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configs/
│ │ │ │ │ └── mobileclip_l.json
│ │ │ │ └── mci.py
│ │ │ └── mobileclip_encoder.py
│ │ ├── multimodal_projector/
│ │ │ └── builder.py
│ │ └── utils.py
│ ├── serve/
│ │ ├── __init__.py
│ │ ├── cli.py
│ │ ├── controller.py
│ │ ├── gradio_web_server.py
│ │ ├── model_worker.py
│ │ ├── register_worker.py
│ │ ├── sglang_worker.py
│ │ └── test_message.py
│ ├── train/
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── llama_xformers_attn_monkey_patch.py
│ │ ├── llava_trainer.py
│ │ ├── train.py
│ │ ├── train_mem.py
│ │ ├── train_qwen.py
│ │ └── train_xformers.py
│ └── utils.py
├── model_export/
│ ├── README.md
│ ├── export_vision_encoder.py
│ └── fastvlm_mlx-vlm.patch
├── predict.py
└── pyproject.toml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# macOS
**/.DS_Store
# PyCharm project settings
.idea/
# Xcode
*.xcworkspace
# FastVLM models
app/FastVLM/model
================================================
FILE: ACKNOWLEDGEMENTS
================================================
Acknowledgements
Portions of this Software may utilize the following copyrighted
material, the use of which is hereby acknowledged.
---------------------------------------------------------------------------------
LLaVA: Large Language and Vision Assistant (LLaVA)
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---------------------------------------------------------------------------------
FastViT (ml-fastvit)
Copyright (C) 2023 Apple Inc. All Rights Reserved.
IMPORTANT: This Apple software is supplied to you by Apple
Inc. ("Apple") in consideration of your agreement to the following
terms, and your use, installation, modification or redistribution of
this Apple software constitutes acceptance of these terms. If you do
not agree with these terms, please do not use, install, modify or
redistribute this Apple software.
In consideration of your agreement to abide by the following terms, and
subject to these terms, Apple grants you a personal, non-exclusive
license, under Apple's copyrights in this original Apple software (the
"Apple Software"), to use, reproduce, modify and redistribute the Apple
Software, with or without modifications, in source and/or binary forms;
provided that if you redistribute the Apple Software in its entirety and
without modifications, you must retain this notice and the following
text and disclaimers in all such redistributions of the Apple Software.
Neither the name, trademarks, service marks or logos of Apple Inc. may
be used to endorse or promote products derived from the Apple Software
without specific prior written permission from Apple. Except as
expressly stated in this notice, no other rights or licenses, express or
implied, are granted by Apple herein, including but not limited to any
patent rights that may be infringed by your derivative works or by other
works in which the Apple Software may be incorporated.
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
---------------------------------------------------------------------------------
mlx-vlm
MIT License
Copyright © 2023 Apple Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------------------------------------------------------------------------
MobileCLIP (ml-mobileclip)
MIT License
Copyright © 2024 Apple Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-------------------------------------------------------------------------------
SOFTWARE DISTRIBUTED WITH ML-MobileCLIP:
The ML-MobileCLIP model weights and data copyright and license terms can be
found in LICENSE_weights_data.
The ML-MobileCLIP software includes a number of subcomponents with separate
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
---------------------------------------------------------------------------------
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
================================================
FILE: CONTRIBUTING.md
================================================
# Contribution Guide
Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
## Before you get started
By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
================================================
FILE: LICENSE
================================================
Copyright (C) 2025 Apple Inc. All Rights Reserved.
IMPORTANT: This Apple software is supplied to you by Apple
Inc. ("Apple") in consideration of your agreement to the following
terms, and your use, installation, modification or redistribution of
this Apple software constitutes acceptance of these terms. If you do
not agree with these terms, please do not use, install, modify or
redistribute this Apple software.
In consideration of your agreement to abide by the following terms, and
subject to these terms, Apple grants you a personal, non-exclusive
license, under Apple's copyrights in this original Apple software (the
"Apple Software"), to use, reproduce, modify and redistribute the Apple
Software, with or without modifications, in source and/or binary forms;
provided that if you redistribute the Apple Software in its entirety and
without modifications, you must retain this notice and the following
text and disclaimers in all such redistributions of the Apple Software.
Neither the name, trademarks, service marks or logos of Apple Inc. may
be used to endorse or promote products derived from the Apple Software
without specific prior written permission from Apple. Except as
expressly stated in this notice, no other rights or licenses, express or
implied, are granted by Apple herein, including but not limited to any
patent rights that may be infringed by your derivative works or by other
works in which the Apple Software may be incorporated.
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
-------------------------------------------------------------------------------
SOFTWARE DISTRIBUTED WITH ML-FASTVLM:
The ml-fastvlm software includes a number of subcomponents with separate
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
The ml-fastvlm model weights copyright and license terms can be
found in LICENSE_MODEL file.
-------------------------------------------------------------------------------
================================================
FILE: LICENSE_MODEL
================================================
Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is
specifically developed and released by Apple Inc. ("Apple") for the sole purpose
of scientific research of artificial intelligence and machine-learning
technology. “Apple Machine Learning Research Model” means the model, including
but not limited to algorithms, formulas, trained model weights, parameters,
configurations, checkpoints, and any related materials (including
documentation).
This Apple Machine Learning Research Model is provided to You by
Apple in consideration of your agreement to the following terms, and your use,
modification, creation of Model Derivatives, and or redistribution of the Apple
Machine Learning Research Model constitutes acceptance of this Agreement. If You
do not agree with these terms, please do not use, modify, create Model
Derivatives of, or distribute this Apple Machine Learning Research Model or
Model Derivatives.
* License Scope: In consideration of your agreement to abide by the following
terms, and subject to these terms, Apple hereby grants you a personal,
non-exclusive, worldwide, non-transferable, royalty-free, revocable, and
limited license, to use, copy, modify, distribute, and create Model
Derivatives (defined below) of the Apple Machine Learning Research Model
exclusively for Research Purposes. You agree that any Model Derivatives You
may create or that may be created for You will be limited to Research Purposes
as well. “Research Purposes” means non-commercial scientific research and
academic development activities, such as experimentation, analysis, testing
conducted by You with the sole intent to advance scientific knowledge and
research. “Research Purposes” does not include any commercial exploitation,
product development or use in any commercial product or service.
* Distribution of Apple Machine Learning Research Model and Model Derivatives:
If you choose to redistribute Apple Machine Learning Research Model or its
Model Derivatives, you must provide a copy of this Agreement to such third
party, and ensure that the following attribution notice be provided: “Apple
Machine Learning Research Model is licensed under the Apple Machine Learning
Research Model License Agreement.” Additionally, all Model Derivatives must
clearly be identified as such, including disclosure of modifications and
changes made to the Apple Machine Learning Research Model. The name,
trademarks, service marks or logos of Apple may not be used to endorse or
promote Model Derivatives or the relationship between You and Apple. “Model
Derivatives” means any models or any other artifacts created by modifications,
improvements, adaptations, alterations to the architecture, algorithm or
training processes of the Apple Machine Learning Research Model, or by any
retraining, fine-tuning of the Apple Machine Learning Research Model.
* No Other License: Except as expressly stated in this notice, no other rights
or licenses, express or implied, are granted by Apple herein, including but
not limited to any patent, trademark, and similar intellectual property rights
worldwide that may be infringed by the Apple Machine Learning Research Model,
the Model Derivatives or by other works in which the Apple Machine Learning
Research Model may be incorporated.
* Compliance with Laws: Your use of Apple Machine Learning Research Model must
be in compliance with all applicable laws and regulations.
* Term and Termination: The term of this Agreement will begin upon your
acceptance of this Agreement or use of the Apple Machine Learning Research
Model and will continue until terminated in accordance with the following
terms. Apple may terminate this Agreement at any time if You are in breach of
any term or condition of this Agreement. Upon termination of this Agreement,
You must cease to use all Apple Machine Learning Research Models and Model
Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will
survive termination.
* Disclaimer and Limitation of Liability: This Apple Machine Learning Research
Model and any outputs generated by the Apple Machine Learning Research Model
are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR
IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE,
REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY
THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for
determining the appropriateness of using or redistributing the Apple Machine
Learning Research Model and any outputs of the Apple Machine Learning Research
Model and assume any risks associated with Your use of the Apple Machine
Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE
LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF
THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE
LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT,
TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS
BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* Governing Law: This Agreement will be governed by and construed under the laws
of the State of California without regard to its choice of law principles. The
Convention on Contracts for the International Sale of Goods shall not apply to
the Agreement except that the arbitration clause and any arbitration hereunder
shall be governed by the Federal Arbitration Act, Chapters 1 and 2.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
================================================
FILE: README.md
================================================
# FastVLM: Efficient Vision Encoding for Vision Language Models
This is the official repository of
**[FastVLM: Efficient Vision Encoding for Vision Language Models](https://www.arxiv.org/abs/2412.13303). (CVPR 2025)**
[//]: # ()
### Highlights
* We introduce FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images.
* Our smallest variant outperforms LLaVA-OneVision-0.5B with 85x faster Time-to-First-Token (TTFT) and 3.4x smaller vision encoder.
* Our larger variants using Qwen2-7B LLM outperform recent works like Cambrian-1-8B while using a single image encoder with a 7.9x faster TTFT.
* Demo iOS app to demonstrate the performance of our model on a mobile device.
## Getting Started
We use LLaVA codebase to train FastVLM variants. In order to train or finetune your own variants,
please follow instructions provided in [LLaVA](https://github.com/haotian-liu/LLaVA) codebase.
We provide instructions for running inference with our models.
### Setup
```bash
conda create -n fastvlm python=3.10
conda activate fastvlm
pip install -e .
```
### Model Zoo
For detailed information on various evaluations, please refer to our [paper](https://www.arxiv.org/abs/2412.13303).
| Model | Stage | Pytorch Checkpoint (url) |
|:-------------|:-----:|:---------------------------------------------------------------------------------------------------------------:|
| FastVLM-0.5B | 2 | [fastvlm_0.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip) |
| | 3 | [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip) |
| FastVLM-1.5B | 2 | [fastvlm_1.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip) |
| | 3 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip) |
| FastVLM-7B | 2 | [fastvlm_7b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip) |
| | 3 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip) |
To download all the pretrained checkpoints run the command below (note that this might take some time depending on your connection so might be good to grab ☕️ while you wait).
```bash
bash get_models.sh # Files will be downloaded to `checkpoints` directory.
```
### Usage Example
To run inference of PyTorch checkpoint, follow the instruction below
```bash
python predict.py --model-path /path/to/checkpoint-dir \
--image-file /path/to/image.png \
--prompt "Describe the image."
```
### Inference on Apple Silicon
To run inference on Apple Silicon, pytorch checkpoints have to be exported to format
suitable for running on Apple Silicon, detailed instructions and code can be found [`model_export`](model_export/) subfolder.
Please see the README there for more details.
For convenience, we provide 3 models that are in Apple Silicon compatible format: [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3_llm.fp16.zip),
[fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3_llm.int8.zip),
[fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3_llm.int4.zip).
We encourage developers to export the model of their choice with the appropriate quantization levels following
the instructions in [`model_export`](model_export/).
### Inference on Apple Devices
To run inference on Apple devices like iPhone, iPad or Mac, see [`app`](app/) subfolder for more details.
## Citation
If you found this code useful, please cite the following paper:
```
@InProceedings{fastvlm2025,
author = {Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari},
title = {FastVLM: Efficient Vision Encoding for Vision Language Models},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2025},
}
```
## Acknowledgements
Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
## License
Please check out the repository [LICENSE](LICENSE) before using the provided code and
[LICENSE_MODEL](LICENSE_MODEL) for the released models.
================================================
FILE: app/Configuration/Build.xcconfig
================================================
// The `DISAMBIGUATOR` configuration is to make it easier to build
// and run a sample code project. Once you set your project's development team,
// you'll have a unique bundle identifier. This is because the bundle identifier
// is derived based on the 'DISAMBIGUATOR' value. Do not use this
// approach in your own projects—it's only useful for example projects because
// they are frequently downloaded and don't have a development team set.
DISAMBIGUATOR=${DEVELOPMENT_TEAM}
================================================
FILE: app/FastVLM/FastVLM.h
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
#ifndef FastVLM_h
#define FastVLM_h
#endif /* FastVLM_h */
================================================
FILE: app/FastVLM/FastVLM.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import CoreImage
import CoreML
import Foundation
import MLX
import MLXFast
import MLXLMCommon
import MLXNN
import MLXVLM
import Tokenizers
// FastVLM is Qwen2VL with a custom vision tower.
// MARK: - Common
/// Rotates half the hidden dims of the input
private func rotateHalf(_ x: MLXArray) -> MLXArray {
let index = x.dim(-1) / 2
let x1 = x[.ellipsis, 0 ..< index]
let x2 = x[.ellipsis, index...]
return concatenated([-x2, x1], axis: -1)
}
// MARK: - Language
private enum Language {
/// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
static private func applyMultimodalRotaryPositionEmbedding(
q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray,
positionIds: MLXArray, mropeSection: [Int]
) -> (MLXArray, MLXArray) {
var cos = cos[positionIds]
var sin = sin[positionIds]
cos =
concatenated(
// [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))]
split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] },
axis: -1
)[0..., .newAxis, 0..., 0...]
sin =
concatenated(
split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] },
axis: -1
)[0..., .newAxis, 0..., 0...]
// Apply rotary embedding
let qEmbed = (q * cos) + (rotateHalf(q) * sin)
let kEmbed = (k * cos) + (rotateHalf(k) * sin)
return (qEmbed, kEmbed)
}
fileprivate class Attention: Module {
let heads: Int
let kvHeads: Int
let headDim: Int
let scale: Float
let mropeSection: [Int]
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
@ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE
public init(_ args: FastVLMConfiguration.TextConfiguration) {
let dim = args.hiddenSize
self.heads = args.attentionHeads
self.kvHeads = args.kvHeads
self.headDim = dim / heads
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() {
// mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
self.mropeSection = sequence(state: (0, array.makeIterator())) { state in
if let v = state.1.next() {
// note the *2
state.0 += v * 2
return state.0
} else {
return nil
}
}.dropLast()
} else {
fatalError("rope_scaling['mrope_section'] must be an array of integers")
}
self._rotaryEmbedding.wrappedValue = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
) -> MLXArray {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
let offset = cache?.offset ?? 0
let mask = mask?[0..., 0 ..< keys.dim(-2)]
queries = rotaryEmbedding(queries, offset: offset)
keys = rotaryEmbedding(keys, offset: offset)
if let cache {
(keys, values) = cache.update(keys: keys, values: values)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return wo(output)
}
}
fileprivate class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
down(silu(gate(x)) * up(x))
}
}
fileprivate class FastVLMDecoderLayer: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: FastVLMConfiguration.TextConfiguration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
) -> MLXArray {
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return out
}
}
fileprivate class Qwen2Model: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [FastVLMDecoderLayer]
fileprivate let norm: RMSNorm
public init(_ args: FastVLMConfiguration.TextConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
FastVLMDecoderLayer(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil
) -> MLXArray {
var h: MLXArray
if let inputEmbedding {
h = inputEmbedding
} else if let inputs {
h = embedTokens(inputs)
} else {
fatalError("one of inputs or inputEmbedding must be non-nil")
}
let mask = createAttentionMask(h: h, cache: cache)
for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}
return norm(h)
}
}
fileprivate class LanguageModel: Module, KVCacheDimensionProvider {
@ModuleInfo var model: Qwen2Model
@ModuleInfo(key: "lm_head") var lmHead: Linear?
var kvHeads: [Int]
public init(_ args: FastVLMConfiguration.TextConfiguration) {
self.model = Qwen2Model(args)
if !args.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
}
public func callAsFunction(
_ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil
) -> LMOutput {
var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return LMOutput(logits: out)
}
}
}
// MARK: - Vision
private enum Vision {
fileprivate class VisionModelCoreML {
let lock = NSLock()
var _model: fastvithd?
init() {
}
func load() throws -> fastvithd {
try lock.withLock {
if let model = _model { return model }
let model = try fastvithd()
_model = model
return model
}
}
public func model() -> fastvithd {
try! load()
}
public func encode(_ image: MLXArray) -> MLXArray {
// MLMultiArray requires mutable input data
var (data, strides) = {
let arrayData = image.asType(.float32).asData(access: .noCopyIfContiguous)
return (arrayData.data, arrayData.strides)
}()
precondition(image.ndim == 4)
precondition(image.dim(0) == 1)
precondition(image.dim(1) == 3)
let h = NSNumber(value: image.dim(2))
let w = NSNumber(value: image.dim(3))
return data.withUnsafeMutableBytes { (ptr: UnsafeMutableRawBufferPointer) in
// wrap the backing of the MLXArray
let array = try! MLMultiArray(
dataPointer: ptr.baseAddress!, shape: [1, 3, h, w], dataType: .float32,
strides: strides.map { .init(value: $0) })
// inference
let output = try! model().prediction(images: array)
precondition(output.image_features.shape == [1, 256, 3072])
precondition(output.image_features.dataType == .float32)
return output.image_features.withUnsafeBytes { ptr in
MLXArray(ptr, [1, 256, 3072], type: Float32.self)
}
}
}
}
fileprivate class VisionModel: Module {
let model = VisionModelCoreML()
public override init() {}
public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray {
model.encode(hiddenStates)
}
}
}
// MARK: - Processor
/// FastVLM `UserInputProcessor`.
///
/// This is meant to be used with ``FastVLM`` and is typically created by ``VLMModelFactory``.
public class FastVLMProcessor: UserInputProcessor {
private let config: FastVLMProcessorConfiguration
private let imageProcessingConfig: FastVLMPreProcessorConfiguration
private let tokenizer: any Tokenizer
public init(_ config: FastVLMPreProcessorConfiguration, tokenizer: any Tokenizer) {
self.config = FastVLMProcessorConfiguration()
self.imageProcessingConfig = config
self.tokenizer = tokenizer
}
public func preprocess(image: CIImage, processing: UserInput.Processing?) throws -> (
MLXArray, THW
) {
// first apply the user requested resizing, etc. if any
var image = MediaProcessingExtensions.apply(image, processing: processing)
// image_processing_clip.py
let size = MediaProcessingExtensions.fitIn(
image.extent.size, shortestEdge: imageProcessingConfig.size.shortestEdge)
image = MediaProcessingExtensions.resampleBicubic(image, to: size)
image = MediaProcessingExtensions.centerCrop(
image, size: imageProcessingConfig.cropSize.size)
image = MediaProcessing.normalize(
image, mean: imageProcessingConfig.imageMeanTuple,
std: imageProcessingConfig.imageStdTuple)
let array = MediaProcessingExtensions.asPlanarMLXArray(image)
return (array, .init(0, array.dim(2), array.dim(3)))
}
public func prepare(prompt: UserInput.Prompt, imageTHW: THW?) -> String {
var messages = prompt.asMessages()
if messages[0]["role"] != "system" {
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
}
let lastIndex = messages.count - 1
var lastMessage = messages[lastIndex]["content"] ?? ""
// processing_llava.py
if let imageTHW {
let height = imageTHW.h
let width = imageTHW.w
let patchSize = config.patchSize
var numImageTokens =
(height / patchSize) * (width / patchSize) + config.numAdditionalImageTokens
if config.visionFeatureSelectStrategy == .default {
numImageTokens -= 1
}
lastMessage += Array(repeating: config.imageToken, count: numImageTokens)
.joined()
}
messages[lastIndex]["content"] = lastMessage
return
messages
.map {
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
}
.joined(separator: "\n")
+ "\n<|im_start|>assistant\n"
}
public func prepare(input: UserInput) throws -> LMInput {
if input.images.isEmpty {
// just a straight text prompt
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
let promptTokens = tokenizer.encode(text: prompt)
return LMInput(tokens: MLXArray(promptTokens))
}
if input.images.count > 1 {
throw VLMError.singleImageAllowed
}
let (pixels, thw) = try preprocess(
image: input.images[0].asCIImage(), processing: input.processing)
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: [thw])
let prompt = prepare(prompt: input.prompt, imageTHW: thw)
let promptTokens = tokenizer.encode(text: prompt)
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: promptArray).asType(.int8)
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
}
}
// MARK: - Model
private class FastVLMMultiModalProjector: Module, UnaryLayer {
@ModuleInfo(key: "linear_0") var linear0: Linear
@ModuleInfo(key: "gelu") var gelu: GELU
@ModuleInfo(key: "linear_2") var linear2: Linear
public init(_ config: FastVLMConfiguration) {
self._linear0.wrappedValue = Linear(
config.visionConfiguration.hiddenSize,
config.textConfiguration.hiddenSize,
bias: true)
self._gelu.wrappedValue = GELU()
self._linear2.wrappedValue = Linear(
config.textConfiguration.hiddenSize,
config.textConfiguration.hiddenSize,
bias: true)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = linear0(x)
x = gelu(x)
x = linear2(x)
return x
}
}
/// FastVLM
///
/// This is typically created by ``VLMModelFactory``.
public class FastVLM: Module, VLMModel, KVCacheDimensionProvider {
static public var modelConfiguration: ModelConfiguration {
let bundle = Bundle(for: FastVLM.self)
let url = bundle.url(forResource: "config", withExtension: "json")!
.resolvingSymlinksInPath()
.deletingLastPathComponent()
return ModelConfiguration(directory: url)
}
static public func register(modelFactory: VLMModelFactory) {
modelFactory.typeRegistry.registerModelType("llava_qwen2") { url in
let configuration = try JSONDecoder().decode(
FastVLMConfiguration.self, from: Data(contentsOf: url))
return FastVLM(configuration)
}
modelFactory.processorRegistry.registerProcessorType("LlavaProcessor") { url, tokenizer in
let configuration = try JSONDecoder().decode(
FastVLMPreProcessorConfiguration.self, from: Data(contentsOf: url))
return FastVLMProcessor(configuration, tokenizer: tokenizer)
}
}
@ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel
@ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel
@ModuleInfo(key: "multi_modal_projector") private var multiModalProjector:
FastVLMMultiModalProjector
public let config: FastVLMConfiguration
public var vocabularySize: Int { config.baseConfiguration.vocabularySize }
public var kvHeads: [Int] { languageModel.kvHeads }
public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers {
languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
public init(_ config: FastVLMConfiguration) {
self.config = config
self._visionModel.wrappedValue = Vision.VisionModel()
self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration)
self._multiModalProjector.wrappedValue = FastVLMMultiModalProjector(config)
}
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?)
-> MLXArray
{
guard let pixelValues, let gridThw else {
return languageModel(inputIds).logits
}
// Get the input embeddings from the language model
let inputEmbeds = languageModel.model.embedTokens(inputIds)
// Get the ouptut hidden states from the vision model
let imageFeaturesCoreML = self.visionModel(pixelValues, gridThw: gridThw)
let imageFeatures = multiModalProjector(imageFeaturesCoreML)
// Insert special image tokens in the input_ids
return mergeInputIdsWithImageFeatures(
inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: imageFeatures)
}
private func mergeInputIdsWithImageFeatures(
inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray
) -> MLXArray {
let imageTokenIndex = config.baseConfiguration.imageTokenId
var imageIndices = [Int]()
for (i, v) in inputIds.asArray(Int.self).enumerated() {
if v == imageTokenIndex {
imageIndices.append(i)
}
}
inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures
return inputEmbeds
}
public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws
-> PrepareResult
{
let gridThw = input.image?.imageGridThw
let dtype = DType.float32
let pixels = input.image?.pixels.asType(dtype)
let inputEmbeddings = self.inputEmbeddings(
inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)
let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)
return .logits(result)
}
public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray {
languageModel(inputs, cache: cache).logits
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
_ = try? visionModel.model.load()
return weights
}
}
// MARK: - Configuration
/// Configuration for ``FastVLM``
public struct FastVLMConfiguration: Codable, Sendable {
public struct VisionConfiguration: Codable, Sendable {
public let hiddenSize: Int
enum CodingKeys: String, CodingKey {
case hiddenSize = "mm_hidden_size"
}
}
public struct TextConfiguration: Codable, Sendable {
public let modelType: String
public let hiddenSize: Int
public let hiddenLayers: Int
public let intermediateSize: Int
public let attentionHeads: Int
private let _rmsNormEps: Float?
public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 }
public let vocabularySize: Int
public let kvHeads: Int
private let _maxPositionEmbeddings: Int?
public var maxpPositionEmbeddings: Int { _maxPositionEmbeddings ?? 32768 }
private let _ropeTheta: Float?
public var ropeTheta: Float { _ropeTheta ?? 1_000_000 }
private let _ropeTraditional: Bool?
public var ropeTraditional: Bool { _ropeTraditional ?? false }
public let _ropeScaling: [String: StringOrNumber]?
public var ropeScaling: [String: StringOrNumber]? {
_ropeScaling ?? ["mrope_section": .ints([2, 1, 1])]
}
private let _tieWordEmbeddings: Bool?
public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true }
enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case _rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case _maxPositionEmbeddings = "max_position_embeddings"
case _ropeTheta = "rope_theta"
case _ropeTraditional = "rope_traditional"
case _ropeScaling = "rope_scaling"
case _tieWordEmbeddings = "tie_word_embeddings"
}
}
public struct BaseConfiguration: Codable, Sendable {
public let modelType: String
public let vocabularySize: Int
public let imageTokenId: Int
public let hiddenSize: Int
enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case vocabularySize = "vocab_size"
case imageTokenId = "image_token_index"
case hiddenSize = "hidden_size"
}
}
public let visionConfiguration: VisionConfiguration
public let textConfiguration: TextConfiguration
public let baseConfiguration: BaseConfiguration
public init(from decoder: any Swift.Decoder) throws {
// these are overlaid in the top level
self.visionConfiguration = try VisionConfiguration(from: decoder)
self.textConfiguration = try TextConfiguration(from: decoder)
self.baseConfiguration = try BaseConfiguration(from: decoder)
}
}
/// Configuration for ``FastVLMProcessor``
public struct FastVLMPreProcessorConfiguration: Codable, Sendable {
public struct CropSize: Codable, Sendable {
let width: Int
let height: Int
var size: CGSize { .init(width: CGFloat(width), height: CGFloat(height)) }
}
public struct Size: Codable, Sendable {
let shortestEdge: Int
enum CodingKeys: String, CodingKey {
case shortestEdge = "shortest_edge"
}
}
public var imageMean: [CGFloat]
public var imageStd: [CGFloat]
public var size: Size
public var cropSize: CropSize
public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) {
(imageMean[0], imageMean[1], imageMean[2])
}
public var imageStdTuple: (CGFloat, CGFloat, CGFloat) {
(imageStd[0], imageStd[1], imageStd[2])
}
enum CodingKeys: String, CodingKey {
case imageMean = "image_mean"
case imageStd = "image_std"
case size
case cropSize = "crop_size"
}
}
public struct FastVLMProcessorConfiguration: Codable, Sendable {
public enum Strategy: Codable, Sendable {
case `default`
}
public var imageToken = ""
public var numAdditionalImageTokens = 0
public var patchSize = 64
public var visionFeatureSelectStrategy: Strategy?
}
================================================
FILE: app/FastVLM/MediaProcessingExtensions.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import Accelerate
import CoreImage
import MLX
import MLXLMCommon
import MLXVLM
/// Additions to MediaProcessing -- not currently present in mlx-libraries
enum MediaProcessingExtensions {
// this function is not exported in current mlx-swift-examples -- local copy until it is exposed
// properly
public static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
var image = image
if let resize = processing?.resize {
let scale = MediaProcessing.bestFitScale(image.extent.size, in: resize)
image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale))
}
return image
}
public static func rectSmallerOrEqual(_ extent: CGRect, size: CGSize) -> Bool {
return extent.width <= size.width && extent.height <= size.height
}
public static func centerCrop(_ extent: CGRect, size: CGSize) -> CGRect {
let targetWidth = min(extent.width, size.width)
let targetHeight = min(extent.height, size.height)
return CGRect(
x: (extent.maxX - targetWidth) / 2,
y: (extent.maxY - targetHeight) / 2,
width: targetWidth, height: targetHeight
)
}
public static func centerCrop(_ image: CIImage, size: CGSize) -> CIImage {
let extent = image.extent
if rectSmallerOrEqual(extent, size: size) {
return image
}
let crop = centerCrop(extent, size: size)
return
image
.cropped(to: crop)
.transformed(by: CGAffineTransform(translationX: -crop.minX, y: -crop.minY))
}
public static func fitIn(_ size: CGSize, shortestEdge: Int) -> CGSize {
let floatShortestEdge = CGFloat(shortestEdge)
let (short, long) =
size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
let newShort = floatShortestEdge
let newLong = floatShortestEdge * long / short
return size.width <= size.height
? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
}
public static func fitIn(_ size: CGSize, longestEdge: Int) -> CGSize {
let floatLongestEdge = CGFloat(longestEdge)
var (newShort, newLong) =
size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
if newLong > floatLongestEdge {
newLong = floatLongestEdge
newShort = floatLongestEdge * newShort / newLong
}
return size.width <= size.height
? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
}
// version of function from https://github.com/ml-explore/mlx-swift-examples/pull/222
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
// Create a bicubic scale filter
let yScale = size.height / image.extent.height
let xScale = size.width / image.extent.width
let filter = CIFilter.bicubicScaleTransform()
filter.inputImage = image
filter.scale = Float(yScale)
filter.aspectRatio = Float(xScale / yScale)
let scaledImage = filter.outputImage!
// Create a rect with the exact dimensions we want
let exactRect = CGRect(
x: 0,
y: 0,
width: size.width,
height: size.height
)
// Crop to ensure exact dimensions
return scaledImage.cropped(to: exactRect)
}
static let context = CIContext()
/// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`.
///
/// This physically moves the channels into a planar configuration -- this is
/// required for feeding into the CoreML model and is faster to use
/// dedicated functions than transforming into contiguous memory
/// on readout.
static public func asPlanarMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil)
-> MLXArray
{
let size = image.extent.size
let w = Int(size.width.rounded())
let h = Int(size.height.rounded())
// probably not strictly necessary, but this is what happens in
// e.g. image_processing_siglip in transformers (float32)
let format = CIFormat.RGBAf
let componentsPerPixel = 4
let bytesPerComponent: Int = MemoryLayout.size
let bytesPerPixel = componentsPerPixel * bytesPerComponent
let bytesPerRow = w * bytesPerPixel
var data = Data(count: w * h * bytesPerPixel)
var planarData = Data(count: 3 * w * h * bytesPerComponent)
data.withUnsafeMutableBytes { ptr in
context.render(
image, toBitmap: ptr.baseAddress!, rowBytes: bytesPerRow, bounds: image.extent,
format: format, colorSpace: colorSpace)
context.clearCaches()
let vh = vImagePixelCount(h)
let vw = vImagePixelCount(w)
// convert from RGBAf -> RGBf in place
let rgbBytesPerRow = w * 3 * bytesPerComponent
var rgbaSrc = vImage_Buffer(
data: ptr.baseAddress!, height: vh, width: vw, rowBytes: bytesPerRow)
var rgbDest = vImage_Buffer(
data: ptr.baseAddress!, height: vh, width: vw, rowBytes: rgbBytesPerRow)
vImageConvert_RGBAFFFFtoRGBFFF(&rgbaSrc, &rgbDest, vImage_Flags(kvImageNoFlags))
// and convert to planar data in a second buffer
planarData.withUnsafeMutableBytes { planarPtr in
let planeBytesPerRow = w * bytesPerComponent
var rDest = vImage_Buffer(
data: planarPtr.baseAddress!.advanced(by: 0 * planeBytesPerRow * h), height: vh,
width: vw, rowBytes: planeBytesPerRow)
var gDest = vImage_Buffer(
data: planarPtr.baseAddress!.advanced(by: 1 * planeBytesPerRow * h), height: vh,
width: vw, rowBytes: planeBytesPerRow)
var bDest = vImage_Buffer(
data: planarPtr.baseAddress!.advanced(by: 2 * planeBytesPerRow * h), height: vh,
width: vw, rowBytes: planeBytesPerRow)
vImageConvert_RGBFFFtoPlanarF(
&rgbDest, &rDest, &gDest, &bDest, vImage_Flags(kvImageNoFlags))
}
}
return MLXArray(planarData, [1, 3, h, w], type: Float32.self)
}
}
================================================
FILE: app/FastVLM App/Assets.xcassets/AccentColor.colorset/Contents.json
================================================
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}
================================================
FILE: app/FastVLM App/Assets.xcassets/AppIcon.appiconset/Contents.json
================================================
{
"images" : [
{
"filename" : "FastVLM - 150 Blue - Light@2x.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "dark"
}
],
"filename" : "FastVLM - 150 Blue - Dark@2x.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "tinted"
}
],
"filename" : "FastVLM - 150 White - Tinted@2x.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"filename" : "FastVLM - MacOS - Dark@1x.png",
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"filename" : "FastVLM - MacOS - Dark@2x.png",
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}
================================================
FILE: app/FastVLM App/Assets.xcassets/Contents.json
================================================
{
"info" : {
"author" : "xcode",
"version" : 1
}
}
================================================
FILE: app/FastVLM App/ContentView.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import AVFoundation
import MLXLMCommon
import SwiftUI
import Video
// support swift 6
extension CVImageBuffer: @unchecked @retroactive Sendable {}
extension CMSampleBuffer: @unchecked @retroactive Sendable {}
// delay between frames -- controls the frame rate of the updates
let FRAME_DELAY = Duration.milliseconds(1)
struct ContentView: View {
@State private var camera = CameraController()
@State private var model = FastVLMModel()
/// stream of frames -> VideoFrameView, see distributeVideoFrames
@State private var framesToDisplay: AsyncStream?
@State private var prompt = "Describe the image in English."
@State private var promptSuffix = "Output should be brief, about 15 words or less."
@State private var isShowingInfo: Bool = false
@State private var selectedCameraType: CameraType = .continuous
@State private var isEditingPrompt: Bool = false
var toolbarItemPlacement: ToolbarItemPlacement {
var placement: ToolbarItemPlacement = .navigation
#if os(iOS)
placement = .topBarLeading
#endif
return placement
}
var statusTextColor : Color {
return model.evaluationState == .processingPrompt ? .black : .white
}
var statusBackgroundColor : Color {
switch model.evaluationState {
case .idle:
return .gray
case .generatingResponse:
return .green
case .processingPrompt:
return .yellow
}
}
var body: some View {
NavigationStack {
Form {
Section {
VStack(alignment: .leading, spacing: 10.0) {
Picker("Camera Type", selection: $selectedCameraType) {
ForEach(CameraType.allCases, id: \.self) { cameraType in
Text(cameraType.rawValue.capitalized).tag(cameraType)
}
}
// Prevent macOS from adding a text label for the picker
.labelsHidden()
.pickerStyle(.segmented)
.onChange(of: selectedCameraType) { _, _ in
// Cancel any in-flight requests when switching modes
model.cancel()
}
if let framesToDisplay {
VideoFrameView(
frames: framesToDisplay,
cameraType: selectedCameraType,
action: { frame in
processSingleFrame(frame)
})
// Because we're using the AVCaptureSession preset
// `.vga640x480`, we can assume this aspect ratio
.aspectRatio(4/3, contentMode: .fit)
#if os(macOS)
.frame(maxWidth: 750)
#endif
.overlay(alignment: .top) {
if !model.promptTime.isEmpty {
Text("TTFT \(model.promptTime)")
.font(.caption)
.foregroundStyle(.white)
.monospaced()
.padding(.vertical, 4.0)
.padding(.horizontal, 6.0)
.background(alignment: .center) {
RoundedRectangle(cornerRadius: 8)
.fill(Color.black.opacity(0.6))
}
.padding(.top)
}
}
#if !os(macOS)
.overlay(alignment: .topTrailing) {
CameraControlsView(
backCamera: $camera.backCamera,
device: $camera.device,
devices: $camera.devices)
.padding()
}
#endif
.overlay(alignment: .bottom) {
if selectedCameraType == .continuous {
Group {
if model.evaluationState == .processingPrompt {
HStack {
ProgressView()
.tint(self.statusTextColor)
.controlSize(.small)
Text(model.evaluationState.rawValue)
}
} else if model.evaluationState == .idle {
HStack(spacing: 6.0) {
Image(systemName: "clock.fill")
.font(.caption)
Text(model.evaluationState.rawValue)
}
}
else {
// I'm manually tweaking the spacing to
// better match the spacing with ProgressView
HStack(spacing: 6.0) {
Image(systemName: "lightbulb.fill")
.font(.caption)
Text(model.evaluationState.rawValue)
}
}
}
.foregroundStyle(self.statusTextColor)
.font(.caption)
.bold()
.padding(.vertical, 6.0)
.padding(.horizontal, 8.0)
.background(self.statusBackgroundColor)
.clipShape(.capsule)
.padding(.bottom)
}
}
#if os(macOS)
.frame(maxWidth: .infinity)
.frame(minWidth: 500)
.frame(minHeight: 375)
#endif
}
}
}
.listRowInsets(EdgeInsets())
.listRowBackground(Color.clear)
.listRowSeparator(.hidden)
promptSections
Section {
if model.output.isEmpty && model.running {
ProgressView()
.controlSize(.large)
.frame(maxWidth: .infinity)
} else {
ScrollView {
Text(model.output)
.foregroundStyle(isEditingPrompt ? .secondary : .primary)
.textSelection(.enabled)
#if os(macOS)
.font(.headline)
.fontWeight(.regular)
#endif
}
.frame(minHeight: 50.0, maxHeight: 200.0)
}
} header: {
Text("Response")
#if os(macOS)
.font(.headline)
.padding(.bottom, 2.0)
#endif
}
#if os(macOS)
Spacer()
#endif
}
#if os(iOS)
.listSectionSpacing(0)
#elseif os(macOS)
.padding()
#endif
.task {
camera.start()
}
.task {
await model.load()
}
#if !os(macOS)
.onAppear {
// Prevent the screen from dimming or sleeping due to inactivity
UIApplication.shared.isIdleTimerDisabled = true
}
.onDisappear {
// Resumes normal idle timer behavior
UIApplication.shared.isIdleTimerDisabled = false
}
#endif
// task to distribute video frames -- this will cancel
// and restart when the view is on/off screen. note: it is
// important that this is here (attached to the VideoFrameView)
// rather than the outer view because this has the correct lifecycle
.task {
if Task.isCancelled {
return
}
await distributeVideoFrames()
}
.navigationTitle("FastVLM")
#if os(iOS)
.navigationBarTitleDisplayMode(.inline)
#endif
.toolbar {
ToolbarItem(placement: toolbarItemPlacement) {
Button {
isShowingInfo.toggle()
}
label: {
Image(systemName: "info.circle")
}
}
ToolbarItem(placement: .primaryAction) {
if isEditingPrompt {
Button {
isEditingPrompt.toggle()
}
label: {
Text("Done")
.fontWeight(.bold)
}
}
else {
Menu {
Button("Describe image") {
prompt = "Describe the image in English."
promptSuffix = "Output should be brief, about 15 words or less."
}
Button("Facial expression") {
prompt = "What is this person's facial expression?"
promptSuffix = "Output only one or two words."
}
Button("Read text") {
prompt = "What is written in this image?"
promptSuffix = "Output only the text in the image."
}
#if !os(macOS)
Button("Customize...") {
isEditingPrompt.toggle()
}
#endif
} label: { Text("Prompts") }
}
}
}
.sheet(isPresented: $isShowingInfo) {
InfoView()
}
}
}
var promptSummary: some View {
Section("Prompt") {
VStack(alignment: .leading, spacing: 4.0) {
let trimmedPrompt = prompt.trimmingCharacters(in: .whitespacesAndNewlines)
if !trimmedPrompt.isEmpty {
Text(trimmedPrompt)
.foregroundStyle(.secondary)
}
let trimmedSuffix = promptSuffix.trimmingCharacters(in: .whitespacesAndNewlines)
if !trimmedSuffix.isEmpty {
Text(trimmedSuffix)
.font(.caption)
.foregroundStyle(.tertiary)
}
}
}
}
var promptForm: some View {
Group {
#if os(iOS)
Section("Prompt") {
TextEditor(text: $prompt)
.frame(minHeight: 38)
}
Section("Prompt Suffix") {
TextEditor(text: $promptSuffix)
.frame(minHeight: 38)
}
#elseif os(macOS)
Section {
HStack(alignment: .top) {
VStack(alignment: .leading) {
Text("Prompt")
.font(.headline)
TextEditor(text: $prompt)
.frame(height: 38)
.padding(.horizontal, 8.0)
.padding(.vertical, 10.0)
.background(Color(.textBackgroundColor))
.cornerRadius(10.0)
}
VStack(alignment: .leading) {
Text("Prompt Suffix")
.font(.headline)
TextEditor(text: $promptSuffix)
.frame(height: 38)
.padding(.horizontal, 8.0)
.padding(.vertical, 10.0)
.background(Color(.textBackgroundColor))
.cornerRadius(10.0)
}
}
}
.padding(.vertical)
#endif
}
}
var promptSections: some View {
Group {
#if os(iOS)
if isEditingPrompt {
promptForm
}
else {
promptSummary
}
#elseif os(macOS)
promptForm
#endif
}
}
func analyzeVideoFrames(_ frames: AsyncStream) async {
for await frame in frames {
let userInput = UserInput(
prompt: .text("\(prompt) \(promptSuffix)"),
images: [.ciImage(CIImage(cvPixelBuffer: frame))]
)
// generate output for a frame and wait for generation to complete
let t = await model.generate(userInput)
_ = await t.result
do {
try await Task.sleep(for: FRAME_DELAY)
} catch { return }
}
}
func distributeVideoFrames() async {
// attach a stream to the camera -- this code will read this
let frames = AsyncStream(bufferingPolicy: .bufferingNewest(1)) {
camera.attach(continuation: $0)
}
let (framesToDisplay, framesToDisplayContinuation) = AsyncStream.makeStream(
of: CVImageBuffer.self,
bufferingPolicy: .bufferingNewest(1)
)
self.framesToDisplay = framesToDisplay
// Only create analysis stream if in continuous mode
let (framesToAnalyze, framesToAnalyzeContinuation) = AsyncStream.makeStream(
of: CVImageBuffer.self,
bufferingPolicy: .bufferingNewest(1)
)
// set up structured tasks (important -- this means the child tasks
// are cancelled when the parent is cancelled)
async let distributeFrames: () = {
for await sampleBuffer in frames {
if let frame = sampleBuffer.imageBuffer {
framesToDisplayContinuation.yield(frame)
// Only send frames for analysis in continuous mode
if await selectedCameraType == .continuous {
framesToAnalyzeContinuation.yield(frame)
}
}
}
// detach from the camera controller and feed to the video view
await MainActor.run {
self.framesToDisplay = nil
self.camera.detatch()
}
framesToDisplayContinuation.finish()
framesToAnalyzeContinuation.finish()
}()
// Only analyze frames if in continuous mode
if selectedCameraType == .continuous {
async let analyze: () = analyzeVideoFrames(framesToAnalyze)
await distributeFrames
await analyze
} else {
await distributeFrames
}
}
/// Perform FastVLM inference on a single frame.
/// - Parameter frame: The frame to analyze.
func processSingleFrame(_ frame: CVImageBuffer) {
// Reset Response UI (show spinner)
Task { @MainActor in
model.output = ""
}
// Construct request to model
let userInput = UserInput(
prompt: .text("\(prompt) \(promptSuffix)"),
images: [.ciImage(CIImage(cvPixelBuffer: frame))]
)
// Post request to FastVLM
Task {
await model.generate(userInput)
}
}
}
#Preview {
ContentView()
}
================================================
FILE: app/FastVLM App/FastVLM.entitlements
================================================
com.apple.developer.kernel.increased-memory-limit
com.apple.security.app-sandbox
com.apple.security.device.camera
com.apple.security.files.user-selected.read-only
com.apple.security.network.client
================================================
FILE: app/FastVLM App/FastVLMApp.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import SwiftUI
@main
struct FastVLMApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}
================================================
FILE: app/FastVLM App/FastVLMModel.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import CoreImage
import FastVLM
import Foundation
import MLX
import MLXLMCommon
import MLXRandom
import MLXVLM
@Observable
@MainActor
class FastVLMModel {
public var running = false
public var modelInfo = ""
public var output = ""
public var promptTime: String = ""
enum LoadState {
case idle
case loaded(ModelContainer)
}
private let modelConfiguration = FastVLM.modelConfiguration
/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.0)
let maxTokens = 240
/// update the display every N tokens -- 4 looks like it updates continuously
/// and is low overhead. observed ~15% reduction in tokens/s when updating
/// on every token
let displayEveryNTokens = 4
private var loadState = LoadState.idle
private var currentTask: Task?
enum EvaluationState: String, CaseIterable {
case idle = "Idle"
case processingPrompt = "Processing Prompt"
case generatingResponse = "Generating Response"
}
public var evaluationState = EvaluationState.idle
public init() {
FastVLM.register(modelFactory: VLMModelFactory.shared)
}
private func _load() async throws -> ModelContainer {
switch loadState {
case .idle:
// limit the buffer cache
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
let modelContainer = try await VLMModelFactory.shared.loadContainer(
configuration: modelConfiguration
) {
[modelConfiguration] progress in
Task { @MainActor in
self.modelInfo =
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
}
}
self.modelInfo = "Loaded"
loadState = .loaded(modelContainer)
return modelContainer
case .loaded(let modelContainer):
return modelContainer
}
}
public func load() async {
do {
_ = try await _load()
} catch {
self.modelInfo = "Error loading model: \(error)"
}
}
public func generate(_ userInput: UserInput) async -> Task {
if let currentTask, running {
return currentTask
}
running = true
// Cancel any existing task
currentTask?.cancel()
// Create new task and store reference
let task = Task {
do {
let modelContainer = try await _load()
// each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
// Check if task was cancelled
if Task.isCancelled { return }
let result = try await modelContainer.perform { context in
// Measure the time it takes to prepare the input
Task { @MainActor in
evaluationState = .processingPrompt
}
let llmStart = Date()
let input = try await context.processor.prepare(input: userInput)
var seenFirstToken = false
// FastVLM generates the output
let result = try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) { tokens in
// Check if task was cancelled
if Task.isCancelled {
return .stop
}
if !seenFirstToken {
seenFirstToken = true
// produced first token, update the time to first token,
// the processing state and start displaying the text
let llmDuration = Date().timeIntervalSince(llmStart)
let text = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
evaluationState = .generatingResponse
self.output = text
self.promptTime = "\(Int(llmDuration * 1000)) ms"
}
}
// Show the text in the view as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = text
}
}
if tokens.count >= maxTokens {
return .stop
} else {
return .more
}
}
// Return the duration of the LLM and the result
return result
}
// Check if task was cancelled before updating UI
if !Task.isCancelled {
self.output = result.output
}
} catch {
if !Task.isCancelled {
output = "Failed: \(error)"
}
}
if evaluationState == .generatingResponse {
evaluationState = .idle
}
running = false
}
currentTask = task
return task
}
public func cancel() {
currentTask?.cancel()
currentTask = nil
running = false
output = ""
promptTime = ""
}
}
================================================
FILE: app/FastVLM App/Info.plist
================================================
================================================
FILE: app/FastVLM App/InfoView.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import Foundation
import SwiftUI
struct InfoView: View {
@Environment(\.dismiss) var dismiss
let paragraph1 = "**FastVLM¹** is a new family of Vision-Language models that makes use of **FastViTHD**, a hierarchical hybrid vision encoder that produces small number of high quality tokens at low latencies, resulting in significantly faster time-to-first-token (TTFT)."
let paragraph2 = "This app showcases the **FastVLM** model in action, allowing users to freely customize the prompt. FastVLM utilizes Qwen2-Instruct LLMs without additional safety tuning, so please exercise caution when modifying the prompt."
let footer = "1. **FastVLM: Efficient Vision Encoding for Vision Language Models.** (CVPR 2025) Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari"
var body: some View {
NavigationStack {
VStack(alignment: .leading, spacing: 20.0) {
// I'm not going to lie, this doesn't make sense...
// Wrapping `String`s with `.init()` turns them into `LocalizedStringKey`s
// which gives us all of the fun Markdown formatting while retaining the
// ability to use `String` variables. ¯\_(ツ)_/¯
Text("\(.init(paragraph1))\n\n\(.init(paragraph2))\n\n")
.font(.body)
Spacer()
Text(.init(footer))
.font(.caption)
.foregroundStyle(.secondary)
}
.padding()
.frame(maxWidth: .infinity, maxHeight: .infinity, alignment: .top)
.textSelection(.enabled)
.navigationTitle("Information")
#if os(iOS)
.navigationBarTitleDisplayMode(.inline)
#endif
.toolbar {
#if os(iOS)
ToolbarItem(placement: .navigationBarLeading) {
Button {
dismiss()
} label: {
Image(systemName: "xmark.circle")
.resizable()
.frame(width: 25, height: 25)
.foregroundStyle(.secondary)
}
.buttonStyle(.plain)
}
#elseif os(macOS)
ToolbarItem(placement: .cancellationAction) {
Button("Done") {
dismiss()
}
.buttonStyle(.bordered)
}
#endif
}
}
}
}
#Preview {
InfoView()
}
================================================
FILE: app/FastVLM App/Preview Content/Preview Assets.xcassets/Contents.json
================================================
{
"info" : {
"author" : "xcode",
"version" : 1
}
}
================================================
FILE: app/FastVLM.xcodeproj/project.pbxproj
================================================
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 77;
objects = {
/* Begin PBXBuildFile section */
019A3E1A2D78E7370055F93B /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E192D78E7370055F93B /* MLX */; };
019A3E1C2D78E73E0055F93B /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1B2D78E73E0055F93B /* MLXLMCommon */; };
019A3E1E2D78E7470055F93B /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1D2D78E7470055F93B /* MLXRandom */; };
019A3E202D78E74C0055F93B /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = 019A3E1F2D78E74C0055F93B /* MLXVLM */; };
019A3E212D78E7530055F93B /* Video.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C35372AE2D08C32D00474D34 /* Video.framework */; };
019A3E222D78E7530055F93B /* Video.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C35372AE2D08C32D00474D34 /* Video.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
C3ED544E2D790860005E20B3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED544D2D790860005E20B3 /* MLXLMCommon */; };
C3ED54502D790860005E20B3 /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED544F2D790860005E20B3 /* MLXVLM */; };
C3ED54522D790860005E20B3 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54512D790860005E20B3 /* MLX */; };
C3ED54542D790860005E20B3 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54532D790860005E20B3 /* MLXNN */; };
C3ED54582D790A68005E20B3 /* MLXFast in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED54572D790A68005E20B3 /* MLXFast */; };
C3ED545B2D790AD6005E20B3 /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = C3ED545A2D790AD6005E20B3 /* Transformers */; };
C3ED55012D7A0A7A005E20B3 /* FastVLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C39BB3E62D79082A005DB8FB /* FastVLM.framework */; };
C3ED55022D7A0A7A005E20B3 /* FastVLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C39BB3E62D79082A005DB8FB /* FastVLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
019A3E232D78E7530055F93B /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = C35EDB642D07699400757E80 /* Project object */;
proxyType = 1;
remoteGlobalIDString = C35372AD2D08C32D00474D34;
remoteInfo = Video;
};
C3ED55032D7A0A7A005E20B3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = C35EDB642D07699400757E80 /* Project object */;
proxyType = 1;
remoteGlobalIDString = C39BB3E52D79082A005DB8FB;
remoteInfo = FastVLM;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
019A3E252D78E7530055F93B /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
C3ED55022D7A0A7A005E20B3 /* FastVLM.framework in Embed Frameworks */,
019A3E222D78E7530055F93B /* Video.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
019A3E0A2D78E6A00055F93B /* FastVLM App.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "FastVLM App.app"; sourceTree = BUILT_PRODUCTS_DIR; };
12FFAF3D2DC93583009C4EFA /* get_pretrained_mlx_model.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = get_pretrained_mlx_model.sh; sourceTree = ""; };
12FFAF3E2DC93583009C4EFA /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
C35372AE2D08C32D00474D34 /* Video.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Video.framework; sourceTree = BUILT_PRODUCTS_DIR; };
C39BB3E62D79082A005DB8FB /* FastVLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = FastVLM.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
120A44852D9B05A900E244A3 /* Exceptions for "FastVLM App" folder in "FastVLM App" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
membershipExceptions = (
Info.plist,
);
target = 019A3E092D78E6A00055F93B /* FastVLM App */;
};
C35372B92D08C32D00474D34 /* Exceptions for "Video" folder in "Video" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
publicHeaders = (
Video.h,
);
target = C35372AD2D08C32D00474D34 /* Video */;
};
C3ED54BB2D791BEA005E20B3 /* Exceptions for "FastVLM" folder in "FastVLM" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
publicHeaders = (
FastVLM.h,
);
target = C39BB3E52D79082A005DB8FB /* FastVLM */;
};
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
/* Begin PBXFileSystemSynchronizedRootGroup section */
019A3E0B2D78E6A00055F93B /* FastVLM App */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
120A44852D9B05A900E244A3 /* Exceptions for "FastVLM App" folder in "FastVLM App" target */,
);
path = "FastVLM App";
sourceTree = "";
};
C32B4A802DA4805400EF663D /* Configuration */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = Configuration;
sourceTree = "";
};
C35372AF2D08C32D00474D34 /* Video */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
C35372B92D08C32D00474D34 /* Exceptions for "Video" folder in "Video" target */,
);
path = Video;
sourceTree = "";
};
C39BB3E72D79082A005DB8FB /* FastVLM */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
C3ED54BB2D791BEA005E20B3 /* Exceptions for "FastVLM" folder in "FastVLM" target */,
);
path = FastVLM;
sourceTree = "";
};
/* End PBXFileSystemSynchronizedRootGroup section */
/* Begin PBXFrameworksBuildPhase section */
019A3E072D78E6A00055F93B /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
019A3E1C2D78E73E0055F93B /* MLXLMCommon in Frameworks */,
019A3E212D78E7530055F93B /* Video.framework in Frameworks */,
C3ED55012D7A0A7A005E20B3 /* FastVLM.framework in Frameworks */,
019A3E1E2D78E7470055F93B /* MLXRandom in Frameworks */,
019A3E202D78E74C0055F93B /* MLXVLM in Frameworks */,
019A3E1A2D78E7370055F93B /* MLX in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
C35372AB2D08C32D00474D34 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C39BB3E32D79082A005DB8FB /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
C3ED545B2D790AD6005E20B3 /* Transformers in Frameworks */,
C3ED54522D790860005E20B3 /* MLX in Frameworks */,
C3ED54502D790860005E20B3 /* MLXVLM in Frameworks */,
C3ED54542D790860005E20B3 /* MLXNN in Frameworks */,
C3ED54582D790A68005E20B3 /* MLXFast in Frameworks */,
C3ED544E2D790860005E20B3 /* MLXLMCommon in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
C35EDB632D07699400757E80 = {
isa = PBXGroup;
children = (
12FFAF3E2DC93583009C4EFA /* README.md */,
12FFAF3D2DC93583009C4EFA /* get_pretrained_mlx_model.sh */,
C32B4A802DA4805400EF663D /* Configuration */,
C35372AF2D08C32D00474D34 /* Video */,
019A3E0B2D78E6A00055F93B /* FastVLM App */,
C39BB3E72D79082A005DB8FB /* FastVLM */,
C35EDB7F2D076C3C00757E80 /* Frameworks */,
C35EDB702D076A5A00757E80 /* Products */,
);
sourceTree = "";
};
C35EDB702D076A5A00757E80 /* Products */ = {
isa = PBXGroup;
children = (
C35372AE2D08C32D00474D34 /* Video.framework */,
019A3E0A2D78E6A00055F93B /* FastVLM App.app */,
C39BB3E62D79082A005DB8FB /* FastVLM.framework */,
);
name = Products;
sourceTree = "";
};
C35EDB7F2D076C3C00757E80 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
C35372A92D08C32D00474D34 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C39BB3E12D79082A005DB8FB /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
019A3E092D78E6A00055F93B /* FastVLM App */ = {
isa = PBXNativeTarget;
buildConfigurationList = 019A3E182D78E6A20055F93B /* Build configuration list for PBXNativeTarget "FastVLM App" */;
buildPhases = (
019A3E062D78E6A00055F93B /* Sources */,
019A3E072D78E6A00055F93B /* Frameworks */,
019A3E082D78E6A00055F93B /* Resources */,
019A3E252D78E7530055F93B /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
019A3E242D78E7530055F93B /* PBXTargetDependency */,
C3ED55042D7A0A7A005E20B3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
019A3E0B2D78E6A00055F93B /* FastVLM App */,
);
name = "FastVLM App";
packageProductDependencies = (
019A3E192D78E7370055F93B /* MLX */,
019A3E1B2D78E73E0055F93B /* MLXLMCommon */,
019A3E1D2D78E7470055F93B /* MLXRandom */,
019A3E1F2D78E74C0055F93B /* MLXVLM */,
);
productName = FastVLMCameraExample;
productReference = 019A3E0A2D78E6A00055F93B /* FastVLM App.app */;
productType = "com.apple.product-type.application";
};
C35372AD2D08C32D00474D34 /* Video */ = {
isa = PBXNativeTarget;
buildConfigurationList = C35372BA2D08C32D00474D34 /* Build configuration list for PBXNativeTarget "Video" */;
buildPhases = (
C35372A92D08C32D00474D34 /* Headers */,
C35372AA2D08C32D00474D34 /* Sources */,
C35372AB2D08C32D00474D34 /* Frameworks */,
C35372AC2D08C32D00474D34 /* Resources */,
);
buildRules = (
);
dependencies = (
);
fileSystemSynchronizedGroups = (
C35372AF2D08C32D00474D34 /* Video */,
);
name = Video;
packageProductDependencies = (
);
productName = Video;
productReference = C35372AE2D08C32D00474D34 /* Video.framework */;
productType = "com.apple.product-type.framework";
};
C39BB3E52D79082A005DB8FB /* FastVLM */ = {
isa = PBXNativeTarget;
buildConfigurationList = C39BB3FF2D79082A005DB8FB /* Build configuration list for PBXNativeTarget "FastVLM" */;
buildPhases = (
C39BB3E12D79082A005DB8FB /* Headers */,
C39BB3E22D79082A005DB8FB /* Sources */,
C39BB3E32D79082A005DB8FB /* Frameworks */,
C39BB3E42D79082A005DB8FB /* Resources */,
);
buildRules = (
);
dependencies = (
);
fileSystemSynchronizedGroups = (
C39BB3E72D79082A005DB8FB /* FastVLM */,
);
name = FastVLM;
packageProductDependencies = (
C3ED544D2D790860005E20B3 /* MLXLMCommon */,
C3ED544F2D790860005E20B3 /* MLXVLM */,
C3ED54512D790860005E20B3 /* MLX */,
C3ED54532D790860005E20B3 /* MLXNN */,
C3ED54572D790A68005E20B3 /* MLXFast */,
C3ED545A2D790AD6005E20B3 /* Transformers */,
);
productName = FastVLM;
productReference = C39BB3E62D79082A005DB8FB /* FastVLM.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
C35EDB642D07699400757E80 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 1620;
LastUpgradeCheck = 1630;
TargetAttributes = {
019A3E092D78E6A00055F93B = {
CreatedOnToolsVersion = 16.2;
};
C35372AD2D08C32D00474D34 = {
CreatedOnToolsVersion = 16.0;
};
C39BB3E52D79082A005DB8FB = {
CreatedOnToolsVersion = 16.0;
LastSwiftMigration = 1600;
};
};
};
buildConfigurationList = C35EDB672D07699400757E80 /* Build configuration list for PBXProject "FastVLM" */;
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = C35EDB632D07699400757E80;
minimizedProjectReferenceProxies = 1;
packageReferences = (
C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */,
C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */,
C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */,
);
preferredProjectObjectVersion = 77;
productRefGroup = C35EDB702D076A5A00757E80 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
C35372AD2D08C32D00474D34 /* Video */,
019A3E092D78E6A00055F93B /* FastVLM App */,
C39BB3E52D79082A005DB8FB /* FastVLM */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
019A3E082D78E6A00055F93B /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C35372AC2D08C32D00474D34 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C39BB3E42D79082A005DB8FB /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
019A3E062D78E6A00055F93B /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C35372AA2D08C32D00474D34 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
C39BB3E22D79082A005DB8FB /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin PBXTargetDependency section */
019A3E242D78E7530055F93B /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = C35372AD2D08C32D00474D34 /* Video */;
targetProxy = 019A3E232D78E7530055F93B /* PBXContainerItemProxy */;
};
C3ED55042D7A0A7A005E20B3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = C39BB3E52D79082A005DB8FB /* FastVLM */;
targetProxy = C3ED55032D7A0A7A005E20B3 /* PBXContainerItemProxy */;
};
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
019A3E162D78E6A20055F93B /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_ENTITLEMENTS = "FastVLM App/FastVLM.entitlements";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 0.1.0;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = dwarf;
DEVELOPMENT_ASSET_PATHS = "\"FastVLM App/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "FastVLM App/Info.plist";
INFOPLIST_KEY_CFBundleDisplayName = FastVLM;
INFOPLIST_KEY_NSCameraUsageDescription = "Use camera to get live feed of images";
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
INFOPLIST_KEY_UIRequiresFullScreen = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait;
IPHONEOS_DEPLOYMENT_TARGET = 18.2;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.2;
MARKETING_VERSION = 1.0;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
PRODUCT_BUNDLE_IDENTIFIER = com.apple.ml.FastVLM;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
XROS_DEPLOYMENT_TARGET = 2.2;
};
name = Debug;
};
019A3E172D78E6A20055F93B /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_ENTITLEMENTS = "FastVLM App/FastVLM.entitlements";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 0.1.0;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_ASSET_PATHS = "\"FastVLM App/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_NS_ASSERTIONS = NO;
ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "FastVLM App/Info.plist";
INFOPLIST_KEY_CFBundleDisplayName = FastVLM;
INFOPLIST_KEY_NSCameraUsageDescription = "Use camera to get live feed of images";
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
INFOPLIST_KEY_UIRequiresFullScreen = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait;
IPHONEOS_DEPLOYMENT_TARGET = 18.2;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.2;
MARKETING_VERSION = 1.0;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = com.apple.ml.FastVLM;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
XROS_DEPLOYMENT_TARGET = 2.2;
};
name = Release;
};
C35372B72D08C32D00474D34 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES;
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = dwarf;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 18.0;
LD_RUNPATH_SEARCH_PATHS = (
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = (
"@executable_path/../Frameworks",
"@loader_path/Frameworks",
);
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.Video;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SDKROOT = auto;
SKIP_INSTALL = YES;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
XROS_DEPLOYMENT_TARGET = 2.0;
};
name = Debug;
};
C35372B82D08C32D00474D34 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES;
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 18.0;
LD_RUNPATH_SEARCH_PATHS = (
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = (
"@executable_path/../Frameworks",
"@loader_path/Frameworks",
);
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.Video;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SDKROOT = auto;
SKIP_INSTALL = YES;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
XROS_DEPLOYMENT_TARGET = 2.0;
};
name = Release;
};
C35EDB682D07699400757E80 /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReferenceAnchor = C32B4A802DA4805400EF663D /* Configuration */;
baseConfigurationReferenceRelativePath = Build.xcconfig;
buildSettings = {
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 565ARCVNXV;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
ONLY_ACTIVE_ARCH = YES;
};
name = Debug;
};
C35EDB692D07699400757E80 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 565ARCVNXV;
ENABLE_STRICT_OBJC_MSGSEND = YES;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
};
name = Release;
};
C39BB3FB2D79082A005DB8FB /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES;
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
BUILD_LIBRARY_FOR_DISTRIBUTION = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = dwarf;
DEFINES_MODULE = NO;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 18.0;
LD_RUNPATH_SEARCH_PATHS = (
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = (
"@executable_path/../Frameworks",
"@loader_path/Frameworks",
);
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.FastVLM;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SDKROOT = auto;
SKIP_INSTALL = YES;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
XROS_DEPLOYMENT_TARGET = 2.0;
};
name = Debug;
};
C39BB3FC2D79082A005DB8FB /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALLOW_TARGET_PLATFORM_SPECIALIZATION = YES;
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
BUILD_LIBRARY_FOR_DISTRIBUTION = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEFINES_MODULE = NO;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 18.0;
LD_RUNPATH_SEARCH_PATHS = (
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = (
"@executable_path/../Frameworks",
"@loader_path/Frameworks",
);
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.FastVLM;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SDKROOT = auto;
SKIP_INSTALL = YES;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
XROS_DEPLOYMENT_TARGET = 2.0;
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
019A3E182D78E6A20055F93B /* Build configuration list for PBXNativeTarget "FastVLM App" */ = {
isa = XCConfigurationList;
buildConfigurations = (
019A3E162D78E6A20055F93B /* Debug */,
019A3E172D78E6A20055F93B /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
C35372BA2D08C32D00474D34 /* Build configuration list for PBXNativeTarget "Video" */ = {
isa = XCConfigurationList;
buildConfigurations = (
C35372B72D08C32D00474D34 /* Debug */,
C35372B82D08C32D00474D34 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
C35EDB672D07699400757E80 /* Build configuration list for PBXProject "FastVLM" */ = {
isa = XCConfigurationList;
buildConfigurations = (
C35EDB682D07699400757E80 /* Debug */,
C35EDB692D07699400757E80 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
C39BB3FF2D79082A005DB8FB /* Build configuration list for PBXNativeTarget "FastVLM" */ = {
isa = XCConfigurationList;
buildConfigurations = (
C39BB3FB2D79082A005DB8FB /* Debug */,
C39BB3FC2D79082A005DB8FB /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */
C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/ml-explore/mlx-swift-examples";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.21.2;
};
};
C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/ml-explore/mlx-swift";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.21.2;
};
};
C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/huggingface/swift-transformers";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.1.18;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
019A3E192D78E7370055F93B /* MLX */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLX;
};
019A3E1B2D78E73E0055F93B /* MLXLMCommon */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */;
productName = MLXLMCommon;
};
019A3E1D2D78E7470055F93B /* MLXRandom */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLXRandom;
};
019A3E1F2D78E74C0055F93B /* MLXVLM */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */;
productName = MLXVLM;
};
C3ED544D2D790860005E20B3 /* MLXLMCommon */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */;
productName = MLXLMCommon;
};
C3ED544F2D790860005E20B3 /* MLXVLM */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB6A2D076A3900757E80 /* XCRemoteSwiftPackageReference "mlx-swift-examples" */;
productName = MLXVLM;
};
C3ED54512D790860005E20B3 /* MLX */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLX;
};
C3ED54532D790860005E20B3 /* MLXNN */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLXNN;
};
C3ED54572D790A68005E20B3 /* MLXFast */ = {
isa = XCSwiftPackageProductDependency;
package = C35EDB8A2D07777E00757E80 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLXFast;
};
C3ED545A2D790AD6005E20B3 /* Transformers */ = {
isa = XCSwiftPackageProductDependency;
package = C3ED54592D790AC6005E20B3 /* XCRemoteSwiftPackageReference "swift-transformers" */;
productName = Transformers;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = C35EDB642D07699400757E80 /* Project object */;
}
================================================
FILE: app/FastVLM.xcodeproj/xcshareddata/xcschemes/FastVLM App.xcscheme
================================================
================================================
FILE: app/README.md
================================================
# FastVLM
Demonstrates the performance of **FastVLM** models for on-device, visual question answering.
## Features
- FastVLM runs on iOS (18.2+) and macOS (15.2+).
- View Time-To-First-Token (TTFT) with every inference.
- All predictions are processed privately and securely using on-device models.
### Flexible Prompting
The app includes a set of built-in prompts to help you get started quickly. Tap the **Prompts** button in the top-right corner to explore them. Selecting a prompt will immediately update the active input. To create new prompts or edit existing ones, choose **Customize…** from the **Prompts** menu.
## Pretrained Model Options
There are 3 pretrained sizes of FastVLM to choose from:
- **FastVLM 0.5B**: Small and fast - great for mobile devices where speed matters.
- **FastVLM 1.5B**: Well balanced - great for larger devices where speed and accuracy matters.
- **FastVLM 7B**: Fast and accurate - ideal for situations where accuracy matters over speed.
To download any FastVLM listed above, use the [get_pretrained_mlx_model.sh](get_pretrained_mlx_model.sh) script. The script downloads the model from the web and places it in the appropriate location. Once a model has been downloaded using the steps below, no additional steps are needed to build the app in Xcode.
To explore how the other models work for your use-case, simply re-run the `get_pretrained_mlx_model.sh` with the new model selected, follow the prompts, and rebuild your app in Xcode.
### Download Instructions
1. Make the script executable
```shell
chmod +x app/get_pretrained_mlx_model.sh
```
2. Download FastVLM
```shell
app/get_pretrained_mlx_model.sh --model 0.5b --dest app/FastVLM/model
```
3. Open the app in Xcode, Build, and Run.
### Custom Model
In addition to pretrained sizes of FastVLM, you can further quantize or fine-tune FastVLM to best fit their needs. To learn more, check out our documentation on how to [`export the model`](../model_export#export-vlm).
Please clear existing model in `app/FastVLM/model` before downloading or copying a new model.
================================================
FILE: app/Video/CameraController.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import AVFoundation
import CoreImage
#if os(iOS)
import UIKit
#endif
@Observable
public class CameraController: NSObject {
private var framesContinuation: AsyncStream.Continuation?
public var backCamera = true {
didSet {
stop()
start()
}
}
public var devices = [AVCaptureDevice]()
public var device: AVCaptureDevice = AVCaptureDevice.default(for: .video)! {
didSet {
stop()
start()
}
}
private var permissionGranted = true
private var captureSession: AVCaptureSession?
private let sessionQueue = DispatchQueue(label: "sessionQueue")
@objc dynamic private var rotationCoordinator : AVCaptureDevice.RotationCoordinator?
private var rotationObservation: NSKeyValueObservation?
public func attach(continuation: AsyncStream.Continuation) {
sessionQueue.async {
self.framesContinuation = continuation
}
}
public func detatch() {
sessionQueue.async {
self.framesContinuation = nil
}
}
public func stop() {
sessionQueue.sync { [self] in
captureSession?.stopRunning()
captureSession = nil
}
}
public func start() {
sessionQueue.async { [self] in
let captureSession = AVCaptureSession()
self.captureSession = captureSession
self.checkPermission()
self.setupCaptureSession(position: backCamera ? .back : .front)
captureSession.startRunning()
}
}
#if os(iOS)
private func setOrientation(_ orientation: UIDeviceOrientation) {
guard let captureSession else { return }
let angle: Double?
switch orientation {
case .unknown, .faceDown:
angle = nil
case .portrait, .faceUp:
angle = 90
case .portraitUpsideDown:
angle = 270
case .landscapeLeft:
angle = 0
case .landscapeRight:
angle = 180
@unknown default:
angle = nil
}
if let angle {
for output in captureSession.outputs {
output.connection(with: .video)?.videoRotationAngle = angle
}
}
}
private func updateRotation(rotation : CGFloat) {
guard let captureSession else { return }
for output in captureSession.outputs {
output.connection(with: .video)?.videoRotationAngle = rotation
}
}
#endif
func checkPermission() {
switch AVCaptureDevice.authorizationStatus(for: .video) {
case .authorized:
// The user has previously granted access to the camera.
self.permissionGranted = true
case .notDetermined:
// The user has not yet been asked for camera access.
self.requestPermission()
// Combine the two other cases into the default case
default:
self.permissionGranted = false
}
}
func requestPermission() {
// Strong reference not a problem here but might become one in the future.
AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in
self.permissionGranted = granted
}
}
func setupCaptureSession(position: AVCaptureDevice.Position) {
guard let captureSession else { return }
let videoOutput = AVCaptureVideoDataOutput()
guard permissionGranted else {
print("No permission for camera")
return
}
let deviceTypes: [AVCaptureDevice.DeviceType]
#if os(iOS)
deviceTypes = [.builtInDualCamera, .builtInWideAngleCamera]
#else
deviceTypes = [.external, .continuityCamera, .builtInWideAngleCamera]
#endif
let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession(
deviceTypes: deviceTypes,
mediaType: .video,
position: position)
let videoDevice: AVCaptureDevice?
if videoDeviceDiscoverySession.devices.contains(self.device) {
videoDevice = self.device
} else {
videoDevice = videoDeviceDiscoverySession.devices.first
}
if devices.isEmpty {
self.devices = videoDeviceDiscoverySession.devices
}
guard
let videoDevice
else {
print("Unable to find video device")
return
}
guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else {
print("Unable to create AVCaptureDeviceInput")
return
}
guard captureSession.canAddInput(videoDeviceInput) else {
print("Unable to add input")
return
}
captureSession.addInput(videoDeviceInput)
videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue"))
captureSession.addOutput(videoOutput)
captureSession.sessionPreset = AVCaptureSession.Preset.hd1920x1080
#if os(iOS)
rotationCoordinator = AVCaptureDevice.RotationCoordinator(device: videoDevice, previewLayer: nil)
rotationObservation = observe(\.rotationCoordinator!.videoRotationAngleForHorizonLevelCapture, options: [.initial, .new]) { [weak self] _, change in
if let nv = change.newValue {
self?.updateRotation(rotation: nv)
}
}
#endif
}
}
extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate {
public func captureOutput(
_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer,
from connection: AVCaptureConnection
) {
if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil {
framesContinuation?.yield(sampleBuffer)
}
}
}
================================================
FILE: app/Video/CameraControlsView.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import AVFoundation
import SwiftUI
public struct CameraControlsView: View {
@Binding public var backCamera: Bool
@Binding public var device: AVCaptureDevice
@Binding public var devices: [AVCaptureDevice]
public init(
backCamera: Binding,
device: Binding,
devices: Binding<[AVCaptureDevice]>
) {
self._backCamera = backCamera
self._device = device
self._devices = devices
}
public var body: some View {
Button {
backCamera.toggle()
} label: {
RoundedRectangle(cornerRadius: 8.0)
.fill(.regularMaterial)
.frame(width: 32.0, height: 32.0)
.overlay(alignment: .center) {
// Switch cameras image
Image(systemName: "arrow.triangle.2.circlepath.camera.fill")
.foregroundStyle(.primary)
.padding(6.0)
}
}
}
}
================================================
FILE: app/Video/CameraType.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import Foundation
public enum CameraType: String, CaseIterable {
case continuous
case single
}
================================================
FILE: app/Video/Video.h
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
#import
//! Project version number for Video.
FOUNDATION_EXPORT double VideoVersionNumber;
//! Project version string for Video.
FOUNDATION_EXPORT const unsigned char VideoVersionString[];
================================================
FILE: app/Video/VideoFrameView.swift
================================================
//
// For licensing see accompanying LICENSE file.
// Copyright (C) 2025 Apple Inc. All Rights Reserved.
//
import AVFoundation
import CoreImage
import Foundation
import SwiftUI
/// Displays a stream of video frames
public struct VideoFrameView: View {
@Environment(\.colorScheme) private var colorScheme
public let frames: AsyncStream
public let cameraType: CameraType
public let action: ((CVImageBuffer) -> Void)?
@State private var hold: Bool = false
@State private var videoFrame: CVImageBuffer?
private var backgroundColor: Color {
#if os(iOS)
return Color(.secondarySystemBackground)
#elseif os(macOS)
return Color(.secondarySystemFill)
#else
// When in doubt, use these values that I captured to match iOS' secondarySystemBackground
if colorScheme == .dark {
return Color(red: 0.11, green: 0.11, blue: 0.12)
} else {
return Color(red: 0.95, green: 0.95, blue: 0.97)
}
#endif
}
public init(
frames: AsyncStream,
cameraType: CameraType,
action: ((CVImageBuffer) -> Void)?
) {
self.frames = frames
self.cameraType = cameraType
self.action = action
}
public var body: some View {
Group {
if let videoFrame {
_ImageView(image: videoFrame)
.overlay(alignment: .bottom) {
if cameraType == .single {
Button {
tap()
} label: {
if hold {
Label("Resume", systemImage: "play.fill")
} else {
Label("Capture Photo", systemImage: "camera.fill")
}
}
.clipShape(.capsule)
.buttonStyle(.borderedProminent)
.tint(hold ? .gray : .accentColor)
.foregroundColor(.white)
.padding()
}
}
} else {
// spinner before the camera comes up
ProgressView()
.controlSize(.large)
}
}
// This ensures that we take up the full 4/3 aspect ratio
// even if we don't have an image to display
.frame(maxWidth: .infinity, maxHeight: .infinity)
.background(backgroundColor)
.clipShape(RoundedRectangle(cornerRadius: 10.0))
.task {
// feed frames to the _ImageView
if Task.isCancelled {
return
}
for await frame in frames {
if !hold {
videoFrame = frame
}
}
}
.onChange(of: cameraType) { _, newType in
// No matter what, when the user switches to .continuous,
// we need to continue showing updated frames
if newType == .continuous {
hold = false
}
}
}
private func tap() {
if hold {
// resume
hold = false
} else if let videoFrame {
hold = true
if let action {
action(videoFrame)
}
}
}
}
#if os(iOS)
/// Internal view to display a CVImageBuffer
private struct _ImageView: UIViewRepresentable {
let image: Any
var gravity = CALayerContentsGravity.resizeAspectFill
func makeUIView(context: Context) -> UIView {
let view = UIView()
view.layer.contentsGravity = gravity
return view
}
func updateUIView(_ uiView: UIView, context: Context) {
uiView.layer.contents = image
}
}
#else
private struct _ImageView: NSViewRepresentable {
let image: Any
var gravity = CALayerContentsGravity.resizeAspectFill
func makeNSView(context: Context) -> NSView {
let view = NSView()
view.wantsLayer = true
view.layer?.contentsGravity = gravity
return view
}
func updateNSView(_ uiView: NSView, context: Context) {
uiView.layer?.contents = image
}
}
#endif
================================================
FILE: app/get_pretrained_mlx_model.sh
================================================
#!/usr/bin/env bash
#
# For licensing see accompanying LICENSE_MODEL file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
set -e
# Help function
show_help() {
local is_error=${1:-true} # Default to error mode if no argument provided
echo "Usage: $0 --model --dest "
echo
echo "Required arguments:"
echo " --model Size of the model to download"
echo " --dest Directory where the model will be downloaded"
echo
echo "Available model sizes:"
echo " 0.5b - 0.5B parameter model (FP16)"
echo " 1.5b - 1.5B parameter model (INT8)"
echo " 7b - 7B parameter model (INT4)"
echo
echo "Options:"
echo " --help Show help message"
# Exit with success (0) for help flag, error (1) for usage errors
if [ "$is_error" = "false" ]; then
exit 0
else
exit 1
fi
}
# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
case $1 in
--model) model_size="$2"; shift ;;
--dest) dest_dir="$2"; shift ;;
--help) show_help false ;; # Explicit help request
*) echo -e "Unknown parameter: $1\n"; show_help true ;; # Error case
esac
shift
done
# Validate required parameters
if [ -z "$model_size" ]; then
echo -e "Error: --model parameter is required\n"
show_help true
fi
if [ -z "$dest_dir" ]; then
echo -e "Error: --dest parameter is required\n"
show_help true
fi
# Map model size to full model name
case "$model_size" in
"0.5b") model="llava-fastvithd_0.5b_stage3_llm.fp16" ;;
"1.5b") model="llava-fastvithd_1.5b_stage3_llm.int8" ;;
"7b") model="llava-fastvithd_7b_stage3_llm.int4" ;;
*)
echo -e "Error: Invalid model size '$model_size'\n"
show_help true
;;
esac
cleanup() {
rm -rf "$tmp_dir"
}
download_model() {
# Download directory
tmp_dir=$(mktemp -d)
# Model paths
base_url="https://ml-site.cdn-apple.com/datasets/fastvlm"
# Create destination directory if it doesn't exist
if [ ! -d "$dest_dir" ]; then
echo "Creating destination directory: $dest_dir"
mkdir -p "$dest_dir"
elif [ "$(ls -A "$dest_dir")" ]; then
echo -e "Destination directory '$dest_dir' exists and is not empty.\n"
read -p "Do you want to clear it and continue? [y/N]: " confirm
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
echo -e "\nStopping."
exit 1
fi
echo -e "\nClearing existing contents in '$dest_dir'"
rm -rf "${dest_dir:?}"/*
fi
# Create temp variables
tmp_zip_file="${tmp_dir}/${model}.zip"
tmp_extract_dir="${tmp_dir}/${model}"
# Create temp extract directory
mkdir -p "$tmp_extract_dir"
# Download model
echo -e "\nDownloading '${model}' model ...\n"
wget -q --progress=bar:noscroll --show-progress -O "$tmp_zip_file" "$base_url/$model.zip"
# Unzip model
echo -e "\nUnzipping model..."
unzip -q "$tmp_zip_file" -d "$tmp_extract_dir"
# Copy model files to destination directory
echo -e "\nCopying model files to destination directory..."
cp -r "$tmp_extract_dir/$model"/* "$dest_dir"
# Verify destination directory exists and is not empty
if [ ! -d "$dest_dir" ] || [ -z "$(ls -A "$dest_dir")" ]; then
echo -e "\nModel extraction failed. Destination directory '$dest_dir' is missing or empty."
exit 1
fi
echo -e "\nModel downloaded and extracted to '$dest_dir'"
}
# Cleanup download directory on exit
trap cleanup EXIT INT TERM
# Download models
download_model
================================================
FILE: get_models.sh
================================================
#!/usr/bin/env bash
#
# For licensing see accompanying LICENSE_MODEL file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
mkdir -p checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip -P checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip -P checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip -P checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip -P checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip -P checkpoints
wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip -P checkpoints
# Extract models
cd checkpoints
unzip -qq llava-fastvithd_0.5b_stage2.zip
unzip -qq llava-fastvithd_0.5b_stage3.zip
unzip -qq llava-fastvithd_1.5b_stage2.zip
unzip -qq llava-fastvithd_1.5b_stage3.zip
unzip -qq llava-fastvithd_7b_stage2.zip
unzip -qq llava-fastvithd_7b_stage3.zip
# Clean up
rm llava-fastvithd_0.5b_stage2.zip
rm llava-fastvithd_0.5b_stage3.zip
rm llava-fastvithd_1.5b_stage2.zip
rm llava-fastvithd_1.5b_stage3.zip
rm llava-fastvithd_7b_stage2.zip
rm llava-fastvithd_7b_stage3.zip
cd -
================================================
FILE: llava/__init__.py
================================================
from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM
================================================
FILE: llava/constants.py
================================================
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = ""
DEFAULT_IMAGE_PATCH_TOKEN = ""
DEFAULT_IM_START_TOKEN = ""
DEFAULT_IM_END_TOKEN = ""
IMAGE_PLACEHOLDER = ""
================================================
FILE: llava/conversation.py
================================================
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
QWEN_2 = auto() # fix: add qwen2
CHATML = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("", "").strip()
if 'mmtag' in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], ""))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
# elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2
# seps = [self.sep, self.sep2]
# ret = self.system + seps[0]
# ret = ""
# for i, (role, message) in enumerate(messages):
# if message:
# if type(message) is tuple:
# message, _, _ = message
# ret += role + ": " + message + seps[i % 2]
# else:
# ret += role + ":"
elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2
ret = self.system + self.sep
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in messages:
if message:
if type(message) is tuple:
message, images = message
message = "" * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
def wrap_sys(msg): return f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
def wrap_inst(msg): return f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
image = self.process_image(image, image_process_mode, return_pil=return_pil)
images.append(image)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
img_b64_str = self.process_image(
image, "Default", return_pil=False,
image_format='JPEG')
img_str = f'
'
msg = img_str + msg.replace('', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="",
)
conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_llava_plain = Conversation(
system="",
roles=("", ""),
messages=(
),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: visual content.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
)
conv_llava_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: visual content.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
version="v1_mmtag",
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="",
)
conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_qwen_2 = Conversation(
system="<|im_start|>system\nYou are a helpful assistant.",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="qwen_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN_2,
sep="<|im_end|>\n",
)
# conv_qwen_2 = Conversation(
# system="",
# roles=("user", "assistant"),
# version="qwen_v2",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.QWEN_2,
# sep=" ",
# sep2="<|im_end|>",
# )
# fix: add qwen2
# conv_qwen_2 = Conversation(
# system="A chat between a curious user and an artificial intelligence assistant. "
# "The assistant gives helpful, detailed, and polite answers to the user's questions.",
# roles=("USER", "ASSISTANT"),
# version="qwen_v2",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.QWEN_2,
# sep=" ",
# sep2="<|endoftext|>",
# )
# conv_qwen_2 = Conversation(
# system="""<|im_start|>system
# You are a helpful assistant.""",
# roles=("<|im_start|>user", "<|im_start|>assistant"),
# version="qwen_v2",
# messages=[],
# offset=0,
# sep_style=SeparatorStyle.QWEN_2,
# sep="<|im_end|>",
# sep2="<|im_end|>",
# )
default_conversation = conv_qwen_2
conv_templates = {
"default": conv_qwen_2,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"qwen_2": conv_qwen_2,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,
"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
"llava_v0": conv_llava_v0,
"v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"mpt": conv_mpt,
}
if __name__ == "__main__":
print("conversation:", default_conversation.get_prompt())
================================================
FILE: llava/mm_utils.py
================================================
import PIL
from PIL import Image
PIL.Image.MAX_IMAGE_PIXELS=500000000
from io import BytesIO
import base64
import torch
import math
import ast
from transformers import StoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size['height'])
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == "anyres":
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
if torch.equal(truncated_output_ids, keyword_id):
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
================================================
FILE: llava/model/__init__.py
================================================
# try:
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig
# except:
# pass
================================================
FILE: llava/model/apply_delta.py
================================================
"""
Usage:
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava import LlavaLlamaForCausalLM
def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading delta")
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
print("Applying delta")
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data += base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
print("Saving target model")
delta.save_pretrained(target_model_path)
delta_tokenizer.save_pretrained(target_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
args = parser.parse_args()
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
================================================
FILE: llava/model/builder.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs['device_map'] = {"": device}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
if use_flash_attn:
kwargs['attn_implementation'] = 'flash_attention_2'
if 'llava' in model_name.lower():
# Load LLaVA model
if 'lora' in model_name.lower() and model_base is None:
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
if 'lora' in model_name.lower() and model_base is not None:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA from base model...')
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
print('Loading additional LLaVA weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
print('Loading LoRA weights...')
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Model is loaded...')
elif model_base is not None:
# this may be mm projector only
print('Loading LLaVA from base model...')
if 'mpt' in model_name.lower():
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
# model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
model = LlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
if 'mpt' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
elif 'mistral' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
elif 'dclm' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaOpenlmForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# model = LlavaLlamaForCausalLM.from_pretrained(
# model_path,
# low_cpu_mem_usage=True,
# **kwargs
# )
model = LlavaQwen2ForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print('Convert to FP16...')
model.to(torch.float16)
else:
use_fast = False
if 'mpt' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
image_processor = None
if 'llava' in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
if device_map != 'auto':
vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
================================================
FILE: llava/model/consolidate.py
================================================
"""
Usage:
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
"""
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model import *
from llava.model.utils import auto_upgrade
def consolidate_ckpt(src_path, dst_path):
print("Loading model")
auto_upgrade(src_path)
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
src_model.save_pretrained(dst_path)
src_tokenizer.save_pretrained(dst_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
parser.add_argument("--dst", type=str, required=True)
args = parser.parse_args()
consolidate_ckpt(args.src, args.dst)
================================================
FILE: llava/model/language_model/llava_llama.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig
def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
================================================
FILE: llava/model/language_model/llava_mistral.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
MistralConfig, MistralModel, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMistralConfig(MistralConfig):
model_type = "llava_mistral"
class LlavaMistralModel(LlavaMetaModel, MistralModel):
config_class = LlavaMistralConfig
def __init__(self, config: MistralConfig):
super(LlavaMistralModel, self).__init__(config)
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMistralConfig
def __init__(self, config):
super(MistralForCausalLM, self).__init__(config)
self.model = LlavaMistralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_mistral", LlavaMistralConfig)
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
================================================
FILE: llava/model/language_model/llava_mpt.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers import AutoConfig, AutoModelForCausalLM, \
MptConfig, MptForCausalLM, MptModel
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMptConfig(MptConfig):
model_type = "llava_mpt"
class LlavaMptModel(LlavaMetaModel, MptModel):
config_class = LlavaMptConfig
def __init__(self, config: MptConfig):
config.hidden_size = config.d_model
super(LlavaMptModel, self).__init__(config)
def embed_tokens(self, x):
return self.wte(x)
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMptConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super(MptForCausalLM, self).__init__(config)
self.transformer = LlavaMptModel(config)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlavaMptModel):
module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None):
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
return super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
_inputs['images'] = images
return _inputs
AutoConfig.register("llava_mpt", LlavaMptConfig)
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
================================================
FILE: llava/model/language_model/llava_qwen.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaConfig(Qwen2Config):
model_type = "llava_qwen2"
class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
config_class = LlavaConfig
def __init__(self, config: Qwen2Config):
super(LlavaQwen2Model, self).__init__(config)
class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = LlavaQwen2Model(config)
# self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_qwen2", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
================================================
FILE: llava/model/llava_arch.py
================================================
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import get_anyres_image_grid_shape
class LlavaMetaModel:
def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
if 'unpad' in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of PIL image (width, height).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
images, image_sizes=None
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith('spatial'):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == 'anyres':
if hasattr(self.get_vision_tower(), 's2_image_size'):
img_size = self.get_vision_tower().s2_image_size
elif isinstance(self.get_vision_tower().config, dict):
img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
else:
img_size = self.get_vision_tower().config.image_size
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.model.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
================================================
FILE: llava/model/make_delta.py
================================================
"""
Usage:
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.utils import auto_upgrade
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
================================================
FILE: llava/model/multimodal_encoder/builder.py
================================================
import os
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
from .mobileclip_encoder import MobileCLIPVisionTower
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, 's2', False)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
if use_s2:
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "mobileclip" in vision_tower.lower():
return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
================================================
FILE: llava/model/multimodal_encoder/clip_encoder.py
================================================
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
self.input_image_size = getattr(args, 'input_image_size', None)
if self.tune_vision_tower:
print("CLIP Vision tower is set to tunable")
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
if self.input_image_size is not None:
self.cfg_only.image_size = self.input_image_size
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
if not self.tune_vision_tower:
self.vision_tower.requires_grad_(False)
if self.input_image_size is not None:
print("Using input image size: {}".format(self.input_image_size))
self.image_processor.size['shortest_edge'] = self.input_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.input_image_size
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
def forward(self, images):
if self.tune_vision_tower:
return self.forward_images(images)
else:
with torch.no_grad():
return self.forward_images(images)
def forward_images(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class CLIPVisionTowerS2(CLIPVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
self.s2_scales = list(map(int, self.s2_scales.split(',')))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]
super().__init__(vision_tower, args, delay_load)
try:
from s2wrapper import forward as multiscale_forward
except ImportError:
raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
self.multiscale_forward = multiscale_forward
# change resize/crop size in preprocessing to the largest image size in s2_scale
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
self.is_loaded = True
@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)
================================================
FILE: llava/model/multimodal_encoder/mobileclip/__init__.py
================================================
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import os
import json
from typing import Any
import torch.nn as nn
from timm.models import create_model
from .mci import GlobalPool2D
def load_model_config(
model_name: str,
) -> Any:
# Strip suffixes to model name
model_name = "_".join(model_name.split("_")[0:2])
# Config files
root_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(root_dir, "configs")
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
# Get config from yaml file
if not os.path.exists(model_cfg_file):
raise ValueError(f"Unsupported model name: {model_name}")
model_cfg = json.load(open(model_cfg_file, "r"))
return model_cfg
class MCi(nn.Module):
"""
This class implements `MCi Models `_
"""
def __init__(self, model_name: str, *args, **kwargs) -> None:
super().__init__()
self.projection_dim = None
if "projection_dim" in kwargs:
self.projection_dim = kwargs.get("projection_dim")
# Create model
self.model = create_model(model_name, projection_dim=self.projection_dim)
# Build out projection head.
if self.projection_dim is not None:
if hasattr(self.model, "head"):
self.model.head = MCi._update_image_classifier(
image_classifier=self.model.head, projection_dim=self.projection_dim
)
def forward(self, x: Any, *args, **kwargs) -> Any:
"""A forward function of the model."""
x = self.model(x, *args, **kwargs)
return x
@staticmethod
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
"""Return the input feature dimension to the image classification head."""
in_features = None
if isinstance(image_classifier, nn.Sequential):
# Classifier that uses nn.Sequential usually has global pooling and
# multiple linear layers. Find the first linear layer and get its
# in_features
for layer in image_classifier:
if isinstance(layer, nn.Linear):
in_features = layer.in_features
break
elif isinstance(image_classifier, nn.Linear):
in_features = image_classifier.in_features
if in_features is None:
raise NotImplementedError(
f"Cannot get input feature dimension of {image_classifier}."
)
return in_features
@staticmethod
def _update_image_classifier(
image_classifier: nn.Module, projection_dim: int, *args, **kwargs
) -> nn.Module:
in_features = MCi._get_in_feature_dimension(image_classifier)
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
return new_img_classifier
================================================
FILE: llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json
================================================
{
"embed_dim": 768,
"image_cfg": {
"image_size": 1024,
"model_name": "fastvithd",
"embed_dim": 3072,
"patch_size": 64
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"dim": 768,
"ffn_multiplier_per_layer": 4.0,
"n_heads_per_layer": 12,
"n_transformer_layers": 12,
"norm_layer": "layer_norm_fp32",
"causal_masking": false,
"model_name": "base"
}
}
================================================
FILE: llava/model/multimodal_encoder/mobileclip/mci.py
================================================
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import copy
from functools import partial
from typing import List, Tuple, Optional, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.init import normal_
from timm.models import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, SqueezeExcite
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 256, 256),
"pool_size": None,
"crop_pct": 0.95,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
"classifier": "head",
**kwargs,
}
default_cfgs = {
"fastvit_t": _cfg(crop_pct=0.9),
"fastvit_s": _cfg(crop_pct=0.9),
"fastvit_m": _cfg(crop_pct=0.95),
}
class SEBlock(nn.Module):
"""Squeeze and Excite module.
Pytorch implementation of `Squeeze-and-Excitation Networks` -
https://arxiv.org/pdf/1709.01507.pdf
"""
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
"""Construct a Squeeze and Excite Module.
Args:
in_channels: Number of input channels.
rd_ratio: Input channel reduction ratio.
"""
super(SEBlock, self).__init__()
self.reduce = nn.Conv2d(
in_channels=in_channels,
out_channels=int(in_channels * rd_ratio),
kernel_size=1,
stride=1,
bias=True,
)
self.expand = nn.Conv2d(
in_channels=int(in_channels * rd_ratio),
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
b, c, h, w = inputs.size()
x = F.avg_pool2d(inputs, kernel_size=[h, w])
x = self.reduce(x)
x = F.relu(x)
x = self.expand(x)
x = torch.sigmoid(x)
x = x.view(-1, c, 1, 1)
return inputs * x
class MobileOneBlock(nn.Module):
"""MobileOne building block.
This block has a multi-branched architecture at train-time
and plain-CNN style architecture at inference time
For more details, please refer to our paper:
`An Improved One millisecond Mobile Backbone` -
https://arxiv.org/pdf/2206.04040.pdf
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
activation: nn.Module = nn.GELU(),
) -> None:
"""Construct a MobileOneBlock module.
Args:
in_channels: Number of channels in the input.
out_channels: Number of channels produced by the block.
kernel_size: Size of the convolution kernel.
stride: Stride size.
padding: Zero-padding size.
dilation: Kernel dilation factor.
groups: Group number.
inference_mode: If True, instantiates model in inference mode.
use_se: Whether to use SE-ReLU activations.
use_act: Whether to use activation. Default: ``True``
use_scale_branch: Whether to use scale branch. Default: ``True``
num_conv_branches: Number of linear conv branches.
"""
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = groups
self.stride = stride
self.padding = padding
self.dilation = dilation
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
if use_se:
self.se = SEBlock(out_channels)
else:
self.se = nn.Identity()
if use_act:
self.activation = activation
else:
self.activation = nn.Identity()
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
)
else:
# Re-parameterizable skip connection
# Fallback, sometimes batchnorm tensors
# do not get instantiated correctly on some processes
# when using deepspeed + accelerate
norm_layer = nn.BatchNorm2d(num_features=in_channels)
if norm_layer.weight.shape[0] == 0:
norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
if norm_layer.bias.shape[0] == 0:
norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
self.rbr_skip = (
norm_layer
if out_channels == in_channels and stride == 1
else None
)
# Re-parameterizable conv branches
if num_conv_branches > 0:
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(
self._conv_bn(kernel_size=kernel_size, padding=padding)
)
self.rbr_conv = nn.ModuleList(rbr_conv)
else:
self.rbr_conv = None
# Re-parameterizable scale branch
self.rbr_scale = None
if not isinstance(kernel_size, int):
kernel_size = kernel_size[0]
if (kernel_size > 1) and use_scale_branch:
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
# Inference mode forward pass.
if self.inference_mode:
return self.activation(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass.
# Skip branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
# Other branches
out = scale_out + identity_out
if self.rbr_conv is not None:
for ix in range(self.num_conv_branches):
out += self.rbr_conv[ix](x)
return self.activation(self.se(out))
def reparameterize(self):
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.inference_mode:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"):
self.__delattr__("rbr_skip")
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
if self.rbr_conv is not None:
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
Args:
branch: Sequence of ops to be fused.
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_size = self.kernel_size
if isinstance(self.kernel_size, int):
kernel_size = (self.kernel_size, self.kernel_size)
kernel_value = torch.zeros(
(self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
dtype=branch.weight.dtype,
device=branch.weight.device,
)
for i in range(self.in_channels):
kernel_value[
i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
"""Helper method to construct conv-batchnorm layers.
Args:
kernel_size: Size of the convolution kernel.
padding: Zero-padding size.
Returns:
Conv-BN module.
"""
# Fallback, sometimes batchnorm tensors
# do not get instantiated correctly on some processes
# when using deepspeed + accelerate
norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
if norm_layer.weight.shape[0] == 0:
norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
if norm_layer.bias.shape[0] == 0:
norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
mod_list = nn.Sequential()
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False,
),
)
mod_list.add_module("bn", norm_layer)
return mod_list
class ReparamLargeKernelConv(nn.Module):
"""Building Block of RepLKNet
This class defines overparameterized large kernel conv block
introduced in `RepLKNet `_
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
groups: int,
small_kernel: int,
inference_mode: bool = False,
use_se: bool = False,
activation: nn.Module = nn.GELU(),
) -> None:
"""Construct a ReparamLargeKernelConv module.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: Kernel size of the large kernel conv branch.
stride: Stride size. Default: 1
groups: Group number. Default: 1
small_kernel: Kernel size of small kernel conv branch.
inference_mode: If True, instantiates model in inference mode. Default: ``False``
activation: Activation module. Default: ``nn.GELU``
"""
super(ReparamLargeKernelConv, self).__init__()
self.stride = stride
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels
self.activation = activation
self.kernel_size = kernel_size
self.small_kernel = small_kernel
self.padding = kernel_size // 2
# Check if SE is requested
if use_se:
self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
else:
self.se = nn.Identity()
if inference_mode:
self.lkb_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=1,
groups=groups,
bias=True,
)
else:
self.lkb_origin = self._conv_bn(
kernel_size=kernel_size, padding=self.padding
)
if small_kernel is not None:
assert (
small_kernel <= kernel_size
), "The kernel size for re-param cannot be larger than the large kernel!"
self.small_conv = self._conv_bn(
kernel_size=small_kernel, padding=small_kernel // 2
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
if hasattr(self, "lkb_reparam"):
out = self.lkb_reparam(x)
else:
out = self.lkb_origin(x)
if hasattr(self, "small_conv"):
out += self.small_conv(x)
return self.activation(self.se(out))
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
if hasattr(self, "small_conv"):
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b
eq_k += nn.functional.pad(
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
)
return eq_k, eq_b
def reparameterize(self) -> None:
"""
Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
eq_k, eq_b = self.get_kernel_bias()
self.lkb_reparam = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.lkb_origin.conv.dilation,
groups=self.groups,
bias=True,
)
self.lkb_reparam.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b
self.__delattr__("lkb_origin")
if hasattr(self, "small_conv"):
self.__delattr__("small_conv")
@staticmethod
def _fuse_bn(
conv: torch.Tensor, bn: nn.BatchNorm2d
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with conv layer.
Args:
conv: Convolutional kernel weights.
bn: Batchnorm 2d layer.
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
"""Helper method to construct conv-batchnorm layers.
Args:
kernel_size: Size of the convolution kernel.
padding: Zero-padding size.
Returns:
A nn.Sequential Conv-BN module.
"""
# Fallback, sometimes batchnorm tensors
# do not get instantiated correctly on some processes
# when using deepspeed + accelerate
norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
if norm_layer.weight.shape[0] == 0:
norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
if norm_layer.bias.shape[0] == 0:
norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
mod_list = nn.Sequential()
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False,
),
)
mod_list.add_module("bn", norm_layer)
return mod_list
def convolutional_stem(
in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
Returns:
nn.Sequential object with stem elements.
"""
return nn.Sequential(
MobileOneBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
groups=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
use_scale_branch=use_scale_branch
),
MobileOneBlock(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
use_scale_branch=use_scale_branch
),
MobileOneBlock(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
use_scale_branch=use_scale_branch
),
)
class LayerNormChannel(nn.Module):
"""
LayerNorm only for Channel Dimension.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_features, eps=1e-05) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.eps = eps
def forward(self, x) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
+ self.bias.unsqueeze(-1).unsqueeze(-1)
return x
class MHSA(nn.Module):
"""Multi-headed Self Attention module.
Source modified from:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Build MHSA module that can handle 3D or 4D input tensors.
Args:
dim: Number of embedding dimensions.
head_dim: Number of hidden dimensions per head. Default: ``32``
qkv_bias: Use bias or not. Default: ``False``
attn_drop: Dropout rate for attention tensor.
proj_drop: Dropout rate for projection tensor.
"""
super().__init__()
assert dim % head_dim == 0, "dim should be divisible by head_dim"
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.shape
B, C, H, W = shape
N = H * W
if len(shape) == 4:
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# trick here to make q@k.t more stable
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
if len(shape) == 4:
x = x.transpose(-2, -1).reshape(B, C, H, W)
return x
class PatchEmbed(nn.Module):
"""Convolutional patch embedding layer."""
def __init__(
self,
patch_size: int,
stride: int,
in_channels: int,
embed_dim: int,
inference_mode: bool = False,
use_se: bool = False,
) -> None:
"""Build patch embedding layer.
Args:
patch_size: Patch size for embedding computation.
stride: Stride for convolutional embedding layer.
in_channels: Number of channels of input tensor.
embed_dim: Number of embedding dimensions.
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
use_se: If ``True`` SE block will be used.
"""
super().__init__()
block = list()
block.append(
ReparamLargeKernelConv(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=stride,
groups=in_channels,
small_kernel=3,
inference_mode=inference_mode,
use_se=use_se,
)
)
block.append(
MobileOneBlock(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=1,
stride=1,
padding=0,
groups=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
)
)
self.proj = nn.Sequential(*block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x
class RepMixer(nn.Module):
"""Reparameterizable token mixer.
For more details, please refer to our paper:
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_
"""
def __init__(
self,
dim,
kernel_size=3,
use_layer_scale=True,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Module.
Args:
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
kernel_size: Kernel size for spatial mixing. Default: 3
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.inference_mode = inference_mode
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=self.dim,
out_channels=self.dim,
kernel_size=self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
groups=self.dim,
bias=True,
)
else:
self.norm = MobileOneBlock(
dim,
dim,
kernel_size,
padding=kernel_size // 2,
groups=dim,
use_act=False,
use_scale_branch=False,
num_conv_branches=0,
)
self.mixer = MobileOneBlock(
dim,
dim,
kernel_size,
padding=kernel_size // 2,
groups=dim,
use_act=False,
)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
x = self.reparam_conv(x)
return x
else:
if self.use_layer_scale:
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
else:
x = x + self.mixer(x) - self.norm(x)
return x
def reparameterize(self) -> None:
"""Reparameterize mixer and norm into a single
convolutional layer for efficient inference.
"""
if self.inference_mode:
return
self.mixer.reparameterize()
self.norm.reparameterize()
if self.use_layer_scale:
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
)
b = torch.squeeze(self.layer_scale) * (
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
)
else:
w = (
self.mixer.id_tensor
+ self.mixer.reparam_conv.weight
- self.norm.reparam_conv.weight
)
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
self.reparam_conv = nn.Conv2d(
in_channels=self.dim,
out_channels=self.dim,
kernel_size=self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
groups=self.dim,
bias=True,
)
self.reparam_conv.weight.data = w
self.reparam_conv.bias.data = b
self.__delattr__("mixer")
self.__delattr__("norm")
if self.use_layer_scale:
self.__delattr__("layer_scale")
class ConvFFN(nn.Module):
"""Convolutional FFN Module."""
def __init__(
self,
in_channels: int,
hidden_channels: Optional[int] = None,
out_channels: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
Args:
in_channels: Number of input channels.
hidden_channels: Number of channels after expansion. Default: None
out_channels: Number of output channels. Default: None
act_layer: Activation layer. Default: ``GELU``
drop: Dropout rate. Default: ``0.0``.
"""
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.conv = nn.Sequential()
self.conv.add_module(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=7,
padding=3,
groups=in_channels,
bias=False,
),
)
# Fallback, sometimes batchnorm tensors
# do not get instantiated correctly on some processes
# when using deepspeed + accelerate
norm_layer = nn.BatchNorm2d(num_features=out_channels)
if norm_layer.weight.shape[0] == 0:
norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
if norm_layer.bias.shape[0] == 0:
norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
self.conv.add_module("bn", norm_layer)
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Conv2d):
normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class RepCPE(nn.Module):
"""Implementation of conditional positional encoding.
For more details refer to paper:
`Conditional Positional Encodings for Vision Transformers `_
In our implementation, we can reparameterize this module to eliminate a skip connection.
"""
def __init__(
self,
in_channels: int,
embed_dim: int = 768,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
Args:
in_channels: Number of input channels.
embed_dim: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super(RepCPE, self).__init__()
if isinstance(spatial_shape, int):
spatial_shape = tuple([spatial_shape] * 2)
assert isinstance(spatial_shape, Tuple), (
f'"spatial_shape" must by a sequence or int, '
f"get {type(spatial_shape)} instead."
)
assert len(spatial_shape) == 2, (
f'Length of "spatial_shape" should be 2, '
f"got {len(spatial_shape)} instead."
)
self.spatial_shape = spatial_shape
self.embed_dim = embed_dim
self.in_channels = in_channels
self.groups = embed_dim
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim,
bias=True,
)
else:
self.pe = nn.Conv2d(
in_channels,
embed_dim,
spatial_shape,
1,
int(spatial_shape[0] // 2),
bias=True,
groups=embed_dim,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
x = self.reparam_conv(x)
return x
else:
x = self.pe(x) + x
return x
def reparameterize(self) -> None:
# Build equivalent Id tensor
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(
self.in_channels,
input_dim,
self.spatial_shape[0],
self.spatial_shape[1],
),
dtype=self.pe.weight.dtype,
device=self.pe.weight.device,
)
for i in range(self.in_channels):
kernel_value[
i,
i % input_dim,
self.spatial_shape[0] // 2,
self.spatial_shape[1] // 2,
] = 1
id_tensor = kernel_value
# Reparameterize Id tensor and conv
w_final = id_tensor + self.pe.weight
b_final = self.pe.bias
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim,
bias=True,
)
self.reparam_conv.weight.data = w_final
self.reparam_conv.bias.data = b_final
self.__delattr__("pe")
class RepMixerBlock(nn.Module):
"""Implementation of Metaformer block with RepMixer as token mixer.
For more details on Metaformer structure, please refer to:
`MetaFormer Is Actually What You Need for Vision `_
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Block.
Args:
dim: Number of embedding dimensions.
kernel_size: Kernel size for repmixer. Default: 3
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super().__init__()
self.token_mixer = RepMixer(
dim,
kernel_size=kernel_size,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvFFN(
in_channels=dim,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Drop Path
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
def forward(self, x):
if self.use_layer_scale:
x = self.token_mixer(x)
x = x + self.drop_path(self.layer_scale * self.convffn(x))
else:
x = self.token_mixer(x)
x = x + self.drop_path(self.convffn(x))
return x
class AttentionBlock(nn.Module):
"""Implementation of metaformer block with MHSA as token mixer.
For more details on Metaformer structure, please refer to:
`MetaFormer Is Actually What You Need for Vision `_
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop: float = 0.0,
drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
):
"""Build Attention Block.
Args:
dim: Number of embedding dimensions.
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
"""
super().__init__()
# Fallback, sometimes batchnorm tensors
# do not get instantiated correctly on some processes
# when using deepspeed + accelerate
norm_layer_ = norm_layer(num_features=dim)
if norm_layer_.weight.shape[0] == 0:
norm_layer_.weight = nn.Parameter(torch.zeros(dim))
if norm_layer_.bias.shape[0] == 0:
norm_layer_.bias = nn.Parameter(torch.zeros(dim))
self.norm = norm_layer_
self.token_mixer = MHSA(dim=dim)
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvFFN(
in_channels=dim,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Drop path
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
else:
x = x + self.drop_path(self.token_mixer(self.norm(x)))
x = x + self.drop_path(self.convffn(x))
return x
def basic_blocks(
dim: int,
block_index: int,
num_blocks: List[int],
token_mixer_type: str,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode=False,
) -> nn.Sequential:
"""Build FastViT blocks within a stage.
Args:
dim: Number of embedding dimensions.
block_index: block index.
num_blocks: List containing number of blocks per stage.
token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio.
act_layer: Activation layer.
norm_layer: Normalization layer.
drop_rate: Dropout rate.
drop_path_rate: Drop path rate.
use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode.
Returns:
nn.Sequential object of all the blocks within the stage.
"""
blocks = []
for block_idx in range(num_blocks[block_index]):
block_dpr = (
drop_path_rate
* (block_idx + sum(num_blocks[:block_index]))
/ (sum(num_blocks) - 1)
)
if token_mixer_type == "repmixer":
blocks.append(
RepMixerBlock(
dim,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
)
elif token_mixer_type == "attention":
blocks.append(
AttentionBlock(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
)
)
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
)
blocks = nn.Sequential(*blocks)
return blocks
class GlobalPool2D(nn.Module):
"""This class implements global pooling with linear projection."""
def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
super().__init__()
scale = in_dim**-0.5
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
self.in_dim = in_dim
self.out_dim = out_dim
def pool(self, x) -> Tensor:
if x.dim() == 4:
dims = [-2, -1]
elif x.dim() == 5:
dims = [-3, -2, -1]
x = torch.mean(x, dim=dims, keepdim=False)
return x
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
# x is of shape [batch, in_dim]
assert (
x.dim() == 4
), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
x.shape
)
# [batch, in_dim, in_height, in_width] --> [batch, in_dim]
x = self.pool(x)
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
x = x @ self.proj
return x
class FastViT(nn.Module):
"""
This class implements `FastViT architecture `_
"""
def __init__(
self,
layers,
token_mixers: Tuple[str, ...],
embed_dims=None,
mlp_ratios=None,
downsamples=None,
se_downsamples=None,
repmixer_kernel_size=3,
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
num_classes=1000,
pos_embs=None,
down_patch_size=7,
down_stride=2,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
init_cfg=None,
pretrained=None,
cls_ratio=2.0,
inference_mode=False,
stem_scale_branch=True,
**kwargs,
) -> None:
super().__init__()
self.num_classes = num_classes
if len(layers) == 4:
self.out_indices = [0, 2, 4, 7]
elif len(layers) == 5:
self.out_indices = [0, 2, 4, 7, 10]
else:
raise NotImplementedError("FPN is not implemented for more than 5 stages.")
if pos_embs is None:
pos_embs = [None] * len(layers)
if se_downsamples is None:
se_downsamples = [False] * len(layers)
# Convolutional stem
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
use_scale_branch=stem_scale_branch)
# Build the main stages of the network architecture
network = []
for i in range(len(layers)):
# Add position embeddings if requested
if pos_embs[i] is not None:
network.append(
pos_embs[i](
embed_dims[i], embed_dims[i], inference_mode=inference_mode
)
)
stage = basic_blocks(
embed_dims[i],
i,
layers,
token_mixer_type=token_mixers[i],
kernel_size=repmixer_kernel_size,
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
network.append(stage)
if i >= len(layers) - 1:
break
# Patch merging/downsampling between stages.
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
network.append(
PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_channels=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
use_se=se_downsamples[i + 1],
)
)
self.network = nn.ModuleList(network)
# Classifier head
self.conv_exp = MobileOneBlock(
in_channels=embed_dims[-1],
out_channels=int(embed_dims[-1] * cls_ratio),
kernel_size=3,
stride=1,
padding=1,
groups=embed_dims[-1],
inference_mode=inference_mode,
use_se=True,
num_conv_branches=1,
)
self.head = (
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self.cls_init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
def cls_init_weights(self, m: nn.Module) -> None:
"""Init. for classification"""
if isinstance(m, nn.Linear):
normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
return x
def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for idx, block in enumerate(self.network):
x = block(x)
return x
def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
# for image classification/embedding
x = self.conv_exp(x)
cls_out = self.head(x)
out_dict = dict()
if kwargs.get("return_image_embeddings", False):
out_dict.update({"logits": cls_out})
out_dict.update({"image_embeddings": x})
return out_dict
else:
return cls_out
@register_model
def fastvithd(pretrained=False, **kwargs):
"""Instantiate FastViTHD model variant."""
layers = [2, 12, 24, 4, 2]
embed_dims = [96, 192, 384, 768, 1536]
mlp_ratios = [4, 4, 4, 4, 4]
downsamples = [True, True, True, True, True]
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
model = FastViT(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
norm_layer=LayerNormChannel,
stem_scale_branch=False,
inference_mode=True,
**kwargs,
)
model.default_cfg = default_cfgs["fastvit_m"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
================================================
FILE: llava/model/multimodal_encoder/mobileclip_encoder.py
================================================
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPImageProcessor
import llava.model.multimodal_encoder.mobileclip as mobileclip
class MobileCLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
self.input_image_size = int(vision_tower.split("_")[-1])
# Delay load is disabled for now
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
model_cfg = mobileclip.load_model_config(self.vision_tower_name)
self.cfg_only = model_cfg
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
# Load model config
model_cfg = mobileclip.load_model_config(self.vision_tower_name)
# Override default image resolution
model_cfg["image_cfg"]["image_size"] = self.input_image_size
self.cfg_only = model_cfg
# Build HF CLIPImageProcessor with MobileCLIP parameters
self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
"width": model_cfg["image_cfg"]["image_size"]},
image_mean=[0.0, 0.0, 0.0],
image_std=[1.0, 1.0, 1.0],
size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
# Instantiate the image encoder
self.vision_tower = mobileclip.MCi(model_name=model_cfg["image_cfg"]["model_name"],
projection_dim=model_cfg["embed_dim"])
if not self.tune_vision_tower:
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
# Features from penultimate layer
image_features = image_forward_outs["image_embeddings"]
# Reshape 4D tensor to 3D
B, C, H, W = image_features.shape
image_features = image_features.reshape(B, C, H*W)
image_features = image_features.transpose(1, 2)
return image_features
def forward(self, images):
if self.tune_vision_tower:
return self.forward_images(images)
else:
with torch.no_grad():
return self.forward_images(images)
def forward_images(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
return self.cfg_only
@property
def hidden_size(self):
return self.config["image_cfg"]["embed_dim"]
@property
def num_patches_per_side(self):
return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
@property
def num_patches(self):
return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
================================================
FILE: llava/model/multimodal_projector/builder.py
================================================
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
================================================
FILE: llava/model/utils.py
================================================
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and 'llava' not in cfg.model_type:
assert cfg.model_type == 'llama'
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
================================================
FILE: llava/serve/__init__.py
================================================
================================================
FILE: llava/serve/cli.py
================================================
import argparse
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
image = load_image(args.image_file)
image_size = image.size
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
conv.messages[-1][-1] = outputs
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
================================================
FILE: llava/serve/controller.py
================================================
"""
A controller manages distributed workers.
It sends worker addresses to clients.
"""
import argparse
import asyncio
import dataclasses
from enum import Enum, auto
import json
import logging
import time
from typing import List, Union
import threading
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from llava.utils import build_logger, server_error_msg
logger = build_logger("controller", "controller.log")
class DispatchMethod(Enum):
LOTTERY = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, name):
if name == "lottery":
return cls.LOTTERY
elif name == "shortest_queue":
return cls.SHORTEST_QUEUE
else:
raise ValueError(f"Invalid dispatch method")
@dataclasses.dataclass
class WorkerInfo:
model_names: List[str]
speed: int
queue_length: int
check_heart_beat: bool
last_heart_beat: str
def heart_beat_controller(controller):
while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stable_workers_by_expiration()
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
self.worker_info = {}
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(
target=heart_beat_controller, args=(self,), daemon=True)
self.heart_beat_thread.start()
logger.info("Init controller")
def register_worker(self, worker_name: str, check_heart_beat: bool,
worker_status: dict):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
logger.info(f"Register an existing worker: {worker_name}")
if not worker_status:
worker_status = self.get_worker_status(worker_name)
if not worker_status:
return False
self.worker_info[worker_name] = WorkerInfo(
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
check_heart_beat, time.time())
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
if r.status_code != 200:
logger.error(f"Get status fails: {worker_name}, {r}")
return None
return r.json()
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]
def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}
for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
logger.info(f"Remove stale worker: {w_name}")
def list_models(self):
model_names = set()
for w_name, w_info in self.worker_info.items():
model_names.update(w_info.model_names)
return list(model_names)
def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_speeds.append(w_info.speed)
worker_speeds = np.array(worker_speeds, dtype=np.float32)
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
continue
return worker_name
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
worker_names = []
worker_qlen = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_qlen.append(w_info.queue_length / w_info.speed)
if len(worker_names) == 0:
return ""
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
def receive_heart_beat(self, worker_name: str, queue_length: int):
if worker_name not in self.worker_info:
logger.info(f"Receive unknown heart beat. {worker_name}")
return False
self.worker_info[worker_name].queue_length = queue_length
self.worker_info[worker_name].last_heart_beat = time.time()
logger.info(f"Receive heart beat. {worker_name}")
return True
def remove_stable_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
to_delete.append(worker_name)
for worker_name in to_delete:
self.remove_worker(worker_name)
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"error_code": 2,
}
yield json.dumps(ret).encode() + b"\0"
try:
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True, timeout=5)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
ret = {
"text": server_error_msg,
"error_code": 3,
}
yield json.dumps(ret).encode() + b"\0"
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0
for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
return {
"model_names": list(model_names),
"speed": speed,
"queue_length": queue_length,
}
app = FastAPI()
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
data["worker_name"], data["check_heart_beat"],
data.get("worker_status", None))
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
exist = controller.receive_heart_beat(
data["worker_name"], data["queue_length"])
return {"exist": exist}
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator)
@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
parser.add_argument("--dispatch-method", type=str, choices=[
"lottery", "shortest_queue"], default="shortest_queue")
args = parser.parse_args()
logger.info(f"args: {args}")
controller = Controller(args.dispatch_method)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
================================================
FILE: llava/serve/gradio_web_server.py
================================================
import argparse
import datetime
import json
import os
import time
import gradio as gr
import requests
from llava.conversation import (default_conversation, conv_templates,
SeparatorStyle)
from llava.constants import LOGDIR
from llava.utils import (build_logger, server_error_msg,
violates_moderation, moderation_msg)
import hashlib
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "LLaVA Client"}
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list():
ret = requests.post(args.controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown(value=model, visible=True)
state = default_conversation.copy()
return state, dropdown_update
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
dropdown_update = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else ""
)
return state, dropdown_update
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if '' not in text:
# text = '' + text
text = text + '\n'
text = (text, image, image_process_mode)
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
if "llava" in model_name.lower():
if 'llama-2' in model_name.lower():
template_name = "llava_llama_2"
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
if 'orca' in model_name.lower():
template_name = "mistral_orca"
elif 'hermes' in model_name.lower():
template_name = "chatml_direct"
else:
template_name = "mistral_instruct"
elif 'llava-v1.6-34b' in model_name.lower():
template_name = "chatml_direct"
elif "v1" in model_name.lower():
if 'mmtag' in model_name.lower():
template_name = "v1_mmtag"
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
template_name = "v1_mmtag"
else:
template_name = "llava_v1"
elif "mpt" in model_name.lower():
template_name = "mpt"
else:
if 'mmtag' in model_name.lower():
template_name = "v0_mmtag"
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
template_name = "v0_mmtag"
else:
template_name = "llava_v0"
elif "mpt" in model_name:
template_name = "mpt_text"
elif "llama-2" in model_name:
template_name = "llama_2"
else:
template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address",
json={"model": model_name})
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": min(int(max_new_tokens), 1536),
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
}
logger.info(f"==== request ====\n{pload}")
pload['images'] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream",
headers=headers, json=pload, stream=True, timeout=10)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt):].strip()
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"images": all_image_hash,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
title_markdown = ("""
# 🌋 LLaVA: Large Language and Vision Assistant
[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)]
""")
tos_markdown = ("""
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
""")
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(examples=[
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
], inputs=[imagebox, textbox])
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="LLaVA Chatbot",
height=650,
layout="panel",
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[state, model_selector],
js=get_window_url_params
)
elif args.model_list_mode == "reload":
demo.load(
load_demo_refresh_model_list,
None,
[state, model_selector],
queue=False
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=16)
parser.add_argument("--model-list-mode", type=str, default="once",
choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
models = get_model_list()
logger.info(args)
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
demo.queue(
api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share
)
================================================
FILE: llava/serve/model_worker.py
================================================
"""
A model worker executes the model.
"""
import argparse
import asyncio
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import torch
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
from threading import Thread
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
load_8bit, load_4bit, device, use_flash_attn=False):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
self.is_multimodal = 'llava' in self.model_name.lower()
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,), daemon=True)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status()
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length()}, timeout=5)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
num_image_tokens = 0
if images is not None and len(images) > 0 and self.is_multimodal:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of tokens in prompt")
images = [load_image_from_base64(image) for image in images]
image_sizes = [image.size for image in images]
images = process_images(images, image_processor, model.config)
if type(images) is list:
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
else:
images = images.to(self.model.device, dtype=torch.float16)
replace_token = DEFAULT_IMAGE_TOKEN
if getattr(self.model.config, 'mm_use_im_start_end', False):
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
else:
images = None
image_sizes = None
image_args = {"images": images, "image_sizes": image_sizes}
else:
images = None
image_args = {}
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
keywords = [stop_str]
# stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
if max_new_tokens < 1:
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
return
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
**image_args
))
thread.start()
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[:-len(stop_str)]
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError as e:
print("Caught torch.cuda.CudaError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str,
default="http://localhost:21002")
parser.add_argument("--controller-address", type=str,
default="http://localhost:21001")
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--use-flash-attn", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
if args.multi_modal:
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
worker = ModelWorker(args.controller_address,
args.worker_address,
worker_id,
args.no_register,
args.model_path,
args.model_base,
args.model_name,
args.load_8bit,
args.load_4bit,
args.device,
use_flash_attn=args.use_flash_attn)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
================================================
FILE: llava/serve/register_worker.py
================================================
"""
Manually register workers.
Usage:
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str)
parser.add_argument("--worker-name", type=str)
parser.add_argument("--check-heart-beat", action="store_true")
args = parser.parse_args()
url = args.controller_address + "/register_worker"
data = {
"worker_name": args.worker_name,
"check_heart_beat": args.check_heart_beat,
"worker_status": None,
}
r = requests.post(url, json=data)
assert r.status_code == 200
================================================
FILE: llava/serve/sglang_worker.py
================================================
"""
A model worker executes the model.
"""
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import re
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
from llava.constants import DEFAULT_IMAGE_TOKEN
import sglang as sgl
from sglang.backend.runtime_endpoint import RuntimeEndpoint
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
@sgl.function
def pipeline(s, prompt, max_tokens):
for p in prompt:
if type(p) is str:
s += p
else:
s += sgl.image(p)
s += sgl.gen("response", max_tokens=max_tokens)
class ModelWorker:
def __init__(self, controller_addr, worker_addr, sgl_endpoint,
worker_id, no_register, model_name):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
# Select backend
backend = RuntimeEndpoint(sgl_endpoint)
sgl.set_default_backend(backend)
model_path = backend.model_info["model_path"]
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,), daemon=True)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status()
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length()}, timeout=5)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
async def generate_stream(self, params):
ori_prompt = prompt = params["prompt"]
images = params.get("images", None)
if images is not None and len(images) > 0:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of tokens in prompt")
images = [load_image_from_base64(image) for image in images]
# FIXME: for image-start/end token
# replace_token = DEFAULT_IMAGE_TOKEN
# if getattr(self.model.config, 'mm_use_im_start_end', False):
# replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
prompt = []
for i in range(len(prompt_split)):
prompt.append(prompt_split[i])
if i < len(images):
prompt.append(images[i])
else:
prompt = [prompt]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
# max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
stop_str = [stop_str] if stop_str is not None else None
print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
generated_text = ori_prompt
async for text_outputs in state.text_async_iter(var_name="response"):
generated_text += text_outputs
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
async def generate_stream_gate(self, params):
try:
async for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str,
default="http://localhost:21002")
parser.add_argument("--controller-address", type=str,
default="http://localhost:21001")
parser.add_argument("--model-name", type=str)
parser.add_argument("--sgl-endpoint", type=str)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
worker = ModelWorker(args.controller_address,
args.worker_address,
args.sgl_endpoint,
worker_id,
args.no_register,
args.model_name)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
================================================
FILE: llava/serve/test_message.py
================================================
import argparse
import json
import requests
from llava.conversation import default_conversation
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(controller_addr + "/get_worker_address",
json={"model": args.model_name})
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = default_conversation.copy()
conv.append_message(conv.roles[0], args.message)
prompt = conv.get_prompt()
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": args.model_name,
"prompt": prompt,
"max_new_tokens": args.max_new_tokens,
"temperature": 0.7,
"stop": conv.sep,
}
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
json=pload, stream=True)
print(prompt.replace(conv.sep, "\n"), end="")
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(conv.sep)[-1]
print(output, end="\r")
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--max-new-tokens", type=int, default=32)
parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.")
args = parser.parse_args()
main()
================================================
FILE: llava/train/llama_flash_attn_monkey_patch.py
================================================
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
================================================
FILE: llava/train/llama_xformers_attn_monkey_patch.py
================================================
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
import transformers.models.llama.modeling_llama
from torch import nn
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
def replace_llama_attn_with_xformers_attn():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
================================================
FILE: llava/train/llava_trainer.py
================================================
import os
import torch
import torch.nn as nn
from torch.utils.data import Sampler
import transformers
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
# ALL_LAYERNORM_LAYERS,
logger,
)
from typing import List, Optional
ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.BatchNorm2d]
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class LLaVATrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
lr_mapper = {}
if self.args.mm_projector_lr is not None:
lr_mapper["mm_projector"] = self.args.mm_projector_lr
if self.args.mm_vision_tower_lr is not None:
lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
if len(lr_mapper) > 0:
special_lr_parameters = [name for name, _ in opt_model.named_parameters() if
any(module_keyword in name for module_keyword in lr_mapper)]
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if
(n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if
(n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
for module_keyword, lr in lr_mapper.items():
module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if
(n in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if
(n not in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": lr,
},
]
)
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ['mm_projector', 'vision_resampler']
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
# Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144
model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None)
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
# Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144
self.model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None)
super(LLaVATrainer, self)._save(output_dir, state_dict)
================================================
FILE: llava/train/train.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from packaging import version
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import transformers
import tokenizers
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer
from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token, process_anyres_image
from PIL import Image
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
tune_mm_mlp_and_vision_tower: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default='linear')
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default='flat')
mm_vision_select_feature: Optional[str] = field(default="patch")
unfreeze_mm_vision_tower: bool = field(default=False)
s2: Optional[bool] = field(default=False)
hd: Optional[bool] = field(default=False)
@dataclass
class DataArguments:
data_path: Optional[List[str]] = field(default=None,
metadata={"help": "Optional list of paths to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[List[str]] = field(default=None)
image_aspect_ratio: str = 'square'
image_grid_pinpoints: Optional[str] = field(default=None)
image_crop_resolution: Optional[int] = field(default=None)
image_split_resolution: Optional[int] = field(default=None)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
mm_vision_tower_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
# Only save Adapter
keys_to_match = ['mm_projector']
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in', 'tok_embeddings'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = 'unknown'
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
sentence["value"] + END_SIGNAL)
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence['value']:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
sentence['value'] = sentence['value'].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def preprocess_llama_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
# fix: add qwen2
def preprocess_qwen_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
rounds_len = len(rounds)
cur_len = 0
# target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_ids = tokenizer_image_token(rou, tokenizer)
instruction_ids = tokenizer_image_token(parts[0], tokenizer)
equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
round_len = len(round_ids)
else:
round_ids = tokenizer(rou).input_ids
instruction_ids = tokenizer(parts[0]).input_ids
equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
round_len = len(round_ids)
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len + rounds_len - 2:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_IMAGE_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer, has_image=has_image)
# fix: add qwen2
if conversation_lib.default_conversation.version.startswith("qwen_v2"):
return preprocess_qwen_2(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: List[str],
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
#list_data_dict = json.load(open(data_path, "r"))
list_data_dict = []
for i, _data_path in enumerate(data_path):
data = json.load(open(_data_path, "r"))
data = [{**entry, 'img_path_idx': i} for entry in data]
list_data_dict += data
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
img_path_idx = self.list_data_dict[i]['img_path_idx']
image_folder = self.data_args.image_folder[img_path_idx]
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
image_size = image.size
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
elif self.data_args.image_aspect_ratio == "anyres" or "anyres_max" in self.data_args.image_aspect_ratio:
image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
data_dict['image_size'] = image_size
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
data_dict['image_size'] = (crop_size['height'], crop_size['width'])
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
batch['image_sizes'] = [instance['image_size'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
if model_args.vision_tower is not None:
if 'mpt' in model_args.model_name_or_path:
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = training_args.mpt_attn_impl
model = LlavaMptForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
elif 'dclm' in model_args.model_name_or_path.lower():
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = "eager"
model = LlavaOpenlmForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if 'mpt' in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right"
)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token="[PAD]"),
tokenizer=tokenizer,
model=model,
)
elif model_args.version == "v0.5":
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(
model_args=model_args,
fsdp=training_args.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
if model_args.tune_mm_mlp_and_vision_tower:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
for p in model.get_vision_tower().parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = LLaVATrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
================================================
FILE: llava/train/train_mem.py
================================================
from llava.train.train_qwen import train
if __name__ == "__main__":
train(attn_implementation="flash_attention_2")
================================================
FILE: llava/train/train_qwen.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from packaging import version
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import transformers
import tokenizers
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer
from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token, process_anyres_image
from PIL import Image
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
tune_mm_mlp_and_vision_tower: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default='linear')
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default='flat')
mm_vision_select_feature: Optional[str] = field(default="patch")
unfreeze_mm_vision_tower: bool = field(default=False)
s2: Optional[bool] = field(default=False)
hd: Optional[bool] = field(default=False)
@dataclass
class DataArguments:
data_path: Optional[List[str]] = field(default=None,
metadata={"help": "Optional list of paths to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[List[str]] = field(default=None)
image_aspect_ratio: str = 'square'
image_grid_pinpoints: Optional[str] = field(default=None)
image_crop_resolution: Optional[int] = field(default=None)
image_split_resolution: Optional[int] = field(default=None)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
mm_vision_tower_lr: Optional[float] = None
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
# Only save Adapter
keys_to_match = ['mm_projector']
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = 'unknown'
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
sentence["value"] + END_SIGNAL)
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence['value']:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
sentence['value'] = sentence['value'].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def preprocess_llama_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
# fix: add qwen2
# def preprocess_qwen_2(
# sources,
# tokenizer: transformers.PreTrainedTokenizer,
# has_image: bool = False
# ) -> Dict:
# # print('-----preprocess_qwen_2-------')
# conv = conversation_lib.default_conversation.copy()
# roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
#
# # Apply prompt templates
# conversations = []
# for i, source in enumerate(sources):
# if roles[source[0]["from"]] != conv.roles[0]:
# # Skip the first one if it is not from human
# source = source[1:]
#
# conv.messages = []
# for j, sentence in enumerate(source):
# role = roles[sentence["from"]]
# assert role == conv.roles[j % 2], f"{i}"
# conv.append_message(role, sentence["value"])
# conversations.append(conv.get_prompt())
#
# # Tokenize conversations
#
# if has_image:
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
# else:
# input_ids = tokenizer(
# conversations,
# return_tensors="pt",
# padding="longest",
# max_length=tokenizer.model_max_length,
# truncation=True,
# ).input_ids
#
# targets = input_ids.clone()
#
# assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2
#
# rank0_print(50*'S')
# # Mask targets
# sep = conv.sep + conv.roles[1] + ": "
# for conversation, target in zip(conversations, targets):
# total_len = int(target.ne(tokenizer.pad_token_id).sum())
# rank0_print(f"target.shape={target.shape}", f"total_len={total_len}")
#
# rounds = conversation.split(conv.sep2)
# rounds_len = len(rounds)
# cur_len = 0
# # target[:cur_len] = IGNORE_INDEX
# for i, rou in enumerate(rounds):
# if rou == "":
# break
#
# parts = rou.split(sep)
# if len(parts) != 2:
# break
# parts[0] += sep
#
# if has_image:
# round_ids = tokenizer_image_token(rou, tokenizer)
# instruction_ids = tokenizer_image_token(parts[0], tokenizer)
# equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
#
# instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
# round_len = len(round_ids)
#
# else:
# round_ids = tokenizer(rou).input_ids
# instruction_ids = tokenizer(parts[0]).input_ids
# equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
#
# instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
# round_len = len(round_ids)
#
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
# round_len += 1
# instruction_len += 1
#
# rank0_print(i, rou, f"cur_len={cur_len}", f"round_len={round_len}", f"instruction_len={instruction_len}", f"cur_len + instruction_len={cur_len + instruction_len}")
# target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
#
# cur_len += round_len
# rank0_print("Outside Loop")
# rank0_print(cur_len, len(target))
# rank0_print(target)
# rank0_print(50 * 'E')
# exit(0)
# target[cur_len:] = IGNORE_INDEX
#
# if cur_len < tokenizer.model_max_length:
# if cur_len != total_len + rounds_len - 2:
# target[:] = IGNORE_INDEX
# print(
# f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
# f" (ignored)"
# )
#
# return dict(
# input_ids=input_ids,
# labels=targets,
# )
def preprocess_qwen_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
try:
role = roles[sentence["from"]]
except KeyError as e:
print("e")
print("skipping sentence due to unrecognized role: {}".format(sentence["from"]))
continue
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2
split_sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds_before = conversation.split(
conv.sep) # system->user->assistant->user->assistant->user->assistant->user->assistant->Empty
if rounds_before[0] == conv.system:
rounds_before[1] = conv.sep.join([rounds_before[0], rounds_before[1]])
rounds_before = rounds_before[1:]
# connect every pair:
rounds = []
for i in range(0, len(rounds_before), 2):
if i < len(rounds_before)-1:
rounds.append(conv.sep.join([rounds_before[i], rounds_before[i + 1]]))
else:
rounds.append(rounds_before[i])
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(split_sep)
if len(parts) != 2:
break
assert parts[0].startswith(conv.roles[0]) or parts[0].startswith(conv.system)
parts[0] += split_sep
if has_image:
round_ids = tokenizer_image_token(rou, tokenizer)
instruction_ids = tokenizer_image_token(parts[0], tokenizer)
equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
round_len = len(round_ids)
else:
round_ids = tokenizer(rou).input_ids
instruction_ids = tokenizer(parts[0]).input_ids
equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)]
instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts)
round_len = len(round_ids)
round_len += 2 # this is tom compensate for the sep2
assert rou == parts[0]+parts[1]
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_IMAGE_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
# print("conversation:",conversation_lib.default_conversation.version)
# conversation_lib.default_conversation.version == "qwen_v2"
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version.startswith("v1"):
# print('--v1--')
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
# print('--mpt--')
return preprocess_mpt(sources, tokenizer, has_image=has_image)
# fix: add qwen2
if conversation_lib.default_conversation.version.startswith("qwen_v2"):
# print('--qwen_v2--')
return preprocess_qwen_2(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: List[str],
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
#list_data_dict = json.load(open(data_path, "r"))
list_data_dict = []
for i, _data_path in enumerate(data_path):
data = json.load(open(_data_path, "r"))
data = [{**entry, 'img_path_idx': i} for entry in data]
list_data_dict += data
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list
def get_sample(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
img_path_idx = self.list_data_dict[i]['img_path_idx']
image_folder = self.data_args.image_folder[img_path_idx]
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
image_size = image.size
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
elif self.data_args.image_aspect_ratio == "anyres" or "anyres_max" in self.data_args.image_aspect_ratio:
image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
data_dict['image_size'] = image_size
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
data_dict['image_size'] = (crop_size['height'], crop_size['width'])
return data_dict
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
try:
return self.get_sample(i)
except Exception as e:
print("Error loading sample")
print()
return self.get_sample(0)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
batch['image_sizes'] = [instance['image_size'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
if model_args.vision_tower is not None:
if 'mpt' in model_args.model_name_or_path:
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = training_args.mpt_attn_impl
model = LlavaMptForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
else:
model = LlavaQwen2ForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
model = transformers.Qwen2ForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if 'mpt' in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right"
)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token="[PAD]"),
tokenizer=tokenizer,
model=model,
)
elif model_args.version == "v0.5":
tokenizer.pad_token = tokenizer.unk_token
else:
if tokenizer.unk_token:
tokenizer.pad_token = tokenizer.unk_token
else: # use qwen
tokenizer.legacy = False
if model_args.version in conversation_lib.conv_templates:
# print('version:', model_args.version)
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(
model_args=model_args,
fsdp=training_args.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = LLaVATrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
================================================
FILE: llava/train/train_xformers.py
================================================
# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
# Need to call this before importing transformers.
from llava.train.train import train
from llava.train.llama_xformers_attn_monkey_patch import (
replace_llama_attn_with_xformers_attn,
)
replace_llama_attn_with_xformers_attn()
if __name__ == "__main__":
train()
================================================
FILE: llava/utils.py
================================================
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from llava.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True, encoding='UTF-8')
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
================================================
FILE: model_export/README.md
================================================
# Model Export for inference on Apple Silicon
Disclaimer: this is not an official recommendation, just research and exploration.
## Export Vision Encoder
We found that LLaVA trainer does not save all the states needed for auto inference,
predominantly used in third party libraries like `mlx-vlm`. We save additional metadata
to model checkpoint directory and export the vision model using coremltools.
Export vision encoder and patch the checkpoint using the instruction below.
```bash
python export_vision_encoder.py --model-path /path/to/fastvlm-checkpoint
```
## Export VLM
### Install mlx-vlm
We provide a patch to `mlx-vlm` to support inference of FastVLM.
```bash
git clone https://github.com/Blaizzy/mlx-vlm.git
cd mlx-vlm
git checkout 1884b551bc741f26b2d54d68fa89d4e934b9a3de
git apply ../fastvlm_mlx-vlm.patch
pip install -e .
```
Export model using the following instruction.
```bash
python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \
--mlx-path /path/to/exported-fastvlm \
--only-llm
```
To quantize the LLM, additional options can be provided as shown below.
`--q-bits` specifies bits per weight, the command below exports the LLM with 8-bit quantization.
```bash
python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \
--mlx-path /path/to/exported-fastvlm \
--only-llm \
-q \
--q-bits 8 # For 4-bit quantization, specify 4
```
### Generate
The exported model can be used for inference in a python environment following the instruction below.
```bash
python -m mlx_vlm.generate --model /path/to/exported-fastvlm \
--image /path/to/image.png \
--prompt "Describe the image." \
--max-tokens 256 \
--temp 0.0
```
## Troubleshooting
We noticed that sometimes `config.json` for the LLaVA model incorrectly sets the value for `tie_word_embeddings`.
This causes the following error during conversion, `ValueError: Received parameters not in model: language_model.lm_head.weight.`
If you encounter this error, set the value of `tie_word_embeddings` accordingly.
================================================
FILE: model_export/export_vision_encoder.py
================================================
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import os
import json
import copy
import argparse
import torch
import numpy as np
import coremltools
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path
def export(args):
# Load model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
args.model_base,
model_name,
device="mps")
# Save extra metadata that is not saved during LLaVA training
# required by HF for auto-loading model and for mlx-vlm preprocessing
# Save image processing config
setattr(image_processor, "processor_class", "LlavaProcessor")
output_path = os.path.join(model_path, "preprocessor_config.json")
image_processor.to_json_file(output_path)
# Create processor config
processor_config = dict()
processor_config["image_token"] = ""
processor_config["num_additional_image_tokens"] = 0
processor_config["processor_class"] = "LlavaProcessor"
processor_config["patch_size"] = 64
output_path = os.path.join(model_path, "processor_config.json")
json.dump(processor_config, open(output_path, "w"), indent=2)
# Modify tokenizer to include special token.
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
token_ids = list()
image_token_is_present = False
for k, v in tokenizer_config['added_tokens_decoder'].items():
token_ids.append(int(k))
if v["content"] == "":
image_token_is_present = True
token_ids.pop()
# Append only if token is not present
if not image_token_is_present:
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = ""
json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
# Modify config to contain token id for
config_path = os.path.join(model_path, "config.json")
model_config = json.load(open(config_path, 'r'))
model_config["image_token_index"] = max(token_ids) + 1
json.dump(model_config, open(config_path, 'w'), indent=2)
# Export the vision encoder to CoreML
image_res = image_processor.to_dict()['size']['shortest_edge']
inputs = torch.rand(1, 3, image_res, image_res)
inputs_tensor = [
coremltools.TensorType(
name="images",
shape=inputs.shape,
)
]
vision_model = model.get_vision_tower()
vision_model = vision_model.float()
traced_model = torch.jit.trace(vision_model, torch.Tensor(inputs))
pt_name = "fastvithd.pt"
traced_model.save(pt_name)
# Export
ml_model = coremltools.convert(
model=pt_name,
outputs=[coremltools.TensorType(name="image_features", dtype=np.float32)],
inputs=inputs_tensor,
convert_to="mlprogram",
debug=False,
compute_units=coremltools.ComputeUnit.CPU_AND_GPU,
minimum_deployment_target=coremltools.target.iOS16,
compute_precision=coremltools.precision.FLOAT32
)
ml_model_path = os.path.join(model_path, "fastvithd.mlpackage")
ml_model.save(ml_model_path)
# Remove traced model
os.remove(pt_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default="qwen_2")
args = parser.parse_args()
export(args)
================================================
FILE: model_export/fastvlm_mlx-vlm.patch
================================================
diff --git a/mlx_vlm/convert.py b/mlx_vlm/convert.py
index 5952a88..335e9db 100644
--- a/mlx_vlm/convert.py
+++ b/mlx_vlm/convert.py
@@ -55,6 +55,12 @@ def configure_parser() -> argparse.ArgumentParser:
action="store_true",
default=False,
)
+ parser.add_argument(
+ "--only-llm",
+ help="Convert only LLM.",
+ action="store_true",
+ default=False,
+ )
return parser
diff --git a/mlx_vlm/models/fastvlm/__init__.py b/mlx_vlm/models/fastvlm/__init__.py
new file mode 100644
index 0000000..691192e
--- /dev/null
+++ b/mlx_vlm/models/fastvlm/__init__.py
@@ -0,0 +1,7 @@
+from .fastvlm import (
+ LanguageModel,
+ Model,
+ ModelConfig,
+ TextConfig,
+ VisionConfig,
+)
diff --git a/mlx_vlm/models/fastvlm/fastvlm.py b/mlx_vlm/models/fastvlm/fastvlm.py
new file mode 100644
index 0000000..7db6497
--- /dev/null
+++ b/mlx_vlm/models/fastvlm/fastvlm.py
@@ -0,0 +1,187 @@
+import glob
+import inspect
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+import coremltools
+from huggingface_hub import snapshot_download
+
+from .language import LanguageModel, TextConfig
+
+
+@dataclass
+class VisionConfig:
+ mm_hidden_size: int
+ mm_vision_tower: str
+
+ @classmethod
+ def from_dict(cls, params):
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
+
+@dataclass
+class ModelConfig:
+ text_config: TextConfig
+ vision_config: VisionConfig
+ model_type: str
+ ignore_index: int = -100
+ image_token_index: int = 32000
+ vision_feature_select_strategy: str = "default"
+ vision_feature_layer: int = -2
+ vocab_size: int = 151936
+
+ @classmethod
+ def from_dict(cls, params):
+ # Copy text config parameters from root level
+ params["text_config"] = dict(
+ filter(lambda x: 'mm' not in x[0], params.items())
+ )
+ # Copy vision config parameters from root level
+ params["vision_config"] = dict(
+ filter(lambda x: 'mm' in x[0], params.items())
+ )
+
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
+
+
+class FastVLMMultiModalProjector(nn.Module):
+ def __init__(self, config: ModelConfig):
+ super().__init__()
+ self.linear_0 = nn.Linear(
+ config.vision_config.mm_hidden_size, config.text_config.hidden_size, bias=True
+ )
+ self.gelu = nn.GELU()
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=True
+ )
+
+ def __call__(self, x: mx.array) -> mx.array:
+ x = self.linear_0(x)
+ x = self.gelu(x)
+ x = self.linear_2(x)
+ return x
+
+
+class Model(nn.Module):
+ def __init__(self, config: ModelConfig):
+ super().__init__()
+ self.config = config
+ self.vision_tower = None
+ self.language_model = LanguageModel(config.text_config)
+ self.multi_modal_projector = FastVLMMultiModalProjector(config)
+ self.vision_feature_layer = config.vision_feature_layer
+ self.vision_feature_select_strategy = config.vision_feature_select_strategy
+
+ def get_input_embeddings(
+ self,
+ input_ids: Optional[mx.array] = None,
+ pixel_values: Optional[mx.array] = None,
+ ):
+ if pixel_values is None:
+ return self.language_model.model.embed_tokens(input_ids)
+
+ # Get the input embeddings from the language model
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+ # Get image features from CoreML model
+ coreml_out_dict = self.vision_tower.predict({"images": np.array(pixel_values, copy=False)})
+
+ # Pass image features through the multi-modal projector
+ image_features = self.multi_modal_projector(mx.array(coreml_out_dict["image_features"]))
+
+ # Insert special image tokens in the input_ids
+ final_inputs_embeds = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids
+ )
+ return final_inputs_embeds
+
+ def _merge_input_ids_with_image_features(
+ self, image_features, inputs_embeds, input_ids
+ ):
+ image_token_index = self.config.image_token_index
+ num_images, num_image_patches, embed_dim = image_features.shape
+
+ # Positions of tokens in input_ids, assuming batch size is 1
+ image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+ num_images, _, vision_hidden_size = image_features.shape
+
+ reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)
+
+ # cast to the dtype of the input_embeds to support quantized models
+ reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
+ inputs_embeds.dtype
+ )
+ inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
+ return inputs_embeds
+
+ def __call__(
+ self,
+ input_ids: mx.array,
+ pixel_values: mx.array,
+ mask: mx.array,
+ cache=None,
+ **kwargs,
+ ):
+ input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+ logits = self.language_model(
+ input_ids, cache=cache, inputs_embeds=input_embddings
+ )
+ return logits
+
+ @staticmethod
+ def from_pretrained(path_or_hf_repo: str):
+ path = Path(path_or_hf_repo)
+ if not path.exists():
+ path = Path(
+ snapshot_download(
+ repo_id=path_or_hf_repo,
+ allow_patterns=[
+ "*.json",
+ "*.safetensors",
+ "*.py",
+ "tokenizer.model",
+ "*.tiktoken",
+ ],
+ )
+ )
+
+ with open(path / "config.json", "r") as f:
+ model_config = json.load(f)
+
+ model_config = ModelConfig.from_dict(model_config)
+ model_config.text_config = TextConfig.from_dict(model_config.text_config)
+
+ model = Model(model_config)
+ weight_files = glob.glob(str(path / "*.safetensors"))
+ if not weight_files:
+ raise FileNotFoundError(f"No safetensors found in {path}")
+
+ weights = {}
+ for wf in weight_files:
+ weights.update(mx.load(wf))
+
+ weights = LanguageModel.sanitize(weights)
+
+ # Load CoreML vision tower
+ coreml_file = glob.glob(str(path / "*.mlpackage"))
+ assert len(coreml_file) == 1, "Found multiple vision model files"
+ model.vision_tower = coremltools.models.MLModel(coreml_file[0])
+
+ model.load_weights(list(weights.items()))
+ return model
diff --git a/mlx_vlm/models/fastvlm/language.py b/mlx_vlm/models/fastvlm/language.py
new file mode 100644
index 0000000..b791df4
--- /dev/null
+++ b/mlx_vlm/models/fastvlm/language.py
@@ -0,0 +1,220 @@
+import inspect
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+
+from ..base import KVCache, LanguageModelOutput, create_attention_mask
+
+
+@dataclass
+class TextConfig:
+ model_type: str
+ hidden_size: int
+ num_hidden_layers: int
+ intermediate_size: int
+ num_attention_heads: int
+ rms_norm_eps: float
+ vocab_size: int
+ num_key_value_heads: Optional[int] = None
+ max_position_embeddings: Optional[int] = 32768
+ rope_theta: float = 1000000
+ rope_traditional: bool = False
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+ tie_word_embeddings: bool = True
+
+ def __post_init__(self):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+ if self.rope_scaling:
+ required_keys = {"mrope_section", "type"}
+ if not all(key in self.rope_scaling for key in required_keys):
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
+
+ @classmethod
+ def from_dict(cls, params):
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
+
+
+class Attention(nn.Module):
+ def __init__(self, args: TextConfig):
+ super().__init__()
+
+ dim = args.hidden_size
+ self.n_heads = n_heads = args.num_attention_heads
+ assert args.num_key_value_heads is not None
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+ self.head_dim = head_dim = args.hidden_size // n_heads
+ self.scale = head_dim**-0.5
+
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+
+ self.rotary_emb = nn.RoPE(
+ head_dim,
+ base=args.rope_theta,
+ traditional=args.rope_traditional,
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
+ 0, 2, 1, 3
+ )
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
+ 0, 2, 1, 3
+ )
+
+ offset = cache.offset if cache else 0
+
+ if mask is not None:
+ mask = mask[..., : keys.shape[-2]]
+
+ queries = self.rotary_emb(queries, offset=offset)
+ keys = self.rotary_emb(keys, offset=offset)
+
+ if cache is not None:
+ keys, values = cache.update_and_fetch(keys, values)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class MLP(nn.Module):
+ def __init__(self, dim, hidden_dim):
+ super().__init__()
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+
+ def __call__(self, x) -> mx.array:
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+
+
+class Qwen2DecoderLayer(nn.Module):
+ def __init__(self, args: TextConfig):
+ super().__init__()
+ self.num_attention_heads = args.num_attention_heads
+ self.hidden_size = args.hidden_size
+ self.self_attn = Attention(args)
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.post_attention_layernorm = nn.RMSNorm(
+ args.hidden_size, eps=args.rms_norm_eps
+ )
+ self.args = args
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
+ h = x + r
+ r = self.mlp(self.post_attention_layernorm(h))
+ out = h + r
+ return out
+
+
+class Qwen2Model(nn.Module):
+ def __init__(self, args: TextConfig):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+ self.num_hidden_layers = args.num_hidden_layers
+ assert self.vocab_size > 0
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ Qwen2DecoderLayer(args=args) for _ in range(args.num_hidden_layers)
+ ]
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ inputs_embeds: Optional[mx.array] = None,
+ ):
+ if inputs_embeds is None:
+ h = self.embed_tokens(inputs)
+ else:
+ h = inputs_embeds
+
+ mask = create_attention_mask(h, cache)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, c)
+
+ return self.norm(h)
+
+
+class LanguageModel(nn.Module):
+ def __init__(self, args: TextConfig):
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.model = Qwen2Model(args)
+
+ if "qwen2" not in args.model_type:
+ raise ValueError(f"Unsupported model type: {args.model_type}")
+
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ inputs_embeds: Optional[mx.array] = None,
+ mask: Optional[mx.array] = None,
+ ):
+ out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds)
+ if self.args.tie_word_embeddings:
+ out = self.model.embed_tokens.as_linear(out)
+ else:
+ out = self.lm_head(out)
+ return LanguageModelOutput(logits=out)
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def head_dim(self):
+ return self.args.hidden_size // self.args.num_attention_heads
+
+ @property
+ def n_kv_heads(self):
+ return self.args.num_key_value_heads
diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py
index 725e811..ba48296 100644
--- a/mlx_vlm/prompt_utils.py
+++ b/mlx_vlm/prompt_utils.py
@@ -93,6 +93,7 @@ def get_message_json(
"idefics2": "message_list_with_image",
"idefics3": "message_list_with_image",
"llava": "message_list_with_image",
+ "llava_qwen2": "message_with_image_token_new_line",
"llava_next": "message_list_with_image",
"mllama": "message_list_with_image",
# Models that can handle both image and video formats
@@ -143,7 +144,7 @@ def get_message_json(
def get_chat_template(processor, messages, add_generation_prompt, tokenize=False):
- if "chat_template" in processor.__dict__.keys():
+ if ("chat_template" in processor.__dict__.keys()) and (processor.chat_template is not None):
return processor.apply_chat_template(
messages,
tokenize=tokenize,
diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py
index 4acff3e..00f366f 100644
--- a/mlx_vlm/utils.py
+++ b/mlx_vlm/utils.py
@@ -1,3 +1,4 @@
+import os
import copy
import glob
import importlib
@@ -15,6 +16,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
import requests
+import coremltools
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_unflatten
from PIL import Image, ImageOps
@@ -31,7 +33,7 @@ from .tokenizer_utils import load_tokenizer
from .trainer import apply_lora_layers
# Constants
-MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"}
+MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny", "llava_qwen2": "fastvlm"}
MAX_FILE_SIZE_GB = 5
@@ -168,9 +170,19 @@ python -m mlx_vlm.convert --hf-path --mlx-path
# Sanitize weights
weights = sanitize_weights(model, weights)
- weights = sanitize_weights(
- model_class.VisionModel, weights, model_config.vision_config
- )
+ if hasattr(model_class, 'VisionModel'):
+ weights = sanitize_weights(
+ model_class.VisionModel, weights, model_config.vision_config
+ )
+ else:
+ # Load CoreML vision tower
+ print("Looking for CoreML vision tower")
+ coreml_file = glob.glob(str(model_path / "*.mlpackage"))
+ if len(coreml_file) > 0:
+ assert len(coreml_file) == 1, "Found multiple vision model files."
+ print(f"Loading {coreml_file[0]} vision tower")
+ model.vision_tower = coremltools.models.MLModel(coreml_file[0])
+
weights = sanitize_weights(
model_class.LanguageModel, weights, model_config.text_config
)
@@ -185,7 +197,21 @@ python -m mlx_vlm.convert --hf-path --mlx-path
class_predicate=class_predicate,
)
- model.load_weights(list(weights.items()))
+ if kwargs.get("only_llm", False):
+ # Ignore vision tower weights
+ new_weights = dict()
+ for k, v in weights.items():
+ if 'vision_tower' in k:
+ continue
+ if 'mm_projector' in k:
+ new_k = k.replace('model.mm_projector.', 'multi_modal_projector.linear_')
+ new_weights[new_k] = v
+ else:
+ new_weights['language_model.'+k] = v
+
+ model.load_weights(list(new_weights.items()))
+ else:
+ model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
@@ -669,11 +695,12 @@ def convert(
dequantize: bool = False,
skip_vision: bool = False,
trust_remote_code: bool = True,
+ only_llm: bool = False
):
print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, processor = fetch_from_hub(
- model_path, lazy=True, trust_remote_code=trust_remote_code
+ model_path, lazy=True, trust_remote_code=trust_remote_code, only_llm=only_llm
)
weights = dict(tree_flatten(model.parameters()))
@@ -709,6 +736,12 @@ def convert(
save_config(config, config_path=mlx_path / "config.json")
+ # Copy over any coreml files if found
+ coreml_files = glob.glob(str(model_path / "*.mlpackage"))
+ for file in coreml_files:
+ des_path = os.path.join(mlx_path, file.split(os.path.sep)[-1])
+ shutil.copytree(file, des_path)
+
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
================================================
FILE: predict.py
================================================
#
# Modified from LLaVA/predict.py
# Please see ACKNOWLEDGEMENTS for details about LICENSE
#
import os
import argparse
import torch
from PIL import Image
from llava.utils import disable_torch_init
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
def predict(args):
# Remove generation config from model folder
# to read generation parameters from args
model_path = os.path.expanduser(args.model_path)
generation_config = None
if os.path.exists(os.path.join(model_path, 'generation_config.json')):
generation_config = os.path.join(model_path, '.generation_config.json')
os.rename(os.path.join(model_path, 'generation_config.json'),
generation_config)
# Load model
disable_torch_init()
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps")
# Construct prompt
qs = args.prompt
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Set the pad token id for generation
model.generation_config.pad_token_id = tokenizer.pad_token_id
# Tokenize prompt
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch.device("mps"))
# Load and preprocess image
image = Image.open(args.image_file).convert('RGB')
image_tensor = process_images([image], image_processor, model.config)[0]
# Run inference
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half(),
image_sizes=[image.size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=256,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
# Restore generation config
if generation_config is not None:
os.rename(generation_config, os.path.join(model_path, 'generation_config.json'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="./llava-v1.5-13b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default=None, help="location of image file")
parser.add_argument("--prompt", type=str, default="Describe the image.", help="Prompt for VLM.")
parser.add_argument("--conv-mode", type=str, default="qwen_2")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()
predict(args)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "llava"
version = "1.2.2.post1"
description = "Towards GPT-4 like large language and visual assistant."
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"torch==2.6.0", "torchvision==0.21.0",
"transformers==4.48.3", "tokenizers==0.21.0", "sentencepiece==0.1.99", "shortuuid",
"accelerate==1.6.0", "peft>=0.10.0,<0.14.0", "bitsandbytes",
"pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
"gradio==5.11.0", "requests", "uvicorn", "fastapi",
"einops==0.6.1", "einops-exts==0.0.4", "timm==1.0.15",
"coremltools==8.2"
]
[project.optional-dependencies]
train = ["deepspeed==0.13.1", "ninja", "wandb"]
build = ["build", "twine"]
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]