Repository: LinkSoul-AI/LLaSM
Branch: main
Commit: 3443b79c372f
Files: 11
Total size: 39.8 KB
Directory structure:
gitextract_2906ffjf/
├── .gitignore
├── LICENSE
├── README.md
├── examples/
│ ├── 0.txt
│ ├── 1.txt
│ └── 2.txt
├── infer.py
├── infer_tokenize.py
├── llasm.py
├── logger.py
└── pyproject.toml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# LLaSM: Large Language and Speech Model
[](https://huggingface.co/spaces/LinkSoul/LLaSM) [](https://huggingface.co/spaces/LinkSoul/LLaSM) [](https://github.com/LinkSoul-AI/LLaSM/blob/main/LICENSE) [](https://arxiv.org/abs/2308.15930) [](https://huggingface.co/spaces/LinkSoul/LLaSM) [](https://huggingface.co/datasets/LinkSoul/LLaSM-Audio-Instructions)
开源,可商用的**中英文双语语音-语言助手 LLaSM 以及中英文语音 SFT 数据集 LLaSM-Audio-Instructions**,第一个支持中英文语音-文本多模态对话的开源可商用对话模型。
## 模型框架

## 基础演示

## 在线试玩
> Talk is cheap, Show you the Demo.
- [Demo 地址 / Hugging Face Spaces](https://huggingface.co/spaces/LinkSoul/LLaSM)
## 论文
- arXiv 链接:https://arxiv.org/abs/2308.15930
## 资源下载
- Hugging Face模型下载:
- [LLaSM-Chinese-Llama-2-7B](https://huggingface.co/LinkSoul/LLaSM-Cllama2)
- [LLaSM-Baichuan-7B](https://huggingface.co/LinkSoul/LLaSM-Baichuan)
- 百度网盘下载:
- [LLaSM-Chinese-Llama-2-7B](https://pan.baidu.com/s/1PaipNDfqV7f3W1-tl5rwzA?pwd=2549)
- [LLaSM-Baichuan-7B](https://pan.baidu.com/s/1QZrXA8IJXclN77T4jM7tEw?pwd=y2p7)
- 语言模型:
- [Chinese-Llama-2-7b](https://github.com/LinkSoul-AI/Chinese-Llama-2-7b)
- [Baichuan-7B](https://huggingface.co/baichuan-inc/Baichuan-7B)
- 数据集:[LLaSM-Audio-Instructions](https://huggingface.co/datasets/LinkSoul/LLaSM-Audio-Instructions)
## 环境安装
```shell
# clone the repository
git clone https://github.com/LinkSoul-AI/LLaSM
cd LLaSM
# install package
conda create -n llasm python=3.10 -y
conda activate llasm
pip install --upgrade pip
pip install -e .
```
## 快速测试
- 下载 Whisper large v2 模型:https://huggingface.co/openai/whisper-large-v2
```shell
export LLASM_DEVICE="cuda:0"
python infer.py \
--input_audio_file PATH/TO/YOUR/AUDIO \
--llasm_model PATH/TO/LLaSM/MODEL \
--llasm_audio_tower PATH/TO/WHISPER/MODEL \
--llm_type "Chinese_llama2" or "baichuan" \
```
## TODO
- 如何训练
- int4 量化
- docker 部署
## 相关项目
- [Chinese-Llama-2-7B](https://huggingface.co/LinkSoul/Chinese-Llama-2-7b)
- [Whisper](https://ai.meta.com/llama/)
- [baichuan-inc/Baichuan-7B](https://huggingface.co/baichuan-inc/Baichuan-7B)
## 项目协议
[Apache-2.0 license](https://github.com/LinkSoul-AI/LLaSM/blob/main/LICENSE)
## Citation
如果您发现我们的工作和此仓库有用,欢迎给一个星星 :star: 鼓励我们一下 :beer::
```bibtex
@misc{shu2023llasm,
title={LLaSM: Large Language and Speech Model},
author={Yu Shu and Siwei Dong and Guangyao Chen and Wenhao Huang and Ruihua Zhang and Daochen Shi and Qiqi Xiang and Yemin Shi},
year={2023},
eprint={2308.15930},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
## 微信交流群
================================================
FILE: examples/0.txt
================================================
请介绍一下勾股定理
================================================
FILE: examples/1.txt
================================================
世界上人口最多的是哪个国家
================================================
FILE: examples/2.txt
================================================
请介绍一下北京
================================================
FILE: infer.py
================================================
import os
import librosa
import argparse
import torch
from transformers import AutoTokenizer
from transformers import (
WhisperProcessor,
WhisperModel,
)
from llasm import LlaaaLlamaForCausalLM
from infer_tokenize import tokenize
from logger import print_signature
DEFAULT_AUDIO_PATCH_TOKEN = ""
DEFAULT_AUDIO_START_TOKEN = ""
DEFAULT_AUDIO_END_TOKEN = ""
class Setting:
def __init__(self):
self.device = os.environ.get("LLASM_DEVICE", "cuda")
self.llasm_context_len = 2048
self.sampling_rate = 16000
self.audio_token_len = 64
self.stop = ""
CONFIG = Setting()
def main(args):
input_audio_file = args.input_audio_file
temperature = args.temperature
max_new_tokens = args.max_new_tokens
# step0: load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.llasm_model)
# step0-1: add special token //
tokenizer.add_tokens([DEFAULT_AUDIO_PATCH_TOKEN], special_tokens=True)
tokenizer.add_tokens([DEFAULT_AUDIO_START_TOKEN, DEFAULT_AUDIO_END_TOKEN], special_tokens=True)
# step1: load model
model = LlaaaLlamaForCausalLM.from_pretrained(
args.llasm_model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True).to(CONFIG.device)
# step2: load audio processor
audio_processor = WhisperProcessor.from_pretrained(args.llasm_audio_tower, torch_dtype=torch.float16)
# step3: load audio tower
audio_tower = WhisperModel.from_pretrained(
args.llasm_audio_tower,
torch_dtype=torch.float16,
low_cpu_mem_usage=True).to(CONFIG.device)
# step3-1: update audio_tower config for setting special tokens
audio_config = audio_tower.config
audio_config.audio_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_AUDIO_PATCH_TOKEN])[0]
audio_config.audio_start_token, audio_config.audio_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_AUDIO_START_TOKEN, DEFAULT_AUDIO_END_TOKEN])
model.get_model().audio_tower[0] = audio_tower
# step4 preprocessing input audio
audio, _ = librosa.load(input_audio_file, sr=CONFIG.sampling_rate)
audio_feat = audio_processor(audio, sampling_rate=CONFIG.sampling_rate, return_tensors="pt").input_features
audio_feat = audio_feat.unsqueeze(0).unsqueeze(0).to(CONFIG.device, dtype=torch.float16)
# step5: tokenize
qs = DEFAULT_AUDIO_START_TOKEN + DEFAULT_AUDIO_PATCH_TOKEN * CONFIG.audio_token_len + DEFAULT_AUDIO_END_TOKEN
input_qs = {
"conversations": [{
"from": "human",
"value": qs,
},{
"from": "gpt",
"value": ""
}]
}
input_ids = torch.tensor([tokenize(input_qs, tokenizer, args.llm_type)]).to(CONFIG.device)
# step6: infer run
stop_str = CONFIG.stop
output_ids = model.generate(input_ids,audios=audio_feat,do_sample=True,temperature=temperature,max_new_tokens=max_new_tokens)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
label = []
with open(input_audio_file[:-len('mp3')] + 'txt', 'r') as fh:
for ln in fh:
label.append(ln.strip())
text = ''.join(label)
print_signature()
print (f"Human: {input_audio_file} ({text})")
print (f"LLaSM: {outputs}")
print ("="*80)
print ("Go to the Demo page, and have a try!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_audio_file', type=str, default='./examples/0.mp3')
parser.add_argument('--llasm_model', type=str, default='path/to/llasm_model')
parser.add_argument('--llasm_audio_tower', type=str, default='path/to/whisper_large_v2')
parser.add_argument('--llm_type', type=str, default='Chinese_llama2')
parser.add_argument('--temperature', type=float, default=0.2)
parser.add_argument('--max_new_tokens', type=int, default=1024)
args = parser.parse_args()
main(args)
================================================
FILE: infer_tokenize.py
================================================
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
simple_audio_conv_multimodal = {
"system": "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.",
"roles": {"human": "USER", "gpt": "ASSISTANT"},
}
def tokenize_baichuan(item, tokenizer):
roles = simple_audio_conv_multimodal["roles"]
input_ids = []
if "instruction" in item and len(item["instruction"]) > 0:
system = item["instruction"]
else:
system = simple_audio_conv_multimodal["system"]
system_ids = tokenizer.encode(system, add_special_tokens=False)
input_ids += system_ids
for i, turn in enumerate(item["conversations"]):
role = roles.get(turn['from'], 'USER')
content = turn['value']
content = content.strip()
if role == 'ASSISTANT' and content != '':
content += ''
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
max_length=tokenizer.model_max_length)
input_ids += role_ids + content_ids
if tokenizer.add_bos_token:
input_ids = [tokenizer.bos_token_id] + input_ids
input_ids = input_ids[-tokenizer.model_max_length:]
return input_ids
def tokenize_Cllama2(item, tokenizer):
input_ids = []
if "instruction" in item and len(item["instruction"]) > 0:
system = item["instruction"]
else:
system = simple_audio_conv_multimodal["system"]
system = B_SYS + system + E_SYS
system_ids = tokenizer.encode(system, add_special_tokens=False)
input_ids += system_ids
item["conversations"][0]['value'] = system + item["conversations"][0]['value']
for i, turn in enumerate(item["conversations"]):
role = turn['from']
content = turn['value']
content = content.strip()
if role == 'human':
content = f"{B_INST} {content} {E_INST} "
content_ids = tokenizer.encode(content)
else:
# assert role == "gpt"
if content == "":
content_ids = []
else:
content = f"{content} "
content_ids = tokenizer.encode(content, add_special_tokens=False) + [tokenizer.eos_token_id] # add_special_tokens=False remove bos token, and add eos at the end
input_ids += content_ids
input_ids = input_ids[-tokenizer.model_max_length:]
return input_ids
def tokenize(item, tokenizer, llm_type):
if llm_type == "Chinese_llama2":
return tokenize_Cllama2(item, tokenizer)
elif llm_type == "baichuan":
return tokenize_baichuan(item, tokenizer)
else:
raise ValueError (f"Invalid llm type {llm_type}, please choose in ['Chinese_llama2', 'baichuan']")
================================================
FILE: llasm.py
================================================
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers import (
WhisperProcessor,
WhisperModel,
)
DEFAULT_AUDIO_PATCH_TOKEN = ""
DEFAULT_AUDIO_START_TOKEN = ""
DEFAULT_AUDIO_END_TOKEN = ""
class LlaaaConfig(LlamaConfig):
model_type = "llaaa"
def load_whisper(audio_tower_name):
model = WhisperModel.from_pretrained(audio_tower_name)
model.config.forced_decoder_ids = None
return model
class LlaaaLlamaModel(LlamaModel):
config_class = LlaaaConfig
def __init__(self, config: LlamaConfig):
super(LlaaaLlamaModel, self).__init__(config)
if hasattr(config, "mm_audio_tower"):
# HACK: for FSDP
self.audio_tower = [load_whisper(config.mm_audio_tower)]
if hasattr(config, "use_mm_proj"):
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
def initialize_audio_modules(self, audio_tower, audio_token_len, pretrain_mm_mlp_adapter=None):
self.config.mm_audio_tower = audio_tower
processor = WhisperProcessor.from_pretrained(audio_tower)
if not hasattr(self, 'audio_tower'):
audio_tower = load_whisper(audio_tower)
else:
audio_tower = self.audio_tower[0]
audio_tower.requires_grad_(False)
audio_tower = audio_tower.to(torch.float16)
self.audio_tower = [audio_tower]
self.config.use_mm_proj = True
self.config.mm_hidden_size = 1280
self.config.audio_token_len = audio_token_len
if not hasattr(self, 'mm_projector'):
self.mm_projector = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size)
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
return dict(
processor=processor,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
audios: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaAA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
audio_tower = getattr(self, 'audio_tower', None)
if audio_tower is not None and (input_ids.shape[1] != 1 or self.training) and audios is not None:
audio_tower = audio_tower[0] # HACK: for FSDP
with torch.no_grad():
bs_audio_features = []
for audios_list in audios:
if len(audios_list) == 0:
dummy_audio_feature = torch.zeros(self.config.audio_token_len, self.config.mm_hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
audio_features = [dummy_audio_feature]
else:
audio_features = []
for audio in audios_list:
decoder_input_ids = torch.ones((1, self.config.audio_token_len)) * audio_tower.config.decoder_start_token_id
decoder_input_ids = decoder_input_ids.to(audio.device).to(torch.long)
audio_feature = audio_tower(audio, decoder_input_ids=decoder_input_ids).last_hidden_state
audio_features.append(audio_feature)
bs_audio_features.append(audio_features)
audio_config = audio_tower.config
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_audio_features in zip(input_ids, inputs_embeds, bs_audio_features):
if (cur_input_ids == audio_config.audio_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal, for using both language and audio data
dummy_audio_features = self.mm_projector(cur_audio_features[0])
cur_input_embeds = cur_input_embeds + (0. * dummy_audio_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if (cur_input_ids == audio_config.audio_start_token).sum() != (cur_input_ids == audio_config.audio_end_token).sum():
raise ValueError("The number of audio start tokens and audio end tokens should be the same.")
audio_start_tokens = torch.where(cur_input_ids == audio_config.audio_start_token)[0]
if len(audio_start_tokens) != len(cur_audio_features):
raise ValueError(f"The number of audio start tokens ({len(audio_start_tokens)}) and audio features ({len(cur_audio_features)}) should be the same.")
for audio_start_token_pos, cur_audio_feature in zip(audio_start_tokens, cur_audio_features):
cur_audio_feature = self.mm_projector(cur_audio_feature)[0]
cur_audio_feature = cur_audio_feature.to(device=cur_input_embeds.device)
num_patches = cur_audio_feature.shape[0]
if cur_input_ids[audio_start_token_pos + num_patches + 1] != audio_config.audio_end_token:
raise ValueError("The audio end token should follow the audio start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:audio_start_token_pos].detach(),
cur_input_embeds[audio_start_token_pos:audio_start_token_pos+1],
cur_audio_feature,
cur_input_embeds[audio_start_token_pos + num_patches + 1:audio_start_token_pos + num_patches + 2],
cur_input_embeds[audio_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat((
cur_input_embeds[:audio_start_token_pos+1],
cur_audio_feature,
cur_input_embeds[audio_start_token_pos + num_patches + 1:]), dim=0)
new_input_embeds.append(cur_new_input_embeds)
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(LlaaaLlamaModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class LlaaaLlamaForCausalLM(LlamaForCausalLM):
config_class = LlaaaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlaaaLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
audios: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
audios=audios
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"audios": kwargs.get("audios", None),
}
)
return model_inputs
def initialize_audio_tokenizer(self, tokenizer, device,
tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
num_new_tokens = tokenizer.add_tokens([DEFAULT_AUDIO_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
num_new_tokens += tokenizer.add_tokens([DEFAULT_AUDIO_START_TOKEN, DEFAULT_AUDIO_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tune_mm_mlp_adapter:
self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if pretrain_mm_mlp_adapter and num_new_tokens > 0:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 3
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
audio_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_AUDIO_PATCH_TOKEN])[0]
audio_start_token, audio_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_AUDIO_START_TOKEN, DEFAULT_AUDIO_END_TOKEN])
self.model.audio_tower[0].config.audio_patch_token = audio_patch_token
self.model.audio_tower[0].config.audio_start_token = audio_start_token
self.model.audio_tower[0].config.audio_end_token = audio_end_token
AutoConfig.register("llaaa", LlaaaConfig)
AutoModelForCausalLM.register(LlaaaConfig, LlaaaLlamaForCausalLM)
================================================
FILE: logger.py
================================================
def print_signature():
llasm = """\
__ __ __
/ / / / __ _/ _\ /\/\
/ / / / / _` \ \ / \
/ /___/ /__| (_| |\ \/ /\/\ \\
\____/\____/\__,_\__/\/ \/
"""
logo = """\
__ _ _ __ _
/ /(_)_ __ | | __/ _\ ___ _ _| |
/ / | | '_ \| |/ /\ \ / _ \| | | | |
/ /__| | | | | < _\ \ (_) | |_| | |
\____/_|_| |_|_|\_\\\__/\___/ \__,_|_|
"""
print ("="*80)
print (llasm)
print (logo)
print ("-"*80)
print ("Demo/HuggingFace: https://huggingface.co/spaces/LinkSoul/LLaSM")
print ("欢迎点一点 Star ^_^")
print ("="*80)
if __name__ == '__main__':
print_signature()
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "llasm"
version = "1.0.0"
description = "LLaSM: Large Language and Speech Model."
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"numpy", "requests",
"librosa", "protobuf", "accelerate",
"tokenizers>=0.12.1",
"torch", "torchvision",
"transformers==4.31.0",
"sentencepiece==0.1.99",
]
[project.urls]
"Homepage" = "https://huggingface.co/spaces/LinkSoul/LLaSM"
"Bug Tracker" = "https://github.com/LinkSoul-AI/LLaSM/issues"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]