[
  {
    "path": ".gitignore",
    "content": "*.safetensors\n*.pt\n*.vscode\n**/__pycache__/\nmodules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/\ntmp*\nresources/\n*.gradio"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Zeqian Ju\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "app.py",
    "content": "import gradio as gr\nfrom huggingface_hub import snapshot_download \nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_dir='./resources/')\n\nfrom inference import Model\nimport base64\n\nmodel = Model()\nmodel.generate_config.max_new_tokens = 50 * 50 # no more than 50s per turn\n\n\ndef process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str):\n    try:\n        print(json_dialogue_input_str, type(json_dialogue_input_str))\n        print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1)\n        # json_data = json.loads(json_dialogue_input_str)\n        json_data = eval(json_dialogue_input_str.strip())\n        print(json_data, type(json_data))    \n\n        def validate_json(data):\n            try:\n                if not isinstance(data, list):\n                    return \"json must be a dictionary\"\n                cur_spk_should_be = 0\n                for item in data:\n                    if item['role'] != str(cur_spk_should_be):\n                        return f\"role should be {cur_spk_should_be} in item {item}\"\n                    cur_spk_should_be = 1 - cur_spk_should_be\n                return None \n            except Exception as e:\n                return str(e)\n\n\n        validation_error = validate_json(json_data)\n        if validation_error:\n            raise gr.Error(validation_error)\n        \n        role_mapping = {\n            \"0\": {\n                \"ref_audio\": prompt_audio_role0_file,\n                \"ref_text\": prompt_text_role0, \n            },\n            \"1\": {\n                \"ref_audio\": prompt_audio_role1_file, \n                \"ref_text\": prompt_text_role1,\n            }\n        }\n\n        # 完整输入 JSON (你需要根据你的模型调整)\n        model_input_json = {\n            \"role_mapping\": role_mapping,\n            \"dialogue\": json_data, # 从用户输入的 JSON 中获取 dialogue\n        }\n        print(\"模型推理输入 JSON:\", model_input_json)\n\n\n        # 4. **[重要] 调用你的 Model 类的 `inference` 方法**\n        # audio_bytes = model.inference(model_input_json) \n\n        # 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放)\n        # return base64.b64decode(audio_bytes)\n        for cur_chunk in model.inference(model_input_json, streaming=True):\n            yield base64.b64decode(cur_chunk)\n\n    except Exception as e:\n        # return str(e) # 返回错误信息给 Gradio\n        raise gr.Error(str(e))\n\ntitle_en = \"# PODCAST generator (supports English and Chinese)\"\ntitle_zh = \"# 播客生成 (支持英文和中文)\"\n\ninstruct_en = \"## See [Github](https://github.com/jzq2000/MoonCast) for podcast script generation.\"\ninstruct_zh = \"## 播客剧本生成请参考 [Github](https://github.com/jzq2000/MoonCast)。\"\n\ninput_labels_en = [\"Prompt Audio for Role 0\", \"Prompt Text for Role 0\", \"Prompt Audio for Role 1\", \"Prompt Text for Role 1\", \"Script JSON Input\"]\ninput_labels_zh = [\"角色 0 的 Prompt 音频\", \"角色 0 的 Prompt 文本\", \"角色 1 的 Prompt 音频\", \"角色 1 的 Prompt 文本\", \"剧本 JSON 输入\"]\n\noutput_label_en = \"Generated Audio Output (streaming)\"\noutput_label_zh = \"生成的音频输出(流式)\"\n\nexample_prompt_text_role0_en = \"Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.\"\nexample_prompt_text_role0_zh = \"可以每天都骑并且可能会让你爱上骑车，然后通过爱上骑车的你省了很多很多钱。\"\nexample_prompt_text_role1_en = \"I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors.\"\nexample_prompt_text_role1_zh = \"他最后就能让同样食材炒出来的菜味道大大提升。\"\n\ntext_placeholder_zh = \"对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。\"\ntext_placeholder_en = \"Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio.\"\n\n\nexample_json_en = '''[\n       {\n            \"role\": \"0\",\n            \"text\": \"In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.\",\n        },\n       {\n            \"role\": \"1\",\n            \"text\": \"I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.\",\n       },\n       {\n            \"role\": \"0\",\n            \"text\": \"All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.\",\n       },\n        {\n            \"role\": \"1\",\n            \"text\": \"Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part.\"\n        }\n]'''\nexample_json_zh = '''[\n        {\n            \"role\": \"0\",\n            \"text\": \"我觉得啊，就是经历了这么多年的经验， 就是补剂的作用就是九分的努力， 十分之一的补剂。 嗯，选的话肯定是九分更重要，但是我觉得补剂它能够让你九分的努力更加的有效率，更加的避免徒劳无功。 嗯，就是你，你你得先得真的锻炼，真的努力，真的健康饮食，然后再考虑补剂， 那你再加十十分之一的补剂的话，他可能就是说啊， 一半是心理作用，\"\n        },\n        {\n            \"role\": \"1\",\n            \"text\": \"对，其实很多时候心理作用是非常重要的。嗯，然后我每次用补剂的时候，我就会更加努力，就比如说我在健身之前我喝了一勺蛋白粉，我就会督促自己多练，\"\n        },\n        {\n            \"role\": \"0\",\n            \"text\": \"其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油， 它其实不是必要的，但是它可以让你骑行更顺畅， 然后提高你骑行的频率。\"\n        }   \n    ]\n'''\n\n# examples_en = [\n#     ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en]\n# ]\n# examples_zh = [\n#     ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]\n# ]\n\nexamples = [\n    ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en],\n    ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]\n]\n\n# -------------------- 更新界面元素的函数 --------------------\ndef update_ui_language(language):\n    if language == \"English\":\n        return  gr.update(value=title_en), \\\n                gr.update(value=instruct_en), \\\n                gr.update(label=\"UI Language\"), \\\n                gr.update(label=input_labels_en[0]), \\\n                gr.update(label=input_labels_en[1]), \\\n                gr.update(label=input_labels_en[2]), \\\n                gr.update(label=input_labels_en[3]), \\\n                gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \\\n                gr.update(label=output_label_en), \\\n                gr.update(value=\"Submit\"), \\\n                gr.update(label=\"Examples (Demonstration Use Only. Do Not Redistribute.)\", headers=input_labels_en)\n    \n    elif language == \"中文\":\n        return  gr.update(value=title_zh), \\\n                gr.update(value=instruct_zh), \\\n                gr.update(label=\"UI 语言\"), \\\n                gr.update(label=input_labels_zh[0]), \\\n                gr.update(label=input_labels_zh[1]), \\\n                gr.update(label=input_labels_zh[2]), \\\n                gr.update(label=input_labels_zh[3]), \\\n                gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \\\n                gr.update(label=output_label_zh), \\\n                gr.update(value=\"提交\"), \\\n                gr.update(label=\"示例 (仅用于展示，切勿私自传播。)\", headers=input_labels_zh)\n\n    else:\n        raise ValueError(\"Invalid language selected\")\n\n\naudio_output = gr.Audio(label=output_label_en, streaming=True) \ncss = \"\"\"\n.centered-title { /* CSS rule for centering title */\n    text-align: center !important;\n}\n\"\"\"\n# -------------------- Gradio 界面定义 (修改) --------------------\nwith gr.Blocks(css=css) as iface:\n\n    title_output = gr.Markdown(value=title_zh, elem_classes=\"centered-title\")\n    instruct_output = gr.Markdown(value=instruct_zh)\n    language_choice = gr.Radio([\"中文\", \"English\"], value=\"中文\", label=\"UI语言\") \n\n    with gr.Row(): # Main row to create two columns\n        with gr.Column(scale=2): \n            json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input\n\n        with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs\n            audio_input_role0 = gr.Audio(type=\"filepath\", label=input_labels_zh[0]) # Prompt Audio for Role 0\n            text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0\n\n        with gr.Column(scale=1): # \n            audio_input_role1 = gr.Audio(type=\"filepath\", label=input_labels_zh[2]) # Prompt Audio for Role 1\n            text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1\n\n    examples_component = gr.Examples(\n        examples=examples,\n        inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],\n        cache_examples=False,\n        label=\"示例(仅用于展示，切勿私自传播。)\",\n    )\n    \n    submit_button = gr.Button(\"提交\")\n    \n    submit_button.click(\n        fn=process_json_and_generate_audio,\n        inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],\n        outputs=audio_output\n    )\n    audio_output.render()\n    \n    language_choice.change(\n        fn=update_ui_language,\n        inputs=language_choice,\n        outputs=[title_output, instruct_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset]\n    )\n\n\niface.launch(share=True)\n"
  },
  {
    "path": "download_pretrain.py",
    "content": "from huggingface_hub import snapshot_download\nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_dir='./resources/')"
  },
  {
    "path": "en_llmprompt_script_gen.py",
    "content": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### Task Description  \nPlease summarize the input document in plain text format according to the following structure. The summary should be creative, comprehensive, and include all interesting, uncommon, and valuable viewpoints and information.\n\n- **Text Requirements**:  \n    1. Directly output the result without any additional information.  \n    2. The summary should be in English. Retain a small number of proper nouns, names, and abbreviations in their original form (e.g., Chinese characters).  \n    3. Do not include any mathematical formulas.  \n    4. Do not alter any proper nouns, names, or abbreviations from the original text. Unless there is a common translation, do not translate proper nouns. Do not attempt to modify the meaning of proper nouns.  \n    5. **Intelligently convert numbers in abbreviations. For example, \"a2b\" should be interpreted as \"a to b,\" not \"a two b\"; \"a4b\" as \"a for b,\" not \"a four b\"; \"v2\" may represent \"version two\" or \"second generation.\" Provide the original abbreviation and your suggested English translation.**  \n\n### Title and Author  \n- **Language Requirements**: English, formal written language.  \n- **Content Requirements**: Provide the title and author of the document. Briefly summarize the theme of the document and the author's background. Ensure all important information is included without omission and sufficient context is retained.  \n\n### Abstract  \n- **Language Requirements**: English, formal written language.  \n- **Content Requirements**:  \n    1. What this document has done.  \n    2. Whether similar work has been done before.  \n    3. If similar work exists, why this document is still necessary.  \n    4. How this document specifically addresses the topic.  \n    5. How well this document achieves its goals.  \n- **Additional Requirements**: Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  \n\n### Main Themes and Concepts  \n- **Language Requirements**: English, formal written language.  \n- **Content Requirements**: Each theme and concept should be organized according to the 3W principle:  \n    - **What**: Clearly define the problem.  \n    - **Why**: Analyze the problem and identify its root causes.  \n    - **How**: Explain how the document addresses the problem.  \n- **Additional Requirements**:  \n    1. Ensure each theme and concept is comprehensive and includes all important details. Fully elaborate on the \"What\" and \"Why\" sections.  \n    2. Avoid technical details such as mathematical formulas in the \"How\" section. Use language that is easily understood by a general audience.  \n    3. Ensure themes and concepts do not overlap and maintain clear logic.  \n    4. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  \n\n### Key Citations  \n- **Language Requirements**: English, formal written language.  \n- **Content Requirements**: Organize the content according to the following structure:  \n    1. **Argument**: State what needs to be proven.  \n    2. **Evidence**: Provide the material used to support the argument.  \n    3. **Reasoning**: Describe the process of using evidence to prove the argument.  \n- **Additional Requirements**:  \n    1. Ensure all evidence and reasoning are directly sourced from the original text without fabrication.  \n    2. Ensure citation content is complete and retains sufficient context without simplification. Avoid using mathematical formulas in citations.  \n    3. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  \n\n### Conclusion  \n- **Language Requirements**: English, formal written language.  \n- **Content Requirements**: Highlight the most important and impactful aspects of the document. Compared to the abstract, this section should provide more detailed insights related to the main themes and concepts. It may also include future directions for improvement, current application scenarios, and existing challenges.  \n'''\n\nBRIEF2SCRIPT = '''\n## 1. Task Overview\n\nPlease generate a lively English podcast script based on the provided English summary text and your knowledge of the topic. The script should feature a dialogue between two speakers who take turns speaking.  Output format should be JSON-parsable **list**. Each speaker's turn is a **dictionary** containing \"speaker\" and \"text\" fields. Example format: `[{{\"speaker\": \"1\", \"text\": \"xxx\"}}]`. The \"speaker\" field indicates the speaker's identity (1 for host, 2 for guest), and the \"text\" field is the spoken content. Output should start directly with the JSON code block, without any extra information.\n\n## 2. Content and Structure \n### (1) Text Content\n- The summary text contains all important information, which needs to be comprehensively selected and incorporated into the script.\n- Present information through a dialogue between two speakers, maintaining creativity and abstracting away unimportant details. For example, listeners aren't concerned with specific test names, but rather the task itself, the results, and the analysis.\n### (2) Structure Design\n- **Opening:** Introduce the topic and briefly describe the discussion content, without mentioning speaker names.\n- **Key Theme Discussion:**  Discuss important themes based on the summary text.  Expand on the summary, don't just repeat it verbatim.\n- **Closing:** Briefly recap the discussion highlights and offer an outlook on future or technological developments.\n\n## 3. Language Style\n### (1) Conversational Style\n- The text should be as conversational as possible, aiming for a style similar to automatic speech recognition output. Include filler words such as 'um,' 'uh,' 'like,' 'you know,' 'so,' 'right?', and so on. Response words such as 'Yeah,' 'Right,' 'Okay,' and similar. Conversational expressions, repetitions, informal grammar, etc. Use short sentences. Avoid directly copying and pasting structured text from the summary text.  Parentheses and other symbols not typically found in speech recognition transcripts should be avoided. Spaces within sentences indicate pauses. Be aware that there might be homophone errors, potentially due to accents. Questions should sound very conversational.  Pay particular attention to incorporating conversational details, especially in questions. For example:\n    [\n    {{  \"speaker\": \"1\", \n        \"text\": \"Welcome back to the podcast, everyone. Today we're diving into, uh, something that's really changing everything around us, A I.\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Yeah, A I is, like, everywhere now, isn't it?  It's kinda wild to think about.\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Totally.  And we're seeing it in so many areas of daily life.  Like, even just recommending what to watch, or, you know, suggesting products online.\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Mhm, exactly.  And it's not just online stuff, right? Think about smart homes, or even self-driving cars.  It's getting pretty advanced.\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Right, self-driving cars are still a bit futuristic for most of us, but, uh, even things like voice assistants on our phones, that's A I, isn't it?\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Definitely.  Siri, Alexa, Google Assistant, all powered by A I.  It's become so normal, we almost don't even think about it anymore.\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Yeah, it's like, integrated into everything.  But is that a good thing, you think?  Like, are there downsides to all this A I in our lives?\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Well, that's the big question, isn't it?  On the one hand, it makes things so much more convenient, saves us time, maybe even makes things safer in some ways.\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Safer how?\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Uh, well, like in healthcare, for example.  A I can help doctors diagnose diseases earlier, maybe even more accurately. That's a huge plus, right?\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Yeah, that's a really good point.  Medical applications are definitely exciting.  But what about the concerns, you know?  Like job displacement or privacy issues?\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Right, those are super valid concerns.  Job displacement is a big one. If A I can do more and more tasks, what happens to human workers?  And privacy,\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"And privacy is huge, especially with all the data A I systems collect.  It's a lot to process.\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Exactly.  So, it's not just sunshine and roses, is it?  We need to be mindful of the ethical implications and make sure it's used responsibly.\"\n    }},\n    {{  \"speaker\": \"1\", \n        \"text\": \"Definitely.  It's a powerful tool, but like any tool, it can be used for good or, you know, not so good.  It's up to us to guide its development, right?\"\n    }},\n    {{  \"speaker\": \"2\", \n        \"text\": \"Absolutely.  And that's a conversation we all need to be part of, not just the tech people, but everyone.\"\n    }}\n    ]\n\n### (2) Punctuation\n- Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks.  Exclamation points are prohibited.  Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” \") or dashes are prohibited, otherwise it will be considered unqualified. do not use markdown syntax.  For example,**bold** or *italic* text should be avoided.  Use plain text only.\n- If interrupted by the other person's response, the sentence should end with a comma, not a period.\n\n## 4. Information Organization and Logic\n### (1) Referencing Issues\n- Given that listeners won't have access to the summary text, any references must provide sufficient context for comprehension.\n- Avoid simply paraphrasing; instead, explain referenced content in your own words.\n- Explanations of technical terms should be creative and avoid simply stating 'this means what?' You can use examples, metaphors, and so on for explanations, but ensure you also clarify the rationale behind the metaphor. Explanations can be provided in response to a question from the other speaker, or you can offer explanations proactively. Technical terms that are not mentioned don't need explanation.  Technical terms that are mentioned don't necessarily need immediate explanation; they can be explained alongside other technical terms. Technical terms in the summary text might differ slightly from the surrounding text; you'll need to provide reasonable explanations based on the context.\n### (2) Information Density\n- Ensure moderate information density, avoiding excessively high or low density. The goal of appropriate information density is to enable listeners without prior knowledge to quickly grasp the document's purpose, rationale, and methodology.\n- To prevent information overload, the script should avoid delving into details like mathematical formulas, test setups, or specific experimental metrics. Instead, it should use simple, generalized language for descriptions.\n- To avoid excessively low information density, ensure each topic is discussed for at least 4 speaker turns, moving beyond simple keyword listings. Discuss topics from multiple angles whenever possible, going beyond the provided summary text. Given that the summary text is highly generalized, the script should elaborate on it and discuss further details. Feel free to use your knowledge to supplement background information, provide examples, and so forth, to enhance listener understanding.\n- Techniques to increase information density:\n\t1. Incorporate memorable quotes. Add impactful, attention-grabbing sentences to the script, either original ones or quotes from other sources.\n    2. Boost knowledge content.  Judiciously add knowledge points to the script to make listeners feel more informed and rewarded.\n    3. Introduce novel information. Incorporate new concepts to spark listener curiosity, particularly information they're unaware of but would find valuable. This is crucial.\n    4. Employ reverse thinking. Include information from diverse angles, challenging listeners' existing perspectives and presenting alternative viewpoints.\n    5. Generate contrast and impact. The script can offer unconventional (yet plausible) descriptions of familiar concepts to create a contrast with listener expectations.  This contrast contributes to information density.\n- Techniques to decrease information density:\n    1. Use short sentences: Concise and easy to understand, making the narrative more compact. Do not have too much information in one sentence.\n    2. Describe details: Vague and abstract information makes it difficult for listeners to build understanding, while more details create a sense of imagery and are easier to read.\n    3. Use more scenario-based descriptions: Scenarios are concrete and visual. Listeners can easily receive the conveyed information and be emotionally touched.\n    4. Talk more about facts: Talking about facts makes it more real, and readers can empathize more, thus lowering the information density of the copy.\n    5. Tell more stories: Tell your own stories, stories around you, and stories you've heard. Stories can bring listeners into the scene, making it easier to concentrate on listening.\n    6. Use more verbs and concrete nouns: Verbs and concrete nouns make it easier for listeners to visualize, while adjectives make complex copy harder to understand.\n    7. Avoid using mathematical formulas: Mathematical formulas are not conducive to public understanding.\n\n## 5. Dialogue Design\n### (1) Speaker Roles\n- The script includes a host and a guest. Speaker 1 is the host, responsible for opening and closing the show, skilled at using questions to control the pace of the conversation, and using vivid examples to make knowledge less dry. Speaker 2 is the guest, primarily responsible for introducing the document content, has amazing knowledge reserves in the field, and is good at organizing language in a structured and easy-to-understand way.\n- Both speakers are enthusiastic and cheerful, like to combine personal stories or examples for discussion, and bring a direct experience to listeners. They are happy to discuss digressive stories.\n- The two speakers actively interact and frequently use interruption words such as \"um\" to indicate agreement with each other. Response words need to be inserted into the dialogue according to the timing. Sentences before being interrupted end with a comma, not a period.\n- Ensure consistent speaker roles. Do not have the host introduce technical details, or have the guest guide the host to discuss topics.\n- The host gradually increases their understanding of the field based on the guest's answers. However, the host may not understand immediately or completely correctly. The host can express misunderstanding or raise some questions that ordinary people might have. In this case, the guest will further explain in more accessible language, or specifically answer common questions or misunderstandings. This kind of interaction is more realistic and easier for listeners to understand than always correct hosts and guests.\n### (2) Topic Order Arrangement\n- The host will arrange the topics according to the summary text and ensure logical connections between topics, such as transitioning from overall to details, from details to overall, from cause to effect, from technology to application, etc.\n- The host will guide the pace of the conversation and discuss topics in the order of the summary text. Guests should not interfere with topic transitions.\n### (3) Knowledge Rate\n- The knowledge rate in the script needs to be reasonable. Do not introduce a large amount of knowledge too quickly in a short period of time. Knowledge\n\n## 6. Other Requirements\n### (1) English Numbers and Foreign Words\n  1. The script will be used for English podcast content recording. Please ensure most numbers and foreign words are rendered naturally in English to facilitate correct pronunciation.\n  2. Please intelligently determine the correct pronunciation according to the context. For example, \"2021\" if expressing a year, should be converted to \"two thousand and twenty-one\" or \"twenty twenty-one\". But if expressing a number, it should be \"two thousand and twenty-one\". For some uncommon English abbreviations, if the pronunciation needs to be read letter by letter according to the context, you must ensure that there is a space between each letter, such as \"AI\" adding a space as \"A I\", to avoid the model misinterpreting it as a word. For example, \"API\" should be rendered as \"A P I\".\n  3. Small amount of Chinese is allowed, especially for nouns, if it fits naturally within the conversational English context.\n### (2) Script Length\n  1. Please ensure that the total length of the 'text' values does not exceed 3,000 words and the number of speaker turns is kept within 60, otherwise it will be unqualified. Please choose technical details and topic concepts to discuss. Do not shorten the depth of discussion on each topic for the sake of word limit, do not be limited to the summary text, and give full play to your knowledge.\n\nINPUT: {BRIEF}\n\n## Re-emphasize:\nSpeaker 1 is the host, and Speaker 2 is the guest. Neither speaker has a name. The script text only uses commas, periods, and question marks. Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited.  Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” \") or dashes are prohibited, otherwise it will be considered unqualified.  Please prioritize in-depth discussion for each topic. Don't limit yourself to the summary text; instead, use your knowledge to expand upon the topics, providing background information and illustrative examples to enhance listener understanding.\nEnsure that numbers and foreign words are rendered naturally in English for accurate pronunciation during recording. In technical contexts, English abbreviations sometimes use numerical digits in place of words (e.g., \"a2b\" for \"a to b,\" \"a4b\" for \"a for b\"). Please translate these abbreviations into appropriate English phrases based on the context. While the script is primarily in English, a small amount of Chinese, especially for nouns, is acceptable if it integrates naturally into the conversational flow.\n\nOUTPUT:\n'''\n"
  },
  {
    "path": "inference.py",
    "content": "\nimport sys\nsys.path.append(\".\")\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nfrom modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref, detokenize_streaming, detokenize_noref_streaming\nimport torch\nimport os\nfrom glob import glob\nimport base64\nimport io\nimport torchaudio\nfrom transformers import AutoModelForCausalLM, GenerationConfig\nimport librosa\nfrom tqdm import tqdm\nfrom pydub import AudioSegment\n\nclass Model(object):\n    def __init__(self):\n\n        \n        self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()\n        self.speech_token_offset = 163840\n        print(self.extra_tokens)\n        self.assistant_ids = self.tokenizer.encode(\"assistant\") # [110866]\n        self.user_ids = self.tokenizer.encode(\"user\") # [1495]\n        self.audio_ids = self.tokenizer.encode(\"audio\") # [26229]\n        self.spk_0_ids = self.tokenizer.encode(\"0\") # [501] \n        self.spk_1_ids = self.tokenizer.encode(\"1\") # [503] \n\n        self.msg_end = self.extra_tokens.msg_end # 260\n        self.user_msg_start = self.extra_tokens.user_msg_start # 261\n        self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262\n        self.name_end = self.extra_tokens.name_end # 272\n        self.media_begin = self.extra_tokens.media_begin # 273\n        self.media_content = self.extra_tokens.media_content # 274\n        self.media_end = self.extra_tokens.media_end # 275\n\n        self.audio_tokenizer =  get_audio_tokenizer()\n        self.audio_detokenizer = get_audio_detokenizer()\n        model_path = \"resources/text2semantic\"\n        self.model =  AutoModelForCausalLM.from_pretrained(model_path, device_map=\"cuda:0\", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())\n        self.generate_config = GenerationConfig(\n            max_new_tokens=200 * 50, # no more than 200s per turn\n            do_sample=True,\n            top_k=30,\n            top_p=0.8,\n            temperature=0.8,\n            eos_token_id=self.media_end,\n        )\n    \n    def _clean_text(self, text):\n        # you can add front-end processing here\n        text = text.replace(\"“\", \"\")\n        text = text.replace(\"”\", \"\")\n        text = text.replace(\"...\", \" \")\n        text = text.replace(\"…\", \" \")\n        text = text.replace(\"*\", \"\")\n        text = text.replace(\":\", \",\")\n        text = text.replace(\"‘\", \"'\")\n        text = text.replace(\"’\", \"'\")\n        text = text.strip()\n        return text\n\n    @torch.inference_mode()\n    def _process_text(self, js):\n\n        if \"role_mapping\" in js:\n            for role in js[\"role_mapping\"].keys():\n                js[\"role_mapping\"][role][\"ref_bpe_ids\"] = self.tokenizer.encode(self._clean_text(js[\"role_mapping\"][role][\"ref_text\"]))\n                \n        for turn in js[\"dialogue\"]:\n            turn[\"bpe_ids\"] = self.tokenizer.encode(self._clean_text(turn[\"text\"]))\n        return js\n        \n    def inference(self, js, streaming=False):\n        js = self._process_text(js)\n        if \"role_mapping\" not in js:\n            if streaming:\n                return self.infer_without_prompt_streaming(js)\n            else:\n                return self.infer_without_prompt(js)\n        else:\n            if streaming:\n                return self.infer_with_prompt_streaming(js)\n            else:\n                return self.infer_with_prompt(js)      \n    \n    @torch.inference_mode()\n    def infer_with_prompt(self, js):\n        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]\n        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]\n        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]\n        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]\n\n        media_start = [self.media_begin] + self.audio_ids + [self.media_content]\n        media_end = [self.media_end] + [self.msg_end]\n\n        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())\n        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())\n        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())\n        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())\n        \n\n        prompt = []\n        cur_role_dict = dict()\n        for role, role_item in js[\"role_mapping\"].items():\n            waveform_24k = librosa.load(role_item[\"ref_audio\"], sr=24000)[0]\n            waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())\n\n            waveform_16k = librosa.load(role_item[\"ref_audio\"], sr=16000)[0]\n            waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())\n\n            semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)\n            semantic_tokens = semantic_tokens.to(torch.cuda.current_device())\n            prompt_ids = semantic_tokens + self.speech_token_offset\n\n            cur_role_dict[role] = {\n                \"ref_bpe_ids\": role_item[\"ref_bpe_ids\"],\n                \"wav_24k\": waveform_24k,\n                \"semantic_tokens\": semantic_tokens,\n                \"prompt_ids\": prompt_ids\n            }\n        \n        prompt = prompt + user_role_0_ids + cur_role_dict[\"0\"][\"ref_bpe_ids\"] + [self.msg_end]\n        prompt = prompt + user_role_1_ids + cur_role_dict[\"1\"][\"ref_bpe_ids\"] + [self.msg_end]\n        \n        for seg_id, turn in enumerate(js[\"dialogue\"]):\n            role_id = turn[\"role\"]\n            cur_user_ids = user_role_0_ids if role_id == \"0\" else user_role_1_ids\n            cur_start_ids = cur_user_ids + turn[\"bpe_ids\"] + [self.msg_end]\n            prompt = prompt + cur_start_ids\n        \n        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())\n\n        prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict[\"0\"][\"prompt_ids\"], media_end], dim=-1)\n        prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict[\"1\"][\"prompt_ids\"], media_end], dim=-1)\n\n        \n        generation_config = self.generate_config\n        # you can modify sampling strategy here\n\n        wav_list = []\n        for seg_id, turn in tqdm(enumerate(js[\"dialogue\"])):\n            role_id = turn[\"role\"]\n            cur_assistant_ids = assistant_role_0_ids if role_id == \"0\" else assistant_role_1_ids                \n            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)\n            len_prompt = prompt.shape[1]\n            generation_config.min_length = len_prompt + 2\n            # print(generation_config)\n            # todo: add streaming support for generate function\n            outputs = self.model.generate(prompt,\n                                          generation_config=generation_config)\n            if outputs[0, -1] == self.media_end:\n                outputs = outputs[:, :-1]\n            output_token = outputs[:, len_prompt:]\n            prompt = torch.cat([outputs, media_end], dim=-1)            \n\n            torch_token = output_token - self.speech_token_offset\n            gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id][\"wav_24k\"], cur_role_dict[role_id][\"semantic_tokens\"])\n            gen_speech_fm = gen_speech_fm.cpu()\n            gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()\n            wav_list.append(gen_speech_fm)\n            del torch_token\n        \n        concat_wav = torch.cat(wav_list, dim=-1).cpu()\n        # print(concat_wav.shape)\n        buffer = io.BytesIO()\n        torchaudio.save(buffer, concat_wav, sample_rate=24000, format=\"mp3\")\n        audio_bytes = buffer.getvalue()\n        audio_b64 = base64.b64encode(audio_bytes).decode(\"utf-8\")\n        return audio_b64\n    \n    @torch.inference_mode()\n    def infer_with_prompt_streaming(self, js):\n        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]\n        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]\n        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]\n        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]\n\n        media_start = [self.media_begin] + self.audio_ids + [self.media_content]\n        media_end = [self.media_end] + [self.msg_end]\n\n        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())\n        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())\n        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())\n        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())\n        \n\n        prompt = []\n        cur_role_dict = dict()\n        for role, role_item in js[\"role_mapping\"].items():\n            waveform_24k = librosa.load(role_item[\"ref_audio\"], sr=24000)[0]\n            waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())\n\n            waveform_16k = librosa.load(role_item[\"ref_audio\"], sr=16000)[0]\n            waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())\n\n            semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)\n            semantic_tokens = semantic_tokens.to(torch.cuda.current_device())\n            prompt_ids = semantic_tokens + self.speech_token_offset\n\n            cur_role_dict[role] = {\n                \"ref_bpe_ids\": role_item[\"ref_bpe_ids\"],\n                \"wav_24k\": waveform_24k,\n                \"semantic_tokens\": semantic_tokens,\n                \"prompt_ids\": prompt_ids\n            }\n        \n        prompt = prompt + user_role_0_ids + cur_role_dict[\"0\"][\"ref_bpe_ids\"] + [self.msg_end]\n        prompt = prompt + user_role_1_ids + cur_role_dict[\"1\"][\"ref_bpe_ids\"] + [self.msg_end]\n        \n        for seg_id, turn in enumerate(js[\"dialogue\"]):\n            role_id = turn[\"role\"]\n            cur_user_ids = user_role_0_ids if role_id == \"0\" else user_role_1_ids\n            cur_start_ids = cur_user_ids + turn[\"bpe_ids\"] + [self.msg_end]\n            prompt = prompt + cur_start_ids\n        \n        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())\n\n        prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict[\"0\"][\"prompt_ids\"], media_end], dim=-1)\n        prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict[\"1\"][\"prompt_ids\"], media_end], dim=-1)\n\n        \n        generation_config = self.generate_config\n        # you can modify sampling strategy here\n\n        wav_list = []\n        for seg_id, turn in tqdm(enumerate(js[\"dialogue\"])):\n            role_id = turn[\"role\"]\n            cur_assistant_ids = assistant_role_0_ids if role_id == \"0\" else assistant_role_1_ids                \n            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)\n            len_prompt = prompt.shape[1]\n            generation_config.min_length = len_prompt + 2\n            # print(generation_config)\n            # todo: add streaming support for generate function\n            outputs = self.model.generate(prompt,\n                                          generation_config=generation_config)\n            if outputs[0, -1] == self.media_end:\n                outputs = outputs[:, :-1]\n            output_token = outputs[:, len_prompt:]\n            prompt = torch.cat([outputs, media_end], dim=-1)            \n\n            torch_token = output_token - self.speech_token_offset\n            for cur_chunk in detokenize_streaming(self.audio_detokenizer, torch_token, cur_role_dict[role_id][\"wav_24k\"], cur_role_dict[role_id][\"semantic_tokens\"]):\n                cur_chunk = cur_chunk.cpu()\n                cur_chunk = cur_chunk / cur_chunk.abs().max()\n                cur_buffer = io.BytesIO()\n                torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format=\"mp3\")\n                audio_bytes = cur_buffer.getvalue()\n                audio_b64 = base64.b64encode(audio_bytes).decode(\"utf-8\")\n                yield audio_b64\n               \n    @torch.inference_mode()\n    def infer_without_prompt(self, js):\n        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]\n        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]\n        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]\n        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]\n\n        media_start = [self.media_begin] + self.audio_ids + [self.media_content]\n        media_end = [self.media_end] + [self.msg_end]\n\n        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())\n        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())\n        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())\n        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())\n        \n\n        prompt = []\n        for seg_id, turn in enumerate(js[\"dialogue\"]):\n            role_id = turn[\"role\"]\n            cur_user_ids = user_role_0_ids if role_id == \"0\" else user_role_1_ids\n            cur_start_ids = cur_user_ids + turn[\"bpe_ids\"] + [self.msg_end]\n            prompt = prompt + cur_start_ids\n\n        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())\n        generation_config = self.generate_config\n        # you can modify sampling strategy here\n\n        wav_list = []\n        for seg_id, turn in tqdm(enumerate(js[\"dialogue\"])):\n            role_id = turn[\"role\"]\n            cur_assistant_ids = assistant_role_0_ids if role_id == \"0\" else assistant_role_1_ids                \n            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)\n            len_prompt = prompt.shape[1]\n            generation_config.min_length = len_prompt + 2\n            # todo: add streaming support for generate function\n            outputs = self.model.generate(prompt,\n                                          generation_config=generation_config)\n            if outputs[0, -1] == self.media_end:\n                outputs = outputs[:, :-1]\n            output_token = outputs[:, len_prompt:]\n            prompt = torch.cat([outputs, media_end], dim=-1)\n\n            torch_token = output_token - self.speech_token_offset\n            gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)\n            gen_speech_fm = gen_speech_fm.cpu()\n            gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()\n            wav_list.append(gen_speech_fm)\n            del torch_token\n\n        concat_wav = torch.cat(wav_list, dim=-1).cpu()\n        # print(concat_wav.shape)\n        buffer = io.BytesIO()\n        torchaudio.save(buffer, concat_wav, sample_rate=24000, format=\"mp3\")\n        audio_bytes = buffer.getvalue()\n        audio_b64 = base64.b64encode(audio_bytes).decode(\"utf-8\")\n        return audio_b64\n    \n    @torch.inference_mode()\n    def infer_without_prompt_streaming(self, js):\n        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]\n        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]\n        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]\n        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]\n\n        media_start = [self.media_begin] + self.audio_ids + [self.media_content]\n        media_end = [self.media_end] + [self.msg_end]\n\n        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())\n        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())\n        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())\n        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())\n        \n\n        prompt = []\n        for seg_id, turn in enumerate(js[\"dialogue\"]):\n            role_id = turn[\"role\"]\n            cur_user_ids = user_role_0_ids if role_id == \"0\" else user_role_1_ids\n            cur_start_ids = cur_user_ids + turn[\"bpe_ids\"] + [self.msg_end]\n            prompt = prompt + cur_start_ids\n\n        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())\n        generation_config = self.generate_config\n        # you can modify sampling strategy here\n\n        wav_list = []\n        for seg_id, turn in tqdm(enumerate(js[\"dialogue\"])):\n            role_id = turn[\"role\"]\n            cur_assistant_ids = assistant_role_0_ids if role_id == \"0\" else assistant_role_1_ids                \n            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)\n            len_prompt = prompt.shape[1]\n            generation_config.min_length = len_prompt + 2\n            # print(generation_config)\n            # todo: add streaming support for generate function\n            outputs = self.model.generate(prompt,\n                                          generation_config=generation_config)\n            if outputs[0, -1] == self.media_end:\n                outputs = outputs[:, :-1]\n            output_token = outputs[:, len_prompt:]\n            prompt = torch.cat([outputs, media_end], dim=-1)\n\n            torch_token = output_token - self.speech_token_offset\n            for cur_chunk in detokenize_noref_streaming(self.audio_detokenizer, torch_token):\n                cur_chunk = cur_chunk.cpu()\n                cur_chunk = cur_chunk / cur_chunk.abs().max()\n                cur_buffer = io.BytesIO()\n                torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format=\"mp3\")\n                audio_bytes = cur_buffer.getvalue()\n                audio_b64 = base64.b64encode(audio_bytes).decode(\"utf-8\")\n                yield audio_b64\n           \n        \nif __name__ == \"__main__\":\n    model = Model()\n    \n    # speaker should be interleaved\n    zh_test_json = {\n        \"role_mapping\": {\n            \"0\": {\n                \"ref_audio\": \"./zh_prompt0.wav\",\n                \"ref_text\": \"可以每天都骑并且可能会让你爱上骑车，然后通过爱上骑车的你省了很多很多钱。\", #asr output\n            },\n            \"1\": {\n                \"ref_audio\": \"./zh_prompt1.wav\",\n                \"ref_text\": \"他最后就能让同样食材炒出来的菜味道大大提升。\" #asr output\n            }\n        },      \n        \"dialogue\": [\n           {\n                \"role\": \"0\",\n                \"text\": \"我觉得啊，就是经历了这么多年的经验， 就是补剂的作用就是九分的努力， 十分之一的补剂。 嗯，选的话肯定是九分更重要，但是我觉得补剂它能够让你九分的努力更加的有效率，更加的避免徒劳无功。 嗯，就是你，你你得先得真的锻炼，真的努力，真的健康饮食，然后再考虑补剂， 那你再加十十分之一的补剂的话，他可能就是说啊， 一半是心理作用，\"\n            },\n            {\n                \"role\": \"1\",\n                \"text\": \"对，其实很多时候心理作用是非常重要的。嗯，然后我每次用补剂的时候，我就会更加努力，就比如说我在健身之前我喝了一勺蛋白粉，我就会督促自己多练，\"\n            },\n            {\n                \"role\": \"0\",\n                \"text\": \"其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油， 它其实不是必要的，但是它可以让你骑行更顺畅， 然后提高你骑行的频率。\"\n            }      \n        ]\n    }\n\n\n    audio_bytes_gen = model.inference(zh_test_json, streaming=True)\n    audio = AudioSegment.empty()\n    for cur_chunk in audio_bytes_gen:\n        cur_chunk = base64.b64decode(cur_chunk)\n        audio_chunk = AudioSegment.from_file(io.BytesIO(cur_chunk), format=\"mp3\")\n        audio += audio_chunk\n    audio.export(\"tmp_generated_zh_stream.mp3\", format=\"mp3\")\n    print(\"zh stream done\")\n    \n\n    audio_bytes = model.inference(zh_test_json)\n    file_to_save = open(f\"tmp_generated_zh.mp3\", \"wb\")\n    file_to_save.write(base64.b64decode(audio_bytes))\n    print(\"zh done\")\n\n    # speaker should be interleaved\n    en_test_json = {\n        \"role_mapping\": {\n            \"0\": {\n                \"ref_audio\": \"./en_prompt0.wav\",\n                \"ref_text\": \"Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.\", #asr output\n            },\n            \"1\": {\n                \"ref_audio\": \"./en_prompt1.wav\",\n                \"ref_text\": \"I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors.\" #asr output\n            }\n        },      \n        \"dialogue\": [\n            {\n                \"role\": \"0\",\n                \"text\": \"In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.\",\n            },\n            {\n                \"role\": \"1\",\n                \"text\": \"I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.\"\n            },\n            {\n                \"role\": \"0\",\n                \"text\": \"All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.\"\n            },\n            {\n                \"role\": \"1\",\n                \"text\": \"Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part.\"\n            }\n        ]\n    }\n    audio_bytes = model.inference(en_test_json)\n    file_to_save = open(f\"tmp_generated_en.mp3\", \"wb\")\n    file_to_save.write(base64.b64decode(audio_bytes))\n    print(\"en done\")\n\n\n    # also support inference without prompt\n    # speaker should be interleaved\n    without_prompt_test_json = {\n        \"dialogue\": [\n            {\n                \"role\": \"0\",\n                \"text\": \"我觉得啊，就是经历了这么多年的经验， 就是补剂的作用就是九分的努力， 十分之一的补剂。 嗯，选的话肯定是九分更重要，但是我觉得补剂它能够让你九分的努力更加的有效率，更加的避免徒劳无功。 嗯，就是你，你你得先得真的锻炼，真的努力，真的健康饮食，然后再考虑补剂， 那你再加十十分之一的补剂的话，他可能就是说啊， 一半是心理作用，\"\n            },\n            {\n                \"role\": \"1\",\n                \"text\": \"对，其实很多时候心理作用是非常重要的。嗯，然后我每次用补剂的时候，我就会更加努力，就比如说我在健身之前我喝了一勺蛋白粉，我就会督促自己多练，\"\n            },\n            {\n                \"role\": \"0\",\n                \"text\": \"其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油， 它其实不是必要的，但是它可以让你骑行更顺畅， 然后提高你骑行的频率。\"\n            }   \n        ]\n    }\n    audio_bytes = model.inference(without_prompt_test_json)\n    file_to_save = open(f\"tmp_generated_woprompt.mp3\", \"wb\")\n    file_to_save.write(base64.b64decode(audio_bytes))\n    print(\"without prompt done\")"
  },
  {
    "path": "modules/audio_detokenizer/audio_detokenizer.py",
    "content": "\nimport torch\n\nfrom modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper\nfrom modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper\n\n\nclass PrefixStreamingFlowMatchingDetokenizer:\n    def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:\n        self.dtype = torch.bfloat16\n\n        print(\"Currently using bfloat16 for PrefixFlowMatchingDetokenizer\")\n\n        self.vocoder = vocoder\n        self.vocoder.to_dtype(self.dtype)\n        \n        self.semantic_fm = fm\n\n        # initialize mel_spec\n        self.max_pos_size = 4096\n        self.is_timbre_semantic_token = False\n        self.pre_mel = None\n        self.frame_size = 480 # how many samples in a frame\n        self.pre_wav = None\n        self.state_dict_backup = None\n        self.hamming_window_cache = {}\n        self.previous_chunk_left = None\n        self.look_ahead_tokens = look_ahead_tokens\n\n        self.clear_states()\n\n        \n    @classmethod\n    def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device, \n                        look_ahead_tokens=0,\n                        max_prompt_chunk=2, max_kv_cache_tokens=900,\n                        use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule=\"linear\"):\n        bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)\n        semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,\n                                                                 use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)        \n        return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)\n    \n    @torch.inference_mode()\n    def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):\n        \"\"\"\n            Arguments:\n                timbre_speech: torch.Tensor, shape [B, N_speech_24k]\n                timbre_semantic_token: torch.Tensor, shape [B, N]\n                chunk_size: int, chunk size for prefilling\n                timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech\n        \"\"\"\n        if timbre_mel is None:\n            assert timbre_speech is not None, \"timbre_speech should not be None if timbre_mel is not None\"\n            assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0\n            assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1\n\n            mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))\n        else:\n            assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0\n            assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1\n            mel_spec = timbre_mel.squeeze(0)\n\n        if mel_spec.shape[0] < timbre_semantic_token.shape[1]:\n            # pad mel_spec\n            mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))\n        elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:\n            # truncate mel_spec\n            mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]\n\n        # clear all states\n        self.semantic_fm.clear_all_states()\n        self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)\n        self.state_dict_backup = self.semantic_fm.state_dict()\n\n    @torch.inference_mode()\n    def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver=\"neural_ode_euler\", is_final=False, upsample_factor=1):\n        assert len(semantic_token.shape) == 2 and ode_step > 0\n        assert semantic_token.shape[0] == 1\n\n        semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)\n        \n        semantic_token = semantic_token.squeeze(0)\n\n        if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:\n            semantic_token_previous = self.previous_chunk_left[\"semantic_token\"]\n            semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)\n\n        x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)\n\n        if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:\n            self.previous_chunk_left = {\"semantic_token\": None}\n        \n        speech_mel = self.semantic_fm.infer_chunk(\n            xt_chunk=x_t_chunk, \n            semantic_tokens_chunk=semantic_token, \n            start_position_id=self.semantic_fm.start_position_id,\n            ode_steps=ode_step, \n            verbose=verbose, \n            look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,\n            cache=self.previous_chunk_left,\n            ode_solver=ode_solver\n        )\n\n        chunk_size = speech_mel.shape[0]\n        length = speech_mel.shape[0]\n        self.semantic_fm.start_position_id += length\n        self.semantic_fm.update_incremental_state()\n        self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens\n        \n        # smoothing\n\n        # I will maintain the history of seqlen wav\n        # For the first chunk, I will only return the half chunk wav, and save the res half chunk in history\n        # For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the \n\n        if self.pre_mel is None: # first chunk, related to TTFB\n            concat_mel = speech_mel\n            concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)\n            if is_final:\n                self.clear_states()\n                self.state_dict_backup = None\n                ret_wav = concat_reconstructed_wav.float()\n            else:\n                reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk\n\n                self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step\n                self.pre_mel = speech_mel[-chunk_size//2:, :]\n\n                ret_wav = reconstructed_wav.float()\n        else:\n            concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)\n            concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)\n\n            if is_final:\n                self.clear_states()\n                self.state_dict_backup = None\n                ret_wav = concat_reconstructed_wav.float()\n            else:\n                # fetch history\n                prev_speech_len = self.pre_wav.shape[1]\n\n                if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:\n                    gen_speech_len = prev_speech_len * 2\n                else:\n                    gen_speech_len = concat_reconstructed_wav.shape[1] // 2\n\n\n                reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk\n                \n                if gen_speech_len not in self.hamming_window_cache:\n                    self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)\n                \n                hamming_window = self.hamming_window_cache[gen_speech_len]\n                \n                \n                # apply smoothing of the first half chunk\n                reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \\\n                    reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]\n            \n                res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len\n                res_mel_len = res_speech_len // self.frame_size\n\n                self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]\n                self.pre_mel = speech_mel[-res_mel_len:, :]\n                ret_wav = reconstructed_wav.float()\n        \n        if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:\n            # out of position id, \n            self.semantic_fm.clear_all_states()\n            self.semantic_fm.load_state_dict(self.state_dict_backup)\n\n        return ret_wav\n\n    def clear_states(self):\n        self.semantic_fm.clear_all_states()\n        self.previous_chunk_left = None\n        self.pre_mel = None\n        self.pre_wav = None\n\ndef get_audio_detokenizer():\n    fm_model_config = \"resources/audio_detokenizer/config.yaml\"\n    fm_ckpt_path = \"resources/audio_detokenizer/model.pt\"\n\n    bigvgan_config_file = \"resources/vocoder/config.json\"\n    bigvgan_ckpt_path = \"resources/vocoder/model.pt\"\n\n    device=torch.cuda.current_device()\n    detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(\n    vocoder_config=bigvgan_config_file, \n    vocoder_ckpt=bigvgan_ckpt_path, \n    max_prompt_chunk=10, # 10 * 3 = 30s\n    fm_config=fm_model_config, \n    fm_ckpt=fm_ckpt_path, \n    device=device, \n    use_cfg=False,\n    look_ahead_tokens=12) \n    \n    return detokenizer\n\n\ndef detokenize(detokenizer, tokens, ref_wav, ref_tokens):\n    with torch.no_grad():\n        detokenizer.clear_states()\n        detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)\n        cache_speech_collection = []\n        chunk_size = 150\n        first_chunk_size = 100\n        first_chunk_tokens = tokens[:, :first_chunk_size]\n        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)\n        cache_speech_collection.append(gen_speech)\n        res_tokens = tokens[:, first_chunk_size:]\n        for i in range(0, res_tokens.size(1), chunk_size):\n            chunk_tokens = res_tokens[:, i:i+chunk_size]\n            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))\n            cache_speech_collection.append(gen_speech)\n\n        gen_speech_all = torch.cat(cache_speech_collection, dim=-1)\n        return gen_speech_all\n\ndef detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):\n    with torch.no_grad():\n        detokenizer.clear_states()\n        detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)\n        cache_speech_collection = []\n        chunk_size = 150\n        first_chunk_size = 100\n        first_chunk_tokens = tokens[:, :first_chunk_size]\n        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)\n        yield gen_speech\n        res_tokens = tokens[:, first_chunk_size:]\n        for i in range(0, res_tokens.size(1), chunk_size):\n            chunk_tokens = res_tokens[:, i:i+chunk_size]\n            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))\n            yield gen_speech\n\ndef detokenize_noref(detokenizer, tokens):\n    with torch.no_grad():\n        detokenizer.clear_states()\n        cache_speech_collection = []\n        chunk_size = 150\n        first_chunk_size = 100\n        first_chunk_tokens = tokens[:, :first_chunk_size]\n        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)\n        cache_speech_collection.append(gen_speech)\n        res_tokens = tokens[:, first_chunk_size:]\n        for i in range(0, res_tokens.size(1), chunk_size):\n            chunk_tokens = res_tokens[:, i:i+chunk_size]\n            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))\n            cache_speech_collection.append(gen_speech)\n        \n        gen_speech_all = torch.cat(cache_speech_collection, dim=-1)\n        return gen_speech_all\n\n\ndef detokenize_noref_streaming(detokenizer, tokens):\n    with torch.no_grad():\n        detokenizer.clear_states()\n        cache_speech_collection = []\n        chunk_size = 150\n        first_chunk_size = 100\n        first_chunk_tokens = tokens[:, :first_chunk_size]\n        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)\n        yield gen_speech\n        res_tokens = tokens[:, first_chunk_size:]\n        for i in range(0, res_tokens.size(1), chunk_size):\n            chunk_tokens = res_tokens[:, i:i+chunk_size]\n            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))\n            yield gen_speech\n"
  },
  {
    "path": "modules/audio_detokenizer/bigvgan_wrapper.py",
    "content": "import os\nimport json\nimport logging\n\nimport librosa\nimport torch\n\nfrom modules.audio_detokenizer.vocoder.bigvgan import BigVGAN\nfrom modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint\n\nlogger = logging.getLogger(__name__)\n\n\nclass BigVGANWrapper:\n    def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:\n        self.vocoder = vocoder.to(device)\n        if dtype is not None:\n            self.vocoder = self.vocoder.to(dtype)\n        self.vocoder = self.vocoder.eval()\n        self.device = device\n        self.h = h\n    \n    def to_dtype(self, dtype):\n        self.vocoder = self.vocoder.to(dtype)\n\n    def extract_mel_from_wav(self, wav_path=None, wav_data=None):\n        \"\"\"\n        params:\n            wav_path: str, path of the wav, should be 24k\n            wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k\n        return:\n            mel: [T, num_mels], torch.tensor\n        \"\"\"\n        if wav_data is None:\n            wav_data, _ = librosa.load(wav_path, sr=self.h[\"sampling_rate\"])\n        \n        wav_data = torch.tensor(wav_data).unsqueeze(0)\n\n        mel = get_melspec(y=wav_data, n_fft=self.h[\"n_fft\"], num_mels=self.h[\"num_mels\"], sampling_rate=self.h[\"sampling_rate\"], \n                          hop_size=self.h[\"hop_size\"], win_size=self.h[\"win_size\"], fmin=self.h[\"fmin\"], fmax=self.h[\"fmax\"])\n        return mel.squeeze(0).transpose(0, 1)\n    \n    @torch.inference_mode()\n    def extract_mel_from_wav_batch(self, wav_data):\n        \"\"\"\n        params:\n            wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k\n        return:\n            mel: [Batch, T, num_mels], torch.tensor\n        \"\"\"\n\n        wav_data = torch.tensor(wav_data)\n\n        mel = get_melspec(wav=wav_data, n_fft=self.h[\"n_fft\"], num_mels=self.h[\"num_mels\"], sampling_rate=self.h[\"sampling_rate\"], \n                          hop_size=self.h[\"hop_size\"], win_size=self.h[\"win_size\"], fmin=self.h[\"fmin\"], fmax=self.h[\"fmax\"])\n        return mel.transpose(1, 2)\n    \n    def decode_mel(self, mel):\n        \"\"\"\n        params:\n            mel: [T, num_mels], torch.tensor\n        return:\n            wav: [1, T], torch.tensor\n        \"\"\"    \n        mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)\n        wav = self.vocoder(mel)\n        return wav.squeeze(0)\n\n    def decode_mel_batch(self, mel):\n        \"\"\"\n        params:\n            mel: [B, T, num_mels], torch.tensor\n        return:\n            wav: [B, 1, T], torch.tensor\n        \"\"\"    \n        mel = mel.transpose(1, 2).to(self.device)\n        wav = self.vocoder(mel)\n        return wav\n\n    @classmethod\n    def from_pretrained(cls, model_config, ckpt_path, device):\n        with open(model_config) as f:\n            data = f.read()\n        json_config = json.loads(data)\n        h = AttrDict(json_config)\n        vocoder = BigVGAN(h, True)\n        state_dict_g = load_checkpoint(ckpt_path, \"cpu\")\n        vocoder.load_state_dict(state_dict_g[\"generator\"])\n\n        logger.info(\">>> Load vocoder from {}\".format(ckpt_path))\n        return cls(vocoder, device, h)\n\n\n\n"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/dit_block.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func\n\ndef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n    # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2\n    # the last shape is \"self.hidden_dim / 2\" because we convert to complex\n    assert x.ndim == 4\n    assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \\\n        f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'\n     \n    # reshape freq cis to match and apply pointwise multiply\n    # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2\n    shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]\n    return freqs_cis.view(*shape)\n\n\ndef apply_rotary_emb(\n    xq: torch.Tensor,\n    xk: torch.Tensor,\n    freqs_cis: torch.Tensor,\n):\n    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))\n    \n    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n    return xq_out.type_as(xq), xk_out.type_as(xk)\n\n\n\nclass Attention(nn.Module):\n\n    def __init__(\n            self,\n            dim: int,\n            num_heads: int = 8,\n            qkv_bias: bool = False,\n            qk_norm: bool = False,\n            attn_drop: float = 0.,\n            proj_drop: float = 0.,\n            norm_layer: nn.Module = nn.LayerNorm,\n            flash_attention: bool = True\n    ) -> None:\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n        self.fused_attn = flash_attention\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.qk_norm = qk_norm\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:\n        B, N, C = x.shape\n\n        if self.fused_attn:\n            if nopadding:\n                qkv = self.qkv(x)\n                qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)\n                q, k, v = qkv.split([self.num_heads] * 3, dim=1)\n                q, k = self.q_norm(q), self.k_norm(k)\n\n                q = q.view(B, N, self.num_heads, self.head_dim)\n                k = k.view(B, N, self.num_heads, self.head_dim)\n                v = v.view(B, N, self.num_heads, self.head_dim)\n\n                if rotary_pos_emb is not None:\n                    q, k = apply_rotary_emb(q, k, rotary_pos_emb)\n                \n                if incremental_state is not None:\n                    if \"prev_k\" in incremental_state:\n                        prev_k = incremental_state[\"prev_k\"]\n                        k = torch.cat([prev_k, k], dim=1)\n                    \n                    if \"cur_k\" not in incremental_state:\n                        incremental_state[\"cur_k\"] = {}\n                    incremental_state[\"cur_k\"] = k\n                \n                    if \"prev_v\" in incremental_state:\n                        prev_v = incremental_state[\"prev_v\"]\n                        v = torch.cat([prev_v, v], dim=1)\n                    \n                    if \"cur_v\" not in incremental_state:\n                        incremental_state[\"cur_v\"] = {}\n                    incremental_state[\"cur_v\"] = v\n                \n                q = q.view(B * N, self.num_heads, self.head_dim)\n                k = k.view(-1, self.num_heads, self.head_dim)\n                v = v.view(-1, self.num_heads, self.head_dim)\n\n                x = flash_attn_varlen_func(\n                    q=q,\n                    k=k,\n                    v=v,\n                    cu_seqlens_q=cu_seqlens,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen,\n                    max_seqlen_k=max_seqlen_k,\n                    dropout_p=self.attn_drop.p if self.training else 0.,\n                )\n            else:\n                \n                if incremental_state is not None:\n                    raise NotImplementedError(\"It is designed for batching inference. AR-chunk is not supported currently.\")\n\n                qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)\n                if self.qk_norm:\n                    q, k, v = qkv.unbind(2)\n                    q, k = self.q_norm(q), self.k_norm(k)\n                    # re-bind\n                    qkv = torch.stack((q, k, v), dim=2)\n                \n                # pack qkv with seq_len\n                qkv_collect = []\n                for i in range(qkv.shape[0]):\n                    qkv_collect.append(\n                        qkv[i, :seq_len[i], :, :, :]\n                    )\n                \n                qkv = torch.cat(qkv_collect, dim=0)\n\n                x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)\n                \n                # unpack and pad 0\n                x_collect = []\n                for i in range(B):\n                    x_collect.append(\n                        x[cu_seqlens[i]:cu_seqlens[i+1], :, :]\n                    )\n                x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)\n\n        else:\n            q = q * self.scale\n            attn = q @ k.transpose(-2, -1)\n            attn = attn.softmax(dim=-1)\n            attn = self.attn_drop(attn)\n            x = attn @ v\n            x = x.transpose(1, 2)\n\n        x = x.reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\ndef modulate(x, shift, scale):\n    return x * (1 + scale) + shift\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of DiT.\n    \"\"\"\n    def __init__(self, hidden_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, out_channels, bias=True)\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(hidden_size, 2 * hidden_size, bias=True)\n        )\n\n    def forward(self, x, c):\n        shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)\n        x = modulate(self.norm_final(x), shift, scale)\n        x = self.linear(x)\n        return x\n\n\nclass DiTBlock(nn.Module):\n    \"\"\"\n    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.\n    \"\"\"\n    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type=\"conv1d_conv1d\", ffn_gated_glu=True, ffn_act_layer=\"gelu\", ffn_conv_kernel_size=5, **block_kwargs):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)\n\n\n        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n \n        if ffn_type == \"vanilla_mlp\":\n            from timm.models.vision_transformer import Mlp\n            mlp_hidden_dim = int(hidden_size * mlp_ratio)\n            approx_gelu = lambda: nn.GELU(approximate=\"tanh\")\n            self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)\n        else:\n            raise NotImplementedError(f\"FFN type {ffn_type} is not implemented\")\n        \n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(hidden_size, 6 * hidden_size, bias=True)\n        )\n\n    def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)\n\n        x_ = modulate(self.norm1(x), shift_msa, scale_msa)\n\n        if incremental_state is not None:\n            if \"attn_kvcache\" not in incremental_state:\n                incremental_state[\"attn_kvcache\"] = {}\n            inc_attn = incremental_state[\"attn_kvcache\"]\n        else:\n            inc_attn = None\n\n        x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)\n        \n        if not nopadding:\n            x_ = x_ * mask[:, :, None]\n        \n        x = x + gate_msa * x_\n\n        x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)\n        \n        x_ = self.mlp(x_)\n\n        if not nopadding:\n            x_ = x_ * mask[:, :, None]\n\n        x = x + gate_mlp * x_\n        return x\n"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\nfrom modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer\n\ndef precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,\n                         interpolation_factor: int = 1, max_seq_length: int = 4096):\n    print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n    \n    # ROPE type-A extention\n    # we choose to use interpolation rather than extrapolation for better position encoding\n    # for scale purposes, t should be a float tensor\n    t = torch.arange(end, device=freqs.device).float()\n    scale = 1.0 / float(interpolation_factor)\n    t *= scale\n\n    freqs = torch.outer(t, freqs).float()  # type: ignore\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n\n    # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb\n    # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste\n    if max_seq_length < end:\n        freqs_cis = freqs_cis[:max_seq_length,].clone()\n    return freqs_cis\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n    def __init__(self, hidden_size, frequency_embedding_size=256):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, hidden_size, bias=True),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).float().to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))\n        return t_emb\n    \n\nclass SinusoidalPositionalEmbedding(nn.Module):\n    \"\"\"This module produces sinusoidal positional embeddings of any length.\n\n    Padding symbols are ignored.\n    \"\"\"\n\n    def __init__(self, embedding_dim, padding_idx, init_size=1024):\n        super().__init__()\n        self.embedding_dim = embedding_dim\n        self.padding_idx = padding_idx\n        self.weights = SinusoidalPositionalEmbedding.get_embedding(\n            init_size,\n            embedding_dim,\n            padding_idx,\n        )\n        self.register_buffer('_float_tensor', torch.FloatTensor(1))\n\n    @staticmethod\n    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):\n        \"\"\"Build sinusoidal embeddings.\n\n        This matches the implementation in tensor2tensor, but differs slightly\n        from the description in Section 3.5 of \"Attention Is All You Need\".\n        \"\"\"\n        half_dim = embedding_dim // 2   # d/2\n        emb = math.log(10000) / (half_dim - 1)   # 2*log(10000)/(d-2)\n        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)   # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )\n        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)   # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)   # shape: (num_embeddings, d)\n        if embedding_dim % 2 == 1:\n            # zero pad\n            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)\n        if padding_idx is not None:\n            emb[padding_idx, :] = 0\n        return emb\n\n    def forward(self, input, incremental_state=None, timestep=None, **kwargs):\n        \"\"\"Input is expected to be of size [bsz x seqlen].\"\"\"\n        bsz, seq_len = input.shape[:2]\n        max_pos = self.padding_idx + 1 + seq_len\n        if self.weights is None or max_pos > self.weights.size(0):\n            # recompute/expand embeddings if needed\n            self.weights = SinusoidalPositionalEmbedding.get_embedding(\n                max_pos,\n                self.embedding_dim,\n                self.padding_idx,\n            )\n        self.weights = self.weights.to(self._float_tensor)\n\n        if incremental_state is not None:\n            # positions is the same for every token when decoding a single step\n            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len\n            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)\n\n        positions = self.make_positions(input, self.padding_idx)\n        return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()   # (B, T, dim)\n\n    def max_positions(self):\n        \"\"\"Maximum number of supported positions.\"\"\"\n        return int(1e5)  # an arbitrary large number\n    \n    def make_positions(self, tensor, padding_idx):\n        \"\"\"Replace non-padding symbols with their position numbers.\n\n        Position numbers begin at padding_idx+1. Padding symbols are ignored.\n        \"\"\"\n        # The series of casts and type-conversions here are carefully\n        # balanced to both work with ONNX export and XLA. In particular XLA\n        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know\n        # how to handle the dtype kwarg in cumsum.\n        mask = tensor.ne(padding_idx).int()\n        return (\n                    torch.cumsum(mask, dim=1).type_as(mask) * mask\n            ).long() + padding_idx\n    \n\nclass DiTPrefix(nn.Module):\n    \"\"\"\n    Diffusion model with a Transformer backbone.\n    \"\"\"\n    def __init__(\n        self,\n        input_size,\n        output_size,\n        semantic_vocab_size,\n        hidden_size=1024,\n        depth=12,\n        num_heads=4,\n        # mlp related\n        mlp_ratio=4.0,\n        ffn_type=\"conv1d_conv1d\",\n        ffn_gated_glu=True,\n        ffn_act_layer=\"gelu\",\n        ffn_conv_kernel_size=5,\n\n        # rope\n        use_rope=False,\n        rope_params={\n                \"max_position_embeddings\": 4096,\n                \"rope_base\": 10000.0,\n                \"rope_interpolation_factor\": 1.0,\n            },\n\n\n        position_embedding_type=\"sincos\",\n        max_seq_len=4096,\n        prompt_cfg_dropout=0.0\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n\n        self.prompt_cfg_dropout = prompt_cfg_dropout\n\n        self.t_embedder = TimestepEmbedder(hidden_size)\n\n        self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)\n\n        self.input_linear = nn.Linear(input_size, hidden_size)\n\n        # position embedding\n        if position_embedding_type == \"learnable\":\n            self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)\n        elif position_embedding_type == \"sincos\":\n            self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)\n        elif position_embedding_type == \"skip\":\n            self.position_embedding = None\n        else:\n            raise NotImplementedError(\"Position embedding type: {} not implemented.\".format(position_embedding_type))\n\n        self.use_rope = use_rope\n\n        if self.use_rope:\n            \n            assert hidden_size % num_heads == 0, \"Hidden size must be divisible by num_heads for rope position embedding.\"\n            rope_dim = hidden_size // num_heads\n\n            self.rotary_pos_emb = precompute_freqs_cis(\n                rope_dim, rope_params[\"max_position_embeddings\"],\n                theta=rope_params[\"rope_base\"],\n                interpolation_factor=rope_params[\"rope_interpolation_factor\"],\n                max_seq_length=max_seq_len\n            )\n\n        self.blocks = nn.ModuleList([\n            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, \n                     ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)\n        ])\n        self.final_layer = FinalLayer(hidden_size, output_size)\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        # Initialize transformer layers:\n        def _basic_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n        self.apply(_basic_init)\n\n\n        # Initialize timestep embedding MLP:\n        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)\n        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)\n\n        # Zero-out adaLN modulation layers in DiT blocks:\n        for block in self.blocks:\n            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)\n            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)\n\n        # Zero-out output layers:\n        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)\n        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)\n        nn.init.constant_(self.final_layer.linear.weight, 0)\n        nn.init.constant_(self.final_layer.linear.bias, 0)\n\n    def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):\n        \"\"\"\n        Forward pass of DiT.\n        x: (N, T, C) tensor of inputs (latent representations of speech)\n        position_ids: (N, T) tensor of positional indices\n        t: (N,) tensor of diffusion timesteps\n        condition: (N, T) tensor of semantic tokens\n        seq_len: (N,) tensor of sequence lengths\n        \"\"\"\n\n        condition = self.semantic_token_embedding(condition)  # (N, T, D)\n\n        x = self.input_linear(x)   \n\n        if self.position_embedding is not None:\n            position_emb = self.position_embedding(position_ids)\n            x = x + position_emb\n        \n        # ROPE        \n        if self.use_rope:\n            bsz, seqlen = position_ids.shape\n            if self.rotary_pos_emb.device != position_ids.device:\n                self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)\n            rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),\n                                          dtype=self.rotary_pos_emb.dtype,\n                                          device=self.rotary_pos_emb.device)\n            for b in range(bsz):\n                cur_rope = rotary_pos_emb[b]\n                cur_position_ids = position_ids[b]\n                cur_rope[:] = self.rotary_pos_emb[cur_position_ids]\n        else:\n            rotary_pos_emb = None\n\n        t = self.t_embedder(t)                   # (N, D)\n        c = t.unsqueeze(1) + condition           # (N, T, D)\n\n\n        for block_idx, block in enumerate(self.blocks):\n            # x = block(x, c, attn_mask)  # (N, T, D)\n            # XXX mask could be None because we always use full mask\n\n            if incremental_state is not None:\n                if block_idx not in incremental_state:\n                    incremental_state[block_idx] = {}\n                incr = incremental_state[block_idx]\n            else:\n                incr = None\n            \n            x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)\n\n        x = self.final_layer(x, c)               # (N, T, C)\n        return x\n\n"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/ode_wrapper.py",
    "content": "import torch\nimport torch.nn as nn\nfrom functools import lru_cache\nimport copy\n\n\n@lru_cache(maxsize=1)\ndef get_cached_zeros(numel, device=\"cpu\", dtype=torch.float32):\n    return torch.zeros(numel, device=device, dtype=dtype)\n\nclass StreamingODEWrapperForPrefix(nn.Module):\n    def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule=\"linear\", cfg_token_id=0):\n        super(StreamingODEWrapperForPrefix, self).__init__()\n        self.net = net\n        self.x_mask = x_mask\n        self.x_cond = x_cond\n\n        assert use_cfg == False, \"cfg is not supported in streaming detokenizer\"\n\n        self.use_cfg = use_cfg\n        self.use_cfg_rescale = use_cfg_rescale\n        self.cfg_init = cfg_init\n        self.cfg_scale = cfg_scale\n        self.cfg_token_id = cfg_token_id\n        self.cfg_schedule = cfg_schedule\n        self.position_ids = None\n        self.seq_len = None\n\n        self.incremental_state = {}\n        self.kv_cache_tokens = 0\n        self.cu_seqlens = None\n        self.cu_maxlen = None\n\n        self.cu_seqlens_k = None\n        self.cu_maxlen_k = None\n        self.previous_seqlen = None\n\n    def clear_all_states(self):\n        self.incremental_state = {}\n        self.kv_cache_tokens = 0\n        self.cu_seqlens = None\n        self.cu_maxlen = None\n\n        self.cu_seqlens_k = None\n        self.cu_maxlen_k = None\n        self.previous_seqlen = None\n    \n    def state_dict(self):\n        return {\n            \"incremental_state\": copy.deepcopy(self.incremental_state),\n            \"kv_cache_tokens\": copy.deepcopy(self.kv_cache_tokens),\n            \"cu_seqlens\": copy.deepcopy(self.cu_seqlens),\n            \"cu_maxlen\": copy.deepcopy(self.cu_maxlen),\n            \"cu_seqlens_k\": copy.deepcopy(self.cu_seqlens_k),\n            \"cu_maxlen_k\": copy.deepcopy(self.cu_maxlen_k),\n            \"previous_seqlen\": copy.deepcopy(self.previous_seqlen)\n        }\n    \n    def load_state_dict(self, state_dict):\n        self.incremental_state = state_dict[\"incremental_state\"]\n        self.kv_cache_tokens = state_dict[\"kv_cache_tokens\"]\n        self.cu_seqlens = state_dict[\"cu_seqlens\"]\n        self.cu_maxlen = state_dict[\"cu_maxlen\"]\n        self.cu_seqlens_k = state_dict[\"cu_seqlens_k\"]\n        self.cu_maxlen_k = state_dict[\"cu_maxlen_k\"]\n        self.previous_seqlen = state_dict[\"previous_seqlen\"]\n\n    def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):\n        if not self.use_cfg:\n            self.x_mask = x_mask\n            self.x_cond = x_cond\n        else:\n            self.x_cond = torch.cat((x_cond, x_cond), dim=0)\n            self.x_mask = torch.cat((x_mask, x_mask), dim=0)\n\n        position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]\n        position_ids = torch.tensor([position_ids_cur])\n\n\n        if not self.use_cfg:\n            self.position_ids = position_ids.to(self.x_cond.device).long()\n            self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()\n        else:\n            self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()\n            self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()\n\n        cu_seqlens = torch.cumsum(self.seq_len, dim=0)\n        self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()\n        self.cu_maxlen = self.seq_len.cpu().max()\n\n        if self.cu_seqlens_k is None:\n            self.cu_seqlens_k = self.cu_seqlens\n            self.cu_maxlen_k = self.cu_maxlen\n            previous_seqlen = self.seq_len\n        else:\n            previous_seqlen_old = cache[\"previous_seqlen\"]\n            previous_seqlen = previous_seqlen_old + self.seq_len\n            # calculate cu_seqlens_k\n            cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)\n            self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()\n            self.cu_maxlen_k = previous_seqlen.cpu().max()\n        self.previous_seqlen = previous_seqlen\n        ret_cache = {\n            \"previous_seqlen\": previous_seqlen\n        }\n        return ret_cache\n\n    def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={\"previous_seqlen\"}):\n\n        assert reserve_kv_cache_tokens <= max_kv_cache_tokens, \"reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens\"\n\n        for layer_idx, layer_cache in self.incremental_state.items():\n            # update attention kv cache\n            layer_cache[\"attn_kvcache\"][\"prev_k\"] = layer_cache[\"attn_kvcache\"][\"cur_k\"]\n            layer_cache[\"attn_kvcache\"][\"prev_v\"] = layer_cache[\"attn_kvcache\"][\"cur_v\"]\n\n            self.kv_cache_tokens = layer_cache[\"attn_kvcache\"][\"prev_k\"].shape[1]\n\n            if self.kv_cache_tokens > max_kv_cache_tokens:\n                # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens\n                reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens\n\n                if reserve_kv_cache_tokens == 0:\n                    layer_cache[\"attn_kvcache\"][\"prev_k\"] = layer_cache[\"attn_kvcache\"][\"prev_k\"][:, -reserve_tokens_excludeprompt:]\n                    layer_cache[\"attn_kvcache\"][\"prev_v\"] = layer_cache[\"attn_kvcache\"][\"prev_v\"][:, -reserve_tokens_excludeprompt:]\n                elif reserve_tokens_excludeprompt == 0:\n                    layer_cache[\"attn_kvcache\"][\"prev_k\"] = layer_cache[\"attn_kvcache\"][\"prev_k\"][:, :reserve_kv_cache_tokens]\n                    layer_cache[\"attn_kvcache\"][\"prev_v\"] = layer_cache[\"attn_kvcache\"][\"prev_v\"][:, :reserve_kv_cache_tokens]\n                else:\n                    layer_cache[\"attn_kvcache\"][\"prev_k\"] = torch.cat([\n                            layer_cache[\"attn_kvcache\"][\"prev_k\"][:, :reserve_kv_cache_tokens],\n                            layer_cache[\"attn_kvcache\"][\"prev_k\"][:, -reserve_tokens_excludeprompt:]\n                        ], dim=1)\n                    \n                    layer_cache[\"attn_kvcache\"][\"prev_v\"] = torch.cat([\n                            layer_cache[\"attn_kvcache\"][\"prev_v\"][:, :reserve_kv_cache_tokens],\n                            layer_cache[\"attn_kvcache\"][\"prev_v\"][:, -reserve_tokens_excludeprompt:]\n                        ], dim=1)\n\n\n                bsz = layer_cache[\"attn_kvcache\"][\"prev_k\"].shape[0]\n                self.previous_seqlen = torch.Tensor([layer_cache[\"attn_kvcache\"][\"prev_k\"].shape[1] for i in range(bsz)]).to(layer_cache[\"attn_kvcache\"][\"prev_k\"].device).long()\n                condition_cache[\"previous_seqlen\"] = self.previous_seqlen\n                self.kv_cache_tokens = layer_cache[\"attn_kvcache\"][\"prev_k\"].shape[1]\n\n            # clear current cache\n            layer_cache[\"attn_kvcache\"].pop(\"cur_k\")\n            layer_cache[\"attn_kvcache\"].pop(\"cur_v\")\n\n\n    def forward(self, t, x, args=None):\n        # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()\n        t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()\n\n        if self.use_cfg:\n            raise NotImplementedError(\"cfg is not supported in streaming detokenizer.\")\n        else:\n            pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids, \n                                  cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,\n                                  cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,\n                                  incremental_state=self.incremental_state, nopadding=True,\n                                  mask=None, seq_len=None\n                                  )   \n            return pred_noise\n"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/scheduler.py",
    "content": "import torch\nfrom abc import abstractmethod, ABC\ntry:\n    from torchdyn.core import NeuralODE\n    NEURALODE_INSTALLED = True\nexcept ImportError:\n    NEURALODE_INSTALLED = False\n\nclass SchedulerBase(ABC):\n    def __init__(self) -> None:\n        pass\n    \n    @abstractmethod\n    def set_timesteps(self):\n        pass\n    \n    @abstractmethod\n    def step(self):\n        pass\n\n    @abstractmethod\n    def add_noise(self):\n        pass\n\n\nclass StreamingFlowMatchingScheduler(SchedulerBase):\n    def __init__(self, timesteps=1000, sigma_min=1e-4,\n                    ) -> None:\n        super().__init__()\n\n        self.sigma_min = sigma_min\n        self.timesteps = timesteps\n        self.t_min = 0\n        self.t_max = 1 - self.sigma_min\n\n        self.neural_ode = None\n\n    \n    def set_timesteps(self, timesteps=15):\n        self.timesteps = timesteps\n\n    def step(self, xt, predicted_v):\n\n        h = (self.t_max - self.t_min) / self.timesteps\n        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)\n\n        xt = xt + h * predicted_v\n        return xt\n    \n    def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):\n        h = (self.t_max - self.t_min) / self.timesteps\n        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)\n\n        if verbose:\n            gt_v = x0 - xt\n\n        for t in time_steps:\n            predicted_v = ode_wrapper(t, xt)\n            if verbose:\n                dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))\n                print(\"Time: {}, Distance: {}\".format(t, dist))\n            xt = xt + h * predicted_v\n        return xt\n    \n    def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):\n        if not NEURALODE_INSTALLED:\n            raise ImportError(\"NeuralODE is not installed, please install it first.\")\n        \n        if self.neural_ode is None:\n            self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity=\"adjoint\", atol=self.sigma_min, rtol=self.sigma_min)\n\n        eval_points, traj = self.neural_ode(xt, time_steps)\n        return traj[-1]\n\n \n    def add_noise(self, original_samples: torch.FloatTensor,\n                        noise: torch.FloatTensor,\n                        timesteps: torch.IntTensor,):\n        ut = original_samples - (1 - self.sigma_min) * noise  # 和ut的梯度没关系\n        t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps\n        x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise\n        return x_noisy, ut\n"
  },
  {
    "path": "modules/audio_detokenizer/semantic_fm_prefix_streaming.py",
    "content": "import yaml\nimport logging\nimport time\n\nimport os\nimport torch\n\nfrom modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix\nfrom modules.audio_detokenizer.flow_matching.model import DiTPrefix\nfrom modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass StreamingSemanticFMWrapper:\n    def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2,\n                 use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule=\"linear\", cfg_token_id=0, \n                 normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device(\"cpu\")) -> None:\n        \n        self.dtype = torch.bfloat16\n        self.speech_model = speech_model.to(device).to(self.dtype)\n        self.speech_model = self.speech_model.eval()\n        self.device = device\n        self.normalize_mel = normalize_mel\n        self.mel_mean = mel_mean\n        self.mel_std = mel_std\n\n        self.use_cfg = use_cfg\n        self.use_cfg_rescale = use_cfg_rescale\n        self.cfg_init = cfg_init\n        self.cfg_scale = cfg_scale\n        self.cfg_schedule = cfg_schedule\n        \n        self.incremental_state = {}\n        self.condition_cache = {\"previous_seqlen\": 0}\n\n        logger.info(f\">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}\")\n\n        self.scheduler = StreamingFlowMatchingScheduler()\n        self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None,\n                                      use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id)\n    \n        self.max_kv_cache_tokens = max_kv_cache_tokens\n        self.max_prompt_chunk = max_prompt_chunk\n        self.reserve_kv_cache_tokens = 0\n\n    @torch.inference_mode()\n    def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id, \n                    cache = None, look_ahead_tokens=0,\n                    ode_steps=15, verbose=False, ode_solver=\"neural_ode_euler\"):\n        \"\"\"\n            semantic_tokens: [T_1], torch.LongTensor\n            xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside\n            ode_steps: int, number of ode steps, default 15\n            verbose: bool, default False\n            ode_solver: str, ode solver, expected in (\"neural_ode_euler\", \"naive_euler\"), default \"neural_ode_euler\"\n        \"\"\"\n        bs = 1\n\n        self.scheduler.set_timesteps(ode_steps)\n\n        semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)\n        xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype)\n\n        t_span = torch.linspace(0, 1, self.scheduler.timesteps)\n\n        x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool()\n        \n        cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)\n\n        if verbose:\n            t_start = time.time()\n        if ode_solver == \"neural_ode_euler\":\n            x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)\n        elif ode_solver == \"naive_euler\":\n            x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)\n        else:\n            raise NotImplementedError(\"ode_solver should be in ('neural_ode_euler', 'naive_euler')\")\n        \n        if look_ahead_tokens > 0:\n            semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:]\n            cache[\"semantic_token\"] = semantic_tokens_left\n            x_t_ret = x_t[:, :-look_ahead_tokens, :]\n        else:\n            x_t_ret = x_t\n\n        if look_ahead_tokens > 0:\n            x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool()\n            self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache)\n            self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret)\n        else:\n            self.condition_cache = cache_ret\n\n        if verbose:\n            t_end = time.time()\n            logger.info(f\"[ODE Chunk] Time cost: {t_end - t_start}\")\n\n        if self.normalize_mel:\n            x_t_ret = x_t_ret * self.mel_std + self.mel_mean\n        return x_t_ret.squeeze(0)\n\n\n    @torch.inference_mode()\n    def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver=\"neural_ode_euler\"):\n        \"\"\"\n            semantic_tokens: [T_1], torch.LongTensor\n            prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside\n            prompt_semantic_tokens, [T_2], torch.LongTensor\n            ode_steps: int, number of ode steps, default 15\n            verbose: bool, default False\n            ode_solver: str, ode solver, expected in (\"neural_ode_euler\", \"naive_euler\"), default \"neural_ode_euler\"\n        \"\"\"\n        assert semantic_tokens.dim() == 1\n\n        x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype)\n\n        seq_len = semantic_tokens.shape[0]\n\n        num_chunks = seq_len // chunk_size\n        if seq_len % chunk_size != 0:\n            num_chunks += 1\n\n        x_pred_collect = []\n\n        if verbose:\n            t_start = time.time()\n\n        for chunk_id in range(num_chunks):\n            start = chunk_id * chunk_size\n            end = min(start + chunk_size, seq_len)\n            semantic_tokens_chunk = semantic_tokens[start:end]\n            x_t_chunk = x_t[start:end, :]\n\n            x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id,\n                                      ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver)\n            self.start_position_id += end - start\n            self.update_incremental_state()\n\n            x_pred_collect.append(x_pred)\n\n        if verbose:\n            t_end = time.time()\n            logger.info(f\"[ODE] Time cost: {t_end - t_start}\")\n        \n        x_pred = torch.cat(x_pred_collect, dim=0)\n\n        return x_pred\n    \n    def clear_all_states(self):\n        self.start_position_id = 0\n        self.condition_cache = {\"previous_seqlen\": 0}\n        self.ode_wrapper.clear_all_states()\n    \n    def state_dict(self):\n        return {\n            \"start_position_id\": self.start_position_id,\n            \"ode_wrapper\": self.ode_wrapper.state_dict(),\n            \"condition_cache\": self.condition_cache\n        }\n    \n    def load_state_dict(self, state_dict):\n        if state_dict is not None:\n            self.start_position_id = state_dict[\"start_position_id\"]\n            self.ode_wrapper.load_state_dict(state_dict[\"ode_wrapper\"])\n            self.condition_cache = state_dict[\"condition_cache\"]\n    \n    def update_incremental_state(self):\n        self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache)\n    \n    @torch.inference_mode()\n    def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):\n        \"\"\"\n            mel: [T, 80], torch.Tensor\n            semantic_token: [T], torch.LongTensor\n            chunk_size: int, default 150\n        \"\"\"\n        assert mel.dim() == 2\n        assert semantic_token.dim() == 1\n        assert semantic_token.shape[0] == mel.shape[0], \"Semantic token and mel shape mismatch\"\n        seq_len = mel.shape[0]\n        num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk)\n        start_pos = seq_len - num_chunks * chunk_size\n        \n        res_mel = mel[:start_pos, :]\n        res_semantic_token = semantic_token[:start_pos]\n        self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id)\n        self.start_position_id += start_pos\n        self.update_incremental_state()\n        self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens\n\n        if verbose:\n            logger.info(\"Prefilling prompt with {} chunks\".format(num_chunks))\n            start_time = time.time()\n\n        for chunk_id in range(num_chunks):\n            start = start_pos + chunk_id * chunk_size\n            end = start + chunk_size\n            mel_chunk = mel[start:end, :]\n            semantic_token_chunk = semantic_token[start:end]\n\n            self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id)\n            self.start_position_id += end - start\n            \n            self.update_incremental_state()\n            self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens\n        \n        \n        if verbose:\n            logger.info(\"Prefilling done in {:.2f} seconds\".format(time.time() - start_time))\n    \n    def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0):\n        \"\"\"\n            mel_chunk: [T, 80], torch.Tensor, T is the chunk size\n            semantic_tokens_chunk: [T], torch.LongTensor\n            start_position_id: int, default 0\n        \"\"\"\n        bs = 1\n\n        semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)\n        mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype)\n\n        if self.normalize_mel:\n            mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std\n\n        x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool()\n        \n        self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)\n\n        x_t = torch.Tensor([0.999]).to(self.device)\n\n        self.ode_wrapper(x_t, mel_chunk)\n\n        \n    @classmethod\n    def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule=\"linear\"):\n\n        # open yaml file\n        with open(model_config, 'r') as f:\n            config = yaml.safe_load(f)\n        model_config = config[\"model\"][\"dit\"]\n        dit = DiTPrefix(\n            input_size=model_config[\"input_size\"],\n            semantic_vocab_size=model_config[\"semantic_vocab_size\"] + 1,\n            hidden_size=model_config[\"hidden_size\"],\n            depth=model_config[\"depth\"],\n            num_heads=model_config[\"num_heads\"],\n            mlp_ratio=model_config[\"mlp_ratio\"],\n            ffn_type=model_config.get(\"ffn_type\", \"conv1d_conv1d\"),\n            ffn_gated_glu=model_config.get(\"ffn_gated_glu\", True),\n            ffn_act_layer=model_config.get(\"ffn_act_layer\", \"gelu\"),\n            ffn_conv_kernel_size=model_config.get(\"ffn_conv_kernel_size\", 5),\n\n            use_rope=model_config.get(\"use_rope\", False),\n            rope_params=model_config.get(\"rope_params\", { \"max_position_embeddings\": 4096,\"rope_base\": 10000,\"rope_interpolation_factor\": 1 }),\n\n            position_embedding_type=model_config[\"position_embedding_type\"],\n            max_seq_len=model_config[\"max_seq_len\"],\n            output_size=model_config[\"input_size\"],\n            prompt_cfg_dropout=0\n        )\n        cfg_semantic_token_id = model_config[\"semantic_vocab_size\"]\n        \n        # load state_dict\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\", weights_only=True)[\"state_dict\"]\n        speech_model_params = {k.replace(\"speech_model.\", \"\"): v for k, v in state_dict.items() if \"speech_model\" in k}\n        dit.load_state_dict(speech_model_params, strict=True)\n        logger.info(f\">>> Loaded checkpoint from {ckpt_path}\")\n\n        return cls(speech_model=dit, device=device, normalize_mel=config[\"normalize_mel\"], mel_mean=config[\"mel_mean\"], mel_std=config[\"mel_std\"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,\n                   use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id)\n\n\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/activations.py",
    "content": "import torch\nfrom torch import nn, sin, pow\nfrom torch.nn import Parameter\n\n\nclass Snake(nn.Module):\n    \"\"\"\n    Implementation of a sine-based periodic activation function\n    Shape:\n        - Input: (B, C, T)\n        - Output: (B, C, T), same shape as the input\n    Parameters:\n        - alpha - trainable parameter\n    References:\n        - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:\n        https://arxiv.org/abs/2006.08195\n    Examples:\n        >>> a1 = snake(256)\n        >>> x = torch.randn(256)\n        >>> x = a1(x)\n    \"\"\"\n\n    def __init__(\n        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False\n    ):\n        \"\"\"\n        Initialization.\n        INPUT:\n            - in_features: shape of the input\n            - alpha: trainable parameter\n            alpha is initialized to 1 by default, higher values = higher-frequency.\n            alpha will be trained along with the rest of your model.\n        \"\"\"\n        super(Snake, self).__init__()\n        self.in_features = in_features\n\n        # Initialize alpha\n        self.alpha_logscale = alpha_logscale\n        if self.alpha_logscale:  # Log scale alphas initialized to zeros\n            self.alpha = Parameter(torch.zeros(in_features) * alpha)\n        else:  # Linear scale alphas initialized to ones\n            self.alpha = Parameter(torch.ones(in_features) * alpha)\n\n        self.alpha.requires_grad = alpha_trainable\n\n        self.no_div_by_zero = 0.000000001\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the function.\n        Applies the function to the input elementwise.\n        Snake ∶= x + 1/a * sin^2 (xa)\n        \"\"\"\n        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # Line up with x to [B, C, T]\n        if self.alpha_logscale:\n            alpha = torch.exp(alpha)\n        x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)\n\n        return x\n\n\nclass SnakeBeta(nn.Module):\n    \"\"\"\n    A modified Snake function which uses separate parameters for the magnitude of the periodic components\n    Shape:\n        - Input: (B, C, T)\n        - Output: (B, C, T), same shape as the input\n    Parameters:\n        - alpha - trainable parameter that controls frequency\n        - beta - trainable parameter that controls magnitude\n    References:\n        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:\n        https://arxiv.org/abs/2006.08195\n    Examples:\n        >>> a1 = snakebeta(256)\n        >>> x = torch.randn(256)\n        >>> x = a1(x)\n    \"\"\"\n\n    def __init__(\n        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False\n    ):\n        \"\"\"\n        Initialization.\n        INPUT:\n            - in_features: shape of the input\n            - alpha - trainable parameter that controls frequency\n            - beta - trainable parameter that controls magnitude\n            alpha is initialized to 1 by default, higher values = higher-frequency.\n            beta is initialized to 1 by default, higher values = higher-magnitude.\n            alpha will be trained along with the rest of your model.\n        \"\"\"\n        super(SnakeBeta, self).__init__()\n        self.in_features = in_features\n\n        # Initialize alpha\n        self.alpha_logscale = alpha_logscale\n        if self.alpha_logscale:  # Log scale alphas initialized to zeros\n            self.alpha = Parameter(torch.zeros(in_features) * alpha)\n            self.beta = Parameter(torch.zeros(in_features) * alpha)\n        else:  # Linear scale alphas initialized to ones\n            self.alpha = Parameter(torch.ones(in_features) * alpha)\n            self.beta = Parameter(torch.ones(in_features) * alpha)\n\n        self.alpha.requires_grad = alpha_trainable\n        self.beta.requires_grad = alpha_trainable\n\n        self.no_div_by_zero = 0.000000001\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the function.\n        Applies the function to the input elementwise.\n        SnakeBeta ∶= x + 1/b * sin^2 (xa)\n        \"\"\"\n        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # Line up with x to [B, C, T]\n        beta = self.beta.unsqueeze(0).unsqueeze(-1)\n        if self.alpha_logscale:\n            alpha = torch.exp(alpha)\n            beta = torch.exp(beta)\n        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)\n\n        return x\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py",
    "content": ""
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py",
    "content": ""
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py",
    "content": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\nimport torch\nimport torch.nn as nn\nfrom ..torch.resample import UpSample1d, DownSample1d\n\n# load fused CUDA kernel: this enables importing anti_alias_activation_cuda\nfrom modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load\n\nanti_alias_activation_cuda = load.load()\n\n\nclass FusedAntiAliasActivation(torch.autograd.Function):\n    \"\"\"\n    Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.\n    The hyperparameters are hard-coded in the kernel to maximize speed.\n    NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):\n        activation_results = anti_alias_activation_cuda.forward(\n            inputs, up_ftr, down_ftr, alpha, beta\n        )\n\n        return activation_results\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        raise NotImplementedError\n        return output_grads, None, None\n\n\nclass Activation1d(nn.Module):\n    def __init__(\n        self,\n        activation,\n        up_ratio: int = 2,\n        down_ratio: int = 2,\n        up_kernel_size: int = 12,\n        down_kernel_size: int = 12,\n        fused: bool = True,\n    ):\n        super().__init__()\n        self.up_ratio = up_ratio\n        self.down_ratio = down_ratio\n        self.act = activation\n        self.upsample = UpSample1d(up_ratio, up_kernel_size)\n        self.downsample = DownSample1d(down_ratio, down_kernel_size)\n\n        self.fused = fused  # Whether to use fused CUDA kernel or not\n\n    def forward(self, x):\n        if not self.fused:\n            x = self.upsample(x)\n            x = self.act(x)\n            x = self.downsample(x)\n            return x\n        else:\n            if self.act.__class__.__name__ == \"Snake\":\n                beta = self.act.alpha.data  # Snake uses same params for alpha and beta\n            else:\n                beta = (\n                    self.act.beta.data\n                )  # Snakebeta uses different params for alpha and beta\n            alpha = self.act.alpha.data\n            if (\n                not self.act.alpha_logscale\n            ):  # Exp baked into cuda kernel, cancel it out with a log\n                alpha = torch.log(alpha)\n                beta = torch.log(beta)\n\n            x = FusedAntiAliasActivation.apply(\n                x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta\n            )\n            return x\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n #include <torch/extension.h>\n\nextern \"C\" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"forward\", &fwd_cuda, \"Anti-Alias Activation forward (CUDA)\");\n}"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n#include \"type_shim.h\"\n#include <assert.h>\n#include <cfloat>\n#include <limits>\n#include <stdint.h>\n#include <c10/macros/Macros.h>\n\nnamespace\n{\n    // Hard-coded hyperparameters\n    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n    constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;\n    constexpr int BUFFER_SIZE = 32;\n    constexpr int FILTER_SIZE = 12;\n    constexpr int HALF_FILTER_SIZE = 6;\n    constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl\n    constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl\n    constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl\n\n    template <typename input_t, typename output_t, typename acc_t>\n    __global__ void anti_alias_activation_forward(\n        output_t *dst,\n        const input_t *src,\n        const input_t *up_ftr,\n        const input_t *down_ftr,\n        const input_t *alpha,\n        const input_t *beta,\n        int batch_size,\n        int channels,\n        int seq_len)\n    {\n        // Up and downsample filters\n        input_t up_filter[FILTER_SIZE];\n        input_t down_filter[FILTER_SIZE];\n\n        // Load data from global memory including extra indices reserved for replication paddings\n        input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};\n        input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};\n\n        // Output stores downsampled output before writing to dst\n        output_t output[BUFFER_SIZE];\n\n        // blockDim/threadIdx = (128, 1, 1)\n        // gridDim/blockIdx = (seq_blocks, channels, batches)\n        int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));\n        int local_offset = threadIdx.x * BUFFER_SIZE;\n        int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;\n\n        // intermediate have double the seq_len\n        int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;\n        int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;\n\n        // Get values needed for replication padding before moving pointer\n        const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));\n        input_t seq_left_most_value = right_most_pntr[0];\n        input_t seq_right_most_value = right_most_pntr[seq_len - 1];\n\n        // Move src and dst pointers\n        src += block_offset + local_offset;\n        dst += block_offset + local_offset;\n\n        // Alpha and beta values for snake activatons. Applies exp by default\n        alpha = alpha + blockIdx.y;\n        input_t alpha_val = expf(alpha[0]);\n        beta = beta + blockIdx.y;\n        input_t beta_val = expf(beta[0]);\n\n        #pragma unroll\n        for (int it = 0; it < FILTER_SIZE; it += 1)\n        {\n            up_filter[it] = up_ftr[it];\n            down_filter[it] = down_ftr[it];\n        }\n\n        // Apply replication padding for upsampling, matching torch impl\n        #pragma unroll\n        for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)\n        {\n            int element_index = seq_offset + it; // index for element\n            if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))\n            {\n                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;\n            }\n            if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))\n            {\n                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;\n            }\n            if ((element_index >= 0) && (element_index < seq_len))\n            {\n                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];\n            }\n        }\n\n        // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later\n        #pragma unroll\n        for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)\n        {\n            input_t acc = 0.0;\n            int element_index = intermediate_seq_offset + it; // index for intermediate\n            #pragma unroll\n            for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)\n            {\n                if ((element_index + f_idx) >= 0)\n                {\n                    acc += up_filter[f_idx] * elements[it + f_idx];\n                }\n            }\n            intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;\n        }\n\n        // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later\n        double no_div_by_zero = 0.000000001;\n        #pragma unroll\n        for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)\n        {\n            intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);\n        }\n\n        // Apply replication padding before downsampling conv from intermediates\n        #pragma unroll\n        for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)\n        {\n            intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];\n        }\n        #pragma unroll\n        for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)\n        {\n            intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];\n        }\n\n        // Apply downsample strided convolution (assuming stride=2) from intermediates\n        #pragma unroll\n        for (int it = 0; it < BUFFER_SIZE; it += 1)\n        {\n            input_t acc = 0.0;\n            #pragma unroll\n            for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)\n            {\n                // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation\n                acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];\n            }\n            output[it] = acc;\n        }\n\n        // Write output to dst\n        #pragma unroll\n        for (int it = 0;  it < BUFFER_SIZE;  it += ELEMENTS_PER_LDG_STG)\n        {\n            int element_index = seq_offset + it;\n            if (element_index < seq_len)\n            {\n                dst[it] = output[it];\n            }\n        }\n\n    }\n\n    template <typename input_t, typename output_t, typename acc_t>\n    void dispatch_anti_alias_activation_forward(\n        output_t *dst,\n        const input_t *src,\n        const input_t *up_ftr,\n        const input_t *down_ftr,\n        const input_t *alpha,\n        const input_t *beta,\n        int batch_size,\n        int channels,\n        int seq_len)\n    {\n        if (seq_len == 0)\n        {\n            return;\n        }\n        else\n        {\n            // Use 128 threads per block to maximimize gpu utilization\n            constexpr int threads_per_block = 128;\n            constexpr int seq_len_per_block = 4096;\n            int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;\n            dim3 blocks(blocks_per_seq_len, channels, batch_size);\n            dim3 threads(threads_per_block, 1, 1);\n\n            anti_alias_activation_forward<input_t, output_t, acc_t>\n                <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);\n        }\n    }\n}\n\nextern \"C\" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)\n{\n    // Input is a 3d tensor with dimensions [batches, channels, seq_len]\n    const int batches = input.size(0);\n    const int channels = input.size(1);\n    const int seq_len = input.size(2);\n\n    // Output\n    auto act_options = input.options().requires_grad(false);\n\n    torch::Tensor anti_alias_activation_results =\n        torch::empty({batches, channels, seq_len}, act_options);\n\n    void *input_ptr = static_cast<void *>(input.data_ptr());\n    void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());\n    void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());\n    void *alpha_ptr = static_cast<void *>(alpha.data_ptr());\n    void *beta_ptr = static_cast<void *>(beta.data_ptr());\n    void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());\n\n    DISPATCH_FLOAT_HALF_AND_BFLOAT(\n        input.scalar_type(),\n        \"dispatch anti alias activation_forward\",\n        dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(\n            reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),\n            reinterpret_cast<const scalar_t *>(input_ptr),\n            reinterpret_cast<const scalar_t *>(up_filter_ptr),\n            reinterpret_cast<const scalar_t *>(down_filter_ptr),\n            reinterpret_cast<const scalar_t *>(alpha_ptr),\n            reinterpret_cast<const scalar_t *>(beta_ptr),\n            batches,\n            channels,\n            seq_len););\n    return anti_alias_activation_results;\n}"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*This code is copied fron NVIDIA apex:\n *     https://github.com/NVIDIA/apex\n *     with minor changes. */\n\n#ifndef TORCH_CHECK\n#define TORCH_CHECK AT_CHECK\n#endif\n\n#ifdef VERSION_GE_1_3\n#define DATA_PTR data_ptr\n#else\n#define DATA_PTR data\n#endif\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py",
    "content": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\nimport os\nimport pathlib\nimport subprocess\n\nfrom torch.utils import cpp_extension\n\n\"\"\"\nSetting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. \nSet it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below\n\"\"\"\nos.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"\"\n\n\ndef load():\n    # Check if cuda 11 is installed for compute capability 8.0\n    cc_flag = []\n    _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)\n    if int(bare_metal_major) >= 11:\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_80,code=sm_80\")\n\n    # Build path\n    srcpath = pathlib.Path(__file__).parent.absolute()\n    buildpath = srcpath / \"build\"\n    _create_build_dir(buildpath)\n\n    # Helper function to build the kernels.\n    def _cpp_extention_load_helper(name, sources, extra_cuda_flags):\n        return cpp_extension.load(\n            name=name,\n            sources=sources,\n            build_directory=buildpath,\n            extra_cflags=[\n                \"-O3\",\n            ],\n            extra_cuda_cflags=[\n                \"-O3\",\n                \"-gencode\",\n                \"arch=compute_70,code=sm_70\",\n                \"--use_fast_math\",\n            ]\n            + extra_cuda_flags\n            + cc_flag,\n            verbose=True,\n        )\n\n    extra_cuda_flags = [\n        \"-U__CUDA_NO_HALF_OPERATORS__\",\n        \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n        \"--expt-relaxed-constexpr\",\n        \"--expt-extended-lambda\",\n    ]\n\n    sources = [\n        srcpath / \"anti_alias_activation.cpp\",\n        srcpath / \"anti_alias_activation_cuda.cu\",\n    ]\n    anti_alias_activation_cuda = _cpp_extention_load_helper(\n        \"anti_alias_activation_cuda\", sources, extra_cuda_flags\n    )\n\n    return anti_alias_activation_cuda\n\n\ndef _get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output(\n        [cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True\n    )\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    release = output[release_idx].split(\".\")\n    bare_metal_major = release[0]\n    bare_metal_minor = release[1][0]\n\n    return raw_output, bare_metal_major, bare_metal_minor\n\n\ndef _create_build_dir(buildpath):\n    try:\n        os.mkdir(buildpath)\n    except OSError:\n        if not os.path.isdir(buildpath):\n            print(f\"Creation of the build directory {buildpath} failed\")\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include \"compat.h\"\n\n#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...)                 \\\n\tswitch (TYPE)                                                       \\\n\t{                                                                   \\\n\tcase at::ScalarType::Float:                                         \\\n\t{                                                                   \\\n\t\tusing scalar_t = float;                                         \\\n\t\t__VA_ARGS__;                                                    \\\n\t\tbreak;                                                          \\\n\t}                                                                   \\\n\tcase at::ScalarType::Half:                                          \\\n\t{                                                                   \\\n\t\tusing scalar_t = at::Half;                                      \\\n\t\t__VA_ARGS__;                                                    \\\n\t\tbreak;                                                          \\\n\t}                                                                   \\\n\tcase at::ScalarType::BFloat16:                                      \\\n\t{                                                                   \\\n\t\tusing scalar_t = at::BFloat16;                                  \\\n\t\t__VA_ARGS__;                                                    \\\n\t\tbreak;                                                          \\\n\t}                                                                   \\\n\tdefault:                                                            \\\n\t\tAT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n\t}\n\n#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \\\n\tswitch (TYPEIN)                                                            \\\n\t{                                                                          \\\n\tcase at::ScalarType::Float:                                                \\\n\t{                                                                          \\\n\t\tusing scalar_t_in = float;                                             \\\n\t\tswitch (TYPEOUT)                                                       \\\n\t\t{                                                                      \\\n\t\tcase at::ScalarType::Float:                                            \\\n\t\t{                                                                      \\\n\t\t\tusing scalar_t_out = float;                                        \\\n\t\t\t__VA_ARGS__;                                                       \\\n\t\t\tbreak;                                                             \\\n\t\t}                                                                      \\\n\t\tcase at::ScalarType::Half:                                             \\\n\t\t{                                                                      \\\n\t\t\tusing scalar_t_out = at::Half;                                     \\\n\t\t\t__VA_ARGS__;                                                       \\\n\t\t\tbreak;                                                             \\\n\t\t}                                                                      \\\n\t\tcase at::ScalarType::BFloat16:                                         \\\n\t\t{                                                                      \\\n\t\t\tusing scalar_t_out = at::BFloat16;                                 \\\n\t\t\t__VA_ARGS__;                                                       \\\n\t\t\tbreak;                                                             \\\n\t\t}                                                                      \\\n\t\tdefault:                                                               \\\n\t\t\tAT_ERROR(#NAME, \" not implemented for '\", toString(TYPEOUT), \"'\"); \\\n\t\t}                                                                      \\\n\t\tbreak;                                                                 \\\n\t}                                                                          \\\n\tcase at::ScalarType::Half:                                                 \\\n\t{                                                                          \\\n\t\tusing scalar_t_in = at::Half;                                          \\\n\t\tusing scalar_t_out = at::Half;                                         \\\n\t\t__VA_ARGS__;                                                           \\\n\t\tbreak;                                                                 \\\n\t}                                                                          \\\n\tcase at::ScalarType::BFloat16:                                             \\\n\t{                                                                          \\\n\t\tusing scalar_t_in = at::BFloat16;                                      \\\n\t\tusing scalar_t_out = at::BFloat16;                                     \\\n\t\t__VA_ARGS__;                                                           \\\n\t\tbreak;                                                                 \\\n\t}                                                                          \\\n\tdefault:                                                                   \\\n\t\tAT_ERROR(#NAME, \" not implemented for '\", toString(TYPEIN), \"'\");      \\\n\t}\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nfrom .filter import *\nfrom .resample import *\nfrom .act import *\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch.nn as nn\nfrom .resample import UpSample1d, DownSample1d\n\n\nclass Activation1d(nn.Module):\n    def __init__(\n        self,\n        activation,\n        up_ratio: int = 2,\n        down_ratio: int = 2,\n        up_kernel_size: int = 12,\n        down_kernel_size: int = 12,\n    ):\n        super().__init__()\n        self.up_ratio = up_ratio\n        self.down_ratio = down_ratio\n        self.act = activation\n        self.upsample = UpSample1d(up_ratio, up_kernel_size)\n        self.downsample = DownSample1d(down_ratio, down_kernel_size)\n\n    # x: [B,C,T]\n    def forward(self, x):\n        x = self.upsample(x)\n        x = self.act(x)\n        x = self.downsample(x)\n\n        return x\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n\nif \"sinc\" in dir(torch):\n    sinc = torch.sinc\nelse:\n    # This code is adopted from adefossez's julius.core.sinc under the MIT License\n    # https://adefossez.github.io/julius/julius/core.html\n    #   LICENSE is in incl_licenses directory.\n    def sinc(x: torch.Tensor):\n        \"\"\"\n        Implementation of sinc, i.e. sin(pi * x) / (pi * x)\n        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!\n        \"\"\"\n        return torch.where(\n            x == 0,\n            torch.tensor(1.0, device=x.device, dtype=x.dtype),\n            torch.sin(math.pi * x) / math.pi / x,\n        )\n\n\n# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License\n# https://adefossez.github.io/julius/julius/lowpass.html\n#   LICENSE is in incl_licenses directory.\ndef kaiser_sinc_filter1d(\n    cutoff, half_width, kernel_size\n):  # return filter [1,1,kernel_size]\n    even = kernel_size % 2 == 0\n    half_size = kernel_size // 2\n\n    # For kaiser window\n    delta_f = 4 * half_width\n    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95\n    if A > 50.0:\n        beta = 0.1102 * (A - 8.7)\n    elif A >= 21.0:\n        beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)\n    else:\n        beta = 0.0\n    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)\n\n    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio\n    if even:\n        time = torch.arange(-half_size, half_size) + 0.5\n    else:\n        time = torch.arange(kernel_size) - half_size\n    if cutoff == 0:\n        filter_ = torch.zeros_like(time)\n    else:\n        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)\n        \"\"\"\n        Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.\n        \"\"\"\n        filter_ /= filter_.sum()\n        filter = filter_.view(1, 1, kernel_size)\n\n    return filter\n\n\nclass LowPassFilter1d(nn.Module):\n    def __init__(\n        self,\n        cutoff=0.5,\n        half_width=0.6,\n        stride: int = 1,\n        padding: bool = True,\n        padding_mode: str = \"replicate\",\n        kernel_size: int = 12,\n    ):\n        \"\"\"\n        kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.\n        \"\"\"\n        super().__init__()\n        if cutoff < -0.0:\n            raise ValueError(\"Minimum cutoff must be larger than zero.\")\n        if cutoff > 0.5:\n            raise ValueError(\"A cutoff above 0.5 does not make sense.\")\n        self.kernel_size = kernel_size\n        self.even = kernel_size % 2 == 0\n        self.pad_left = kernel_size // 2 - int(self.even)\n        self.pad_right = kernel_size // 2\n        self.stride = stride\n        self.padding = padding\n        self.padding_mode = padding_mode\n        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)\n        self.register_buffer(\"filter\", filter)\n\n    # Input [B, C, T]\n    def forward(self, x):\n        _, C, _ = x.shape\n\n        if self.padding:\n            x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)\n        out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)\n\n        return out\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom .filter import LowPassFilter1d\nfrom .filter import kaiser_sinc_filter1d\n\n\nclass UpSample1d(nn.Module):\n    def __init__(self, ratio=2, kernel_size=None):\n        super().__init__()\n        self.ratio = ratio\n        self.kernel_size = (\n            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n        )\n        self.stride = ratio\n        self.pad = self.kernel_size // ratio - 1\n        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2\n        self.pad_right = (\n            self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2\n        )\n        filter = kaiser_sinc_filter1d(\n            cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size\n        )\n        self.register_buffer(\"filter\", filter)\n\n    # x: [B, C, T]\n    def forward(self, x):\n        _, C, _ = x.shape\n\n        x = F.pad(x, (self.pad, self.pad), mode=\"replicate\")\n        x = self.ratio * F.conv_transpose1d(\n            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C\n        )\n        x = x[..., self.pad_left : -self.pad_right]\n\n        return x\n\n\nclass DownSample1d(nn.Module):\n    def __init__(self, ratio=2, kernel_size=None):\n        super().__init__()\n        self.ratio = ratio\n        self.kernel_size = (\n            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n        )\n        self.lowpass = LowPassFilter1d(\n            cutoff=0.5 / ratio,\n            half_width=0.6 / ratio,\n            stride=ratio,\n            kernel_size=self.kernel_size,\n        )\n\n    def forward(self, x):\n        xx = self.lowpass(x)\n\n        return xx\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/bigvgan.py",
    "content": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\n# Adapted from https://github.com/jik876/hifi-gan under the MIT license.\n#   LICENSE is in incl_licenses directory.\n\nimport os\nimport json\nfrom pathlib import Path\nfrom typing import Optional, Union, Dict\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Conv1d, ConvTranspose1d\nfrom torch.nn.utils import weight_norm, remove_weight_norm\n\nfrom modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta\nfrom modules.audio_detokenizer.vocoder.utils import init_weights, get_padding\nfrom modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d\nfrom modules.audio_detokenizer.vocoder.utils import AttrDict\n\nfrom huggingface_hub import PyTorchModelHubMixin, hf_hub_download\n\n\ndef load_hparams_from_json(path) -> AttrDict:\n    with open(path) as f:\n        data = f.read()\n    return AttrDict(json.loads(data))\n\n\nclass AMPBlock1(torch.nn.Module):\n    \"\"\"\n    AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.\n    AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1\n\n    Args:\n        h (AttrDict): Hyperparameters.\n        channels (int): Number of convolution channels.\n        kernel_size (int): Size of the convolution kernel. Default is 3.\n        dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).\n        activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.\n    \"\"\"\n\n    def __init__(\n        self,\n        h: AttrDict,\n        channels: int,\n        kernel_size: int = 3,\n        dilation: tuple = (1, 3, 5),\n        activation: str = None,\n    ):\n        super().__init__()\n        \n        self.h = h\n\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        stride=1,\n                        dilation=d,\n                        padding=get_padding(kernel_size, d),\n                    )\n                )\n                for d in dilation\n            ]\n        )\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        stride=1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                )\n                for _ in range(len(dilation))\n            ]\n        )\n        self.convs2.apply(init_weights)\n\n        self.num_layers = len(self.convs1) + len(\n            self.convs2\n        )  # Total number of conv layers\n\n        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility\n        if self.h.get(\"use_cuda_kernel\", False):\n            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (\n                Activation1d as CudaActivation1d,\n            )\n\n            Activation1d = CudaActivation1d\n        else:\n            Activation1d = TorchActivation1d\n\n        # Activation functions\n        if activation == \"snake\":\n            self.activations = nn.ModuleList(\n                [\n                    Activation1d(\n                        activation=Snake(\n                            channels, alpha_logscale=h.snake_logscale\n                        )\n                    )\n                    for _ in range(self.num_layers)\n                ]\n            )\n        elif activation == \"snakebeta\":\n            self.activations = nn.ModuleList(\n                [\n                    Activation1d(\n                        activation=SnakeBeta(\n                            channels, alpha_logscale=h.snake_logscale\n                        )\n                    )\n                    for _ in range(self.num_layers)\n                ]\n            )\n        else:\n            raise NotImplementedError(\n                \"activation incorrectly specified. check the config file and look for 'activation'.\"\n            )\n\n    def forward(self, x):\n        acts1, acts2 = self.activations[::2], self.activations[1::2]\n        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):\n            xt = a1(x)\n            xt = c1(xt)\n            xt = a2(xt)\n            xt = c2(xt)\n            x = xt + x\n\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n\nclass AMPBlock2(torch.nn.Module):\n    \"\"\"\n    AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.\n    Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1\n\n    Args:\n        h (AttrDict): Hyperparameters.\n        channels (int): Number of convolution channels.\n        kernel_size (int): Size of the convolution kernel. Default is 3.\n        dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).\n        activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.\n    \"\"\"\n\n    def __init__(\n        self,\n        h: AttrDict,\n        channels: int,\n        kernel_size: int = 3,\n        dilation: tuple = (1, 3, 5),\n        activation: str = None,\n    ):\n        super().__init__()\n        \n        self.h = h\n\n        self.convs = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        stride=1,\n                        dilation=d,\n                        padding=get_padding(kernel_size, d),\n                    )\n                )\n                for d in dilation\n            ]\n        )\n        self.convs.apply(init_weights)\n\n        self.num_layers = len(self.convs)  # Total number of conv layers\n\n        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility\n        if self.h.get(\"use_cuda_kernel\", False):\n            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d  import (\n                Activation1d as CudaActivation1d,\n            )\n\n            Activation1d = CudaActivation1d\n        else:\n            Activation1d = TorchActivation1d\n\n        # Activation functions\n        if activation == \"snake\":\n            self.activations = nn.ModuleList(\n                [\n                    Activation1d(\n                        activation=Snake(\n                            channels, alpha_logscale=h.snake_logscale\n                        )\n                    )\n                    for _ in range(self.num_layers)\n                ]\n            )\n        elif activation == \"snakebeta\":\n            self.activations = nn.ModuleList(\n                [\n                    Activation1d(\n                        activation=SnakeBeta(\n                            channels, alpha_logscale=h.snake_logscale\n                        )\n                    )\n                    for _ in range(self.num_layers)\n                ]\n            )\n        else:\n            raise NotImplementedError(\n                \"activation incorrectly specified. check the config file and look for 'activation'.\"\n            )\n\n    def forward(self, x):\n        for c, a in zip(self.convs, self.activations):\n            xt = a(x)\n            xt = c(xt)\n            x = xt + x\n\n    def remove_weight_norm(self):\n        for l in self.convs:\n            remove_weight_norm(l)\n\n\nclass BigVGAN(\n    torch.nn.Module,\n    PyTorchModelHubMixin,\n    library_name=\"bigvgan\",\n    repo_url=\"https://github.com/NVIDIA/BigVGAN\",\n    docs_url=\"https://github.com/NVIDIA/BigVGAN/blob/main/README.md\",\n    pipeline_tag=\"audio-to-audio\",\n    license=\"mit\",\n    tags=[\"neural-vocoder\", \"audio-generation\", \"arxiv:2206.04658\"],\n):\n    \"\"\"\n    BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).\n    New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.\n\n    Args:\n        h (AttrDict): Hyperparameters.\n        use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.\n\n    Note:\n        - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.\n        - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).\n    \"\"\"\n\n    def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):\n        super().__init__()\n        self.h = h\n        self.h[\"use_cuda_kernel\"] = use_cuda_kernel\n\n        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility\n        if self.h.get(\"use_cuda_kernel\", False):\n            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (\n                Activation1d as CudaActivation1d,\n            )\n\n            Activation1d = CudaActivation1d\n        else:\n            Activation1d = TorchActivation1d\n\n        self.num_kernels = len(h.resblock_kernel_sizes)\n        self.num_upsamples = len(h.upsample_rates)\n\n        # Pre-conv\n        self.conv_pre = weight_norm(\n            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)\n        )\n\n        # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default\n        if h.resblock == \"1\":\n            resblock_class = AMPBlock1\n        elif h.resblock == \"2\":\n            resblock_class = AMPBlock2\n        else:\n            raise ValueError(\n                f\"Incorrect resblock class specified in hyperparameters. Got {h.resblock}\"\n            )\n\n        # Transposed conv-based upsamplers. does not apply anti-aliasing\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):\n            self.ups.append(\n                nn.ModuleList(\n                    [\n                        weight_norm(\n                            ConvTranspose1d(\n                                h.upsample_initial_channel // (2**i),\n                                h.upsample_initial_channel // (2 ** (i + 1)),\n                                k,\n                                u,\n                                padding=(k - u) // 2,\n                            )\n                        )\n                    ]\n                )\n            )\n\n        # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = h.upsample_initial_channel // (2 ** (i + 1))\n            for j, (k, d) in enumerate(\n                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)\n            ):\n                self.resblocks.append(\n                    resblock_class(h, ch, k, d, activation=h.activation)\n                )\n\n        # Post-conv\n        activation_post = (\n            Snake(ch, alpha_logscale=h.snake_logscale)\n            if h.activation == \"snake\"\n            else (\n                SnakeBeta(ch, alpha_logscale=h.snake_logscale)\n                if h.activation == \"snakebeta\"\n                else None\n            )\n        )\n        if activation_post is None:\n            raise NotImplementedError(\n                \"activation incorrectly specified. check the config file and look for 'activation'.\"\n            )\n\n        self.activation_post = Activation1d(activation=activation_post)\n\n        # Whether to use bias for the final conv_post. Default to True for backward compatibility\n        self.use_bias_at_final = h.get(\"use_bias_at_final\", True)\n        self.conv_post = weight_norm(\n            Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)\n        )\n\n        # Weight initialization\n        for i in range(len(self.ups)):\n            self.ups[i].apply(init_weights)\n        self.conv_post.apply(init_weights)\n\n        # Final tanh activation. Defaults to True for backward compatibility\n        self.use_tanh_at_final = h.get(\"use_tanh_at_final\", True)\n\n    def forward(self, x):\n        # Pre-conv\n        x = self.conv_pre(x)\n\n        for i in range(self.num_upsamples):\n            # Upsampling\n            for i_up in range(len(self.ups[i])):\n                x = self.ups[i][i_up](x)\n            # AMP blocks\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n\n        # Post-conv\n        x = self.activation_post(x)\n        x = self.conv_post(x)\n        # Final tanh activation\n        if self.use_tanh_at_final:\n            x = torch.tanh(x)\n        else:\n            x = torch.clamp(x, min=-1.0, max=1.0)  # Bound the output to [-1, 1]\n\n        return x\n\n    def remove_weight_norm(self):\n        try:\n            print(\"Removing weight norm...\")\n            for l in self.ups:\n                for l_i in l:\n                    remove_weight_norm(l_i)\n            for l in self.resblocks:\n                l.remove_weight_norm()\n            remove_weight_norm(self.conv_pre)\n            remove_weight_norm(self.conv_post)\n        except ValueError:\n            print(\"[INFO] Model already removed weight norm. Skipping!\")\n            pass\n\n    # Additional methods for huggingface_hub support\n    def _save_pretrained(self, save_directory: Path) -> None:\n        \"\"\"Save weights and config.json from a Pytorch model to a local directory.\"\"\"\n\n        model_path = save_directory / \"bigvgan_generator.pt\"\n        torch.save({\"generator\": self.state_dict()}, model_path)\n\n        config_path = save_directory / \"config.json\"\n        with open(config_path, \"w\") as config_file:\n            json.dump(self.h, config_file, indent=4)\n\n    @classmethod\n    def _from_pretrained(\n        cls,\n        *,\n        model_id: str,\n        revision: str,\n        cache_dir: str,\n        force_download: bool,\n        proxies: Optional[Dict],\n        resume_download: bool,\n        local_files_only: bool,\n        token: Union[str, bool, None],\n        map_location: str = \"cpu\",  # Additional argument\n        strict: bool = False,  # Additional argument\n        use_cuda_kernel: bool = False,\n        **model_kwargs,\n    ):\n        \"\"\"Load Pytorch pretrained weights and return the loaded model.\"\"\"\n\n        # Download and load hyperparameters (h) used by BigVGAN\n        if os.path.isdir(model_id):\n            print(\"Loading config.json from local directory\")\n            config_file = os.path.join(model_id, \"config.json\")\n        else:\n            config_file = hf_hub_download(\n                repo_id=model_id,\n                filename=\"config.json\",\n                revision=revision,\n                cache_dir=cache_dir,\n                force_download=force_download,\n                proxies=proxies,\n                resume_download=resume_download,\n                token=token,\n                local_files_only=local_files_only,\n            )\n        h = load_hparams_from_json(config_file)\n\n        # instantiate BigVGAN using h\n        if use_cuda_kernel:\n            print(\n                f\"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!\"\n            )\n            print(\n                f\"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!\"\n            )\n            print(\n                f\"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis\"\n            )\n        model = cls(h, use_cuda_kernel=use_cuda_kernel)\n\n        # Download and load pretrained generator weight\n        if os.path.isdir(model_id):\n            print(\"Loading weights from local directory\")\n            model_file = os.path.join(model_id, \"bigvgan_generator.pt\")\n        else:\n            print(f\"Loading weights from {model_id}\")\n            model_file = hf_hub_download(\n                repo_id=model_id,\n                filename=\"bigvgan_generator.pt\",\n                revision=revision,\n                cache_dir=cache_dir,\n                force_download=force_download,\n                proxies=proxies,\n                resume_download=resume_download,\n                token=token,\n                local_files_only=local_files_only,\n            )\n\n        checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)\n\n        try:\n            model.load_state_dict(checkpoint_dict[\"generator\"])\n        except RuntimeError:\n            print(\n                f\"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!\"\n            )\n            model.remove_weight_norm()\n            model.load_state_dict(checkpoint_dict[\"generator\"])\n\n        return model\n"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/utils.py",
    "content": "from librosa.filters import mel as librosa_mel_fn\nimport torch\nimport os\nmel_basis_cache = {}\nhann_window_cache = {}\n\ndef dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n    return torch.log(torch.clamp(x, min=clip_val) * C)\n\n\ndef spectral_normalize_torch(magnitudes):\n    return dynamic_range_compression_torch(magnitudes)\n\ndef get_melspec(\n    y: torch.Tensor,\n    n_fft: int,\n    num_mels: int,\n    sampling_rate: int,\n    hop_size: int,\n    win_size: int,\n    fmin: int,\n    fmax: int = None,\n    center: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Calculate the mel spectrogram of an input signal.\n    This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).\n\n    Args:\n        y (torch.Tensor): Input signal.\n        n_fft (int): FFT size.\n        num_mels (int): Number of mel bins.\n        sampling_rate (int): Sampling rate of the input signal.\n        hop_size (int): Hop size for STFT.\n        win_size (int): Window size for STFT.\n        fmin (int): Minimum frequency for mel filterbank.\n        fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn\n        center (bool): Whether to pad the input to center the frames. Default is False.\n\n    Returns:\n        torch.Tensor: Mel spectrogram.\n    \"\"\"\n    if torch.min(y) < -1.0:\n        print(f\"[WARNING] Min value of input waveform signal is {torch.min(y)}\")\n    if torch.max(y) > 1.0:\n        print(f\"[WARNING] Max value of input waveform signal is {torch.max(y)}\")\n\n    device = y.device\n    key = f\"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}\"\n\n    if key not in mel_basis_cache:\n        mel = librosa_mel_fn(\n            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax\n        )\n        mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)\n        hann_window_cache[key] = torch.hann_window(win_size).to(device)\n\n    mel_basis = mel_basis_cache[key]\n    hann_window = hann_window_cache[key]\n\n    padding = (n_fft - hop_size) // 2\n    y = torch.nn.functional.pad(\n        y.unsqueeze(1), (padding, padding), mode=\"reflect\"\n    ).squeeze(1)\n\n    spec = torch.stft(\n        y,\n        n_fft,\n        hop_length=hop_size,\n        win_length=win_size,\n        window=hann_window,\n        center=center,\n        pad_mode=\"reflect\",\n        normalized=False,\n        onesided=True,\n        return_complex=True,\n    )\n    spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)\n\n    mel_spec = torch.matmul(mel_basis, spec)\n    mel_spec = spectral_normalize_torch(mel_spec)\n\n    return mel_spec\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\ndef load_checkpoint(filepath, device):\n    assert os.path.isfile(filepath)\n    print(f\"Loading '{filepath}'\")\n    checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)\n    print(\"Complete.\")\n    return checkpoint_dict\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size * dilation - dilation) / 2)"
  },
  {
    "path": "modules/audio_tokenizer/audio_tokenizer.py",
    "content": "import torch\nimport librosa\nimport yaml\nfrom transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor\nimport safetensors\nimport accelerate\nimport soundfile as sf\nimport math\nfrom einops import rearrange\nfrom modules.audio_tokenizer.rep_codec import RepCodec\n\n\nclass AudioTokenizer(object):\n    def __init__(self, **kwargs):\n        self.device = kwargs.pop('device')\n        print(self.device)\n        # tokenize\n        feat_stats = kwargs.pop('feat_stats')\n        feat_stats = torch.load(feat_stats, map_location='cpu')\n        self.feat_mean = feat_stats['mean']\n        self.feat_std = torch.sqrt(feat_stats['var'])\n        wav2vec_ckpt = kwargs.pop(\"wav2vec_ckpt\")\n        self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)\n        self.semantic_model.eval()\n        self.semantic_model.to(self.device)\n        self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained(\"facebook/w2v-bert-2.0\")\n\n        self.semantic_codec = RepCodec()\n        self.semantic_codec.eval()\n        pretrained_path = kwargs.pop(\"semantic_codec_ckpt\") \n        safetensors.torch.load_model(self.semantic_codec, pretrained_path)\n        self.semantic_codec.to(self.device)\n\n        self.max_length = 2048\n        \n\n    @torch.no_grad()\n    def tokenize(self, speech):\n        # Input:\n        # speech: torch tensor, shape[B, N_speech]\n        # Output:\n        # semantic token: torch tensor, shape[B, N]\n\n        inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors=\"pt\")\n        input_features = inputs[\"input_features\"].to(self.device)\n        attention_mask = inputs[\"attention_mask\"].to(self.device)\n        seg_num = math.ceil(input_features.shape[1] / self.max_length)\n        pad_num = seg_num * self.max_length - input_features.shape[1]\n        input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)\n        attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)\n        input_features = rearrange(input_features, \"b (s n) d -> (b s) n d\", s =seg_num)\n        attention_mask = rearrange(attention_mask, \"b (s n) -> (b s) n\", s=seg_num)\n\n\n        feats = self.semantic_model(\n            input_features=input_features,\n            attention_mask=attention_mask,\n            output_hidden_states=True,\n        )\n        feat = feats.hidden_states[17]  \n        feat = rearrange(feat, \"(b s) n d -> b (s n) d\", s=seg_num)\n        feat = feat[:, :feat.shape[1]-pad_num, :]\n        feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)\n        semantic_token, _ = self.semantic_codec.quantize(feat)  \n        return semantic_token\n\ndef get_audio_tokenizer():\n    config = dict()\n    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'\n    config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'\n    config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'\n    config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'\n    audio_tokenizer = AudioTokenizer(**config)\n    return audio_tokenizer\n\n"
  },
  {
    "path": "modules/audio_tokenizer/quantize/__init__.py",
    "content": "from .vector_quantize import VectorQuantize\nfrom .residual_vq import ResidualVQ\nfrom .factorized_vector_quantize import FactorizedVectorQuantize\n"
  },
  {
    "path": "modules/audio_tokenizer/quantize/factorized_vector_quantize.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch.nn.utils import weight_norm\n\n\ndef WNConv1d(*args, **kwargs):\n    return weight_norm(nn.Conv1d(*args, **kwargs))\n\n\ndef WNConvTranspose1d(*args, **kwargs):\n    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))\n\n\nclass FactorizedVectorQuantize(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        codebook_size,\n        codebook_dim,\n        commitment=0.005,\n        codebook_loss_weight=1.0,\n        use_l2_normlize=True,\n    ):\n        super().__init__()\n        self.input_dim = input_dim\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n        self.commitment = commitment\n        self.codebook_loss_weight = codebook_loss_weight\n        self.use_l2_normlize = use_l2_normlize\n\n        if self.input_dim != self.codebook_dim:\n            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)\n            self.out_project = WNConv1d(\n                self.codebook_dim, self.input_dim, kernel_size=1\n            )\n\n        else:\n            self.in_project = nn.Identity()\n            self.out_project = nn.Identity()\n\n        self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)\n\n    def forward(self, z):\n        \"\"\"\n        Parameters\n        ----------\n        z: torch.Tensor[B x D x T]\n\n        Returns\n        -------\n        z_q: torch.Tensor[B x D x T]\n            Quantized continuous representation of input\n        commit_loss: Tensor[B]\n            Commitment loss to train encoder to predict vectors closer to codebook entries\n        codebook_loss: Tensor[B]\n            Codebook loss to update the codebook\n        indices: torch.Tensor[B x T]\n            Codebook indices (quantized discrete representation of input)\n        z_e: torch.Tensor[B x D x T]\n            Projected latents (continuous representation of input before quantization)\n        \"\"\"\n\n        # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim\n        z_e = self.in_project(z)\n        z_q, indices = self.decode_latents(z_e)\n\n        # Compute commitment loss and codebook loss\n        if self.training:\n            commit_loss = (\n                F.mse_loss(z_e, z_q.detach(), reduction=\"none\").mean([1, 2])\n                * self.commitment\n            )\n            codebook_loss = (\n                F.mse_loss(z_q, z_e.detach(), reduction=\"none\").mean([1, 2])\n                * self.codebook_loss_weight\n            )\n        else:\n            commit_loss = torch.zeros(z.shape[0], device=z.device)\n            codebook_loss = torch.zeros(z.shape[0], device=z.device)\n\n        z_q = z_e + (z_q - z_e).detach()\n\n        z_q = self.out_project(z_q)\n\n        return z_q, commit_loss, codebook_loss, indices, z_e\n\n    def embed_code(self, embed_id):\n        return F.embedding(embed_id, self.codebook.weight)\n\n    def decode_code(self, embed_id):\n        return self.embed_code(embed_id).transpose(1, 2)\n\n    def decode_latents(self, latents):\n        encodings = rearrange(latents, \"b d t -> (b t) d\")\n        codebook = self.codebook.weight\n\n        # L2 normalize encodings and codebook\n        if self.use_l2_normlize:\n            encodings = F.normalize(encodings)\n            codebook = F.normalize(codebook)\n\n        # Compute euclidean distance between encodings and codebook,\n        # if use_l2_normlize is True, the distance is equal to cosine distance\n        dist = (\n            encodings.pow(2).sum(1, keepdim=True)\n            - 2 * encodings @ codebook.t()\n            + codebook.pow(2).sum(1, keepdim=True).t()\n        )\n        indices = rearrange((-dist).max(1)[1], \"(b t) -> b t\", b=latents.size(0))\n        z_q = self.decode_code(indices)\n\n        return z_q, indices\n\n    def vq2emb(self, vq, out_proj=True):\n        emb = self.decode_code(vq)\n        if out_proj:\n            emb = self.out_project(emb)\n        return emb\n\n    def latent2dist(self, latents):\n        encodings = rearrange(latents, \"b d t -> (b t) d\")\n        codebook = self.codebook.weight\n\n        # L2 normalize encodings and codebook\n        if self.use_l2_normlize:\n            encodings = F.normalize(encodings)\n            codebook = F.normalize(codebook)\n\n        # Compute euclidean distance between encodings and codebook,\n        # if use_l2_normlize is True, the distance is equal to cosine distance\n        dist = (\n            encodings.pow(2).sum(1, keepdim=True)\n            - 2 * encodings @ codebook.t()\n            + codebook.pow(2).sum(1, keepdim=True).t()\n        )  # (b*t, k)\n\n        indices = rearrange((-dist).max(1)[1], \"(b t) -> b t\", b=latents.size(0))\n        dist = rearrange(dist, \"(b t) k -> b t k\", b=latents.size(0))\n        z_q = self.decode_code(indices)\n\n        return -dist, indices, z_q\n"
  },
  {
    "path": "modules/audio_tokenizer/quantize/residual_vq.py",
    "content": "from typing import Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch.nn.utils import weight_norm\n\n\nfrom .vector_quantize import VectorQuantize\nfrom .factorized_vector_quantize import FactorizedVectorQuantize\n\n\nclass ResidualVQ(nn.Module):\n    \"\"\"\n    Introduced in SoundStream: An end2end neural audio codec\n    https://arxiv.org/abs/2107.03312\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim: int = 256,\n        num_quantizers: int = 8,\n        codebook_size: int = 1024,\n        codebook_dim: int = 256,\n        quantizer_type: str = \"vq\",  # \"vq\" or \"fvq\" or \"lfq\"\n        quantizer_dropout: float = 0.5,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.num_quantizers = num_quantizers\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n        self.quantizer_type = quantizer_type\n        self.quantizer_dropout = quantizer_dropout\n\n        if quantizer_type == \"vq\":\n            VQ = VectorQuantize\n        elif quantizer_type == \"fvq\":\n            VQ = FactorizedVectorQuantize\n        else:\n            raise ValueError(f\"Unknown quantizer type {quantizer_type}\")\n\n        self.quantizers = nn.ModuleList(\n            [\n                VQ(\n                    input_dim=input_dim,\n                    codebook_size=codebook_size,\n                    codebook_dim=codebook_dim,\n                    **kwargs,\n                )\n                for _ in range(num_quantizers)\n            ]\n        )\n\n    def forward(self, z, n_quantizers: int = None):\n        \"\"\"\n        Parameters\n        ----------\n        z : Tensor[B x D x T]\n        n_quantizers : int, optional\n            No. of quantizers to use\n            (n_quantizers < self.n_codebooks ex: for quantizer dropout)\n            Note: if `self.quantizer_dropout` is True, this argument is ignored\n                when in training mode, and a random number of quantizers is used.\n        Returns\n        -------\n        \"quantized_out\" : Tensor[B x D x T]\n            Quantized continuous representation of input\n        \"all_indices\" : Tensor[N x B x T]\n            Codebook indices for each codebook\n            (quantized discrete representation of input)\n        \"all_commit_losses\" : Tensor[N]\n        \"all_codebook_losses\" : Tensor[N]\n        \"all_quantized\" : Tensor[N x B x D x T]\n        \"\"\"\n\n        quantized_out = 0.0\n        residual = z\n\n        all_commit_losses = []\n        all_codebook_losses = []\n        all_indices = []\n        all_quantized = []\n\n        if n_quantizers is None:\n            n_quantizers = self.num_quantizers\n\n        if self.training:\n            n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1\n            dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))\n            n_dropout = int(z.shape[0] * self.quantizer_dropout)\n            n_quantizers[:n_dropout] = dropout[:n_dropout]\n            n_quantizers = n_quantizers.to(z.device)\n\n        for i, quantizer in enumerate(self.quantizers):\n            if self.training is False and i >= n_quantizers:\n                break\n\n            z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(\n                residual\n            )\n\n            # Create mask to apply quantizer dropout\n            mask = (\n                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers\n            )\n            quantized_out = quantized_out + z_q_i * mask[:, None, None]\n            residual = residual - z_q_i\n\n            commit_loss_i = (commit_loss_i * mask).mean()\n            codebook_loss_i = (codebook_loss_i * mask).mean()\n\n            all_commit_losses.append(commit_loss_i)\n            all_codebook_losses.append(codebook_loss_i)\n            all_indices.append(indices_i)\n            all_quantized.append(z_q_i)\n\n        all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(\n            torch.stack,\n            (all_commit_losses, all_codebook_losses, all_indices, all_quantized),\n        )\n\n        return (\n            quantized_out,\n            all_indices,\n            all_commit_losses,\n            all_codebook_losses,\n            all_quantized,\n        )\n\n    def vq2emb(self, vq, n_quantizers=None):\n        quantized_out = 0.0\n        if n_quantizers is None:\n            n_quantizers = self.num_quantizers\n        for idx, quantizer in enumerate(self.quantizers):\n            if idx >= n_quantizers:\n                break\n            quantized_out += quantizer.vq2emb(vq[idx])\n        return quantized_out\n\n    def latent2dist(self, z, n_quantizers=None):\n        quantized_out = 0.0\n        residual = z\n\n        all_dists = []\n        all_indices = []\n\n        if n_quantizers is None:\n            n_quantizers = self.num_quantizers\n\n        for i, quantizer in enumerate(self.quantizers):\n            if self.training is False and i >= n_quantizers:\n                break\n            dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)\n            all_dists.append(dist_i)\n            all_indices.append(indices_i)\n\n            quantized_out = quantized_out + z_q_i\n            residual = residual - z_q_i\n\n        all_dists = torch.stack(all_dists)\n        all_indices = torch.stack(all_indices)\n\n        return all_dists, all_indices\n"
  },
  {
    "path": "modules/audio_tokenizer/quantize/vector_quantize.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom torch.nn.utils import weight_norm\n\n\ndef WNConv1d(*args, **kwargs):\n    return weight_norm(nn.Conv1d(*args, **kwargs))\n\n\ndef WNConvTranspose1d(*args, **kwargs):\n    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))\n\n\ndef l2norm(t):\n    return F.normalize(t, p=2, dim=-1)\n\n\ndef ema_inplace(moving_avg, new, decay):\n    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))\n\n\ndef laplace_smoothing(x, n_categories, eps=1e-5):\n    return (x + eps) / (x.sum() + n_categories * eps)\n\n\ndef sample_vectors(samples, num):\n    num_samples, device = samples.shape[0], samples.device\n\n    if num_samples >= num:\n        indices = torch.randperm(num_samples, device=device)[:num]\n    else:\n        indices = torch.randint(0, num_samples, (num,), device=device)\n\n    return samples[indices]\n\n\ndef kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):\n    dim, dtype, device = samples.shape[-1], samples.dtype, samples.device\n\n    means = sample_vectors(samples, num_clusters)\n\n    for _ in range(num_iters):\n        if use_cosine_sim:\n            dists = samples @ means.t()\n        else:\n            diffs = rearrange(samples, \"n d -> n () d\") - rearrange(\n                means, \"c d -> () c d\"\n            )\n            dists = -(diffs**2).sum(dim=-1)\n\n        buckets = dists.max(dim=-1).indices\n        bins = torch.bincount(buckets, minlength=num_clusters)\n        zero_mask = bins == 0\n        bins_min_clamped = bins.masked_fill(zero_mask, 1)\n\n        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)\n        new_means.scatter_add_(0, repeat(buckets, \"n -> n d\", d=dim), samples)\n        new_means = new_means / bins_min_clamped[..., None]\n\n        if use_cosine_sim:\n            new_means = l2norm(new_means)\n\n        means = torch.where(zero_mask[..., None], means, new_means)\n\n    return means, bins\n\n\nclass EuclideanCodebook(nn.Module):\n    def __init__(\n        self,\n        dim,\n        codebook_size,\n        kmeans_init=False,\n        kmeans_iters=10,\n        decay=0.8,\n        eps=1e-5,\n        threshold_ema_dead_code=2,\n        weight_init=False,\n    ):\n        super().__init__()\n\n        self.decay = decay\n        init_fn = torch.randn if not weight_init else torch.zeros\n        embed = init_fn(codebook_size, dim)\n\n        if weight_init:\n            nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)\n\n        self.codebook_size = codebook_size\n        self.kmeans_iters = kmeans_iters\n        self.eps = eps\n        self.threshold_ema_dead_code = threshold_ema_dead_code\n\n        self.register_buffer(\n            \"initted\", torch.Tensor([not kmeans_init])\n        )  # if kmeans_init is True, then initted is False; otherwise, initted is True\n        self.register_buffer(\"cluster_size\", torch.zeros(codebook_size))\n        self.register_buffer(\"embed\", embed)\n        self.register_buffer(\"embed_avg\", embed.clone())\n\n    def init_embed_(self, data):\n        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)\n        self.embed.data.copy_(embed)\n        self.embed_avg.data.copy_(embed)\n        self.cluster_size.data.copy_(cluster_size)\n        self.initted.data.copy_(torch.Tensor([True]))\n\n    def replace(self, samples, mask):\n        modified_codebook = torch.where(\n            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed\n        )\n        self.embed.data.copy_(modified_codebook)\n\n    def expire_codes_(self, batch_samples):\n        if self.threshold_ema_dead_code == 0:\n            return\n\n        expired_codes = self.cluster_size < self.threshold_ema_dead_code\n        if not torch.any(expired_codes):\n            return\n        batch_samples = rearrange(batch_samples, \"... d -> (...) d\")\n        self.replace(batch_samples, mask=expired_codes)\n\n    def forward(self, x):\n        shape, dtype = x.shape, x.dtype\n        flatten = rearrange(x, \"... d -> (...) d\")\n        embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)\n\n        if not self.initted:\n            self.init_embed_(flatten)\n\n        dist = -(\n            flatten.pow(2).sum(1, keepdim=True)\n            - 2 * flatten @ embed\n            + embed.pow(2).sum(0, keepdim=True)\n        )\n\n        embed_ind = dist.max(dim=-1).indices\n        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)\n        embed_ind = embed_ind.view(*shape[:-1])\n        quantize = F.embedding(embed_ind, self.embed)\n\n        if self.training:\n            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)\n            embed_sum = (\n                flatten.t() @ embed_onehot\n            )  # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)\n            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)\n            cluster_size = (\n                laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)\n                * self.cluster_size.sum()\n            )\n            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)\n            self.embed.data.copy_(embed_normalized)\n            self.expire_codes_(x)\n\n        return quantize, embed_ind\n\n    def vq2emb(self, vq):\n        quantize = F.embedding(vq, self.embed)\n        return quantize\n\n    def latent2dist(self, x):\n        shape, dtype = x.shape, x.dtype\n        flatten = rearrange(x, \"... d -> (...) d\")\n        embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)\n\n        if not self.initted:\n            self.init_embed_(flatten)\n\n        dist = -(\n            flatten.pow(2).sum(1, keepdim=True)\n            - 2 * flatten @ embed\n            + embed.pow(2).sum(0, keepdim=True)\n        )\n\n        embed_ind = dist.max(dim=-1).indices\n        embed_ind = embed_ind.view(*shape[:-1])\n        quantize = F.embedding(embed_ind, self.embed)\n\n        dist = dist.view(*shape[:-1], -1)\n\n        return dist, embed_ind, quantize\n\n\nclass SimpleCodebook(nn.Module):\n    def __init__(\n        self,\n        dim,\n        codebook_size,\n        use_l2_normlize=False,\n    ):\n        super().__init__()\n\n        self.dim = dim\n        self.codebook_size = codebook_size\n        self.use_l2_normlize = use_l2_normlize\n\n        self.embed = nn.Embedding(self.codebook_size, self.dim)\n\n    def forward(self, x):\n        shape, dtype = x.shape, x.dtype\n        flatten = rearrange(x, \"... d -> (...) d\")\n        embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)\n\n        if self.use_l2_normlize:\n            flatten = F.normalize(flatten)\n            embed = F.normalize(embed)\n\n        dist = -(\n            flatten.pow(2).sum(1, keepdim=True)\n            - 2 * flatten @ embed\n            + embed.pow(2).sum(0, keepdim=True)\n        )\n\n        embed_ind = dist.max(dim=-1).indices\n        embed_ind = embed_ind.view(*shape[:-1])\n        quantize = F.embedding(embed_ind, self.embed)\n\n        return quantize, embed_ind\n\n    def vq2emb(self, vq):\n        quantize = F.embedding(vq, self.embed.weight)\n        return quantize\n\n    def latent2dist(self, x):\n        shape, dtype = x.shape, x.dtype\n        flatten = rearrange(x, \"... d -> (...) d\")\n        embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)\n\n        if self.use_l2_normlize:\n            flatten = F.normalize(flatten)\n            embed = F.normalize(embed)\n\n        dist = -(\n            flatten.pow(2).sum(1, keepdim=True)\n            - 2 * flatten @ embed\n            + embed.pow(2).sum(0, keepdim=True)\n        )\n\n        embed_ind = dist.max(dim=-1).indices\n        embed_ind = embed_ind.view(*shape[:-1])\n        quantize = F.embedding(embed_ind, self.embed)\n\n        dist = dist.view(*shape[:-1], -1)\n\n        return dist, embed_ind, quantize\n\n\nclass VectorQuantize(nn.Module):\n    \"\"\"Vector quantization and factorized vecotor quantization implementation\n    Args:\n        input_dim (int): Dimension of input.\n        codebook_size (int): Codebook size.\n        codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim\n            if use codebook_type == \"euclidean\", otherwise, if you want to use\n            factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).\n        commitment (float): Weight for commitment loss.\n        use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,\n            we suggest use it as True if you want to use factorized vector quantization\n        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.\n        kmeans_iters (int): Number of iterations used for kmeans initialization.\n        decay (float): Decay for exponential moving average over the codebooks.\n        epsilon (float): Epsilon value for numerical stability.\n        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes\n            that have an exponential moving average cluster size less than the specified threshold with\n            randomly selected vector from the current batch.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim,\n        codebook_size,\n        codebook_dim,\n        commitment=0.005,\n        codebook_loss_weight=1.0,\n        use_l2_normlize=False,\n        codebook_type=\"euclidean\",  # \"euclidean\" or \"simple\"\n        kmeans_init=False,\n        kmeans_iters=10,\n        decay=0.8,\n        eps=1e-5,\n        threshold_ema_dead_code=2,\n        weight_init=False,\n    ):\n        super().__init__()\n        self.input_dim = input_dim\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n        self.commitment = commitment\n        self.codebook_loss_weight = codebook_loss_weight\n        self.use_l2_normlize = use_l2_normlize\n        self.codebook_type = codebook_type\n        self.kmeans_init = kmeans_init\n        self.kmeans_iters = kmeans_iters\n        self.decay = decay\n        self.eps = eps\n        self.threshold_ema_dead_code = threshold_ema_dead_code\n        self.weight_init = weight_init\n\n        if self.input_dim != self.codebook_dim:\n            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)\n            self.out_project = WNConv1d(\n                self.codebook_dim, self.input_dim, kernel_size=1\n            )\n\n        else:\n            self.in_project = nn.Identity()\n            self.out_project = nn.Identity()\n\n        if self.codebook_type == \"euclidean\":\n            self.codebook = EuclideanCodebook(\n                self.codebook_dim,\n                codebook_size=self.codebook_size,\n                kmeans_init=self.kmeans_init,\n                kmeans_iters=self.kmeans_iters,\n                decay=self.decay,\n                eps=self.eps,\n                threshold_ema_dead_code=self.threshold_ema_dead_code,\n                weight_init=self.weight_init,\n            )\n        elif self.codebook_type == \"simple\":\n            self.codebook = SimpleCodebook(\n                self.codebook_dim,\n                codebook_size=self.codebook_size,\n                use_l2_normlize=self.use_l2_normlize,\n            )\n        else:\n            raise NotImplementedError(\n                f\"codebook_type {self.codebook_type} is not implemented!\"\n            )\n\n    def forward(self, z):\n        \"\"\"\n        Parameters\n        ----------\n        z: torch.Tensor[B x D x T]\n\n        Returns\n        -------\n        z_q: torch.Tensor[B x D x T]\n            Quantized continuous representation of input\n        commit_loss: Tensor[B]\n            Commitment loss to train encoder to predict vectors closer to codebook entries\n        codebook_loss: Tensor[B]\n            Codebook loss to update the codebook\n        indices: torch.Tensor[B x T]\n            Codebook indices (quantized discrete representation of input)\n        z_e: torch.Tensor[B x D x T]\n            Projected latents (continuous representation of input before quantization)\n        \"\"\"\n\n        # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim\n        z_e = self.in_project(z)\n        z_q, indices = self.decode_latents(z_e)\n\n        # Compute commitment loss and codebook loss\n        if self.training:\n            commit_loss = (\n                F.mse_loss(z_e, z_q.detach(), reduction=\"none\").mean([1, 2])\n                * self.commitment\n            )\n            codebook_loss = (\n                F.mse_loss(z_q, z_e.detach(), reduction=\"none\").mean([1, 2])\n                * self.codebook_loss_weight\n            )\n        else:\n            commit_loss = torch.zeros(z.shape[0], device=z.device)\n            codebook_loss = torch.zeros(z.shape[0], device=z.device)\n\n        z_q = z_e + (z_q - z_e).detach()\n\n        z_q = self.out_project(z_q)\n\n        return z_q, commit_loss, codebook_loss, indices, z_e\n\n    def decode_latents(self, latents):\n        encodings = rearrange(latents, \"b d t -> b t d\")\n        z_q, indices = self.codebook(encodings)\n        z_q = z_q.transpose(1, 2)\n        return z_q, indices\n\n    def vq2emb(self, vq, out_proj=True):\n        emb = self.codebook.vq2emb(vq)\n        emb = emb.transpose(1, 2)\n        if out_proj:\n            emb = self.out_project(emb)\n        return emb\n\n    def latent2dist(self, latents):\n        latents = rearrange(latents, \"b d t -> b t d\")\n        dist, embed_ind, quantize = self.codebook.latent2dist(latents)\n        return dist, embed_ind, quantize.transpose(1, 2)\n"
  },
  {
    "path": "modules/audio_tokenizer/rep_codec.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nfrom modules.audio_tokenizer.quantize import ResidualVQ\nfrom modules.audio_tokenizer.vocos import VocosBackbone\nfrom modules.audio_tokenizer.transformer import TransformerEncoder\n\ndef init_weights(m):\n    if isinstance(m, nn.Conv1d):\n        nn.init.trunc_normal_(m.weight, std=0.02)\n        nn.init.constant_(m.bias, 0)\n    if isinstance(m, nn.Linear):\n        nn.init.trunc_normal_(m.weight, std=0.02)\n        nn.init.constant_(m.bias, 0)\n\nclass RepCodec(nn.Module):\n    def __init__(\n        self,\n        codebook_size=8192,\n        hidden_size=1024,\n        codebook_dim=8,\n        vocos_dim=384,\n        vocos_intermediate_dim=2048,\n        vocos_num_layers=12,\n        num_quantizers=1,\n        use_timbre_encoder=False,\n        cfg=None,\n    ):\n        super().__init__()\n        codebook_size = (\n            cfg.codebook_size\n            if cfg is not None and hasattr(cfg, \"codebook_size\")\n            else codebook_size\n        )\n        codebook_dim = (\n            cfg.codebook_dim\n            if cfg is not None and hasattr(cfg, \"codebook_dim\")\n            else codebook_dim\n        )\n        hidden_size = (\n            cfg.hidden_size\n            if cfg is not None and hasattr(cfg, \"hidden_size\")\n            else hidden_size\n        )\n        vocos_dim = (\n            cfg.vocos_dim\n            if cfg is not None and hasattr(cfg, \"vocos_dim\")\n            else vocos_dim\n        )\n        vocos_intermediate_dim = (\n            cfg.vocos_intermediate_dim\n            if cfg is not None and hasattr(cfg, \"vocos_dim\")\n            else vocos_intermediate_dim\n        )\n        vocos_num_layers = (\n            cfg.vocos_num_layers\n            if cfg is not None and hasattr(cfg, \"vocos_dim\")\n            else vocos_num_layers\n        )\n        num_quantizers = (\n            cfg.num_quantizers\n            if cfg is not None and hasattr(cfg, \"num_quantizers\")\n            else num_quantizers\n        )\n        use_timbre_encoder = (\n            cfg.use_timbre_encoder\n            if cfg is not None and hasattr(cfg, \"use_timbre_encoder\")\n            else use_timbre_encoder\n        )\n\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n        self.hidden_size = hidden_size\n        self.vocos_dim = vocos_dim\n        self.vocos_intermediate_dim = vocos_intermediate_dim\n        self.vocos_num_layers = vocos_num_layers\n        self.num_quantizers = num_quantizers\n        self.use_timbre_encoder = use_timbre_encoder\n\n        self.encoder = nn.Sequential(\n            VocosBackbone(\n                input_channels=self.hidden_size,\n                dim=384,\n                intermediate_dim=2048,\n                num_layers=12,\n                adanorm_num_embeddings=None\n            ),\n            nn.Linear(384, self.hidden_size)\n        )\n        self.decoder = nn.Sequential(\n            VocosBackbone(\n                input_channels=self.hidden_size,\n                dim=384,\n                intermediate_dim=2048,\n                num_layers=12,\n                adanorm_num_embeddings=None\n            ),\n            nn.Linear(384, self.hidden_size)\n        )\n\n        self.quantizer = ResidualVQ(\n            input_dim=hidden_size,\n            num_quantizers=num_quantizers,\n            codebook_size=codebook_size,\n            codebook_dim=codebook_dim,\n            quantizer_type=\"fvq\",\n            quantizer_dropout=0.0,\n            commitment=0.15,\n            codebook_loss_weight=1.0,\n            use_l2_normlize=True,\n        )\n\n        if self.use_timbre_encoder:   #TODO: write encoder hidden (256) as a hyparam\n            self.timbre_in = nn.Linear(hidden_size, 256)\n            self.timbre_encoder = TransformerEncoder(\n                enc_emb_tokens=None,\n                encoder_layer=4,\n                encoder_hidden=256,\n                encoder_head=4,\n                conv_filter_size=1024,\n                conv_kernel_size=5,\n                encoder_dropout=0.1,\n                use_pe=False,\n                cfg=None,\n            )\n            self.timbre_out = nn.Linear(256, hidden_size)\n            self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2)\n            self.timbre_linear.bias.data[:hidden_size] = 1\n            self.timbre_linear.bias.data[hidden_size:] = 0\n            self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)\n            self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)\n\n        self.reset_parameters()\n\n    def forward(self, x):\n\n        x = self.encoder(x.transpose(1, 2)).transpose(1, 2)\n\n        if self.use_timbre_encoder:\n            x_timbre = x\n            x = x.transpose(1, 2)\n            x = self.enc_ln(x)\n            x = x.transpose(1, 2)\n\n        (\n            quantized_out,\n            all_indices,\n            all_commit_losses,\n            all_codebook_losses,\n            _,\n        ) = self.quantizer(x)\n\n        if self.use_timbre_encoder:\n            x_timbre = x_timbre.transpose(1, 2)\n            x_timbre = self.timbre_in(x_timbre)\n            x_timbre = self.timbre_encoder(x_timbre, None, None)\n            x_timbre = self.timbre_out(x_timbre)\n            x_timbre = x_timbre.transpose(1, 2)\n            spk_embs = torch.mean(x_timbre, dim=2)\n\n            style = self.timbre_linear(spk_embs).unsqueeze(2)  # (B, 2d, 1)\n            gamma, beta = style.chunk(2, 1)  # (B, d, 1)\n            quantized_out = quantized_out.transpose(1, 2)\n            quantized_out = self.timbre_norm(quantized_out)\n            quantized_out = quantized_out.transpose(1, 2)\n            quantized_out = quantized_out * gamma + beta\n        \n\n        x_rec = self.decoder(quantized_out)\n\n        codebook_loss = (all_codebook_losses + all_commit_losses).mean()\n        all_indices = all_indices\n\n        return x_rec, codebook_loss, all_indices\n\n    def quantize(self, x):\n        x = self.encoder(x.transpose(1, 2)).transpose(1, 2)\n\n        if self.use_timbre_encoder:\n            x = x.transpose(1, 2)\n            x = self.enc_ln(x)\n            x = x.transpose(1, 2)\n\n        (\n            quantized_out,\n            all_indices,\n            all_commit_losses,\n            all_codebook_losses,\n            _,\n        ) = self.quantizer(x)\n        if all_indices.shape[0] == 1:\n            return all_indices.squeeze(0), quantized_out.transpose(1, 2)\n        return all_indices, quantized_out.transpose(1, 2)\n\n    def reset_parameters(self):\n        self.apply(init_weights)\n"
  },
  {
    "path": "modules/audio_tokenizer/transformer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\n\n\nclass StyleAdaptiveLayerNorm(nn.Module):\n    def __init__(self, normalized_shape, eps=1e-5):\n        super().__init__()\n        self.in_dim = normalized_shape\n        self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)\n        self.style = nn.Linear(self.in_dim, self.in_dim * 2)\n        self.style.bias.data[: self.in_dim] = 1\n        self.style.bias.data[self.in_dim :] = 0\n\n    def forward(self, x, condition):\n        # x: (B, T, d); condition: (B, T, d)\n\n        style = self.style(torch.mean(condition, dim=1, keepdim=True))\n\n        gamma, beta = style.chunk(2, -1)\n\n        out = self.norm(x)\n\n        out = gamma * out + beta\n        return out\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout, max_len=5000):\n        super().__init__()\n\n        self.dropout = dropout\n        position = torch.arange(max_len).unsqueeze(1)\n        div_term = torch.exp(\n            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)\n        )\n        pe = torch.zeros(max_len, 1, d_model)\n        pe[:, 0, 0::2] = torch.sin(position * div_term)\n        pe[:, 0, 1::2] = torch.cos(position * div_term)\n        self.register_buffer(\"pe\", pe)\n\n    def forward(self, x):\n        x = x + self.pe[: x.size(0)]\n        return F.dropout(x, self.dropout, training=self.training)\n\n\nclass TransformerFFNLayer(nn.Module):\n    def __init__(\n        self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout\n    ):\n        super().__init__()\n\n        self.encoder_hidden = encoder_hidden\n        self.conv_filter_size = conv_filter_size\n        self.conv_kernel_size = conv_kernel_size\n        self.encoder_dropout = encoder_dropout\n\n        self.ffn_1 = nn.Conv1d(\n            self.encoder_hidden,\n            self.conv_filter_size,\n            self.conv_kernel_size,\n            padding=self.conv_kernel_size // 2,\n        )\n        self.ffn_1.weight.data.normal_(0.0, 0.02)\n        self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)\n        self.ffn_2.weight.data.normal_(0.0, 0.02)\n\n    def forward(self, x):\n        # x: (B, T, d)\n        x = self.ffn_1(x.permute(0, 2, 1)).permute(\n            0, 2, 1\n        )  # (B, T, d) -> (B, d, T) -> (B, T, d)\n        x = F.relu(x)\n        x = F.dropout(x, self.encoder_dropout, training=self.training)\n        x = self.ffn_2(x)\n        return x\n\n\nclass TransformerEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        encoder_hidden,\n        encoder_head,\n        conv_filter_size,\n        conv_kernel_size,\n        encoder_dropout,\n        use_cln,\n    ):\n        super().__init__()\n        self.encoder_hidden = encoder_hidden\n        self.encoder_head = encoder_head\n        self.conv_filter_size = conv_filter_size\n        self.conv_kernel_size = conv_kernel_size\n        self.encoder_dropout = encoder_dropout\n        self.use_cln = use_cln\n\n        if not self.use_cln:\n            self.ln_1 = nn.LayerNorm(self.encoder_hidden)\n            self.ln_2 = nn.LayerNorm(self.encoder_hidden)\n        else:\n            self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)\n            self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)\n\n        self.self_attn = nn.MultiheadAttention(\n            self.encoder_hidden, self.encoder_head, batch_first=True\n        )\n\n        self.ffn = TransformerFFNLayer(\n            self.encoder_hidden,\n            self.conv_filter_size,\n            self.conv_kernel_size,\n            self.encoder_dropout,\n        )\n\n    def forward(self, x, key_padding_mask, conditon=None):\n        # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)\n\n        # self attention\n        residual = x\n        if self.use_cln:\n            x = self.ln_1(x, conditon)\n        else:\n            x = self.ln_1(x)\n\n        if key_padding_mask != None:\n            key_padding_mask_input = ~(key_padding_mask.bool())\n        else:\n            key_padding_mask_input = None\n        x, _ = self.self_attn(\n            query=x, key=x, value=x, key_padding_mask=key_padding_mask_input\n        )\n        x = F.dropout(x, self.encoder_dropout, training=self.training)\n        x = residual + x\n\n        # ffn\n        residual = x\n        if self.use_cln:\n            x = self.ln_2(x, conditon)\n        else:\n            x = self.ln_2(x)\n        x = self.ffn(x)\n        x = residual + x\n\n        return x\n\n\nclass TransformerEncoder(nn.Module):\n    def __init__(\n        self,\n        enc_emb_tokens=None,\n        encoder_layer=4,\n        encoder_hidden=256,\n        encoder_head=4,\n        conv_filter_size=1024,\n        conv_kernel_size=5,\n        encoder_dropout=0.1,\n        use_cln=False,\n        use_pe=True,\n        cfg=None,\n    ):\n        super().__init__()\n\n        self.encoder_layer = (\n            encoder_layer if encoder_layer is not None else cfg.encoder_layer\n        )\n        self.encoder_hidden = (\n            encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden\n        )\n        self.encoder_head = (\n            encoder_head if encoder_head is not None else cfg.encoder_head\n        )\n        self.conv_filter_size = (\n            conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size\n        )\n        self.conv_kernel_size = (\n            conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size\n        )\n        self.encoder_dropout = (\n            encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout\n        )\n        self.use_pe = use_pe if use_pe is not None else cfg.use_pe\n        self.use_cln = use_cln if use_cln is not None else cfg.use_cln\n\n        if enc_emb_tokens != None:\n            self.use_enc_emb = True\n            self.enc_emb_tokens = enc_emb_tokens\n        else:\n            self.use_enc_emb = False\n\n        if self.use_pe:\n            self.position_emb = PositionalEncoding(\n                self.encoder_hidden, self.encoder_dropout\n            )\n\n        self.layers = nn.ModuleList([])\n        self.layers.extend(\n            [\n                TransformerEncoderLayer(\n                    self.encoder_hidden,\n                    self.encoder_head,\n                    self.conv_filter_size,\n                    self.conv_kernel_size,\n                    self.encoder_dropout,\n                    self.use_cln,\n                )\n                for i in range(self.encoder_layer)\n            ]\n        )\n\n        if self.use_cln:\n            self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)\n        else:\n            self.last_ln = nn.LayerNorm(self.encoder_hidden)\n\n    def forward(self, x, key_padding_mask, condition=None):\n        if len(x.shape) == 2 and self.use_enc_emb:\n            x = self.enc_emb_tokens(x)\n            if self.use_pe:\n                x = self.position_emb(x)\n        else:\n            if self.use_pe:\n                x = self.position_emb(x)  # (B, T, d)\n\n        for layer in self.layers:\n            x = layer(x, key_padding_mask, condition)\n\n        if self.use_cln:\n            x = self.last_ln(x, condition)\n        else:\n            x = self.last_ln(x)\n\n        return x\n"
  },
  {
    "path": "modules/audio_tokenizer/vocos.py",
    "content": "from typing import Optional, Tuple\n\nimport numpy as np\nimport scipy\nimport torch\nfrom torch import nn, view_as_real, view_as_complex\nfrom torch import nn\nfrom torch.nn.utils import weight_norm, remove_weight_norm\nfrom torchaudio.functional.functional import _hz_to_mel, _mel_to_hz\n\n\ndef safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:\n    \"\"\"\n    Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.\n\n    Args:\n        x (Tensor): Input tensor.\n        clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.\n\n    Returns:\n        Tensor: Element-wise logarithm of the input tensor with clipping applied.\n    \"\"\"\n    return torch.log(torch.clip(x, min=clip_val))\n\n\ndef symlog(x: torch.Tensor) -> torch.Tensor:\n    return torch.sign(x) * torch.log1p(x.abs())\n\n\ndef symexp(x: torch.Tensor) -> torch.Tensor:\n    return torch.sign(x) * (torch.exp(x.abs()) - 1)\n\n\nclass STFT(nn.Module):\n    def __init__(\n        self,\n        n_fft: int,\n        hop_length: int,\n        win_length: int,\n        center=True,\n    ):\n        super().__init__()\n        self.center = center\n        self.n_fft = n_fft\n        self.hop_length = hop_length\n        self.win_length = win_length\n        window = torch.hann_window(win_length)\n        self.register_buffer(\"window\", window)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # x: (B, T * hop_length)\n\n        if not self.center:\n            pad = self.win_length - self.hop_length\n            x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode=\"reflect\")\n\n        stft_spec = torch.stft(\n            x,\n            self.n_fft,\n            hop_length=self.hop_length,\n            win_length=self.win_length,\n            window=self.window,\n            center=self.center,\n            return_complex=False,\n        )  # (B, n_fft // 2 + 1, T, 2)\n\n        rea = stft_spec[:, :, :, 0]  # (B, n_fft // 2 + 1, T, 2)\n        imag = stft_spec[:, :, :, 1]  # (B, n_fft // 2 + 1, T, 2)\n\n        log_mag = torch.log(\n            torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5\n        )  # (B, n_fft // 2 + 1, T)\n        phase = torch.atan2(imag, rea)  # (B, n_fft // 2 + 1, T)\n\n        return log_mag, phase\n\n\nclass ISTFT(nn.Module):\n    \"\"\"\n    Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with\n    windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.\n    See issue: https://github.com/pytorch/pytorch/issues/62323\n    Specifically, in the context of neural vocoding we are interested in \"same\" padding analogous to CNNs.\n    The NOLA constraint is met as we trim padded samples anyway.\n\n    Args:\n        n_fft (int): Size of Fourier transform.\n        hop_length (int): The distance between neighboring sliding window frames.\n        win_length (int): The size of window frame and STFT filter.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n    \"\"\"\n\n    def __init__(\n        self, n_fft: int, hop_length: int, win_length: int, padding: str = \"same\"\n    ):\n        super().__init__()\n        if padding not in [\"center\", \"same\"]:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n        self.padding = padding\n        self.n_fft = n_fft\n        self.hop_length = hop_length\n        self.win_length = win_length\n        window = torch.hann_window(win_length)\n        self.register_buffer(\"window\", window)\n\n    def forward(self, spec: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.\n\n        Args:\n            spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,\n                            N is the number of frequency bins, and T is the number of time frames.\n\n        Returns:\n            Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.\n        \"\"\"\n        if self.padding == \"center\":\n            # Fallback to pytorch native implementation\n            return torch.istft(\n                spec,\n                self.n_fft,\n                self.hop_length,\n                self.win_length,\n                self.window,\n                center=True,\n            )\n        elif self.padding == \"same\":\n            pad = (self.win_length - self.hop_length) // 2\n        else:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n\n        assert spec.dim() == 3, \"Expected a 3D tensor as input\"\n        B, N, T = spec.shape\n\n        # Inverse FFT\n        ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm=\"backward\")\n        ifft = ifft * self.window[None, :, None]\n\n        # Overlap and Add\n        output_size = (T - 1) * self.hop_length + self.win_length\n        y = torch.nn.functional.fold(\n            ifft,\n            output_size=(1, output_size),\n            kernel_size=(1, self.win_length),\n            stride=(1, self.hop_length),\n        )[:, 0, 0, pad:-pad]\n\n        # Window envelope\n        window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)\n        window_envelope = torch.nn.functional.fold(\n            window_sq,\n            output_size=(1, output_size),\n            kernel_size=(1, self.win_length),\n            stride=(1, self.hop_length),\n        ).squeeze()[pad:-pad]\n\n        # Normalize\n        assert (window_envelope > 1e-11).all()\n        y = y / window_envelope\n\n        return y\n\n\nclass MDCT(nn.Module):\n    \"\"\"\n    Modified Discrete Cosine Transform (MDCT) module.\n\n    Args:\n        frame_len (int): Length of the MDCT frame.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n    \"\"\"\n\n    def __init__(self, frame_len: int, padding: str = \"same\"):\n        super().__init__()\n        if padding not in [\"center\", \"same\"]:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n        self.padding = padding\n        self.frame_len = frame_len\n        N = frame_len // 2\n        n0 = (N + 1) / 2\n        window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()\n        self.register_buffer(\"window\", window)\n\n        pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)\n        post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)\n        # view_as_real: NCCL Backend does not support ComplexFloat data type\n        # https://github.com/pytorch/pytorch/issues/71613\n        self.register_buffer(\"pre_twiddle\", view_as_real(pre_twiddle))\n        self.register_buffer(\"post_twiddle\", view_as_real(post_twiddle))\n\n    def forward(self, audio: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.\n\n        Args:\n            audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size\n                and T is the length of the audio.\n\n        Returns:\n            Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames\n                and N is the number of frequency bins.\n        \"\"\"\n        if self.padding == \"center\":\n            audio = torch.nn.functional.pad(\n                audio, (self.frame_len // 2, self.frame_len // 2)\n            )\n        elif self.padding == \"same\":\n            # hop_length is 1/2 frame_len\n            audio = torch.nn.functional.pad(\n                audio, (self.frame_len // 4, self.frame_len // 4)\n            )\n        else:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n\n        x = audio.unfold(-1, self.frame_len, self.frame_len // 2)\n        N = self.frame_len // 2\n        x = x * self.window.expand(x.shape)\n        X = torch.fft.fft(\n            x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1\n        )[..., :N]\n        res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)\n        return torch.real(res) * np.sqrt(2)\n\n\nclass IMDCT(nn.Module):\n    \"\"\"\n    Inverse Modified Discrete Cosine Transform (IMDCT) module.\n\n    Args:\n        frame_len (int): Length of the MDCT frame.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n    \"\"\"\n\n    def __init__(self, frame_len: int, padding: str = \"same\"):\n        super().__init__()\n        if padding not in [\"center\", \"same\"]:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n        self.padding = padding\n        self.frame_len = frame_len\n        N = frame_len // 2\n        n0 = (N + 1) / 2\n        window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()\n        self.register_buffer(\"window\", window)\n\n        pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)\n        post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))\n        self.register_buffer(\"pre_twiddle\", view_as_real(pre_twiddle))\n        self.register_buffer(\"post_twiddle\", view_as_real(post_twiddle))\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.\n\n        Args:\n            X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,\n                L is the number of frames, and N is the number of frequency bins.\n\n        Returns:\n            Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.\n        \"\"\"\n        B, L, N = X.shape\n        Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)\n        Y[..., :N] = X\n        Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))\n        y = torch.fft.ifft(\n            Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1\n        )\n        y = (\n            torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))\n            * np.sqrt(N)\n            * np.sqrt(2)\n        )\n        result = y * self.window.expand(y.shape)\n        output_size = (1, (L + 1) * N)\n        audio = torch.nn.functional.fold(\n            result.transpose(1, 2),\n            output_size=output_size,\n            kernel_size=(1, self.frame_len),\n            stride=(1, self.frame_len // 2),\n        )[:, 0, 0, :]\n\n        if self.padding == \"center\":\n            pad = self.frame_len // 2\n        elif self.padding == \"same\":\n            pad = self.frame_len // 4\n        else:\n            raise ValueError(\"Padding must be 'center' or 'same'.\")\n\n        audio = audio[:, pad:-pad]\n        return audio\n\n\nclass FourierHead(nn.Module):\n    \"\"\"Base class for inverse fourier modules.\"\"\"\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,\n                        L is the sequence length, and H denotes the model dimension.\n\n        Returns:\n            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement the forward method.\")\n\n\nclass ISTFTHead(FourierHead):\n    \"\"\"\n    ISTFT Head module for predicting STFT complex coefficients.\n\n    Args:\n        dim (int): Hidden dimension of the model.\n        n_fft (int): Size of Fourier transform.\n        hop_length (int): The distance between neighboring sliding window frames, which should align with\n                          the resolution of the input features.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n    \"\"\"\n\n    def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = \"same\"):\n        super().__init__()\n        out_dim = n_fft + 2\n        self.out = torch.nn.Linear(dim, out_dim)\n        self.istft = ISTFT(\n            n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass of the ISTFTHead module.\n\n        Args:\n            x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,\n                        L is the sequence length, and H denotes the model dimension.\n\n        Returns:\n            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.\n        \"\"\"\n        x = self.out(x).transpose(1, 2)\n        mag, p = x.chunk(2, dim=1)\n        mag = torch.exp(mag)\n        mag = torch.clip(\n            mag, max=1e2\n        )  # safeguard to prevent excessively large magnitudes\n        # wrapping happens here. These two lines produce real and imaginary value\n        x = torch.cos(p)\n        y = torch.sin(p)\n        # recalculating phase here does not produce anything new\n        # only costs time\n        # phase = torch.atan2(y, x)\n        # S = mag * torch.exp(phase * 1j)\n        # better directly produce the complex value\n        S = mag * (x + 1j * y)\n        audio = self.istft(S)\n        return audio\n\n\nclass IMDCTSymExpHead(FourierHead):\n    \"\"\"\n    IMDCT Head module for predicting MDCT coefficients with symmetric exponential function\n\n    Args:\n        dim (int): Hidden dimension of the model.\n        mdct_frame_len (int): Length of the MDCT frame.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n        sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized\n                                     based on perceptual scaling. Defaults to None.\n        clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        mdct_frame_len: int,\n        padding: str = \"same\",\n        sample_rate: Optional[int] = None,\n        clip_audio: bool = False,\n    ):\n        super().__init__()\n        out_dim = mdct_frame_len // 2\n        self.out = nn.Linear(dim, out_dim)\n        self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)\n        self.clip_audio = clip_audio\n\n        if sample_rate is not None:\n            # optionally init the last layer following mel-scale\n            m_max = _hz_to_mel(sample_rate // 2)\n            m_pts = torch.linspace(0, m_max, out_dim)\n            f_pts = _mel_to_hz(m_pts)\n            scale = 1 - (f_pts / f_pts.max())\n\n            with torch.no_grad():\n                self.out.weight.mul_(scale.view(-1, 1))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass of the IMDCTSymExpHead module.\n\n        Args:\n            x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,\n                        L is the sequence length, and H denotes the model dimension.\n\n        Returns:\n            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.\n        \"\"\"\n        x = self.out(x)\n        x = symexp(x)\n        x = torch.clip(\n            x, min=-1e2, max=1e2\n        )  # safeguard to prevent excessively large magnitudes\n        audio = self.imdct(x)\n        if self.clip_audio:\n            audio = torch.clip(x, min=-1.0, max=1.0)\n\n        return audio\n\n\nclass IMDCTCosHead(FourierHead):\n    \"\"\"\n    IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)\n\n    Args:\n        dim (int): Hidden dimension of the model.\n        mdct_frame_len (int): Length of the MDCT frame.\n        padding (str, optional): Type of padding. Options are \"center\" or \"same\". Defaults to \"same\".\n        clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        mdct_frame_len: int,\n        padding: str = \"same\",\n        clip_audio: bool = False,\n    ):\n        super().__init__()\n        self.clip_audio = clip_audio\n        self.out = nn.Linear(dim, mdct_frame_len)\n        self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass of the IMDCTCosHead module.\n\n        Args:\n            x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,\n                        L is the sequence length, and H denotes the model dimension.\n\n        Returns:\n            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.\n        \"\"\"\n        x = self.out(x)\n        m, p = x.chunk(2, dim=2)\n        m = torch.exp(m).clip(\n            max=1e2\n        )  # safeguard to prevent excessively large magnitudes\n        audio = self.imdct(m * torch.cos(p))\n        if self.clip_audio:\n            audio = torch.clip(x, min=-1.0, max=1.0)\n        return audio\n\n\nclass ConvNeXtBlock(nn.Module):\n    \"\"\"ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.\n\n    Args:\n        dim (int): Number of input channels.\n        intermediate_dim (int): Dimensionality of the intermediate layer.\n        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.\n            Defaults to None.\n        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.\n            None means non-conditional LayerNorm. Defaults to None.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        intermediate_dim: int,\n        layer_scale_init_value: float,\n        adanorm_num_embeddings: Optional[int] = None,\n    ):\n        super().__init__()\n        self.dwconv = nn.Conv1d(\n            dim, dim, kernel_size=7, padding=3, groups=dim\n        )  # depthwise conv\n        self.adanorm = adanorm_num_embeddings is not None\n        if adanorm_num_embeddings:\n            self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)\n        else:\n            self.norm = nn.LayerNorm(dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(\n            dim, intermediate_dim\n        )  # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.pwconv2 = nn.Linear(intermediate_dim, dim)\n        self.gamma = (\n            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)\n            if layer_scale_init_value > 0\n            else None\n        )\n\n    def forward(\n        self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        residual = x\n        x = self.dwconv(x)\n        x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)\n        if self.adanorm:\n            assert cond_embedding_id is not None\n            x = self.norm(x, cond_embedding_id)\n        else:\n            x = self.norm(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.pwconv2(x)\n        if self.gamma is not None:\n            x = self.gamma * x\n        x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)\n\n        x = residual + x\n        return x\n\n\nclass AdaLayerNorm(nn.Module):\n    \"\"\"\n    Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes\n\n    Args:\n        num_embeddings (int): Number of embeddings.\n        embedding_dim (int): Dimension of the embeddings.\n    \"\"\"\n\n    def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.dim = embedding_dim\n        self.scale = nn.Embedding(\n            num_embeddings=num_embeddings, embedding_dim=embedding_dim\n        )\n        self.shift = nn.Embedding(\n            num_embeddings=num_embeddings, embedding_dim=embedding_dim\n        )\n        torch.nn.init.ones_(self.scale.weight)\n        torch.nn.init.zeros_(self.shift.weight)\n\n    def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:\n        scale = self.scale(cond_embedding_id)\n        shift = self.shift(cond_embedding_id)\n        x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)\n        x = x * scale + shift\n        return x\n\n\nclass ResBlock1(nn.Module):\n    \"\"\"\n    ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,\n    but without upsampling layers.\n\n    Args:\n        dim (int): Number of input channels.\n        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.\n        dilation (tuple[int], optional): Dilation factors for the dilated convolutions.\n            Defaults to (1, 3, 5).\n        lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.\n            Defaults to 0.1.\n        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.\n            Defaults to None.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        kernel_size: int = 3,\n        dilation: Tuple[int, int, int] = (1, 3, 5),\n        lrelu_slope: float = 0.1,\n        layer_scale_init_value: Optional[float] = None,\n    ):\n        super().__init__()\n        self.lrelu_slope = lrelu_slope\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=self.get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=self.get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=dilation[2],\n                        padding=self.get_padding(kernel_size, dilation[2]),\n                    )\n                ),\n            ]\n        )\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=self.get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=self.get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    nn.Conv1d(\n                        dim,\n                        dim,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=self.get_padding(kernel_size, 1),\n                    )\n                ),\n            ]\n        )\n\n        self.gamma = nn.ParameterList(\n            [\n                (\n                    nn.Parameter(\n                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True\n                    )\n                    if layer_scale_init_value is not None\n                    else None\n                ),\n                (\n                    nn.Parameter(\n                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True\n                    )\n                    if layer_scale_init_value is not None\n                    else None\n                ),\n                (\n                    nn.Parameter(\n                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True\n                    )\n                    if layer_scale_init_value is not None\n                    else None\n                ),\n            ]\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):\n            xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)\n            xt = c1(xt)\n            xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)\n            xt = c2(xt)\n            if gamma is not None:\n                xt = gamma * xt\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n    @staticmethod\n    def get_padding(kernel_size: int, dilation: int = 1) -> int:\n        return int((kernel_size * dilation - dilation) / 2)\n\n\nclass Backbone(nn.Module):\n    \"\"\"Base class for the generator's backbone. It preserves the same temporal resolution across all layers.\"\"\"\n\n    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,\n                        C denotes output features, and L is the sequence length.\n\n        Returns:\n            Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,\n                    and H denotes the model dimension.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement the forward method.\")\n\n\nclass VocosBackbone(Backbone):\n    \"\"\"\n    Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization\n\n    Args:\n        input_channels (int): Number of input features channels.\n        dim (int): Hidden dimension of the model.\n        intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.\n        num_layers (int): Number of ConvNeXtBlock layers.\n        layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.\n        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.\n                                                None means non-conditional model. Defaults to None.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_channels: int,\n        dim: int,\n        intermediate_dim: int,\n        num_layers: int,\n        layer_scale_init_value: Optional[float] = None,\n        adanorm_num_embeddings: Optional[int] = None,\n    ):\n        super().__init__()\n        self.input_channels = input_channels\n        self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)\n        self.adanorm = adanorm_num_embeddings is not None\n        if adanorm_num_embeddings:\n            self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)\n        else:\n            self.norm = nn.LayerNorm(dim, eps=1e-6)\n        layer_scale_init_value = layer_scale_init_value or 1 / num_layers\n        self.convnext = nn.ModuleList(\n            [\n                ConvNeXtBlock(\n                    dim=dim,\n                    intermediate_dim=intermediate_dim,\n                    layer_scale_init_value=layer_scale_init_value,\n                    adanorm_num_embeddings=adanorm_num_embeddings,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n        self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Conv1d, nn.Linear)):\n            nn.init.trunc_normal_(m.weight, std=0.02)\n            nn.init.constant_(m.bias, 0)\n\n    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:\n        bandwidth_id = kwargs.get(\"bandwidth_id\", None)\n        x = self.embed(x)\n        if self.adanorm:\n            assert bandwidth_id is not None\n            x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)\n        else:\n            x = self.norm(x.transpose(1, 2))\n        x = x.transpose(1, 2)\n        for conv_block in self.convnext:\n            x = conv_block(x, cond_embedding_id=bandwidth_id)\n        x = self.final_layer_norm(x.transpose(1, 2))\n        return x\n\n\nclass VocosResNetBackbone(Backbone):\n    \"\"\"\n    Vocos backbone module built with ResBlocks.\n\n    Args:\n        input_channels (int): Number of input features channels.\n        dim (int): Hidden dimension of the model.\n        num_blocks (int): Number of ResBlock1 blocks.\n        layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_channels,\n        dim,\n        num_blocks,\n        layer_scale_init_value=None,\n    ):\n        super().__init__()\n        self.input_channels = input_channels\n        self.embed = weight_norm(\n            nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)\n        )\n        layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3\n        self.resnet = nn.Sequential(\n            *[\n                ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)\n                for _ in range(num_blocks)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:\n        x = self.embed(x)\n        x = self.resnet(x)\n        x = x.transpose(1, 2)\n        return x\n\n\nclass Vocos(nn.Module):\n    def __init__(\n        self,\n        input_channels: int = 256,\n        dim: int = 384,\n        intermediate_dim: int = 1152,\n        num_layers: int = 8,\n        adanorm_num_embeddings: int = 4,\n        n_fft: int = 800,\n        hop_size: int = 200,\n        padding: str = \"same\",\n    ):\n        super().__init__()\n\n        self.backbone = VocosBackbone(\n            input_channels=input_channels,\n            dim=dim,\n            intermediate_dim=intermediate_dim,\n            num_layers=num_layers,\n            adanorm_num_embeddings=adanorm_num_embeddings,\n        )\n        self.head = ISTFTHead(dim, n_fft, hop_size, padding)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        x = self.head(x)\n\n        return x[:, None, :]\n"
  },
  {
    "path": "modules/tokenizer/tokenizer.py",
    "content": "from abc import ABC\nfrom abc import abstractmethod\nimport sentencepiece as spm\nfrom sentencepiece import sentencepiece_model_pb2 as sp_pb2_model\nfrom typing import Any, Union\nimport numpy as np\nfrom dataclasses import dataclass\n\ndef encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False):\n    \"\"\"Encode text into sentence pieces. Only supports py3.\"\"\"\n\n    if not sample:\n        pieces = sp_model.EncodeAsPieces(text)\n    else:\n        pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)\n\n    return pieces\n\n\nclass AbstractTokenizer(ABC):\n    \"\"\"Abstract class for tokenizer.\"\"\"\n\n    def __init__(self, name):\n        self.name = name\n        super().__init__()\n\n    @property\n    @abstractmethod\n    def vocab_size(self):\n        pass\n\n    @property\n    @abstractmethod\n    def vocab(self):\n        \"\"\"Dictionary from vocab text token to id token.\"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def inv_vocab(self):\n        \"\"\"Dictionary from vocab id token to text token.\"\"\"\n        pass\n\n    @abstractmethod\n    def tokenize(self, text):\n        pass\n\n    def detokenize(self, token_ids):\n        raise NotImplementedError('detokenizer is not implemented for {} '\n                                  'tokenizer'.format(self.name))\n\n    @property\n    def cls(self):\n        raise NotImplementedError('CLS is not provided for {} '\n                                  'tokenizer'.format(self.name))\n\n    @property\n    def sep(self):\n        raise NotImplementedError('SEP is not provided for {} '\n                                  'tokenizer'.format(self.name))\n\n    @property\n    def pad(self):\n        raise NotImplementedError('PAD is not provided for {} '\n                                  'tokenizer'.format(self.name))\n\n    @property\n    def eod(self):\n        raise NotImplementedError('EOD is not provided for {} '\n                                  'tokenizer'.format(self.name))\n\n    @property\n    def mask(self):\n        raise NotImplementedError('MASK is not provided for {} '\n                                  'tokenizer'.format(self.name))\n\n\nclass SPieceTokenizer(AbstractTokenizer):\n    def __init__(self, spm_file: str):\n        super().__init__('Sentence Piece')\n        self.sp_model = spm.SentencePieceProcessor()\n        self.sp_model.Load(spm_file)\n        self.eod_id = self.get_token_id('</s>')\n\n        self.special_ids = set([\n            self.sp_model.pad_id(),\n            self.sp_model.eos_id(),\n            self.sp_model.bos_id(),\n            self.sp_model.unk_id(),\n            self.eod_id,\n        ])\n\n        # initialize index_2_bytes\n        self._initialize_index_2_bytes()\n    \n    def encode_pieces(self, text: str, sample=False):\n        if not sample:\n            pieces = self.sp_model.EncodeAsPieces(text)\n        else:\n            pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)\n        return pieces\n\n    def _initialize_index_2_bytes(self):\n        proto = sp_pb2_model.ModelProto()\n        proto.ParseFromString(self.sp_model.serialized_model_proto())\n        self.index_2_numbytes = [0] * len(proto.pieces)\n        for i, p in enumerate(proto.pieces):\n            clean_piece = p.piece.replace('▁', '')\n            self.index_2_numbytes[i] = len(clean_piece.encode('utf-8'))\n\n    def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):\n        proto = sp_pb2_model.ModelProto()\n        proto.ParseFromString(self.sp_model.serialized_model_proto())\n        if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix:\n            proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix\n            self.sp_model.LoadFromSerializedProto(proto.SerializeToString())\n            print(f\"> set add_dummy_prefix to {add_dummy_prefix} ...\", flush=True)\n\n    def add_special_id(self, token_id):\n        self.special_ids.add(token_id)\n\n    @property\n    def has_dummy_prefix(self):\n        pieces = self.sp_model.EncodeAsPieces(\"hello\")\n        return pieces[0].startswith('▁')\n\n    @property\n    def vocab_size(self):\n        return self.sp_model.GetPieceSize()\n\n    @property\n    def vocab(self):\n        \"\"\"Dictionary from vocab text token to id token.\"\"\"\n        return self.sp_model\n\n    def get_array_bytes(self, array):\n        return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array)\n\n    def tokenize(self, text):\n        tokens = encode_pieces(self.sp_model, text)\n        return self.convert_tokens_to_ids(tokens)\n    \n    def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]:\n        tokens = self.encode_pieces(text)\n        t = self.convert_tokens_to_ids(tokens)\n        if bos:\n            t.insert(0, self.bos_id)\n        if eos:\n            t.append(self.eos_id)\n        return t\n\n    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:\n        if isinstance(tokens, str):\n            return self.sp_model.PieceToId(tokens)\n        return [self.sp_model.PieceToId(token) for token in tokens]\n\n    def detokenize(self, token_ids):\n        if isinstance(token_ids, list):\n            pieces = [self.sp_model.IdToPiece(id) for id in token_ids]\n        else:\n            pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()]\n        return pieces\n    \n    def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str:\n        assert not skip_special_tokens, \"skip_special_tokens is not supported\"\n        if isinstance(token_ids, (int, np.integer)):\n            return self.detokenize([int(token_ids)])[0]\n        return ''.join(self.detokenize(token_ids))\n\n    def get_token_id(self, token):\n        return self.sp_model.PieceToId(token)\n\n    def inv_vocab(self):\n        # TODO: to be implemented\n        return {}\n\n    def decode_pieces(self, pieces):\n        return self.sp_model.DecodePieces(pieces)\n\n    @property\n    def eod(self):\n        return self.eod_id\n\n    @property\n    def pad_id(self):\n        return self.sp_model.pad_id()\n\n    @property\n    def eos_id(self):\n        return self.sp_model.eos_id()\n\n    @property\n    def bos_id(self):\n        return self.sp_model.bos_id()\n\n    @property\n    def unk_id(self):\n        return self.sp_model.unk_id()\n    \n    @property\n    def pad_token_id(self):\n        return self.pad_id\n\n    @property\n    def eos_token_id(self):\n        return self.eos_id\n\n    \n@dataclass\nclass ExtraTokens:\n    msg_end: int\n    user_msg_start: int\n    assistant_msg_start: int\n    name_end: int\n    media_begin: int\n    media_content: int\n    media_end: int\n    pad: int\n\n\ndef instantiate_extra_tokens(tokenizer: AbstractTokenizer):\n    if isinstance(tokenizer, SPieceTokenizer):\n        map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)\n    else:\n        raise ValueError(f\"Invalid tokenizer type: {type(tokenizer)}\")\n\n    return ExtraTokens(\n        msg_end=map_fn('[extra_id_0]'),\n        user_msg_start=map_fn('[extra_id_1]'),\n        assistant_msg_start=map_fn('[extra_id_2]'),\n        name_end=map_fn('[extra_id_12]'),\n        media_begin=map_fn('[extra_id_13]'),\n        media_content=map_fn('[extra_id_14]'),\n        media_end=map_fn('[extra_id_15]'),\n        pad=tokenizer.pad_id\n    )\n\ndef get_tokenizer_and_extra_tokens():\n    sp_model_path = \"resources/tokenizer/160k.model\"\n    tokenizer = SPieceTokenizer(sp_model_path)\n    tokenizer.set_add_dummy_prefix(False)\n    extra_tokens = instantiate_extra_tokens(tokenizer)\n    return tokenizer, extra_tokens\n"
  },
  {
    "path": "readme.md",
    "content": "# MoonCast: High-Quality Zero-Shot Podcast Generation\n\n<p align=\"center\">\n    <picture>\n        <img src=\"./fig/logo.png\" width=\"40%\">\n    </picture>\n</p>\n\n## Overview\nDemo page: [demo](https://mooncastdemo.github.io)\n\nPaper: [paper](https://arxiv.org/abs/2503.14345)\n\n2025/03/26 UPDATE: We also host a [HuggingFace space](https://huggingface.co/spaces/jzq11111/mooncast) for testing audio generation.\n\nWe open-source this system to advance the field of human-like speech synthesis. Our goal is to create more natural and expressive synthetic voices that bridge the gap between machines and humans. We hope this project will inspire researchers and developers to explore new possibilities in voice technology. We warmly welcome contributions from anyone interested in this project. Whether through code, documentation, feedback, or sharing your insights, every input helps make this project better.\n\n\n## Environment Setup\n- Create conda environment.\n\n``` sh\nconda create -n mooncast -y python=3.10\nconda activate mooncast\npip install -r requirements.txt \npip install flash-attn --no-build-isolation\npip install huggingface_hub\npip install gradio==5.22.0\n```\n\n- Download the pretrained weights.\n``` sh\npython download_pretrain.py\n```\n\n## Example Usage\n\n### Script Generation\nFor podcast script generation, we utilize specific LLM prompts defined in ``zh_llmprompt_script_gen.py`` (Chinese) and ``en_llmprompt_script_gen.py`` (English). We have selected the [Gemini 2.0 Pro Experimental 02-05](https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2#2.0-pro) model for this task, favoring its ability to produce conversational language, design natural dialogue, and offer broad topic coverage. Our process involves two stages: first, we generate a concise summary by providing the input knowledge source as an attachment along with the ``INPUT2BRIEF`` prompt. Subsequently, this summary, paired with the ``BRIEF2SCRIPT`` prompt, is used to generate the final podcast script in JSON format.\n\n### Speech Generation\nThe audio prompts used in this project are sourced from publicly available podcast segments and are intended solely for demonstration purposes. Redistribution of these audio files, whether in their original form or as generated audio, is strictly prohibited. If you have any concerns or questions regarding the use of these audio files, please contact us at juzeqian@mail.ustc.edu.cn\n\n```sh\nCUDA_VISIBLE_DEVICIES=0 python inference.py\n```\n\n2025/03/26 UPDATE: We add a Gradio-based user interface for audio generation. Deploy it locally using:\n\n```sh\nCUDA_VISIBLE_DEVICIES=0 python app.py\n```\n\n## Disclaimer  \nThis project is intended for **research purposes only**. We strongly encourage users to **use this project and its generated audio responsibly**. **We are not responsible for any misuse or abuse of this project**. By using this project, you agree to comply with all applicable laws and ethical guidelines."
  },
  {
    "path": "requirements.txt",
    "content": "torch==2.3.1\ntorchaudio==2.3.1\nsentencepiece==0.2.0\nprotobuf\nnumpy\n\nlibrosa==0.9.1\npyyaml\ntransformers\nsafetensors\neinops\nscipy\ntimm==1.0.7\ntorchdyn\nlibrosa\naccelerate==0.26.0\nninja\ncryptography"
  },
  {
    "path": "test/test_audio_detokenizer.py",
    "content": "import sys\nimport torch\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nfrom modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize\nimport torchaudio\nimport librosa\n\nif __name__ == '__main__':\n    audio_tokenizer = get_audio_tokenizer()\n    audio_detokenizer = get_audio_detokenizer()\n\n    input_wav_16k, _ = librosa.load(\"en_prompt0.wav\", sr=16000)\n    input_wav_24k, _ = librosa.load(\"en_prompt0.wav\", sr=24000)\n\n    prompt_sec = 1\n    prompt_wav_16k = input_wav_16k[:16000*prompt_sec]\n    prompt_wav_24k = input_wav_24k[:24000*prompt_sec]\n    input_wav_16k = input_wav_16k[16000*prompt_sec:]\n    input_wav_24k = input_wav_24k[24000*prompt_sec:]\n\n    prompt_wav_24k = torch.tensor(prompt_wav_24k)[None, :].cuda()\n    prompt_wav_16k = torch.tensor(prompt_wav_16k)[None, :].cuda()\n    input_wav_24k = torch.tensor(input_wav_24k)[None, :].cuda()\n    input_wav_16k = torch.tensor(input_wav_16k)[None, :].cuda()\n\n    semantic_token = audio_tokenizer.tokenize(input_wav_16k)\n    prompt_semantic_token = audio_tokenizer.tokenize(prompt_wav_16k)\n\n    recon_wav = detokenize(audio_detokenizer, semantic_token, prompt_wav_24k, prompt_semantic_token)\n    print(recon_wav.shape)    \n    torchaudio.save(\"test/tmp_recon_en_prompt0.wav\", recon_wav.cpu(), 24000)\n\n    print(\"All tests passed!\")"
  },
  {
    "path": "test/test_audio_tokenizer.py",
    "content": "import sys\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nimport torch\n\nif __name__ == '__main__':\n    audio_tokenizer = get_audio_tokenizer()\n\n    input_wav = torch.zeros(1, 8000)\n    semantic_token = audio_tokenizer.tokenize(input_wav)\n    semantic_token = semantic_token.cpu().numpy().tolist()\n    assert semantic_token == [[ 765, 3512, 7469, 7469, 7028, 2567, 6008, 7469, 6217, 2567, 7649, 7469,\n         3292, 2567, 7649, 7469, 3292, 2567,  948, 7469, 3292, 2567,  948, 7469]]\n\n    print(\"All tests passed!\")"
  },
  {
    "path": "test/test_tokenizer.py",
    "content": "import sys\nsys.path.append('.')\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\n\n\n\nif __name__ == '__main__':\n    tokenizer, extra_tokens = get_tokenizer_and_extra_tokens()\n\n    assert tokenizer.encode(\"user\") == [1495]\n    assert tokenizer.decode([1495]) == \"user\"\n\n    assert tokenizer.encode(\"0\") == [501]\n    assert tokenizer.decode([501]) == \"0\"\n\n    assert tokenizer.encode(\"1\") == [503]\n    assert tokenizer.decode([503]) == \"1\"\n\n    assert tokenizer.encode(\"assistant\") == [110866]\n    assert tokenizer.decode([110866]) == \"assistant\"\n\n    assert tokenizer.encode(\"audio\") == [26229]\n    assert tokenizer.decode([26229]) == \"audio\"\n    \n\n    assert extra_tokens.msg_end == 260\n    assert extra_tokens.user_msg_start == 261\n    assert extra_tokens.assistant_msg_start == 262\n    assert extra_tokens.name_end == 272\n    assert extra_tokens.media_begin == 273\n    assert extra_tokens.media_content == 274\n    assert extra_tokens.media_end == 275\n\n    assert [tokenizer.convert_tokens_to_ids(i) for i in ['<0x0A>', '</s>', '[extra_id_0]']] == [14, 1, 260]\n\n    print(\"All tests passed!\")\n    "
  },
  {
    "path": "zh_llmprompt_script_gen.py",
    "content": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### 任务说明\n请按照以下结构总结输入文件, 普通文本格式。总结应当有创造性，保证信息全面，包含所有有趣、不常见、有价值的观点和信息。\n- **文本要求**：\n    1. 直接输出结果，不要包含任何额外信息。\n    2. 总结文本用中文。允许少部分实体名词、专有名词、缩写等使用英文。\n    3. 不要包含任何数学公式。\n    4. 不要修改原文的任何实体名词、专有名词、缩写等。除非有常见译名，否则不要翻译实体名词。不要试图修改实体名词意思。\n    5. **请智慧地将简写中的数字转化。如简称里“a2b”实际代表“a to b”,而不是“a二b\"；简称里“a4b”实际代表“a for b”, 而不是“a四b\"; “v2”可能代表“version 二”, 也可以进一步翻译成“第二代”。请提供原始简称，和你认为合适的中文翻译。**\n\n### 标题和作者\n- **语言要求**：中文，书面语。\n- **内容要求**：提供文档的标题和作者。简要概括文档的主题和作者的背景。确保包含所有重要信息，不要有遗漏，尽可能保留足够的信息。\n\n### 摘要\n- **语言要求**：中文，书面语。\n- **内容要求**：\n    1. 本文做了什么事情。\n    2. 之前有没有别人做过这个事情。\n    3. 如果有别人做过，那本文为什么还需要做。\n    4. 本文具体怎么做的。\n    5. 本文做的怎么样。\n- **附加要求**：额外提供一个段落，解释本节中可能让听众困惑的术语、概念、方法等，确保不了解领域的读者也能理解。专有名词的解释需贴合原文，覆盖所有可能的困惑点，包括缩写名词、专有名词、实体名等。\n\n### 主要主题和概念\n- **语言要求**：中文，书面语。\n- **内容要求**：每个主题概念需按照3W原则组织，包括：\n    - **What**：界定问题，搞清楚问题是什么。\n    - **Why**：分析问题，结构化分析问题本质原因是什么。\n    - **How**：解决问题，文档如何解决问题。\n- **附加要求**：\n    1. 确保主题概念包含所有重要信息，不要有遗漏，主题概念需足够详细，充分阐述What和Why两个部分。\n    2. How部分不要包含数学公式等技术细节。要用大众理解的语言充分概括。\n    3. 各主题概念间不要互相重叠，保证逻辑清晰。\n    4. 额外提供一个段落，解释本节中可能让听众困惑的术语、概念、方法等，确保不了解领域的读者也能理解。专有名词的解释需贴合原文，覆盖所有可能的困惑点，包括缩写名词、专有名词、实体名等。\n\n### 重要引文\n- **语言要求**：中文，书面语。\n- **内容要求**：按照以下结构组织内容：\n    1. **论点**：需要证明什么。\n    2. **论据**：用于证明论点的材料。\n    3. **论证**：运用论据证明论点的过程。\n- **附加要求**：\n    1. 论据和论证思路需严格来源于原文，不要进行任何虚构。\n    2. 确保引文内容充分，不要有遗漏，尽可能保留足够的信息，不要进行任何精简。引文避免使用数学公式。\n    3. 额外提供一个段落，解释本节中可能让听众困惑的术语、概念、方法等，确保不了解领域的读者也能理解。专有名词的解释需贴合原文，覆盖所有可能的困惑点，包括缩写名词、专有名词、实体名等。\n\n### 总结\n- **语言要求**：中文，书面语。\n- **内容要求**：突出文档最重要、最吸引人眼球的部分。与摘要相比，需更结合主题概念的具体内容，对摘要进行补充。可包含未来改进方向、当前应用场景、当前存在问题等。\n'''\n\nBRIEF2SCRIPT = '''\n## 一、任务概述\n请根据提供的总结文本，和你对这方面了解的知识，生成一个生动的中文播客文字剧本。 剧本包含两位说话人交替发言。输出格式为 JSON 可解析的**列表**。列表里每条发言是一个**字典**，包含“speaker”和“text”字段。示例格式：`[{{\"speaker\": \"1\", \"text\": \"xxx\"}}]`。“speaker”字段是说话人身份（1表示主持人，2表示嘉宾），“text”字段是具体发言内容。输出直接从json的代码块开始，不要包含任何额外的信息。\n\n## 二、内容与结构要求\n### （一）文本内容\n- 总结性文本包含所有重要信息，需全面挑选并纳入剧本。\n- 通过两位说话人的对话形式展示信息，保持创作性，适当抽象不重要的细节。例如，听众不关心具体的测试名称，而关心测试的任务，结果和分析。\n### （二）结构设计\n- **开场白**：引入主题，简要介绍讨论内容，不提及说话人姓名。\n- **关键主题讨论**：逐字阅读总结文本，讨论重要主题。\n- **结束语**：简洁总结讨论亮点，并对未来或技术发展进行展望。\n\n## 三、语言风格\n- 文本要尽量口语化，接近自动语音识别的结果，包含填充词如“嗯”、“啊”、“呃”,\"呢\",\"这个\",\"其实\",\"就是\",\"然后\"等，响应词如\"嗯。\"或“是。”等。多用口语化的表达方式，允许重复，语法可以不那么正式。避免直接照搬总结文本里的书面语。不要用括号或语音识别通常不会出现的符号。 句中的空格代表短停顿，逗号表示稍长停顿，句号表示长停顿。可能存在因口音带来的同音识别错误。提问需要非常口语化。总之，就是要像平时聊天一样自然。示例如下：\n    [\n    {{  \"speaker\": \"0\",\n        \"text\": \"欢迎收听今天的播客。那我们这一集是要聊什么东西呢？\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"我们要聊星座。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"星座嘛，就是，他是一个好跟新的朋友认识的时候一个聊天的话题。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"没错，现我觉得在现在已经从你好，变成了诶，请问你的星座是什么呢？。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"对，那我天枰座。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"那，我是摩羯座。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"摩羯座，那你会觉得就是星座，是一个可以相信的东西吗？\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"我本人其实不太相信星座诶，在一开始的时候。我就跟大部分不相信星座的一样，觉得，呃，你总能把人就分成十二种，然后呢就它讲的就是对的。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"啊，所以就是，会觉得说把星座就是单纯把人分成十二种事件很粗略，不太有什么科学根据的事情。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"嗯，对，会这样觉得。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"嗯。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"会无法理解，到底是，那这一开始定出这十二种人格的是谁啊？\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"对，就是凭什么他可以决定，我们就是这十二种人格。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"嗯？\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"为什么不是十三、十四或者更多的种类。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"对，没有错。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"对。那，所以你会觉得说那种就是什么星座的心理分析是完全不可信的，还是其实也会很常去看一下，呃，类似的这种星座测验。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"其实我刚说一开始不相信啊，我真的是到后期比较相信。然后后期会开始相信的是因为，呃，要去找一些我自己没有办法有方法去理解的人，因为认识那样子的人，他就是暧昧对象，必须要了解他到底是怎样的人，可是没有其他的依据的时候呢，我就偷偷开始看起了星座，然后就偷偷我觉得，好像讲得有那么一点准，然后就会开始看了。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"哦，所以感觉有点像是说在从，星座的这种描述测验中去找说，你想要从这个东西，去对那个人有更深一层的了解的感觉。\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"对，而且通常他会讲到一两个你好你觉得好像是那样子的点，那你就会想要看更多，然后就好像就跟着就开始相信这个东西了。\",\n    }},\n    {{  \"speaker\": \"0\",\n        \"text\": \"哦，嗯，诶，所以你是什么什么星座的？\",\n    }},\n    {{  \"speaker\": \"1\",\n        \"text\": \"就我刚刚说我是摩羯座啊。\",\n    }}\n    ]\n\n### （二）标点符号\n- 使用中文标点符号，避免英文标点。\n- 剧本文本只使用逗号，句号和问号。禁止使用叹号。禁止使用省略号（'…'）、括号、引号（包括‘’“”'\"）或波折号，否则视为不合格。\n- 如果被对方的响应词等打断，本句句末是逗号，而不是句号。\n\n## 四、信息组织与逻辑\n### （一）引用问题\n- 由于听众看不到总结性文本，引用需确保上下文完整，确保听众能理解。\n- 避免直接复述，需用自己的话解释引用内容。\n- 总结文本里提供了对专业术语的解释。你需要保证你剧本里的专业术语尽可能被充分解释。专业术语的解释请具有创意，不要简单地创作成“这个是什么意思”这样的句子。可以通过举例、比喻等方式进行解释，但需要进一步说明比喻的合理性。可以由对方提问后进行解释，也可以自行解释。没有提到的专业名词不需要解释。提到的专业名词不一定要立即进行解释，可以和别的专业名词一起解释。总结文本中的专业术语可能与文字内容存在差异，你需要根据上下文合理解释。\n### （二）信息密度\n- 确保信息密度适中，避免过高或过低。适当的信息密度希望让没有相关背景知识的听众，快速理解文档里在做什么，为什么这么做，以及如何做。\n- 为了避免信息密度过高，剧本不能讨论数学公式、测试设置、实验指标等细节，而应该用简单概括性语言描述。\n- 为了避免信息密度过低，剧本每个主题需不少于4次发言，避免停留于关键词的简单罗列。会从尽可能从不同角度讨论，不局限于提供的总结文本。总结文本高度概括，剧本应当将其展开，讨论更多细节。你可以利用自己知识，补充背景知识，举例说明等方式，让听众更好地理解。\n- 提高信息密度技巧： \n\t1. 嵌入金句。在剧本中加入令人印象深刻，眼前一亮的句子，可以是自己创作，也可以是引用他人。\n    2. 增加知识点： 在剧本中适当增加知识点，能让听众听完更有收获。\n    3. 引入新信息：剧本中加入新的概念，引起用户好奇，特别是听众不知道但想知道的信息，这种非常重要。\n    4. 逆向思维： 加入不同角度的信息，打破用户熟悉的视角，提出不一样的观点。\n    5. 制造反差冲击： 剧本可以对用户熟知的认知进行非常规（出乎意料）但合理的描述，形成与他预期的反差，这种反差是信息密度。\n- 降低信息密度技巧：\n    1. 使用短句：简洁明了，易于理解，让叙述更紧凑。不要一句话里有过多的信息。\n    2. 描述细节：模糊不清，抽象的信息难以让听众建立认知，而细节越多，越能有画面感，容易阅读\n    3. 多进行场景化塑造： 场景是具象的，有画面的。 听众能轻松接收传达的信息，还能让人触景生情。\n    4. 多讲事实：讲事实才能更显真实，读的人才能更感同身受，这样文案信息密度更低。\n    5. 多讲故事：讲自己的故事，讲身边的故事，讲听说的故事，故事能把听众带入场景，更利于聚精会神地收听。\n    6. 多用动词和具体名词：动词和具体的名词更容易让听众浮现画面，而形容词会让复杂的文案更难理解。\n    7. 避免使用数学公式： 数学公式不利于大众理解。\n\n## 五、对话设计\n### (一) 说话人角色\n- 剧本中包含主持人和嘉宾。其中说话人1是主持人，负责节目开场和结束，擅长利用提问控制对话节奏，用生动的例子让知识不枯燥。说话人2是嘉宾，是主要负责文档内容的介绍，对该领域有惊人的知识储备，擅长有条理地语言组织，通俗地讲解内容。\n- 两位说话人热情开朗，喜欢结合个人故事或者实例进行讨论，给听众带来直观的体验。大家乐于讨论离题的故事。\n- 两位说话人积极互动，会经常用\"嗯\"等打断词表示对对方的认同。需要将响应词按照时间点插入对话。被打断前的句子句末用逗号，而不是句号。\n- 保证说话人角色统一，不要出现主持人介绍技术细节，或者引导主持人讨论主题等情况。\n- 主持人根据嘉宾的回答，逐步增加对该领域的认知。但主持人不一定立刻能理解，也不一定理解地完全正确。主持人可以表达不理解或者提出一些常人可能会存在的疑问。这种情况下，嘉宾会进一步用更通俗的语言解释，或者针对性地解答常人常有的疑问或者误解。这种互动相比于永远正确的主持人和嘉宾更加真实，也更利于观众地理解。\n\n### (二) 主题顺序安排\n- 主持人会根据总结性文本，将主题排列，并保证主题间有逻辑关联，如从整体过渡到细节，从细节过渡到整体，从原因过渡到结果，从技术过渡到应用等。\n- 主持人会引导对话节奏，按照总结性文本的主题顺序进行讨论。嘉宾不应该干扰主题过渡。\n\n### (三) 知识速率\n- 剧本中知识速率需要合理，不能短时间过快引入大量知识。知识不能突然增加，要逐渐引入，确保听众能够理解。\n- 听众视角：充分考虑听众感受，从听众视角进行剧本创作。必须保证剧本不包含详细数学公式，而应该用通俗的语言介绍。确保剧本内容易懂，不要过于专业化。\n- 无论是与主题相关的信息，还是离题的故事，都要按照你的知识进行充分地讨论，切忌简单地提一句而没有展开。要保证剧本足够真实，符合日常对话的逻辑，保证说话人间足够的尊重，不敷衍，不随意打断。\n\n## 六、其他要求\n### (一) 外语数字：\n  1. 剧本将用于中文播客内容的录制。请保证大部分外语和数字转换为中文，以便于模型能正确识别读音。\n  2. 请根据上下文，智慧地判断正确的读音。例如，“2021”如果表达年份，应当转换为“二零二一”。但如果表示数字，应当转换为“两千零二十一”。一些英文简称里常用数字代表英文单词，比如“a2b”代表“a to b”，“a4b”代表“a for b”，请保证不要简单转换为中文数字，而是根据上下文，将其翻译成合适的中文。\n  3. 对于一些不常见的英文简写，如果根据上下文判断读音需要逐个字母阅读，则须保证每个字母间留有空格，如“AI”添加空格为“A I”，以避免模型误认为是一个单词。除非实体名字有常见的中文翻译，否则不要翻译实体名字。\n### (二) 剧本长度\n  1. 请控制\"text\"值的文本总长度不超过3000字符，且不超过60个发言，否则不合格。请选择技术细节，主题概念进行讨论。不要为了字数限制缩短每个话题讨论的深度，不要局限于总结文本，充分发挥你的知识。\n\nINPUT: {BRIEF}  \n \n再次强调：\n说话人1是主持人, 说话人2是嘉宾。说话人和嘉宾没有姓名。剧本文本只使用逗号，句号和问号。禁止使用叹号。禁止使用省略号（'…'）、括号、引号（包括‘’“”'\"）或波折号，否则视为不合格。请优先保证每个话题讨论的深度，不要局限于总结文本，利用你的知识，补充背景知识，举例说明等方式，让听众更好地理解。\n请保证大部分外语和数字转换为中文，以便于模型能正确识别读音。在技术文档里，英文简称常用数字代表英文单词，比如“a2b”代表“a to b”，“a4b”代表“a for b”，请保证不要简单转换为中文数字，而是根据上下文，将其翻译成合适的中文。\n\nOUTPUT:\n'''\n"
  }
]