Full Code of jzq2000/MoonCast for AI

main 07037e58f427 cached
43 files
252.7 KB
65.2k tokens
298 symbols
1 requests
Download .txt
Showing preview only (280K chars total). Download the full file or copy to clipboard to get everything.
Repository: jzq2000/MoonCast
Branch: main
Commit: 07037e58f427
Files: 43
Total size: 252.7 KB

Directory structure:
gitextract_g10rsitx/

├── .gitignore
├── LICENSE
├── app.py
├── download_pretrain.py
├── en_llmprompt_script_gen.py
├── inference.py
├── modules/
│   ├── audio_detokenizer/
│   │   ├── audio_detokenizer.py
│   │   ├── bigvgan_wrapper.py
│   │   ├── flow_matching/
│   │   │   ├── dit_block.py
│   │   │   ├── model.py
│   │   │   ├── ode_wrapper.py
│   │   │   └── scheduler.py
│   │   ├── semantic_fm_prefix_streaming.py
│   │   └── vocoder/
│   │       ├── activations.py
│   │       ├── alias_free_activation/
│   │       │   ├── __init__.py
│   │       │   ├── cuda/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── activation1d.py
│   │       │   │   ├── anti_alias_activation.cpp
│   │       │   │   ├── anti_alias_activation_cuda.cu
│   │       │   │   ├── compat.h
│   │       │   │   ├── load.py
│   │       │   │   └── type_shim.h
│   │       │   └── torch/
│   │       │       ├── __init__.py
│   │       │       ├── act.py
│   │       │       ├── filter.py
│   │       │       └── resample.py
│   │       ├── bigvgan.py
│   │       └── utils.py
│   ├── audio_tokenizer/
│   │   ├── audio_tokenizer.py
│   │   ├── quantize/
│   │   │   ├── __init__.py
│   │   │   ├── factorized_vector_quantize.py
│   │   │   ├── residual_vq.py
│   │   │   └── vector_quantize.py
│   │   ├── rep_codec.py
│   │   ├── transformer.py
│   │   └── vocos.py
│   └── tokenizer/
│       └── tokenizer.py
├── readme.md
├── requirements.txt
├── test/
│   ├── test_audio_detokenizer.py
│   ├── test_audio_tokenizer.py
│   └── test_tokenizer.py
└── zh_llmprompt_script_gen.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.safetensors
*.pt
*.vscode
**/__pycache__/
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/
tmp*
resources/
*.gradio

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2025 Zeqian Ju

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: app.py
================================================
import gradio as gr
from huggingface_hub import snapshot_download 
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')

from inference import Model
import base64

model = Model()
model.generate_config.max_new_tokens = 50 * 50 # no more than 50s per turn


def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str):
    try:
        print(json_dialogue_input_str, type(json_dialogue_input_str))
        print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1)
        # json_data = json.loads(json_dialogue_input_str)
        json_data = eval(json_dialogue_input_str.strip())
        print(json_data, type(json_data))    

        def validate_json(data):
            try:
                if not isinstance(data, list):
                    return "json must be a dictionary"
                cur_spk_should_be = 0
                for item in data:
                    if item['role'] != str(cur_spk_should_be):
                        return f"role should be {cur_spk_should_be} in item {item}"
                    cur_spk_should_be = 1 - cur_spk_should_be
                return None 
            except Exception as e:
                return str(e)


        validation_error = validate_json(json_data)
        if validation_error:
            raise gr.Error(validation_error)
        
        role_mapping = {
            "0": {
                "ref_audio": prompt_audio_role0_file,
                "ref_text": prompt_text_role0, 
            },
            "1": {
                "ref_audio": prompt_audio_role1_file, 
                "ref_text": prompt_text_role1,
            }
        }

        # 完整输入 JSON (你需要根据你的模型调整)
        model_input_json = {
            "role_mapping": role_mapping,
            "dialogue": json_data, # 从用户输入的 JSON 中获取 dialogue
        }
        print("模型推理输入 JSON:", model_input_json)


        # 4. **[重要] 调用你的 Model 类的 `inference` 方法**
        # audio_bytes = model.inference(model_input_json) 

        # 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放)
        # return base64.b64decode(audio_bytes)
        for cur_chunk in model.inference(model_input_json, streaming=True):
            yield base64.b64decode(cur_chunk)

    except Exception as e:
        # return str(e) # 返回错误信息给 Gradio
        raise gr.Error(str(e))

title_en = "# PODCAST generator (supports English and Chinese)"
title_zh = "# 播客生成 (支持英文和中文)"

instruct_en = "## See [Github](https://github.com/jzq2000/MoonCast) for podcast script generation."
instruct_zh = "## 播客剧本生成请参考 [Github](https://github.com/jzq2000/MoonCast)。"

input_labels_en = ["Prompt Audio for Role 0", "Prompt Text for Role 0", "Prompt Audio for Role 1", "Prompt Text for Role 1", "Script JSON Input"]
input_labels_zh = ["角色 0 的 Prompt 音频", "角色 0 的 Prompt 文本", "角色 1 的 Prompt 音频", "角色 1 的 Prompt 文本", "剧本 JSON 输入"]

output_label_en = "Generated Audio Output (streaming)"
output_label_zh = "生成的音频输出(流式)"

example_prompt_text_role0_en = "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem."
example_prompt_text_role0_zh = "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。"
example_prompt_text_role1_en = "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors."
example_prompt_text_role1_zh = "他最后就能让同样食材炒出来的菜味道大大提升。"

text_placeholder_zh = "对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。"
text_placeholder_en = "Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio."


example_json_en = '''[
       {
            "role": "0",
            "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
        },
       {
            "role": "1",
            "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.",
       },
       {
            "role": "0",
            "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.",
       },
        {
            "role": "1",
            "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
        }
]'''
example_json_zh = '''[
        {
            "role": "0",
            "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
        },
        {
            "role": "1",
            "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
        },
        {
            "role": "0",
            "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
        }   
    ]
'''

# examples_en = [
#     ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en]
# ]
# examples_zh = [
#     ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
# ]

examples = [
    ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en],
    ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
]

# -------------------- 更新界面元素的函数 --------------------
def update_ui_language(language):
    if language == "English":
        return  gr.update(value=title_en), \
                gr.update(value=instruct_en), \
                gr.update(label="UI Language"), \
                gr.update(label=input_labels_en[0]), \
                gr.update(label=input_labels_en[1]), \
                gr.update(label=input_labels_en[2]), \
                gr.update(label=input_labels_en[3]), \
                gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \
                gr.update(label=output_label_en), \
                gr.update(value="Submit"), \
                gr.update(label="Examples (Demonstration Use Only. Do Not Redistribute.)", headers=input_labels_en)
    
    elif language == "中文":
        return  gr.update(value=title_zh), \
                gr.update(value=instruct_zh), \
                gr.update(label="UI 语言"), \
                gr.update(label=input_labels_zh[0]), \
                gr.update(label=input_labels_zh[1]), \
                gr.update(label=input_labels_zh[2]), \
                gr.update(label=input_labels_zh[3]), \
                gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \
                gr.update(label=output_label_zh), \
                gr.update(value="提交"), \
                gr.update(label="示例 (仅用于展示,切勿私自传播。)", headers=input_labels_zh)

    else:
        raise ValueError("Invalid language selected")


audio_output = gr.Audio(label=output_label_en, streaming=True) 
css = """
.centered-title { /* CSS rule for centering title */
    text-align: center !important;
}
"""
# -------------------- Gradio 界面定义 (修改) --------------------
with gr.Blocks(css=css) as iface:

    title_output = gr.Markdown(value=title_zh, elem_classes="centered-title")
    instruct_output = gr.Markdown(value=instruct_zh)
    language_choice = gr.Radio(["中文", "English"], value="中文", label="UI语言") 

    with gr.Row(): # Main row to create two columns
        with gr.Column(scale=2): 
            json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input

        with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs
            audio_input_role0 = gr.Audio(type="filepath", label=input_labels_zh[0]) # Prompt Audio for Role 0
            text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0

        with gr.Column(scale=1): # 
            audio_input_role1 = gr.Audio(type="filepath", label=input_labels_zh[2]) # Prompt Audio for Role 1
            text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1

    examples_component = gr.Examples(
        examples=examples,
        inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
        cache_examples=False,
        label="示例(仅用于展示,切勿私自传播。)",
    )
    
    submit_button = gr.Button("提交")
    
    submit_button.click(
        fn=process_json_and_generate_audio,
        inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
        outputs=audio_output
    )
    audio_output.render()
    
    language_choice.change(
        fn=update_ui_language,
        inputs=language_choice,
        outputs=[title_output, instruct_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset]
    )


iface.launch(share=True)


================================================
FILE: download_pretrain.py
================================================
from huggingface_hub import snapshot_download
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')

================================================
FILE: en_llmprompt_script_gen.py
================================================
# INPUT -> BRIEF -> SCRIPT


INPUT2BRIEF = '''
### Task Description  
Please summarize the input document in plain text format according to the following structure. The summary should be creative, comprehensive, and include all interesting, uncommon, and valuable viewpoints and information.

- **Text Requirements**:  
    1. Directly output the result without any additional information.  
    2. The summary should be in English. Retain a small number of proper nouns, names, and abbreviations in their original form (e.g., Chinese characters).  
    3. Do not include any mathematical formulas.  
    4. Do not alter any proper nouns, names, or abbreviations from the original text. Unless there is a common translation, do not translate proper nouns. Do not attempt to modify the meaning of proper nouns.  
    5. **Intelligently convert numbers in abbreviations. For example, "a2b" should be interpreted as "a to b," not "a two b"; "a4b" as "a for b," not "a four b"; "v2" may represent "version two" or "second generation." Provide the original abbreviation and your suggested English translation.**  

### Title and Author  
- **Language Requirements**: English, formal written language.  
- **Content Requirements**: Provide the title and author of the document. Briefly summarize the theme of the document and the author's background. Ensure all important information is included without omission and sufficient context is retained.  

### Abstract  
- **Language Requirements**: English, formal written language.  
- **Content Requirements**:  
    1. What this document has done.  
    2. Whether similar work has been done before.  
    3. If similar work exists, why this document is still necessary.  
    4. How this document specifically addresses the topic.  
    5. How well this document achieves its goals.  
- **Additional Requirements**: Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  

### Main Themes and Concepts  
- **Language Requirements**: English, formal written language.  
- **Content Requirements**: Each theme and concept should be organized according to the 3W principle:  
    - **What**: Clearly define the problem.  
    - **Why**: Analyze the problem and identify its root causes.  
    - **How**: Explain how the document addresses the problem.  
- **Additional Requirements**:  
    1. Ensure each theme and concept is comprehensive and includes all important details. Fully elaborate on the "What" and "Why" sections.  
    2. Avoid technical details such as mathematical formulas in the "How" section. Use language that is easily understood by a general audience.  
    3. Ensure themes and concepts do not overlap and maintain clear logic.  
    4. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  

### Key Citations  
- **Language Requirements**: English, formal written language.  
- **Content Requirements**: Organize the content according to the following structure:  
    1. **Argument**: State what needs to be proven.  
    2. **Evidence**: Provide the material used to support the argument.  
    3. **Reasoning**: Describe the process of using evidence to prove the argument.  
- **Additional Requirements**:  
    1. Ensure all evidence and reasoning are directly sourced from the original text without fabrication.  
    2. Ensure citation content is complete and retains sufficient context without simplification. Avoid using mathematical formulas in citations.  
    3. Include an additional paragraph to explain any terms, concepts, or methods that may confuse readers unfamiliar with the field. Ensure proper nouns are explained consistently with the original text, covering all potential points of confusion, including abbreviations and entity names.  

### Conclusion  
- **Language Requirements**: English, formal written language.  
- **Content Requirements**: Highlight the most important and impactful aspects of the document. Compared to the abstract, this section should provide more detailed insights related to the main themes and concepts. It may also include future directions for improvement, current application scenarios, and existing challenges.  
'''

BRIEF2SCRIPT = '''
## 1. Task Overview

Please generate a lively English podcast script based on the provided English summary text and your knowledge of the topic. The script should feature a dialogue between two speakers who take turns speaking.  Output format should be JSON-parsable **list**. Each speaker's turn is a **dictionary** containing "speaker" and "text" fields. Example format: `[{{"speaker": "1", "text": "xxx"}}]`. The "speaker" field indicates the speaker's identity (1 for host, 2 for guest), and the "text" field is the spoken content. Output should start directly with the JSON code block, without any extra information.

## 2. Content and Structure 
### (1) Text Content
- The summary text contains all important information, which needs to be comprehensively selected and incorporated into the script.
- Present information through a dialogue between two speakers, maintaining creativity and abstracting away unimportant details. For example, listeners aren't concerned with specific test names, but rather the task itself, the results, and the analysis.
### (2) Structure Design
- **Opening:** Introduce the topic and briefly describe the discussion content, without mentioning speaker names.
- **Key Theme Discussion:**  Discuss important themes based on the summary text.  Expand on the summary, don't just repeat it verbatim.
- **Closing:** Briefly recap the discussion highlights and offer an outlook on future or technological developments.

## 3. Language Style
### (1) Conversational Style
- The text should be as conversational as possible, aiming for a style similar to automatic speech recognition output. Include filler words such as 'um,' 'uh,' 'like,' 'you know,' 'so,' 'right?', and so on. Response words such as 'Yeah,' 'Right,' 'Okay,' and similar. Conversational expressions, repetitions, informal grammar, etc. Use short sentences. Avoid directly copying and pasting structured text from the summary text.  Parentheses and other symbols not typically found in speech recognition transcripts should be avoided. Spaces within sentences indicate pauses. Be aware that there might be homophone errors, potentially due to accents. Questions should sound very conversational.  Pay particular attention to incorporating conversational details, especially in questions. For example:
    [
    {{  "speaker": "1", 
        "text": "Welcome back to the podcast, everyone. Today we're diving into, uh, something that's really changing everything around us, A I."
    }},
    {{  "speaker": "2", 
        "text": "Yeah, A I is, like, everywhere now, isn't it?  It's kinda wild to think about."
    }},
    {{  "speaker": "1", 
        "text": "Totally.  And we're seeing it in so many areas of daily life.  Like, even just recommending what to watch, or, you know, suggesting products online."
    }},
    {{  "speaker": "2", 
        "text": "Mhm, exactly.  And it's not just online stuff, right? Think about smart homes, or even self-driving cars.  It's getting pretty advanced."
    }},
    {{  "speaker": "1", 
        "text": "Right, self-driving cars are still a bit futuristic for most of us, but, uh, even things like voice assistants on our phones, that's A I, isn't it?"
    }},
    {{  "speaker": "2", 
        "text": "Definitely.  Siri, Alexa, Google Assistant, all powered by A I.  It's become so normal, we almost don't even think about it anymore."
    }},
    {{  "speaker": "1", 
        "text": "Yeah, it's like, integrated into everything.  But is that a good thing, you think?  Like, are there downsides to all this A I in our lives?"
    }},
    {{  "speaker": "2", 
        "text": "Well, that's the big question, isn't it?  On the one hand, it makes things so much more convenient, saves us time, maybe even makes things safer in some ways."
    }},
    {{  "speaker": "1", 
        "text": "Safer how?"
    }},
    {{  "speaker": "2", 
        "text": "Uh, well, like in healthcare, for example.  A I can help doctors diagnose diseases earlier, maybe even more accurately. That's a huge plus, right?"
    }},
    {{  "speaker": "1", 
        "text": "Yeah, that's a really good point.  Medical applications are definitely exciting.  But what about the concerns, you know?  Like job displacement or privacy issues?"
    }},
    {{  "speaker": "2", 
        "text": "Right, those are super valid concerns.  Job displacement is a big one. If A I can do more and more tasks, what happens to human workers?  And privacy,"
    }},
    {{  "speaker": "1", 
        "text": "And privacy is huge, especially with all the data A I systems collect.  It's a lot to process."
    }},
    {{  "speaker": "2", 
        "text": "Exactly.  So, it's not just sunshine and roses, is it?  We need to be mindful of the ethical implications and make sure it's used responsibly."
    }},
    {{  "speaker": "1", 
        "text": "Definitely.  It's a powerful tool, but like any tool, it can be used for good or, you know, not so good.  It's up to us to guide its development, right?"
    }},
    {{  "speaker": "2", 
        "text": "Absolutely.  And that's a conversation we all need to be part of, not just the tech people, but everyone."
    }}
    ]

### (2) Punctuation
- Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks.  Exclamation points are prohibited.  Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified. do not use markdown syntax.  For example,**bold** or *italic* text should be avoided.  Use plain text only.
- If interrupted by the other person's response, the sentence should end with a comma, not a period.

## 4. Information Organization and Logic
### (1) Referencing Issues
- Given that listeners won't have access to the summary text, any references must provide sufficient context for comprehension.
- Avoid simply paraphrasing; instead, explain referenced content in your own words.
- Explanations of technical terms should be creative and avoid simply stating 'this means what?' You can use examples, metaphors, and so on for explanations, but ensure you also clarify the rationale behind the metaphor. Explanations can be provided in response to a question from the other speaker, or you can offer explanations proactively. Technical terms that are not mentioned don't need explanation.  Technical terms that are mentioned don't necessarily need immediate explanation; they can be explained alongside other technical terms. Technical terms in the summary text might differ slightly from the surrounding text; you'll need to provide reasonable explanations based on the context.
### (2) Information Density
- Ensure moderate information density, avoiding excessively high or low density. The goal of appropriate information density is to enable listeners without prior knowledge to quickly grasp the document's purpose, rationale, and methodology.
- To prevent information overload, the script should avoid delving into details like mathematical formulas, test setups, or specific experimental metrics. Instead, it should use simple, generalized language for descriptions.
- To avoid excessively low information density, ensure each topic is discussed for at least 4 speaker turns, moving beyond simple keyword listings. Discuss topics from multiple angles whenever possible, going beyond the provided summary text. Given that the summary text is highly generalized, the script should elaborate on it and discuss further details. Feel free to use your knowledge to supplement background information, provide examples, and so forth, to enhance listener understanding.
- Techniques to increase information density:
	1. Incorporate memorable quotes. Add impactful, attention-grabbing sentences to the script, either original ones or quotes from other sources.
    2. Boost knowledge content.  Judiciously add knowledge points to the script to make listeners feel more informed and rewarded.
    3. Introduce novel information. Incorporate new concepts to spark listener curiosity, particularly information they're unaware of but would find valuable. This is crucial.
    4. Employ reverse thinking. Include information from diverse angles, challenging listeners' existing perspectives and presenting alternative viewpoints.
    5. Generate contrast and impact. The script can offer unconventional (yet plausible) descriptions of familiar concepts to create a contrast with listener expectations.  This contrast contributes to information density.
- Techniques to decrease information density:
    1. Use short sentences: Concise and easy to understand, making the narrative more compact. Do not have too much information in one sentence.
    2. Describe details: Vague and abstract information makes it difficult for listeners to build understanding, while more details create a sense of imagery and are easier to read.
    3. Use more scenario-based descriptions: Scenarios are concrete and visual. Listeners can easily receive the conveyed information and be emotionally touched.
    4. Talk more about facts: Talking about facts makes it more real, and readers can empathize more, thus lowering the information density of the copy.
    5. Tell more stories: Tell your own stories, stories around you, and stories you've heard. Stories can bring listeners into the scene, making it easier to concentrate on listening.
    6. Use more verbs and concrete nouns: Verbs and concrete nouns make it easier for listeners to visualize, while adjectives make complex copy harder to understand.
    7. Avoid using mathematical formulas: Mathematical formulas are not conducive to public understanding.

## 5. Dialogue Design
### (1) Speaker Roles
- The script includes a host and a guest. Speaker 1 is the host, responsible for opening and closing the show, skilled at using questions to control the pace of the conversation, and using vivid examples to make knowledge less dry. Speaker 2 is the guest, primarily responsible for introducing the document content, has amazing knowledge reserves in the field, and is good at organizing language in a structured and easy-to-understand way.
- Both speakers are enthusiastic and cheerful, like to combine personal stories or examples for discussion, and bring a direct experience to listeners. They are happy to discuss digressive stories.
- The two speakers actively interact and frequently use interruption words such as "um" to indicate agreement with each other. Response words need to be inserted into the dialogue according to the timing. Sentences before being interrupted end with a comma, not a period.
- Ensure consistent speaker roles. Do not have the host introduce technical details, or have the guest guide the host to discuss topics.
- The host gradually increases their understanding of the field based on the guest's answers. However, the host may not understand immediately or completely correctly. The host can express misunderstanding or raise some questions that ordinary people might have. In this case, the guest will further explain in more accessible language, or specifically answer common questions or misunderstandings. This kind of interaction is more realistic and easier for listeners to understand than always correct hosts and guests.
### (2) Topic Order Arrangement
- The host will arrange the topics according to the summary text and ensure logical connections between topics, such as transitioning from overall to details, from details to overall, from cause to effect, from technology to application, etc.
- The host will guide the pace of the conversation and discuss topics in the order of the summary text. Guests should not interfere with topic transitions.
### (3) Knowledge Rate
- The knowledge rate in the script needs to be reasonable. Do not introduce a large amount of knowledge too quickly in a short period of time. Knowledge

## 6. Other Requirements
### (1) English Numbers and Foreign Words
  1. The script will be used for English podcast content recording. Please ensure most numbers and foreign words are rendered naturally in English to facilitate correct pronunciation.
  2. Please intelligently determine the correct pronunciation according to the context. For example, "2021" if expressing a year, should be converted to "two thousand and twenty-one" or "twenty twenty-one". But if expressing a number, it should be "two thousand and twenty-one". For some uncommon English abbreviations, if the pronunciation needs to be read letter by letter according to the context, you must ensure that there is a space between each letter, such as "AI" adding a space as "A I", to avoid the model misinterpreting it as a word. For example, "API" should be rendered as "A P I".
  3. Small amount of Chinese is allowed, especially for nouns, if it fits naturally within the conversational English context.
### (2) Script Length
  1. Please ensure that the total length of the 'text' values does not exceed 3,000 words and the number of speaker turns is kept within 60, otherwise it will be unqualified. Please choose technical details and topic concepts to discuss. Do not shorten the depth of discussion on each topic for the sake of word limit, do not be limited to the summary text, and give full play to your knowledge.

INPUT: {BRIEF}

## Re-emphasize:
Speaker 1 is the host, and Speaker 2 is the guest. Neither speaker has a name. The script text only uses commas, periods, and question marks. Use English punctuation marks. Avoid using other punctuation marks beyond commas, periods, and question marks. Exclamation points are prohibited.  Ellipses ('…'), parentheses, quotation marks (including ‘ ' “ ” ") or dashes are prohibited, otherwise it will be considered unqualified.  Please prioritize in-depth discussion for each topic. Don't limit yourself to the summary text; instead, use your knowledge to expand upon the topics, providing background information and illustrative examples to enhance listener understanding.
Ensure that numbers and foreign words are rendered naturally in English for accurate pronunciation during recording. In technical contexts, English abbreviations sometimes use numerical digits in place of words (e.g., "a2b" for "a to b," "a4b" for "a for b"). Please translate these abbreviations into appropriate English phrases based on the context. While the script is primarily in English, a small amount of Chinese, especially for nouns, is acceptable if it integrates naturally into the conversational flow.

OUTPUT:
'''


================================================
FILE: inference.py
================================================

import sys
sys.path.append(".")
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref, detokenize_streaming, detokenize_noref_streaming
import torch
import os
from glob import glob
import base64
import io
import torchaudio
from transformers import AutoModelForCausalLM, GenerationConfig
import librosa
from tqdm import tqdm
from pydub import AudioSegment

class Model(object):
    def __init__(self):

        
        self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
        self.speech_token_offset = 163840
        print(self.extra_tokens)
        self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
        self.user_ids = self.tokenizer.encode("user") # [1495]
        self.audio_ids = self.tokenizer.encode("audio") # [26229]
        self.spk_0_ids = self.tokenizer.encode("0") # [501] 
        self.spk_1_ids = self.tokenizer.encode("1") # [503] 

        self.msg_end = self.extra_tokens.msg_end # 260
        self.user_msg_start = self.extra_tokens.user_msg_start # 261
        self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
        self.name_end = self.extra_tokens.name_end # 272
        self.media_begin = self.extra_tokens.media_begin # 273
        self.media_content = self.extra_tokens.media_content # 274
        self.media_end = self.extra_tokens.media_end # 275

        self.audio_tokenizer =  get_audio_tokenizer()
        self.audio_detokenizer = get_audio_detokenizer()
        model_path = "resources/text2semantic"
        self.model =  AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
        self.generate_config = GenerationConfig(
            max_new_tokens=200 * 50, # no more than 200s per turn
            do_sample=True,
            top_k=30,
            top_p=0.8,
            temperature=0.8,
            eos_token_id=self.media_end,
        )
    
    def _clean_text(self, text):
        # you can add front-end processing here
        text = text.replace("“", "")
        text = text.replace("”", "")
        text = text.replace("...", " ")
        text = text.replace("…", " ")
        text = text.replace("*", "")
        text = text.replace(":", ",")
        text = text.replace("‘", "'")
        text = text.replace("’", "'")
        text = text.strip()
        return text

    @torch.inference_mode()
    def _process_text(self, js):

        if "role_mapping" in js:
            for role in js["role_mapping"].keys():
                js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
                
        for turn in js["dialogue"]:
            turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
        return js
        
    def inference(self, js, streaming=False):
        js = self._process_text(js)
        if "role_mapping" not in js:
            if streaming:
                return self.infer_without_prompt_streaming(js)
            else:
                return self.infer_without_prompt(js)
        else:
            if streaming:
                return self.infer_with_prompt_streaming(js)
            else:
                return self.infer_with_prompt(js)      
    
    @torch.inference_mode()
    def infer_with_prompt(self, js):
        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]
        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]
        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]

        media_start = [self.media_begin] + self.audio_ids + [self.media_content]
        media_end = [self.media_end] + [self.msg_end]

        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
        

        prompt = []
        cur_role_dict = dict()
        for role, role_item in js["role_mapping"].items():
            waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
            waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())

            waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
            waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())

            semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
            semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
            prompt_ids = semantic_tokens + self.speech_token_offset

            cur_role_dict[role] = {
                "ref_bpe_ids": role_item["ref_bpe_ids"],
                "wav_24k": waveform_24k,
                "semantic_tokens": semantic_tokens,
                "prompt_ids": prompt_ids
            }
        
        prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
        prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
        
        for seg_id, turn in enumerate(js["dialogue"]):
            role_id = turn["role"]
            cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
            cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
            prompt = prompt + cur_start_ids
        
        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())

        prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
        prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)

        
        generation_config = self.generate_config
        # you can modify sampling strategy here

        wav_list = []
        for seg_id, turn in tqdm(enumerate(js["dialogue"])):
            role_id = turn["role"]
            cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids                
            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
            len_prompt = prompt.shape[1]
            generation_config.min_length = len_prompt + 2
            # print(generation_config)
            # todo: add streaming support for generate function
            outputs = self.model.generate(prompt,
                                          generation_config=generation_config)
            if outputs[0, -1] == self.media_end:
                outputs = outputs[:, :-1]
            output_token = outputs[:, len_prompt:]
            prompt = torch.cat([outputs, media_end], dim=-1)            

            torch_token = output_token - self.speech_token_offset
            gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
            gen_speech_fm = gen_speech_fm.cpu()
            gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
            wav_list.append(gen_speech_fm)
            del torch_token
        
        concat_wav = torch.cat(wav_list, dim=-1).cpu()
        # print(concat_wav.shape)
        buffer = io.BytesIO()
        torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
        audio_bytes = buffer.getvalue()
        audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
        return audio_b64
    
    @torch.inference_mode()
    def infer_with_prompt_streaming(self, js):
        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]
        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]
        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]

        media_start = [self.media_begin] + self.audio_ids + [self.media_content]
        media_end = [self.media_end] + [self.msg_end]

        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
        

        prompt = []
        cur_role_dict = dict()
        for role, role_item in js["role_mapping"].items():
            waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
            waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())

            waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
            waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())

            semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
            semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
            prompt_ids = semantic_tokens + self.speech_token_offset

            cur_role_dict[role] = {
                "ref_bpe_ids": role_item["ref_bpe_ids"],
                "wav_24k": waveform_24k,
                "semantic_tokens": semantic_tokens,
                "prompt_ids": prompt_ids
            }
        
        prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
        prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
        
        for seg_id, turn in enumerate(js["dialogue"]):
            role_id = turn["role"]
            cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
            cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
            prompt = prompt + cur_start_ids
        
        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())

        prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
        prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)

        
        generation_config = self.generate_config
        # you can modify sampling strategy here

        wav_list = []
        for seg_id, turn in tqdm(enumerate(js["dialogue"])):
            role_id = turn["role"]
            cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids                
            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
            len_prompt = prompt.shape[1]
            generation_config.min_length = len_prompt + 2
            # print(generation_config)
            # todo: add streaming support for generate function
            outputs = self.model.generate(prompt,
                                          generation_config=generation_config)
            if outputs[0, -1] == self.media_end:
                outputs = outputs[:, :-1]
            output_token = outputs[:, len_prompt:]
            prompt = torch.cat([outputs, media_end], dim=-1)            

            torch_token = output_token - self.speech_token_offset
            for cur_chunk in detokenize_streaming(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]):
                cur_chunk = cur_chunk.cpu()
                cur_chunk = cur_chunk / cur_chunk.abs().max()
                cur_buffer = io.BytesIO()
                torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
                audio_bytes = cur_buffer.getvalue()
                audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
                yield audio_b64
               
    @torch.inference_mode()
    def infer_without_prompt(self, js):
        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]
        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]
        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]

        media_start = [self.media_begin] + self.audio_ids + [self.media_content]
        media_end = [self.media_end] + [self.msg_end]

        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
        

        prompt = []
        for seg_id, turn in enumerate(js["dialogue"]):
            role_id = turn["role"]
            cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
            cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
            prompt = prompt + cur_start_ids

        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
        generation_config = self.generate_config
        # you can modify sampling strategy here

        wav_list = []
        for seg_id, turn in tqdm(enumerate(js["dialogue"])):
            role_id = turn["role"]
            cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids                
            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
            len_prompt = prompt.shape[1]
            generation_config.min_length = len_prompt + 2
            # todo: add streaming support for generate function
            outputs = self.model.generate(prompt,
                                          generation_config=generation_config)
            if outputs[0, -1] == self.media_end:
                outputs = outputs[:, :-1]
            output_token = outputs[:, len_prompt:]
            prompt = torch.cat([outputs, media_end], dim=-1)

            torch_token = output_token - self.speech_token_offset
            gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
            gen_speech_fm = gen_speech_fm.cpu()
            gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
            wav_list.append(gen_speech_fm)
            del torch_token

        concat_wav = torch.cat(wav_list, dim=-1).cpu()
        # print(concat_wav.shape)
        buffer = io.BytesIO()
        torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
        audio_bytes = buffer.getvalue()
        audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
        return audio_b64
    
    @torch.inference_mode()
    def infer_without_prompt_streaming(self, js):
        user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids  + [self.name_end]
        user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids  + [self.name_end]
        assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
        assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]

        media_start = [self.media_begin] + self.audio_ids + [self.media_content]
        media_end = [self.media_end] + [self.msg_end]

        assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
        assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
        media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
        media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
        

        prompt = []
        for seg_id, turn in enumerate(js["dialogue"]):
            role_id = turn["role"]
            cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
            cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
            prompt = prompt + cur_start_ids

        prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
        generation_config = self.generate_config
        # you can modify sampling strategy here

        wav_list = []
        for seg_id, turn in tqdm(enumerate(js["dialogue"])):
            role_id = turn["role"]
            cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids                
            prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
            len_prompt = prompt.shape[1]
            generation_config.min_length = len_prompt + 2
            # print(generation_config)
            # todo: add streaming support for generate function
            outputs = self.model.generate(prompt,
                                          generation_config=generation_config)
            if outputs[0, -1] == self.media_end:
                outputs = outputs[:, :-1]
            output_token = outputs[:, len_prompt:]
            prompt = torch.cat([outputs, media_end], dim=-1)

            torch_token = output_token - self.speech_token_offset
            for cur_chunk in detokenize_noref_streaming(self.audio_detokenizer, torch_token):
                cur_chunk = cur_chunk.cpu()
                cur_chunk = cur_chunk / cur_chunk.abs().max()
                cur_buffer = io.BytesIO()
                torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
                audio_bytes = cur_buffer.getvalue()
                audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
                yield audio_b64
           
        
if __name__ == "__main__":
    model = Model()
    
    # speaker should be interleaved
    zh_test_json = {
        "role_mapping": {
            "0": {
                "ref_audio": "./zh_prompt0.wav",
                "ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
            },
            "1": {
                "ref_audio": "./zh_prompt1.wav",
                "ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
            }
        },      
        "dialogue": [
           {
                "role": "0",
                "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
            },
            {
                "role": "1",
                "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
            },
            {
                "role": "0",
                "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
            }      
        ]
    }


    audio_bytes_gen = model.inference(zh_test_json, streaming=True)
    audio = AudioSegment.empty()
    for cur_chunk in audio_bytes_gen:
        cur_chunk = base64.b64decode(cur_chunk)
        audio_chunk = AudioSegment.from_file(io.BytesIO(cur_chunk), format="mp3")
        audio += audio_chunk
    audio.export("tmp_generated_zh_stream.mp3", format="mp3")
    print("zh stream done")
    

    audio_bytes = model.inference(zh_test_json)
    file_to_save = open(f"tmp_generated_zh.mp3", "wb")
    file_to_save.write(base64.b64decode(audio_bytes))
    print("zh done")

    # speaker should be interleaved
    en_test_json = {
        "role_mapping": {
            "0": {
                "ref_audio": "./en_prompt0.wav",
                "ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
            },
            "1": {
                "ref_audio": "./en_prompt1.wav",
                "ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
            }
        },      
        "dialogue": [
            {
                "role": "0",
                "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
            },
            {
                "role": "1",
                "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
            },
            {
                "role": "0",
                "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
            },
            {
                "role": "1",
                "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
            }
        ]
    }
    audio_bytes = model.inference(en_test_json)
    file_to_save = open(f"tmp_generated_en.mp3", "wb")
    file_to_save.write(base64.b64decode(audio_bytes))
    print("en done")


    # also support inference without prompt
    # speaker should be interleaved
    without_prompt_test_json = {
        "dialogue": [
            {
                "role": "0",
                "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
            },
            {
                "role": "1",
                "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
            },
            {
                "role": "0",
                "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
            }   
        ]
    }
    audio_bytes = model.inference(without_prompt_test_json)
    file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
    file_to_save.write(base64.b64decode(audio_bytes))
    print("without prompt done")

================================================
FILE: modules/audio_detokenizer/audio_detokenizer.py
================================================

import torch

from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper
from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper


class PrefixStreamingFlowMatchingDetokenizer:
    def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:
        self.dtype = torch.bfloat16

        print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")

        self.vocoder = vocoder
        self.vocoder.to_dtype(self.dtype)
        
        self.semantic_fm = fm

        # initialize mel_spec
        self.max_pos_size = 4096
        self.is_timbre_semantic_token = False
        self.pre_mel = None
        self.frame_size = 480 # how many samples in a frame
        self.pre_wav = None
        self.state_dict_backup = None
        self.hamming_window_cache = {}
        self.previous_chunk_left = None
        self.look_ahead_tokens = look_ahead_tokens

        self.clear_states()

        
    @classmethod
    def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device, 
                        look_ahead_tokens=0,
                        max_prompt_chunk=2, max_kv_cache_tokens=900,
                        use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
        bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
        semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
                                                                 use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)        
        return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
    
    @torch.inference_mode()
    def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):
        """
            Arguments:
                timbre_speech: torch.Tensor, shape [B, N_speech_24k]
                timbre_semantic_token: torch.Tensor, shape [B, N]
                chunk_size: int, chunk size for prefilling
                timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
        """
        if timbre_mel is None:
            assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None"
            assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0
            assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1

            mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))
        else:
            assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0
            assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
            mel_spec = timbre_mel.squeeze(0)

        if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
            # pad mel_spec
            mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))
        elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
            # truncate mel_spec
            mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]

        # clear all states
        self.semantic_fm.clear_all_states()
        self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)
        self.state_dict_backup = self.semantic_fm.state_dict()

    @torch.inference_mode()
    def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1):
        assert len(semantic_token.shape) == 2 and ode_step > 0
        assert semantic_token.shape[0] == 1

        semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
        
        semantic_token = semantic_token.squeeze(0)

        if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
            semantic_token_previous = self.previous_chunk_left["semantic_token"]
            semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)

        x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)

        if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
            self.previous_chunk_left = {"semantic_token": None}
        
        speech_mel = self.semantic_fm.infer_chunk(
            xt_chunk=x_t_chunk, 
            semantic_tokens_chunk=semantic_token, 
            start_position_id=self.semantic_fm.start_position_id,
            ode_steps=ode_step, 
            verbose=verbose, 
            look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,
            cache=self.previous_chunk_left,
            ode_solver=ode_solver
        )

        chunk_size = speech_mel.shape[0]
        length = speech_mel.shape[0]
        self.semantic_fm.start_position_id += length
        self.semantic_fm.update_incremental_state()
        self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens
        
        # smoothing

        # I will maintain the history of seqlen wav
        # For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
        # For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the 

        if self.pre_mel is None: # first chunk, related to TTFB
            concat_mel = speech_mel
            concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
            if is_final:
                self.clear_states()
                self.state_dict_backup = None
                ret_wav = concat_reconstructed_wav.float()
            else:
                reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk

                self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step
                self.pre_mel = speech_mel[-chunk_size//2:, :]

                ret_wav = reconstructed_wav.float()
        else:
            concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
            concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)

            if is_final:
                self.clear_states()
                self.state_dict_backup = None
                ret_wav = concat_reconstructed_wav.float()
            else:
                # fetch history
                prev_speech_len = self.pre_wav.shape[1]

                if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
                    gen_speech_len = prev_speech_len * 2
                else:
                    gen_speech_len = concat_reconstructed_wav.shape[1] // 2


                reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk
                
                if gen_speech_len not in self.hamming_window_cache:
                    self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)
                
                hamming_window = self.hamming_window_cache[gen_speech_len]
                
                
                # apply smoothing of the first half chunk
                reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \
                    reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]
            
                res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
                res_mel_len = res_speech_len // self.frame_size

                self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
                self.pre_mel = speech_mel[-res_mel_len:, :]
                ret_wav = reconstructed_wav.float()
        
        if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:
            # out of position id, 
            self.semantic_fm.clear_all_states()
            self.semantic_fm.load_state_dict(self.state_dict_backup)

        return ret_wav

    def clear_states(self):
        self.semantic_fm.clear_all_states()
        self.previous_chunk_left = None
        self.pre_mel = None
        self.pre_wav = None

def get_audio_detokenizer():
    fm_model_config = "resources/audio_detokenizer/config.yaml"
    fm_ckpt_path = "resources/audio_detokenizer/model.pt"

    bigvgan_config_file = "resources/vocoder/config.json"
    bigvgan_ckpt_path = "resources/vocoder/model.pt"

    device=torch.cuda.current_device()
    detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
    vocoder_config=bigvgan_config_file, 
    vocoder_ckpt=bigvgan_ckpt_path, 
    max_prompt_chunk=10, # 10 * 3 = 30s
    fm_config=fm_model_config, 
    fm_ckpt=fm_ckpt_path, 
    device=device, 
    use_cfg=False,
    look_ahead_tokens=12) 
    
    return detokenizer


def detokenize(detokenizer, tokens, ref_wav, ref_tokens):
    with torch.no_grad():
        detokenizer.clear_states()
        detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
        cache_speech_collection = []
        chunk_size = 150
        first_chunk_size = 100
        first_chunk_tokens = tokens[:, :first_chunk_size]
        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
        cache_speech_collection.append(gen_speech)
        res_tokens = tokens[:, first_chunk_size:]
        for i in range(0, res_tokens.size(1), chunk_size):
            chunk_tokens = res_tokens[:, i:i+chunk_size]
            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
            cache_speech_collection.append(gen_speech)

        gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
        return gen_speech_all

def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):
    with torch.no_grad():
        detokenizer.clear_states()
        detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
        cache_speech_collection = []
        chunk_size = 150
        first_chunk_size = 100
        first_chunk_tokens = tokens[:, :first_chunk_size]
        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
        yield gen_speech
        res_tokens = tokens[:, first_chunk_size:]
        for i in range(0, res_tokens.size(1), chunk_size):
            chunk_tokens = res_tokens[:, i:i+chunk_size]
            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
            yield gen_speech

def detokenize_noref(detokenizer, tokens):
    with torch.no_grad():
        detokenizer.clear_states()
        cache_speech_collection = []
        chunk_size = 150
        first_chunk_size = 100
        first_chunk_tokens = tokens[:, :first_chunk_size]
        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
        cache_speech_collection.append(gen_speech)
        res_tokens = tokens[:, first_chunk_size:]
        for i in range(0, res_tokens.size(1), chunk_size):
            chunk_tokens = res_tokens[:, i:i+chunk_size]
            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
            cache_speech_collection.append(gen_speech)
        
        gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
        return gen_speech_all


def detokenize_noref_streaming(detokenizer, tokens):
    with torch.no_grad():
        detokenizer.clear_states()
        cache_speech_collection = []
        chunk_size = 150
        first_chunk_size = 100
        first_chunk_tokens = tokens[:, :first_chunk_size]
        gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
        yield gen_speech
        res_tokens = tokens[:, first_chunk_size:]
        for i in range(0, res_tokens.size(1), chunk_size):
            chunk_tokens = res_tokens[:, i:i+chunk_size]
            gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
            yield gen_speech


================================================
FILE: modules/audio_detokenizer/bigvgan_wrapper.py
================================================
import os
import json
import logging

import librosa
import torch

from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint

logger = logging.getLogger(__name__)


class BigVGANWrapper:
    def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
        self.vocoder = vocoder.to(device)
        if dtype is not None:
            self.vocoder = self.vocoder.to(dtype)
        self.vocoder = self.vocoder.eval()
        self.device = device
        self.h = h
    
    def to_dtype(self, dtype):
        self.vocoder = self.vocoder.to(dtype)

    def extract_mel_from_wav(self, wav_path=None, wav_data=None):
        """
        params:
            wav_path: str, path of the wav, should be 24k
            wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
        return:
            mel: [T, num_mels], torch.tensor
        """
        if wav_data is None:
            wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
        
        wav_data = torch.tensor(wav_data).unsqueeze(0)

        mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], 
                          hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
        return mel.squeeze(0).transpose(0, 1)
    
    @torch.inference_mode()
    def extract_mel_from_wav_batch(self, wav_data):
        """
        params:
            wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
        return:
            mel: [Batch, T, num_mels], torch.tensor
        """

        wav_data = torch.tensor(wav_data)

        mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], 
                          hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
        return mel.transpose(1, 2)
    
    def decode_mel(self, mel):
        """
        params:
            mel: [T, num_mels], torch.tensor
        return:
            wav: [1, T], torch.tensor
        """    
        mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
        wav = self.vocoder(mel)
        return wav.squeeze(0)

    def decode_mel_batch(self, mel):
        """
        params:
            mel: [B, T, num_mels], torch.tensor
        return:
            wav: [B, 1, T], torch.tensor
        """    
        mel = mel.transpose(1, 2).to(self.device)
        wav = self.vocoder(mel)
        return wav

    @classmethod
    def from_pretrained(cls, model_config, ckpt_path, device):
        with open(model_config) as f:
            data = f.read()
        json_config = json.loads(data)
        h = AttrDict(json_config)
        vocoder = BigVGAN(h, True)
        state_dict_g = load_checkpoint(ckpt_path, "cpu")
        vocoder.load_state_dict(state_dict_g["generator"])

        logger.info(">>> Load vocoder from {}".format(ckpt_path))
        return cls(vocoder, device, h)





================================================
FILE: modules/audio_detokenizer/flow_matching/dit_block.py
================================================
import torch
import torch.nn as nn


import torch
import torch.nn as nn
import torch.nn.functional as F

from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
    # the last shape is "self.hidden_dim / 2" because we convert to complex
    assert x.ndim == 4
    assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \
        f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'
     
    # reshape freq cis to match and apply pointwise multiply
    # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
    shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)



class Attention(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
            flash_attention: bool = True
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = flash_attention

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.qk_norm = qk_norm
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:
        B, N, C = x.shape

        if self.fused_attn:
            if nopadding:
                qkv = self.qkv(x)
                qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
                q, k, v = qkv.split([self.num_heads] * 3, dim=1)
                q, k = self.q_norm(q), self.k_norm(k)

                q = q.view(B, N, self.num_heads, self.head_dim)
                k = k.view(B, N, self.num_heads, self.head_dim)
                v = v.view(B, N, self.num_heads, self.head_dim)

                if rotary_pos_emb is not None:
                    q, k = apply_rotary_emb(q, k, rotary_pos_emb)
                
                if incremental_state is not None:
                    if "prev_k" in incremental_state:
                        prev_k = incremental_state["prev_k"]
                        k = torch.cat([prev_k, k], dim=1)
                    
                    if "cur_k" not in incremental_state:
                        incremental_state["cur_k"] = {}
                    incremental_state["cur_k"] = k
                
                    if "prev_v" in incremental_state:
                        prev_v = incremental_state["prev_v"]
                        v = torch.cat([prev_v, v], dim=1)
                    
                    if "cur_v" not in incremental_state:
                        incremental_state["cur_v"] = {}
                    incremental_state["cur_v"] = v
                
                q = q.view(B * N, self.num_heads, self.head_dim)
                k = k.view(-1, self.num_heads, self.head_dim)
                v = v.view(-1, self.num_heads, self.head_dim)

                x = flash_attn_varlen_func(
                    q=q,
                    k=k,
                    v=v,
                    cu_seqlens_q=cu_seqlens,
                    cu_seqlens_k=cu_seqlens_k,
                    max_seqlen_q=max_seqlen,
                    max_seqlen_k=max_seqlen_k,
                    dropout_p=self.attn_drop.p if self.training else 0.,
                )
            else:
                
                if incremental_state is not None:
                    raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.")

                qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
                if self.qk_norm:
                    q, k, v = qkv.unbind(2)
                    q, k = self.q_norm(q), self.k_norm(k)
                    # re-bind
                    qkv = torch.stack((q, k, v), dim=2)
                
                # pack qkv with seq_len
                qkv_collect = []
                for i in range(qkv.shape[0]):
                    qkv_collect.append(
                        qkv[i, :seq_len[i], :, :, :]
                    )
                
                qkv = torch.cat(qkv_collect, dim=0)

                x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)
                
                # unpack and pad 0
                x_collect = []
                for i in range(B):
                    x_collect.append(
                        x[cu_seqlens[i]:cu_seqlens[i+1], :, :]
                    )
                x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)

        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
            x = x.transpose(1, 2)

        x = x.reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)


        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
 
        if ffn_type == "vanilla_mlp":
            from timm.models.vision_transformer import Mlp
            mlp_hidden_dim = int(hidden_size * mlp_ratio)
            approx_gelu = lambda: nn.GELU(approximate="tanh")
            self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        else:
            raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
        
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)

        x_ = modulate(self.norm1(x), shift_msa, scale_msa)

        if incremental_state is not None:
            if "attn_kvcache" not in incremental_state:
                incremental_state["attn_kvcache"] = {}
            inc_attn = incremental_state["attn_kvcache"]
        else:
            inc_attn = None

        x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)
        
        if not nopadding:
            x_ = x_ * mask[:, :, None]
        
        x = x + gate_msa * x_

        x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
        
        x_ = self.mlp(x_)

        if not nopadding:
            x_ = x_ * mask[:, :, None]

        x = x + gate_mlp * x_
        return x


================================================
FILE: modules/audio_detokenizer/flow_matching/model.py
================================================
import torch
import torch.nn as nn
import math
from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
                         interpolation_factor: int = 1, max_seq_length: int = 4096):
    print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # ROPE type-A extention
    # we choose to use interpolation rather than extrapolation for better position encoding
    # for scale purposes, t should be a float tensor
    t = torch.arange(end, device=freqs.device).float()
    scale = 1.0 / float(interpolation_factor)
    t *= scale

    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
    # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
    if max_seq_length < end:
        freqs_cis = freqs_cis[:max_seq_length,].clone()
    return freqs_cis


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).float().to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
        return t_emb
    

class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size,
            embedding_dim,
            padding_idx,
        )
        self.register_buffer('_float_tensor', torch.FloatTensor(1))

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2   # d/2
        emb = math.log(10000) / (half_dim - 1)   # 2*log(10000)/(d-2)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)   # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)   # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)   # shape: (num_embeddings, d)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input, incremental_state=None, timestep=None, **kwargs):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input.shape[:2]
        max_pos = self.padding_idx + 1 + seq_len
        if self.weights is None or max_pos > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights = self.weights.to(self._float_tensor)

        if incremental_state is not None:
            # positions is the same for every token when decoding a single step
            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)

        positions = self.make_positions(input, self.padding_idx)
        return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()   # (B, T, dim)

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number
    
    def make_positions(self, tensor, padding_idx):
        """Replace non-padding symbols with their position numbers.

        Position numbers begin at padding_idx+1. Padding symbols are ignored.
        """
        # The series of casts and type-conversions here are carefully
        # balanced to both work with ONNX export and XLA. In particular XLA
        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
        # how to handle the dtype kwarg in cumsum.
        mask = tensor.ne(padding_idx).int()
        return (
                    torch.cumsum(mask, dim=1).type_as(mask) * mask
            ).long() + padding_idx
    

class DiTPrefix(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size,
        output_size,
        semantic_vocab_size,
        hidden_size=1024,
        depth=12,
        num_heads=4,
        # mlp related
        mlp_ratio=4.0,
        ffn_type="conv1d_conv1d",
        ffn_gated_glu=True,
        ffn_act_layer="gelu",
        ffn_conv_kernel_size=5,

        # rope
        use_rope=False,
        rope_params={
                "max_position_embeddings": 4096,
                "rope_base": 10000.0,
                "rope_interpolation_factor": 1.0,
            },


        position_embedding_type="sincos",
        max_seq_len=4096,
        prompt_cfg_dropout=0.0
    ):
        super().__init__()
        self.num_heads = num_heads

        self.prompt_cfg_dropout = prompt_cfg_dropout

        self.t_embedder = TimestepEmbedder(hidden_size)

        self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)

        self.input_linear = nn.Linear(input_size, hidden_size)

        # position embedding
        if position_embedding_type == "learnable":
            self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)
        elif position_embedding_type == "sincos":
            self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)
        elif position_embedding_type == "skip":
            self.position_embedding = None
        else:
            raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type))

        self.use_rope = use_rope

        if self.use_rope:
            
            assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding."
            rope_dim = hidden_size // num_heads

            self.rotary_pos_emb = precompute_freqs_cis(
                rope_dim, rope_params["max_position_embeddings"],
                theta=rope_params["rope_base"],
                interpolation_factor=rope_params["rope_interpolation_factor"],
                max_seq_length=max_seq_len
            )

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, 
                     ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, output_size)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)


        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):
        """
        Forward pass of DiT.
        x: (N, T, C) tensor of inputs (latent representations of speech)
        position_ids: (N, T) tensor of positional indices
        t: (N,) tensor of diffusion timesteps
        condition: (N, T) tensor of semantic tokens
        seq_len: (N,) tensor of sequence lengths
        """

        condition = self.semantic_token_embedding(condition)  # (N, T, D)

        x = self.input_linear(x)   

        if self.position_embedding is not None:
            position_emb = self.position_embedding(position_ids)
            x = x + position_emb
        
        # ROPE        
        if self.use_rope:
            bsz, seqlen = position_ids.shape
            if self.rotary_pos_emb.device != position_ids.device:
                self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
            rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),
                                          dtype=self.rotary_pos_emb.dtype,
                                          device=self.rotary_pos_emb.device)
            for b in range(bsz):
                cur_rope = rotary_pos_emb[b]
                cur_position_ids = position_ids[b]
                cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
        else:
            rotary_pos_emb = None

        t = self.t_embedder(t)                   # (N, D)
        c = t.unsqueeze(1) + condition           # (N, T, D)


        for block_idx, block in enumerate(self.blocks):
            # x = block(x, c, attn_mask)  # (N, T, D)
            # XXX mask could be None because we always use full mask

            if incremental_state is not None:
                if block_idx not in incremental_state:
                    incremental_state[block_idx] = {}
                incr = incremental_state[block_idx]
            else:
                incr = None
            
            x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)

        x = self.final_layer(x, c)               # (N, T, C)
        return x



================================================
FILE: modules/audio_detokenizer/flow_matching/ode_wrapper.py
================================================
import torch
import torch.nn as nn
from functools import lru_cache
import copy


@lru_cache(maxsize=1)
def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
    return torch.zeros(numel, device=device, dtype=dtype)

class StreamingODEWrapperForPrefix(nn.Module):
    def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0):
        super(StreamingODEWrapperForPrefix, self).__init__()
        self.net = net
        self.x_mask = x_mask
        self.x_cond = x_cond

        assert use_cfg == False, "cfg is not supported in streaming detokenizer"

        self.use_cfg = use_cfg
        self.use_cfg_rescale = use_cfg_rescale
        self.cfg_init = cfg_init
        self.cfg_scale = cfg_scale
        self.cfg_token_id = cfg_token_id
        self.cfg_schedule = cfg_schedule
        self.position_ids = None
        self.seq_len = None

        self.incremental_state = {}
        self.kv_cache_tokens = 0
        self.cu_seqlens = None
        self.cu_maxlen = None

        self.cu_seqlens_k = None
        self.cu_maxlen_k = None
        self.previous_seqlen = None

    def clear_all_states(self):
        self.incremental_state = {}
        self.kv_cache_tokens = 0
        self.cu_seqlens = None
        self.cu_maxlen = None

        self.cu_seqlens_k = None
        self.cu_maxlen_k = None
        self.previous_seqlen = None
    
    def state_dict(self):
        return {
            "incremental_state": copy.deepcopy(self.incremental_state),
            "kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
            "cu_seqlens": copy.deepcopy(self.cu_seqlens),
            "cu_maxlen": copy.deepcopy(self.cu_maxlen),
            "cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
            "cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
            "previous_seqlen": copy.deepcopy(self.previous_seqlen)
        }
    
    def load_state_dict(self, state_dict):
        self.incremental_state = state_dict["incremental_state"]
        self.kv_cache_tokens = state_dict["kv_cache_tokens"]
        self.cu_seqlens = state_dict["cu_seqlens"]
        self.cu_maxlen = state_dict["cu_maxlen"]
        self.cu_seqlens_k = state_dict["cu_seqlens_k"]
        self.cu_maxlen_k = state_dict["cu_maxlen_k"]
        self.previous_seqlen = state_dict["previous_seqlen"]

    def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
        if not self.use_cfg:
            self.x_mask = x_mask
            self.x_cond = x_cond
        else:
            self.x_cond = torch.cat((x_cond, x_cond), dim=0)
            self.x_mask = torch.cat((x_mask, x_mask), dim=0)

        position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]
        position_ids = torch.tensor([position_ids_cur])


        if not self.use_cfg:
            self.position_ids = position_ids.to(self.x_cond.device).long()
            self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
        else:
            self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()
            self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()

        cu_seqlens = torch.cumsum(self.seq_len, dim=0)
        self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()
        self.cu_maxlen = self.seq_len.cpu().max()

        if self.cu_seqlens_k is None:
            self.cu_seqlens_k = self.cu_seqlens
            self.cu_maxlen_k = self.cu_maxlen
            previous_seqlen = self.seq_len
        else:
            previous_seqlen_old = cache["previous_seqlen"]
            previous_seqlen = previous_seqlen_old + self.seq_len
            # calculate cu_seqlens_k
            cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
            self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()
            self.cu_maxlen_k = previous_seqlen.cpu().max()
        self.previous_seqlen = previous_seqlen
        ret_cache = {
            "previous_seqlen": previous_seqlen
        }
        return ret_cache

    def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}):

        assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"

        for layer_idx, layer_cache in self.incremental_state.items():
            # update attention kv cache
            layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
            layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]

            self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]

            if self.kv_cache_tokens > max_kv_cache_tokens:
                # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
                reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens

                if reserve_kv_cache_tokens == 0:
                    layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
                    layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
                elif reserve_tokens_excludeprompt == 0:
                    layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens]
                    layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens]
                else:
                    layer_cache["attn_kvcache"]["prev_k"] = torch.cat([
                            layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens],
                            layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
                        ], dim=1)
                    
                    layer_cache["attn_kvcache"]["prev_v"] = torch.cat([
                            layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens],
                            layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
                        ], dim=1)


                bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
                self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long()
                condition_cache["previous_seqlen"] = self.previous_seqlen
                self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]

            # clear current cache
            layer_cache["attn_kvcache"].pop("cur_k")
            layer_cache["attn_kvcache"].pop("cur_v")


    def forward(self, t, x, args=None):
        # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
        t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()

        if self.use_cfg:
            raise NotImplementedError("cfg is not supported in streaming detokenizer.")
        else:
            pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids, 
                                  cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,
                                  cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,
                                  incremental_state=self.incremental_state, nopadding=True,
                                  mask=None, seq_len=None
                                  )   
            return pred_noise


================================================
FILE: modules/audio_detokenizer/flow_matching/scheduler.py
================================================
import torch
from abc import abstractmethod, ABC
try:
    from torchdyn.core import NeuralODE
    NEURALODE_INSTALLED = True
except ImportError:
    NEURALODE_INSTALLED = False

class SchedulerBase(ABC):
    def __init__(self) -> None:
        pass
    
    @abstractmethod
    def set_timesteps(self):
        pass
    
    @abstractmethod
    def step(self):
        pass

    @abstractmethod
    def add_noise(self):
        pass


class StreamingFlowMatchingScheduler(SchedulerBase):
    def __init__(self, timesteps=1000, sigma_min=1e-4,
                    ) -> None:
        super().__init__()

        self.sigma_min = sigma_min
        self.timesteps = timesteps
        self.t_min = 0
        self.t_max = 1 - self.sigma_min

        self.neural_ode = None

    
    def set_timesteps(self, timesteps=15):
        self.timesteps = timesteps

    def step(self, xt, predicted_v):

        h = (self.t_max - self.t_min) / self.timesteps
        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)

        xt = xt + h * predicted_v
        return xt
    
    def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
        h = (self.t_max - self.t_min) / self.timesteps
        h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)

        if verbose:
            gt_v = x0 - xt

        for t in time_steps:
            predicted_v = ode_wrapper(t, xt)
            if verbose:
                dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
                print("Time: {}, Distance: {}".format(t, dist))
            xt = xt + h * predicted_v
        return xt
    
    def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
        if not NEURALODE_INSTALLED:
            raise ImportError("NeuralODE is not installed, please install it first.")
        
        if self.neural_ode is None:
            self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)

        eval_points, traj = self.neural_ode(xt, time_steps)
        return traj[-1]

 
    def add_noise(self, original_samples: torch.FloatTensor,
                        noise: torch.FloatTensor,
                        timesteps: torch.IntTensor,):
        ut = original_samples - (1 - self.sigma_min) * noise  # 和ut的梯度没关系
        t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
        x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
        return x_noisy, ut


================================================
FILE: modules/audio_detokenizer/semantic_fm_prefix_streaming.py
================================================
import yaml
import logging
import time

import os
import torch

from modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix
from modules.audio_detokenizer.flow_matching.model import DiTPrefix
from modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler


logger = logging.getLogger(__name__)


class StreamingSemanticFMWrapper:
    def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2,
                 use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear", cfg_token_id=0, 
                 normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device("cpu")) -> None:
        
        self.dtype = torch.bfloat16
        self.speech_model = speech_model.to(device).to(self.dtype)
        self.speech_model = self.speech_model.eval()
        self.device = device
        self.normalize_mel = normalize_mel
        self.mel_mean = mel_mean
        self.mel_std = mel_std

        self.use_cfg = use_cfg
        self.use_cfg_rescale = use_cfg_rescale
        self.cfg_init = cfg_init
        self.cfg_scale = cfg_scale
        self.cfg_schedule = cfg_schedule
        
        self.incremental_state = {}
        self.condition_cache = {"previous_seqlen": 0}

        logger.info(f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}")

        self.scheduler = StreamingFlowMatchingScheduler()
        self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None,
                                      use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id)
    
        self.max_kv_cache_tokens = max_kv_cache_tokens
        self.max_prompt_chunk = max_prompt_chunk
        self.reserve_kv_cache_tokens = 0

    @torch.inference_mode()
    def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id, 
                    cache = None, look_ahead_tokens=0,
                    ode_steps=15, verbose=False, ode_solver="neural_ode_euler"):
        """
            semantic_tokens: [T_1], torch.LongTensor
            xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
            ode_steps: int, number of ode steps, default 15
            verbose: bool, default False
            ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
        """
        bs = 1

        self.scheduler.set_timesteps(ode_steps)

        semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
        xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype)

        t_span = torch.linspace(0, 1, self.scheduler.timesteps)

        x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool()
        
        cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)

        if verbose:
            t_start = time.time()
        if ode_solver == "neural_ode_euler":
            x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
        elif ode_solver == "naive_euler":
            x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
        else:
            raise NotImplementedError("ode_solver should be in ('neural_ode_euler', 'naive_euler')")
        
        if look_ahead_tokens > 0:
            semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:]
            cache["semantic_token"] = semantic_tokens_left
            x_t_ret = x_t[:, :-look_ahead_tokens, :]
        else:
            x_t_ret = x_t

        if look_ahead_tokens > 0:
            x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool()
            self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache)
            self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret)
        else:
            self.condition_cache = cache_ret

        if verbose:
            t_end = time.time()
            logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}")

        if self.normalize_mel:
            x_t_ret = x_t_ret * self.mel_std + self.mel_mean
        return x_t_ret.squeeze(0)


    @torch.inference_mode()
    def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver="neural_ode_euler"):
        """
            semantic_tokens: [T_1], torch.LongTensor
            prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
            prompt_semantic_tokens, [T_2], torch.LongTensor
            ode_steps: int, number of ode steps, default 15
            verbose: bool, default False
            ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
        """
        assert semantic_tokens.dim() == 1

        x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype)

        seq_len = semantic_tokens.shape[0]

        num_chunks = seq_len // chunk_size
        if seq_len % chunk_size != 0:
            num_chunks += 1

        x_pred_collect = []

        if verbose:
            t_start = time.time()

        for chunk_id in range(num_chunks):
            start = chunk_id * chunk_size
            end = min(start + chunk_size, seq_len)
            semantic_tokens_chunk = semantic_tokens[start:end]
            x_t_chunk = x_t[start:end, :]

            x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id,
                                      ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver)
            self.start_position_id += end - start
            self.update_incremental_state()

            x_pred_collect.append(x_pred)

        if verbose:
            t_end = time.time()
            logger.info(f"[ODE] Time cost: {t_end - t_start}")
        
        x_pred = torch.cat(x_pred_collect, dim=0)

        return x_pred
    
    def clear_all_states(self):
        self.start_position_id = 0
        self.condition_cache = {"previous_seqlen": 0}
        self.ode_wrapper.clear_all_states()
    
    def state_dict(self):
        return {
            "start_position_id": self.start_position_id,
            "ode_wrapper": self.ode_wrapper.state_dict(),
            "condition_cache": self.condition_cache
        }
    
    def load_state_dict(self, state_dict):
        if state_dict is not None:
            self.start_position_id = state_dict["start_position_id"]
            self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"])
            self.condition_cache = state_dict["condition_cache"]
    
    def update_incremental_state(self):
        self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache)
    
    @torch.inference_mode()
    def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
        """
            mel: [T, 80], torch.Tensor
            semantic_token: [T], torch.LongTensor
            chunk_size: int, default 150
        """
        assert mel.dim() == 2
        assert semantic_token.dim() == 1
        assert semantic_token.shape[0] == mel.shape[0], "Semantic token and mel shape mismatch"
        seq_len = mel.shape[0]
        num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk)
        start_pos = seq_len - num_chunks * chunk_size
        
        res_mel = mel[:start_pos, :]
        res_semantic_token = semantic_token[:start_pos]
        self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id)
        self.start_position_id += start_pos
        self.update_incremental_state()
        self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens

        if verbose:
            logger.info("Prefilling prompt with {} chunks".format(num_chunks))
            start_time = time.time()

        for chunk_id in range(num_chunks):
            start = start_pos + chunk_id * chunk_size
            end = start + chunk_size
            mel_chunk = mel[start:end, :]
            semantic_token_chunk = semantic_token[start:end]

            self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id)
            self.start_position_id += end - start
            
            self.update_incremental_state()
            self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
        
        
        if verbose:
            logger.info("Prefilling done in {:.2f} seconds".format(time.time() - start_time))
    
    def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0):
        """
            mel_chunk: [T, 80], torch.Tensor, T is the chunk size
            semantic_tokens_chunk: [T], torch.LongTensor
            start_position_id: int, default 0
        """
        bs = 1

        semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
        mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype)

        if self.normalize_mel:
            mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std

        x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool()
        
        self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)

        x_t = torch.Tensor([0.999]).to(self.device)

        self.ode_wrapper(x_t, mel_chunk)

        
    @classmethod
    def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):

        # open yaml file
        with open(model_config, 'r') as f:
            config = yaml.safe_load(f)
        model_config = config["model"]["dit"]
        dit = DiTPrefix(
            input_size=model_config["input_size"],
            semantic_vocab_size=model_config["semantic_vocab_size"] + 1,
            hidden_size=model_config["hidden_size"],
            depth=model_config["depth"],
            num_heads=model_config["num_heads"],
            mlp_ratio=model_config["mlp_ratio"],
            ffn_type=model_config.get("ffn_type", "conv1d_conv1d"),
            ffn_gated_glu=model_config.get("ffn_gated_glu", True),
            ffn_act_layer=model_config.get("ffn_act_layer", "gelu"),
            ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5),

            use_rope=model_config.get("use_rope", False),
            rope_params=model_config.get("rope_params", { "max_position_embeddings": 4096,"rope_base": 10000,"rope_interpolation_factor": 1 }),

            position_embedding_type=model_config["position_embedding_type"],
            max_seq_len=model_config["max_seq_len"],
            output_size=model_config["input_size"],
            prompt_cfg_dropout=0
        )
        cfg_semantic_token_id = model_config["semantic_vocab_size"]
        
        # load state_dict
        state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"]
        speech_model_params = {k.replace("speech_model.", ""): v for k, v in state_dict.items() if "speech_model" in k}
        dit.load_state_dict(speech_model_params, strict=True)
        logger.info(f">>> Loaded checkpoint from {ckpt_path}")

        return cls(speech_model=dit, device=device, normalize_mel=config["normalize_mel"], mel_mean=config["mel_mean"], mel_std=config["mel_std"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
                   use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id)




================================================
FILE: modules/audio_detokenizer/vocoder/activations.py
================================================
import torch
from torch import nn, sin, pow
from torch.nn import Parameter


class Snake(nn.Module):
    """
    Implementation of a sine-based periodic activation function
    Shape:
        - Input: (B, C, T)
        - Output: (B, C, T), same shape as the input
    Parameters:
        - alpha - trainable parameter
    References:
        - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
        https://arxiv.org/abs/2006.08195
    Examples:
        >>> a1 = snake(256)
        >>> x = torch.randn(256)
        >>> x = a1(x)
    """

    def __init__(
        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
    ):
        """
        Initialization.
        INPUT:
            - in_features: shape of the input
            - alpha: trainable parameter
            alpha is initialized to 1 by default, higher values = higher-frequency.
            alpha will be trained along with the rest of your model.
        """
        super(Snake, self).__init__()
        self.in_features = in_features

        # Initialize alpha
        self.alpha_logscale = alpha_logscale
        if self.alpha_logscale:  # Log scale alphas initialized to zeros
            self.alpha = Parameter(torch.zeros(in_features) * alpha)
        else:  # Linear scale alphas initialized to ones
            self.alpha = Parameter(torch.ones(in_features) * alpha)

        self.alpha.requires_grad = alpha_trainable

        self.no_div_by_zero = 0.000000001

    def forward(self, x):
        """
        Forward pass of the function.
        Applies the function to the input elementwise.
        Snake ∶= x + 1/a * sin^2 (xa)
        """
        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # Line up with x to [B, C, T]
        if self.alpha_logscale:
            alpha = torch.exp(alpha)
        x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

        return x


class SnakeBeta(nn.Module):
    """
    A modified Snake function which uses separate parameters for the magnitude of the periodic components
    Shape:
        - Input: (B, C, T)
        - Output: (B, C, T), same shape as the input
    Parameters:
        - alpha - trainable parameter that controls frequency
        - beta - trainable parameter that controls magnitude
    References:
        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
        https://arxiv.org/abs/2006.08195
    Examples:
        >>> a1 = snakebeta(256)
        >>> x = torch.randn(256)
        >>> x = a1(x)
    """

    def __init__(
        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
    ):
        """
        Initialization.
        INPUT:
            - in_features: shape of the input
            - alpha - trainable parameter that controls frequency
            - beta - trainable parameter that controls magnitude
            alpha is initialized to 1 by default, higher values = higher-frequency.
            beta is initialized to 1 by default, higher values = higher-magnitude.
            alpha will be trained along with the rest of your model.
        """
        super(SnakeBeta, self).__init__()
        self.in_features = in_features

        # Initialize alpha
        self.alpha_logscale = alpha_logscale
        if self.alpha_logscale:  # Log scale alphas initialized to zeros
            self.alpha = Parameter(torch.zeros(in_features) * alpha)
            self.beta = Parameter(torch.zeros(in_features) * alpha)
        else:  # Linear scale alphas initialized to ones
            self.alpha = Parameter(torch.ones(in_features) * alpha)
            self.beta = Parameter(torch.ones(in_features) * alpha)

        self.alpha.requires_grad = alpha_trainable
        self.beta.requires_grad = alpha_trainable

        self.no_div_by_zero = 0.000000001

    def forward(self, x):
        """
        Forward pass of the function.
        Applies the function to the input elementwise.
        SnakeBeta ∶= x + 1/b * sin^2 (xa)
        """
        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # Line up with x to [B, C, T]
        beta = self.beta.unsqueeze(0).unsqueeze(-1)
        if self.alpha_logscale:
            alpha = torch.exp(alpha)
            beta = torch.exp(beta)
        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

        return x


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py
================================================


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py
================================================


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
#   Licensed under the MIT license.

import torch
import torch.nn as nn
from ..torch.resample import UpSample1d, DownSample1d

# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load

anti_alias_activation_cuda = load.load()


class FusedAntiAliasActivation(torch.autograd.Function):
    """
    Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
    The hyperparameters are hard-coded in the kernel to maximize speed.
    NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
    """

    @staticmethod
    def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
        activation_results = anti_alias_activation_cuda.forward(
            inputs, up_ftr, down_ftr, alpha, beta
        )

        return activation_results

    @staticmethod
    def backward(ctx, output_grads):
        raise NotImplementedError
        return output_grads, None, None


class Activation1d(nn.Module):
    def __init__(
        self,
        activation,
        up_ratio: int = 2,
        down_ratio: int = 2,
        up_kernel_size: int = 12,
        down_kernel_size: int = 12,
        fused: bool = True,
    ):
        super().__init__()
        self.up_ratio = up_ratio
        self.down_ratio = down_ratio
        self.act = activation
        self.upsample = UpSample1d(up_ratio, up_kernel_size)
        self.downsample = DownSample1d(down_ratio, down_kernel_size)

        self.fused = fused  # Whether to use fused CUDA kernel or not

    def forward(self, x):
        if not self.fused:
            x = self.upsample(x)
            x = self.act(x)
            x = self.downsample(x)
            return x
        else:
            if self.act.__class__.__name__ == "Snake":
                beta = self.act.alpha.data  # Snake uses same params for alpha and beta
            else:
                beta = (
                    self.act.beta.data
                )  # Snakebeta uses different params for alpha and beta
            alpha = self.act.alpha.data
            if (
                not self.act.alpha_logscale
            ):  # Exp baked into cuda kernel, cancel it out with a log
                alpha = torch.log(alpha)
                beta = torch.log(beta)

            x = FusedAntiAliasActivation.apply(
                x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
            )
            return x


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp
================================================
/* coding=utf-8
 * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

 #include <torch/extension.h>

extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
}

================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu
================================================
/* coding=utf-8
 * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "type_shim.h"
#include <assert.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>

namespace
{
    // Hard-coded hyperparameters
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
    constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
    constexpr int BUFFER_SIZE = 32;
    constexpr int FILTER_SIZE = 12;
    constexpr int HALF_FILTER_SIZE = 6;
    constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
    constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
    constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl

    template <typename input_t, typename output_t, typename acc_t>
    __global__ void anti_alias_activation_forward(
        output_t *dst,
        const input_t *src,
        const input_t *up_ftr,
        const input_t *down_ftr,
        const input_t *alpha,
        const input_t *beta,
        int batch_size,
        int channels,
        int seq_len)
    {
        // Up and downsample filters
        input_t up_filter[FILTER_SIZE];
        input_t down_filter[FILTER_SIZE];

        // Load data from global memory including extra indices reserved for replication paddings
        input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
        input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};

        // Output stores downsampled output before writing to dst
        output_t output[BUFFER_SIZE];

        // blockDim/threadIdx = (128, 1, 1)
        // gridDim/blockIdx = (seq_blocks, channels, batches)
        int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
        int local_offset = threadIdx.x * BUFFER_SIZE;
        int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;

        // intermediate have double the seq_len
        int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
        int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;

        // Get values needed for replication padding before moving pointer
        const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
        input_t seq_left_most_value = right_most_pntr[0];
        input_t seq_right_most_value = right_most_pntr[seq_len - 1];

        // Move src and dst pointers
        src += block_offset + local_offset;
        dst += block_offset + local_offset;

        // Alpha and beta values for snake activatons. Applies exp by default
        alpha = alpha + blockIdx.y;
        input_t alpha_val = expf(alpha[0]);
        beta = beta + blockIdx.y;
        input_t beta_val = expf(beta[0]);

        #pragma unroll
        for (int it = 0; it < FILTER_SIZE; it += 1)
        {
            up_filter[it] = up_ftr[it];
            down_filter[it] = down_ftr[it];
        }

        // Apply replication padding for upsampling, matching torch impl
        #pragma unroll
        for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
        {
            int element_index = seq_offset + it; // index for element
            if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
            {
                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
            }
            if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
            {
                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
            }
            if ((element_index >= 0) && (element_index < seq_len))
            {
                elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
            }
        }

        // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
        #pragma unroll
        for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
        {
            input_t acc = 0.0;
            int element_index = intermediate_seq_offset + it; // index for intermediate
            #pragma unroll
            for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
            {
                if ((element_index + f_idx) >= 0)
                {
                    acc += up_filter[f_idx] * elements[it + f_idx];
                }
            }
            intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
        }

        // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
        double no_div_by_zero = 0.000000001;
        #pragma unroll
        for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
        {
            intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
        }

        // Apply replication padding before downsampling conv from intermediates
        #pragma unroll
        for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
        {
            intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
        }
        #pragma unroll
        for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
        {
            intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
        }

        // Apply downsample strided convolution (assuming stride=2) from intermediates
        #pragma unroll
        for (int it = 0; it < BUFFER_SIZE; it += 1)
        {
            input_t acc = 0.0;
            #pragma unroll
            for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
            {
                // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
                acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
            }
            output[it] = acc;
        }

        // Write output to dst
        #pragma unroll
        for (int it = 0;  it < BUFFER_SIZE;  it += ELEMENTS_PER_LDG_STG)
        {
            int element_index = seq_offset + it;
            if (element_index < seq_len)
            {
                dst[it] = output[it];
            }
        }

    }

    template <typename input_t, typename output_t, typename acc_t>
    void dispatch_anti_alias_activation_forward(
        output_t *dst,
        const input_t *src,
        const input_t *up_ftr,
        const input_t *down_ftr,
        const input_t *alpha,
        const input_t *beta,
        int batch_size,
        int channels,
        int seq_len)
    {
        if (seq_len == 0)
        {
            return;
        }
        else
        {
            // Use 128 threads per block to maximimize gpu utilization
            constexpr int threads_per_block = 128;
            constexpr int seq_len_per_block = 4096;
            int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
            dim3 blocks(blocks_per_seq_len, channels, batch_size);
            dim3 threads(threads_per_block, 1, 1);

            anti_alias_activation_forward<input_t, output_t, acc_t>
                <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
        }
    }
}

extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
{
    // Input is a 3d tensor with dimensions [batches, channels, seq_len]
    const int batches = input.size(0);
    const int channels = input.size(1);
    const int seq_len = input.size(2);

    // Output
    auto act_options = input.options().requires_grad(false);

    torch::Tensor anti_alias_activation_results =
        torch::empty({batches, channels, seq_len}, act_options);

    void *input_ptr = static_cast<void *>(input.data_ptr());
    void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
    void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
    void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
    void *beta_ptr = static_cast<void *>(beta.data_ptr());
    void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());

    DISPATCH_FLOAT_HALF_AND_BFLOAT(
        input.scalar_type(),
        "dispatch anti alias activation_forward",
        dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
            reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
            reinterpret_cast<const scalar_t *>(input_ptr),
            reinterpret_cast<const scalar_t *>(up_filter_ptr),
            reinterpret_cast<const scalar_t *>(down_filter_ptr),
            reinterpret_cast<const scalar_t *>(alpha_ptr),
            reinterpret_cast<const scalar_t *>(beta_ptr),
            batches,
            channels,
            seq_len););
    return anti_alias_activation_results;
}

================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h
================================================
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*This code is copied fron NVIDIA apex:
 *     https://github.com/NVIDIA/apex
 *     with minor changes. */

#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
#   Licensed under the MIT license.

import os
import pathlib
import subprocess

from torch.utils import cpp_extension

"""
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. 
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
"""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""


def load():
    # Check if cuda 11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
    if int(bare_metal_major) >= 11:
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_80,code=sm_80")

    # Build path
    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / "build"
    _create_build_dir(buildpath)

    # Helper function to build the kernels.
    def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
        return cpp_extension.load(
            name=name,
            sources=sources,
            build_directory=buildpath,
            extra_cflags=[
                "-O3",
            ],
            extra_cuda_cflags=[
                "-O3",
                "-gencode",
                "arch=compute_70,code=sm_70",
                "--use_fast_math",
            ]
            + extra_cuda_flags
            + cc_flag,
            verbose=True,
        )

    extra_cuda_flags = [
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "--expt-relaxed-constexpr",
        "--expt-extended-lambda",
    ]

    sources = [
        srcpath / "anti_alias_activation.cpp",
        srcpath / "anti_alias_activation_cuda.cu",
    ]
    anti_alias_activation_cuda = _cpp_extention_load_helper(
        "anti_alias_activation_cuda", sources, extra_cuda_flags
    )

    return anti_alias_activation_cuda


def _get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output(
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
    )
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor


def _create_build_dir(buildpath):
    try:
        os.mkdir(buildpath)
    except OSError:
        if not os.path.isdir(buildpath):
            print(f"Creation of the build directory {buildpath} failed")


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h
================================================
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include "compat.h"

#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...)                 \
	switch (TYPE)                                                       \
	{                                                                   \
	case at::ScalarType::Float:                                         \
	{                                                                   \
		using scalar_t = float;                                         \
		__VA_ARGS__;                                                    \
		break;                                                          \
	}                                                                   \
	case at::ScalarType::Half:                                          \
	{                                                                   \
		using scalar_t = at::Half;                                      \
		__VA_ARGS__;                                                    \
		break;                                                          \
	}                                                                   \
	case at::ScalarType::BFloat16:                                      \
	{                                                                   \
		using scalar_t = at::BFloat16;                                  \
		__VA_ARGS__;                                                    \
		break;                                                          \
	}                                                                   \
	default:                                                            \
		AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
	}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
	switch (TYPEIN)                                                            \
	{                                                                          \
	case at::ScalarType::Float:                                                \
	{                                                                          \
		using scalar_t_in = float;                                             \
		switch (TYPEOUT)                                                       \
		{                                                                      \
		case at::ScalarType::Float:                                            \
		{                                                                      \
			using scalar_t_out = float;                                        \
			__VA_ARGS__;                                                       \
			break;                                                             \
		}                                                                      \
		case at::ScalarType::Half:                                             \
		{                                                                      \
			using scalar_t_out = at::Half;                                     \
			__VA_ARGS__;                                                       \
			break;                                                             \
		}                                                                      \
		case at::ScalarType::BFloat16:                                         \
		{                                                                      \
			using scalar_t_out = at::BFloat16;                                 \
			__VA_ARGS__;                                                       \
			break;                                                             \
		}                                                                      \
		default:                                                               \
			AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
		}                                                                      \
		break;                                                                 \
	}                                                                          \
	case at::ScalarType::Half:                                                 \
	{                                                                          \
		using scalar_t_in = at::Half;                                          \
		using scalar_t_out = at::Half;                                         \
		__VA_ARGS__;                                                           \
		break;                                                                 \
	}                                                                          \
	case at::ScalarType::BFloat16:                                             \
	{                                                                          \
		using scalar_t_in = at::BFloat16;                                      \
		using scalar_t_out = at::BFloat16;                                     \
		__VA_ARGS__;                                                           \
		break;                                                                 \
	}                                                                          \
	default:                                                                   \
		AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'");      \
	}


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
#   LICENSE is in incl_licenses directory.

from .filter import *
from .resample import *
from .act import *


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
#   LICENSE is in incl_licenses directory.

import torch.nn as nn
from .resample import UpSample1d, DownSample1d


class Activation1d(nn.Module):
    def __init__(
        self,
        activation,
        up_ratio: int = 2,
        down_ratio: int = 2,
        up_kernel_size: int = 12,
        down_kernel_size: int = 12,
    ):
        super().__init__()
        self.up_ratio = up_ratio
        self.down_ratio = down_ratio
        self.act = activation
        self.upsample = UpSample1d(up_ratio, up_kernel_size)
        self.downsample = DownSample1d(down_ratio, down_kernel_size)

    # x: [B,C,T]
    def forward(self, x):
        x = self.upsample(x)
        x = self.act(x)
        x = self.downsample(x)

        return x


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
#   LICENSE is in incl_licenses directory.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

if "sinc" in dir(torch):
    sinc = torch.sinc
else:
    # This code is adopted from adefossez's julius.core.sinc under the MIT License
    # https://adefossez.github.io/julius/julius/core.html
    #   LICENSE is in incl_licenses directory.
    def sinc(x: torch.Tensor):
        """
        Implementation of sinc, i.e. sin(pi * x) / (pi * x)
        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
        """
        return torch.where(
            x == 0,
            torch.tensor(1.0, device=x.device, dtype=x.dtype),
            torch.sin(math.pi * x) / math.pi / x,
        )


# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
#   LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(
    cutoff, half_width, kernel_size
):  # return filter [1,1,kernel_size]
    even = kernel_size % 2 == 0
    half_size = kernel_size // 2

    # For kaiser window
    delta_f = 4 * half_width
    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
    if A > 50.0:
        beta = 0.1102 * (A - 8.7)
    elif A >= 21.0:
        beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
    else:
        beta = 0.0
    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)

    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
    if even:
        time = torch.arange(-half_size, half_size) + 0.5
    else:
        time = torch.arange(kernel_size) - half_size
    if cutoff == 0:
        filter_ = torch.zeros_like(time)
    else:
        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
        """
        Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
        """
        filter_ /= filter_.sum()
        filter = filter_.view(1, 1, kernel_size)

    return filter


class LowPassFilter1d(nn.Module):
    def __init__(
        self,
        cutoff=0.5,
        half_width=0.6,
        stride: int = 1,
        padding: bool = True,
        padding_mode: str = "replicate",
        kernel_size: int = 12,
    ):
        """
        kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
        """
        super().__init__()
        if cutoff < -0.0:
            raise ValueError("Minimum cutoff must be larger than zero.")
        if cutoff > 0.5:
            raise ValueError("A cutoff above 0.5 does not make sense.")
        self.kernel_size = kernel_size
        self.even = kernel_size % 2 == 0
        self.pad_left = kernel_size // 2 - int(self.even)
        self.pad_right = kernel_size // 2
        self.stride = stride
        self.padding = padding
        self.padding_mode = padding_mode
        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
        self.register_buffer("filter", filter)

    # Input [B, C, T]
    def forward(self, x):
        _, C, _ = x.shape

        if self.padding:
            x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
        out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)

        return out


================================================
FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py
================================================
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
#   LICENSE is in incl_licenses directory.

import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d


class UpSample1d(nn.Module):
    def __init__(self, ratio=2, kernel_size=None):
        super().__init__()
        self.ratio = ratio
        self.kernel_size = (
            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
        )
        self.stride = ratio
        self.pad = self.kernel_size // ratio - 1
        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
        self.pad_right = (
            self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
        )
        filter = kaiser_sinc_filter1d(
            cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
        )
        self.register_buffer("filter", filter)

    # x: [B, C, T]
    def forward(self, x):
        _, C, _ = x.shape

        x = F.pad(x, (self.pad, self.pad), mode="replicate")
        x = self.ratio * F.conv_transpose1d(
            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
        )
        x = x[..., self.pad_left : -self.pad_right]

        return x


class DownSample1d(nn.Module):
    def __init__(self, ratio=2, kernel_size=None):
        super().__init__()
        self.ratio = ratio
        self.kernel_size = (
            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
        )
        self.lowpass = LowPassFilter1d(
            cutoff=0.5 / ratio,
            half_width=0.6 / ratio,
            stride=ratio,
            kernel_size=self.kernel_size,
        )

    def forward(self, x):
        xx = self.lowpass(x)

        return xx


================================================
FILE: modules/audio_detokenizer/vocoder/bigvgan.py
================================================
# Copyright (c) 2024 NVIDIA CORPORATION.
#   Licensed under the MIT license.

# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
#   LICENSE is in incl_licenses directory.

import os
import json
from pathlib import Path
from typing import Optional, Union, Dict

import torch
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm

from modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta
from modules.audio_detokenizer.vocoder.utils import init_weights, get_padding
from modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d
from modules.audio_detokenizer.vocoder.utils import AttrDict

from huggingface_hub import PyTorchModelHubMixin, hf_hub_download


def load_hparams_from_json(path) -> AttrDict:
    with open(path) as f:
        data = f.read()
    return AttrDict(json.loads(data))


class AMPBlock1(torch.nn.Module):
    """
    AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
    AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1

    Args:
        h (AttrDict): Hyperparameters.
        channels (int): Number of convolution channels.
        kernel_size (int): Size of the convolution kernel. Default is 3.
        dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
        activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
    """

    def __init__(
        self,
        h: AttrDict,
        channels: int,
        kernel_size: int = 3,
        dilation: tuple = (1, 3, 5),
        activation: str = None,
    ):
        super().__init__()
        
        self.h = h

        self.convs1 = nn.ModuleList(
            [
                weight_norm(
                    Conv1d(
                        channels,
                        channels,
                        kernel_size,
                        stride=1,
                        dilation=d,
                        padding=get_padding(kernel_size, d),
                    )
                )
                for d in dilation
            ]
        )
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList(
            [
                weight_norm(
                    Conv1d(
                        channels,
                        channels,
                        kernel_size,
                        stride=1,
                        dilation=1,
                        padding=get_padding(kernel_size, 1),
                    )
                )
                for _ in range(len(dilation))
            ]
        )
        self.convs2.apply(init_weights)

        self.num_layers = len(self.convs1) + len(
            self.convs2
        )  # Total number of conv layers

        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
        if self.h.get("use_cuda_kernel", False):
            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
                Activation1d as CudaActivation1d,
            )

            Activation1d = CudaActivation1d
        else:
            Activation1d = TorchActivation1d

        # Activation functions
        if activation == "snake":
            self.activations = nn.ModuleList(
                [
                    Activation1d(
                        activation=Snake(
                            channels, alpha_logscale=h.snake_logscale
                        )
                    )
                    for _ in range(self.num_layers)
                ]
            )
        elif activation == "snakebeta":
            self.activations = nn.ModuleList(
                [
                    Activation1d(
                        activation=SnakeBeta(
                            channels, alpha_logscale=h.snake_logscale
                        )
                    )
                    for _ in range(self.num_layers)
                ]
            )
        else:
            raise NotImplementedError(
                "activation incorrectly specified. check the config file and look for 'activation'."
            )

    def forward(self, x):
        acts1, acts2 = self.activations[::2], self.activations[1::2]
        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
            xt = a1(x)
            xt = c1(xt)
            xt = a2(xt)
            xt = c2(xt)
            x = xt + x

        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)


class AMPBlock2(torch.nn.Module):
    """
    AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
    Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1

    Args:
        h (AttrDict): Hyperparameters.
        channels (int): Number of convolution channels.
        kernel_size (int): Size of the convolution kernel. Default is 3.
        dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
        activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
    """

    def __init__(
        self,
        h: AttrDict,
        channels: int,
        kernel_size: int = 3,
        dilation: tuple = (1, 3, 5),
        activation: str = None,
    ):
        super().__init__()
        
        self.h = h

        self.convs = nn.ModuleList(
            [
                weight_norm(
                    Conv1d(
                        channels,
                        channels,
                        kernel_size,
                        stride=1,
                        dilation=d,
                        padding=get_padding(kernel_size, d),
                    )
                )
                for d in dilation
            ]
        )
        self.convs.apply(init_weights)

        self.num_layers = len(self.convs)  # Total number of conv layers

        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
        if self.h.get("use_cuda_kernel", False):
            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d  import (
                Activation1d as CudaActivation1d,
            )

            Activation1d = CudaActivation1d
        else:
            Activation1d = TorchActivation1d

        # Activation functions
        if activation == "snake":
            self.activations = nn.ModuleList(
                [
                    Activation1d(
                        activation=Snake(
                            channels, alpha_logscale=h.snake_logscale
                        )
                    )
                    for _ in range(self.num_layers)
                ]
            )
        elif activation == "snakebeta":
            self.activations = nn.ModuleList(
                [
                    Activation1d(
                        activation=SnakeBeta(
                            channels, alpha_logscale=h.snake_logscale
                        )
                    )
                    for _ in range(self.num_layers)
                ]
            )
        else:
            raise NotImplementedError(
                "activation incorrectly specified. check the config file and look for 'activation'."
            )

    def forward(self, x):
        for c, a in zip(self.convs, self.activations):
            xt = a(x)
            xt = c(xt)
            x = xt + x

    def remove_weight_norm(self):
        for l in self.convs:
            remove_weight_norm(l)


class BigVGAN(
    torch.nn.Module,
    PyTorchModelHubMixin,
    library_name="bigvgan",
    repo_url="https://github.com/NVIDIA/BigVGAN",
    docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
    pipeline_tag="audio-to-audio",
    license="mit",
    tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
):
    """
    BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
    New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.

    Args:
        h (AttrDict): Hyperparameters.
        use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.

    Note:
        - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
        - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
    """

    def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
        super().__init__()
        self.h = h
        self.h["use_cuda_kernel"] = use_cuda_kernel

        # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
        if self.h.get("use_cuda_kernel", False):
            from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
                Activation1d as CudaActivation1d,
            )

            Activation1d = CudaActivation1d
        else:
            Activation1d = TorchActivation1d

        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)

        # Pre-conv
        self.conv_pre = weight_norm(
            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
        )

        # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
        if h.resblock == "1":
            resblock_class = AMPBlock1
        elif h.resblock == "2":
            resblock_class = AMPBlock2
        else:
            raise ValueError(
                f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
            )

        # Transposed conv-based upsamplers. does not apply anti-aliasing
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(
                nn.ModuleList(
                    [
                        weight_norm(
                            ConvTranspose1d(
                                h.upsample_initial_channel // (2**i),
                                h.upsample_initial_channel // (2 ** (i + 1)),
                                k,
                                u,
                                padding=(k - u) // 2,
                            )
                        )
                    ]
                )
            )

        # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel // (2 ** (i + 1))
            for j, (k, d) in enumerate(
                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
            ):
                self.resblocks.append(
                    resblock_class(h, ch, k, d, activation=h.activation)
                )

        # Post-conv
        activation_post = (
            Snake(ch, alpha_logscale=h.snake_logscale)
            if h.activation == "snake"
            else (
                SnakeBeta(ch, alpha_logscale=h.snake_logscale)
                if h.activation == "snakebeta"
                else None
            )
        )
        if activation_post is None:
            raise NotImplementedError(
                "activation incorrectly specified. check the config file and look for 'activation'."
            )

        self.activation_post = Activation1d(activation=activation_post)

        # Whether to use bias for the final conv_post. Default to True for backward compatibility
        self.use_bias_at_final = h.get("use_bias_at_final", True)
        self.conv_post = weight_norm(
            Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
        )

        # Weight initialization
        for i in range(len(self.ups)):
            self.ups[i].apply(init_weights)
        self.conv_post.apply(init_weights)

        # Final tanh activation. Defaults to True for backward compatibility
        self.use_tanh_at_final = h.get("use_tanh_at_final", True)

    def forward(self, x):
        # Pre-conv
        x = self.conv_pre(x)

        for i in range(self.num_upsamples):
            # Upsampling
            for i_up in range(len(self.ups[i])):
                x = self.ups[i][i_up](x)
            # AMP blocks
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels

        # Post-conv
        x = self.activation_post(x)
        x = self.conv_post(x)
        # Final tanh activation
        if self.use_tanh_at_final:
            x = torch.tanh(x)
        else:
            x = torch.clamp(x, min=-1.0, max=1.0)  # Bound the output to [-1, 1]

        return x

    def remove_weight_norm(self):
        try:
            print("Removing weight norm...")
            for l in self.ups:
                for l_i in l:
                    remove_weight_norm(l_i)
            for l in self.resblocks:
                l.remove_weight_norm()
            remove_weight_norm(self.conv_pre)
            remove_weight_norm(self.conv_post)
        except ValueError:
            print("[INFO] Model already removed weight norm. Skipping!")
            pass

    # Additional methods for huggingface_hub support
    def _save_pretrained(self, save_directory: Path) -> None:
        """Save weights and config.json from a Pytorch model to a local directory."""

        model_path = save_directory / "bigvgan_generator.pt"
        torch.save({"generator": self.state_dict()}, model_path)

        config_path = save_directory / "config.json"
        with open(config_path, "w") as config_file:
            json.dump(self.h, config_file, indent=4)

    @classmethod
    def _from_pretrained(
        cls,
        *,
        model_id: str,
        revision: str,
        cache_dir: str,
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: bool,
        local_files_only: bool,
        token: Union[str, bool, None],
        map_location: str = "cpu",  # Additional argument
        strict: bool = False,  # Additional argument
        use_cuda_kernel: bool = False,
        **model_kwargs,
    ):
        """Load Pytorch pretrained weights and return the loaded model."""

        # Download and load hyperparameters (h) used by BigVGAN
        if os.path.isdir(model_id):
            print("Loading config.json from local directory")
            config_file = os.path.join(model_id, "config.json")
        else:
            config_file = hf_hub_download(
                repo_id=model_id,
                filename="config.json",
                revision=revision,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                token=token,
                local_files_only=local_files_only,
            )
        h = load_hparams_from_json(config_file)

        # instantiate BigVGAN using h
        if use_cuda_kernel:
            print(
                f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
            )
            print(
                f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
            )
            print(
                f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
            )
        model = cls(h, use_cuda_kernel=use_cuda_kernel)

        # Download and load pretrained generator weight
        if os.path.isdir(model_id):
            print("Loading weights from local directory")
            model_file = os.path.join(model_id, "bigvgan_generator.pt")
        else:
            print(f"Loading weights from {model_id}")
            model_file = hf_hub_download(
                repo_id=model_id,
                filename="bigvgan_generator.pt",
                revision=revision,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                token=token,
                local_files_only=local_files_only,
            )

        checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)

        try:
            model.load_state_dict(checkpoint_dict["generator"])
        except RuntimeError:
            print(
                f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
            )
            model.remove_weight_norm()
            model.load_state_dict(checkpoint_dict["generator"])

        return model


================================================
FILE: modules/audio_detokenizer/vocoder/utils.py
================================================
from librosa.filters import mel as librosa_mel_fn
import torch
import os
mel_basis_cache = {}
hann_window_cache = {}

def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)


def spectral_normalize_torch(magnitudes):
    return dynamic_range_compression_torch(magnitudes)

def get_melspec(
    y: torch.Tensor,
    n_fft: int,
    num_mels: int,
    sampling_rate: int,
    hop_size: int,
    win_size: int,
    fmin: int,
    fmax: int = None,
    center: bool = False,
) -> torch.Tensor:
    """
    Calculate the mel spectrogram of an input signal.
    This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).

    Args:
        y (torch.Tensor): Input signal.
        n_fft (int): FFT size.
        num_mels (int): Number of mel bins.
        sampling_rate (int): Sampling rate of the input signal.
        hop_size (int): Hop size for STFT.
        win_size (int): Window size for STFT.
        fmin (int): Minimum frequency for mel filterbank.
        fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
        center (bool): Whether to pad the input to center the frames. Default is False.

    Returns:
        torch.Tensor: Mel spectrogram.
    """
    if torch.min(y) < -1.0:
        print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
    if torch.max(y) > 1.0:
        print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")

    device = y.device
    key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"

    if key not in mel_basis_cache:
        mel = librosa_mel_fn(
            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
        )
        mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
        hann_window_cache[key] = torch.hann_window(win_size).to(device)

    mel_basis = mel_basis_cache[key]
    hann_window = hann_window_cache[key]

    padding = (n_fft - hop_size) // 2
    y = torch.nn.functional.pad(
        y.unsqueeze(1), (padding, padding), mode="reflect"
    ).squeeze(1)

    spec = torch.stft(
        y,
        n_fft,
        hop_length=hop_size,
        win_length=win_size,
        window=hann_window,
        center=center,
        pad_mode="reflect",
        normalized=False,
        onesided=True,
        return_complex=True,
    )
    spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)

    mel_spec = torch.matmul(mel_basis, spec)
    mel_spec = spectral_normalize_torch(mel_spec)

    return mel_spec


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print(f"Loading '{filepath}'")
    checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
    print("Complete.")
    return checkpoint_dict

def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)


def get_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)

================================================
FILE: modules/audio_tokenizer/audio_tokenizer.py
================================================
import torch
import librosa
import yaml
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
import safetensors
import accelerate
import soundfile as sf
import math
from einops import rearrange
from modules.audio_tokenizer.rep_codec import RepCodec


class AudioTokenizer(object):
    def __init__(self, **kwargs):
        self.device = kwargs.pop('device')
        print(self.device)
        # tokenize
        feat_stats = kwargs.pop('feat_stats')
        feat_stats = torch.load(feat_stats, map_location='cpu')
        self.feat_mean = feat_stats['mean']
        self.feat_std = torch.sqrt(feat_stats['var'])
        wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
        self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
        self.semantic_model.eval()
        self.semantic_model.to(self.device)
        self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

        self.semantic_codec = RepCodec()
        self.semantic_codec.eval()
        pretrained_path = kwargs.pop("semantic_codec_ckpt") 
        safetensors.torch.load_model(self.semantic_codec, pretrained_path)
        self.semantic_codec.to(self.device)

        self.max_length = 2048
        

    @torch.no_grad()
    def tokenize(self, speech):
        # Input:
        # speech: torch tensor, shape[B, N_speech]
        # Output:
        # semantic token: torch tensor, shape[B, N]

        inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
        input_features = inputs["input_features"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        seg_num = math.ceil(input_features.shape[1] / self.max_length)
        pad_num = seg_num * self.max_length - input_features.shape[1]
        input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
        attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
        input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
        attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)


        feats = self.semantic_model(
            input_features=input_features,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        feat = feats.hidden_states[17]  
        feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
        feat = feat[:, :feat.shape[1]-pad_num, :]
        feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
        semantic_token, _ = self.semantic_codec.quantize(feat)  
        return semantic_token

def get_audio_tokenizer():
    config = dict()
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
    config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
    config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
    config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
    audio_tokenizer = AudioTokenizer(**config)
    return audio_tokenizer



================================================
FILE: modules/audio_tokenizer/quantize/__init__.py
================================================
from .vector_quantize import VectorQuantize
from .residual_vq import ResidualVQ
from .factorized_vector_quantize import FactorizedVectorQuantize


================================================
FILE: modules/audio_tokenizer/quantize/factorized_vector_quantize.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


class FactorizedVectorQuantize(nn.Module):
    def __init__(
        self,
        input_dim,
        codebook_size,
        codebook_dim,
        commitment=0.005,
        codebook_loss_weight=1.0,
        use_l2_normlize=True,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.commitment = commitment
        self.codebook_loss_weight = codebook_loss_weight
        self.use_l2_normlize = use_l2_normlize

        if self.input_dim != self.codebook_dim:
            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
            self.out_project = WNConv1d(
                self.codebook_dim, self.input_dim, kernel_size=1
            )

        else:
            self.in_project = nn.Identity()
            self.out_project = nn.Identity()

        self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)

    def forward(self, z):
        """
        Parameters
        ----------
        z: torch.Tensor[B x D x T]

        Returns
        -------
        z_q: torch.Tensor[B x D x T]
            Quantized continuous representation of input
        commit_loss: Tensor[B]
            Commitment loss to train encoder to predict vectors closer to codebook entries
        codebook_loss: Tensor[B]
            Codebook loss to update the codebook
        indices: torch.Tensor[B x T]
            Codebook indices (quantized discrete representation of input)
        z_e: torch.Tensor[B x D x T]
            Projected latents (continuous representation of input before quantization)
        """

        # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
        z_e = self.in_project(z)
        z_q, indices = self.decode_latents(z_e)

        # Compute commitment loss and codebook loss
        if self.training:
            commit_loss = (
                F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
                * self.commitment
            )
            codebook_loss = (
                F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
                * self.codebook_loss_weight
            )
        else:
            commit_loss = torch.zeros(z.shape[0], device=z.device)
            codebook_loss = torch.zeros(z.shape[0], device=z.device)

        z_q = z_e + (z_q - z_e).detach()

        z_q = self.out_project(z_q)

        return z_q, commit_loss, codebook_loss, indices, z_e

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.codebook.weight)

    def decode_code(self, embed_id):
        return self.embed_code(embed_id).transpose(1, 2)

    def decode_latents(self, latents):
        encodings = rearrange(latents, "b d t -> (b t) d")
        codebook = self.codebook.weight

        # L2 normalize encodings and codebook
        if self.use_l2_normlize:
            encodings = F.normalize(encodings)
            codebook = F.normalize(codebook)

        # Compute euclidean distance between encodings and codebook,
        # if use_l2_normlize is True, the distance is equal to cosine distance
        dist = (
            encodings.pow(2).sum(1, keepdim=True)
            - 2 * encodings @ codebook.t()
            + codebook.pow(2).sum(1, keepdim=True).t()
        )
        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
        z_q = self.decode_code(indices)

        return z_q, indices

    def vq2emb(self, vq, out_proj=True):
        emb = self.decode_code(vq)
        if out_proj:
            emb = self.out_project(emb)
        return emb

    def latent2dist(self, latents):
        encodings = rearrange(latents, "b d t -> (b t) d")
        codebook = self.codebook.weight

        # L2 normalize encodings and codebook
        if self.use_l2_normlize:
            encodings = F.normalize(encodings)
            codebook = F.normalize(codebook)

        # Compute euclidean distance between encodings and codebook,
        # if use_l2_normlize is True, the distance is equal to cosine distance
        dist = (
            encodings.pow(2).sum(1, keepdim=True)
            - 2 * encodings @ codebook.t()
            + codebook.pow(2).sum(1, keepdim=True).t()
        )  # (b*t, k)

        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
        dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
        z_q = self.decode_code(indices)

        return -dist, indices, z_q


================================================
FILE: modules/audio_tokenizer/quantize/residual_vq.py
================================================
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm


from .vector_quantize import VectorQuantize
from .factorized_vector_quantize import FactorizedVectorQuantize


class ResidualVQ(nn.Module):
    """
    Introduced in SoundStream: An end2end neural audio codec
    https://arxiv.org/abs/2107.03312
    """

    def __init__(
        self,
        input_dim: int = 256,
        num_quantizers: int = 8,
        codebook_size: int = 1024,
        codebook_dim: int = 256,
        quantizer_type: str = "vq",  # "vq" or "fvq" or "lfq"
        quantizer_dropout: float = 0.5,
        **kwargs,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.quantizer_type = quantizer_type
        self.quantizer_dropout = quantizer_dropout

        if quantizer_type == "vq":
            VQ = VectorQuantize
        elif quantizer_type == "fvq":
            VQ = FactorizedVectorQuantize
        else:
            raise ValueError(f"Unknown quantizer type {quantizer_type}")

        self.quantizers = nn.ModuleList(
            [
                VQ(
                    input_dim=input_dim,
                    codebook_size=codebook_size,
                    codebook_dim=codebook_dim,
                    **kwargs,
                )
                for _ in range(num_quantizers)
            ]
        )

    def forward(self, z, n_quantizers: int = None):
        """
        Parameters
        ----------
        z : Tensor[B x D x T]
        n_quantizers : int, optional
            No. of quantizers to use
            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
            Note: if `self.quantizer_dropout` is True, this argument is ignored
                when in training mode, and a random number of quantizers is used.
        Returns
        -------
        "quantized_out" : Tensor[B x D x T]
            Quantized continuous representation of input
        "all_indices" : Tensor[N x B x T]
            Codebook indices for each codebook
            (quantized discrete representation of input)
        "all_commit_losses" : Tensor[N]
        "all_codebook_losses" : Tensor[N]
        "all_quantized" : Tensor[N x B x D x T]
        """

        quantized_out = 0.0
        residual = z

        all_commit_losses = []
        all_codebook_losses = []
        all_indices = []
        all_quantized = []

        if n_quantizers is None:
            n_quantizers = self.num_quantizers

        if self.training:
            n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
            dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
            n_dropout = int(z.shape[0] * self.quantizer_dropout)
            n_quantizers[:n_dropout] = dropout[:n_dropout]
            n_quantizers = n_quantizers.to(z.device)

        for i, quantizer in enumerate(self.quantizers):
            if self.training is False and i >= n_quantizers:
                break

            z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
                residual
            )

            # Create mask to apply quantizer dropout
            mask = (
                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
            )
            quantized_out = quantized_out + z_q_i * mask[:, None, None]
            residual = residual - z_q_i

            commit_loss_i = (commit_loss_i * mask).mean()
            codebook_loss_i = (codebook_loss_i * mask).mean()

            all_commit_losses.append(commit_loss_i)
            all_codebook_losses.append(codebook_loss_i)
            all_indices.append(indices_i)
            all_quantized.append(z_q_i)

        all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
            torch.stack,
            (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
        )

        return (
            quantized_out,
            all_indices,
            all_commit_losses,
            all_codebook_losses,
            all_quantized,
        )

    def vq2emb(self, vq, n_quantizers=None):
        quantized_out = 0.0
        if n_quantizers is None:
            n_quantizers = self.num_quantizers
        for idx, quantizer in enumerate(self.quantizers):
            if idx >= n_quantizers:
                break
            quantized_out += quantizer.vq2emb(vq[idx])
        return quantized_out

    def latent2dist(self, z, n_quantizers=None):
        quantized_out = 0.0
        residual = z

        all_dists = []
        all_indices = []

        if n_quantizers is None:
            n_quantizers = self.num_quantizers

        for i, quantizer in enumerate(self.quantizers):
            if self.training is False and i >= n_quantizers:
                break
            dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
            all_dists.append(dist_i)
            all_indices.append(indices_i)

            quantized_out = quantized_out + z_q_i
            residual = residual - z_q_i

        all_dists = torch.stack(all_dists)
        all_indices = torch.stack(all_indices)

        return all_dists, all_indices


================================================
FILE: modules/audio_tokenizer/quantize/vector_quantize.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.nn.utils import weight_norm


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


def l2norm(t):
    return F.normalize(t, p=2, dim=-1)


def ema_inplace(moving_avg, new, decay):
    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))


def laplace_smoothing(x, n_categories, eps=1e-5):
    return (x + eps) / (x.sum() + n_categories * eps)


def sample_vectors(samples, num):
    num_samples, device = samples.shape[0], samples.device

    if num_samples >= num:
        indices = torch.randperm(num_samples, device=device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device=device)

    return samples[indices]


def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
    dim, dtype, device = samples.shape[-1], samples.dtype, samples.device

    means = sample_vectors(samples, num_clusters)

    for _ in range(num_iters):
        if use_cosine_sim:
            dists = samples @ means.t()
        else:
            diffs = rearrange(samples, "n d -> n () d") - rearrange(
                means, "c d -> () c d"
            )
            dists = -(diffs**2).sum(dim=-1)

        buckets = dists.max(dim=-1).indices
        bins = torch.bincount(buckets, minlength=num_clusters)
        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
        new_means = new_means / bins_min_clamped[..., None]

        if use_cosine_sim:
            new_means = l2norm(new_means)

        means = torch.where(zero_mask[..., None], means, new_means)

    return means, bins


class EuclideanCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        kmeans_init=False,
        kmeans_iters=10,
        decay=0.8,
        eps=1e-5,
        threshold_ema_dead_code=2,
        weight_init=False,
    ):
        super().__init__()

        self.decay = decay
        init_fn = torch.randn if not weight_init else torch.zeros
        embed = init_fn(codebook_size, dim)

        if weight_init:
            nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)

        self.codebook_size = codebook_size
        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code

        self.register_buffer(
            "initted", torch.Tensor([not kmeans_init])
        )  # if kmeans_init is True, then initted is False; otherwise, initted is True
        self.register_buffer("cluster_size", torch.zeros(codebook_size))
        self.register_buffer("embed", embed)
        self.register_buffer("embed_avg", embed.clone())

    def init_embed_(self, data):
        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed)
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))

    def replace(self, samples, mask):
        modified_codebook = torch.where(
            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
        )
        self.embed.data.copy_(modified_codebook)

    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code
        if not torch.any(expired_codes):
            return
        batch_samples = rearrange(batch_samples, "... d -> (...) d")
        self.replace(batch_samples, mask=expired_codes)

    def forward(self, x):
        shape, dtype = x.shape, x.dtype
        flatten = rearrange(x, "... d -> (...) d")
        embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)

        if not self.initted:
            self.init_embed_(flatten)

        dist = -(
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ embed
            + embed.pow(2).sum(0, keepdim=True)
        )

        embed_ind = dist.max(dim=-1).indices
        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
        embed_ind = embed_ind.view(*shape[:-1])
        quantize = F.embedding(embed_ind, self.embed)

        if self.training:
            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
            embed_sum = (
                flatten.t() @ embed_onehot
            )  # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
            cluster_size = (
                laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
                * self.cluster_size.sum()
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
            self.embed.data.copy_(embed_normalized)
            self.expire_codes_(x)

        return quantize, embed_ind

    def vq2emb(self, vq):
        quantize = F.embedding(vq, self.embed)
        return quantize

    def latent2dist(self, x):
        shape, dtype = x.shape, x.dtype
        flatten = rearrange(x, "... d -> (...) d")
        embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)

        if not self.initted:
            self.init_embed_(flatten)

        dist = -(
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ embed
            + embed.pow(2).sum(0, keepdim=True)
        )

        embed_ind = dist.max(dim=-1).indices
        embed_ind = embed_ind.view(*shape[:-1])
        quantize = F.embedding(embed_ind, self.embed)

        dist = dist.view(*shape[:-1], -1)

        return dist, embed_ind, quantize


class SimpleCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        use_l2_normlize=False,
    ):
        super().__init__()

        self.dim = dim
        self.codebook_size = codebook_size
        self.use_l2_normlize = use_l2_normlize

        self.embed = nn.Embedding(self.codebook_size, self.dim)

    def forward(self, x):
        shape, dtype = x.shape, x.dtype
        flatten = rearrange(x, "... d -> (...) d")
        embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)

        if self.use_l2_normlize:
            flatten = F.normalize(flatten)
            embed = F.normalize(embed)

        dist = -(
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ embed
            + embed.pow(2).sum(0, keepdim=True)
        )

        embed_ind = dist.max(dim=-1).indices
        embed_ind = embed_ind.view(*shape[:-1])
        quantize = F.embedding(embed_ind, self.embed)

        return quantize, embed_ind

    def vq2emb(self, vq):
        quantize = F.embedding(vq, self.embed.weight)
        return quantize

    def latent2dist(self, x):
        shape, dtype = x.shape, x.dtype
        flatten = rearrange(x, "... d -> (...) d")
        embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)

        if self.use_l2_normlize:
            flatten = F.normalize(flatten)
            embed = F.normalize(embed)

        dist = -(
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ embed
            + embed.pow(2).sum(0, keepdim=True)
        )

        embed_ind = dist.max(dim=-1).indices
        embed_ind = embed_ind.view(*shape[:-1])
        quantize = F.embedding(embed_ind, self.embed)

        dist = dist.view(*shape[:-1], -1)

        return dist, embed_ind, quantize


class VectorQuantize(nn.Module):
    """Vector quantization and factorized vecotor quantization implementation
    Args:
        input_dim (int): Dimension of input.
        codebook_size (int): Codebook size.
        codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
            if use codebook_type == "euclidean", otherwise, if you want to use
            factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
        commitment (float): Weight for commitment loss.
        use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
            we suggest use it as True if you want to use factorized vector quantization
        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
        kmeans_iters (int): Number of iterations used for kmeans initialization.
        decay (float): Decay for exponential moving average over the codebooks.
        epsilon (float): Epsilon value for numerical stability.
        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
            that have an exponential moving average cluster size less than the specified threshold with
            randomly selected vector from the current batch.
    """

    def __init__(
        self,
        input_dim,
        codebook_size,
        codebook_dim,
        commitment=0.005,
        codebook_loss_weight=1.0,
        use_l2_normlize=False,
        codebook_type="euclidean",  # "euclidean" or "simple"
        kmeans_init=False,
        kmeans_iters=10,
        decay=0.8,
        eps=1e-5,
        threshold_ema_dead_code=2,
        weight_init=False,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.commitment = commitment
        self.codebook_loss_weight = codebook_loss_weight
        self.use_l2_normlize = use_l2_normlize
        self.codebook_type = codebook_type
        self.kmeans_init = kmeans_init
        self.kmeans_iters = kmeans_iters
        self.decay = decay
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.weight_init = weight_init

        if self.input_dim != self.codebook_dim:
            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
            self.out_project = WNConv1d(
                self.codebook_dim, self.input_dim, kernel_size=1
            )

        else:
            self.in_project = nn.Identity()
            self.out_project = nn.Identity()

        if self.codebook_type == "euclidean":
            self.codebook = EuclideanCodebook(
                self.codebook_dim,
                codebook_size=self.codebook_size,
                kmeans_init=self.kmeans_init,
                kmeans_iters=self.kmeans_iters,
                decay=self.decay,
                eps=self.eps,
                threshold_ema_dead_code=self.threshold_ema_dead_code,
                weight_init=self.weight_init,
            )
        elif self.codebook_type == "simple":
            self.codebook = SimpleCodebook(
                self.codebook_dim,
                codebook_size=self.codebook_size,
                use_l2_normlize=self.use_l2_normlize,
            )
        else:
            raise NotImplementedError(
                f"codebook_type {self.codebook_type} is not implemented!"
            )

    def forward(self, z):
        """
        Parameters
        ----------
        z: torch.Tensor[B x D x T]

        Returns
        -------
        z_q: torch.Tensor[B x D x T]
            Quantized continuous representation of input
        commit_loss: Tensor[B]
            Commitment loss to train encoder to predict vectors closer to codebook entries
        codebook_loss: Tensor[B]
            Codebook loss to update the codebook
        indices: torch.Tensor[B x T]
            Codebook indices (quantized discrete representation of input)
        z_e: torch.Tensor[B x D x T]
            Projected latents (continuous representation of input before quantization)
        """

        # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
        z_e = self.in_project(z)
        z_q, indices = self.decode_latents(z_e)

        # Compute commitment loss and codebook loss
        if self.training:
            commit_loss = (
                F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
                * self.commitment
            )
            codebook_loss = (
                F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
                * self.codebook_loss_weight
            )
        else:
            commit_loss = torch.zeros(z.shape[0], device=z.device)
            codebook_loss = torch.zeros(z.shape[0], device=z.device)

        z_q = z_e + (z_q - z_e).detach()

        z_q = self.out_project(z_q)

        return z_q, commit_loss, codebook_loss, indices, z_e

    def decode_latents(self, latents):
        encodings = rearrange(latents, "b d t -> b t d")
        z_q, indices = self.codebook(encodings)
        z_q = z_q.transpose(1, 2)
        return z_q, indices

    def vq2emb(self, vq, out_proj=True):
        emb = self.codebook.vq2emb(vq)
        emb = emb.transpose(1, 2)
        if out_proj:
            emb = self.out_project(emb)
        return emb

    def latent2dist(self, latents):
        latents = rearrange(latents, "b d t -> b t d")
        dist, embed_ind, quantize = self.codebook.latent2dist(latents)
        return dist, embed_ind, quantize.transpose(1, 2)


================================================
FILE: modules/audio_tokenizer/rep_codec.py
================================================
import torch
import torch.nn as nn


from modules.audio_tokenizer.quantize import ResidualVQ
from modules.audio_tokenizer.vocos import VocosBackbone
from modules.audio_tokenizer.transformer import TransformerEncoder

def init_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.trunc_normal_(m.weig
Download .txt
gitextract_g10rsitx/

├── .gitignore
├── LICENSE
├── app.py
├── download_pretrain.py
├── en_llmprompt_script_gen.py
├── inference.py
├── modules/
│   ├── audio_detokenizer/
│   │   ├── audio_detokenizer.py
│   │   ├── bigvgan_wrapper.py
│   │   ├── flow_matching/
│   │   │   ├── dit_block.py
│   │   │   ├── model.py
│   │   │   ├── ode_wrapper.py
│   │   │   └── scheduler.py
│   │   ├── semantic_fm_prefix_streaming.py
│   │   └── vocoder/
│   │       ├── activations.py
│   │       ├── alias_free_activation/
│   │       │   ├── __init__.py
│   │       │   ├── cuda/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── activation1d.py
│   │       │   │   ├── anti_alias_activation.cpp
│   │       │   │   ├── anti_alias_activation_cuda.cu
│   │       │   │   ├── compat.h
│   │       │   │   ├── load.py
│   │       │   │   └── type_shim.h
│   │       │   └── torch/
│   │       │       ├── __init__.py
│   │       │       ├── act.py
│   │       │       ├── filter.py
│   │       │       └── resample.py
│   │       ├── bigvgan.py
│   │       └── utils.py
│   ├── audio_tokenizer/
│   │   ├── audio_tokenizer.py
│   │   ├── quantize/
│   │   │   ├── __init__.py
│   │   │   ├── factorized_vector_quantize.py
│   │   │   ├── residual_vq.py
│   │   │   └── vector_quantize.py
│   │   ├── rep_codec.py
│   │   ├── transformer.py
│   │   └── vocos.py
│   └── tokenizer/
│       └── tokenizer.py
├── readme.md
├── requirements.txt
├── test/
│   ├── test_audio_detokenizer.py
│   ├── test_audio_tokenizer.py
│   └── test_tokenizer.py
└── zh_llmprompt_script_gen.py
Download .txt
SYMBOL INDEX (298 symbols across 26 files)

FILE: app.py
  function process_json_and_generate_audio (line 12) | def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text...
  function update_ui_language (line 137) | def update_ui_language(language):

FILE: inference.py
  class Model (line 18) | class Model(object):
    method __init__ (line 19) | def __init__(self):
    method _clean_text (line 52) | def _clean_text(self, text):
    method _process_text (line 66) | def _process_text(self, js):
    method inference (line 76) | def inference(self, js, streaming=False):
    method infer_with_prompt (line 90) | def infer_with_prompt(self, js):
    method infer_with_prompt_streaming (line 175) | def infer_with_prompt_streaming(self, js):
    method infer_without_prompt (line 255) | def infer_without_prompt(self, js):
    method infer_without_prompt_streaming (line 312) | def infer_without_prompt_streaming(self, js):

FILE: modules/audio_detokenizer/audio_detokenizer.py
  class PrefixStreamingFlowMatchingDetokenizer (line 8) | class PrefixStreamingFlowMatchingDetokenizer:
    method __init__ (line 9) | def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWra...
    method from_pretrained (line 34) | def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_c...
    method prefill (line 44) | def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: in...
    method detokenize_streaming (line 76) | def detokenize_streaming(self, semantic_token, ode_step=30, verbose=Fa...
    method clear_states (line 174) | def clear_states(self):
  function get_audio_detokenizer (line 180) | def get_audio_detokenizer():
  function detokenize (line 201) | def detokenize(detokenizer, tokens, ref_wav, ref_tokens):
  function detokenize_streaming (line 220) | def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):
  function detokenize_noref (line 236) | def detokenize_noref(detokenizer, tokens):
  function detokenize_noref_streaming (line 255) | def detokenize_noref_streaming(detokenizer, tokens):

FILE: modules/audio_detokenizer/bigvgan_wrapper.py
  class BigVGANWrapper (line 14) | class BigVGANWrapper:
    method __init__ (line 15) | def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict...
    method to_dtype (line 23) | def to_dtype(self, dtype):
    method extract_mel_from_wav (line 26) | def extract_mel_from_wav(self, wav_path=None, wav_data=None):
    method extract_mel_from_wav_batch (line 44) | def extract_mel_from_wav_batch(self, wav_data):
    method decode_mel (line 58) | def decode_mel(self, mel):
    method decode_mel_batch (line 69) | def decode_mel_batch(self, mel):
    method from_pretrained (line 81) | def from_pretrained(cls, model_config, ckpt_path, device):

FILE: modules/audio_detokenizer/flow_matching/dit_block.py
  function reshape_for_broadcast (line 11) | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  function apply_rotary_emb (line 24) | def apply_rotary_emb(
  class Attention (line 39) | class Attention(nn.Module):
    method __init__ (line 41) | def __init__(
    method forward (line 67) | def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu...
  function modulate (line 160) | def modulate(x, shift, scale):
  class FinalLayer (line 164) | class FinalLayer(nn.Module):
    method __init__ (line 168) | def __init__(self, hidden_size, out_channels):
    method forward (line 177) | def forward(self, x, c):
  class DiTBlock (line 184) | class DiTBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="co...
    method forward (line 209) | def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, ...

FILE: modules/audio_detokenizer/flow_matching/model.py
  function precompute_freqs_cis (line 6) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
  class TimestepEmbedder (line 28) | class TimestepEmbedder(nn.Module):
    method __init__ (line 32) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method timestep_embedding (line 42) | def timestep_embedding(t, dim, max_period=10000):
    method forward (line 62) | def forward(self, t):
  class SinusoidalPositionalEmbedding (line 68) | class SinusoidalPositionalEmbedding(nn.Module):
    method __init__ (line 74) | def __init__(self, embedding_dim, padding_idx, init_size=1024):
    method get_embedding (line 86) | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
    method forward (line 104) | def forward(self, input, incremental_state=None, timestep=None, **kwar...
    method max_positions (line 125) | def max_positions(self):
    method make_positions (line 129) | def make_positions(self, tensor, padding_idx):
  class DiTPrefix (line 144) | class DiTPrefix(nn.Module):
    method __init__ (line 148) | def __init__(
    method initialize_weights (line 218) | def initialize_weights(self):
    method forward (line 243) | def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, ...

FILE: modules/audio_detokenizer/flow_matching/ode_wrapper.py
  function get_cached_zeros (line 8) | def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
  class StreamingODEWrapperForPrefix (line 11) | class StreamingODEWrapperForPrefix(nn.Module):
    method __init__ (line 12) | def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale...
    method clear_all_states (line 38) | def clear_all_states(self):
    method state_dict (line 48) | def state_dict(self):
    method load_state_dict (line 59) | def load_state_dict(self, state_dict):
    method set_conditions (line 68) | def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
    method update_incremental_state (line 108) | def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_c...
    method forward (line 151) | def forward(self, t, x, args=None):

FILE: modules/audio_detokenizer/flow_matching/scheduler.py
  class SchedulerBase (line 9) | class SchedulerBase(ABC):
    method __init__ (line 10) | def __init__(self) -> None:
    method set_timesteps (line 14) | def set_timesteps(self):
    method step (line 18) | def step(self):
    method add_noise (line 22) | def add_noise(self):
  class StreamingFlowMatchingScheduler (line 26) | class StreamingFlowMatchingScheduler(SchedulerBase):
    method __init__ (line 27) | def __init__(self, timesteps=1000, sigma_min=1e-4,
    method set_timesteps (line 39) | def set_timesteps(self, timesteps=15):
    method step (line 42) | def step(self, xt, predicted_v):
    method sample (line 50) | def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
    method sample_by_neuralode (line 65) | def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=Fal...
    method add_noise (line 76) | def add_noise(self, original_samples: torch.FloatTensor,

FILE: modules/audio_detokenizer/semantic_fm_prefix_streaming.py
  class StreamingSemanticFMWrapper (line 16) | class StreamingSemanticFMWrapper:
    method __init__ (line 17) | def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, m...
    method infer_chunk (line 49) | def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id,
    method infer_mel (line 105) | def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, ver...
    method clear_all_states (line 150) | def clear_all_states(self):
    method state_dict (line 155) | def state_dict(self):
    method load_state_dict (line 162) | def load_state_dict(self, state_dict):
    method update_incremental_state (line 168) | def update_incremental_state(self):
    method prefill (line 172) | def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
    method prefill_chunk (line 212) | def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_positi...
    method from_pretrained (line 236) | def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_c...

FILE: modules/audio_detokenizer/vocoder/activations.py
  class Snake (line 6) | class Snake(nn.Module):
    method __init__ (line 23) | def __init__(
    method forward (line 48) | def forward(self, x):
  class SnakeBeta (line 62) | class SnakeBeta(nn.Module):
    method __init__ (line 80) | def __init__(
    method forward (line 110) | def forward(self, x):

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py
  class FusedAntiAliasActivation (line 14) | class FusedAntiAliasActivation(torch.autograd.Function):
    method forward (line 22) | def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
    method backward (line 30) | def backward(ctx, output_grads):
  class Activation1d (line 35) | class Activation1d(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 54) | def forward(self, x):

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp
  function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py
  function load (line 17) | def load():
  function _get_cuda_bare_metal_version (line 68) | def _get_cuda_bare_metal_version(cuda_dir):
  function _create_build_dir (line 81) | def _create_build_dir(buildpath):

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py
  class Activation1d (line 8) | class Activation1d(nn.Module):
    method __init__ (line 9) | def __init__(
    method forward (line 25) | def forward(self, x):

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py
  function sinc (line 15) | def sinc(x: torch.Tensor):
  function kaiser_sinc_filter1d (line 30) | def kaiser_sinc_filter1d(
  class LowPassFilter1d (line 65) | class LowPassFilter1d(nn.Module):
    method __init__ (line 66) | def __init__(
    method forward (line 94) | def forward(self, x):

FILE: modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py
  class UpSample1d (line 10) | class UpSample1d(nn.Module):
    method __init__ (line 11) | def __init__(self, ratio=2, kernel_size=None):
    method forward (line 29) | def forward(self, x):
  class DownSample1d (line 41) | class DownSample1d(nn.Module):
    method __init__ (line 42) | def __init__(self, ratio=2, kernel_size=None):
    method forward (line 55) | def forward(self, x):

FILE: modules/audio_detokenizer/vocoder/bigvgan.py
  function load_hparams_from_json (line 25) | def load_hparams_from_json(path) -> AttrDict:
  class AMPBlock1 (line 31) | class AMPBlock1(torch.nn.Module):
    method __init__ (line 44) | def __init__(
    method forward (line 132) | def forward(self, x):
    method remove_weight_norm (line 143) | def remove_weight_norm(self):
  class AMPBlock2 (line 150) | class AMPBlock2(torch.nn.Module):
    method __init__ (line 163) | def __init__(
    method forward (line 232) | def forward(self, x):
    method remove_weight_norm (line 238) | def remove_weight_norm(self):
  class BigVGAN (line 243) | class BigVGAN(
    method __init__ (line 266) | def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
    method forward (line 360) | def forward(self, x):
    method remove_weight_norm (line 388) | def remove_weight_norm(self):
    method _save_pretrained (line 403) | def _save_pretrained(self, save_directory: Path) -> None:
    method _from_pretrained (line 414) | def _from_pretrained(

FILE: modules/audio_detokenizer/vocoder/utils.py
  function dynamic_range_compression_torch (line 7) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
  function spectral_normalize_torch (line 11) | def spectral_normalize_torch(magnitudes):
  function get_melspec (line 14) | def get_melspec(
  class AttrDict (line 86) | class AttrDict(dict):
    method __init__ (line 87) | def __init__(self, *args, **kwargs):
  function load_checkpoint (line 91) | def load_checkpoint(filepath, device):
  function init_weights (line 98) | def init_weights(m, mean=0.0, std=0.01):
  function get_padding (line 104) | def get_padding(kernel_size, dilation=1):

FILE: modules/audio_tokenizer/audio_tokenizer.py
  class AudioTokenizer (line 13) | class AudioTokenizer(object):
    method __init__ (line 14) | def __init__(self, **kwargs):
    method tokenize (line 38) | def tokenize(self, speech):
  function get_audio_tokenizer (line 67) | def get_audio_tokenizer():

FILE: modules/audio_tokenizer/quantize/factorized_vector_quantize.py
  function WNConv1d (line 9) | def WNConv1d(*args, **kwargs):
  function WNConvTranspose1d (line 13) | def WNConvTranspose1d(*args, **kwargs):
  class FactorizedVectorQuantize (line 17) | class FactorizedVectorQuantize(nn.Module):
    method __init__ (line 18) | def __init__(
    method forward (line 47) | def forward(self, z):
    method embed_code (line 91) | def embed_code(self, embed_id):
    method decode_code (line 94) | def decode_code(self, embed_id):
    method decode_latents (line 97) | def decode_latents(self, latents):
    method vq2emb (line 118) | def vq2emb(self, vq, out_proj=True):
    method latent2dist (line 124) | def latent2dist(self, latents):

FILE: modules/audio_tokenizer/quantize/residual_vq.py
  class ResidualVQ (line 15) | class ResidualVQ(nn.Module):
    method __init__ (line 21) | def __init__(
    method forward (line 59) | def forward(self, z, n_quantizers: int = None):
    method vq2emb (line 135) | def vq2emb(self, vq, n_quantizers=None):
    method latent2dist (line 145) | def latent2dist(self, z, n_quantizers=None):

FILE: modules/audio_tokenizer/quantize/vector_quantize.py
  function WNConv1d (line 9) | def WNConv1d(*args, **kwargs):
  function WNConvTranspose1d (line 13) | def WNConvTranspose1d(*args, **kwargs):
  function l2norm (line 17) | def l2norm(t):
  function ema_inplace (line 21) | def ema_inplace(moving_avg, new, decay):
  function laplace_smoothing (line 25) | def laplace_smoothing(x, n_categories, eps=1e-5):
  function sample_vectors (line 29) | def sample_vectors(samples, num):
  function kmeans (line 40) | def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
  class EuclideanCodebook (line 71) | class EuclideanCodebook(nn.Module):
    method __init__ (line 72) | def __init__(
    method init_embed_ (line 104) | def init_embed_(self, data):
    method replace (line 111) | def replace(self, samples, mask):
    method expire_codes_ (line 117) | def expire_codes_(self, batch_samples):
    method forward (line 127) | def forward(self, x):
    method vq2emb (line 162) | def vq2emb(self, vq):
    method latent2dist (line 166) | def latent2dist(self, x):
  class SimpleCodebook (line 189) | class SimpleCodebook(nn.Module):
    method __init__ (line 190) | def __init__(
    method forward (line 204) | def forward(self, x):
    method vq2emb (line 225) | def vq2emb(self, vq):
    method latent2dist (line 229) | def latent2dist(self, x):
  class VectorQuantize (line 253) | class VectorQuantize(nn.Module):
    method __init__ (line 273) | def __init__(
    method forward (line 336) | def forward(self, z):
    method decode_latents (line 380) | def decode_latents(self, latents):
    method vq2emb (line 386) | def vq2emb(self, vq, out_proj=True):
    method latent2dist (line 393) | def latent2dist(self, latents):

FILE: modules/audio_tokenizer/rep_codec.py
  function init_weights (line 9) | def init_weights(m):
  class RepCodec (line 17) | class RepCodec(nn.Module):
    method __init__ (line 18) | def __init__(
    method forward (line 136) | def forward(self, x):
    method quantize (line 177) | def quantize(self, x):
    method reset_parameters (line 196) | def reset_parameters(self):

FILE: modules/audio_tokenizer/transformer.py
  class StyleAdaptiveLayerNorm (line 8) | class StyleAdaptiveLayerNorm(nn.Module):
    method __init__ (line 9) | def __init__(self, normalized_shape, eps=1e-5):
    method forward (line 17) | def forward(self, x, condition):
  class PositionalEncoding (line 30) | class PositionalEncoding(nn.Module):
    method __init__ (line 31) | def __init__(self, d_model, dropout, max_len=5000):
    method forward (line 44) | def forward(self, x):
  class TransformerFFNLayer (line 49) | class TransformerFFNLayer(nn.Module):
    method __init__ (line 50) | def __init__(
    method forward (line 70) | def forward(self, x):
  class TransformerEncoderLayer (line 81) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 82) | def __init__(
    method forward (line 117) | def forward(self, x, key_padding_mask, conditon=None):
  class TransformerEncoder (line 149) | class TransformerEncoder(nn.Module):
    method __init__ (line 150) | def __init__(
    method forward (line 217) | def forward(self, x, key_padding_mask, condition=None):

FILE: modules/audio_tokenizer/vocos.py
  function safe_log (line 12) | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
  function symlog (line 26) | def symlog(x: torch.Tensor) -> torch.Tensor:
  function symexp (line 30) | def symexp(x: torch.Tensor) -> torch.Tensor:
  class STFT (line 34) | class STFT(nn.Module):
    method __init__ (line 35) | def __init__(
    method forward (line 50) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ISTFT (line 78) | class ISTFT(nn.Module):
    method __init__ (line 93) | def __init__(
    method forward (line 106) | def forward(self, spec: torch.Tensor) -> torch.Tensor:
  class MDCT (line 164) | class MDCT(nn.Module):
    method __init__ (line 173) | def __init__(self, frame_len: int, padding: str = "same"):
    method forward (line 191) | def forward(self, audio: torch.Tensor) -> torch.Tensor:
  class IMDCT (line 225) | class IMDCT(nn.Module):
    method __init__ (line 234) | def __init__(self, frame_len: int, padding: str = "same"):
    method forward (line 250) | def forward(self, X: torch.Tensor) -> torch.Tensor:
  class FourierHead (line 293) | class FourierHead(nn.Module):
    method forward (line 296) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ISTFTHead (line 308) | class ISTFTHead(FourierHead):
    method __init__ (line 320) | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str...
    method forward (line 328) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class IMDCTSymExpHead (line 358) | class IMDCTSymExpHead(FourierHead):
    method __init__ (line 371) | def __init__(
    method forward (line 395) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class IMDCTCosHead (line 418) | class IMDCTCosHead(FourierHead):
    method __init__ (line 429) | def __init__(
    method forward (line 441) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ConvNeXtBlock (line 463) | class ConvNeXtBlock(nn.Module):
    method __init__ (line 475) | def __init__(
    method forward (line 502) | def forward(
  class AdaLayerNorm (line 524) | class AdaLayerNorm(nn.Module):
    method __init__ (line 533) | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float...
    method forward (line 546) | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) ->...
  class ResBlock1 (line 554) | class ResBlock1(nn.Module):
    method __init__ (line 570) | def __init__(
    method forward (line 676) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method remove_weight_norm (line 687) | def remove_weight_norm(self):
    method get_padding (line 694) | def get_padding(kernel_size: int, dilation: int = 1) -> int:
  class Backbone (line 698) | class Backbone(nn.Module):
    method forward (line 701) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
  class VocosBackbone (line 714) | class VocosBackbone(Backbone):
    method __init__ (line 728) | def __init__(
    method _init_weights (line 760) | def _init_weights(self, m):
    method forward (line 765) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
  class VocosResNetBackbone (line 780) | class VocosResNetBackbone(Backbone):
    method __init__ (line 791) | def __init__(
    method forward (line 811) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
  class Vocos (line 818) | class Vocos(nn.Module):
    method __init__ (line 819) | def __init__(
    method forward (line 841) | def forward(self, x):

FILE: modules/tokenizer/tokenizer.py
  function encode_pieces (line 9) | def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sampl...
  class AbstractTokenizer (line 20) | class AbstractTokenizer(ABC):
    method __init__ (line 23) | def __init__(self, name):
    method vocab_size (line 29) | def vocab_size(self):
    method vocab (line 34) | def vocab(self):
    method inv_vocab (line 40) | def inv_vocab(self):
    method tokenize (line 45) | def tokenize(self, text):
    method detokenize (line 48) | def detokenize(self, token_ids):
    method cls (line 53) | def cls(self):
    method sep (line 58) | def sep(self):
    method pad (line 63) | def pad(self):
    method eod (line 68) | def eod(self):
    method mask (line 73) | def mask(self):
  class SPieceTokenizer (line 78) | class SPieceTokenizer(AbstractTokenizer):
    method __init__ (line 79) | def __init__(self, spm_file: str):
    method encode_pieces (line 96) | def encode_pieces(self, text: str, sample=False):
    method _initialize_index_2_bytes (line 103) | def _initialize_index_2_bytes(self):
    method set_add_dummy_prefix (line 111) | def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
    method add_special_id (line 119) | def add_special_id(self, token_id):
    method has_dummy_prefix (line 123) | def has_dummy_prefix(self):
    method vocab_size (line 128) | def vocab_size(self):
    method vocab (line 132) | def vocab(self):
    method get_array_bytes (line 136) | def get_array_bytes(self, array):
    method tokenize (line 139) | def tokenize(self, text):
    method encode (line 143) | def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs...
    method convert_tokens_to_ids (line 152) | def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list...
    method detokenize (line 157) | def detokenize(self, token_ids):
    method decode (line 164) | def decode(self, token_ids: Union[int, list[int]], skip_special_tokens...
    method get_token_id (line 170) | def get_token_id(self, token):
    method inv_vocab (line 173) | def inv_vocab(self):
    method decode_pieces (line 177) | def decode_pieces(self, pieces):
    method eod (line 181) | def eod(self):
    method pad_id (line 185) | def pad_id(self):
    method eos_id (line 189) | def eos_id(self):
    method bos_id (line 193) | def bos_id(self):
    method unk_id (line 197) | def unk_id(self):
    method pad_token_id (line 201) | def pad_token_id(self):
    method eos_token_id (line 205) | def eos_token_id(self):
  class ExtraTokens (line 210) | class ExtraTokens:
  function instantiate_extra_tokens (line 221) | def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
  function get_tokenizer_and_extra_tokens (line 238) | def get_tokenizer_and_extra_tokens():
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (283K chars).
[
  {
    "path": ".gitignore",
    "chars": 136,
    "preview": "*.safetensors\n*.pt\n*.vscode\n**/__pycache__/\nmodules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/\ntmp*\nres"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2025 Zeqian Ju\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "app.py",
    "chars": 9338,
    "preview": "import gradio as gr\nfrom huggingface_hub import snapshot_download \nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_"
  },
  {
    "path": "download_pretrain.py",
    "chars": 118,
    "preview": "from huggingface_hub import snapshot_download\nsnapshot_download(repo_id=\"jzq11111/mooncast\", local_dir='./resources/')"
  },
  {
    "path": "en_llmprompt_script_gen.py",
    "chars": 19143,
    "preview": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### Task Description  \nPlease summarize the input document in plain text "
  },
  {
    "path": "inference.py",
    "chars": 22900,
    "preview": "\nimport sys\nsys.path.append(\".\")\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\nfrom modules.aud"
  },
  {
    "path": "modules/audio_detokenizer/audio_detokenizer.py",
    "chars": 12700,
    "preview": "\nimport torch\n\nfrom modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper\nfrom modules.audio_detokenizer.seman"
  },
  {
    "path": "modules/audio_detokenizer/bigvgan_wrapper.py",
    "chars": 3154,
    "preview": "import os\nimport json\nimport logging\n\nimport librosa\nimport torch\n\nfrom modules.audio_detokenizer.vocoder.bigvgan import"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/dit_block.py",
    "chars": 9133,
    "preview": "import torch\nimport torch.nn as nn\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom flash_attn"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/model.py",
    "chars": 12192,
    "preview": "import torch\nimport torch.nn as nn\nimport math\nfrom modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, F"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/ode_wrapper.py",
    "chars": 7809,
    "preview": "import torch\nimport torch.nn as nn\nfrom functools import lru_cache\nimport copy\n\n\n@lru_cache(maxsize=1)\ndef get_cached_ze"
  },
  {
    "path": "modules/audio_detokenizer/flow_matching/scheduler.py",
    "chars": 2562,
    "preview": "import torch\nfrom abc import abstractmethod, ABC\ntry:\n    from torchdyn.core import NeuralODE\n    NEURALODE_INSTALLED = "
  },
  {
    "path": "modules/audio_detokenizer/semantic_fm_prefix_streaming.py",
    "chars": 12226,
    "preview": "import yaml\nimport logging\nimport time\n\nimport os\nimport torch\n\nfrom modules.audio_detokenizer.flow_matching.ode_wrapper"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/activations.py",
    "chars": 4403,
    "preview": "import torch\nfrom torch import nn, sin, pow\nfrom torch.nn import Parameter\n\n\nclass Snake(nn.Module):\n    \"\"\"\n    Impleme"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py",
    "chars": 2569,
    "preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\nimport torch\nimport torch.nn as nn\nfrom .."
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp",
    "chars": 977,
    "preview": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu",
    "chars": 10328,
    "preview": "/* coding=utf-8\n * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h",
    "chars": 893,
    "preview": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py",
    "chars": 2594,
    "preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\nimport os\nimport pathlib\nimport subprocess"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h",
    "chars": 5838,
    "preview": "/* coding=utf-8\n * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py",
    "chars": 200,
    "preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licens"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py",
    "chars": 825,
    "preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licens"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py",
    "chars": 3401,
    "preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licens"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py",
    "chars": 1831,
    "preview": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licens"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/bigvgan.py",
    "chars": 17693,
    "preview": "# Copyright (c) 2024 NVIDIA CORPORATION.\n#   Licensed under the MIT license.\n\n# Adapted from https://github.com/jik876/h"
  },
  {
    "path": "modules/audio_detokenizer/vocoder/utils.py",
    "chars": 3353,
    "preview": "from librosa.filters import mel as librosa_mel_fn\nimport torch\nimport os\nmel_basis_cache = {}\nhann_window_cache = {}\n\nde"
  },
  {
    "path": "modules/audio_tokenizer/audio_tokenizer.py",
    "chars": 3061,
    "preview": "import torch\nimport librosa\nimport yaml\nfrom transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor\nimport s"
  },
  {
    "path": "modules/audio_tokenizer/quantize/__init__.py",
    "chars": 145,
    "preview": "from .vector_quantize import VectorQuantize\nfrom .residual_vq import ResidualVQ\nfrom .factorized_vector_quantize import "
  },
  {
    "path": "modules/audio_tokenizer/quantize/factorized_vector_quantize.py",
    "chars": 4878,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom "
  },
  {
    "path": "modules/audio_tokenizer/quantize/residual_vq.py",
    "chars": 5381,
    "preview": "from typing import Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ein"
  },
  {
    "path": "modules/audio_tokenizer/quantize/vector_quantize.py",
    "chars": 13476,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repe"
  },
  {
    "path": "modules/audio_tokenizer/rep_codec.py",
    "chars": 6464,
    "preview": "import torch\nimport torch.nn as nn\n\n\nfrom modules.audio_tokenizer.quantize import ResidualVQ\nfrom modules.audio_tokenize"
  },
  {
    "path": "modules/audio_tokenizer/transformer.py",
    "chars": 7207,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\n\n\nclass StyleAdap"
  },
  {
    "path": "modules/audio_tokenizer/vocos.py",
    "chars": 30188,
    "preview": "from typing import Optional, Tuple\n\nimport numpy as np\nimport scipy\nimport torch\nfrom torch import nn, view_as_real, vie"
  },
  {
    "path": "modules/tokenizer/tokenizer.py",
    "chars": 7442,
    "preview": "from abc import ABC\nfrom abc import abstractmethod\nimport sentencepiece as spm\nfrom sentencepiece import sentencepiece_m"
  },
  {
    "path": "readme.md",
    "chars": 2933,
    "preview": "# MoonCast: High-Quality Zero-Shot Podcast Generation\n\n<p align=\"center\">\n    <picture>\n        <img src=\"./fig/logo.png"
  },
  {
    "path": "requirements.txt",
    "chars": 194,
    "preview": "torch==2.3.1\ntorchaudio==2.3.1\nsentencepiece==0.2.0\nprotobuf\nnumpy\n\nlibrosa==0.9.1\npyyaml\ntransformers\nsafetensors\neinop"
  },
  {
    "path": "test/test_audio_detokenizer.py",
    "chars": 1356,
    "preview": "import sys\nimport torch\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nfro"
  },
  {
    "path": "test/test_audio_tokenizer.py",
    "chars": 559,
    "preview": "import sys\nsys.path.append('.')\nfrom modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer\nimport torch\n\nif"
  },
  {
    "path": "test/test_tokenizer.py",
    "chars": 1119,
    "preview": "import sys\nsys.path.append('.')\nfrom modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens\n\n\n\nif __name__ =="
  },
  {
    "path": "zh_llmprompt_script_gen.py",
    "chars": 7031,
    "preview": "# INPUT -> BRIEF -> SCRIPT\n\n\nINPUT2BRIEF = '''\n### 任务说明\n请按照以下结构总结输入文件, 普通文本格式。总结应当有创造性,保证信息全面,包含所有有趣、不常见、有价值的观点和信息。\n- **"
  }
]

About this extraction

This page contains the full source code of the jzq2000/MoonCast GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (252.7 KB), approximately 65.2k tokens, and a symbol index with 298 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!