Showing preview only (280K chars total). Download the full file or copy to clipboard to get everything.
Repository: jzq2000/MoonCast
Branch: main
Commit: 07037e58f427
Files: 43
Total size: 252.7 KB
Directory structure:
gitextract_g10rsitx/
├── .gitignore
├── LICENSE
├── app.py
├── download_pretrain.py
├── en_llmprompt_script_gen.py
├── inference.py
├── modules/
│ ├── audio_detokenizer/
│ │ ├── audio_detokenizer.py
│ │ ├── bigvgan_wrapper.py
│ │ ├── flow_matching/
│ │ │ ├── dit_block.py
│ │ │ ├── model.py
│ │ │ ├── ode_wrapper.py
│ │ │ └── scheduler.py
│ │ ├── semantic_fm_prefix_streaming.py
│ │ └── vocoder/
│ │ ├── activations.py
│ │ ├── alias_free_activation/
│ │ │ ├── __init__.py
│ │ │ ├── cuda/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activation1d.py
│ │ │ │ ├── anti_alias_activation.cpp
│ │ │ │ ├── anti_alias_activation_cuda.cu
│ │ │ │ ├── compat.h
│ │ │ │ ├── load.py
│ │ │ │ └── type_shim.h
│ │ │ └── torch/
│ │ │ ├── __init__.py
│ │ │ ├── act.py
│ │ │ ├── filter.py
│ │ │ └── resample.py
│ │ ├── bigvgan.py
│ │ └── utils.py
│ ├── audio_tokenizer/
│ │ ├── audio_tokenizer.py
│ │ ├── quantize/
│ │ │ ├── __init__.py
│ │ │ ├── factorized_vector_quantize.py
│ │ │ ├── residual_vq.py
│ │ │ └── vector_quantize.py
│ │ ├── rep_codec.py
│ │ ├── transformer.py
│ │ └── vocos.py
│ └── tokenizer/
│ └── tokenizer.py
├── readme.md
├── requirements.txt
├── test/
│ ├── test_audio_detokenizer.py
│ ├── test_audio_tokenizer.py
│ └── test_tokenizer.py
└── zh_llmprompt_script_gen.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.safetensors
*.pt
*.vscode
**/__pycache__/
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/
tmp*
resources/
*.gradio
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Zeqian Ju
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: app.py
================================================
import gradio as gr
from huggingface_hub import snapshot_download
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
from inference import Model
import base64
model = Model()
model.generate_config.max_new_tokens = 50 * 50 # no more than 50s per turn
def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str):
try:
print(json_dialogue_input_str, type(json_dialogue_input_str))
print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1)
# json_data = json.loads(json_dialogue_input_str)
json_data = eval(json_dialogue_input_str.strip())
print(json_data, type(json_data))
def validate_json(data):
try:
if not isinstance(data, list):
return "json must be a dictionary"
cur_spk_should_be = 0
for item in data:
if item['role'] != str(cur_spk_should_be):
return f"role should be {cur_spk_should_be} in item {item}"
cur_spk_should_be = 1 - cur_spk_should_be
return None
except Exception as e:
return str(e)
validation_error = validate_json(json_data)
if validation_error:
raise gr.Error(validation_error)
role_mapping = {
"0": {
"ref_audio": prompt_audio_role0_file,
"ref_text": prompt_text_role0,
},
"1": {
"ref_audio": prompt_audio_role1_file,
"ref_text": prompt_text_role1,
}
}
# 完整输入 JSON (你需要根据你的模型调整)
model_input_json = {
"role_mapping": role_mapping,
"dialogue": json_data, # 从用户输入的 JSON 中获取 dialogue
}
print("模型推理输入 JSON:", model_input_json)
# 4. **[重要] 调用你的 Model 类的 `inference` 方法**
# audio_bytes = model.inference(model_input_json)
# 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放)
# return base64.b64decode(audio_bytes)
for cur_chunk in model.inference(model_input_json, streaming=True):
yield base64.b64decode(cur_chunk)
except Exception as e:
# return str(e) # 返回错误信息给 Gradio
raise gr.Error(str(e))
title_en = "# PODCAST generator (supports English and Chinese)"
title_zh = "# 播客生成 (支持英文和中文)"
instruct_en = "## See [Github](https://github.com/jzq2000/MoonCast) for podcast script generation."
instruct_zh = "## 播客剧本生成请参考 [Github](https://github.com/jzq2000/MoonCast)。"
input_labels_en = ["Prompt Audio for Role 0", "Prompt Text for Role 0", "Prompt Audio for Role 1", "Prompt Text for Role 1", "Script JSON Input"]
input_labels_zh = ["角色 0 的 Prompt 音频", "角色 0 的 Prompt 文本", "角色 1 的 Prompt 音频", "角色 1 的 Prompt 文本", "剧本 JSON 输入"]
output_label_en = "Generated Audio Output (streaming)"
output_label_zh = "生成的音频输出(流式)"
example_prompt_text_role0_en = "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem."
example_prompt_text_role0_zh = "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。"
example_prompt_text_role1_en = "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors."
example_prompt_text_role1_zh = "他最后就能让同样食材炒出来的菜味道大大提升。"
text_placeholder_zh = "对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。"
text_placeholder_en = "Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio."
example_json_en = '''[
{
"role": "0",
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
},
{
"role": "1",
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.",
},
{
"role": "0",
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.",
},
{
"role": "1",
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
}
]'''
example_json_zh = '''[
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
'''
# examples_en = [
# ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en]
# ]
# examples_zh = [
# ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
# ]
examples = [
['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en],
['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
]
# -------------------- 更新界面元素的函数 --------------------
def update_ui_language(language):
if language == "English":
return gr.update(value=title_en), \
gr.update(value=instruct_en), \
gr.update(label="UI Language"), \
gr.update(label=input_labels_en[0]), \
gr.update(label=input_labels_en[1]), \
gr.update(label=input_labels_en[2]), \
gr.update(label=input_labels_en[3]), \
gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \
gr.update(label=output_label_en), \
gr.update(value="Submit"), \
gr.update(label="Examples (Demonstration Use Only. Do Not Redistribute.)", headers=input_labels_en)
elif language == "中文":
return gr.update(value=title_zh), \
gr.update(value=instruct_zh), \
gr.update(label="UI 语言"), \
gr.update(label=input_labels_zh[0]), \
gr.update(label=input_labels_zh[1]), \
gr.update(label=input_labels_zh[2]), \
gr.update(label=input_labels_zh[3]), \
gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \
gr.update(label=output_label_zh), \
gr.update(value="提交"), \
gr.update(label="示例 (仅用于展示,切勿私自传播。)", headers=input_labels_zh)
else:
raise ValueError("Invalid language selected")
audio_output = gr.Audio(label=output_label_en, streaming=True)
css = """
.centered-title { /* CSS rule for centering title */
text-align: center !important;
}
"""
# -------------------- Gradio 界面定义 (修改) --------------------
with gr.Blocks(css=css) as iface:
title_output = gr.Markdown(value=title_zh, elem_classes="centered-title")
instruct_output = gr.Markdown(value=instruct_zh)
language_choice = gr.Radio(["中文", "English"], value="中文", label="UI语言")
with gr.Row(): # Main row to create two columns
with gr.Column(scale=2):
json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input
with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs
audio_input_role0 = gr.Audio(type="filepath", label=input_labels_zh[0]) # Prompt Audio for Role 0
text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0
with gr.Column(scale=1): #
audio_input_role1 = gr.Audio(type="filepath", label=input_labels_zh[2]) # Prompt Audio for Role 1
text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1
examples_component = gr.Examples(
examples=examples,
inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
cache_examples=False,
label="示例(仅用于展示,切勿私自传播。)",
)
submit_button = gr.Button("提交")
submit_button.click(
fn=process_json_and_generate_audio,
inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
outputs=audio_output
)
audio_output.render()
language_choice.change(
fn=update_ui_language,
inputs=language_choice,
outputs=[title_output, instruct_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset]
)
iface.launch(share=True)
================================================
FILE: download_pretrain.py
================================================
from huggingface_hub import snapshot_download
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
================================================
FILE: en_llmprompt_script_gen.py
================================================
# INPUT -> BRIEF -> SCRIPT
INPUT2BRIEF = '''
### Task Description
Please summarize the input document in plain text format according to the following structure. The summary should be creative, comprehensive, and include all interesting, uncommon, and valuable viewpoints and information.
- **Text Requirements**:
1. Directly output the result without any additional information.
2. The summary should be in English. Retain a small number of proper nouns, names, and abbreviations in their original form (e.g., Chinese characters).
3. Do not include any mathematical formulas.
4. Do not alter any proper nouns, names, or abbreviations from the original text. Unless there is a common translation, do not translate proper nouns. Do not attempt to modify the meaning of proper nouns.
5. **Intelligently convert numbers in abbreviations. For example, "a2b" should be interpreted as "a to b," not "a two b"; "a4b" as "a for b," not "a four b"; "v2" may represent "version two" or "second generation." Provide the original abbreviation and your suggested English translation.**
### Title and Author
- **Language Requirements**: English, formal written language.
- **Content Requirements**: Provide the title and author of the document. Briefly summarize the theme of the document and the author's background. Ensure all important information is included without omission and sufficient context is retained.
### Abstract
- **Language Requirements**: English, formal written language.
- **Content Requirements**:
1. What this document has done.
2. Whether similar work has been done before.
3. If similar work exists, why this document is still necessary.
4. How this document specifically addresses the topic.
5. How well this document achieves its goals.
- **Additional Requirements**: Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.
### Main Themes and Concepts
- **Language Requirements**: English, formal written language.
- **Content Requirements**: Each theme and concept should be organized according to the 3W principle:
- **What**: Clearly define the problem.
- **Why**: Analyze the problem and identify its root causes.
- **How**: Explain how the document addresses the problem.
- **Additional Requirements**:
1. Ensure each theme and concept is comprehensive and includes all important details. Fully elaborate on the "What" and "Why" sections.
2. Avoid technical details such as mathematical formulas in the "How" section. Use language that is easily understood by a general audience.
3. Ensure themes and concepts do not overlap and maintain clear logic.
4. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.
### Key Citations
- **Language Requirements**: English, formal written language.
- **Content Requirements**: Organize the content according to the following structure:
1. **Argument**: State what needs to be proven.
2. **Evidence**: Provide the material used to support the argument.
3. **Reasoning**: Describe the process of using evidence to prove the argument.
- **Additional Requirements**:
1. Ensure all evidence and reasoning are directly sourced from the original text without fabrication.
2. Ensure citation content is complete and retains sufficient context without simplification. Avoid using mathematical formulas in citations.
3. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.
### Conclusion
- **Language Requirements**: English, formal written language.
- **Content Requirements**: Highlight the most important and impactful aspects of the document. Compared to the abstract, this section should provide more detailed insights related to the main themes and concepts. It may also include future directions for improvement, current application scenarios, and existing challenges.
'''
BRIEF2SCRIPT = '''
## 1. Task Overview
Please generate a lively English podcast script based on the provided English summary text and your knowledge of the topic. The script should feature a dialogue between two speakers who take turns speaking. Output format should be JSON-parsable **list**. Each speaker's turn is a **dictionary** containing "speaker" and "text" fields. Example format: `[{{"speaker": "1", "text": "xxx"}}]`. The "speaker" field indicates the speaker's identity (1 for host, 2 for guest), and the "text" field is the spoken content. Output should start directly with the JSON code block, without any extra information.
## 2. Content and Structure
### (1) Text Content
- The summary text contains all important information, which needs to be comprehensively selected and incorporated into the script.
- Present information through a dialogue between two speakers, maintaining creativity and abstracting away unimportant details. For example, listeners aren't concerned with specific test names, but rather the task itself, the results, and the analysis.
### (2) Structure Design
- **Opening:** Introduce the topic and briefly describe the discussion content, without mentioning speaker names.
- **Key Theme Discussion:** Discuss important themes based on the summary text. Expand on the summary, don't just repeat it verbatim.
- **Closing:** Briefly recap the discussion highlights and offer an outlook on future or technological developments.
## 3. Language Style
### (1) Conversational Style
- The text should be as conversational as possible, aiming for a style similar to automatic speech recognition output. Include filler words such as 'um,' 'uh,' 'like,' 'you know,' 'so,' 'right?', and so on. Response words such as 'Yeah,' 'Right,' 'Okay,' and similar. Conversational expressions, repetitions, informal grammar, etc. Use short sentences. Avoid directly copying and pasting structured text from the summary text. Parentheses and other symbols not typically found in speech recognition transcripts should be avoided. Spaces within sentences indicate pauses. Be aware that there might be homophone errors, potentially due to accents. Questions should sound very conversational. Pay particular attention to incorporating conversational details, especially in questions. For example:
[
{{ "speaker": "1",
"text": "Welcome back to the podcast, everyone. Today we're diving into, uh, something that's really changing everything around us, A I."
}},
{{ "speaker": "2",
"text": "Yeah, A I is, like, everywhere now, isn't it? It's kinda wild to think about."
}},
{{ "speaker": "1",
"text": "Totally. And we're seeing it in so many areas of daily life. Like, even just recommending what to watch, or, you know, suggesting products online."
}},
{{ "speaker": "2",
"text": "Mhm, exactly. And it's not just online stuff, right? Think about smart homes, or even self-driving cars. It's getting pretty advanced."
}},
{{ "speaker": "1",
"text": "Right, self-driving cars are still a bit futuristic for most of us, but, uh, even things like voice assistants on our phones, that's A I, isn't it?"
}},
{{ "speaker": "2",
"text": "Definitely. Siri, Alexa, Google Assistant, all powered by A I. It's become so normal, we almost don't even think about it anymore."
}},
{{ "speaker": "1",
"text": "Yeah, it's like, integrated into everything. But is that a good thing, you think? Like, are there downsides to all this A I in our lives?"
}},
{{ "speaker": "2",
"text": "Well, that's the big question, isn't it? On the one hand, it makes things so much more convenient, saves us time, maybe even makes things safer in some ways."
}},
{{ "speaker": "1",
"text": "Safer how?"
}},
{{ "speaker": "2",
"text": "Uh, well, like in healthcare, for example. A I can help doctors diagnose diseases earlier, maybe even more accurately. That's a huge plus, right?"
}},
{{ "speaker": "1",
"text": "Yeah, that's a really good point. Medical applications are definitely exciting. But what about the concerns, you know? Like job displacement or privacy issues?"
}},
{{ "speaker": "2",
"text": "Right, those are super valid concerns. Job displacement is a big one. If A I can do more and more tasks, what happens to human workers? And privacy,"
}},
{{ "speaker": "1",
"text": "And privacy is huge, especially with all the data A I systems collect. It's a lot to process."
}},
{{ "speaker": "2",
"text": "Exactly. So, it's not just sunshine and roses, is it? We need to be mindful of the ethical implications and make sure it's used responsibly."
}},
{{ "speaker": "1",
"text": "Definitely. It's a powerful tool, but like any tool, it can be used for good or, you know, not so good. It's up to us to guide its development, right?"
}},
{{ "speaker": "2",
"text": "Absolutely. And that's a conversation we all need to be part of, not just the tech people, but everyone."
}}
]
### (2) Punctuation
- Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited. Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified. do not use markdown syntax. For example,**bold** or *italic* text should be avoided. Use plain text only.
- If interrupted by the other person's response, the sentence should end with a comma, not a period.
## 4. Information Organization and Logic
### (1) Referencing Issues
- Given that listeners won't have access to the summary text, any references must provide sufficient context for comprehension.
- Avoid simply paraphrasing; instead, explain referenced content in your own words.
- Explanations of technical terms should be creative and avoid simply stating 'this means what?' You can use examples, metaphors, and so on for explanations, but ensure you also clarify the rationale behind the metaphor. Explanations can be provided in response to a question from the other speaker, or you can offer explanations proactively. Technical terms that are not mentioned don't need explanation. Technical terms that are mentioned don't necessarily need immediate explanation; they can be explained alongside other technical terms. Technical terms in the summary text might differ slightly from the surrounding text; you'll need to provide reasonable explanations based on the context.
### (2) Information Density
- Ensure moderate information density, avoiding excessively high or low density. The goal of appropriate information density is to enable listeners without prior knowledge to quickly grasp the document's purpose, rationale, and methodology.
- To prevent information overload, the script should avoid delving into details like mathematical formulas, test setups, or specific experimental metrics. Instead, it should use simple, generalized language for descriptions.
- To avoid excessively low information density, ensure each topic is discussed for at least 4 speaker turns, moving beyond simple keyword listings. Discuss topics from multiple angles whenever possible, going beyond the provided summary text. Given that the summary text is highly generalized, the script should elaborate on it and discuss further details. Feel free to use your knowledge to supplement background information, provide examples, and so forth, to enhance listener understanding.
- Techniques to increase information density:
1. Incorporate memorable quotes. Add impactful, attention-grabbing sentences to the script, either original ones or quotes from other sources.
2. Boost knowledge content. Judiciously add knowledge points to the script to make listeners feel more informed and rewarded.
3. Introduce novel information. Incorporate new concepts to spark listener curiosity, particularly information they're unaware of but would find valuable. This is crucial.
4. Employ reverse thinking. Include information from diverse angles, challenging listeners' existing perspectives and presenting alternative viewpoints.
5. Generate contrast and impact. The script can offer unconventional (yet plausible) descriptions of familiar concepts to create a contrast with listener expectations. This contrast contributes to information density.
- Techniques to decrease information density:
1. Use short sentences: Concise and easy to understand, making the narrative more compact. Do not have too much information in one sentence.
2. Describe details: Vague and abstract information makes it difficult for listeners to build understanding, while more details create a sense of imagery and are easier to read.
3. Use more scenario-based descriptions: Scenarios are concrete and visual. Listeners can easily receive the conveyed information and be emotionally touched.
4. Talk more about facts: Talking about facts makes it more real, and readers can empathize more, thus lowering the information density of the copy.
5. Tell more stories: Tell your own stories, stories around you, and stories you've heard. Stories can bring listeners into the scene, making it easier to concentrate on listening.
6. Use more verbs and concrete nouns: Verbs and concrete nouns make it easier for listeners to visualize, while adjectives make complex copy harder to understand.
7. Avoid using mathematical formulas: Mathematical formulas are not conducive to public understanding.
## 5. Dialogue Design
### (1) Speaker Roles
- The script includes a host and a guest. Speaker 1 is the host, responsible for opening and closing the show, skilled at using questions to control the pace of the conversation, and using vivid examples to make knowledge less dry. Speaker 2 is the guest, primarily responsible for introducing the document content, has amazing knowledge reserves in the field, and is good at organizing language in a structured and easy-to-understand way.
- Both speakers are enthusiastic and cheerful, like to combine personal stories or examples for discussion, and bring a direct experience to listeners. They are happy to discuss digressive stories.
- The two speakers actively interact and frequently use interruption words such as "um" to indicate agreement with each other. Response words need to be inserted into the dialogue according to the timing. Sentences before being interrupted end with a comma, not a period.
- Ensure consistent speaker roles. Do not have the host introduce technical details, or have the guest guide the host to discuss topics.
- The host gradually increases their understanding of the field based on the guest's answers. However, the host may not understand immediately or completely correctly. The host can express misunderstanding or raise some questions that ordinary people might have. In this case, the guest will further explain in more accessible language, or specifically answer common questions or misunderstandings. This kind of interaction is more realistic and easier for listeners to understand than always correct hosts and guests.
### (2) Topic Order Arrangement
- The host will arrange the topics according to the summary text and ensure logical connections between topics, such as transitioning from overall to details, from details to overall, from cause to effect, from technology to application, etc.
- The host will guide the pace of the conversation and discuss topics in the order of the summary text. Guests should not interfere with topic transitions.
### (3) Knowledge Rate
- The knowledge rate in the script needs to be reasonable. Do not introduce a large amount of knowledge too quickly in a short period of time. Knowledge
## 6. Other Requirements
### (1) English Numbers and Foreign Words
1. The script will be used for English podcast content recording. Please ensure most numbers and foreign words are rendered naturally in English to facilitate correct pronunciation.
2. Please intelligently determine the correct pronunciation according to the context. For example, "2021" if expressing a year, should be converted to "two thousand and twenty-one" or "twenty twenty-one". But if expressing a number, it should be "two thousand and twenty-one". For some uncommon English abbreviations, if the pronunciation needs to be read letter by letter according to the context, you must ensure that there is a space between each letter, such as "AI" adding a space as "A I", to avoid the model misinterpreting it as a word. For example, "API" should be rendered as "A P I".
3. Small amount of Chinese is allowed, especially for nouns, if it fits naturally within the conversational English context.
### (2) Script Length
1. Please ensure that the total length of the 'text' values does not exceed 3,000 words and the number of speaker turns is kept within 60, otherwise it will be unqualified. Please choose technical details and topic concepts to discuss. Do not shorten the depth of discussion on each topic for the sake of word limit, do not be limited to the summary text, and give full play to your knowledge.
INPUT: {BRIEF}
## Re-emphasize:
Speaker 1 is the host, and Speaker 2 is the guest. Neither speaker has a name. The script text only uses commas, periods, and question marks. Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited. Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified. Please prioritize in-depth discussion for each topic. Don't limit yourself to the summary text; instead, use your knowledge to expand upon the topics, providing background information and illustrative examples to enhance listener understanding.
Ensure that numbers and foreign words are rendered naturally in English for accurate pronunciation during recording. In technical contexts, English abbreviations sometimes use numerical digits in place of words (e.g., "a2b" for "a to b," "a4b" for "a for b"). Please translate these abbreviations into appropriate English phrases based on the context. While the script is primarily in English, a small amount of Chinese, especially for nouns, is acceptable if it integrates naturally into the conversational flow.
OUTPUT:
'''
================================================
FILE: inference.py
================================================
import sys
sys.path.append(".")
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref, detokenize_streaming, detokenize_noref_streaming
import torch
import os
from glob import glob
import base64
import io
import torchaudio
from transformers import AutoModelForCausalLM, GenerationConfig
import librosa
from tqdm import tqdm
from pydub import AudioSegment
class Model(object):
def __init__(self):
self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
self.speech_token_offset = 163840
print(self.extra_tokens)
self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
self.user_ids = self.tokenizer.encode("user") # [1495]
self.audio_ids = self.tokenizer.encode("audio") # [26229]
self.spk_0_ids = self.tokenizer.encode("0") # [501]
self.spk_1_ids = self.tokenizer.encode("1") # [503]
self.msg_end = self.extra_tokens.msg_end # 260
self.user_msg_start = self.extra_tokens.user_msg_start # 261
self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
self.name_end = self.extra_tokens.name_end # 272
self.media_begin = self.extra_tokens.media_begin # 273
self.media_content = self.extra_tokens.media_content # 274
self.media_end = self.extra_tokens.media_end # 275
self.audio_tokenizer = get_audio_tokenizer()
self.audio_detokenizer = get_audio_detokenizer()
model_path = "resources/text2semantic"
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
self.generate_config = GenerationConfig(
max_new_tokens=200 * 50, # no more than 200s per turn
do_sample=True,
top_k=30,
top_p=0.8,
temperature=0.8,
eos_token_id=self.media_end,
)
def _clean_text(self, text):
# you can add front-end processing here
text = text.replace("“", "")
text = text.replace("”", "")
text = text.replace("...", " ")
text = text.replace("…", " ")
text = text.replace("*", "")
text = text.replace(":", ",")
text = text.replace("‘", "'")
text = text.replace("’", "'")
text = text.strip()
return text
@torch.inference_mode()
def _process_text(self, js):
if "role_mapping" in js:
for role in js["role_mapping"].keys():
js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
for turn in js["dialogue"]:
turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
return js
def inference(self, js, streaming=False):
js = self._process_text(js)
if "role_mapping" not in js:
if streaming:
return self.infer_without_prompt_streaming(js)
else:
return self.infer_without_prompt(js)
else:
if streaming:
return self.infer_with_prompt_streaming(js)
else:
return self.infer_with_prompt(js)
@torch.inference_mode()
def infer_with_prompt(self, js):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
cur_role_dict = dict()
for role, role_item in js["role_mapping"].items():
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
prompt_ids = semantic_tokens + self.speech_token_offset
cur_role_dict[role] = {
"ref_bpe_ids": role_item["ref_bpe_ids"],
"wav_24k": waveform_24k,
"semantic_tokens": semantic_tokens,
"prompt_ids": prompt_ids
}
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
@torch.inference_mode()
def infer_with_prompt_streaming(self, js):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
cur_role_dict = dict()
for role, role_item in js["role_mapping"].items():
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
prompt_ids = semantic_tokens + self.speech_token_offset
cur_role_dict[role] = {
"ref_bpe_ids": role_item["ref_bpe_ids"],
"wav_24k": waveform_24k,
"semantic_tokens": semantic_tokens,
"prompt_ids": prompt_ids
}
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
for cur_chunk in detokenize_streaming(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
@torch.inference_mode()
def infer_without_prompt(self, js):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
@torch.inference_mode()
def infer_without_prompt_streaming(self, js):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
for cur_chunk in detokenize_noref_streaming(self.audio_detokenizer, torch_token):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
if __name__ == "__main__":
model = Model()
# speaker should be interleaved
zh_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./zh_prompt0.wav",
"ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
},
"1": {
"ref_audio": "./zh_prompt1.wav",
"ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes_gen = model.inference(zh_test_json, streaming=True)
audio = AudioSegment.empty()
for cur_chunk in audio_bytes_gen:
cur_chunk = base64.b64decode(cur_chunk)
audio_chunk = AudioSegment.from_file(io.BytesIO(cur_chunk), format="mp3")
audio += audio_chunk
audio.export("tmp_generated_zh_stream.mp3", format="mp3")
print("zh stream done")
audio_bytes = model.inference(zh_test_json)
file_to_save = open(f"tmp_generated_zh.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("zh done")
# speaker should be interleaved
en_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./en_prompt0.wav",
"ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
},
"1": {
"ref_audio": "./en_prompt1.wav",
"ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
},
{
"role": "1",
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
},
{
"role": "0",
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
},
{
"role": "1",
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
}
]
}
audio_bytes = model.inference(en_test_json)
file_to_save = open(f"tmp_generated_en.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("en done")
# also support inference without prompt
# speaker should be interleaved
without_prompt_test_json = {
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes = model.inference(without_prompt_test_json)
file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("without prompt done")
================================================
FILE: modules/audio_detokenizer/audio_detokenizer.py
================================================
import torch
from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper
from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper
class PrefixStreamingFlowMatchingDetokenizer:
def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:
self.dtype = torch.bfloat16
print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")
self.vocoder = vocoder
self.vocoder.to_dtype(self.dtype)
self.semantic_fm = fm
# initialize mel_spec
self.max_pos_size = 4096
self.is_timbre_semantic_token = False
self.pre_mel = None
self.frame_size = 480 # how many samples in a frame
self.pre_wav = None
self.state_dict_backup = None
self.hamming_window_cache = {}
self.previous_chunk_left = None
self.look_ahead_tokens = look_ahead_tokens
self.clear_states()
@classmethod
def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device,
look_ahead_tokens=0,
max_prompt_chunk=2, max_kv_cache_tokens=900,
use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)
return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
@torch.inference_mode()
def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):
"""
Arguments:
timbre_speech: torch.Tensor, shape [B, N_speech_24k]
timbre_semantic_token: torch.Tensor, shape [B, N]
chunk_size: int, chunk size for prefilling
timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
"""
if timbre_mel is None:
assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None"
assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0
assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))
else:
assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0
assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = timbre_mel.squeeze(0)
if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
# pad mel_spec
mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))
elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
# truncate mel_spec
mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]
# clear all states
self.semantic_fm.clear_all_states()
self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)
self.state_dict_backup = self.semantic_fm.state_dict()
@torch.inference_mode()
def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1):
assert len(semantic_token.shape) == 2 and ode_step > 0
assert semantic_token.shape[0] == 1
semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
semantic_token = semantic_token.squeeze(0)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
semantic_token_previous = self.previous_chunk_left["semantic_token"]
semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)
x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
self.previous_chunk_left = {"semantic_token": None}
speech_mel = self.semantic_fm.infer_chunk(
xt_chunk=x_t_chunk,
semantic_tokens_chunk=semantic_token,
start_position_id=self.semantic_fm.start_position_id,
ode_steps=ode_step,
verbose=verbose,
look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,
cache=self.previous_chunk_left,
ode_solver=ode_solver
)
chunk_size = speech_mel.shape[0]
length = speech_mel.shape[0]
self.semantic_fm.start_position_id += length
self.semantic_fm.update_incremental_state()
self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens
# smoothing
# I will maintain the history of seqlen wav
# For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
# For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the
if self.pre_mel is None: # first chunk, related to TTFB
concat_mel = speech_mel
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk
self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step
self.pre_mel = speech_mel[-chunk_size//2:, :]
ret_wav = reconstructed_wav.float()
else:
concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
# fetch history
prev_speech_len = self.pre_wav.shape[1]
if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
gen_speech_len = prev_speech_len * 2
else:
gen_speech_len = concat_reconstructed_wav.shape[1] // 2
reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk
if gen_speech_len not in self.hamming_window_cache:
self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)
hamming_window = self.hamming_window_cache[gen_speech_len]
# apply smoothing of the first half chunk
reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \
reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]
res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
res_mel_len = res_speech_len // self.frame_size
self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
self.pre_mel = speech_mel[-res_mel_len:, :]
ret_wav = reconstructed_wav.float()
if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:
# out of position id,
self.semantic_fm.clear_all_states()
self.semantic_fm.load_state_dict(self.state_dict_backup)
return ret_wav
def clear_states(self):
self.semantic_fm.clear_all_states()
self.previous_chunk_left = None
self.pre_mel = None
self.pre_wav = None
def get_audio_detokenizer():
fm_model_config = "resources/audio_detokenizer/config.yaml"
fm_ckpt_path = "resources/audio_detokenizer/model.pt"
bigvgan_config_file = "resources/vocoder/config.json"
bigvgan_ckpt_path = "resources/vocoder/model.pt"
device=torch.cuda.current_device()
detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
vocoder_config=bigvgan_config_file,
vocoder_ckpt=bigvgan_ckpt_path,
max_prompt_chunk=10, # 10 * 3 = 30s
fm_config=fm_model_config,
fm_ckpt=fm_ckpt_path,
device=device,
use_cfg=False,
look_ahead_tokens=12)
return detokenizer
def detokenize(detokenizer, tokens, ref_wav, ref_tokens):
with torch.no_grad():
detokenizer.clear_states()
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
cache_speech_collection.append(gen_speech)
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all
def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):
with torch.no_grad():
detokenizer.clear_states()
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
yield gen_speech
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
yield gen_speech
def detokenize_noref(detokenizer, tokens):
with torch.no_grad():
detokenizer.clear_states()
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
cache_speech_collection.append(gen_speech)
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all
def detokenize_noref_streaming(detokenizer, tokens):
with torch.no_grad():
detokenizer.clear_states()
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
yield gen_speech
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
yield gen_speech
================================================
FILE: modules/audio_detokenizer/bigvgan_wrapper.py
================================================
import os
import json
import logging
import librosa
import torch
from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint
logger = logging.getLogger(__name__)
class BigVGANWrapper:
def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
self.vocoder = vocoder.to(device)
if dtype is not None:
self.vocoder = self.vocoder.to(dtype)
self.vocoder = self.vocoder.eval()
self.device = device
self.h = h
def to_dtype(self, dtype):
self.vocoder = self.vocoder.to(dtype)
def extract_mel_from_wav(self, wav_path=None, wav_data=None):
"""
params:
wav_path: str, path of the wav, should be 24k
wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
return:
mel: [T, num_mels], torch.tensor
"""
if wav_data is None:
wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
wav_data = torch.tensor(wav_data).unsqueeze(0)
mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
return mel.squeeze(0).transpose(0, 1)
@torch.inference_mode()
def extract_mel_from_wav_batch(self, wav_data):
"""
params:
wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
return:
mel: [Batch, T, num_mels], torch.tensor
"""
wav_data = torch.tensor(wav_data)
mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
return mel.transpose(1, 2)
def decode_mel(self, mel):
"""
params:
mel: [T, num_mels], torch.tensor
return:
wav: [1, T], torch.tensor
"""
mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
wav = self.vocoder(mel)
return wav.squeeze(0)
def decode_mel_batch(self, mel):
"""
params:
mel: [B, T, num_mels], torch.tensor
return:
wav: [B, 1, T], torch.tensor
"""
mel = mel.transpose(1, 2).to(self.device)
wav = self.vocoder(mel)
return wav
@classmethod
def from_pretrained(cls, model_config, ckpt_path, device):
with open(model_config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
vocoder = BigVGAN(h, True)
state_dict_g = load_checkpoint(ckpt_path, "cpu")
vocoder.load_state_dict(state_dict_g["generator"])
logger.info(">>> Load vocoder from {}".format(ckpt_path))
return cls(vocoder, device, h)
================================================
FILE: modules/audio_detokenizer/flow_matching/dit_block.py
================================================
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
# the last shape is "self.hidden_dim / 2" because we convert to complex
assert x.ndim == 4
assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \
f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'
# reshape freq cis to match and apply pointwise multiply
# new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
flash_attention: bool = True
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = flash_attention
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qk_norm = qk_norm
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:
B, N, C = x.shape
if self.fused_attn:
if nopadding:
qkv = self.qkv(x)
qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
q, k, v = qkv.split([self.num_heads] * 3, dim=1)
q, k = self.q_norm(q), self.k_norm(k)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
q, k = apply_rotary_emb(q, k, rotary_pos_emb)
if incremental_state is not None:
if "prev_k" in incremental_state:
prev_k = incremental_state["prev_k"]
k = torch.cat([prev_k, k], dim=1)
if "cur_k" not in incremental_state:
incremental_state["cur_k"] = {}
incremental_state["cur_k"] = k
if "prev_v" in incremental_state:
prev_v = incremental_state["prev_v"]
v = torch.cat([prev_v, v], dim=1)
if "cur_v" not in incremental_state:
incremental_state["cur_v"] = {}
incremental_state["cur_v"] = v
q = q.view(B * N, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
x = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen_k,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
if incremental_state is not None:
raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.")
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if self.qk_norm:
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
# re-bind
qkv = torch.stack((q, k, v), dim=2)
# pack qkv with seq_len
qkv_collect = []
for i in range(qkv.shape[0]):
qkv_collect.append(
qkv[i, :seq_len[i], :, :, :]
)
qkv = torch.cat(qkv_collect, dim=0)
x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)
# unpack and pad 0
x_collect = []
for i in range(B):
x_collect.append(
x[cu_seqlens[i]:cu_seqlens[i+1], :, :]
)
x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
if ffn_type == "vanilla_mlp":
from timm.models.vision_transformer import Mlp
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
else:
raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
x_ = modulate(self.norm1(x), shift_msa, scale_msa)
if incremental_state is not None:
if "attn_kvcache" not in incremental_state:
incremental_state["attn_kvcache"] = {}
inc_attn = incremental_state["attn_kvcache"]
else:
inc_attn = None
x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_msa * x_
x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
x_ = self.mlp(x_)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_mlp * x_
return x
================================================
FILE: modules/audio_detokenizer/flow_matching/model.py
================================================
import torch
import torch.nn as nn
import math
from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
interpolation_factor: int = 1, max_seq_length: int = 4096):
print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# ROPE type-A extention
# we choose to use interpolation rather than extrapolation for better position encoding
# for scale purposes, t should be a float tensor
t = torch.arange(end, device=freqs.device).float()
scale = 1.0 / float(interpolation_factor)
t *= scale
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
# e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
if max_seq_length < end:
freqs_cis = freqs_cis[:max_seq_length,].clone()
return freqs_cis
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).float().to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2 # d/2
emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = self.make_positions(input, self.padding_idx)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
def make_positions(self, tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
class DiTPrefix(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size,
output_size,
semantic_vocab_size,
hidden_size=1024,
depth=12,
num_heads=4,
# mlp related
mlp_ratio=4.0,
ffn_type="conv1d_conv1d",
ffn_gated_glu=True,
ffn_act_layer="gelu",
ffn_conv_kernel_size=5,
# rope
use_rope=False,
rope_params={
"max_position_embeddings": 4096,
"rope_base": 10000.0,
"rope_interpolation_factor": 1.0,
},
position_embedding_type="sincos",
max_seq_len=4096,
prompt_cfg_dropout=0.0
):
super().__init__()
self.num_heads = num_heads
self.prompt_cfg_dropout = prompt_cfg_dropout
self.t_embedder = TimestepEmbedder(hidden_size)
self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)
self.input_linear = nn.Linear(input_size, hidden_size)
# position embedding
if position_embedding_type == "learnable":
self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)
elif position_embedding_type == "sincos":
self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)
elif position_embedding_type == "skip":
self.position_embedding = None
else:
raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type))
self.use_rope = use_rope
if self.use_rope:
assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding."
rope_dim = hidden_size // num_heads
self.rotary_pos_emb = precompute_freqs_cis(
rope_dim, rope_params["max_position_embeddings"],
theta=rope_params["rope_base"],
interpolation_factor=rope_params["rope_interpolation_factor"],
max_seq_length=max_seq_len
)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, output_size)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):
"""
Forward pass of DiT.
x: (N, T, C) tensor of inputs (latent representations of speech)
position_ids: (N, T) tensor of positional indices
t: (N,) tensor of diffusion timesteps
condition: (N, T) tensor of semantic tokens
seq_len: (N,) tensor of sequence lengths
"""
condition = self.semantic_token_embedding(condition) # (N, T, D)
x = self.input_linear(x)
if self.position_embedding is not None:
position_emb = self.position_embedding(position_ids)
x = x + position_emb
# ROPE
if self.use_rope:
bsz, seqlen = position_ids.shape
if self.rotary_pos_emb.device != position_ids.device:
self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),
dtype=self.rotary_pos_emb.dtype,
device=self.rotary_pos_emb.device)
for b in range(bsz):
cur_rope = rotary_pos_emb[b]
cur_position_ids = position_ids[b]
cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
else:
rotary_pos_emb = None
t = self.t_embedder(t) # (N, D)
c = t.unsqueeze(1) + condition # (N, T, D)
for block_idx, block in enumerate(self.blocks):
# x = block(x, c, attn_mask) # (N, T, D)
# XXX mask could be None because we always use full mask
if incremental_state is not None:
if block_idx not in incremental_state:
incremental_state[block_idx] = {}
incr = incremental_state[block_idx]
else:
incr = None
x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)
x = self.final_layer(x, c) # (N, T, C)
return x
================================================
FILE: modules/audio_detokenizer/flow_matching/ode_wrapper.py
================================================
import torch
import torch.nn as nn
from functools import lru_cache
import copy
@lru_cache(maxsize=1)
def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
return torch.zeros(numel, device=device, dtype=dtype)
class StreamingODEWrapperForPrefix(nn.Module):
def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0):
super(StreamingODEWrapperForPrefix, self).__init__()
self.net = net
self.x_mask = x_mask
self.x_cond = x_cond
assert use_cfg == False, "cfg is not supported in streaming detokenizer"
self.use_cfg = use_cfg
self.use_cfg_rescale = use_cfg_rescale
self.cfg_init = cfg_init
self.cfg_scale = cfg_scale
self.cfg_token_id = cfg_token_id
self.cfg_schedule = cfg_schedule
self.position_ids = None
self.seq_len = None
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def clear_all_states(self):
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def state_dict(self):
return {
"incremental_state": copy.deepcopy(self.incremental_state),
"kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
"cu_seqlens": copy.deepcopy(self.cu_seqlens),
"cu_maxlen": copy.deepcopy(self.cu_maxlen),
"cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
"cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
"previous_seqlen": copy.deepcopy(self.previous_seqlen)
}
def load_state_dict(self, state_dict):
self.incremental_state = state_dict["incremental_state"]
self.kv_cache_tokens = state_dict["kv_cache_tokens"]
self.cu_seqlens = state_dict["cu_seqlens"]
self.cu_maxlen = state_dict["cu_maxlen"]
self.cu_seqlens_k = state_dict["cu_seqlens_k"]
self.cu_maxlen_k = state_dict["cu_maxlen_k"]
self.previous_seqlen = state_dict["previous_seqlen"]
def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
if not self.use_cfg:
self.x_mask = x_mask
self.x_cond = x_cond
else:
self.x_cond = torch.cat((x_cond, x_cond), dim=0)
self.x_mask = torch.cat((x_mask, x_mask), dim=0)
position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]
position_ids = torch.tensor([position_ids_cur])
if not self.use_cfg:
self.position_ids = position_ids.to(self.x_cond.device).long()
self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
else:
self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()
self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()
cu_seqlens = torch.cumsum(self.seq_len, dim=0)
self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()
self.cu_maxlen = self.seq_len.cpu().max()
if self.cu_seqlens_k is None:
self.cu_seqlens_k = self.cu_seqlens
self.cu_maxlen_k = self.cu_maxlen
previous_seqlen = self.seq_len
else:
previous_seqlen_old = cache["previous_seqlen"]
previous_seqlen = previous_seqlen_old + self.seq_len
# calculate cu_seqlens_k
cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()
self.cu_maxlen_k = previous_seqlen.cpu().max()
self.previous_seqlen = previous_seqlen
ret_cache = {
"previous_seqlen": previous_seqlen
}
return ret_cache
def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}):
assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"
for layer_idx, layer_cache in self.incremental_state.items():
# update attention kv cache
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
if self.kv_cache_tokens > max_kv_cache_tokens:
# drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens
if reserve_kv_cache_tokens == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
elif reserve_tokens_excludeprompt == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens]
else:
layer_cache["attn_kvcache"]["prev_k"] = torch.cat([
layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens],
layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
], dim=1)
layer_cache["attn_kvcache"]["prev_v"] = torch.cat([
layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens],
layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
], dim=1)
bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long()
condition_cache["previous_seqlen"] = self.previous_seqlen
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
# clear current cache
layer_cache["attn_kvcache"].pop("cur_k")
layer_cache["attn_kvcache"].pop("cur_v")
def forward(self, t, x, args=None):
# t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()
if self.use_cfg:
raise NotImplementedError("cfg is not supported in streaming detokenizer.")
else:
pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids,
cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,
cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,
incremental_state=self.incremental_state, nopadding=True,
mask=None, seq_len=None
)
return pred_noise
================================================
FILE: modules/audio_detokenizer/flow_matching/scheduler.py
================================================
import torch
from abc import abstractmethod, ABC
try:
from torchdyn.core import NeuralODE
NEURALODE_INSTALLED = True
except ImportError:
NEURALODE_INSTALLED = False
class SchedulerBase(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def set_timesteps(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def add_noise(self):
pass
class StreamingFlowMatchingScheduler(SchedulerBase):
def __init__(self, timesteps=1000, sigma_min=1e-4,
) -> None:
super().__init__()
self.sigma_min = sigma_min
self.timesteps = timesteps
self.t_min = 0
self.t_max = 1 - self.sigma_min
self.neural_ode = None
def set_timesteps(self, timesteps=15):
self.timesteps = timesteps
def step(self, xt, predicted_v):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
xt = xt + h * predicted_v
return xt
def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
if verbose:
gt_v = x0 - xt
for t in time_steps:
predicted_v = ode_wrapper(t, xt)
if verbose:
dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
print("Time: {}, Distance: {}".format(t, dist))
xt = xt + h * predicted_v
return xt
def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
if not NEURALODE_INSTALLED:
raise ImportError("NeuralODE is not installed, please install it first.")
if self.neural_ode is None:
self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)
eval_points, traj = self.neural_ode(xt, time_steps)
return traj[-1]
def add_noise(self, original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,):
ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
return x_noisy, ut
================================================
FILE: modules/audio_detokenizer/semantic_fm_prefix_streaming.py
================================================
import yaml
import logging
import time
import os
import torch
from modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix
from modules.audio_detokenizer.flow_matching.model import DiTPrefix
from modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler
logger = logging.getLogger(__name__)
class StreamingSemanticFMWrapper:
def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2,
use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear", cfg_token_id=0,
normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device("cpu")) -> None:
self.dtype = torch.bfloat16
self.speech_model = speech_model.to(device).to(self.dtype)
self.speech_model = self.speech_model.eval()
self.device = device
self.normalize_mel = normalize_mel
self.mel_mean = mel_mean
self.mel_std = mel_std
self.use_cfg = use_cfg
self.use_cfg_rescale = use_cfg_rescale
self.cfg_init = cfg_init
self.cfg_scale = cfg_scale
self.cfg_schedule = cfg_schedule
self.incremental_state = {}
self.condition_cache = {"previous_seqlen": 0}
logger.info(f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}")
self.scheduler = StreamingFlowMatchingScheduler()
self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None,
use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id)
self.max_kv_cache_tokens = max_kv_cache_tokens
self.max_prompt_chunk = max_prompt_chunk
self.reserve_kv_cache_tokens = 0
@torch.inference_mode()
def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id,
cache = None, look_ahead_tokens=0,
ode_steps=15, verbose=False, ode_solver="neural_ode_euler"):
"""
semantic_tokens: [T_1], torch.LongTensor
xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
ode_steps: int, number of ode steps, default 15
verbose: bool, default False
ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
"""
bs = 1
self.scheduler.set_timesteps(ode_steps)
semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype)
t_span = torch.linspace(0, 1, self.scheduler.timesteps)
x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool()
cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
if verbose:
t_start = time.time()
if ode_solver == "neural_ode_euler":
x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
elif ode_solver == "naive_euler":
x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
else:
raise NotImplementedError("ode_solver should be in ('neural_ode_euler', 'naive_euler')")
if look_ahead_tokens > 0:
semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:]
cache["semantic_token"] = semantic_tokens_left
x_t_ret = x_t[:, :-look_ahead_tokens, :]
else:
x_t_ret = x_t
if look_ahead_tokens > 0:
x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool()
self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache)
self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret)
else:
self.condition_cache = cache_ret
if verbose:
t_end = time.time()
logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}")
if self.normalize_mel:
x_t_ret = x_t_ret * self.mel_std + self.mel_mean
return x_t_ret.squeeze(0)
@torch.inference_mode()
def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver="neural_ode_euler"):
"""
semantic_tokens: [T_1], torch.LongTensor
prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
prompt_semantic_tokens, [T_2], torch.LongTensor
ode_steps: int, number of ode steps, default 15
verbose: bool, default False
ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
"""
assert semantic_tokens.dim() == 1
x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype)
seq_len = semantic_tokens.shape[0]
num_chunks = seq_len // chunk_size
if seq_len % chunk_size != 0:
num_chunks += 1
x_pred_collect = []
if verbose:
t_start = time.time()
for chunk_id in range(num_chunks):
start = chunk_id * chunk_size
end = min(start + chunk_size, seq_len)
semantic_tokens_chunk = semantic_tokens[start:end]
x_t_chunk = x_t[start:end, :]
x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id,
ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver)
self.start_position_id += end - start
self.update_incremental_state()
x_pred_collect.append(x_pred)
if verbose:
t_end = time.time()
logger.info(f"[ODE] Time cost: {t_end - t_start}")
x_pred = torch.cat(x_pred_collect, dim=0)
return x_pred
def clear_all_states(self):
self.start_position_id = 0
self.condition_cache = {"previous_seqlen": 0}
self.ode_wrapper.clear_all_states()
def state_dict(self):
return {
"start_position_id": self.start_position_id,
"ode_wrapper": self.ode_wrapper.state_dict(),
"condition_cache": self.condition_cache
}
def load_state_dict(self, state_dict):
if state_dict is not None:
self.start_position_id = state_dict["start_position_id"]
self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"])
self.condition_cache = state_dict["condition_cache"]
def update_incremental_state(self):
self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache)
@torch.inference_mode()
def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
"""
mel: [T, 80], torch.Tensor
semantic_token: [T], torch.LongTensor
chunk_size: int, default 150
"""
assert mel.dim() == 2
assert semantic_token.dim() == 1
assert semantic_token.shape[0] == mel.shape[0], "Semantic token and mel shape mismatch"
seq_len = mel.shape[0]
num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk)
start_pos = seq_len - num_chunks * chunk_size
res_mel = mel[:start_pos, :]
res_semantic_token = semantic_token[:start_pos]
self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id)
self.start_position_id += start_pos
self.update_incremental_state()
self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
if verbose:
logger.info("Prefilling prompt with {} chunks".format(num_chunks))
start_time = time.time()
for chunk_id in range(num_chunks):
start = start_pos + chunk_id * chunk_size
end = start + chunk_size
mel_chunk = mel[start:end, :]
semantic_token_chunk = semantic_token[start:end]
self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id)
self.start_position_id += end - start
self.update_incremental_state()
self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
if verbose:
logger.info("Prefilling done in {:.2f} seconds".format(time.time() - start_time))
def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0):
"""
mel_chunk: [T, 80], torch.Tensor, T is the chunk size
semantic_tokens_chunk: [T], torch.LongTensor
start_position_id: int, default 0
"""
bs = 1
semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype)
if self.normalize_mel:
mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std
x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool()
self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
x_t = torch.Tensor([0.999]).to(self.device)
self.ode_wrapper(x_t, mel_chunk)
@classmethod
def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
# open yaml file
with open(model_config, 'r') as f:
config = yaml.safe_load(f)
model_config = config["model"]["dit"]
dit = DiTPrefix(
input_size=model_config["input_size"],
semantic_vocab_size=model_config["semantic_vocab_size"] + 1,
hidden_size=model_config["hidden_size"],
depth=model_config["depth"],
num_heads=model_config["num_heads"],
mlp_ratio=model_config["mlp_ratio"],
ffn_type=model_config.get("ffn_type", "conv1d_conv1d"),
ffn_gated_glu=model_config.get("ffn_gated_glu", True),
ffn_act_layer=model_config.get("ffn_act_layer", "gelu"),
ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5),
use_rope=model_config.get("use_rope", False),
rope_params=model_config.get("rope_params", { "max_position_embeddings": 4096,"rope_base": 10000,"rope_interpolation_factor": 1 }),
position_embedding_type=model_config["position_embedding_type"],
max_seq_len=model_config["max_seq_len"],
output_size=model_config["input_size"],
prompt_cfg_dropout=0
)
cfg_semantic_token_id = model_config["semantic_vocab_size"]
# load state_dict
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"]
speech_model_params = {k.replace("speech_model.", ""): v for k, v in state_dict.items() if "speech_model" in k}
dit.load_state_dict(speech_model_params, strict=True)
logger.info(f">>> Loaded checkpoint from {ckpt_path}")
return cls(speech_model=dit, device=device, normalize_mel=config["normalize_mel"], mel_mean=config["mel_mean"], mel_std=config["mel_std"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id)
================================================
FILE: modules/audio_detokenizer/vocoder/activations.py
================================================
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
"""
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
"""
super(Snake, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta ∶= x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py
================================================
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py
================================================
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from ..torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
class FusedAntiAliasActivation(torch.autograd.Function):
"""
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
The hyperparameters are hard-coded in the kernel to maximize speed.
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
"""
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)
return activation_results
@staticmethod
def backward(ctx, output_grads):
raise NotImplementedError
return output_grads, None, None
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
fused: bool = True,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
self.fused = fused # Whether to use fused CUDA kernel or not
def forward(self, x):
if not self.fused:
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
else:
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = (
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if (
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
return x
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp
================================================
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
}
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu
================================================
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "type_shim.h"
#include <assert.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
// Hard-coded hyperparameters
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
constexpr int BUFFER_SIZE = 32;
constexpr int FILTER_SIZE = 12;
constexpr int HALF_FILTER_SIZE = 6;
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
template <typename input_t, typename output_t, typename acc_t>
__global__ void anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
// Up and downsample filters
input_t up_filter[FILTER_SIZE];
input_t down_filter[FILTER_SIZE];
// Load data from global memory including extra indices reserved for replication paddings
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
// Output stores downsampled output before writing to dst
output_t output[BUFFER_SIZE];
// blockDim/threadIdx = (128, 1, 1)
// gridDim/blockIdx = (seq_blocks, channels, batches)
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
int local_offset = threadIdx.x * BUFFER_SIZE;
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
// intermediate have double the seq_len
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
// Get values needed for replication padding before moving pointer
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
input_t seq_left_most_value = right_most_pntr[0];
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
// Move src and dst pointers
src += block_offset + local_offset;
dst += block_offset + local_offset;
// Alpha and beta values for snake activatons. Applies exp by default
alpha = alpha + blockIdx.y;
input_t alpha_val = expf(alpha[0]);
beta = beta + blockIdx.y;
input_t beta_val = expf(beta[0]);
#pragma unroll
for (int it = 0; it < FILTER_SIZE; it += 1)
{
up_filter[it] = up_ftr[it];
down_filter[it] = down_ftr[it];
}
// Apply replication padding for upsampling, matching torch impl
#pragma unroll
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
{
int element_index = seq_offset + it; // index for element
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
}
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
}
if ((element_index >= 0) && (element_index < seq_len))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
}
}
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
#pragma unroll
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
{
input_t acc = 0.0;
int element_index = intermediate_seq_offset + it; // index for intermediate
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
if ((element_index + f_idx) >= 0)
{
acc += up_filter[f_idx] * elements[it + f_idx];
}
}
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
}
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
double no_div_by_zero = 0.000000001;
#pragma unroll
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
{
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
}
// Apply replication padding before downsampling conv from intermediates
#pragma unroll
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
}
#pragma unroll
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
}
// Apply downsample strided convolution (assuming stride=2) from intermediates
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += 1)
{
input_t acc = 0.0;
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
}
output[it] = acc;
}
// Write output to dst
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
{
int element_index = seq_offset + it;
if (element_index < seq_len)
{
dst[it] = output[it];
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
if (seq_len == 0)
{
return;
}
else
{
// Use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
constexpr int seq_len_per_block = 4096;
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
dim3 blocks(blocks_per_seq_len, channels, batch_size);
dim3 threads(threads_per_block, 1, 1);
anti_alias_activation_forward<input_t, output_t, acc_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
}
}
}
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
{
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
const int batches = input.size(0);
const int channels = input.size(1);
const int seq_len = input.size(2);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor anti_alias_activation_results =
torch::empty({batches, channels, seq_len}, act_options);
void *input_ptr = static_cast<void *>(input.data_ptr());
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
void *beta_ptr = static_cast<void *>(beta.data_ptr());
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch anti alias activation_forward",
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
reinterpret_cast<const scalar_t *>(input_ptr),
reinterpret_cast<const scalar_t *>(up_filter_ptr),
reinterpret_cast<const scalar_t *>(down_filter_ptr),
reinterpret_cast<const scalar_t *>(alpha_ptr),
reinterpret_cast<const scalar_t *>(beta_ptr),
batches,
channels,
seq_len););
return anti_alias_activation_results;
}
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h
================================================
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
"""
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
"""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def load():
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"--use_fast_math",
]
+ extra_cuda_flags
+ cc_flag,
verbose=True,
)
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
sources = [
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper(
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h
================================================
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
from .filter import *
from .resample import *
from .act import *
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
"""
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
"""
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
):
"""
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
"""
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
# Input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return out
================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
xx = self.lowpass(x)
return xx
================================================
FILE: modules/audio_detokenizer/vocoder/bigvgan.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import os
import json
from pathlib import Path
from typing import Optional, Union, Dict
import torch
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
from modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta
from modules.audio_detokenizer.vocoder.utils import init_weights, get_padding
from modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d
from modules.audio_detokenizer.vocoder.utils import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
def load_hparams_from_json(path) -> AttrDict:
with open(path) as f:
data = f.read()
return AttrDict(json.loads(data))
class AMPBlock1(torch.nn.Module):
"""
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
Args:
h (AttrDict): Hyperparameters.
channels (int): Number of convolution channels.
kernel_size (int): Size of the convolution kernel. Default is 3.
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
"""
def __init__(
self,
h: AttrDict,
channels: int,
kernel_size: int = 3,
dilation: tuple = (1, 3, 5),
activation: str = None,
):
super().__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
)
for d in dilation
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
)
for _ in range(len(dilation))
]
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(
self.convs2
) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# Activation functions
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=Snake(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
"""
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
Args:
h (AttrDict): Hyperparameters.
channels (int): Number of convolution channels.
kernel_size (int): Size of the convolution kernel. Default is 3.
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
"""
def __init__(
self,
h: AttrDict,
channels: int,
kernel_size: int = 3,
dilation: tuple = (1, 3, 5),
activation: str = None,
):
super().__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
)
for d in dilation
]
)
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# Activation functions
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=Snake(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class BigVGAN(
torch.nn.Module,
PyTorchModelHubMixin,
library_name="bigvgan",
repo_url="https://github.com/NVIDIA/BigVGAN",
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
pipeline_tag="audio-to-audio",
license="mit",
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
):
"""
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
Args:
h (AttrDict): Hyperparameters.
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
Note:
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
"""
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
super().__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# Pre-conv
self.conv_pre = weight_norm(
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1":
resblock_class = AMPBlock1
elif h.resblock == "2":
resblock_class = AMPBlock2
else:
raise ValueError(
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
)
# Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList(
[
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
]
)
)
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
)
# Post-conv
activation_post = (
Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake"
else (
SnakeBeta(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snakebeta"
else None
)
)
if activation_post is None:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.activation_post = Activation1d(activation=activation_post)
# Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
# Weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
# Final tanh activation. Defaults to True for backward compatibility
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
def forward(self, x):
# Pre-conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# Upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# Post-conv
x = self.activation_post(x)
x = self.conv_post(x)
# Final tanh activation
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
return x
def remove_weight_norm(self):
try:
print("Removing weight norm...")
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
except ValueError:
print("[INFO] Model already removed weight norm. Skipping!")
pass
# Additional methods for huggingface_hub support
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights and config.json from a Pytorch model to a local directory."""
model_path = save_directory / "bigvgan_generator.pt"
torch.save({"generator": self.state_dict()}, model_path)
config_path = save_directory / "config.json"
with open(config_path, "w") as config_file:
json.dump(self.h, config_file, indent=4)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu", # Additional argument
strict: bool = False, # Additional argument
use_cuda_kernel: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
# Download and load hyperparameters (h) used by BigVGAN
if os.path.isdir(model_id):
print("Loading config.json from local directory")
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(
repo_id=model_id,
filename="config.json",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
h = load_hparams_from_json(config_file)
# instantiate BigVGAN using h
if use_cuda_kernel:
print(
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
)
print(
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
)
print(
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
)
model = cls(h, use_cuda_kernel=use_cuda_kernel)
# Download and load pretrained generator weight
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, "bigvgan_generator.pt")
else:
print(f"Loading weights from {model_id}")
model_file = hf_hub_download(
repo_id=model_id,
filename="bigvgan_generator.pt",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)
try:
model.load_state_dict(checkpoint_dict["generator"])
except RuntimeError:
print(
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
)
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])
return model
================================================
FILE: modules/audio_detokenizer/vocoder/utils.py
================================================
from librosa.filters import mel as librosa_mel_fn
import torch
import os
mel_basis_cache = {}
hann_window_cache = {}
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def get_melspec(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sampling_rate: int,
hop_size: int,
win_size: int,
fmin: int,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""
Calculate the mel spectrogram of an input signal.
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
Args:
y (torch.Tensor): Input signal.
n_fft (int): FFT size.
num_mels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input signal.
hop_size (int): Hop size for STFT.
win_size (int): Window size for STFT.
fmin (int): Minimum frequency for mel filterbank.
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
center (bool): Whether to pad the input to center the frames. Default is False.
Returns:
torch.Tensor: Mel spectrogram.
"""
if torch.min(y) < -1.0:
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
if torch.max(y) > 1.0:
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
print("Complete.")
return checkpoint_dict
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
================================================
FILE: modules/audio_tokenizer/audio_tokenizer.py
================================================
import torch
import librosa
import yaml
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
import safetensors
import accelerate
import soundfile as sf
import math
from einops import rearrange
from modules.audio_tokenizer.rep_codec import RepCodec
class AudioTokenizer(object):
def __init__(self, **kwargs):
self.device = kwargs.pop('device')
print(self.device)
# tokenize
feat_stats = kwargs.pop('feat_stats')
feat_stats = torch.load(feat_stats, map_location='cpu')
self.feat_mean = feat_stats['mean']
self.feat_std = torch.sqrt(feat_stats['var'])
wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
self.semantic_model.eval()
self.semantic_model.to(self.device)
self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
self.semantic_codec = RepCodec()
self.semantic_codec.eval()
pretrained_path = kwargs.pop("semantic_codec_ckpt")
safetensors.torch.load_model(self.semantic_codec, pretrained_path)
self.semantic_codec.to(self.device)
self.max_length = 2048
@torch.no_grad()
def tokenize(self, speech):
# Input:
# speech: torch tensor, shape[B, N_speech]
# Output:
# semantic token: torch tensor, shape[B, N]
inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
input_features = inputs["input_features"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
seg_num = math.ceil(input_features.shape[1] / self.max_length)
pad_num = seg_num * self.max_length - input_features.shape[1]
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)
feats = self.semantic_model(
input_features=input_features,
attention_mask=attention_mask,
output_hidden_states=True,
)
feat = feats.hidden_states[17]
feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
feat = feat[:, :feat.shape[1]-pad_num, :]
feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
semantic_token, _ = self.semantic_codec.quantize(feat)
return semantic_token
def get_audio_tokenizer():
config = dict()
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
audio_tokenizer = AudioTokenizer(**config)
return audio_tokenizer
================================================
FILE: modules/audio_tokenizer/quantize/__init__.py
================================================
from .vector_quantize import VectorQuantize
from .residual_vq import ResidualVQ
from .factorized_vector_quantize import FactorizedVectorQuantize
================================================
FILE: modules/audio_tokenizer/quantize/factorized_vector_quantize.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class FactorizedVectorQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=0.005,
codebook_loss_weight=1.0,
use_l2_normlize=True,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.use_l2_normlize = use_l2_normlize
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(
self.codebook_dim, self.input_dim, kernel_size=1
)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
def forward(self, z):
"""
Parameters
----------
z: torch.Tensor[B x D x T]
Returns
-------
z_q: torch.Tensor[B x D x T]
Quantized continuous representation of input
commit_loss: Tensor[B]
Commitment loss to train encoder to predict vectors closer to codebook entries
codebook_loss: Tensor[B]
Codebook loss to update the codebook
indices: torch.Tensor[B x T]
Codebook indices (quantized discrete representation of input)
z_e: torch.Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
z_e = self.in_project(z)
z_q, indices = self.decode_latents(z_e)
# Compute commitment loss and codebook loss
if self.training:
commit_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = (
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
* self.codebook_loss_weight
)
else:
commit_loss = torch.zeros(z.shape[0], device=z.device)
codebook_loss = torch.zeros(z.shape[0], device=z.device)
z_q = z_e + (z_q - z_e).detach()
z_q = self.out_project(z_q)
return z_q, commit_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight
# L2 normalize encodings and codebook
if self.use_l2_normlize:
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance between encodings and codebook,
# if use_l2_normlize is True, the distance is equal to cosine distance
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
def vq2emb(self, vq, out_proj=True):
emb = self.decode_code(vq)
if out_proj:
emb = self.out_project(emb)
return emb
def latent2dist(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight
# L2 normalize encodings and codebook
if self.use_l2_normlize:
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance between encodings and codebook,
# if use_l2_normlize is True, the distance is equal to cosine distance
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
) # (b*t, k)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
z_q = self.decode_code(indices)
return -dist, indices, z_q
================================================
FILE: modules/audio_tokenizer/quantize/residual_vq.py
================================================
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
from .vector_quantize import VectorQuantize
from .factorized_vector_quantize import FactorizedVectorQuantize
class ResidualVQ(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 256,
num_quantizers: int = 8,
codebook_size: int = 1024,
codebook_dim: int = 256,
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
quantizer_dropout: float = 0.5,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_type = quantizer_type
self.quantizer_dropout = quantizer_dropout
if quantizer_type == "vq":
VQ = VectorQuantize
elif quantizer_type == "fvq":
VQ = FactorizedVectorQuantize
else:
raise ValueError(f"Unknown quantizer type {quantizer_type}")
self.quantizers = nn.ModuleList(
[
VQ(
input_dim=input_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
**kwargs,
)
for _ in range(num_quantizers)
]
)
def forward(self, z, n_quantizers: int = None):
"""
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
"quantized_out" : Tensor[B x D x T]
Quantized continuous representation of input
"all_indices" : Tensor[N x B x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"all_commit_losses" : Tensor[N]
"all_codebook_losses" : Tensor[N]
"all_quantized" : Tensor[N x B x D x T]
"""
quantized_out = 0.0
residual = z
all_commit_losses = []
all_codebook_losses = []
all_indices = []
all_quantized = []
if n_quantizers is None:
n_quantizers = self.num_quantizers
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
quantized_out = quantized_out + z_q_i * mask[:, None, None]
residual = residual - z_q_i
commit_loss_i = (commit_loss_i * mask).mean()
codebook_loss_i = (codebook_loss_i * mask).mean()
all_commit_losses.append(commit_loss_i)
all_codebook_losses.append(codebook_loss_i)
all_indices.append(indices_i)
all_quantized.append(z_q_i)
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
torch.stack,
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
)
return (
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
all_quantized,
)
def vq2emb(self, vq, n_quantizers=None):
quantized_out = 0.0
if n_quantizers is None:
n_quantizers = self.num_quantizers
for idx, quantizer in enumerate(self.quantizers):
if idx >= n_quantizers:
break
quantized_out += quantizer.vq2emb(vq[idx])
return quantized_out
def latent2dist(self, z, n_quantizers=None):
quantized_out = 0.0
residual = z
all_dists = []
all_indices = []
if n_quantizers is None:
n_quantizers = self.num_quantizers
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
all_dists.append(dist_i)
all_indices.append(indices_i)
quantized_out = quantized_out + z_q_i
residual = residual - z_q_i
all_dists = torch.stack(all_dists)
all_indices = torch.stack(all_indices)
return all_dists, all_indices
================================================
FILE: modules/audio_tokenizer/quantize/vector_quantize.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
weight_init=False,
):
super().__init__()
self.decay = decay
init_fn = torch.randn if not weight_init else torch.zeros
embed = init_fn(codebook_size, dim)
if weight_init:
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer(
"initted", torch.Tensor([not kmeans_init])
) # if kmeans_init is True, then initted is False; otherwise, initted is True
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
def init_embed_(self, data):
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def replace(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace(batch_samples, mask=expired_codes)
def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
if not self.initted:
self.init_embed_(flatten)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
if self.training:
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = (
flatten.t() @ embed_onehot
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
return quantize, embed_ind
def vq2emb(self, vq):
quantize = F.embedding(vq, self.embed)
return quantize
def latent2dist(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
if not self.initted:
self.init_embed_(flatten)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
dist = dist.view(*shape[:-1], -1)
return dist, embed_ind, quantize
class SimpleCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
use_l2_normlize=False,
):
super().__init__()
self.dim = dim
self.codebook_size = codebook_size
self.use_l2_normlize = use_l2_normlize
self.embed = nn.Embedding(self.codebook_size, self.dim)
def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
if self.use_l2_normlize:
flatten = F.normalize(flatten)
embed = F.normalize(embed)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
return quantize, embed_ind
def vq2emb(self, vq):
quantize = F.embedding(vq, self.embed.weight)
return quantize
def latent2dist(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
if self.use_l2_normlize:
flatten = F.normalize(flatten)
embed = F.normalize(embed)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
dist = dist.view(*shape[:-1], -1)
return dist, embed_ind, quantize
class VectorQuantize(nn.Module):
"""Vector quantization and factorized vecotor quantization implementation
Args:
input_dim (int): Dimension of input.
codebook_size (int): Codebook size.
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
if use codebook_type == "euclidean", otherwise, if you want to use
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
commitment (float): Weight for commitment loss.
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
we suggest use it as True if you want to use factorized vector quantization
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=0.005,
codebook_loss_weight=1.0,
use_l2_normlize=False,
codebook_type="euclidean", # "euclidean" or "simple"
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
weight_init=False,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.use_l2_normlize = use_l2_normlize
self.codebook_type = codebook_type
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.decay = decay
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.weight_init = weight_init
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(
self.codebook_dim, self.input_dim, kernel_size=1
)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
if self.codebook_type == "euclidean":
self.codebook = EuclideanCodebook(
self.codebook_dim,
codebook_size=self.codebook_size,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
decay=self.decay,
eps=self.eps,
threshold_ema_dead_code=self.threshold_ema_dead_code,
weight_init=self.weight_init,
)
elif self.codebook_type == "simple":
self.codebook = SimpleCodebook(
self.codebook_dim,
codebook_size=self.codebook_size,
use_l2_normlize=self.use_l2_normlize,
)
else:
raise NotImplementedError(
f"codebook_type {self.codebook_type} is not implemented!"
)
def forward(self, z):
"""
Parameters
----------
z: torch.Tensor[B x D x T]
Returns
-------
z_q: torch.Tensor[B x D x T]
Quantized continuous representation of input
commit_loss: Tensor[B]
Commitment loss to train encoder to predict vectors closer to codebook entries
codebook_loss: Tensor[B]
Codebook loss to update the codebook
indices: torch.Tensor[B x T]
Codebook indices (quantized discrete representation of input)
z_e: torch.Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
z_e = self.in_project(z)
z_q, indices = self.decode_latents(z_e)
# Compute commitment loss and codebook loss
if self.training:
commit_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = (
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
* self.codebook_loss_weight
)
else:
commit_loss = torch.zeros(z.shape[0], device=z.device)
codebook_loss = torch.zeros(z.shape[0], device=z.device)
z_q = z_e + (z_q - z_e).detach()
z_q = self.out_project(z_q)
return z_q, commit_loss, codebook_loss, indices, z_e
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> b t d")
z_q, indices = self.codebook(encodings)
z_q = z_q.transpose(1, 2)
return z_q, indices
def vq2emb(self, vq, out_proj=True):
emb = self.codebook.vq2emb(vq)
emb = emb.transpose(1, 2)
if out_proj:
emb = self.out_project(emb)
return emb
def latent2dist(self, latents):
latents = rearrange(latents, "b d t -> b t d")
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
return dist, embed_ind, quantize.transpose(1, 2)
================================================
FILE: modules/audio_tokenizer/rep_codec.py
================================================
import torch
import torch.nn as nn
from modules.audio_tokenizer.quantize import ResidualVQ
from modules.audio_tokenizer.vocos import VocosBackbone
from modules.audio_tokenizer.transformer import TransformerEncoder
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weig
gitextract_g10rsitx/ ├── .gitignore ├── LICENSE ├── app.py ├── download_pretrain.py ├── en_llmprompt_script_gen.py ├── inference.py ├── modules/ │ ├── audio_detokenizer/ │ │ ├── audio_detokenizer.py │ │ ├── bigvgan_wrapper.py │ │ ├── flow_matching/ │ │ │ ├── dit_block.py │ │ │ ├── model.py │ │ │ ├── ode_wrapper.py │ │ │ └── scheduler.py │ │ ├── semantic_fm_prefix_streaming.py │ │ └── vocoder/ │ │ ├── activations.py │ │ ├── alias_free_activation/ │ │ │ ├── __init__.py │ │ │ ├── cuda/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation1d.py │ │ │ │ ├── anti_alias_activation.cpp │ │ │ │ ├── anti_alias_activation_cuda.cu │ │ │ │ ├── compat.h │ │ │ │ ├── load.py │ │ │ │ └── type_shim.h │ │ │ └── torch/ │ │ │ ├── __init__.py │ │ │ ├── act.py │ │ │ ├── filter.py │ │ │ └── resample.py │ │ ├── bigvgan.py │ │ └── utils.py │ ├── audio_tokenizer/ │ │ ├── audio_tokenizer.py │ │ ├── quantize/ │ │ │ ├── __init__.py │ │ │ ├── factorized_vector_quantize.py │ │ │ ├── residual_vq.py │ │ │ └── vector_quantize.py │ │ ├── rep_codec.py │ │ ├── transformer.py │ │ └── vocos.py │ └── tokenizer/ │ └── tokenizer.py ├── readme.md ├── requirements.txt ├── test/ │ ├── test_audio_detokenizer.py │ ├── test_audio_tokenizer.py │ └── test_tokenizer.py └── zh_llmprompt_script_gen.py
SYMBOL INDEX (298 symbols across 26 files)
FILE: app.py
function process_json_and_generate_audio (line 12) | def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text...
function update_ui_language (line 137) | def update_ui_language(language):
FILE: inference.py
class Model (line 18) | class Model(object):
method __init__ (line 19) | def __init__(self):
method _clean_text (line 52) | def _clean_text(self, text):
method _process_text (line 66) | def _process_text(self, js):
method inference (line 76) | def inference(self, js, streaming=False):
method infer_with_prompt (line 90) | def infer_with_prompt(self, js):
method infer_with_prompt_streaming (line 175) | def infer_with_prompt_streaming(self, js):
method infer_without_prompt (line 255) | def infer_without_prompt(self, js):
method infer_without_prompt_streaming (line 312) | def infer_without_prompt_streaming(self, js):
FILE: modules/audio_detokenizer/audio_detokenizer.py
class PrefixStreamingFlowMatchingDetokenizer (line 8) | class PrefixStreamingFlowMatchingDetokenizer:
method __init__ (line 9) | def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWra...
method from_pretrained (line 34) | def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_c...
method prefill (line 44) | def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: in...
method detokenize_streaming (line 76) | def detokenize_streaming(self, semantic_token, ode_step=30, verbose=Fa...
method clear_states (line 174) | def clear_states(self):
function get_audio_detokenizer (line 180) | def get_audio_detokenizer():
function detokenize (line 201) | def detokenize(detokenizer, tokens, ref_wav, ref_tokens):
function detokenize_streaming (line 220) | def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):
function detokenize_noref (line 236) | def detokenize_noref(detokenizer, tokens):
function detokenize_noref_streaming (line 255) | def detokenize_noref_streaming(detokenizer, tokens):
FILE: modules/audio_detokenizer/bigvgan_wrapper.py
class BigVGANWrapper (line 14) | class BigVGANWrapper:
method __init__ (line 15) | def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict...
method to_dtype (line 23) | def to_dtype(self, dtype):
method extract_mel_from_wav (line 26) | def extract_mel_from_wav(self, wav_path=None, wav_data=None):
method extract_mel_from_wav_batch (line 44) | def extract_mel_from_wav_batch(self, wav_data):
method decode_mel (line 58) | def decode_mel(self, mel):
method decode_mel_batch (line 69) | def decode_mel_batch(self, mel):
method from_pretrained (line 81) | def from_pretrained(cls, model_config, ckpt_path, device):
FILE: modules/audio_detokenizer/flow_matching/dit_block.py
function reshape_for_broadcast (line 11) | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
function apply_rotary_emb (line 24) | def apply_rotary_emb(
class Attention (line 39) | class Attention(nn.Module):
method __init__ (line 41) | def __init__(
method forward (line 67) | def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu...
function modulate (line 160) | def modulate(x, shift, scale):
class FinalLayer (line 164) | class FinalLayer(nn.Module):
method __init__ (line 168) | def __init__(self, hidden_size, out_channels):
method forward (line 177) | def forward(self, x, c):
class DiTBlock (line 184) | class DiTBlock(nn.Module):
method __init__ (line 188) | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="co...
method forward (line 209) | def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, ...
FILE: modules/audio_detokenizer/flow_matching/model.py
function precompute_freqs_cis (line 6) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
class TimestepEmbedder (line 28) | class TimestepEmbedder(nn.Module):
method __init__ (line 32) | def __init__(self, hidden_size, frequency_embedding_size=256):
method timestep_embedding (line 42) | def timestep_embedding(t, dim, max_period=10000):
method forward (line 62) | def forward(self, t):
class SinusoidalPositionalEmbedding (line 68) | class SinusoidalPositionalEmbedding(nn.Module):
method __init__ (line 74) | def __init__(self, embedding_dim, padding_idx, init_size=1024):
method get_embedding (line 86) | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
method forward (line 104) | def forward(self, input, incremental_state=None, timestep=None, **kwar...
method max_positions (line 125) | def max_positions(self):
method make_positions (line 129) | def make_positions(self, tensor, padding_idx):
class DiTPrefix (line 144) | class DiTPrefix(nn.Module):
method __init__ (line 148) | def __init__(
method initialize_weights (line 218) | def initialize_weights(self):
method forward (line 243) | def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, ...
FILE: modules/audio_detokenizer/flow_matching/ode_wrapper.py
function get_cached_zeros (line 8) | def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
class StreamingODEWrapperForPrefix (line 11) | class StreamingODEWrapperForPrefix(nn.Module):
method __init__ (line 12) | def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale...
method clear_all_states (line 38) | def clear_all_states(self):
method state_dict (line 48) | def state_dict(self):
method load_state_dict (line 59) | def load_state_dict(self, state_dict):
method set_conditions (line 68) | def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
method update_incremental_state (line 108) | def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_c...
method forward (line 151) | def forward(self, t, x, args=None):
FILE: modules/audio_detokenizer/flow_matching/scheduler.py
class SchedulerBase (line 9) | class SchedulerBase(ABC):
method __init__ (line 10) | def __init__(self) -> None:
method set_timesteps (line 14) | def set_timesteps(self):
method step (line 18) | def step(self):
method add_noise (line 22) | def add_noise(self):
class StreamingFlowMatchingScheduler (line 26) | class StreamingFlowMatchingScheduler(SchedulerBase):
method __init__ (line 27) | def __init__(self, timesteps=1000, sigma_min=1e-4,
method set_timesteps (line 39) | def set_timesteps(self, timesteps=15):
method step (line 42) | def step(self, xt, predicted_v):
method sample (line 50) | def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
method sample_by_neuralode (line 65) | def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=Fal...
method add_noise (line 76) | def add_noise(self, original_samples: torch.FloatTensor,
FILE: modules/audio_detokenizer/semantic_fm_prefix_streaming.py
class StreamingSemanticFMWrapper (line 16) | class StreamingSemanticFMWrapper:
method __init__ (line 17) | def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, m...
method infer_chunk (line 49) | def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id,
method infer_mel (line 105) | def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, ver...
method clear_all_states (line 150) | def clear_all_states(self):
method state_dict (line 155) | def state_dict(self):
method load_state_dict (line 162) | def load_state_dict(self, state_dict):
method update_incremental_state (line 168) | def update_incremental_state(self):
method prefill (line 172) | def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
method prefill_chunk (line 212) | def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_positi...
method from_pretrained (line 236) | def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_c...
FILE: modules/audio_detokenizer/vocoder/activations.py
class Snake (line 6) | class Snake(nn.Module):
method __init__ (line 23) | def __init__(
method forward (line 48) | def forward(self, x):
class SnakeBeta (line 62) | class SnakeBeta(nn.Module):
method __init__ (line 80) | def __init__(
method forward (line 110) | def forward(self, x):
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py
class FusedAntiAliasActivation (line 14) | class FusedAntiAliasActivation(torch.autograd.Function):
method forward (line 22) | def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
method backward (line 30) | def backward(ctx, output_grads):
class Activation1d (line 35) | class Activation1d(nn.Module):
method __init__ (line 36) | def __init__(
method forward (line 54) | def forward(self, x):
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp
function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py
function load (line 17) | def load():
function _get_cuda_bare_metal_version (line 68) | def _get_cuda_bare_metal_version(cuda_dir):
function _create_build_dir (line 81) | def _create_build_dir(buildpath):
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py
class Activation1d (line 8) | class Activation1d(nn.Module):
method __init__ (line 9) | def __init__(
method forward (line 25) | def forward(self, x):
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py
function sinc (line 15) | def sinc(x: torch.Tensor):
function kaiser_sinc_filter1d (line 30) | def kaiser_sinc_filter1d(
class LowPassFilter1d (line 65) | class LowPassFilter1d(nn.Module):
method __init__ (line 66) | def __init__(
method forward (line 94) | def forward(self, x):
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py
class UpSample1d (line 10) | class UpSample1d(nn.Module):
method __init__ (line 11) | def __init__(self, ratio=2, kernel_size=None):
method forward (line 29) | def forward(self, x):
class DownSample1d (line 41) | class DownSample1d(nn.Module):
method __init__ (line 42) | def __init__(self, ratio=2, kernel_size=None):
method forward (line 55) | def forward(self, x):
FILE: modules/audio_detokenizer/vocoder/bigvgan.py
function load_hparams_from_json (line 25) | def load_hparams_from_json(path) -> AttrDict:
class AMPBlock1 (line 31) | class AMPBlock1(torch.nn.Module):
method __init__ (line 44) | def __init__(
method forward (line 132) | def forward(self, x):
method remove_weight_norm (line 143) | def remove_weight_norm(self):
class AMPBlock2 (line 150) | class AMPBlock2(torch.nn.Module):
method __init__ (line 163) | def __init__(
method forward (line 232) | def forward(self, x):
method remove_weight_norm (line 238) | def remove_weight_norm(self):
class BigVGAN (line 243) | class BigVGAN(
method __init__ (line 266) | def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
method forward (line 360) | def forward(self, x):
method remove_weight_norm (line 388) | def remove_weight_norm(self):
method _save_pretrained (line 403) | def _save_pretrained(self, save_directory: Path) -> None:
method _from_pretrained (line 414) | def _from_pretrained(
FILE: modules/audio_detokenizer/vocoder/utils.py
function dynamic_range_compression_torch (line 7) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
function spectral_normalize_torch (line 11) | def spectral_normalize_torch(magnitudes):
function get_melspec (line 14) | def get_melspec(
class AttrDict (line 86) | class AttrDict(dict):
method __init__ (line 87) | def __init__(self, *args, **kwargs):
function load_checkpoint (line 91) | def load_checkpoint(filepath, device):
function init_weights (line 98) | def init_weights(m, mean=0.0, std=0.01):
function get_padding (line 104) | def get_padding(kernel_size, dilation=1):
FILE: modules/audio_tokenizer/audio_tokenizer.py
class AudioTokenizer (line 13) | class AudioTokenizer(object):
method __init__ (line 14) | def __init__(self, **kwargs):
method tokenize (line 38) | def tokenize(self, speech):
function get_audio_tokenizer (line 67) | def get_audio_tokenizer():
FILE: modules/audio_tokenizer/quantize/factorized_vector_quantize.py
function WNConv1d (line 9) | def WNConv1d(*args, **kwargs):
function WNConvTranspose1d (line 13) | def WNConvTranspose1d(*args, **kwargs):
class FactorizedVectorQuantize (line 17) | class FactorizedVectorQuantize(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 47) | def forward(self, z):
method embed_code (line 91) | def embed_code(self, embed_id):
method decode_code (line 94) | def decode_code(self, embed_id):
method decode_latents (line 97) | def decode_latents(self, latents):
method vq2emb (line 118) | def vq2emb(self, vq, out_proj=True):
method latent2dist (line 124) | def latent2dist(self, latents):
FILE: modules/audio_tokenizer/quantize/residual_vq.py
class ResidualVQ (line 15) | class ResidualVQ(nn.Module):
method __init__ (line 21) | def __init__(
method forward (line 59) | def forward(self, z, n_quantizers: int = None):
method vq2emb (line 135) | def vq2emb(self, vq, n_quantizers=None):
method latent2dist (line 145) | def latent2dist(self, z, n_quantizers=None):
FILE: modules/audio_tokenizer/quantize/vector_quantize.py
function WNConv1d (line 9) | def WNConv1d(*args, **kwargs):
function WNConvTranspose1d (line 13) | def WNConvTranspose1d(*args, **kwargs):
function l2norm (line 17) | def l2norm(t):
function ema_inplace (line 21) | def ema_inplace(moving_avg, new, decay):
function laplace_smoothing (line 25) | def laplace_smoothing(x, n_categories, eps=1e-5):
function sample_vectors (line 29) | def sample_vectors(samples, num):
function kmeans (line 40) | def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
class EuclideanCodebook (line 71) | class EuclideanCodebook(nn.Module):
method __init__ (line 72) | def __init__(
method init_embed_ (line 104) | def init_embed_(self, data):
method replace (line 111) | def replace(self, samples, mask):
method expire_codes_ (line 117) | def expire_codes_(self, batch_samples):
method forward (line 127) | def forward(self, x):
method vq2emb (line 162) | def vq2emb(self, vq):
method latent2dist (line 166) | def latent2dist(self, x):
class SimpleCodebook (line 189) | class SimpleCodebook(nn.Module):
method __init__ (line 190) | def __init__(
method forward (line 204) | def forward(self, x):
method vq2emb (line 225) | def vq2emb(self, vq):
method latent2dist (line 229) | def latent2dist(self, x):
class VectorQuantize (line 253) | class VectorQuantize(nn.Module):
method __init__ (line 273) | def __init__(
method forward (line 336) | def forward(self, z):
method decode_latents (line 380) | def decode_latents(self, latents):
method vq2emb (line 386) | def vq2emb(self, vq, out_proj=True):
method latent2dist (line 393) | def latent2dist(self, latents):
FILE: modules/audio_tokenizer/rep_codec.py
function init_weights (line 9) | def init_weights(m):
class RepCodec (line 17) | class RepCodec(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 136) | def forward(self, x):
method quantize (line 177) | def quantize(self, x):
method reset_parameters (line 196) | def reset_parameters(self):
FILE: modules/audio_tokenizer/transformer.py
class StyleAdaptiveLayerNorm (line 8) | class StyleAdaptiveLayerNorm(nn.Module):
method __init__ (line 9) | def __init__(self, normalized_shape, eps=1e-5):
method forward (line 17) | def forward(self, x, condition):
class PositionalEncoding (line 30) | class PositionalEncoding(nn.Module):
method __init__ (line 31) | def __init__(self, d_model, dropout, max_len=5000):
method forward (line 44) | def forward(self, x):
class TransformerFFNLayer (line 49) | class TransformerFFNLayer(nn.Module):
method __init__ (line 50) | def __init__(
method forward (line 70) | def forward(self, x):
class TransformerEncoderLayer (line 81) | class TransformerEncoderLayer(nn.Module):
method __init__ (line 82) | def __init__(
method forward (line 117) | def forward(self, x, key_padding_mask, conditon=None):
class TransformerEncoder (line 149) | class TransformerEncoder(nn.Module):
method __init__ (line 150) | def __init__(
method forward (line 217) | def forward(self, x, key_padding_mask, condition=None):
FILE: modules/audio_tokenizer/vocos.py
function safe_log (line 12) | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
function symlog (line 26) | def symlog(x: torch.Tensor) -> torch.Tensor:
function symexp (line 30) | def symexp(x: torch.Tensor) -> torch.Tensor:
class STFT (line 34) | class STFT(nn.Module):
method __init__ (line 35) | def __init__(
method forward (line 50) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class ISTFT (line 78) | class ISTFT(nn.Module):
method __init__ (line 93) | def __init__(
method forward (line 106) | def forward(self, spec: torch.Tensor) -> torch.Tensor:
class MDCT (line 164) | class MDCT(nn.Module):
method __init__ (line 173) | def __init__(self, frame_len: int, padding: str = "same"):
method forward (line 191) | def forward(self, audio: torch.Tensor) -> torch.Tensor:
class IMDCT (line 225) | class IMDCT(nn.Module):
method __init__ (line 234) | def __init__(self, frame_len: int, padding: str = "same"):
method forward (line 250) | def forward(self, X: torch.Tensor) -> torch.Tensor:
class FourierHead (line 293) | class FourierHead(nn.Module):
method forward (line 296) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class ISTFTHead (line 308) | class ISTFTHead(FourierHead):
method __init__ (line 320) | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str...
method forward (line 328) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class IMDCTSymExpHead (line 358) | class IMDCTSymExpHead(FourierHead):
method __init__ (line 371) | def __init__(
method forward (line 395) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class IMDCTCosHead (line 418) | class IMDCTCosHead(FourierHead):
method __init__ (line 429) | def __init__(
method forward (line 441) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class ConvNeXtBlock (line 463) | class ConvNeXtBlock(nn.Module):
method __init__ (line 475) | def __init__(
method forward (line 502) | def forward(
class AdaLayerNorm (line 524) | class AdaLayerNorm(nn.Module):
method __init__ (line 533) | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float...
method forward (line 546) | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) ->...
class ResBlock1 (line 554) | class ResBlock1(nn.Module):
method __init__ (line 570) | def __init__(
method forward (line 676) | def forward(self, x: torch.Tensor) -> torch.Tensor:
method remove_weight_norm (line 687) | def remove_weight_norm(self):
method get_padding (line 694) | def get_padding(kernel_size: int, dilation: int = 1) -> int:
class Backbone (line 698) | class Backbone(nn.Module):
method forward (line 701) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
class VocosBackbone (line 714) | class VocosBackbone(Backbone):
method __init__ (line 728) | def __init__(
method _init_weights (line 760) | def _init_weights(self, m):
method forward (line 765) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
class VocosResNetBackbone (line 780) | class VocosResNetBackbone(Backbone):
method __init__ (line 791) | def __init__(
method forward (line 811) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
class Vocos (line 818) | class Vocos(nn.Module):
method __init__ (line 819) | def __init__(
method forward (line 841) | def forward(self, x):
FILE: modules/tokenizer/tokenizer.py
function encode_pieces (line 9) | def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sampl...
class AbstractTokenizer (line 20) | class AbstractTokenizer(ABC):
method __init__ (line 23) | def __init__(self, name):
method vocab_size (line 29) | def vocab_size(self):
method vocab (line 34) | def vocab(self):
method inv_vocab (line 40) | def inv_vocab(self):
method tokenize (line 45) | def tokenize(self, text):
method detokenize (line 48) | def detokenize(self, token_ids):
method cls (line 53) | def cls(self):
method sep (line 58) | def sep(self):
method pad (line 63) | def pad(self):
method eod (line 68) | def eod(self):
method mask (line 73) | def mask(self):
class SPieceTokenizer (line 78) | class SPieceTokenizer(AbstractTokenizer):
method __init__ (line 79) | def __init__(self, spm_file: str):
method encode_pieces (line 96) | def encode_pieces(self, text: str, sample=False):
method _initialize_index_2_bytes (line 103) | def _initialize_index_2_bytes(self):
method set_add_dummy_prefix (line 111) | def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
method add_special_id (line 119) | def add_special_id(self, token_id):
method has_dummy_prefix (line 123) | def has_dummy_prefix(self):
method vocab_size (line 128) | def vocab_size(self):
method vocab (line 132) | def vocab(self):
method get_array_bytes (line 136) | def get_array_bytes(self, array):
method tokenize (line 139) | def tokenize(self, text):
method encode (line 143) | def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs...
method convert_tokens_to_ids (line 152) | def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list...
method detokenize (line 157) | def detokenize(self, token_ids):
method decode (line 164) | def decode(self, token_ids: Union[int, list[int]], skip_special_tokens...
method get_token_id (line 170) | def get_token_id(self, token):
method inv_vocab (line 173) | def inv_vocab(self):
method decode_pieces (line 177) | def decode_pieces(self, pieces):
method eod (line 181) | def eod(self):
method pad_id (line 185) | def pad_id(self):
method eos_id (line 189) | def eos_id(self):
method bos_id (line 193) | def bos_id(self):
method unk_id (line 197) | def unk_id(self):
method pad_token_id (line 201) | def pad_token_id(self):
method eos_token_id (line 205) | def eos_token_id(self):
class ExtraTokens (line 210) | class ExtraTokens:
function instantiate_extra_tokens (line 221) | def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
function get_tokenizer_and_extra_tokens (line 238) | def get_tokenizer_and_extra_tokens():
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (283K chars).
[
{
"path": ".gitignore",
"chars": 136,
"preview": "*.safetensors\n*.pt\n*.vscode\n**/__pycache__/\nmodules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/\ntmp*\nres"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2025 Zeqian Ju\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
},
{
"path": "app.py",
"chars": 9338,
"preview": "import gradio as gr\nfrom huggingface_hub import snapshot_download \nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_"
},
{
"path": "download_pretrain.py",
"chars": 118,
"preview": "from huggingface_hub import snapshot_download\nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_dir='./resources/')"
},
{
"path": "en_llmprompt_script_gen.py",
"chars": 19143,
"preview": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### Task Description \nPlease summarize the input document in plain text "
},
{
"path": "inference.py",
"chars": 22900,
"preview": "\nimport sys\nsys.path.append(\".\")\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\nfrom modules.aud"
},
{
"path": "modules/audio_detokenizer/audio_detokenizer.py",
"chars": 12700,
"preview": "\nimport torch\n\nfrom modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper\nfrom modules.audio_detokenizer.seman"
},
{
"path": "modules/audio_detokenizer/bigvgan_wrapper.py",
"chars": 3154,
"preview": "import os\nimport json\nimport logging\n\nimport librosa\nimport torch\n\nfrom modules.audio_detokenizer.vocoder.bigvgan import"
},
{
"path": "modules/audio_detokenizer/flow_matching/dit_block.py",
"chars": 9133,
"preview": "import torch\nimport torch.nn as nn\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom flash_attn"
},
{
"path": "modules/audio_detokenizer/flow_matching/model.py",
"chars": 12192,
"preview": "import torch\nimport torch.nn as nn\nimport math\nfrom modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, F"
},
{
"path": "modules/audio_detokenizer/flow_matching/ode_wrapper.py",
"chars": 7809,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import lru_cache\nimport copy\n\n\n@lru_cache(maxsize=1)\ndef get_cached_ze"
},
{
"path": "modules/audio_detokenizer/flow_matching/scheduler.py",
"chars": 2562,
"preview": "import torch\nfrom abc import abstractmethod, ABC\ntry:\n from torchdyn.core import NeuralODE\n NEURALODE_INSTALLED = "
},
{
"path": "modules/audio_detokenizer/semantic_fm_prefix_streaming.py",
"chars": 12226,
"preview": "import yaml\nimport logging\nimport time\n\nimport os\nimport torch\n\nfrom modules.audio_detokenizer.flow_matching.ode_wrapper"
},
{
"path": "modules/audio_detokenizer/vocoder/activations.py",
"chars": 4403,
"preview": "import torch\nfrom torch import nn, sin, pow\nfrom torch.nn import Parameter\n\n\nclass Snake(nn.Module):\n \"\"\"\n Impleme"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py",
"chars": 2569,
"preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n# Licensed under the MIT license.\n\nimport torch\nimport torch.nn as nn\nfrom .."
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp",
"chars": 977,
"preview": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu",
"chars": 10328,
"preview": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h",
"chars": 893,
"preview": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py",
"chars": 2594,
"preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n# Licensed under the MIT license.\n\nimport os\nimport pathlib\nimport subprocess"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h",
"chars": 5838,
"preview": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py",
"chars": 200,
"preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n# LICENSE is in incl_licens"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py",
"chars": 825,
"preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n# LICENSE is in incl_licens"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py",
"chars": 3401,
"preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n# LICENSE is in incl_licens"
},
{
"path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py",
"chars": 1831,
"preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n# LICENSE is in incl_licens"
},
{
"path": "modules/audio_detokenizer/vocoder/bigvgan.py",
"chars": 17693,
"preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n# Licensed under the MIT license.\n\n# Adapted from https://github.com/jik876/h"
},
{
"path": "modules/audio_detokenizer/vocoder/utils.py",
"chars": 3353,
"preview": "from librosa.filters import mel as librosa_mel_fn\nimport torch\nimport os\nmel_basis_cache = {}\nhann_window_cache = {}\n\nde"
},
{
"path": "modules/audio_tokenizer/audio_tokenizer.py",
"chars": 3061,
"preview": "import torch\nimport librosa\nimport yaml\nfrom transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor\nimport s"
},
{
"path": "modules/audio_tokenizer/quantize/__init__.py",
"chars": 145,
"preview": "from .vector_quantize import VectorQuantize\nfrom .residual_vq import ResidualVQ\nfrom .factorized_vector_quantize import "
},
{
"path": "modules/audio_tokenizer/quantize/factorized_vector_quantize.py",
"chars": 4878,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom "
},
{
"path": "modules/audio_tokenizer/quantize/residual_vq.py",
"chars": 5381,
"preview": "from typing import Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ein"
},
{
"path": "modules/audio_tokenizer/quantize/vector_quantize.py",
"chars": 13476,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repe"
},
{
"path": "modules/audio_tokenizer/rep_codec.py",
"chars": 6464,
"preview": "import torch\nimport torch.nn as nn\n\n\nfrom modules.audio_tokenizer.quantize import ResidualVQ\nfrom modules.audio_tokenize"
},
{
"path": "modules/audio_tokenizer/transformer.py",
"chars": 7207,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\n\n\nclass StyleAdap"
},
{
"path": "modules/audio_tokenizer/vocos.py",
"chars": 30188,
"preview": "from typing import Optional, Tuple\n\nimport numpy as np\nimport scipy\nimport torch\nfrom torch import nn, view_as_real, vie"
},
{
"path": "modules/tokenizer/tokenizer.py",
"chars": 7442,
"preview": "from abc import ABC\nfrom abc import abstractmethod\nimport sentencepiece as spm\nfrom sentencepiece import sentencepiece_m"
},
{
"path": "readme.md",
"chars": 2933,
"preview": "# MoonCast: High-Quality Zero-Shot Podcast Generation\n\n<p align=\"center\">\n <picture>\n <img src=\"./fig/logo.png"
},
{
"path": "requirements.txt",
"chars": 194,
"preview": "torch==2.3.1\ntorchaudio==2.3.1\nsentencepiece==0.2.0\nprotobuf\nnumpy\n\nlibrosa==0.9.1\npyyaml\ntransformers\nsafetensors\neinop"
},
{
"path": "test/test_audio_detokenizer.py",
"chars": 1356,
"preview": "import sys\nimport torch\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nfro"
},
{
"path": "test/test_audio_tokenizer.py",
"chars": 559,
"preview": "import sys\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nimport torch\n\nif"
},
{
"path": "test/test_tokenizer.py",
"chars": 1119,
"preview": "import sys\nsys.path.append('.')\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\n\n\n\nif __name__ =="
},
{
"path": "zh_llmprompt_script_gen.py",
"chars": 7031,
"preview": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### 任务说明\n请按照以下结构总结输入文件, 普通文本格式。总结应当有创造性,保证信息全面,包含所有有趣、不常见、有价值的观点和信息。\n- **"
}
]
About this extraction
This page contains the full source code of the jzq2000/MoonCast GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (252.7 KB), approximately 65.2k tokens, and a symbol index with 298 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.