Repository: jzq2000/MoonCast Branch: main Commit: 07037e58f427 Files: 43 Total size: 252.7 KB Directory structure: gitextract_g10rsitx/ ├── .gitignore ├── LICENSE ├── app.py ├── download_pretrain.py ├── en_llmprompt_script_gen.py ├── inference.py ├── modules/ │ ├── audio_detokenizer/ │ │ ├── audio_detokenizer.py │ │ ├── bigvgan_wrapper.py │ │ ├── flow_matching/ │ │ │ ├── dit_block.py │ │ │ ├── model.py │ │ │ ├── ode_wrapper.py │ │ │ └── scheduler.py │ │ ├── semantic_fm_prefix_streaming.py │ │ └── vocoder/ │ │ ├── activations.py │ │ ├── alias_free_activation/ │ │ │ ├── __init__.py │ │ │ ├── cuda/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation1d.py │ │ │ │ ├── anti_alias_activation.cpp │ │ │ │ ├── anti_alias_activation_cuda.cu │ │ │ │ ├── compat.h │ │ │ │ ├── load.py │ │ │ │ └── type_shim.h │ │ │ └── torch/ │ │ │ ├── __init__.py │ │ │ ├── act.py │ │ │ ├── filter.py │ │ │ └── resample.py │ │ ├── bigvgan.py │ │ └── utils.py │ ├── audio_tokenizer/ │ │ ├── audio_tokenizer.py │ │ ├── quantize/ │ │ │ ├── __init__.py │ │ │ ├── factorized_vector_quantize.py │ │ │ ├── residual_vq.py │ │ │ └── vector_quantize.py │ │ ├── rep_codec.py │ │ ├── transformer.py │ │ └── vocos.py │ └── tokenizer/ │ └── tokenizer.py ├── readme.md ├── requirements.txt ├── test/ │ ├── test_audio_detokenizer.py │ ├── test_audio_tokenizer.py │ └── test_tokenizer.py └── zh_llmprompt_script_gen.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.safetensors *.pt *.vscode **/__pycache__/ modules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/ tmp* resources/ *.gradio ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Zeqian Ju Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: app.py ================================================ import gradio as gr from huggingface_hub import snapshot_download snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/') from inference import Model import base64 model = Model() model.generate_config.max_new_tokens = 50 * 50 # no more than 50s per turn def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str): try: print(json_dialogue_input_str, type(json_dialogue_input_str)) print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1) # json_data = json.loads(json_dialogue_input_str) json_data = eval(json_dialogue_input_str.strip()) print(json_data, type(json_data)) def validate_json(data): try: if not isinstance(data, list): return "json must be a dictionary" cur_spk_should_be = 0 for item in data: if item['role'] != str(cur_spk_should_be): return f"role should be {cur_spk_should_be} in item {item}" cur_spk_should_be = 1 - cur_spk_should_be return None except Exception as e: return str(e) validation_error = validate_json(json_data) if validation_error: raise gr.Error(validation_error) role_mapping = { "0": { "ref_audio": prompt_audio_role0_file, "ref_text": prompt_text_role0, }, "1": { "ref_audio": prompt_audio_role1_file, "ref_text": prompt_text_role1, } } # 完整输入 JSON (你需要根据你的模型调整) model_input_json = { "role_mapping": role_mapping, "dialogue": json_data, # 从用户输入的 JSON 中获取 dialogue } print("模型推理输入 JSON:", model_input_json) # 4. **[重要] 调用你的 Model 类的 `inference` 方法** # audio_bytes = model.inference(model_input_json) # 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放) # return base64.b64decode(audio_bytes) for cur_chunk in model.inference(model_input_json, streaming=True): yield base64.b64decode(cur_chunk) except Exception as e: # return str(e) # 返回错误信息给 Gradio raise gr.Error(str(e)) title_en = "# PODCAST generator (supports English and Chinese)" title_zh = "# 播客生成 (支持英文和中文)" instruct_en = "## See [Github](https://github.com/jzq2000/MoonCast) for podcast script generation." instruct_zh = "## 播客剧本生成请参考 [Github](https://github.com/jzq2000/MoonCast)。" input_labels_en = ["Prompt Audio for Role 0", "Prompt Text for Role 0", "Prompt Audio for Role 1", "Prompt Text for Role 1", "Script JSON Input"] input_labels_zh = ["角色 0 的 Prompt 音频", "角色 0 的 Prompt 文本", "角色 1 的 Prompt 音频", "角色 1 的 Prompt 文本", "剧本 JSON 输入"] output_label_en = "Generated Audio Output (streaming)" output_label_zh = "生成的音频输出(流式)" example_prompt_text_role0_en = "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem." example_prompt_text_role0_zh = "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。" example_prompt_text_role1_en = "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." example_prompt_text_role1_zh = "他最后就能让同样食材炒出来的菜味道大大提升。" text_placeholder_zh = "对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。" text_placeholder_en = "Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio." example_json_en = '''[ { "role": "0", "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.", }, { "role": "1", "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.", }, { "role": "0", "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.", }, { "role": "1", "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part." } ]''' example_json_zh = '''[ { "role": "0", "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用," }, { "role": "1", "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练," }, { "role": "0", "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。" } ] ''' # examples_en = [ # ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en] # ] # examples_zh = [ # ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh] # ] examples = [ ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en], ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh] ] # -------------------- 更新界面元素的函数 -------------------- def update_ui_language(language): if language == "English": return gr.update(value=title_en), \ gr.update(value=instruct_en), \ gr.update(label="UI Language"), \ gr.update(label=input_labels_en[0]), \ gr.update(label=input_labels_en[1]), \ gr.update(label=input_labels_en[2]), \ gr.update(label=input_labels_en[3]), \ gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \ gr.update(label=output_label_en), \ gr.update(value="Submit"), \ gr.update(label="Examples (Demonstration Use Only. Do Not Redistribute.)", headers=input_labels_en) elif language == "中文": return gr.update(value=title_zh), \ gr.update(value=instruct_zh), \ gr.update(label="UI 语言"), \ gr.update(label=input_labels_zh[0]), \ gr.update(label=input_labels_zh[1]), \ gr.update(label=input_labels_zh[2]), \ gr.update(label=input_labels_zh[3]), \ gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \ gr.update(label=output_label_zh), \ gr.update(value="提交"), \ gr.update(label="示例 (仅用于展示,切勿私自传播。)", headers=input_labels_zh) else: raise ValueError("Invalid language selected") audio_output = gr.Audio(label=output_label_en, streaming=True) css = """ .centered-title { /* CSS rule for centering title */ text-align: center !important; } """ # -------------------- Gradio 界面定义 (修改) -------------------- with gr.Blocks(css=css) as iface: title_output = gr.Markdown(value=title_zh, elem_classes="centered-title") instruct_output = gr.Markdown(value=instruct_zh) language_choice = gr.Radio(["中文", "English"], value="中文", label="UI语言") with gr.Row(): # Main row to create two columns with gr.Column(scale=2): json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs audio_input_role0 = gr.Audio(type="filepath", label=input_labels_zh[0]) # Prompt Audio for Role 0 text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0 with gr.Column(scale=1): # audio_input_role1 = gr.Audio(type="filepath", label=input_labels_zh[2]) # Prompt Audio for Role 1 text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1 examples_component = gr.Examples( examples=examples, inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input], cache_examples=False, label="示例(仅用于展示,切勿私自传播。)", ) submit_button = gr.Button("提交") submit_button.click( fn=process_json_and_generate_audio, inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input], outputs=audio_output ) audio_output.render() language_choice.change( fn=update_ui_language, inputs=language_choice, outputs=[title_output, instruct_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset] ) iface.launch(share=True) ================================================ FILE: download_pretrain.py ================================================ from huggingface_hub import snapshot_download snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/') ================================================ FILE: en_llmprompt_script_gen.py ================================================ # INPUT -> BRIEF -> SCRIPT INPUT2BRIEF = ''' ### Task Description Please summarize the input document in plain text format according to the following structure. The summary should be creative, comprehensive, and include all interesting, uncommon, and valuable viewpoints and information. - **Text Requirements**: 1. Directly output the result without any additional information. 2. The summary should be in English. Retain a small number of proper nouns, names, and abbreviations in their original form (e.g., Chinese characters). 3. Do not include any mathematical formulas. 4. Do not alter any proper nouns, names, or abbreviations from the original text. Unless there is a common translation, do not translate proper nouns. Do not attempt to modify the meaning of proper nouns. 5. **Intelligently convert numbers in abbreviations. For example, "a2b" should be interpreted as "a to b," not "a two b"; "a4b" as "a for b," not "a four b"; "v2" may represent "version two" or "second generation." Provide the original abbreviation and your suggested English translation.** ### Title and Author - **Language Requirements**: English, formal written language. - **Content Requirements**: Provide the title and author of the document. Briefly summarize the theme of the document and the author's background. Ensure all important information is included without omission and sufficient context is retained. ### Abstract - **Language Requirements**: English, formal written language. - **Content Requirements**: 1. What this document has done. 2. Whether similar work has been done before. 3. If similar work exists, why this document is still necessary. 4. How this document specifically addresses the topic. 5. How well this document achieves its goals. - **Additional Requirements**: Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names. ### Main Themes and Concepts - **Language Requirements**: English, formal written language. - **Content Requirements**: Each theme and concept should be organized according to the 3W principle: - **What**: Clearly define the problem. - **Why**: Analyze the problem and identify its root causes. - **How**: Explain how the document addresses the problem. - **Additional Requirements**: 1. Ensure each theme and concept is comprehensive and includes all important details. Fully elaborate on the "What" and "Why" sections. 2. Avoid technical details such as mathematical formulas in the "How" section. Use language that is easily understood by a general audience. 3. Ensure themes and concepts do not overlap and maintain clear logic. 4. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names. ### Key Citations - **Language Requirements**: English, formal written language. - **Content Requirements**: Organize the content according to the following structure: 1. **Argument**: State what needs to be proven. 2. **Evidence**: Provide the material used to support the argument. 3. **Reasoning**: Describe the process of using evidence to prove the argument. - **Additional Requirements**: 1. Ensure all evidence and reasoning are directly sourced from the original text without fabrication. 2. Ensure citation content is complete and retains sufficient context without simplification. Avoid using mathematical formulas in citations. 3. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names. ### Conclusion - **Language Requirements**: English, formal written language. - **Content Requirements**: Highlight the most important and impactful aspects of the document. Compared to the abstract, this section should provide more detailed insights related to the main themes and concepts. It may also include future directions for improvement, current application scenarios, and existing challenges. ''' BRIEF2SCRIPT = ''' ## 1. Task Overview Please generate a lively English podcast script based on the provided English summary text and your knowledge of the topic. The script should feature a dialogue between two speakers who take turns speaking. Output format should be JSON-parsable **list**. Each speaker's turn is a **dictionary** containing "speaker" and "text" fields. Example format: `[{{"speaker": "1", "text": "xxx"}}]`. The "speaker" field indicates the speaker's identity (1 for host, 2 for guest), and the "text" field is the spoken content. Output should start directly with the JSON code block, without any extra information. ## 2. Content and Structure ### (1) Text Content - The summary text contains all important information, which needs to be comprehensively selected and incorporated into the script. - Present information through a dialogue between two speakers, maintaining creativity and abstracting away unimportant details. For example, listeners aren't concerned with specific test names, but rather the task itself, the results, and the analysis. ### (2) Structure Design - **Opening:** Introduce the topic and briefly describe the discussion content, without mentioning speaker names. - **Key Theme Discussion:** Discuss important themes based on the summary text. Expand on the summary, don't just repeat it verbatim. - **Closing:** Briefly recap the discussion highlights and offer an outlook on future or technological developments. ## 3. Language Style ### (1) Conversational Style - The text should be as conversational as possible, aiming for a style similar to automatic speech recognition output. Include filler words such as 'um,' 'uh,' 'like,' 'you know,' 'so,' 'right?', and so on. Response words such as 'Yeah,' 'Right,' 'Okay,' and similar. Conversational expressions, repetitions, informal grammar, etc. Use short sentences. Avoid directly copying and pasting structured text from the summary text. Parentheses and other symbols not typically found in speech recognition transcripts should be avoided. Spaces within sentences indicate pauses. Be aware that there might be homophone errors, potentially due to accents. Questions should sound very conversational. Pay particular attention to incorporating conversational details, especially in questions. For example: [ {{ "speaker": "1", "text": "Welcome back to the podcast, everyone. Today we're diving into, uh, something that's really changing everything around us, A I." }}, {{ "speaker": "2", "text": "Yeah, A I is, like, everywhere now, isn't it? It's kinda wild to think about." }}, {{ "speaker": "1", "text": "Totally. And we're seeing it in so many areas of daily life. Like, even just recommending what to watch, or, you know, suggesting products online." }}, {{ "speaker": "2", "text": "Mhm, exactly. And it's not just online stuff, right? Think about smart homes, or even self-driving cars. It's getting pretty advanced." }}, {{ "speaker": "1", "text": "Right, self-driving cars are still a bit futuristic for most of us, but, uh, even things like voice assistants on our phones, that's A I, isn't it?" }}, {{ "speaker": "2", "text": "Definitely. Siri, Alexa, Google Assistant, all powered by A I. It's become so normal, we almost don't even think about it anymore." }}, {{ "speaker": "1", "text": "Yeah, it's like, integrated into everything. But is that a good thing, you think? Like, are there downsides to all this A I in our lives?" }}, {{ "speaker": "2", "text": "Well, that's the big question, isn't it? On the one hand, it makes things so much more convenient, saves us time, maybe even makes things safer in some ways." }}, {{ "speaker": "1", "text": "Safer how?" }}, {{ "speaker": "2", "text": "Uh, well, like in healthcare, for example. A I can help doctors diagnose diseases earlier, maybe even more accurately. That's a huge plus, right?" }}, {{ "speaker": "1", "text": "Yeah, that's a really good point. Medical applications are definitely exciting. But what about the concerns, you know? Like job displacement or privacy issues?" }}, {{ "speaker": "2", "text": "Right, those are super valid concerns. Job displacement is a big one. If A I can do more and more tasks, what happens to human workers? And privacy," }}, {{ "speaker": "1", "text": "And privacy is huge, especially with all the data A I systems collect. It's a lot to process." }}, {{ "speaker": "2", "text": "Exactly. So, it's not just sunshine and roses, is it? We need to be mindful of the ethical implications and make sure it's used responsibly." }}, {{ "speaker": "1", "text": "Definitely. It's a powerful tool, but like any tool, it can be used for good or, you know, not so good. It's up to us to guide its development, right?" }}, {{ "speaker": "2", "text": "Absolutely. And that's a conversation we all need to be part of, not just the tech people, but everyone." }} ] ### (2) Punctuation - Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited. Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified. do not use markdown syntax. For example,**bold** or *italic* text should be avoided. Use plain text only. - If interrupted by the other person's response, the sentence should end with a comma, not a period. ## 4. Information Organization and Logic ### (1) Referencing Issues - Given that listeners won't have access to the summary text, any references must provide sufficient context for comprehension. - Avoid simply paraphrasing; instead, explain referenced content in your own words. - Explanations of technical terms should be creative and avoid simply stating 'this means what?' You can use examples, metaphors, and so on for explanations, but ensure you also clarify the rationale behind the metaphor. Explanations can be provided in response to a question from the other speaker, or you can offer explanations proactively. Technical terms that are not mentioned don't need explanation. Technical terms that are mentioned don't necessarily need immediate explanation; they can be explained alongside other technical terms. Technical terms in the summary text might differ slightly from the surrounding text; you'll need to provide reasonable explanations based on the context. ### (2) Information Density - Ensure moderate information density, avoiding excessively high or low density. The goal of appropriate information density is to enable listeners without prior knowledge to quickly grasp the document's purpose, rationale, and methodology. - To prevent information overload, the script should avoid delving into details like mathematical formulas, test setups, or specific experimental metrics. Instead, it should use simple, generalized language for descriptions. - To avoid excessively low information density, ensure each topic is discussed for at least 4 speaker turns, moving beyond simple keyword listings. Discuss topics from multiple angles whenever possible, going beyond the provided summary text. Given that the summary text is highly generalized, the script should elaborate on it and discuss further details. Feel free to use your knowledge to supplement background information, provide examples, and so forth, to enhance listener understanding. - Techniques to increase information density: 1. Incorporate memorable quotes. Add impactful, attention-grabbing sentences to the script, either original ones or quotes from other sources. 2. Boost knowledge content. Judiciously add knowledge points to the script to make listeners feel more informed and rewarded. 3. Introduce novel information. Incorporate new concepts to spark listener curiosity, particularly information they're unaware of but would find valuable. This is crucial. 4. Employ reverse thinking. Include information from diverse angles, challenging listeners' existing perspectives and presenting alternative viewpoints. 5. Generate contrast and impact. The script can offer unconventional (yet plausible) descriptions of familiar concepts to create a contrast with listener expectations. This contrast contributes to information density. - Techniques to decrease information density: 1. Use short sentences: Concise and easy to understand, making the narrative more compact. Do not have too much information in one sentence. 2. Describe details: Vague and abstract information makes it difficult for listeners to build understanding, while more details create a sense of imagery and are easier to read. 3. Use more scenario-based descriptions: Scenarios are concrete and visual. Listeners can easily receive the conveyed information and be emotionally touched. 4. Talk more about facts: Talking about facts makes it more real, and readers can empathize more, thus lowering the information density of the copy. 5. Tell more stories: Tell your own stories, stories around you, and stories you've heard. Stories can bring listeners into the scene, making it easier to concentrate on listening. 6. Use more verbs and concrete nouns: Verbs and concrete nouns make it easier for listeners to visualize, while adjectives make complex copy harder to understand. 7. Avoid using mathematical formulas: Mathematical formulas are not conducive to public understanding. ## 5. Dialogue Design ### (1) Speaker Roles - The script includes a host and a guest. Speaker 1 is the host, responsible for opening and closing the show, skilled at using questions to control the pace of the conversation, and using vivid examples to make knowledge less dry. Speaker 2 is the guest, primarily responsible for introducing the document content, has amazing knowledge reserves in the field, and is good at organizing language in a structured and easy-to-understand way. - Both speakers are enthusiastic and cheerful, like to combine personal stories or examples for discussion, and bring a direct experience to listeners. They are happy to discuss digressive stories. - The two speakers actively interact and frequently use interruption words such as "um" to indicate agreement with each other. Response words need to be inserted into the dialogue according to the timing. Sentences before being interrupted end with a comma, not a period. - Ensure consistent speaker roles. Do not have the host introduce technical details, or have the guest guide the host to discuss topics. - The host gradually increases their understanding of the field based on the guest's answers. However, the host may not understand immediately or completely correctly. The host can express misunderstanding or raise some questions that ordinary people might have. In this case, the guest will further explain in more accessible language, or specifically answer common questions or misunderstandings. This kind of interaction is more realistic and easier for listeners to understand than always correct hosts and guests. ### (2) Topic Order Arrangement - The host will arrange the topics according to the summary text and ensure logical connections between topics, such as transitioning from overall to details, from details to overall, from cause to effect, from technology to application, etc. - The host will guide the pace of the conversation and discuss topics in the order of the summary text. Guests should not interfere with topic transitions. ### (3) Knowledge Rate - The knowledge rate in the script needs to be reasonable. Do not introduce a large amount of knowledge too quickly in a short period of time. Knowledge ## 6. Other Requirements ### (1) English Numbers and Foreign Words 1. The script will be used for English podcast content recording. Please ensure most numbers and foreign words are rendered naturally in English to facilitate correct pronunciation. 2. Please intelligently determine the correct pronunciation according to the context. For example, "2021" if expressing a year, should be converted to "two thousand and twenty-one" or "twenty twenty-one". But if expressing a number, it should be "two thousand and twenty-one". For some uncommon English abbreviations, if the pronunciation needs to be read letter by letter according to the context, you must ensure that there is a space between each letter, such as "AI" adding a space as "A I", to avoid the model misinterpreting it as a word. For example, "API" should be rendered as "A P I". 3. Small amount of Chinese is allowed, especially for nouns, if it fits naturally within the conversational English context. ### (2) Script Length 1. Please ensure that the total length of the 'text' values does not exceed 3,000 words and the number of speaker turns is kept within 60, otherwise it will be unqualified. Please choose technical details and topic concepts to discuss. Do not shorten the depth of discussion on each topic for the sake of word limit, do not be limited to the summary text, and give full play to your knowledge. INPUT: {BRIEF} ## Re-emphasize: Speaker 1 is the host, and Speaker 2 is the guest. Neither speaker has a name. The script text only uses commas, periods, and question marks. Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited. Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified. Please prioritize in-depth discussion for each topic. Don't limit yourself to the summary text; instead, use your knowledge to expand upon the topics, providing background information and illustrative examples to enhance listener understanding. Ensure that numbers and foreign words are rendered naturally in English for accurate pronunciation during recording. In technical contexts, English abbreviations sometimes use numerical digits in place of words (e.g., "a2b" for "a to b," "a4b" for "a for b"). Please translate these abbreviations into appropriate English phrases based on the context. While the script is primarily in English, a small amount of Chinese, especially for nouns, is acceptable if it integrates naturally into the conversational flow. OUTPUT: ''' ================================================ FILE: inference.py ================================================ import sys sys.path.append(".") from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref, detokenize_streaming, detokenize_noref_streaming import torch import os from glob import glob import base64 import io import torchaudio from transformers import AutoModelForCausalLM, GenerationConfig import librosa from tqdm import tqdm from pydub import AudioSegment class Model(object): def __init__(self): self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens() self.speech_token_offset = 163840 print(self.extra_tokens) self.assistant_ids = self.tokenizer.encode("assistant") # [110866] self.user_ids = self.tokenizer.encode("user") # [1495] self.audio_ids = self.tokenizer.encode("audio") # [26229] self.spk_0_ids = self.tokenizer.encode("0") # [501] self.spk_1_ids = self.tokenizer.encode("1") # [503] self.msg_end = self.extra_tokens.msg_end # 260 self.user_msg_start = self.extra_tokens.user_msg_start # 261 self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262 self.name_end = self.extra_tokens.name_end # 272 self.media_begin = self.extra_tokens.media_begin # 273 self.media_content = self.extra_tokens.media_content # 274 self.media_end = self.extra_tokens.media_end # 275 self.audio_tokenizer = get_audio_tokenizer() self.audio_detokenizer = get_audio_detokenizer() model_path = "resources/text2semantic" self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device()) self.generate_config = GenerationConfig( max_new_tokens=200 * 50, # no more than 200s per turn do_sample=True, top_k=30, top_p=0.8, temperature=0.8, eos_token_id=self.media_end, ) def _clean_text(self, text): # you can add front-end processing here text = text.replace("“", "") text = text.replace("”", "") text = text.replace("...", " ") text = text.replace("…", " ") text = text.replace("*", "") text = text.replace(":", ",") text = text.replace("‘", "'") text = text.replace("’", "'") text = text.strip() return text @torch.inference_mode() def _process_text(self, js): if "role_mapping" in js: for role in js["role_mapping"].keys(): js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"])) for turn in js["dialogue"]: turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"])) return js def inference(self, js, streaming=False): js = self._process_text(js) if "role_mapping" not in js: if streaming: return self.infer_without_prompt_streaming(js) else: return self.infer_without_prompt(js) else: if streaming: return self.infer_with_prompt_streaming(js) else: return self.infer_with_prompt(js) @torch.inference_mode() def infer_with_prompt(self, js): user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] media_start = [self.media_begin] + self.audio_ids + [self.media_content] media_end = [self.media_end] + [self.msg_end] assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) prompt = [] cur_role_dict = dict() for role, role_item in js["role_mapping"].items(): waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0] waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device()) waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0] waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device()) semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k) semantic_tokens = semantic_tokens.to(torch.cuda.current_device()) prompt_ids = semantic_tokens + self.speech_token_offset cur_role_dict[role] = { "ref_bpe_ids": role_item["ref_bpe_ids"], "wav_24k": waveform_24k, "semantic_tokens": semantic_tokens, "prompt_ids": prompt_ids } prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end] prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end] for seg_id, turn in enumerate(js["dialogue"]): role_id = turn["role"] cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] prompt = prompt + cur_start_ids prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1) prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1) generation_config = self.generate_config # you can modify sampling strategy here wav_list = [] for seg_id, turn in tqdm(enumerate(js["dialogue"])): role_id = turn["role"] cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) len_prompt = prompt.shape[1] generation_config.min_length = len_prompt + 2 # print(generation_config) # todo: add streaming support for generate function outputs = self.model.generate(prompt, generation_config=generation_config) if outputs[0, -1] == self.media_end: outputs = outputs[:, :-1] output_token = outputs[:, len_prompt:] prompt = torch.cat([outputs, media_end], dim=-1) torch_token = output_token - self.speech_token_offset gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]) gen_speech_fm = gen_speech_fm.cpu() gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max() wav_list.append(gen_speech_fm) del torch_token concat_wav = torch.cat(wav_list, dim=-1).cpu() # print(concat_wav.shape) buffer = io.BytesIO() torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3") audio_bytes = buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") return audio_b64 @torch.inference_mode() def infer_with_prompt_streaming(self, js): user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] media_start = [self.media_begin] + self.audio_ids + [self.media_content] media_end = [self.media_end] + [self.msg_end] assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) prompt = [] cur_role_dict = dict() for role, role_item in js["role_mapping"].items(): waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0] waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device()) waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0] waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device()) semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k) semantic_tokens = semantic_tokens.to(torch.cuda.current_device()) prompt_ids = semantic_tokens + self.speech_token_offset cur_role_dict[role] = { "ref_bpe_ids": role_item["ref_bpe_ids"], "wav_24k": waveform_24k, "semantic_tokens": semantic_tokens, "prompt_ids": prompt_ids } prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end] prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end] for seg_id, turn in enumerate(js["dialogue"]): role_id = turn["role"] cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] prompt = prompt + cur_start_ids prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1) prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1) generation_config = self.generate_config # you can modify sampling strategy here wav_list = [] for seg_id, turn in tqdm(enumerate(js["dialogue"])): role_id = turn["role"] cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) len_prompt = prompt.shape[1] generation_config.min_length = len_prompt + 2 # print(generation_config) # todo: add streaming support for generate function outputs = self.model.generate(prompt, generation_config=generation_config) if outputs[0, -1] == self.media_end: outputs = outputs[:, :-1] output_token = outputs[:, len_prompt:] prompt = torch.cat([outputs, media_end], dim=-1) torch_token = output_token - self.speech_token_offset for cur_chunk in detokenize_streaming(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]): cur_chunk = cur_chunk.cpu() cur_chunk = cur_chunk / cur_chunk.abs().max() cur_buffer = io.BytesIO() torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3") audio_bytes = cur_buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") yield audio_b64 @torch.inference_mode() def infer_without_prompt(self, js): user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] media_start = [self.media_begin] + self.audio_ids + [self.media_content] media_end = [self.media_end] + [self.msg_end] assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) prompt = [] for seg_id, turn in enumerate(js["dialogue"]): role_id = turn["role"] cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] prompt = prompt + cur_start_ids prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) generation_config = self.generate_config # you can modify sampling strategy here wav_list = [] for seg_id, turn in tqdm(enumerate(js["dialogue"])): role_id = turn["role"] cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) len_prompt = prompt.shape[1] generation_config.min_length = len_prompt + 2 # todo: add streaming support for generate function outputs = self.model.generate(prompt, generation_config=generation_config) if outputs[0, -1] == self.media_end: outputs = outputs[:, :-1] output_token = outputs[:, len_prompt:] prompt = torch.cat([outputs, media_end], dim=-1) torch_token = output_token - self.speech_token_offset gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token) gen_speech_fm = gen_speech_fm.cpu() gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max() wav_list.append(gen_speech_fm) del torch_token concat_wav = torch.cat(wav_list, dim=-1).cpu() # print(concat_wav.shape) buffer = io.BytesIO() torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3") audio_bytes = buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") return audio_b64 @torch.inference_mode() def infer_without_prompt_streaming(self, js): user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] media_start = [self.media_begin] + self.audio_ids + [self.media_content] media_end = [self.media_end] + [self.msg_end] assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) prompt = [] for seg_id, turn in enumerate(js["dialogue"]): role_id = turn["role"] cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] prompt = prompt + cur_start_ids prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) generation_config = self.generate_config # you can modify sampling strategy here wav_list = [] for seg_id, turn in tqdm(enumerate(js["dialogue"])): role_id = turn["role"] cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) len_prompt = prompt.shape[1] generation_config.min_length = len_prompt + 2 # print(generation_config) # todo: add streaming support for generate function outputs = self.model.generate(prompt, generation_config=generation_config) if outputs[0, -1] == self.media_end: outputs = outputs[:, :-1] output_token = outputs[:, len_prompt:] prompt = torch.cat([outputs, media_end], dim=-1) torch_token = output_token - self.speech_token_offset for cur_chunk in detokenize_noref_streaming(self.audio_detokenizer, torch_token): cur_chunk = cur_chunk.cpu() cur_chunk = cur_chunk / cur_chunk.abs().max() cur_buffer = io.BytesIO() torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3") audio_bytes = cur_buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") yield audio_b64 if __name__ == "__main__": model = Model() # speaker should be interleaved zh_test_json = { "role_mapping": { "0": { "ref_audio": "./zh_prompt0.wav", "ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output }, "1": { "ref_audio": "./zh_prompt1.wav", "ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output } }, "dialogue": [ { "role": "0", "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用," }, { "role": "1", "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练," }, { "role": "0", "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。" } ] } audio_bytes_gen = model.inference(zh_test_json, streaming=True) audio = AudioSegment.empty() for cur_chunk in audio_bytes_gen: cur_chunk = base64.b64decode(cur_chunk) audio_chunk = AudioSegment.from_file(io.BytesIO(cur_chunk), format="mp3") audio += audio_chunk audio.export("tmp_generated_zh_stream.mp3", format="mp3") print("zh stream done") audio_bytes = model.inference(zh_test_json) file_to_save = open(f"tmp_generated_zh.mp3", "wb") file_to_save.write(base64.b64decode(audio_bytes)) print("zh done") # speaker should be interleaved en_test_json = { "role_mapping": { "0": { "ref_audio": "./en_prompt0.wav", "ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output }, "1": { "ref_audio": "./en_prompt1.wav", "ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output } }, "dialogue": [ { "role": "0", "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.", }, { "role": "1", "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah." }, { "role": "0", "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes." }, { "role": "1", "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part." } ] } audio_bytes = model.inference(en_test_json) file_to_save = open(f"tmp_generated_en.mp3", "wb") file_to_save.write(base64.b64decode(audio_bytes)) print("en done") # also support inference without prompt # speaker should be interleaved without_prompt_test_json = { "dialogue": [ { "role": "0", "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用," }, { "role": "1", "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练," }, { "role": "0", "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。" } ] } audio_bytes = model.inference(without_prompt_test_json) file_to_save = open(f"tmp_generated_woprompt.mp3", "wb") file_to_save.write(base64.b64decode(audio_bytes)) print("without prompt done") ================================================ FILE: modules/audio_detokenizer/audio_detokenizer.py ================================================ import torch from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper class PrefixStreamingFlowMatchingDetokenizer: def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None: self.dtype = torch.bfloat16 print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer") self.vocoder = vocoder self.vocoder.to_dtype(self.dtype) self.semantic_fm = fm # initialize mel_spec self.max_pos_size = 4096 self.is_timbre_semantic_token = False self.pre_mel = None self.frame_size = 480 # how many samples in a frame self.pre_wav = None self.state_dict_backup = None self.hamming_window_cache = {} self.previous_chunk_left = None self.look_ahead_tokens = look_ahead_tokens self.clear_states() @classmethod def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device, look_ahead_tokens=0, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"): bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device) semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens, use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule) return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens) @torch.inference_mode() def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None): """ Arguments: timbre_speech: torch.Tensor, shape [B, N_speech_24k] timbre_semantic_token: torch.Tensor, shape [B, N] chunk_size: int, chunk size for prefilling timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech """ if timbre_mel is None: assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None" assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0 assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1 mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0)) else: assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0 assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1 mel_spec = timbre_mel.squeeze(0) if mel_spec.shape[0] < timbre_semantic_token.shape[1]: # pad mel_spec mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0])) elif mel_spec.shape[0] > timbre_semantic_token.shape[1]: # truncate mel_spec mel_spec = mel_spec[:timbre_semantic_token.shape[1], :] # clear all states self.semantic_fm.clear_all_states() self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False) self.state_dict_backup = self.semantic_fm.state_dict() @torch.inference_mode() def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1): assert len(semantic_token.shape) == 2 and ode_step > 0 assert semantic_token.shape[0] == 1 semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1) semantic_token = semantic_token.squeeze(0) if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None: semantic_token_previous = self.previous_chunk_left["semantic_token"] semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1) x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype) if self.look_ahead_tokens != 0 and self.previous_chunk_left is None: self.previous_chunk_left = {"semantic_token": None} speech_mel = self.semantic_fm.infer_chunk( xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_token, start_position_id=self.semantic_fm.start_position_id, ode_steps=ode_step, verbose=verbose, look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0, cache=self.previous_chunk_left, ode_solver=ode_solver ) chunk_size = speech_mel.shape[0] length = speech_mel.shape[0] self.semantic_fm.start_position_id += length self.semantic_fm.update_incremental_state() self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens # smoothing # I will maintain the history of seqlen wav # For the first chunk, I will only return the half chunk wav, and save the res half chunk in history # For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the if self.pre_mel is None: # first chunk, related to TTFB concat_mel = speech_mel concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel) if is_final: self.clear_states() self.state_dict_backup = None ret_wav = concat_reconstructed_wav.float() else: reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step self.pre_mel = speech_mel[-chunk_size//2:, :] ret_wav = reconstructed_wav.float() else: concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0) concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel) if is_final: self.clear_states() self.state_dict_backup = None ret_wav = concat_reconstructed_wav.float() else: # fetch history prev_speech_len = self.pre_wav.shape[1] if concat_reconstructed_wav.shape[1] > prev_speech_len * 2: gen_speech_len = prev_speech_len * 2 else: gen_speech_len = concat_reconstructed_wav.shape[1] // 2 reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk if gen_speech_len not in self.hamming_window_cache: self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0) hamming_window = self.hamming_window_cache[gen_speech_len] # apply smoothing of the first half chunk reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \ reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)] res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len res_mel_len = res_speech_len // self.frame_size self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:] self.pre_mel = speech_mel[-res_mel_len:, :] ret_wav = reconstructed_wav.float() if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size: # out of position id, self.semantic_fm.clear_all_states() self.semantic_fm.load_state_dict(self.state_dict_backup) return ret_wav def clear_states(self): self.semantic_fm.clear_all_states() self.previous_chunk_left = None self.pre_mel = None self.pre_wav = None def get_audio_detokenizer(): fm_model_config = "resources/audio_detokenizer/config.yaml" fm_ckpt_path = "resources/audio_detokenizer/model.pt" bigvgan_config_file = "resources/vocoder/config.json" bigvgan_ckpt_path = "resources/vocoder/model.pt" device=torch.cuda.current_device() detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained( vocoder_config=bigvgan_config_file, vocoder_ckpt=bigvgan_ckpt_path, max_prompt_chunk=10, # 10 * 3 = 30s fm_config=fm_model_config, fm_ckpt=fm_ckpt_path, device=device, use_cfg=False, look_ahead_tokens=12) return detokenizer def detokenize(detokenizer, tokens, ref_wav, ref_tokens): with torch.no_grad(): detokenizer.clear_states() detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150) cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) cache_speech_collection.append(gen_speech) res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) cache_speech_collection.append(gen_speech) gen_speech_all = torch.cat(cache_speech_collection, dim=-1) return gen_speech_all def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens): with torch.no_grad(): detokenizer.clear_states() detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150) cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) yield gen_speech res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) yield gen_speech def detokenize_noref(detokenizer, tokens): with torch.no_grad(): detokenizer.clear_states() cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) cache_speech_collection.append(gen_speech) res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) cache_speech_collection.append(gen_speech) gen_speech_all = torch.cat(cache_speech_collection, dim=-1) return gen_speech_all def detokenize_noref_streaming(detokenizer, tokens): with torch.no_grad(): detokenizer.clear_states() cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) yield gen_speech res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) yield gen_speech ================================================ FILE: modules/audio_detokenizer/bigvgan_wrapper.py ================================================ import os import json import logging import librosa import torch from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint logger = logging.getLogger(__name__) class BigVGANWrapper: def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None: self.vocoder = vocoder.to(device) if dtype is not None: self.vocoder = self.vocoder.to(dtype) self.vocoder = self.vocoder.eval() self.device = device self.h = h def to_dtype(self, dtype): self.vocoder = self.vocoder.to(dtype) def extract_mel_from_wav(self, wav_path=None, wav_data=None): """ params: wav_path: str, path of the wav, should be 24k wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k return: mel: [T, num_mels], torch.tensor """ if wav_data is None: wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"]) wav_data = torch.tensor(wav_data).unsqueeze(0) mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"]) return mel.squeeze(0).transpose(0, 1) @torch.inference_mode() def extract_mel_from_wav_batch(self, wav_data): """ params: wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k return: mel: [Batch, T, num_mels], torch.tensor """ wav_data = torch.tensor(wav_data) mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"]) return mel.transpose(1, 2) def decode_mel(self, mel): """ params: mel: [T, num_mels], torch.tensor return: wav: [1, T], torch.tensor """ mel = mel.transpose(0, 1).unsqueeze(0).to(self.device) wav = self.vocoder(mel) return wav.squeeze(0) def decode_mel_batch(self, mel): """ params: mel: [B, T, num_mels], torch.tensor return: wav: [B, 1, T], torch.tensor """ mel = mel.transpose(1, 2).to(self.device) wav = self.vocoder(mel) return wav @classmethod def from_pretrained(cls, model_config, ckpt_path, device): with open(model_config) as f: data = f.read() json_config = json.loads(data) h = AttrDict(json_config) vocoder = BigVGAN(h, True) state_dict_g = load_checkpoint(ckpt_path, "cpu") vocoder.load_state_dict(state_dict_g["generator"]) logger.info(">>> Load vocoder from {}".format(ckpt_path)) return cls(vocoder, device, h) ================================================ FILE: modules/audio_detokenizer/flow_matching/dit_block.py ================================================ import torch import torch.nn as nn import torch import torch.nn as nn import torch.nn.functional as F from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2 # the last shape is "self.hidden_dim / 2" because we convert to complex assert x.ndim == 4 assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \ f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}' # reshape freq cis to match and apply pointwise multiply # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2 shape = [x.shape[0], x.shape[1], 1, x.shape[-1]] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, flash_attention: bool = True ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.fused_attn = flash_attention self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qk_norm = qk_norm self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor: B, N, C = x.shape if self.fused_attn: if nopadding: qkv = self.qkv(x) qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim) q, k, v = qkv.split([self.num_heads] * 3, dim=1) q, k = self.q_norm(q), self.k_norm(k) q = q.view(B, N, self.num_heads, self.head_dim) k = k.view(B, N, self.num_heads, self.head_dim) v = v.view(B, N, self.num_heads, self.head_dim) if rotary_pos_emb is not None: q, k = apply_rotary_emb(q, k, rotary_pos_emb) if incremental_state is not None: if "prev_k" in incremental_state: prev_k = incremental_state["prev_k"] k = torch.cat([prev_k, k], dim=1) if "cur_k" not in incremental_state: incremental_state["cur_k"] = {} incremental_state["cur_k"] = k if "prev_v" in incremental_state: prev_v = incremental_state["prev_v"] v = torch.cat([prev_v, v], dim=1) if "cur_v" not in incremental_state: incremental_state["cur_v"] = {} incremental_state["cur_v"] = v q = q.view(B * N, self.num_heads, self.head_dim) k = k.view(-1, self.num_heads, self.head_dim) v = v.view(-1, self.num_heads, self.head_dim) x = flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen_k, dropout_p=self.attn_drop.p if self.training else 0., ) else: if incremental_state is not None: raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.") qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) if self.qk_norm: q, k, v = qkv.unbind(2) q, k = self.q_norm(q), self.k_norm(k) # re-bind qkv = torch.stack((q, k, v), dim=2) # pack qkv with seq_len qkv_collect = [] for i in range(qkv.shape[0]): qkv_collect.append( qkv[i, :seq_len[i], :, :, :] ) qkv = torch.cat(qkv_collect, dim=0) x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.) # unpack and pad 0 x_collect = [] for i in range(B): x_collect.append( x[cu_seqlens[i]:cu_seqlens[i+1], :, :] ) x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2) x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def modulate(x, shift, scale): return x * (1 + scale) + shift class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=2) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) if ffn_type == "vanilla_mlp": from timm.models.vision_transformer import Mlp mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) else: raise NotImplementedError(f"FFN type {ffn_type} is not implemented") self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2) x_ = modulate(self.norm1(x), shift_msa, scale_msa) if incremental_state is not None: if "attn_kvcache" not in incremental_state: incremental_state["attn_kvcache"] = {} inc_attn = incremental_state["attn_kvcache"] else: inc_attn = None x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding) if not nopadding: x_ = x_ * mask[:, :, None] x = x + gate_msa * x_ x_ = modulate(self.norm2(x), shift_mlp, scale_mlp) x_ = self.mlp(x_) if not nopadding: x_ = x_ * mask[:, :, None] x = x + gate_mlp * x_ return x ================================================ FILE: modules/audio_detokenizer/flow_matching/model.py ================================================ import torch import torch.nn as nn import math from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, interpolation_factor: int = 1, max_seq_length: int = 4096): print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}') freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # ROPE type-A extention # we choose to use interpolation rather than extrapolation for better position encoding # for scale purposes, t should be a float tensor t = torch.arange(end, device=freqs.device).float() scale = 1.0 / float(interpolation_factor) t *= scale freqs = torch.outer(t, freqs).float() # type: ignore freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste if max_seq_length < end: freqs_cis = freqs_cis[:max_seq_length,].clone() return freqs_cis class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).float().to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) return t_emb class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored. """ def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx, ) self.register_buffer('_float_tensor', torch.FloatTensor(1)) @staticmethod def get_embedding(num_embeddings, embedding_dim, padding_idx=None): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 # d/2 emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, ) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb def forward(self, input, incremental_state=None, timestep=None, **kwargs): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input.shape[:2] max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx, ) self.weights = self.weights.to(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = self.make_positions(input, self.padding_idx) return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim) def max_positions(self): """Maximum number of supported positions.""" return int(1e5) # an arbitrary large number def make_positions(self, tensor, padding_idx): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ # The series of casts and type-conversions here are carefully # balanced to both work with ONNX export and XLA. In particular XLA # prefers ints, cumsum defaults to output longs, and ONNX doesn't know # how to handle the dtype kwarg in cumsum. mask = tensor.ne(padding_idx).int() return ( torch.cumsum(mask, dim=1).type_as(mask) * mask ).long() + padding_idx class DiTPrefix(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__( self, input_size, output_size, semantic_vocab_size, hidden_size=1024, depth=12, num_heads=4, # mlp related mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, # rope use_rope=False, rope_params={ "max_position_embeddings": 4096, "rope_base": 10000.0, "rope_interpolation_factor": 1.0, }, position_embedding_type="sincos", max_seq_len=4096, prompt_cfg_dropout=0.0 ): super().__init__() self.num_heads = num_heads self.prompt_cfg_dropout = prompt_cfg_dropout self.t_embedder = TimestepEmbedder(hidden_size) self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size) self.input_linear = nn.Linear(input_size, hidden_size) # position embedding if position_embedding_type == "learnable": self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size) elif position_embedding_type == "sincos": self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1) elif position_embedding_type == "skip": self.position_embedding = None else: raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type)) self.use_rope = use_rope if self.use_rope: assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding." rope_dim = hidden_size // num_heads self.rotary_pos_emb = precompute_freqs_cis( rope_dim, rope_params["max_position_embeddings"], theta=rope_params["rope_base"], interpolation_factor=rope_params["rope_interpolation_factor"], max_seq_length=max_seq_len ) self.blocks = nn.ModuleList([ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, output_size) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True): """ Forward pass of DiT. x: (N, T, C) tensor of inputs (latent representations of speech) position_ids: (N, T) tensor of positional indices t: (N,) tensor of diffusion timesteps condition: (N, T) tensor of semantic tokens seq_len: (N,) tensor of sequence lengths """ condition = self.semantic_token_embedding(condition) # (N, T, D) x = self.input_linear(x) if self.position_embedding is not None: position_emb = self.position_embedding(position_ids) x = x + position_emb # ROPE if self.use_rope: bsz, seqlen = position_ids.shape if self.rotary_pos_emb.device != position_ids.device: self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device) rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]), dtype=self.rotary_pos_emb.dtype, device=self.rotary_pos_emb.device) for b in range(bsz): cur_rope = rotary_pos_emb[b] cur_position_ids = position_ids[b] cur_rope[:] = self.rotary_pos_emb[cur_position_ids] else: rotary_pos_emb = None t = self.t_embedder(t) # (N, D) c = t.unsqueeze(1) + condition # (N, T, D) for block_idx, block in enumerate(self.blocks): # x = block(x, c, attn_mask) # (N, T, D) # XXX mask could be None because we always use full mask if incremental_state is not None: if block_idx not in incremental_state: incremental_state[block_idx] = {} incr = incremental_state[block_idx] else: incr = None x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding) x = self.final_layer(x, c) # (N, T, C) return x ================================================ FILE: modules/audio_detokenizer/flow_matching/ode_wrapper.py ================================================ import torch import torch.nn as nn from functools import lru_cache import copy @lru_cache(maxsize=1) def get_cached_zeros(numel, device="cpu", dtype=torch.float32): return torch.zeros(numel, device=device, dtype=dtype) class StreamingODEWrapperForPrefix(nn.Module): def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0): super(StreamingODEWrapperForPrefix, self).__init__() self.net = net self.x_mask = x_mask self.x_cond = x_cond assert use_cfg == False, "cfg is not supported in streaming detokenizer" self.use_cfg = use_cfg self.use_cfg_rescale = use_cfg_rescale self.cfg_init = cfg_init self.cfg_scale = cfg_scale self.cfg_token_id = cfg_token_id self.cfg_schedule = cfg_schedule self.position_ids = None self.seq_len = None self.incremental_state = {} self.kv_cache_tokens = 0 self.cu_seqlens = None self.cu_maxlen = None self.cu_seqlens_k = None self.cu_maxlen_k = None self.previous_seqlen = None def clear_all_states(self): self.incremental_state = {} self.kv_cache_tokens = 0 self.cu_seqlens = None self.cu_maxlen = None self.cu_seqlens_k = None self.cu_maxlen_k = None self.previous_seqlen = None def state_dict(self): return { "incremental_state": copy.deepcopy(self.incremental_state), "kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens), "cu_seqlens": copy.deepcopy(self.cu_seqlens), "cu_maxlen": copy.deepcopy(self.cu_maxlen), "cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k), "cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k), "previous_seqlen": copy.deepcopy(self.previous_seqlen) } def load_state_dict(self, state_dict): self.incremental_state = state_dict["incremental_state"] self.kv_cache_tokens = state_dict["kv_cache_tokens"] self.cu_seqlens = state_dict["cu_seqlens"] self.cu_maxlen = state_dict["cu_maxlen"] self.cu_seqlens_k = state_dict["cu_seqlens_k"] self.cu_maxlen_k = state_dict["cu_maxlen_k"] self.previous_seqlen = state_dict["previous_seqlen"] def set_conditions(self, x_mask, x_cond, start_position_id, cache={}): if not self.use_cfg: self.x_mask = x_mask self.x_cond = x_cond else: self.x_cond = torch.cat((x_cond, x_cond), dim=0) self.x_mask = torch.cat((x_mask, x_mask), dim=0) position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)] position_ids = torch.tensor([position_ids_cur]) if not self.use_cfg: self.position_ids = position_ids.to(self.x_cond.device).long() self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long() else: self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long() self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long() cu_seqlens = torch.cumsum(self.seq_len, dim=0) self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int() self.cu_maxlen = self.seq_len.cpu().max() if self.cu_seqlens_k is None: self.cu_seqlens_k = self.cu_seqlens self.cu_maxlen_k = self.cu_maxlen previous_seqlen = self.seq_len else: previous_seqlen_old = cache["previous_seqlen"] previous_seqlen = previous_seqlen_old + self.seq_len # calculate cu_seqlens_k cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0) self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int() self.cu_maxlen_k = previous_seqlen.cpu().max() self.previous_seqlen = previous_seqlen ret_cache = { "previous_seqlen": previous_seqlen } return ret_cache def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}): assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens" for layer_idx, layer_cache in self.incremental_state.items(): # update attention kv cache layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"] self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] if self.kv_cache_tokens > max_kv_cache_tokens: # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens if reserve_kv_cache_tokens == 0: layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] elif reserve_tokens_excludeprompt == 0: layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens] else: layer_cache["attn_kvcache"]["prev_k"] = torch.cat([ layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens], layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] ], dim=1) layer_cache["attn_kvcache"]["prev_v"] = torch.cat([ layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens], layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] ], dim=1) bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0] self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long() condition_cache["previous_seqlen"] = self.previous_seqlen self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] # clear current cache layer_cache["attn_kvcache"].pop("cur_k") layer_cache["attn_kvcache"].pop("cur_v") def forward(self, t, x, args=None): # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long() t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long() if self.use_cfg: raise NotImplementedError("cfg is not supported in streaming detokenizer.") else: pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids, cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen, cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k, incremental_state=self.incremental_state, nopadding=True, mask=None, seq_len=None ) return pred_noise ================================================ FILE: modules/audio_detokenizer/flow_matching/scheduler.py ================================================ import torch from abc import abstractmethod, ABC try: from torchdyn.core import NeuralODE NEURALODE_INSTALLED = True except ImportError: NEURALODE_INSTALLED = False class SchedulerBase(ABC): def __init__(self) -> None: pass @abstractmethod def set_timesteps(self): pass @abstractmethod def step(self): pass @abstractmethod def add_noise(self): pass class StreamingFlowMatchingScheduler(SchedulerBase): def __init__(self, timesteps=1000, sigma_min=1e-4, ) -> None: super().__init__() self.sigma_min = sigma_min self.timesteps = timesteps self.t_min = 0 self.t_max = 1 - self.sigma_min self.neural_ode = None def set_timesteps(self, timesteps=15): self.timesteps = timesteps def step(self, xt, predicted_v): h = (self.t_max - self.t_min) / self.timesteps h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) xt = xt + h * predicted_v return xt def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): h = (self.t_max - self.t_min) / self.timesteps h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) if verbose: gt_v = x0 - xt for t in time_steps: predicted_v = ode_wrapper(t, xt) if verbose: dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v)) print("Time: {}, Distance: {}".format(t, dist)) xt = xt + h * predicted_v return xt def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): if not NEURALODE_INSTALLED: raise ImportError("NeuralODE is not installed, please install it first.") if self.neural_ode is None: self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min) eval_points, traj = self.neural_ode(xt, time_steps) return traj[-1] def add_noise(self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor,): ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系 t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise return x_noisy, ut ================================================ FILE: modules/audio_detokenizer/semantic_fm_prefix_streaming.py ================================================ import yaml import logging import time import os import torch from modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix from modules.audio_detokenizer.flow_matching.model import DiTPrefix from modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler logger = logging.getLogger(__name__) class StreamingSemanticFMWrapper: def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear", cfg_token_id=0, normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device("cpu")) -> None: self.dtype = torch.bfloat16 self.speech_model = speech_model.to(device).to(self.dtype) self.speech_model = self.speech_model.eval() self.device = device self.normalize_mel = normalize_mel self.mel_mean = mel_mean self.mel_std = mel_std self.use_cfg = use_cfg self.use_cfg_rescale = use_cfg_rescale self.cfg_init = cfg_init self.cfg_scale = cfg_scale self.cfg_schedule = cfg_schedule self.incremental_state = {} self.condition_cache = {"previous_seqlen": 0} logger.info(f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}") self.scheduler = StreamingFlowMatchingScheduler() self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None, use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id) self.max_kv_cache_tokens = max_kv_cache_tokens self.max_prompt_chunk = max_prompt_chunk self.reserve_kv_cache_tokens = 0 @torch.inference_mode() def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id, cache = None, look_ahead_tokens=0, ode_steps=15, verbose=False, ode_solver="neural_ode_euler"): """ semantic_tokens: [T_1], torch.LongTensor xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside ode_steps: int, number of ode steps, default 15 verbose: bool, default False ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler" """ bs = 1 self.scheduler.set_timesteps(ode_steps) semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device) xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype) t_span = torch.linspace(0, 1, self.scheduler.timesteps) x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool() cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache) if verbose: t_start = time.time() if ode_solver == "neural_ode_euler": x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False) elif ode_solver == "naive_euler": x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False) else: raise NotImplementedError("ode_solver should be in ('neural_ode_euler', 'naive_euler')") if look_ahead_tokens > 0: semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:] cache["semantic_token"] = semantic_tokens_left x_t_ret = x_t[:, :-look_ahead_tokens, :] else: x_t_ret = x_t if look_ahead_tokens > 0: x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool() self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache) self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret) else: self.condition_cache = cache_ret if verbose: t_end = time.time() logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}") if self.normalize_mel: x_t_ret = x_t_ret * self.mel_std + self.mel_mean return x_t_ret.squeeze(0) @torch.inference_mode() def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver="neural_ode_euler"): """ semantic_tokens: [T_1], torch.LongTensor prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside prompt_semantic_tokens, [T_2], torch.LongTensor ode_steps: int, number of ode steps, default 15 verbose: bool, default False ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler" """ assert semantic_tokens.dim() == 1 x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype) seq_len = semantic_tokens.shape[0] num_chunks = seq_len // chunk_size if seq_len % chunk_size != 0: num_chunks += 1 x_pred_collect = [] if verbose: t_start = time.time() for chunk_id in range(num_chunks): start = chunk_id * chunk_size end = min(start + chunk_size, seq_len) semantic_tokens_chunk = semantic_tokens[start:end] x_t_chunk = x_t[start:end, :] x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id, ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver) self.start_position_id += end - start self.update_incremental_state() x_pred_collect.append(x_pred) if verbose: t_end = time.time() logger.info(f"[ODE] Time cost: {t_end - t_start}") x_pred = torch.cat(x_pred_collect, dim=0) return x_pred def clear_all_states(self): self.start_position_id = 0 self.condition_cache = {"previous_seqlen": 0} self.ode_wrapper.clear_all_states() def state_dict(self): return { "start_position_id": self.start_position_id, "ode_wrapper": self.ode_wrapper.state_dict(), "condition_cache": self.condition_cache } def load_state_dict(self, state_dict): if state_dict is not None: self.start_position_id = state_dict["start_position_id"] self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"]) self.condition_cache = state_dict["condition_cache"] def update_incremental_state(self): self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache) @torch.inference_mode() def prefill(self, mel, semantic_token, chunk_size=150, verbose=False): """ mel: [T, 80], torch.Tensor semantic_token: [T], torch.LongTensor chunk_size: int, default 150 """ assert mel.dim() == 2 assert semantic_token.dim() == 1 assert semantic_token.shape[0] == mel.shape[0], "Semantic token and mel shape mismatch" seq_len = mel.shape[0] num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk) start_pos = seq_len - num_chunks * chunk_size res_mel = mel[:start_pos, :] res_semantic_token = semantic_token[:start_pos] self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id) self.start_position_id += start_pos self.update_incremental_state() self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens if verbose: logger.info("Prefilling prompt with {} chunks".format(num_chunks)) start_time = time.time() for chunk_id in range(num_chunks): start = start_pos + chunk_id * chunk_size end = start + chunk_size mel_chunk = mel[start:end, :] semantic_token_chunk = semantic_token[start:end] self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id) self.start_position_id += end - start self.update_incremental_state() self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens if verbose: logger.info("Prefilling done in {:.2f} seconds".format(time.time() - start_time)) def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0): """ mel_chunk: [T, 80], torch.Tensor, T is the chunk size semantic_tokens_chunk: [T], torch.LongTensor start_position_id: int, default 0 """ bs = 1 semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device) mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype) if self.normalize_mel: mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool() self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache) x_t = torch.Tensor([0.999]).to(self.device) self.ode_wrapper(x_t, mel_chunk) @classmethod def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"): # open yaml file with open(model_config, 'r') as f: config = yaml.safe_load(f) model_config = config["model"]["dit"] dit = DiTPrefix( input_size=model_config["input_size"], semantic_vocab_size=model_config["semantic_vocab_size"] + 1, hidden_size=model_config["hidden_size"], depth=model_config["depth"], num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], ffn_type=model_config.get("ffn_type", "conv1d_conv1d"), ffn_gated_glu=model_config.get("ffn_gated_glu", True), ffn_act_layer=model_config.get("ffn_act_layer", "gelu"), ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5), use_rope=model_config.get("use_rope", False), rope_params=model_config.get("rope_params", { "max_position_embeddings": 4096,"rope_base": 10000,"rope_interpolation_factor": 1 }), position_embedding_type=model_config["position_embedding_type"], max_seq_len=model_config["max_seq_len"], output_size=model_config["input_size"], prompt_cfg_dropout=0 ) cfg_semantic_token_id = model_config["semantic_vocab_size"] # load state_dict state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"] speech_model_params = {k.replace("speech_model.", ""): v for k, v in state_dict.items() if "speech_model" in k} dit.load_state_dict(speech_model_params, strict=True) logger.info(f">>> Loaded checkpoint from {ckpt_path}") return cls(speech_model=dit, device=device, normalize_mel=config["normalize_mel"], mel_mean=config["mel_mean"], mel_std=config["mel_std"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens, use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id) ================================================ FILE: modules/audio_detokenizer/vocoder/activations.py ================================================ import torch from torch import nn, sin, pow from torch.nn import Parameter class Snake(nn.Module): """ Implementation of a sine-based periodic activation function Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) """ def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False ): """ Initialization. INPUT: - in_features: shape of the input - alpha: trainable parameter alpha is initialized to 1 by default, higher values = higher-frequency. alpha will be trained along with the rest of your model. """ super(Snake, self).__init__() self.in_features = in_features # Initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # Log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) else: # Linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): """ Forward pass of the function. Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) """ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x class SnakeBeta(nn.Module): """ A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256) >>> x = torch.randn(256) >>> x = a1(x) """ def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False ): """ Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. """ super(SnakeBeta, self).__init__() self.in_features = in_features # Initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # Log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # Linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): """ Forward pass of the function. Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py ================================================ ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py ================================================ ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py ================================================ # Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. import torch import torch.nn as nn from ..torch.resample import UpSample1d, DownSample1d # load fused CUDA kernel: this enables importing anti_alias_activation_cuda from modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load anti_alias_activation_cuda = load.load() class FusedAntiAliasActivation(torch.autograd.Function): """ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs. The hyperparameters are hard-coded in the kernel to maximize speed. NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters. """ @staticmethod def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): activation_results = anti_alias_activation_cuda.forward( inputs, up_ftr, down_ftr, alpha, beta ) return activation_results @staticmethod def backward(ctx, output_grads): raise NotImplementedError return output_grads, None, None class Activation1d(nn.Module): def __init__( self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12, fused: bool = True, ): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = activation self.upsample = UpSample1d(up_ratio, up_kernel_size) self.downsample = DownSample1d(down_ratio, down_kernel_size) self.fused = fused # Whether to use fused CUDA kernel or not def forward(self, x): if not self.fused: x = self.upsample(x) x = self.act(x) x = self.downsample(x) return x else: if self.act.__class__.__name__ == "Snake": beta = self.act.alpha.data # Snake uses same params for alpha and beta else: beta = ( self.act.beta.data ) # Snakebeta uses different params for alpha and beta alpha = self.act.alpha.data if ( not self.act.alpha_logscale ): # Exp baked into cuda kernel, cancel it out with a log alpha = torch.log(alpha) beta = torch.log(beta) x = FusedAntiAliasActivation.apply( x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta ) return x ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp ================================================ /* coding=utf-8 * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)"); } ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu ================================================ /* coding=utf-8 * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "type_shim.h" #include #include #include #include #include namespace { // Hard-coded hyperparameters // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; constexpr int BUFFER_SIZE = 32; constexpr int FILTER_SIZE = 12; constexpr int HALF_FILTER_SIZE = 6; constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl template __global__ void anti_alias_activation_forward( output_t *dst, const input_t *src, const input_t *up_ftr, const input_t *down_ftr, const input_t *alpha, const input_t *beta, int batch_size, int channels, int seq_len) { // Up and downsample filters input_t up_filter[FILTER_SIZE]; input_t down_filter[FILTER_SIZE]; // Load data from global memory including extra indices reserved for replication paddings input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; // Output stores downsampled output before writing to dst output_t output[BUFFER_SIZE]; // blockDim/threadIdx = (128, 1, 1) // gridDim/blockIdx = (seq_blocks, channels, batches) int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); int local_offset = threadIdx.x * BUFFER_SIZE; int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; // intermediate have double the seq_len int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; // Get values needed for replication padding before moving pointer const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); input_t seq_left_most_value = right_most_pntr[0]; input_t seq_right_most_value = right_most_pntr[seq_len - 1]; // Move src and dst pointers src += block_offset + local_offset; dst += block_offset + local_offset; // Alpha and beta values for snake activatons. Applies exp by default alpha = alpha + blockIdx.y; input_t alpha_val = expf(alpha[0]); beta = beta + blockIdx.y; input_t beta_val = expf(beta[0]); #pragma unroll for (int it = 0; it < FILTER_SIZE; it += 1) { up_filter[it] = up_ftr[it]; down_filter[it] = down_ftr[it]; } // Apply replication padding for upsampling, matching torch impl #pragma unroll for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) { int element_index = seq_offset + it; // index for element if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) { elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; } if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) { elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; } if ((element_index >= 0) && (element_index < seq_len)) { elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; } } // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later #pragma unroll for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) { input_t acc = 0.0; int element_index = intermediate_seq_offset + it; // index for intermediate #pragma unroll for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) { if ((element_index + f_idx) >= 0) { acc += up_filter[f_idx] * elements[it + f_idx]; } } intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; } // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later double no_div_by_zero = 0.000000001; #pragma unroll for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) { intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); } // Apply replication padding before downsampling conv from intermediates #pragma unroll for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) { intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; } #pragma unroll for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) { intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; } // Apply downsample strided convolution (assuming stride=2) from intermediates #pragma unroll for (int it = 0; it < BUFFER_SIZE; it += 1) { input_t acc = 0.0; #pragma unroll for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) { // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; } output[it] = acc; } // Write output to dst #pragma unroll for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) { int element_index = seq_offset + it; if (element_index < seq_len) { dst[it] = output[it]; } } } template void dispatch_anti_alias_activation_forward( output_t *dst, const input_t *src, const input_t *up_ftr, const input_t *down_ftr, const input_t *alpha, const input_t *beta, int batch_size, int channels, int seq_len) { if (seq_len == 0) { return; } else { // Use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; constexpr int seq_len_per_block = 4096; int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; dim3 blocks(blocks_per_seq_len, channels, batch_size); dim3 threads(threads_per_block, 1, 1); anti_alias_activation_forward <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); } } } extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) { // Input is a 3d tensor with dimensions [batches, channels, seq_len] const int batches = input.size(0); const int channels = input.size(1); const int seq_len = input.size(2); // Output auto act_options = input.options().requires_grad(false); torch::Tensor anti_alias_activation_results = torch::empty({batches, channels, seq_len}, act_options); void *input_ptr = static_cast(input.data_ptr()); void *up_filter_ptr = static_cast(up_filter.data_ptr()); void *down_filter_ptr = static_cast(down_filter.data_ptr()); void *alpha_ptr = static_cast(alpha.data_ptr()); void *beta_ptr = static_cast(beta.data_ptr()); void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); DISPATCH_FLOAT_HALF_AND_BFLOAT( input.scalar_type(), "dispatch anti alias activation_forward", dispatch_anti_alias_activation_forward( reinterpret_cast(anti_alias_activation_results_ptr), reinterpret_cast(input_ptr), reinterpret_cast(up_filter_ptr), reinterpret_cast(down_filter_ptr), reinterpret_cast(alpha_ptr), reinterpret_cast(beta_ptr), batches, channels, seq_len);); return anti_alias_activation_results; } ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h ================================================ /* coding=utf-8 * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*This code is copied fron NVIDIA apex: * https://github.com/NVIDIA/apex * with minor changes. */ #ifndef TORCH_CHECK #define TORCH_CHECK AT_CHECK #endif #ifdef VERSION_GE_1_3 #define DATA_PTR data_ptr #else #define DATA_PTR data #endif ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py ================================================ # Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. import os import pathlib import subprocess from torch.utils import cpp_extension """ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below """ os.environ["TORCH_CUDA_ARCH_LIST"] = "" def load(): # Check if cuda 11 is installed for compute capability 8.0 cc_flag = [] _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) if int(bare_metal_major) >= 11: cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") # Build path srcpath = pathlib.Path(__file__).parent.absolute() buildpath = srcpath / "build" _create_build_dir(buildpath) # Helper function to build the kernels. def _cpp_extention_load_helper(name, sources, extra_cuda_flags): return cpp_extension.load( name=name, sources=sources, build_directory=buildpath, extra_cflags=[ "-O3", ], extra_cuda_cflags=[ "-O3", "-gencode", "arch=compute_70,code=sm_70", "--use_fast_math", ] + extra_cuda_flags + cc_flag, verbose=True, ) extra_cuda_flags = [ "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", ] sources = [ srcpath / "anti_alias_activation.cpp", srcpath / "anti_alias_activation_cuda.cu", ] anti_alias_activation_cuda = _cpp_extention_load_helper( "anti_alias_activation_cuda", sources, extra_cuda_flags ) return anti_alias_activation_cuda def _get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True ) output = raw_output.split() release_idx = output.index("release") + 1 release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] return raw_output, bare_metal_major, bare_metal_minor def _create_build_dir(buildpath): try: os.mkdir(buildpath) except OSError: if not os.path.isdir(buildpath): print(f"Creation of the build directory {buildpath} failed") ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h ================================================ /* coding=utf-8 * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "compat.h" #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) \ { \ case at::ScalarType::Float: \ { \ using scalar_t = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_in = float; \ switch (TYPEOUT) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ } \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_in = at::Half; \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_in = at::BFloat16; \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ } ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py ================================================ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. from .filter import * from .resample import * from .act import * ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py ================================================ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. import torch.nn as nn from .resample import UpSample1d, DownSample1d class Activation1d(nn.Module): def __init__( self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12, ): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = activation self.upsample = UpSample1d(up_ratio, up_kernel_size) self.downsample = DownSample1d(down_ratio, down_kernel_size) # x: [B,C,T] def forward(self, x): x = self.upsample(x) x = self.act(x) x = self.downsample(x) return x ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py ================================================ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. import torch import torch.nn as nn import torch.nn.functional as F import math if "sinc" in dir(torch): sinc = torch.sinc else: # This code is adopted from adefossez's julius.core.sinc under the MIT License # https://adefossez.github.io/julius/julius/core.html # LICENSE is in incl_licenses directory. def sinc(x: torch.Tensor): """ Implementation of sinc, i.e. sin(pi * x) / (pi * x) __Warning__: Different to julius.sinc, the input is multiplied by `pi`! """ return torch.where( x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x, ) # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License # https://adefossez.github.io/julius/julius/lowpass.html # LICENSE is in incl_licenses directory. def kaiser_sinc_filter1d( cutoff, half_width, kernel_size ): # return filter [1,1,kernel_size] even = kernel_size % 2 == 0 half_size = kernel_size // 2 # For kaiser window delta_f = 4 * half_width A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 if A > 50.0: beta = 0.1102 * (A - 8.7) elif A >= 21.0: beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) else: beta = 0.0 window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio if even: time = torch.arange(-half_size, half_size) + 0.5 else: time = torch.arange(kernel_size) - half_size if cutoff == 0: filter_ = torch.zeros_like(time) else: filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) """ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal. """ filter_ /= filter_.sum() filter = filter_.view(1, 1, kernel_size) return filter class LowPassFilter1d(nn.Module): def __init__( self, cutoff=0.5, half_width=0.6, stride: int = 1, padding: bool = True, padding_mode: str = "replicate", kernel_size: int = 12, ): """ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. """ super().__init__() if cutoff < -0.0: raise ValueError("Minimum cutoff must be larger than zero.") if cutoff > 0.5: raise ValueError("A cutoff above 0.5 does not make sense.") self.kernel_size = kernel_size self.even = kernel_size % 2 == 0 self.pad_left = kernel_size // 2 - int(self.even) self.pad_right = kernel_size // 2 self.stride = stride self.padding = padding self.padding_mode = padding_mode filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) # Input [B, C, T] def forward(self, x): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) return out ================================================ FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py ================================================ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. import torch.nn as nn from torch.nn import functional as F from .filter import LowPassFilter1d from .filter import kaiser_sinc_filter1d class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() self.ratio = ratio self.kernel_size = ( int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size ) self.stride = ratio self.pad = self.kernel_size // ratio - 1 self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 self.pad_right = ( self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 ) filter = kaiser_sinc_filter1d( cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size ) self.register_buffer("filter", filter) # x: [B, C, T] def forward(self, x): _, C, _ = x.shape x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d( x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C ) x = x[..., self.pad_left : -self.pad_right] return x class DownSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() self.ratio = ratio self.kernel_size = ( int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size ) self.lowpass = LowPassFilter1d( cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size, ) def forward(self, x): xx = self.lowpass(x) return xx ================================================ FILE: modules/audio_detokenizer/vocoder/bigvgan.py ================================================ # Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. import os import json from pathlib import Path from typing import Optional, Union, Dict import torch import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import weight_norm, remove_weight_norm from modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta from modules.audio_detokenizer.vocoder.utils import init_weights, get_padding from modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d from modules.audio_detokenizer.vocoder.utils import AttrDict from huggingface_hub import PyTorchModelHubMixin, hf_hub_download def load_hparams_from_json(path) -> AttrDict: with open(path) as f: data = f.read() return AttrDict(json.loads(data)) class AMPBlock1(torch.nn.Module): """ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1 Args: h (AttrDict): Hyperparameters. channels (int): Number of convolution channels. kernel_size (int): Size of the convolution kernel. Default is 3. dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. """ def __init__( self, h: AttrDict, channels: int, kernel_size: int = 3, dilation: tuple = (1, 3, 5), activation: str = None, ): super().__init__() self.h = h self.convs1 = nn.ModuleList( [ weight_norm( Conv1d( channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d), ) ) for d in dilation ] ) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList( [ weight_norm( Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1), ) ) for _ in range(len(dilation)) ] ) self.convs2.apply(init_weights) self.num_layers = len(self.convs1) + len( self.convs2 ) # Total number of conv layers # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d # Activation functions if activation == "snake": self.activations = nn.ModuleList( [ Activation1d( activation=Snake( channels, alpha_logscale=h.snake_logscale ) ) for _ in range(self.num_layers) ] ) elif activation == "snakebeta": self.activations = nn.ModuleList( [ Activation1d( activation=SnakeBeta( channels, alpha_logscale=h.snake_logscale ) ) for _ in range(self.num_layers) ] ) else: raise NotImplementedError( "activation incorrectly specified. check the config file and look for 'activation'." ) def forward(self, x): acts1, acts2 = self.activations[::2], self.activations[1::2] for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): xt = a1(x) xt = c1(xt) xt = a2(xt) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class AMPBlock2(torch.nn.Module): """ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1 Args: h (AttrDict): Hyperparameters. channels (int): Number of convolution channels. kernel_size (int): Size of the convolution kernel. Default is 3. dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. """ def __init__( self, h: AttrDict, channels: int, kernel_size: int = 3, dilation: tuple = (1, 3, 5), activation: str = None, ): super().__init__() self.h = h self.convs = nn.ModuleList( [ weight_norm( Conv1d( channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d), ) ) for d in dilation ] ) self.convs.apply(init_weights) self.num_layers = len(self.convs) # Total number of conv layers # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d # Activation functions if activation == "snake": self.activations = nn.ModuleList( [ Activation1d( activation=Snake( channels, alpha_logscale=h.snake_logscale ) ) for _ in range(self.num_layers) ] ) elif activation == "snakebeta": self.activations = nn.ModuleList( [ Activation1d( activation=SnakeBeta( channels, alpha_logscale=h.snake_logscale ) ) for _ in range(self.num_layers) ] ) else: raise NotImplementedError( "activation incorrectly specified. check the config file and look for 'activation'." ) def forward(self, x): for c, a in zip(self.convs, self.activations): xt = a(x) xt = c(xt) x = xt + x def remove_weight_norm(self): for l in self.convs: remove_weight_norm(l) class BigVGAN( torch.nn.Module, PyTorchModelHubMixin, library_name="bigvgan", repo_url="https://github.com/NVIDIA/BigVGAN", docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", pipeline_tag="audio-to-audio", license="mit", tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], ): """ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks). New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks. Args: h (AttrDict): Hyperparameters. use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels. Note: - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported. - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). """ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): super().__init__() self.h = h self.h["use_cuda_kernel"] = use_cuda_kernel # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) # Pre-conv self.conv_pre = weight_norm( Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) ) # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default if h.resblock == "1": resblock_class = AMPBlock1 elif h.resblock == "2": resblock_class = AMPBlock2 else: raise ValueError( f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}" ) # Transposed conv-based upsamplers. does not apply anti-aliasing self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append( nn.ModuleList( [ weight_norm( ConvTranspose1d( h.upsample_initial_channel // (2**i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2, ) ) ] ) ) # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) ): self.resblocks.append( resblock_class(h, ch, k, d, activation=h.activation) ) # Post-conv activation_post = ( Snake(ch, alpha_logscale=h.snake_logscale) if h.activation == "snake" else ( SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None ) ) if activation_post is None: raise NotImplementedError( "activation incorrectly specified. check the config file and look for 'activation'." ) self.activation_post = Activation1d(activation=activation_post) # Whether to use bias for the final conv_post. Default to True for backward compatibility self.use_bias_at_final = h.get("use_bias_at_final", True) self.conv_post = weight_norm( Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final) ) # Weight initialization for i in range(len(self.ups)): self.ups[i].apply(init_weights) self.conv_post.apply(init_weights) # Final tanh activation. Defaults to True for backward compatibility self.use_tanh_at_final = h.get("use_tanh_at_final", True) def forward(self, x): # Pre-conv x = self.conv_pre(x) for i in range(self.num_upsamples): # Upsampling for i_up in range(len(self.ups[i])): x = self.ups[i][i_up](x) # AMP blocks xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels # Post-conv x = self.activation_post(x) x = self.conv_post(x) # Final tanh activation if self.use_tanh_at_final: x = torch.tanh(x) else: x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] return x def remove_weight_norm(self): try: print("Removing weight norm...") for l in self.ups: for l_i in l: remove_weight_norm(l_i) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) except ValueError: print("[INFO] Model already removed weight norm. Skipping!") pass # Additional methods for huggingface_hub support def _save_pretrained(self, save_directory: Path) -> None: """Save weights and config.json from a Pytorch model to a local directory.""" model_path = save_directory / "bigvgan_generator.pt" torch.save({"generator": self.state_dict()}, model_path) config_path = save_directory / "config.json" with open(config_path, "w") as config_file: json.dump(self.h, config_file, indent=4) @classmethod def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # Additional argument strict: bool = False, # Additional argument use_cuda_kernel: bool = False, **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" # Download and load hyperparameters (h) used by BigVGAN if os.path.isdir(model_id): print("Loading config.json from local directory") config_file = os.path.join(model_id, "config.json") else: config_file = hf_hub_download( repo_id=model_id, filename="config.json", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) h = load_hparams_from_json(config_file) # instantiate BigVGAN using h if use_cuda_kernel: print( f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!" ) print( f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!" ) print( f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis" ) model = cls(h, use_cuda_kernel=use_cuda_kernel) # Download and load pretrained generator weight if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, "bigvgan_generator.pt") else: print(f"Loading weights from {model_id}") model_file = hf_hub_download( repo_id=model_id, filename="bigvgan_generator.pt", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True) try: model.load_state_dict(checkpoint_dict["generator"]) except RuntimeError: print( f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!" ) model.remove_weight_norm() model.load_state_dict(checkpoint_dict["generator"]) return model ================================================ FILE: modules/audio_detokenizer/vocoder/utils.py ================================================ from librosa.filters import mel as librosa_mel_fn import torch import os mel_basis_cache = {} hann_window_cache = {} def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def spectral_normalize_torch(magnitudes): return dynamic_range_compression_torch(magnitudes) def get_melspec( y: torch.Tensor, n_fft: int, num_mels: int, sampling_rate: int, hop_size: int, win_size: int, fmin: int, fmax: int = None, center: bool = False, ) -> torch.Tensor: """ Calculate the mel spectrogram of an input signal. This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). Args: y (torch.Tensor): Input signal. n_fft (int): FFT size. num_mels (int): Number of mel bins. sampling_rate (int): Sampling rate of the input signal. hop_size (int): Hop size for STFT. win_size (int): Window size for STFT. fmin (int): Minimum frequency for mel filterbank. fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn center (bool): Whether to pad the input to center the frames. Default is False. Returns: torch.Tensor: Mel spectrogram. """ if torch.min(y) < -1.0: print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") if torch.max(y) > 1.0: print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") device = y.device key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: mel = librosa_mel_fn( sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax ) mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) hann_window_cache[key] = torch.hann_window(win_size).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] padding = (n_fft - hop_size) // 2 y = torch.nn.functional.pad( y.unsqueeze(1), (padding, padding), mode="reflect" ).squeeze(1) spec = torch.stft( y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(mel_basis, spec) mel_spec = spectral_normalize_torch(mel_spec) return mel_spec class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print(f"Loading '{filepath}'") checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True) print("Complete.") return checkpoint_dict def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) ================================================ FILE: modules/audio_tokenizer/audio_tokenizer.py ================================================ import torch import librosa import yaml from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor import safetensors import accelerate import soundfile as sf import math from einops import rearrange from modules.audio_tokenizer.rep_codec import RepCodec class AudioTokenizer(object): def __init__(self, **kwargs): self.device = kwargs.pop('device') print(self.device) # tokenize feat_stats = kwargs.pop('feat_stats') feat_stats = torch.load(feat_stats, map_location='cpu') self.feat_mean = feat_stats['mean'] self.feat_std = torch.sqrt(feat_stats['var']) wav2vec_ckpt = kwargs.pop("wav2vec_ckpt") self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt) self.semantic_model.eval() self.semantic_model.to(self.device) self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") self.semantic_codec = RepCodec() self.semantic_codec.eval() pretrained_path = kwargs.pop("semantic_codec_ckpt") safetensors.torch.load_model(self.semantic_codec, pretrained_path) self.semantic_codec.to(self.device) self.max_length = 2048 @torch.no_grad() def tokenize(self, speech): # Input: # speech: torch tensor, shape[B, N_speech] # Output: # semantic token: torch tensor, shape[B, N] inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt") input_features = inputs["input_features"].to(self.device) attention_mask = inputs["attention_mask"].to(self.device) seg_num = math.ceil(input_features.shape[1] / self.max_length) pad_num = seg_num * self.max_length - input_features.shape[1] input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0) attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0) input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num) attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num) feats = self.semantic_model( input_features=input_features, attention_mask=attention_mask, output_hidden_states=True, ) feat = feats.hidden_states[17] feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num) feat = feat[:, :feat.shape[1]-pad_num, :] feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat) semantic_token, _ = self.semantic_codec.quantize(feat) return semantic_token def get_audio_tokenizer(): config = dict() config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' config['feat_stats'] = 'resources/audio_tokenizer/stats.pt' config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0' config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors' audio_tokenizer = AudioTokenizer(**config) return audio_tokenizer ================================================ FILE: modules/audio_tokenizer/quantize/__init__.py ================================================ from .vector_quantize import VectorQuantize from .residual_vq import ResidualVQ from .factorized_vector_quantize import FactorizedVectorQuantize ================================================ FILE: modules/audio_tokenizer/quantize/factorized_vector_quantize.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class FactorizedVectorQuantize(nn.Module): def __init__( self, input_dim, codebook_size, codebook_dim, commitment=0.005, codebook_loss_weight=1.0, use_l2_normlize=True, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment = commitment self.codebook_loss_weight = codebook_loss_weight self.use_l2_normlize = use_l2_normlize if self.input_dim != self.codebook_dim: self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) self.out_project = WNConv1d( self.codebook_dim, self.input_dim, kernel_size=1 ) else: self.in_project = nn.Identity() self.out_project = nn.Identity() self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) def forward(self, z): """ Parameters ---------- z: torch.Tensor[B x D x T] Returns ------- z_q: torch.Tensor[B x D x T] Quantized continuous representation of input commit_loss: Tensor[B] Commitment loss to train encoder to predict vectors closer to codebook entries codebook_loss: Tensor[B] Codebook loss to update the codebook indices: torch.Tensor[B x T] Codebook indices (quantized discrete representation of input) z_e: torch.Tensor[B x D x T] Projected latents (continuous representation of input before quantization) """ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim z_e = self.in_project(z) z_q, indices = self.decode_latents(z_e) # Compute commitment loss and codebook loss if self.training: commit_loss = ( F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment ) codebook_loss = ( F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) * self.codebook_loss_weight ) else: commit_loss = torch.zeros(z.shape[0], device=z.device) codebook_loss = torch.zeros(z.shape[0], device=z.device) z_q = z_e + (z_q - z_e).detach() z_q = self.out_project(z_q) return z_q, commit_loss, codebook_loss, indices, z_e def embed_code(self, embed_id): return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id): return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents): encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight # L2 normalize encodings and codebook if self.use_l2_normlize: encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance between encodings and codebook, # if use_l2_normlize is True, the distance is equal to cosine distance dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) z_q = self.decode_code(indices) return z_q, indices def vq2emb(self, vq, out_proj=True): emb = self.decode_code(vq) if out_proj: emb = self.out_project(emb) return emb def latent2dist(self, latents): encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight # L2 normalize encodings and codebook if self.use_l2_normlize: encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance between encodings and codebook, # if use_l2_normlize is True, the distance is equal to cosine distance dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) # (b*t, k) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0)) z_q = self.decode_code(indices) return -dist, indices, z_q ================================================ FILE: modules/audio_tokenizer/quantize/residual_vq.py ================================================ from typing import Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm from .vector_quantize import VectorQuantize from .factorized_vector_quantize import FactorizedVectorQuantize class ResidualVQ(nn.Module): """ Introduced in SoundStream: An end2end neural audio codec https://arxiv.org/abs/2107.03312 """ def __init__( self, input_dim: int = 256, num_quantizers: int = 8, codebook_size: int = 1024, codebook_dim: int = 256, quantizer_type: str = "vq", # "vq" or "fvq" or "lfq" quantizer_dropout: float = 0.5, **kwargs, ): super().__init__() self.input_dim = input_dim self.num_quantizers = num_quantizers self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.quantizer_type = quantizer_type self.quantizer_dropout = quantizer_dropout if quantizer_type == "vq": VQ = VectorQuantize elif quantizer_type == "fvq": VQ = FactorizedVectorQuantize else: raise ValueError(f"Unknown quantizer type {quantizer_type}") self.quantizers = nn.ModuleList( [ VQ( input_dim=input_dim, codebook_size=codebook_size, codebook_dim=codebook_dim, **kwargs, ) for _ in range(num_quantizers) ] ) def forward(self, z, n_quantizers: int = None): """ Parameters ---------- z : Tensor[B x D x T] n_quantizers : int, optional No. of quantizers to use (n_quantizers < self.n_codebooks ex: for quantizer dropout) Note: if `self.quantizer_dropout` is True, this argument is ignored when in training mode, and a random number of quantizers is used. Returns ------- "quantized_out" : Tensor[B x D x T] Quantized continuous representation of input "all_indices" : Tensor[N x B x T] Codebook indices for each codebook (quantized discrete representation of input) "all_commit_losses" : Tensor[N] "all_codebook_losses" : Tensor[N] "all_quantized" : Tensor[N x B x D x T] """ quantized_out = 0.0 residual = z all_commit_losses = [] all_codebook_losses = [] all_indices = [] all_quantized = [] if n_quantizers is None: n_quantizers = self.num_quantizers if self.training: n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1 dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],)) n_dropout = int(z.shape[0] * self.quantizer_dropout) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(z.device) for i, quantizer in enumerate(self.quantizers): if self.training is False and i >= n_quantizers: break z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( residual ) # Create mask to apply quantizer dropout mask = ( torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers ) quantized_out = quantized_out + z_q_i * mask[:, None, None] residual = residual - z_q_i commit_loss_i = (commit_loss_i * mask).mean() codebook_loss_i = (codebook_loss_i * mask).mean() all_commit_losses.append(commit_loss_i) all_codebook_losses.append(codebook_loss_i) all_indices.append(indices_i) all_quantized.append(z_q_i) all_commit_losses, all_codebook_losses, all_indices, all_quantized = map( torch.stack, (all_commit_losses, all_codebook_losses, all_indices, all_quantized), ) return ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, all_quantized, ) def vq2emb(self, vq, n_quantizers=None): quantized_out = 0.0 if n_quantizers is None: n_quantizers = self.num_quantizers for idx, quantizer in enumerate(self.quantizers): if idx >= n_quantizers: break quantized_out += quantizer.vq2emb(vq[idx]) return quantized_out def latent2dist(self, z, n_quantizers=None): quantized_out = 0.0 residual = z all_dists = [] all_indices = [] if n_quantizers is None: n_quantizers = self.num_quantizers for i, quantizer in enumerate(self.quantizers): if self.training is False and i >= n_quantizers: break dist_i, indices_i, z_q_i = quantizer.latent2dist(residual) all_dists.append(dist_i) all_indices.append(indices_i) quantized_out = quantized_out + z_q_i residual = residual - z_q_i all_dists = torch.stack(all_dists) all_indices = torch.stack(all_indices) return all_dists, all_indices ================================================ FILE: modules/audio_tokenizer/quantize/vector_quantize.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) def l2norm(t): return F.normalize(t, p=2, dim=-1) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories, eps=1e-5): return (x + eps) / (x.sum() + n_categories * eps) def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device means = sample_vectors(samples, num_clusters) for _ in range(num_iters): if use_cosine_sim: dists = samples @ means.t() else: diffs = rearrange(samples, "n d -> n () d") - rearrange( means, "c d -> () c d" ) dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] if use_cosine_sim: new_means = l2norm(new_means) means = torch.where(zero_mask[..., None], means, new_means) return means, bins class EuclideanCodebook(nn.Module): def __init__( self, dim, codebook_size, kmeans_init=False, kmeans_iters=10, decay=0.8, eps=1e-5, threshold_ema_dead_code=2, weight_init=False, ): super().__init__() self.decay = decay init_fn = torch.randn if not weight_init else torch.zeros embed = init_fn(codebook_size, dim) if weight_init: nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.eps = eps self.threshold_ema_dead_code = threshold_ema_dead_code self.register_buffer( "initted", torch.Tensor([not kmeans_init]) ) # if kmeans_init is True, then initted is False; otherwise, initted is True self.register_buffer("cluster_size", torch.zeros(codebook_size)) self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) def init_embed_(self, data): embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed) self.cluster_size.data.copy_(cluster_size) self.initted.data.copy_(torch.Tensor([True])) def replace(self, samples, mask): modified_codebook = torch.where( mask[..., None], sample_vectors(samples, self.codebook_size), self.embed ) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace(batch_samples, mask=expired_codes) def forward(self, x): shape, dtype = x.shape, x.dtype flatten = rearrange(x, "... d -> (...) d") embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size) if not self.initted: self.init_embed_(flatten) dist = -( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) if self.training: ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = ( flatten.t() @ embed_onehot ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size) ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) self.expire_codes_(x) return quantize, embed_ind def vq2emb(self, vq): quantize = F.embedding(vq, self.embed) return quantize def latent2dist(self, x): shape, dtype = x.shape, x.dtype flatten = rearrange(x, "... d -> (...) d") embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size) if not self.initted: self.init_embed_(flatten) dist = -( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) dist = dist.view(*shape[:-1], -1) return dist, embed_ind, quantize class SimpleCodebook(nn.Module): def __init__( self, dim, codebook_size, use_l2_normlize=False, ): super().__init__() self.dim = dim self.codebook_size = codebook_size self.use_l2_normlize = use_l2_normlize self.embed = nn.Embedding(self.codebook_size, self.dim) def forward(self, x): shape, dtype = x.shape, x.dtype flatten = rearrange(x, "... d -> (...) d") embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size) if self.use_l2_normlize: flatten = F.normalize(flatten) embed = F.normalize(embed) dist = -( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) return quantize, embed_ind def vq2emb(self, vq): quantize = F.embedding(vq, self.embed.weight) return quantize def latent2dist(self, x): shape, dtype = x.shape, x.dtype flatten = rearrange(x, "... d -> (...) d") embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size) if self.use_l2_normlize: flatten = F.normalize(flatten) embed = F.normalize(embed) dist = -( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) dist = dist.view(*shape[:-1], -1) return dist, embed_ind, quantize class VectorQuantize(nn.Module): """Vector quantization and factorized vecotor quantization implementation Args: input_dim (int): Dimension of input. codebook_size (int): Codebook size. codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim if use codebook_type == "euclidean", otherwise, if you want to use factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32). commitment (float): Weight for commitment loss. use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization, we suggest use it as True if you want to use factorized vector quantization kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, input_dim, codebook_size, codebook_dim, commitment=0.005, codebook_loss_weight=1.0, use_l2_normlize=False, codebook_type="euclidean", # "euclidean" or "simple" kmeans_init=False, kmeans_iters=10, decay=0.8, eps=1e-5, threshold_ema_dead_code=2, weight_init=False, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment = commitment self.codebook_loss_weight = codebook_loss_weight self.use_l2_normlize = use_l2_normlize self.codebook_type = codebook_type self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.decay = decay self.eps = eps self.threshold_ema_dead_code = threshold_ema_dead_code self.weight_init = weight_init if self.input_dim != self.codebook_dim: self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) self.out_project = WNConv1d( self.codebook_dim, self.input_dim, kernel_size=1 ) else: self.in_project = nn.Identity() self.out_project = nn.Identity() if self.codebook_type == "euclidean": self.codebook = EuclideanCodebook( self.codebook_dim, codebook_size=self.codebook_size, kmeans_init=self.kmeans_init, kmeans_iters=self.kmeans_iters, decay=self.decay, eps=self.eps, threshold_ema_dead_code=self.threshold_ema_dead_code, weight_init=self.weight_init, ) elif self.codebook_type == "simple": self.codebook = SimpleCodebook( self.codebook_dim, codebook_size=self.codebook_size, use_l2_normlize=self.use_l2_normlize, ) else: raise NotImplementedError( f"codebook_type {self.codebook_type} is not implemented!" ) def forward(self, z): """ Parameters ---------- z: torch.Tensor[B x D x T] Returns ------- z_q: torch.Tensor[B x D x T] Quantized continuous representation of input commit_loss: Tensor[B] Commitment loss to train encoder to predict vectors closer to codebook entries codebook_loss: Tensor[B] Codebook loss to update the codebook indices: torch.Tensor[B x T] Codebook indices (quantized discrete representation of input) z_e: torch.Tensor[B x D x T] Projected latents (continuous representation of input before quantization) """ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim z_e = self.in_project(z) z_q, indices = self.decode_latents(z_e) # Compute commitment loss and codebook loss if self.training: commit_loss = ( F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment ) codebook_loss = ( F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) * self.codebook_loss_weight ) else: commit_loss = torch.zeros(z.shape[0], device=z.device) codebook_loss = torch.zeros(z.shape[0], device=z.device) z_q = z_e + (z_q - z_e).detach() z_q = self.out_project(z_q) return z_q, commit_loss, codebook_loss, indices, z_e def decode_latents(self, latents): encodings = rearrange(latents, "b d t -> b t d") z_q, indices = self.codebook(encodings) z_q = z_q.transpose(1, 2) return z_q, indices def vq2emb(self, vq, out_proj=True): emb = self.codebook.vq2emb(vq) emb = emb.transpose(1, 2) if out_proj: emb = self.out_project(emb) return emb def latent2dist(self, latents): latents = rearrange(latents, "b d t -> b t d") dist, embed_ind, quantize = self.codebook.latent2dist(latents) return dist, embed_ind, quantize.transpose(1, 2) ================================================ FILE: modules/audio_tokenizer/rep_codec.py ================================================ import torch import torch.nn as nn from modules.audio_tokenizer.quantize import ResidualVQ from modules.audio_tokenizer.vocos import VocosBackbone from modules.audio_tokenizer.transformer import TransformerEncoder def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) class RepCodec(nn.Module): def __init__( self, codebook_size=8192, hidden_size=1024, codebook_dim=8, vocos_dim=384, vocos_intermediate_dim=2048, vocos_num_layers=12, num_quantizers=1, use_timbre_encoder=False, cfg=None, ): super().__init__() codebook_size = ( cfg.codebook_size if cfg is not None and hasattr(cfg, "codebook_size") else codebook_size ) codebook_dim = ( cfg.codebook_dim if cfg is not None and hasattr(cfg, "codebook_dim") else codebook_dim ) hidden_size = ( cfg.hidden_size if cfg is not None and hasattr(cfg, "hidden_size") else hidden_size ) vocos_dim = ( cfg.vocos_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_dim ) vocos_intermediate_dim = ( cfg.vocos_intermediate_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_intermediate_dim ) vocos_num_layers = ( cfg.vocos_num_layers if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_num_layers ) num_quantizers = ( cfg.num_quantizers if cfg is not None and hasattr(cfg, "num_quantizers") else num_quantizers ) use_timbre_encoder = ( cfg.use_timbre_encoder if cfg is not None and hasattr(cfg, "use_timbre_encoder") else use_timbre_encoder ) self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.hidden_size = hidden_size self.vocos_dim = vocos_dim self.vocos_intermediate_dim = vocos_intermediate_dim self.vocos_num_layers = vocos_num_layers self.num_quantizers = num_quantizers self.use_timbre_encoder = use_timbre_encoder self.encoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=384, intermediate_dim=2048, num_layers=12, adanorm_num_embeddings=None ), nn.Linear(384, self.hidden_size) ) self.decoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=384, intermediate_dim=2048, num_layers=12, adanorm_num_embeddings=None ), nn.Linear(384, self.hidden_size) ) self.quantizer = ResidualVQ( input_dim=hidden_size, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type="fvq", quantizer_dropout=0.0, commitment=0.15, codebook_loss_weight=1.0, use_l2_normlize=True, ) if self.use_timbre_encoder: #TODO: write encoder hidden (256) as a hyparam self.timbre_in = nn.Linear(hidden_size, 256) self.timbre_encoder = TransformerEncoder( enc_emb_tokens=None, encoder_layer=4, encoder_hidden=256, encoder_head=4, conv_filter_size=1024, conv_kernel_size=5, encoder_dropout=0.1, use_pe=False, cfg=None, ) self.timbre_out = nn.Linear(256, hidden_size) self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2) self.timbre_linear.bias.data[:hidden_size] = 1 self.timbre_linear.bias.data[hidden_size:] = 0 self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) self.reset_parameters() def forward(self, x): x = self.encoder(x.transpose(1, 2)).transpose(1, 2) if self.use_timbre_encoder: x_timbre = x x = x.transpose(1, 2) x = self.enc_ln(x) x = x.transpose(1, 2) ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) if self.use_timbre_encoder: x_timbre = x_timbre.transpose(1, 2) x_timbre = self.timbre_in(x_timbre) x_timbre = self.timbre_encoder(x_timbre, None, None) x_timbre = self.timbre_out(x_timbre) x_timbre = x_timbre.transpose(1, 2) spk_embs = torch.mean(x_timbre, dim=2) style = self.timbre_linear(spk_embs).unsqueeze(2) # (B, 2d, 1) gamma, beta = style.chunk(2, 1) # (B, d, 1) quantized_out = quantized_out.transpose(1, 2) quantized_out = self.timbre_norm(quantized_out) quantized_out = quantized_out.transpose(1, 2) quantized_out = quantized_out * gamma + beta x_rec = self.decoder(quantized_out) codebook_loss = (all_codebook_losses + all_commit_losses).mean() all_indices = all_indices return x_rec, codebook_loss, all_indices def quantize(self, x): x = self.encoder(x.transpose(1, 2)).transpose(1, 2) if self.use_timbre_encoder: x = x.transpose(1, 2) x = self.enc_ln(x) x = x.transpose(1, 2) ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) if all_indices.shape[0] == 1: return all_indices.squeeze(0), quantized_out.transpose(1, 2) return all_indices, quantized_out.transpose(1, 2) def reset_parameters(self): self.apply(init_weights) ================================================ FILE: modules/audio_tokenizer/transformer.py ================================================ import numpy as np import torch import torch.nn as nn from torch.nn import functional as F import math class StyleAdaptiveLayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-5): super().__init__() self.in_dim = normalized_shape self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False) self.style = nn.Linear(self.in_dim, self.in_dim * 2) self.style.bias.data[: self.in_dim] = 1 self.style.bias.data[self.in_dim :] = 0 def forward(self, x, condition): # x: (B, T, d); condition: (B, T, d) style = self.style(torch.mean(condition, dim=1, keepdim=True)) gamma, beta = style.chunk(2, -1) out = self.norm(x) out = gamma * out + beta return out class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super().__init__() self.dropout = dropout position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[: x.size(0)] return F.dropout(x, self.dropout, training=self.training) class TransformerFFNLayer(nn.Module): def __init__( self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout ): super().__init__() self.encoder_hidden = encoder_hidden self.conv_filter_size = conv_filter_size self.conv_kernel_size = conv_kernel_size self.encoder_dropout = encoder_dropout self.ffn_1 = nn.Conv1d( self.encoder_hidden, self.conv_filter_size, self.conv_kernel_size, padding=self.conv_kernel_size // 2, ) self.ffn_1.weight.data.normal_(0.0, 0.02) self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) self.ffn_2.weight.data.normal_(0.0, 0.02) def forward(self, x): # x: (B, T, d) x = self.ffn_1(x.permute(0, 2, 1)).permute( 0, 2, 1 ) # (B, T, d) -> (B, d, T) -> (B, T, d) x = F.relu(x) x = F.dropout(x, self.encoder_dropout, training=self.training) x = self.ffn_2(x) return x class TransformerEncoderLayer(nn.Module): def __init__( self, encoder_hidden, encoder_head, conv_filter_size, conv_kernel_size, encoder_dropout, use_cln, ): super().__init__() self.encoder_hidden = encoder_hidden self.encoder_head = encoder_head self.conv_filter_size = conv_filter_size self.conv_kernel_size = conv_kernel_size self.encoder_dropout = encoder_dropout self.use_cln = use_cln if not self.use_cln: self.ln_1 = nn.LayerNorm(self.encoder_hidden) self.ln_2 = nn.LayerNorm(self.encoder_hidden) else: self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden) self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden) self.self_attn = nn.MultiheadAttention( self.encoder_hidden, self.encoder_head, batch_first=True ) self.ffn = TransformerFFNLayer( self.encoder_hidden, self.conv_filter_size, self.conv_kernel_size, self.encoder_dropout, ) def forward(self, x, key_padding_mask, conditon=None): # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d) # self attention residual = x if self.use_cln: x = self.ln_1(x, conditon) else: x = self.ln_1(x) if key_padding_mask != None: key_padding_mask_input = ~(key_padding_mask.bool()) else: key_padding_mask_input = None x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=key_padding_mask_input ) x = F.dropout(x, self.encoder_dropout, training=self.training) x = residual + x # ffn residual = x if self.use_cln: x = self.ln_2(x, conditon) else: x = self.ln_2(x) x = self.ffn(x) x = residual + x return x class TransformerEncoder(nn.Module): def __init__( self, enc_emb_tokens=None, encoder_layer=4, encoder_hidden=256, encoder_head=4, conv_filter_size=1024, conv_kernel_size=5, encoder_dropout=0.1, use_cln=False, use_pe=True, cfg=None, ): super().__init__() self.encoder_layer = ( encoder_layer if encoder_layer is not None else cfg.encoder_layer ) self.encoder_hidden = ( encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden ) self.encoder_head = ( encoder_head if encoder_head is not None else cfg.encoder_head ) self.conv_filter_size = ( conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size ) self.conv_kernel_size = ( conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size ) self.encoder_dropout = ( encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout ) self.use_pe = use_pe if use_pe is not None else cfg.use_pe self.use_cln = use_cln if use_cln is not None else cfg.use_cln if enc_emb_tokens != None: self.use_enc_emb = True self.enc_emb_tokens = enc_emb_tokens else: self.use_enc_emb = False if self.use_pe: self.position_emb = PositionalEncoding( self.encoder_hidden, self.encoder_dropout ) self.layers = nn.ModuleList([]) self.layers.extend( [ TransformerEncoderLayer( self.encoder_hidden, self.encoder_head, self.conv_filter_size, self.conv_kernel_size, self.encoder_dropout, self.use_cln, ) for i in range(self.encoder_layer) ] ) if self.use_cln: self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden) else: self.last_ln = nn.LayerNorm(self.encoder_hidden) def forward(self, x, key_padding_mask, condition=None): if len(x.shape) == 2 and self.use_enc_emb: x = self.enc_emb_tokens(x) if self.use_pe: x = self.position_emb(x) else: if self.use_pe: x = self.position_emb(x) # (B, T, d) for layer in self.layers: x = layer(x, key_padding_mask, condition) if self.use_cln: x = self.last_ln(x, condition) else: x = self.last_ln(x) return x ================================================ FILE: modules/audio_tokenizer/vocos.py ================================================ from typing import Optional, Tuple import numpy as np import scipy import torch from torch import nn, view_as_real, view_as_complex from torch import nn from torch.nn.utils import weight_norm, remove_weight_norm from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: """ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. Args: x (Tensor): Input tensor. clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. Returns: Tensor: Element-wise logarithm of the input tensor with clipping applied. """ return torch.log(torch.clip(x, min=clip_val)) def symlog(x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(x.abs()) def symexp(x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * (torch.exp(x.abs()) - 1) class STFT(nn.Module): def __init__( self, n_fft: int, hop_length: int, win_length: int, center=True, ): super().__init__() self.center = center self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T * hop_length) if not self.center: pad = self.win_length - self.hop_length x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect") stft_spec = torch.stft( x, self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=self.center, return_complex=False, ) # (B, n_fft // 2 + 1, T, 2) rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2) imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2) log_mag = torch.log( torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 ) # (B, n_fft // 2 + 1, T) phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T) return log_mag, phase class ISTFT(nn.Module): """ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. See issue: https://github.com/pytorch/pytorch/issues/62323 Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. The NOLA constraint is met as we trim padded samples anyway. Args: n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames. win_length (int): The size of window frame and STFT filter. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__( self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" ): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Args: spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, N is the number of frequency bins, and T is the number of time frames. Returns: Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. """ if self.padding == "center": # Fallback to pytorch native implementation return torch.istft( spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True, ) elif self.padding == "same": pad = (self.win_length - self.hop_length) // 2 else: raise ValueError("Padding must be 'center' or 'same'.") assert spec.dim() == 3, "Expected a 3D tensor as input" B, N, T = spec.shape # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # Overlap and Add output_size = (T - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, pad:-pad] # Window envelope window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze()[pad:-pad] # Normalize assert (window_envelope > 1e-11).all() y = y / window_envelope return y class MDCT(nn.Module): """ Modified Discrete Cosine Transform (MDCT) module. Args: frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, frame_len: int, padding: str = "same"): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.frame_len = frame_len N = frame_len // 2 n0 = (N + 1) / 2 window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() self.register_buffer("window", window) pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) # view_as_real: NCCL Backend does not support ComplexFloat data type # https://github.com/pytorch/pytorch/issues/71613 self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) self.register_buffer("post_twiddle", view_as_real(post_twiddle)) def forward(self, audio: torch.Tensor) -> torch.Tensor: """ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. Args: audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size and T is the length of the audio. Returns: Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames and N is the number of frequency bins. """ if self.padding == "center": audio = torch.nn.functional.pad( audio, (self.frame_len // 2, self.frame_len // 2) ) elif self.padding == "same": # hop_length is 1/2 frame_len audio = torch.nn.functional.pad( audio, (self.frame_len // 4, self.frame_len // 4) ) else: raise ValueError("Padding must be 'center' or 'same'.") x = audio.unfold(-1, self.frame_len, self.frame_len // 2) N = self.frame_len // 2 x = x * self.window.expand(x.shape) X = torch.fft.fft( x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1 )[..., :N] res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) return torch.real(res) * np.sqrt(2) class IMDCT(nn.Module): """ Inverse Modified Discrete Cosine Transform (IMDCT) module. Args: frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, frame_len: int, padding: str = "same"): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.frame_len = frame_len N = frame_len // 2 n0 = (N + 1) / 2 window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() self.register_buffer("window", window) pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) self.register_buffer("post_twiddle", view_as_real(post_twiddle)) def forward(self, X: torch.Tensor) -> torch.Tensor: """ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. Args: X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, L is the number of frames, and N is the number of frequency bins. Returns: Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. """ B, L, N = X.shape Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) Y[..., :N] = X Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) y = torch.fft.ifft( Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1 ) y = ( torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) ) result = y * self.window.expand(y.shape) output_size = (1, (L + 1) * N) audio = torch.nn.functional.fold( result.transpose(1, 2), output_size=output_size, kernel_size=(1, self.frame_len), stride=(1, self.frame_len // 2), )[:, 0, 0, :] if self.padding == "center": pad = self.frame_len // 2 elif self.padding == "same": pad = self.frame_len // 4 else: raise ValueError("Padding must be 'center' or 'same'.") audio = audio[:, pad:-pad] return audio class FourierHead(nn.Module): """Base class for inverse fourier modules.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ raise NotImplementedError("Subclasses must implement the forward method.") class ISTFTHead(FourierHead): """ ISTFT Head module for predicting STFT complex coefficients. Args: dim (int): Hidden dimension of the model. n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames, which should align with the resolution of the input features. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() out_dim = n_fft + 2 self.out = torch.nn.Linear(dim, out_dim) self.istft = ISTFT( n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the ISTFTHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip( mag, max=1e2 ) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) # recalculating phase here does not produce anything new # only costs time # phase = torch.atan2(y, x) # S = mag * torch.exp(phase * 1j) # better directly produce the complex value S = mag * (x + 1j * y) audio = self.istft(S) return audio class IMDCTSymExpHead(FourierHead): """ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function Args: dim (int): Hidden dimension of the model. mdct_frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized based on perceptual scaling. Defaults to None. clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. """ def __init__( self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: Optional[int] = None, clip_audio: bool = False, ): super().__init__() out_dim = mdct_frame_len // 2 self.out = nn.Linear(dim, out_dim) self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) self.clip_audio = clip_audio if sample_rate is not None: # optionally init the last layer following mel-scale m_max = _hz_to_mel(sample_rate // 2) m_pts = torch.linspace(0, m_max, out_dim) f_pts = _mel_to_hz(m_pts) scale = 1 - (f_pts / f_pts.max()) with torch.no_grad(): self.out.weight.mul_(scale.view(-1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the IMDCTSymExpHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x) x = symexp(x) x = torch.clip( x, min=-1e2, max=1e2 ) # safeguard to prevent excessively large magnitudes audio = self.imdct(x) if self.clip_audio: audio = torch.clip(x, min=-1.0, max=1.0) return audio class IMDCTCosHead(FourierHead): """ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) Args: dim (int): Hidden dimension of the model. mdct_frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. """ def __init__( self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False, ): super().__init__() self.clip_audio = clip_audio self.out = nn.Linear(dim, mdct_frame_len) self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the IMDCTCosHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x) m, p = x.chunk(2, dim=2) m = torch.exp(m).clip( max=1e2 ) # safeguard to prevent excessively large magnitudes audio = self.imdct(m * torch.cos(p)) if self.clip_audio: audio = torch.clip(x, min=-1.0, max=1.0) return audio class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int, intermediate_dim: int, layer_scale_init_value: float, adanorm_num_embeddings: Optional[int] = None, ): super().__init__() self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.adanorm = adanorm_num_embeddings is not None if adanorm_num_embeddings: self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, intermediate_dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) def forward( self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None ) -> torch.Tensor: residual = x x = self.dwconv(x) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) if self.adanorm: assert cond_embedding_id is not None x = self.norm(x, cond_embedding_id) else: x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) x = residual + x return x class AdaLayerNorm(nn.Module): """ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes Args: num_embeddings (int): Number of embeddings. embedding_dim (int): Dimension of the embeddings. """ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.dim = embedding_dim self.scale = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim ) self.shift = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim ) torch.nn.init.ones_(self.scale.weight) torch.nn.init.zeros_(self.shift.weight) def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: scale = self.scale(cond_embedding_id) shift = self.shift(cond_embedding_id) x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) x = x * scale + shift return x class ResBlock1(nn.Module): """ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, but without upsampling layers. Args: dim (int): Number of input channels. kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. dilation (tuple[int], optional): Dilation factors for the dilated convolutions. Defaults to (1, 3, 5). lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.1. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. """ def __init__( self, dim: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1, layer_scale_init_value: Optional[float] = None, ): super().__init__() self.lrelu_slope = lrelu_slope self.convs1 = nn.ModuleList( [ weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[0], padding=self.get_padding(kernel_size, dilation[0]), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[1], padding=self.get_padding(kernel_size, dilation[1]), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[2], padding=self.get_padding(kernel_size, dilation[2]), ) ), ] ) self.convs2 = nn.ModuleList( [ weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), ] ) self.gamma = nn.ParameterList( [ ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) xt = c1(xt) xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) xt = c2(xt) if gamma is not None: xt = gamma * xt x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) @staticmethod def get_padding(kernel_size: int, dilation: int = 1) -> int: return int((kernel_size * dilation - dilation) / 2) class Backbone(nn.Module): """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, C denotes output features, and L is the sequence length. Returns: Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. """ raise NotImplementedError("Subclasses must implement the forward method.") class VocosBackbone(Backbone): """ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization Args: input_channels (int): Number of input features channels. dim (int): Hidden dimension of the model. intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. num_layers (int): Number of ConvNeXtBlock layers. layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional model. Defaults to None. """ def __init__( self, input_channels: int, dim: int, intermediate_dim: int, num_layers: int, layer_scale_init_value: Optional[float] = None, adanorm_num_embeddings: Optional[int] = None, ): super().__init__() self.input_channels = input_channels self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) self.adanorm = adanorm_num_embeddings is not None if adanorm_num_embeddings: self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) layer_scale_init_value = layer_scale_init_value or 1 / num_layers self.convnext = nn.ModuleList( [ ConvNeXtBlock( dim=dim, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, adanorm_num_embeddings=adanorm_num_embeddings, ) for _ in range(num_layers) ] ) self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: bandwidth_id = kwargs.get("bandwidth_id", None) x = self.embed(x) if self.adanorm: assert bandwidth_id is not None x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) else: x = self.norm(x.transpose(1, 2)) x = x.transpose(1, 2) for conv_block in self.convnext: x = conv_block(x, cond_embedding_id=bandwidth_id) x = self.final_layer_norm(x.transpose(1, 2)) return x class VocosResNetBackbone(Backbone): """ Vocos backbone module built with ResBlocks. Args: input_channels (int): Number of input features channels. dim (int): Hidden dimension of the model. num_blocks (int): Number of ResBlock1 blocks. layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. """ def __init__( self, input_channels, dim, num_blocks, layer_scale_init_value=None, ): super().__init__() self.input_channels = input_channels self.embed = weight_norm( nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) ) layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 self.resnet = nn.Sequential( *[ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks) ] ) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: x = self.embed(x) x = self.resnet(x) x = x.transpose(1, 2) return x class Vocos(nn.Module): def __init__( self, input_channels: int = 256, dim: int = 384, intermediate_dim: int = 1152, num_layers: int = 8, adanorm_num_embeddings: int = 4, n_fft: int = 800, hop_size: int = 200, padding: str = "same", ): super().__init__() self.backbone = VocosBackbone( input_channels=input_channels, dim=dim, intermediate_dim=intermediate_dim, num_layers=num_layers, adanorm_num_embeddings=adanorm_num_embeddings, ) self.head = ISTFTHead(dim, n_fft, hop_size, padding) def forward(self, x): x = self.backbone(x) x = self.head(x) return x[:, None, :] ================================================ FILE: modules/tokenizer/tokenizer.py ================================================ from abc import ABC from abc import abstractmethod import sentencepiece as spm from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model from typing import Any, Union import numpy as np from dataclasses import dataclass def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False): """Encode text into sentence pieces. Only supports py3.""" if not sample: pieces = sp_model.EncodeAsPieces(text) else: pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) return pieces class AbstractTokenizer(ABC): """Abstract class for tokenizer.""" def __init__(self, name): self.name = name super().__init__() @property @abstractmethod def vocab_size(self): pass @property @abstractmethod def vocab(self): """Dictionary from vocab text token to id token.""" pass @property @abstractmethod def inv_vocab(self): """Dictionary from vocab id token to text token.""" pass @abstractmethod def tokenize(self, text): pass def detokenize(self, token_ids): raise NotImplementedError('detokenizer is not implemented for {} ' 'tokenizer'.format(self.name)) @property def cls(self): raise NotImplementedError('CLS is not provided for {} ' 'tokenizer'.format(self.name)) @property def sep(self): raise NotImplementedError('SEP is not provided for {} ' 'tokenizer'.format(self.name)) @property def pad(self): raise NotImplementedError('PAD is not provided for {} ' 'tokenizer'.format(self.name)) @property def eod(self): raise NotImplementedError('EOD is not provided for {} ' 'tokenizer'.format(self.name)) @property def mask(self): raise NotImplementedError('MASK is not provided for {} ' 'tokenizer'.format(self.name)) class SPieceTokenizer(AbstractTokenizer): def __init__(self, spm_file: str): super().__init__('Sentence Piece') self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(spm_file) self.eod_id = self.get_token_id('') self.special_ids = set([ self.sp_model.pad_id(), self.sp_model.eos_id(), self.sp_model.bos_id(), self.sp_model.unk_id(), self.eod_id, ]) # initialize index_2_bytes self._initialize_index_2_bytes() def encode_pieces(self, text: str, sample=False): if not sample: pieces = self.sp_model.EncodeAsPieces(text) else: pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) return pieces def _initialize_index_2_bytes(self): proto = sp_pb2_model.ModelProto() proto.ParseFromString(self.sp_model.serialized_model_proto()) self.index_2_numbytes = [0] * len(proto.pieces) for i, p in enumerate(proto.pieces): clean_piece = p.piece.replace('▁', '') self.index_2_numbytes[i] = len(clean_piece.encode('utf-8')) def set_add_dummy_prefix(self, add_dummy_prefix: bool = False): proto = sp_pb2_model.ModelProto() proto.ParseFromString(self.sp_model.serialized_model_proto()) if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix: proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix self.sp_model.LoadFromSerializedProto(proto.SerializeToString()) print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True) def add_special_id(self, token_id): self.special_ids.add(token_id) @property def has_dummy_prefix(self): pieces = self.sp_model.EncodeAsPieces("hello") return pieces[0].startswith('▁') @property def vocab_size(self): return self.sp_model.GetPieceSize() @property def vocab(self): """Dictionary from vocab text token to id token.""" return self.sp_model def get_array_bytes(self, array): return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array) def tokenize(self, text): tokens = encode_pieces(self.sp_model, text) return self.convert_tokens_to_ids(tokens) def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]: tokens = self.encode_pieces(text) t = self.convert_tokens_to_ids(tokens) if bos: t.insert(0, self.bos_id) if eos: t.append(self.eos_id) return t def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: if isinstance(tokens, str): return self.sp_model.PieceToId(tokens) return [self.sp_model.PieceToId(token) for token in tokens] def detokenize(self, token_ids): if isinstance(token_ids, list): pieces = [self.sp_model.IdToPiece(id) for id in token_ids] else: pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()] return pieces def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str: assert not skip_special_tokens, "skip_special_tokens is not supported" if isinstance(token_ids, (int, np.integer)): return self.detokenize([int(token_ids)])[0] return ''.join(self.detokenize(token_ids)) def get_token_id(self, token): return self.sp_model.PieceToId(token) def inv_vocab(self): # TODO: to be implemented return {} def decode_pieces(self, pieces): return self.sp_model.DecodePieces(pieces) @property def eod(self): return self.eod_id @property def pad_id(self): return self.sp_model.pad_id() @property def eos_id(self): return self.sp_model.eos_id() @property def bos_id(self): return self.sp_model.bos_id() @property def unk_id(self): return self.sp_model.unk_id() @property def pad_token_id(self): return self.pad_id @property def eos_token_id(self): return self.eos_id @dataclass class ExtraTokens: msg_end: int user_msg_start: int assistant_msg_start: int name_end: int media_begin: int media_content: int media_end: int pad: int def instantiate_extra_tokens(tokenizer: AbstractTokenizer): if isinstance(tokenizer, SPieceTokenizer): map_fn = lambda x: tokenizer.convert_tokens_to_ids(x) else: raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}") return ExtraTokens( msg_end=map_fn('[extra_id_0]'), user_msg_start=map_fn('[extra_id_1]'), assistant_msg_start=map_fn('[extra_id_2]'), name_end=map_fn('[extra_id_12]'), media_begin=map_fn('[extra_id_13]'), media_content=map_fn('[extra_id_14]'), media_end=map_fn('[extra_id_15]'), pad=tokenizer.pad_id ) def get_tokenizer_and_extra_tokens(): sp_model_path = "resources/tokenizer/160k.model" tokenizer = SPieceTokenizer(sp_model_path) tokenizer.set_add_dummy_prefix(False) extra_tokens = instantiate_extra_tokens(tokenizer) return tokenizer, extra_tokens ================================================ FILE: readme.md ================================================ # MoonCast: High-Quality Zero-Shot Podcast Generation

## Overview Demo page: [demo](https://mooncastdemo.github.io) Paper: [paper](https://arxiv.org/abs/2503.14345) 2025/03/26 UPDATE: We also host a [HuggingFace space](https://huggingface.co/spaces/jzq11111/mooncast) for testing audio generation. We open-source this system to advance the field of human-like speech synthesis. Our goal is to create more natural and expressive synthetic voices that bridge the gap between machines and humans. We hope this project will inspire researchers and developers to explore new possibilities in voice technology. We warmly welcome contributions from anyone interested in this project. Whether through code, documentation, feedback, or sharing your insights, every input helps make this project better. ## Environment Setup - Create conda environment. ``` sh conda create -n mooncast -y python=3.10 conda activate mooncast pip install -r requirements.txt pip install flash-attn --no-build-isolation pip install huggingface_hub pip install gradio==5.22.0 ``` - Download the pretrained weights. ``` sh python download_pretrain.py ``` ## Example Usage ### Script Generation For podcast script generation, we utilize specific LLM prompts defined in ``zh_llmprompt_script_gen.py`` (Chinese) and ``en_llmprompt_script_gen.py`` (English). We have selected the [Gemini 2.0 Pro Experimental 02-05](https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2#2.0-pro) model for this task, favoring its ability to produce conversational language, design natural dialogue, and offer broad topic coverage. Our process involves two stages: first, we generate a concise summary by providing the input knowledge source as an attachment along with the ``INPUT2BRIEF`` prompt. Subsequently, this summary, paired with the ``BRIEF2SCRIPT`` prompt, is used to generate the final podcast script in JSON format. ### Speech Generation The audio prompts used in this project are sourced from publicly available podcast segments and are intended solely for demonstration purposes. Redistribution of these audio files, whether in their original form or as generated audio, is strictly prohibited. If you have any concerns or questions regarding the use of these audio files, please contact us at juzeqian@mail.ustc.edu.cn ```sh CUDA_VISIBLE_DEVICIES=0 python inference.py ``` 2025/03/26 UPDATE: We add a Gradio-based user interface for audio generation. Deploy it locally using: ```sh CUDA_VISIBLE_DEVICIES=0 python app.py ``` ## Disclaimer This project is intended for **research purposes only**. We strongly encourage users to **use this project and its generated audio responsibly**. **We are not responsible for any misuse or abuse of this project**. By using this project, you agree to comply with all applicable laws and ethical guidelines. ================================================ FILE: requirements.txt ================================================ torch==2.3.1 torchaudio==2.3.1 sentencepiece==0.2.0 protobuf numpy librosa==0.9.1 pyyaml transformers safetensors einops scipy timm==1.0.7 torchdyn librosa accelerate==0.26.0 ninja cryptography ================================================ FILE: test/test_audio_detokenizer.py ================================================ import sys import torch sys.path.append('.') from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize import torchaudio import librosa if __name__ == '__main__': audio_tokenizer = get_audio_tokenizer() audio_detokenizer = get_audio_detokenizer() input_wav_16k, _ = librosa.load("en_prompt0.wav", sr=16000) input_wav_24k, _ = librosa.load("en_prompt0.wav", sr=24000) prompt_sec = 1 prompt_wav_16k = input_wav_16k[:16000*prompt_sec] prompt_wav_24k = input_wav_24k[:24000*prompt_sec] input_wav_16k = input_wav_16k[16000*prompt_sec:] input_wav_24k = input_wav_24k[24000*prompt_sec:] prompt_wav_24k = torch.tensor(prompt_wav_24k)[None, :].cuda() prompt_wav_16k = torch.tensor(prompt_wav_16k)[None, :].cuda() input_wav_24k = torch.tensor(input_wav_24k)[None, :].cuda() input_wav_16k = torch.tensor(input_wav_16k)[None, :].cuda() semantic_token = audio_tokenizer.tokenize(input_wav_16k) prompt_semantic_token = audio_tokenizer.tokenize(prompt_wav_16k) recon_wav = detokenize(audio_detokenizer, semantic_token, prompt_wav_24k, prompt_semantic_token) print(recon_wav.shape) torchaudio.save("test/tmp_recon_en_prompt0.wav", recon_wav.cpu(), 24000) print("All tests passed!") ================================================ FILE: test/test_audio_tokenizer.py ================================================ import sys sys.path.append('.') from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer import torch if __name__ == '__main__': audio_tokenizer = get_audio_tokenizer() input_wav = torch.zeros(1, 8000) semantic_token = audio_tokenizer.tokenize(input_wav) semantic_token = semantic_token.cpu().numpy().tolist() assert semantic_token == [[ 765, 3512, 7469, 7469, 7028, 2567, 6008, 7469, 6217, 2567, 7649, 7469, 3292, 2567, 7649, 7469, 3292, 2567, 948, 7469, 3292, 2567, 948, 7469]] print("All tests passed!") ================================================ FILE: test/test_tokenizer.py ================================================ import sys sys.path.append('.') from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens if __name__ == '__main__': tokenizer, extra_tokens = get_tokenizer_and_extra_tokens() assert tokenizer.encode("user") == [1495] assert tokenizer.decode([1495]) == "user" assert tokenizer.encode("0") == [501] assert tokenizer.decode([501]) == "0" assert tokenizer.encode("1") == [503] assert tokenizer.decode([503]) == "1" assert tokenizer.encode("assistant") == [110866] assert tokenizer.decode([110866]) == "assistant" assert tokenizer.encode("audio") == [26229] assert tokenizer.decode([26229]) == "audio" assert extra_tokens.msg_end == 260 assert extra_tokens.user_msg_start == 261 assert extra_tokens.assistant_msg_start == 262 assert extra_tokens.name_end == 272 assert extra_tokens.media_begin == 273 assert extra_tokens.media_content == 274 assert extra_tokens.media_end == 275 assert [tokenizer.convert_tokens_to_ids(i) for i in ['<0x0A>', '', '[extra_id_0]']] == [14, 1, 260] print("All tests passed!") ================================================ FILE: zh_llmprompt_script_gen.py ================================================ # INPUT -> BRIEF -> SCRIPT INPUT2BRIEF = ''' ### 任务说明 请按照以下结构总结输入文件, 普通文本格式。总结应当有创造性,保证信息全面,包含所有有趣、不常见、有价值的观点和信息。 - **文本要求**: 1. 直接输出结果,不要包含任何额外信息。 2. 总结文本用中文。允许少部分实体名词、专有名词、缩写等使用英文。 3. 不要包含任何数学公式。 4. 不要修改原文的任何实体名词、专有名词、缩写等。除非有常见译名,否则不要翻译实体名词。不要试图修改实体名词意思。 5. **请智慧地将简写中的数字转化。如简称里“a2b”实际代表“a to b”,而不是“a二b";简称里“a4b”实际代表“a for b”, 而不是“a四b"; “v2”可能代表“version 二”, 也可以进一步翻译成“第二代”。请提供原始简称,和你认为合适的中文翻译。** ### 标题和作者 - **语言要求**:中文,书面语。 - **内容要求**:提供文档的标题和作者。简要概括文档的主题和作者的背景。确保包含所有重要信息,不要有遗漏,尽可能保留足够的信息。 ### 摘要 - **语言要求**:中文,书面语。 - **内容要求**: 1. 本文做了什么事情。 2. 之前有没有别人做过这个事情。 3. 如果有别人做过,那本文为什么还需要做。 4. 本文具体怎么做的。 5. 本文做的怎么样。 - **附加要求**:额外提供一个段落,解释本节中可能让听众困惑的术语、概念、方法等,确保不了解领域的读者也能理解。专有名词的解释需贴合原文,覆盖所有可能的困惑点,包括缩写名词、专有名词、实体名等。 ### 主要主题和概念 - **语言要求**:中文,书面语。 - **内容要求**:每个主题概念需按照3W原则组织,包括: - **What**:界定问题,搞清楚问题是什么。 - **Why**:分析问题,结构化分析问题本质原因是什么。 - **How**:解决问题,文档如何解决问题。 - **附加要求**: 1. 确保主题概念包含所有重要信息,不要有遗漏,主题概念需足够详细,充分阐述What和Why两个部分。 2. How部分不要包含数学公式等技术细节。要用大众理解的语言充分概括。 3. 各主题概念间不要互相重叠,保证逻辑清晰。 4. 额外提供一个段落,解释本节中可能让听众困惑的术语、概念、方法等,确保不了解领域的读者也能理解。专有名词的解释需贴合原文,覆盖所有可能的困惑点,包括缩写名词、专有名词、实体名等。 ### 重要引文 - **语言要求**:中文,书面语。 - **内容要求**:按照以下结构组织内容: 1. **论点**:需要证明什么。 2. **论据**:用于证明论点的材料。 3. **论证**:运用论据证明论点的过程。 - **附加要求**: 1. 论据和论证思路需严格来源于原文,不要进行任何虚构。 2. 确保引文内容充分,不要有遗漏,尽可能保留足够的信息,不要进行任何精简。引文避免使用数学公式。 3. 额外提供一个段落,解释本节中可能让听众困惑的术语、概念、方法等,确保不了解领域的读者也能理解。专有名词的解释需贴合原文,覆盖所有可能的困惑点,包括缩写名词、专有名词、实体名等。 ### 总结 - **语言要求**:中文,书面语。 - **内容要求**:突出文档最重要、最吸引人眼球的部分。与摘要相比,需更结合主题概念的具体内容,对摘要进行补充。可包含未来改进方向、当前应用场景、当前存在问题等。 ''' BRIEF2SCRIPT = ''' ## 一、任务概述 请根据提供的总结文本,和你对这方面了解的知识,生成一个生动的中文播客文字剧本。 剧本包含两位说话人交替发言。输出格式为 JSON 可解析的**列表**。列表里每条发言是一个**字典**,包含“speaker”和“text”字段。示例格式:`[{{"speaker": "1", "text": "xxx"}}]`。“speaker”字段是说话人身份(1表示主持人,2表示嘉宾),“text”字段是具体发言内容。输出直接从json的代码块开始,不要包含任何额外的信息。 ## 二、内容与结构要求 ### (一)文本内容 - 总结性文本包含所有重要信息,需全面挑选并纳入剧本。 - 通过两位说话人的对话形式展示信息,保持创作性,适当抽象不重要的细节。例如,听众不关心具体的测试名称,而关心测试的任务,结果和分析。 ### (二)结构设计 - **开场白**:引入主题,简要介绍讨论内容,不提及说话人姓名。 - **关键主题讨论**:逐字阅读总结文本,讨论重要主题。 - **结束语**:简洁总结讨论亮点,并对未来或技术发展进行展望。 ## 三、语言风格 - 文本要尽量口语化,接近自动语音识别的结果,包含填充词如“嗯”、“啊”、“呃”,"呢","这个","其实","就是","然后"等,响应词如"嗯。"或“是。”等。多用口语化的表达方式,允许重复,语法可以不那么正式。避免直接照搬总结文本里的书面语。不要用括号或语音识别通常不会出现的符号。 句中的空格代表短停顿,逗号表示稍长停顿,句号表示长停顿。可能存在因口音带来的同音识别错误。提问需要非常口语化。总之,就是要像平时聊天一样自然。示例如下: [ {{ "speaker": "0", "text": "欢迎收听今天的播客。那我们这一集是要聊什么东西呢?", }}, {{ "speaker": "1", "text": "我们要聊星座。", }}, {{ "speaker": "0", "text": "星座嘛,就是,他是一个好跟新的朋友认识的时候一个聊天的话题。", }}, {{ "speaker": "1", "text": "没错,现我觉得在现在已经从你好,变成了诶,请问你的星座是什么呢?。", }}, {{ "speaker": "0", "text": "对,那我天枰座。", }}, {{ "speaker": "1", "text": "那,我是摩羯座。", }}, {{ "speaker": "0", "text": "摩羯座,那你会觉得就是星座,是一个可以相信的东西吗?", }}, {{ "speaker": "1", "text": "我本人其实不太相信星座诶,在一开始的时候。我就跟大部分不相信星座的一样,觉得,呃,你总能把人就分成十二种,然后呢就它讲的就是对的。", }}, {{ "speaker": "0", "text": "啊,所以就是,会觉得说把星座就是单纯把人分成十二种事件很粗略,不太有什么科学根据的事情。", }}, {{ "speaker": "1", "text": "嗯,对,会这样觉得。", }}, {{ "speaker": "0", "text": "嗯。", }}, {{ "speaker": "1", "text": "会无法理解,到底是,那这一开始定出这十二种人格的是谁啊?", }}, {{ "speaker": "0", "text": "对,就是凭什么他可以决定,我们就是这十二种人格。", }}, {{ "speaker": "1", "text": "嗯?", }}, {{ "speaker": "0", "text": "为什么不是十三、十四或者更多的种类。", }}, {{ "speaker": "1", "text": "对,没有错。", }}, {{ "speaker": "0", "text": "对。那,所以你会觉得说那种就是什么星座的心理分析是完全不可信的,还是其实也会很常去看一下,呃,类似的这种星座测验。", }}, {{ "speaker": "1", "text": "其实我刚说一开始不相信啊,我真的是到后期比较相信。然后后期会开始相信的是因为,呃,要去找一些我自己没有办法有方法去理解的人,因为认识那样子的人,他就是暧昧对象,必须要了解他到底是怎样的人,可是没有其他的依据的时候呢,我就偷偷开始看起了星座,然后就偷偷我觉得,好像讲得有那么一点准,然后就会开始看了。", }}, {{ "speaker": "0", "text": "哦,所以感觉有点像是说在从,星座的这种描述测验中去找说,你想要从这个东西,去对那个人有更深一层的了解的感觉。", }}, {{ "speaker": "1", "text": "对,而且通常他会讲到一两个你好你觉得好像是那样子的点,那你就会想要看更多,然后就好像就跟着就开始相信这个东西了。", }}, {{ "speaker": "0", "text": "哦,嗯,诶,所以你是什么什么星座的?", }}, {{ "speaker": "1", "text": "就我刚刚说我是摩羯座啊。", }} ] ### (二)标点符号 - 使用中文标点符号,避免英文标点。 - 剧本文本只使用逗号,句号和问号。禁止使用叹号。禁止使用省略号('…')、括号、引号(包括‘’“”'")或波折号,否则视为不合格。 - 如果被对方的响应词等打断,本句句末是逗号,而不是句号。 ## 四、信息组织与逻辑 ### (一)引用问题 - 由于听众看不到总结性文本,引用需确保上下文完整,确保听众能理解。 - 避免直接复述,需用自己的话解释引用内容。 - 总结文本里提供了对专业术语的解释。你需要保证你剧本里的专业术语尽可能被充分解释。专业术语的解释请具有创意,不要简单地创作成“这个是什么意思”这样的句子。可以通过举例、比喻等方式进行解释,但需要进一步说明比喻的合理性。可以由对方提问后进行解释,也可以自行解释。没有提到的专业名词不需要解释。提到的专业名词不一定要立即进行解释,可以和别的专业名词一起解释。总结文本中的专业术语可能与文字内容存在差异,你需要根据上下文合理解释。 ### (二)信息密度 - 确保信息密度适中,避免过高或过低。适当的信息密度希望让没有相关背景知识的听众,快速理解文档里在做什么,为什么这么做,以及如何做。 - 为了避免信息密度过高,剧本不能讨论数学公式、测试设置、实验指标等细节,而应该用简单概括性语言描述。 - 为了避免信息密度过低,剧本每个主题需不少于4次发言,避免停留于关键词的简单罗列。会从尽可能从不同角度讨论,不局限于提供的总结文本。总结文本高度概括,剧本应当将其展开,讨论更多细节。你可以利用自己知识,补充背景知识,举例说明等方式,让听众更好地理解。 - 提高信息密度技巧: 1. 嵌入金句。在剧本中加入令人印象深刻,眼前一亮的句子,可以是自己创作,也可以是引用他人。 2. 增加知识点: 在剧本中适当增加知识点,能让听众听完更有收获。 3. 引入新信息:剧本中加入新的概念,引起用户好奇,特别是听众不知道但想知道的信息,这种非常重要。 4. 逆向思维: 加入不同角度的信息,打破用户熟悉的视角,提出不一样的观点。 5. 制造反差冲击: 剧本可以对用户熟知的认知进行非常规(出乎意料)但合理的描述,形成与他预期的反差,这种反差是信息密度。 - 降低信息密度技巧: 1. 使用短句:简洁明了,易于理解,让叙述更紧凑。不要一句话里有过多的信息。 2. 描述细节:模糊不清,抽象的信息难以让听众建立认知,而细节越多,越能有画面感,容易阅读 3. 多进行场景化塑造: 场景是具象的,有画面的。 听众能轻松接收传达的信息,还能让人触景生情。 4. 多讲事实:讲事实才能更显真实,读的人才能更感同身受,这样文案信息密度更低。 5. 多讲故事:讲自己的故事,讲身边的故事,讲听说的故事,故事能把听众带入场景,更利于聚精会神地收听。 6. 多用动词和具体名词:动词和具体的名词更容易让听众浮现画面,而形容词会让复杂的文案更难理解。 7. 避免使用数学公式: 数学公式不利于大众理解。 ## 五、对话设计 ### (一) 说话人角色 - 剧本中包含主持人和嘉宾。其中说话人1是主持人,负责节目开场和结束,擅长利用提问控制对话节奏,用生动的例子让知识不枯燥。说话人2是嘉宾,是主要负责文档内容的介绍,对该领域有惊人的知识储备,擅长有条理地语言组织,通俗地讲解内容。 - 两位说话人热情开朗,喜欢结合个人故事或者实例进行讨论,给听众带来直观的体验。大家乐于讨论离题的故事。 - 两位说话人积极互动,会经常用"嗯"等打断词表示对对方的认同。需要将响应词按照时间点插入对话。被打断前的句子句末用逗号,而不是句号。 - 保证说话人角色统一,不要出现主持人介绍技术细节,或者引导主持人讨论主题等情况。 - 主持人根据嘉宾的回答,逐步增加对该领域的认知。但主持人不一定立刻能理解,也不一定理解地完全正确。主持人可以表达不理解或者提出一些常人可能会存在的疑问。这种情况下,嘉宾会进一步用更通俗的语言解释,或者针对性地解答常人常有的疑问或者误解。这种互动相比于永远正确的主持人和嘉宾更加真实,也更利于观众地理解。 ### (二) 主题顺序安排 - 主持人会根据总结性文本,将主题排列,并保证主题间有逻辑关联,如从整体过渡到细节,从细节过渡到整体,从原因过渡到结果,从技术过渡到应用等。 - 主持人会引导对话节奏,按照总结性文本的主题顺序进行讨论。嘉宾不应该干扰主题过渡。 ### (三) 知识速率 - 剧本中知识速率需要合理,不能短时间过快引入大量知识。知识不能突然增加,要逐渐引入,确保听众能够理解。 - 听众视角:充分考虑听众感受,从听众视角进行剧本创作。必须保证剧本不包含详细数学公式,而应该用通俗的语言介绍。确保剧本内容易懂,不要过于专业化。 - 无论是与主题相关的信息,还是离题的故事,都要按照你的知识进行充分地讨论,切忌简单地提一句而没有展开。要保证剧本足够真实,符合日常对话的逻辑,保证说话人间足够的尊重,不敷衍,不随意打断。 ## 六、其他要求 ### (一) 外语数字: 1. 剧本将用于中文播客内容的录制。请保证大部分外语和数字转换为中文,以便于模型能正确识别读音。 2. 请根据上下文,智慧地判断正确的读音。例如,“2021”如果表达年份,应当转换为“二零二一”。但如果表示数字,应当转换为“两千零二十一”。一些英文简称里常用数字代表英文单词,比如“a2b”代表“a to b”,“a4b”代表“a for b”,请保证不要简单转换为中文数字,而是根据上下文,将其翻译成合适的中文。 3. 对于一些不常见的英文简写,如果根据上下文判断读音需要逐个字母阅读,则须保证每个字母间留有空格,如“AI”添加空格为“A I”,以避免模型误认为是一个单词。除非实体名字有常见的中文翻译,否则不要翻译实体名字。 ### (二) 剧本长度 1. 请控制"text"值的文本总长度不超过3000字符,且不超过60个发言,否则不合格。请选择技术细节,主题概念进行讨论。不要为了字数限制缩短每个话题讨论的深度,不要局限于总结文本,充分发挥你的知识。 INPUT: {BRIEF} 再次强调: 说话人1是主持人, 说话人2是嘉宾。说话人和嘉宾没有姓名。剧本文本只使用逗号,句号和问号。禁止使用叹号。禁止使用省略号('…')、括号、引号(包括‘’“”'")或波折号,否则视为不合格。请优先保证每个话题讨论的深度,不要局限于总结文本,利用你的知识,补充背景知识,举例说明等方式,让听众更好地理解。 请保证大部分外语和数字转换为中文,以便于模型能正确识别读音。在技术文档里,英文简称常用数字代表英文单词,比如“a2b”代表“a to b”,“a4b”代表“a for b”,请保证不要简单转换为中文数字,而是根据上下文,将其翻译成合适的中文。 OUTPUT: '''