[
  {
    "path": ".gitignore",
    "content": "artefacts/*\nenv_dataspeech/*\n**/__pycache__/*\nwip_scripts/*\nplots/*\n.vscode/*"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 The Hugging Face team.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Data-Speech\n\nData-Speech is a suite of utility scripts designed to tag speech datasets. \n\nIts aim is to provide a simple, clean codebase for applying audio transformations (or annotations) that may be requested as part of the development of speech-based AI models, such as text-to-speech engines.\n\nIts primary use is to reproduce the annotation method from Dan Lyth and Simon King's research paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://arxiv.org/abs/2402.01912), that labels various speaker characteristics with natural language descriptions.\n\nApplying these tools allows us to prepare and release tagged versions of [LibriTTS-R](https://huggingface.co/datasets/parler-tts/libritts-r-filtered-speaker-descriptions), and of [the English version of MLS](https://huggingface.co/datasets/parler-tts/mls-eng-speaker-descriptions).\n\nThis repository is designed to accompany the [Parler-TTS library](https://github.com/huggingface/parler-tts), which contains the inference and training code for Parler-TTS, a new family of high-quality text-to-speech models.\n\n---------\n\n## 📖 Quick Index\n* [Requirements](#set-up)\n* [Annotating datasets to fine-tune Parler-TTS](#annotating-datasets-to-fine-tune-parler-tts)\n* [Annotating datasets from scratch](#annotating-datasets-from-scratch)\n* [Using Data-Speech to filter your speech datasets](#using-data-speech-to-filter-your-speech-datasets)\n* [❓ FAQ](#faq)\n* [Logs](#logs)\n\n\n## Set-up\n\nYou first need to clone this repository before installing requirements.\n\n```sh\ngit clone git@github.com:huggingface/dataspeech.git\ncd dataspeech\npip install -r requirements.txt\n```\n\n## Annotating datasets to fine-tune Parler-TTS\n\nIn the following examples, we'll load 30 hours of audio data from the [Jenny TTS dataset](https://github.com/dioco-group/jenny-tts-dataset), a high-quality mono-speaker TTS dataset, from an Irish female speaker named Jenny.\n\nThe aim here is to create an annotated version of Jenny TTS, in order to fine-tune the [Parler-TTS v1 checkpoint](https://huggingface.co/parler-tts/parler-tts-mini-v1) on this dataset.\n\nThanks to a [script similar to what's described in the FAQ](#how-do-i-use-datasets-that-i-have-with-this-repository), we've uploaded the dataset to the HuggingFace hub, under the name [reach-vb/jenny_tts_dataset](https://huggingface.co/datasets/reach-vb/jenny_tts_dataset).\n\nFeel free to follow the link above to listen to some samples of the Jenny TTS dataset thanks to the hub viewer.\n\n> [!IMPORTANT]\n> Refer to the section [Annotating datasets from scratch](#annotating-datasets-from-scratch) for more detailed explanations of what's going on under-the-hood.\n\nWe'll:\n1. Annotate the Jenny dataset with continuous variables that measures the speech characteristics\n2. Map those annotations to text bins that characterize the speech characteristics.\n3. Create natural language descriptions from those text bins\n\n### 1. Annotate the Jenny dataset\n\nWe'll use [`main.py`](main.py) to get the following continuous variables:\n    - Speaking rate `(nb_phonemes / utterance_length)`\n    - Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) \n    - Reverberation\n    - Speech monotony\n\n```sh\npython main.py \"reach-vb/jenny_tts_dataset\" \\\n  --configuration \"default\" \\\n  --text_column_name \"transcription\" \\\n  --audio_column_name \"audio\" \\\n  --cpu_num_workers 8 \\\n  --rename_column \\\n  --repo_id \"jenny-tts-tags-v1\" \\\n  --apply_squim_quality_estimation\n```\n\nNote that the script will be faster if you have GPUs at your disposal. It will automatically scale-up to every GPUs available in your environnement.\n\nThe resulting dataset will be pushed to the HuggingFace hub under your HuggingFace handle. Mine was pushed to [ylacombe/jenny-tts-tags-v1](https://huggingface.co/datasets/ylacombe/jenny-tts-tags-v1).\n\n### 2. Map annotations to text bins\n\nSince the ultimate goal here is to fine-tune the [Parler-TTS v1 checkpoint](https://huggingface.co/parler-tts/parler-tts-mini-v1) on the Jenny dataset, we want to stay consistent with the text bins of the datasets on which the latter model was trained.\n\nThis is easy to do thanks to the following command:\n\n```sh\npython ./scripts/metadata_to_text.py \\\n    \"ylacombe/jenny-tts-tags-v1\" \\\n    --repo_id \"jenny-tts-tags-v1\" \\\n    --configuration \"default\" \\\n    --cpu_num_workers \"8\" \\\n    --path_to_bin_edges \"./examples/tags_to_annotations/v02_bin_edges.json\" \\\n    --path_to_text_bins \"./examples/tags_to_annotations/v02_text_bins.json\" \\\n    --avoid_pitch_computation \\\n    --apply_squim_quality_estimation\n```\n\nThanks to [`v02_bin_edges.json`](/examples/tags_to_annotations/v02_bin_edges.json), we don't need to recompute bins from scratch and the above script takes a few seconds.\n\nThe resulting dataset will be pushed to the HuggingFace hub under your HuggingFace handle. Mine was push to [ylacombe/jenny-tts-tags-v1](https://huggingface.co/datasets/ylacombe/jenny-tts-tags-v1).\n\nYou can notice that text bins such as `slightly slowly`, `very monotone` have been added to the samples.\n\n### 3. Create natural language descriptions from those text bins\n\nNow that we have text bins associated to the Jenny dataset, the next step is to create natural language descriptions out of the few created features.\n\nHere, we decided to create prompts that use the name `Jenny`, prompts that'll look like the following:\n`In a very expressive voice, Jenny pronounces her words incredibly slowly. There's some background noise in this room with a bit of echo.'`\n\nThis step generally demands more resources and times and should use one or many GPUs.\n\n[`run_prompt_creation_jenny.sh`](examples/prompt_creation/run_prompt_creation_jenny.sh) indicates how to run it on the Jenny dataset:\n\n```sh\npython ./scripts/run_prompt_creation.py \\\n  --speaker_name \"Jenny\" \\\n  --is_single_speaker \\\n  --is_new_speaker_prompt \\\n  --dataset_name \"ylacombe/jenny-tts-tags-v1\" \\\n  --dataset_config_name \"default\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 128 \\\n  --attn_implementation \"sdpa\" \\\n  --output_dir \"./tmp_jenny\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"jenny-tts-tagged-v1\" \\\n  --preprocessing_num_workers 24 \\\n  --dataloader_num_workers 24\n```\n\nAs usual, we precise the dataset name and configuration we want to annotate. `model_name_or_path` should point to a `transformers` model for prompt annotation. You can find a list of such models [here](https://huggingface.co/models?pipeline_tag=text-generation&library=transformers&sort=trending). Here, we used a version of Mistral's 7B model.\n\n> [!NOTE]\n> If you want to use this on a multi-speaker dataset, you'll have to adapt the logic of the script. First, you need to remove the `--is_single_speaker` and `--speaker_name \"Jenny\"` flags.\n> \n> Then, there's two cases:\n> 1. In case you want to associate names to some speakers, you need to pass the speaker id column name, and a JSON file which maps the speaker ids to these names. For example, `--speaker_id_column \"speaker_id\" --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json`. Feel free to take a look at [speaker_ids_to_names.json](examples/prompt_creation/speaker_ids_to_names.json) to get inspiration.\n> 2. In case you don't want to associate names to speakers, you don't have to do anything else. \n\n\n## Annotating datasets from scratch\n\nIn the following examples, we'll load 1,000 hours of labelled audio data from the [LibriTTS-R dataset](https://huggingface.co/datasets/blabble-io/libritts_r) and add annotations using the dataspeech library. The resulting dataset is complete with discrete annotation tags, as well as a coherent audio\ndescription of the spoken audio characteristics.\n\n\nThere are 3 steps to be completed in order to generate annotations:\n1. [Annotate the speech dataset](#predict-annotations) to get the following continuous variables:\n    - Speaking rate `(nb_phonemes / utterance_length)`\n    - Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) \n    - Reverberation\n    - Speech monotony\n2. [Map the previous annotations categorical to discrete keywords bins](#map-continuous-annotations-to-key-words)\n3. [Create natural language descriptions from a set of keywords](#generate-natural-language-descriptions)\n\n\n### 1. Predict annotations\n\nFor the time being, [`main.py`](main.py) can be used to generate speaking rate, SNR, reverberation, PESQ, SI-SDR and pitch estimation. \n\nTo use it, you need a dataset from the [datasets](https://huggingface.co/docs/datasets/v2.17.0/en/index) library, either locally or on the [hub](https://huggingface.co/datasets).\n\n\n```sh\npython main.py \"blabble-io/libritts_r\" \\\n  --configuration \"dev\" \\\n  --output_dir ./tmp_libritts_r_dev/ \\\n  --text_column_name \"text_normalized\" \\\n  --audio_column_name \"audio\" \\\n  --cpu_num_workers 8 \\\n  --rename_column \\\n  --apply_squim_quality_estimation\n```\n\nHere, we've used 8 processes for operations that don't use GPUs, namely to compute the speaking rate. If GPUs were present in the environnement, the operations that can be computed on GPUs - namely pitch, SNR and reverberation estimation - will use every GPUs available in the environnement.\n\nYou can learn more about the arguments you can pass to `main.py` by passing:\n\n```sh\npython main.py --help\n```\n\nIn [`/examples/tagging/run_main_1k.sh`](/examples/tagging/run_main_1k.sh), we scaled up the initial command line to the whole dataset. Note that we've used the `repo_id` argument to push the dataset to the hub, resulting in [this dataset](https://huggingface.co/datasets/ylacombe/libritts-r-text-tags-v3).\n\nThe dataset viewer gives an idea of what has been done, namely:\n- new columns were added:\n    - `utterance_pitch_std`: Gives a measure of the standard deviation of pitch in the utterance.\n    - `utterance_pitch_mean`: Gives a measure of average pitch in the utterance.\n    - `snr`: Speech-to-noise ratio\n    - `c50`: Reverberation estimation\n    - `speaking_rate`\n    - `phonemes`: which was used to compute the speaking rate\n    - `pesq` and `si-sdr`: which measure intelligibility and a proxy of noise, as indicated [here](https://pytorch.org/audio/main/tutorials/squim_tutorial.html)\n- the audio column was removed - this is especially useful when dealing with big datasets, as writing and pushing audio data can become a bottleneck.\n\n![image](https://github.com/ylacombe/dataspeech/assets/52246514/f422a728-f2af-4c8f-bf2a-65c6722bc0c6)\n\n\n### 2. Map continuous annotations to key-words\n\nThe next step is to map the continuous annotations from the previous steps to key-words. To do so, continous annotations are mapped to categorical bins that are then associated to key-words. For example, the speaking rate can be associated to 7 text bins which are: `\"very slowly\", \"quite slowly\", \"slightly slowly\", \"moderate speed\", \"slightly fast\", \"quite fast\", \"very fast\"`.\n\n[`scripts/metadata_to_text.py`](/scripts/metadata_to_text.py) computes bins on aggregated statistics from multiple datasets:\n- A speaker's pitch is calculated by averaging the pitches across its voice clips. The computed pitch estimator is then compared to speakers of the same gender to derive the pitch keyword of the speaker(very high-pitched to very low-pitched).\n- The rest of the keywords are derived by [computing histograms](https://numpy.org/doc/stable/reference/generated/numpy.histogram.html) of the continuous variables over all training samples, from which the extreme values have been eliminated, and associating a keyword with each bin.\n\n```sh\npython ./scripts/metadata_to_text.py \"ylacombe/libritts-r-text-tags-v3+ylacombe/libritts-r-text-tags-v3\" \\\n--configuration \"clean+other\" \\\n--output_dir \"./tmp_tts_clean+./tmp_tts_other\" \\\n--cpu_num_workers \"8\" \\\n--leading_split_for_bins \"train\" \\\n--plot_directory \"./plots/\" \\\n--path_to_text_bins \"./examples/tags_to_annotations/v02_text_bins.json\" \\\n--apply_squim_quality_estimation \\\n```\nNote how we've been able to pass different datasets with different configurations by separating the relevant arguments with `\"+\"`.\n\nBy passing `--repo_id parler-tts/libritts-r-tags-and-text+parler-tts/libritts-r-tags-and-text`, we pushed the resulting dataset to [this hub repository](https://huggingface.co/datasets/parler-tts/libritts-r-tags-and-text).\n\nNote that this step is a bit more subtle than the previous one, as we generally want to collect a wide variety of speech data to compute accurate key-words. \n\nIndeed, some datasets, such as LibriTTS-R, collect data from only one or a few sources; for LibriTTS-R, these are audiobooks, and the process of collecting or processing the data can result in homogeneous data that has little variation. In the case of LibriTTS-R, the data has been cleaned to have little noise, little reverberation, and the audiobooks collected leaves little variety in intonation.\n\nYou can learn more about the arguments you can pass to `main.py` by passing:\n\n```sh\npython main.py --help\n```\n\n### 3. Generate natural language descriptions\n\nNow that we have text bins associated to our datasets, the next step is to create natural language descriptions. To \nachieve this, we pass the discrete features to an LLM, and have it generate a natural language description. This step \ngenerally demands more resources and times and should use one or many GPUs. It can be performed in one of two ways:\n1. Using the [Accelerate](https://huggingface.co/docs/accelerate/index)-based script, [`scripts/run_prompt_creation.py`](/scripts/run_prompt_creation.py), or\n2. Using the [TGI](https://huggingface.co/docs/text-generation-inference/en/index)-based script, [`scripts/run_prompt_creation_llm_swarm.py`](/scripts/run_prompt_creation_llm_swarm.py)\n\nWe recommend you first try the Accelerate script, since it makes no assumptions about the GPU hardware available and is \nthus easier to run. Should you need faster inference, you can switch to the TGI script, which assumes you have a SLURM \ncluster with Docker support.\n\n### 3.1 Accelerate Inference\n\n[`scripts/run_prompt_creation.py`](/scripts/run_prompt_creation.py) relies on [`accelerate`](https://huggingface.co/docs/accelerate/index) and [`transformers`](https://huggingface.co/docs/transformers/index) to generate natural language descriptions from LLMs. \n\n[`examples/prompt_creation/run_prompt_creation_1k.sh`](examples/prompt_creation/run_prompt_creation_1k.sh) indicates how to run it on LibriTTS-R\nwith 8 GPUs in half-precision:\n\n```sh\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"parler-tts/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"meta-llama/Meta-Llama-3-8B-Instruct\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --torch_compile \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\" \\\n  --is_new_speaker_prompt \\\n```\n\nAs usual, we define the dataset name and configuration we want to annotate. `model_name_or_path` should point to a `transformers` model for prompt annotation. You can find a list of such models [here](https://huggingface.co/models?pipeline_tag=text-generation&library=transformers&sort=trending). Here, we used an instruction-tuned version of Meta's LLaMA-3 8B model. Should you use LLaMA or Gemma, you can enable torch compile with the flag `--torch_compile` for up to 1.5x faster inference.\n\nThe folder [`examples/prompt_creation/`](examples/prompt_creation/) contains more examples. \n\nIn particular, (`run_prompt_creation_1k_with_speaker_consistency.sh`)[examples/prompt_creation/run_prompt_creation_1k_with_speaker_consistency.sh] adapts the previous example but introduces speaker consistency. Here, \"speaker consistency\" simply means associating certain speakers with specific names. In this case, all descriptions linked to these speakers will specify their names, rather than generating anonymous descriptions.\n\n\n> [!TIP]\n> Scripts from this library can also be used as a starting point for applying other models to other datasets from the [datasets library](https://huggingface.co/docs/datasets/v2.17.0/en/index) in a large-scale settings.\n> \n> For example, `scripts/run_prompt_creation.py` can be adapted to perform large-scaled inference using other LLMs and prompts.\n\n### 3.2 TGI Inference\n\n[`scripts/run_prompt_creation_llm_swarm.py`](/scripts/run_prompt_creation_llm_swarm.py) relies on [TGI](https://huggingface.co/docs/text-generation-inference/en/index) \nand [LLM-Swarm](https://github.com/huggingface/llm-swarm/tree/main) to generate descriptions from an LLM endpoint.\nCompared to the Accelerate script, it uses continuous-batching, which improves throughput by up to 1.5x. It requires one \nextra dependency, LLM-Swarm:\n\n```sh\npip install git+https://github.com/huggingface/llm-swarm.git\n```\n\n[`examples/prompt_creation_llm_swarm/run_prompt_creation_1k.sh`](examples/prompt_creation_llm_swarm/run_prompt_creation_1k.sh) indicates how to run it on LibriTTS-R\nwith 1 TGI instance:\n\n```sh\npython run_prompt_creation_llm_swarm.py \\\n  --dataset_name \"stable-speech/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --num_instances \"1\" \\\n  --output_dir \"./\" \\\n  --push_to_hub \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\"\n```\n\nNote that the script relies on the SLURM file [`examples/prompt_creation_llm_swarm/tgi_h100.template.slurm`](examples/prompt_creation_llm_swarm/tgi_h100.template.slurm),\nwhich is a template configuration for the Hugging Face H100 cluster. You can update the config based on your cluster.\n\n### To conclude\n\nIn the [`/examples`](/examples/) folder, we applied this recipe to both [MLS Eng](https://huggingface.co/datasets/parler-tts/mls-eng-speaker-descriptions) and [LibriTTS-R](https://huggingface.co/datasets/parler-tts/libritts-r-filtered-speaker-descriptions). The resulting datasets were used to train [Parler-TTS](https://github.com/huggingface/parler-tts), a new text-to-speech model.\n\nThis recipe is both scalable and easily modifiable and will hopefully help the TTS research community explore new ways of conditionning speech synthesis. \n\n## Using Data-Speech to filter your speech datasets\n\nWhile the rest of the README explains how to use this repository to create text descriptions of speech utterances, Data-Speech can also be used to perform filtering on speech datasets.\n\nFor example, you can\n1. Use the [`Predict annotations`](#1-predict-annotations) step to predict SNR and reverberation.\n2. Filter your data sets to retain only the most qualitative samples.\n\nYou could also, to give more examples, filter on a certain pitch level (e.g only low-pitched voices), or a certain speech rate (e.g only fast speech).\n\n## FAQ\n\n### What kind of datasets do I need?\n\nWe rely on the [`datasets`](https://huggingface.co/docs/datasets/v2.17.0/en/index) library, which is optimized for speed and efficiency, and is deeply integrated with the [HuggingFace Hub](https://huggingface.co/datasets) which allows easy sharing and loading.\n\nIn order to use this repository, you need a speech dataset from [`datasets`](https://huggingface.co/docs/datasets/v2.17.0/en/index) with at least one audio column and a text transcription column. Additionally, you also need a gender and a speaker id column, especially if you want to compute pitch.\n\n### How do I use datasets that I have with this repository?\n\nIf you have a local dataset, and want to create a dataset from [`datasets`](https://huggingface.co/docs/datasets/v2.17.0/en/index) to use Data-Speech, you can use the following recipes or refer to the [`dataset` docs](https://huggingface.co/docs/datasets/v2.17.0/en/index) for more complex use-cases.\n\n1. You first need to create a csv file that contains the **full paths** to the audio. The column name for those audio files could be for example `audio`, but you can use whatever you want. You also need a column with the transcriptions of the audio, this column can be named `transcript` but you can use whatever you want.\n\n2. Once you have this csv file, you can load it to a dataset like this:\n```python\nfrom datasets import DatasetDict\n\ndataset = DatasetDict.from_csv({\"train\": PATH_TO_CSV_FILE})\n```\n3. You then need to convert the audio column name to [`Audio`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.Audio) so that `datasets` understand that it deals with audio files.\n```python\nfrom datasets import Audio\ndataset = dataset.cast_column(\"audio\", Audio())\n```\n4. You can then [push the dataset to the hub](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub):\n```python\ndataset.push_to_hub(REPO_ID)\n```\n\nNote that you can make the dataset private by passing [`private=True`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub.private) to the [`push_to_hub`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub) method. Find other possible arguments [here](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub).\n\nWhen using Data-Speech, you can then use `REPO_ID` (replace this by the name you want here and above) as the dataset name.\n\n## Logs\n\n\n* [August 2024]: Updated version of Data-Speech, suited for Parler-TTS v1\n  * New measures: Pesq and SI-SDR, the latter being used for better noise estimation\n  * Improved prompts\n  * Prompt creation can deal with speaker consistency and accents\n* [April 2024]: Release of the first version of Data-Speech \n\n\n## Acknowledgements\n\nThis library builds on top of a number of open-source giants, to whom we'd like to extend our warmest thanks for providing these tools!\n\nSpecial thanks to:\n- Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively, for publishing such a promising and clear research paper: [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://arxiv.org/abs/2402.01912).\n- and the many libraries used, namely [datasets](https://huggingface.co/docs/datasets/v2.17.0/en/index), [brouhaha](https://github.com/marianne-m/brouhaha-vad/blob/main/README.md), [penn](https://github.com/interactiveaudiolab/penn/blob/master/README.md), [g2p](https://github.com/Kyubyong/g2p), [accelerate](https://huggingface.co/docs/accelerate/en/index) and [transformers](https://huggingface.co/docs/transformers/index).\n\n## Citation\n\nIf you found this repository useful, please consider citing this work and also the original Stability AI paper:\n\n```\n@misc{lacombe-etal-2024-dataspeech,\n  author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},\n  title = {Data-Speech},\n  year = {2024},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/ylacombe/dataspeech}}\n}\n```\n\n```\n@misc{lyth2024natural,\n      title={Natural language guidance of high-fidelity text-to-speech with synthetic annotations},\n      author={Dan Lyth and Simon King},\n      year={2024},\n      eprint={2402.01912},\n      archivePrefix={arXiv},\n      primaryClass={cs.SD}\n}\n```\n\n### TODOs\n- [ ] Accent classification training script\n- [ ] Accent classification inference script\n- [x] Better speaking rate estimation with long silence removal\n- [x] Better SNR estimation with other SNR models\n- [ ] Add more annotation categories\n- [ ] Multilingual speaking rate estimation\n\n- [ ] (long term) Benchmark for best audio dataset format\n- [ ] (long term) Compatibility with streaming\n"
  },
  {
    "path": "dataspeech/__init__.py",
    "content": "from .cpu_enrichments import rate_apply\nfrom .gpu_enrichments import pitch_apply, snr_apply, squim_apply"
  },
  {
    "path": "dataspeech/cpu_enrichments/__init__.py",
    "content": "from .rate import rate_apply\n\n"
  },
  {
    "path": "dataspeech/cpu_enrichments/rate.py",
    "content": "from g2p import make_g2p\n\ntransducer = make_g2p('eng', 'eng-ipa')\n\ndef rate_apply(batch, rank=None, audio_column_name=\"audio\", text_column_name=\"text\"):\n    if isinstance(batch[text_column_name], list):  \n        speaking_rates = []\n        phonemes_list = []\n        if \"speech_duration\" in batch:\n            for text, audio_duration in zip(batch[text_column_name], batch[\"speech_duration\"]):\n                phonemes = transducer(text).output_string\n                audio_duration = audio_duration if audio_duration != 0 else 0.01\n                speaking_rate = len(phonemes) / audio_duration\n                speaking_rates.append(speaking_rate)\n                phonemes_list.append(phonemes)\n        else:\n            for text, audio in zip(batch[text_column_name], batch[audio_column_name]):\n                phonemes = transducer(text).output_string\n                \n                sample_rate = audio[\"sampling_rate\"]\n                audio_length = len(audio[\"array\"].squeeze()) / sample_rate\n                \n                speaking_rate = len(phonemes) / audio_length\n\n                \n                speaking_rates.append(speaking_rate)\n                phonemes_list.append(phonemes)\n        \n        batch[\"speaking_rate\"] = speaking_rates\n        batch[\"phonemes\"] = phonemes_list\n    else:\n        phonemes = transducer(batch[text_column_name]).output_string\n        if \"speech_duration\" in batch:\n            audio_length = batch[\"speech_duration\"] if batch[\"speech_duration\"] != 0 else 0.01\n        else:\n            sample_rate = batch[audio_column_name][\"sampling_rate\"]\n            audio_length = len(batch[audio_column_name][\"array\"].squeeze()) / sample_rate\n\n        speaking_rate = len(phonemes) / audio_length\n        \n        batch[\"speaking_rate\"] = speaking_rate\n        batch[\"phonemes\"] = phonemes\n\n    return batch"
  },
  {
    "path": "dataspeech/gpu_enrichments/__init__.py",
    "content": "from .pitch import pitch_apply\nfrom .snr_and_reverb import snr_apply\nfrom .squim import squim_apply"
  },
  {
    "path": "dataspeech/gpu_enrichments/pitch.py",
    "content": "import torch \nimport penn\n\n\n# Here we'll use a 10 millisecond hopsize\nhopsize = .01\n\n# Provide a sensible frequency range given your domain and model\nfmin = 30.\nfmax = 1000.\n\n# Select a checkpoint to use for inference. Selecting None will\n# download and use FCNF0++ pretrained on MDB-stem-synth and PTDB\ncheckpoint = None\n\n# Centers frames at hopsize / 2, 3 * hopsize / 2, 5 * hopsize / 2, ...\ncenter = 'half-hop'\n\n# (Optional) Linearly interpolate unvoiced regions below periodicity threshold\ninterp_unvoiced_at = .065\n\n\ndef pitch_apply(batch, rank=None, audio_column_name=\"audio\", output_column_name=\"utterance_pitch\", penn_batch_size=4096):\n    if isinstance(batch[audio_column_name], list):  \n        utterance_pitch_mean = []\n        utterance_pitch_std = []\n        for sample in batch[audio_column_name]:\n            # Infer pitch and periodicity\n            pitch, periodicity = penn.from_audio(\n                torch.tensor(sample[\"array\"][None, :]).float(),\n                sample[\"sampling_rate\"],\n                hopsize=hopsize,\n                fmin=fmin,\n                fmax=fmax,\n                checkpoint=checkpoint,\n                batch_size=penn_batch_size,\n                center=center,\n                interp_unvoiced_at=interp_unvoiced_at,\n                gpu=(rank or 0)% torch.cuda.device_count() if torch.cuda.device_count() > 0 else rank\n                )\n            \n            utterance_pitch_mean.append(pitch.mean().cpu())\n            utterance_pitch_std.append(pitch.std().cpu())\n            \n        batch[f\"{output_column_name}_mean\"] = utterance_pitch_mean \n        batch[f\"{output_column_name}_std\"] = utterance_pitch_std \n    else:\n        sample = batch[audio_column_name]\n        pitch, periodicity = penn.from_audio(\n                torch.tensor(sample[\"array\"][None, :]).float(),\n                sample[\"sampling_rate\"],\n                hopsize=hopsize,\n                fmin=fmin,\n                fmax=fmax,\n                checkpoint=checkpoint,\n                batch_size=penn_batch_size,\n                center=center,\n                interp_unvoiced_at=interp_unvoiced_at,\n                gpu=(rank or 0)% torch.cuda.device_count() if torch.cuda.device_count() > 0 else rank\n                )        \n        batch[f\"{output_column_name}_mean\"] = pitch.mean().cpu()\n        batch[f\"{output_column_name}_std\"] = pitch.std().cpu()\n\n    return batch\n"
  },
  {
    "path": "dataspeech/gpu_enrichments/snr_and_reverb.py",
    "content": "from pyannote.audio import Model\nfrom pathlib import Path\nfrom brouhaha.pipeline import RegressiveActivityDetectionPipeline\nimport torch \nfrom huggingface_hub import hf_hub_download\nimport numpy as np\n\nmodel = None\nratio = 16000/270\n\ndef snr_apply(batch, rank=None, audio_column_name=\"audio\", batch_size=32):\n    global model\n    if model is None:\n        model = Model.from_pretrained(\n            Path(hf_hub_download(repo_id=\"ylacombe/brouhaha-best\", filename=\"best.ckpt\")),\n            strict=False,\n        )\n    if rank is not None or torch.cuda.device_count() > 0:\n        # move the model to the right GPU if not there already\n        device = f\"cuda:{(rank or 0)% torch.cuda.device_count()}\"\n        # move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway\n        model.to(device)\n\n    pipeline = RegressiveActivityDetectionPipeline(segmentation=model, batch_size = batch_size)\n    if rank:\n        pipeline.to(torch.device(device))\n    \n    device = pipeline._models[\"segmentation\"].device\n\n    if isinstance(batch[audio_column_name], list):  \n        snr = []\n        c50 = []\n        vad_durations = []\n        for sample in batch[audio_column_name]:\n            res = pipeline({\"sample_rate\": sample[\"sampling_rate\"],\n                            \"waveform\": torch.tensor(sample[\"array\"][None, :]).to(device).float()})\n            \n            mask = np.full(res[\"snr\"].shape, False)\n            for (segment, _) in res[\"annotation\"].itertracks():\n                start = int(segment.start * ratio)\n                end = int(segment.end * ratio)\n                mask[start:end] = True\n            mask =  (~((res[\"snr\"] == 0.0) & (res[\"c50\"] == 0.0)) & mask)\n\n            vad_duration = sum(map(lambda x: x[0].duration, res[\"annotation\"].itertracks()))\n            \n            snr.append(res[\"snr\"][mask].mean())\n            c50.append(res[\"c50\"][mask].mean())\n            vad_durations.append(np.float32(vad_duration))\n        \n        # 16ms window\n        batch[\"snr\"] = snr\n        batch[\"c50\"] = c50\n        batch[\"speech_duration\"] = vad_durations\n        \n    else:\n        res = pipeline({\"sample_rate\": batch[audio_column_name][\"sampling_rate\"],\n                        \"waveform\": torch.tensor(batch[audio_column_name][\"array\"][None, :]).to(device).float()})\n        \n        mask = np.full(res[\"snr\"].shape, False)\n        for (segment, _) in res[\"annotation\"].itertracks():\n            start = int(segment.start * ratio)\n            end = int(segment.end * ratio)\n            mask[start:end] = True\n        mask =  (~((res[\"snr\"] == 0.0) & (res[\"c50\"] == 0.0)) & mask)\n\n        vad_duration = sum(map(lambda x: x[0].duration, res[\"annotation\"].itertracks()))     \n        \n        batch[\"snr\"] = res[\"snr\"][mask].mean()\n        batch[\"c50\"] = res[\"c50\"][mask].mean()\n        batch[\"speech_duration\"] = vad_duration\n        \n    return batch"
  },
  {
    "path": "dataspeech/gpu_enrichments/squim.py",
    "content": "from torchaudio.pipelines import SQUIM_OBJECTIVE\nimport torch \nimport torchaudio\n\nmodel = None\nmax_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate\n\ndef squim_apply(batch, rank=None, audio_column_name=\"audio\"):\n    global model\n    if model is None:\n        model = SQUIM_OBJECTIVE.get_model()\n    if rank is not None or torch.cuda.device_count() > 0:\n        # move the model to the right GPU if not there already\n        device = f\"cuda:{(rank or 0)% torch.cuda.device_count()}\"\n        # move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway\n        model.to(device)\n    else:\n        device = \"cpu\"\n    if isinstance(batch[audio_column_name], list):  \n        sdr = []\n        pesq = []\n        stoi = []\n        for sample in batch[audio_column_name]:\n            waveform = torchaudio.functional.resample(torch.tensor(sample[\"array\"])[None, :].to(device).float(), sample[\"sampling_rate\"], SQUIM_OBJECTIVE.sample_rate)\n            with torch.no_grad():\n                waveform = waveform[:, :min(max_audio_length, waveform.shape[1])]\n                stoi_sample, pesq_sample, sdr_sample = model(waveform)\n            sdr.append(sdr_sample.cpu()[0])\n            pesq.append(pesq_sample.cpu()[0])\n            stoi.append(stoi_sample.cpu()[0])\n\n        batch[\"sdr\"] = sdr\n        batch[\"pesq\"] = pesq\n        batch[\"stoi\"] = stoi\n    else:\n    \n        waveform = torchaudio.functional.resample(torch.tensor(batch[audio_column_name][\"array\"][None, :]).to(device).float(), batch[audio_column_name][\"sampling_rate\"], SQUIM_OBJECTIVE.sample_rate)\n        with torch.no_grad():\n            stoi_sample, pesq_sample, sdr_sample = model(waveform)\n        batch[\"sdr\"] = sdr_sample.cpu()[0]\n        batch[\"pesq\"] = pesq_sample.cpu()[0]\n        batch[\"stoi\"] = stoi_sample.cpu()[0]\n        # TODO\n    return batch\n\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_10k.sh",
    "content": "#!/usr/bin/env bash\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/libritts_r_tags_tagged_10k\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./libritts_r_tags_tagged_10k_generated\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"parler-tts/libritts_r_tags_tagged_10k_generated\"\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/libritts_r_tags_tagged_10k\" \\\n  --dataset_config_name \"other\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./libritts_r_tags_tagged_10k_generated\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"parler-tts/libritts_r_tags_tagged_10k_generated\"\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/mls-eng-10k-tags_tagged_10k\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./mls-eng-10k-tags_tagged_10k_generated\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"parler-tts/mls-eng-10k-tags_tagged_10k_generated\"\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_1k.sh",
    "content": "#!/usr/bin/env bash\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"parler-tts/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\"\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"parler-tts/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"other\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\"\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_1k_with_speaker_consistency.sh",
    "content": "#!/usr/bin/env bash\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"parler-tts/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\" \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"parler-tts/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"other\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --hub_dataset_id \"parler-tts/libritts-r-tags-and-text-generated\" \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_45k.sh",
    "content": "#!/usr/bin/env bash\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/libritts-r-text-tags-v4\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./libritts_r_descriptions_clean\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n  --hub_dataset_id \"ylacombe/libritts-r-descriptions-10k-v5-without-accents\"\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/libritts-r-text-tags-v4\" \\\n  --dataset_config_name \"other\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./libritts_r_descriptions_other\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n  --hub_dataset_id \"ylacombe/libritts-r-descriptions-10k-v5-without-accents\"\n\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \\\n  --dataset_name \"ylacombe/mls-eng-text-tags-v5\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 64 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 4 \\\n  --output_dir \"./mls-eng-descriptions\" \\\n  --push_to_hub \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n  --hub_dataset_id \"parler-tts/mls-eng-speaker-descriptions\"\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_dummy.sh",
    "content": "#!/usr/bin/env bash\n\npython run_prompt_creation.py \\\n  --dataset_name \"ylacombe/libritts_r_tags_and_text\" \\\n  --dataset_config_name \"clean\" \\\n  --dataset_split_name \"dev.clean\" \\\n  --model_name_or_path \"hf-internal-testing/tiny-random-LlamaForCausalLM\" \\\n  --per_device_eval_batch_size 2 \\\n  --attn_implementation \"sdpa\" \\\n  --torch_compile \\\n  --max_eval_samples 128 \\\n  --max_new_tokens 4 \\\n  --dataloader_num_workers 0 \\\n  --save_steps 32 \\\n  --save_total_limit 2 \\\n  --output_dir \"./\" \\\n  --do_sample False\n"
  },
  {
    "path": "examples/prompt_creation/run_prompt_creation_jenny.sh",
    "content": "#!/usr/bin/env bash\npython ./scripts/run_prompt_creation.py \\\n  --speaker_name \"Jenny\" \\\n  --is_single_speaker \\\n  --is_new_speaker_prompt \\\n  --dataset_name \"ylacombe/jenny-tts-tags-v1\" \\\n  --dataset_config_name \"default\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --per_device_eval_batch_size 128 \\\n  --attn_implementation \"sdpa\" \\\n  --dataloader_num_workers 8 \\\n  --output_dir \"./tmp_jenny\" \\\n  --load_in_4bit \\\n  --push_to_hub \\\n  --hub_dataset_id \"jenny-tts-tagged-v1\" \\\n  --preprocessing_num_workers 48 \\\n  --dataloader_num_workers 24"
  },
  {
    "path": "examples/prompt_creation/speaker_ids_to_names.json",
    "content": "{\n    \"192\": \"Brenda\",\n    \"274\": \"Eileen\",\n    \"392\": \"Joy\",\n    \"409\": \"James\",\n    \"412\": \"Eric\",\n    \"505\": \"Aaron\",\n    \"887\": \"Emily\",\n    \"1088\": \"Laura\",\n    \"1112\": \"Gary\",\n    \"1355\": \"Jon\",\n    \"1502\": \"Lea\",\n    \"1509\": \"Karen\",\n    \"1646\": \"Rick\",\n    \"2588\": \"David\",\n    \"2769\": \"Jordan\",\n    \"2990\": \"Mike\",\n    \"3114\": \"Yann\",\n    \"4195\": \"Lauren\",\n    \"4297\": \"Rose\",\n    \"4397\": \"Will\",\n    \"4719\": \"Jason\",\n    \"5514\": \"Naomie\",\n    \"5724\": \"Alisa\",\n    \"5746\": \"Patrick\",\n    \"5909\": \"Jerry\",\n    \"6054\": \"Tina\",\n    \"6904\": \"Jenna\",\n    \"6912\": \"Bill\",\n    \"7190\": \"Tom\",\n    \"7434\": \"Carol\",\n    \"7584\": \"Barbara\",\n    \"7789\": \"Rebecca\",\n    \"8684\": \"Anna\",\n    \"8791\": \"Bruce\"\n  }"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/nginx.template.conf",
    "content": "events {\n    # resolve \"worker_connections are not enough while connecting to upstream\"\n    # https://stackoverflow.com/questions/28265717/worker-connections-are-not-enough\n    worker_connections 100000;\n}\n\nhttp {\n    upstream mytgi {\n        least_conn;\n        {{servers}}\n    }\n\n    server {\n        listen {{port}};\n\n        location / {\n            proxy_pass http://mytgi;\n            proxy_read_timeout 300s;  # Increase this to 300 seconds (5 minutes)\n            proxy_connect_timeout 60s;  # Increase this to 60 seconds (1 minute)\n        }\n    }\n}\n\n\n# sudo docker run  -p 80:80 --network host -v $(pwd)/nginx.conf:/etc/nginx/nginx.conf nginx\n# curl 127.0.0.1:80/generate \\\n#     -X POST \\\n#     -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n#     -H 'Content-Type: application/json'"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/run_prompt_creation_10k.sh",
    "content": "#!/usr/bin/env bash\n\npython run_prompt_creation_llm_swarm.py \\\n  --dataset_name \"ylacombe/mls-eng-10k-text-tags-v2\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --num_instances \"2\" \\\n  --output_dir \"./mls-eng-10k-descriptions-v2\" \\\n  --push_to_hub \\\n  --hub_dataset_id \"stable-speech/mls-eng-10k-descriptions-v2\"\n"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/run_prompt_creation_1k.sh",
    "content": "#!/usr/bin/env bash\n\npython run_prompt_creation_llm_swarm.py \\\n  --dataset_name \"stable-speech/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --output_dir \"./\"\n\npython run_prompt_creation_llm_swarm.py \\\n  --dataset_name \"stable-speech/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"other\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --output_dir \"./\"\n"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/run_prompt_creation_dummy.sh",
    "content": "#!/usr/bin/env bash\n\npython run_prompt_creation_llm_swarm.py \\\n  --dataset_name \"stable-speech/libritts-r-tags-and-text\" \\\n  --dataset_config_name \"clean\" \\\n  --dataset_split_name \"train.clean.100\" \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --output_dir \"./libritts-r-descriptions\"\n"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/run_prompt_creation_full_mls.sh",
    "content": "#!/usr/bin/env bash\n\npython ./run_prompt_creation_llm_swarm.py \\\n  --dataset_name 'ylacombe/mls-eng-text-tags-v5' \\\n  --dataset_config_name 'default' \\\n  --model_name_or_path \"mistralai/Mistral-7B-Instruct-v0.2\" \\\n  --num_instances \"8\" \\\n  --output_dir \"./tmp_mls_with_accent\" \\\n  --push_to_hub \\\n  --hub_dataset_id 'ylacombe/mls-eng-descriptions-v5' \\\n  --temperature 1.2 \\\n  --is_new_speaker_prompt \\\n  --speaker_id_column 'speaker_id' \\\n  --speaker_ids_to_name_json ./examples/prompt_creation/speaker_ids_to_names.json \\\n  --accent_column 'accent'"
  },
  {
    "path": "examples/prompt_creation_llm_swarm/tgi_h100.template.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=llm-swarm\n#SBATCH --partition hopper-prod\n#SBATCH --gpus={{gpus}}\n#SBATCH --cpus-per-task=12\n#SBATCH --mem-per-cpu=11G\n#SBATCH -o slurm/logs/%x_%j.out\n\n# START EDIT\nsource ~/.bashrc\nVOLUME=\"/fsx/yoach/.cache\"\n# END EDIT\n\nexport model={{model}}\nexport revision={{revision}}\n\nfunction unused_port() {\n    N=${1:-1}\n    comm -23 \\\n        <(seq \"1025\" \"65535\" | sort) \\\n        <(ss -Htan |\n            awk '{print $4}' |\n            cut -d':' -f2 |\n            sort -u) |\n        shuf |\n        head -n \"$N\"\n}\nexport PORT=$(unused_port)\n\nif [ -z \"$HUGGING_FACE_HUB_TOKEN\" ]; then\n    # try reading from file\n    export HUGGING_FACE_HUB_TOKEN=$(cat \"${HF_HOME}\"/token)\nfi\n\necho \"Starting TGI container port $PORT\"\necho \"http://$(hostname -I | awk '{print $1}'):$PORT\" >> {{slurm_hosts_path}}\n\n# unset cache dirs to avoid pyxis having host env var somehow get into the container\nunset HF_HUB_CACHE HF_ASSETS_CACHE HF_DATASETS_CACHE HF_MODULES_CACHE\nsrun --container-image='ghcr.io#huggingface/text-generation-inference:2.0' \\\n    --container-env=HUGGING_FACE_HUB_TOKEN,PORT \\\n    --container-mounts=\"${VOLUME}:/data\" \\\n    --no-container-mount-home \\\n    --qos normal \\\n    /usr/local/bin/text-generation-launcher \\\n    --model-id $model \\\n    --revision $revision \\\n    --max-concurrent-requests 2500 \\\n    --max-total-tokens {{model_max_length}} \\\n    --max-input-length {{model_input_length}} \\\n    --max-batch-prefill-tokens {{model_max_length}} \\\n\necho \"End of job\""
  },
  {
    "path": "examples/tagging/run_main_10k.sh",
    "content": "#!/usr/bin/env bash\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"clean\" \\\n    --output_dir ./tmp_libritts_r_clean/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts_r_tags\"\\\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"other\" \\\n    --output_dir ./tmp_libritts_r_other/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts_r_tags\"\\\n\npython main.py \"parler-tts/mls_eng_10k\" \\\n    --output_dir ./tmp_mls_eng_10k/ \\\n    --text_column_name \"transcript\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/mls_eng_10k_tags\"\\"
  },
  {
    "path": "examples/tagging/run_main_1k.sh",
    "content": "#!/usr/bin/env bash\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"clean\" \\\n    --output_dir ./tmp_libritts_r_clean/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts-r-text-tags-v3\"\\\n    --apply_squim_quality_estimation \\\n\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"other\" \\\n    --output_dir ./tmp_libritts_r_other/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts-r-text-tags-v3\"\\\n    --apply_squim_quality_estimation \\\n\n"
  },
  {
    "path": "examples/tagging/run_main_45k.sh",
    "content": "#!/usr/bin/env bash\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"clean\" \\\n    --output_dir ./tmp_libritts_r_clean/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts-r-text-tags-v3\"\\\n    --apply_squim_quality_estimation \\\n\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"other\" \\\n    --output_dir ./tmp_libritts_r_other/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/libritts-r-text-tags-v3\"\\\n    --apply_squim_quality_estimation \\\n\npython main.py \"parler-tts/mls_eng\" \\\n    --output_dir ./tmp_mls_eng/ \\\n    --text_column_name \"transcript\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 32 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\\n    --repo_id \"ylacombe/mls-eng-tags-v4\"\\\n    --apply_squim_quality_estimation \\\n"
  },
  {
    "path": "examples/tagging/run_main_dummy.sh",
    "content": "#!/usr/bin/env bash\n\npython main.py \"blabble-io/libritts_r\" \\\n    --configuration \"dev\" \\\n    --output_dir ./tmp_libritts_r_dev/ \\\n    --text_column_name \"text_normalized\" \\\n    --audio_column_name \"audio\" \\\n    --cpu_num_workers 8 \\\n    --num_workers_per_gpu 4 \\\n    --rename_column \\"
  },
  {
    "path": "examples/tags_to_annotations/run_metadata_to_text_10k.sh",
    "content": "#!/usr/bin/env bash\n\npython ./scripts/metadata_to_text.py \"ylacombe/mls-eng-10k-tags+ylacombe/libritts_r_tags+ylacombe/libritts_r_tags\" \\\n    --configuration \"default+clean+other\" \\\n    --output_dir \"./tmp_mls+./tmp_tts_clean+./tmp_tts_other\" \\\n    --cpu_num_workers \"8\" \\\n    --leading_split_for_bins \"train\" \\\n    --plot_directory \"./plots/\" \\\n    --save_bin_edges \"./examples/tags_to_annotations/v01_bin_edges.json\" \\\n    --only_save_plot\n"
  },
  {
    "path": "examples/tags_to_annotations/run_metadata_to_text_10k_v02.sh",
    "content": "#!/usr/bin/env bash\n\npython ./scripts/metadata_to_text.py \"ylacombe/mls-eng-10k-tags+ylacombe/libritts_r_tags+ylacombe/libritts_r_tags\" \\\n    --configuration \"default+clean+other\" \\\n    --output_dir \"./tmp_mls+./tmp_tts_clean+./tmp_tts_other\" \\\n    --cpu_num_workers \"8\" \\\n    --leading_split_for_bins \"train\" \\\n    --plot_directory \"./plots/\" \\\n    --save_bin_edges \"./examples/tags_to_annotations/v02_bin_edges.json\" \\\n    --path_to_text_bins \".examples/tags_to_annotations/v02_text_bins.json\" \\\n    --pitch_std_tolerance \"1.5\"\\\n    --reverberation_std_tolerance \"8.\"\\\n    --speech_monotony_std_tolerance \"2.\"\\\n    --speaking_rate_std_tolerance \"5.5\"\\\n    --snr_std_tolerance \"3.5\"\\\n    --only_save_plot"
  },
  {
    "path": "examples/tags_to_annotations/run_metadata_to_text_for_finetuning.sh",
    "content": "#!/usr/bin/env bash\n\npython ./scripts/metadata_to_text.py \\\n    \"ylacombe/jenny-tts-tags-v1\" \\\n    --repo_id \"jenny-tts-tags-v1\" \\\n    --configuration \"default\" \\\n    --cpu_num_workers \"8\" \\\n    --path_to_bin_edges \"./examples/tags_to_annotations/v02_bin_edges.json\" \\\n    --path_to_text_bins \"./examples/tags_to_annotations/v02_text_bins.json\" \\\n    --avoid_pitch_computation \\\n    --apply_squim_quality_estimation \\\n\n"
  },
  {
    "path": "examples/tags_to_annotations/v01_bin_edges.json",
    "content": "{\"speaking_rate\": [3.508771929824561, 6.187242299296628, 8.865712668768696, 11.544183038240764, 14.22265340771283, 16.901123777184896, 19.579594146656966, 22.258064516129032], \"noise\": [27.179607391357422, 33.90050179617746, 40.62139620099749, 47.342290605817524, 54.063185010637554, 60.78407941545759, 67.50497382027763, 74.22586822509766], \"reverberation\": [30.498437881469727, 34.706024169921875, 38.91361045837402, 43.12119674682617, 47.32878303527832, 51.53636932373047, 55.74395561218262, 59.951541900634766], \"speech_monotony\": [0.0, 17.430070059640066, 34.86014011928013, 52.2902101789202, 69.72028023856026, 87.15035029820032, 104.5804203578404, 122.01049041748047], \"pitch_bins_male\": [74.04898071289062, 88.6379623413086, 103.22694396972656, 117.81592559814453, 132.4049072265625, 146.993896484375, 161.58287048339844, 176.17185974121094], \"pitch_bins_female\": [130.46119689941406, 149.0537567138672, 167.64630126953125, 186.23886108398438, 204.83140563964844, 223.42396545410156, 242.01651000976562, 260.60906982421875]}"
  },
  {
    "path": "examples/tags_to_annotations/v01_text_bins.json",
    "content": "{\n    \"speaker_rate_bins\":\n        [\"very slowly\", \"quite slowly\", \"slightly slowly\", \"moderate speed\", \"slightly fast\", \"quite fast\", \"very fast\"],\n    \"snr_bins\":\n        [\"very noisy\", \"quite noisy\", \"slightly noisy\", \"moderate ambient sound\", \"slightly clear\", \"quite clear\", \"very clear\"],\n    \"reverberation_bins\":\n        [\"very roomy sounding\", \"quite roomy sounding\", \"slightly roomy sounding\", \"moderate reverberation\", \"slightly confined sounding\", \"quite confined sounding\", \"very confined sounding\"],\n    \"utterance_level_std\":\n        [\"very monotone\", \"quite monotone\", \"slightly monotone\", \"moderate intonation\", \"slightly expressive\", \"quite expressive\", \"very expressive\"],\n    \"speaker_level_pitch_bins\":\n        [\"very low pitch\", \"quite low pitch\", \"slightly low pitch\", \"moderate pitch\", \"slightly high pitch\", \"quite high pitch\", \"very high pitch\"]\n}"
  },
  {
    "path": "examples/tags_to_annotations/v02_bin_edges.json",
    "content": "{\n    \"speaking_rate\": [0.0, 3.8258038258038254, 7.651607651607651, 11.477411477411476, 15.303215303215302, 19.129019129019127, 22.95482295482295, 26.78062678062678], \n    \"noise\": [17.12751579284668, 25.4012325831822, 33.67494937351772, 41.94866616385323, 50.22238295418875, 58.49609974452427, 66.76981653485979, 75.04353332519531], \n    \"reverberation\": [10, 35, 45, 55, 59, 60], \n    \"speech_monotony\": [0.0, 20.37920924595424, 40.75841849190848, 70, 90, 142.6544647216797], \n    \"pitch_bins_male\": [64.6531982421875, 81.66683959960938, 98.68048095703125, 115.69412231445312, 132.707763671875, 149.72140502929688, 166.73504638671875, 183.74868774414062], \n    \"pitch_bins_female\": [120.17855072021484, 141.6242690945264, 163.06998746883795, 184.51570584314953, 205.96142421746106, 227.40714259177264, 248.8528609660842, 270.29857934039575], \n    \"si-sdr\": [-17.804332733154297, -0.40644073486328125, 10, 20, 25, 28, 34.38934326171875], \n    \"pesq\": [1, 1.7, 2.4, 3.1, 3.6, 4, 4.499948978424072]\n}"
  },
  {
    "path": "examples/tags_to_annotations/v02_text_bins.json",
    "content": "{\n    \"speaker_rate_bins\":\n        [\"very slowly\", \"slowly\", \"slightly slowly\", \"moderate speed\", \"slightly fast\", \"fast\", \"very fast\"],\n    \"snr_bins\":\n        [\"very noisy\", \"noisy\", \"slightly noisy\", \"balanced in clarity\", \"slightly clean\", \"clean\", \"very clean\"],\n    \"reverberation_bins\":\n        [\"very distant-sounding\", \"distant-sounding\", \"slightly distant-sounding\", \"slightly close-sounding\", \"very close-sounding\"],\n    \"utterance_level_std\":\n        [\"very monotone\", \"monotone\", \"slightly expressive and animated\", \"expressive and animated\", \"very expressive and animated\"],\n    \"speaker_level_pitch_bins\":\n        [\"very low-pitch\", \"low-pitch\", \"slightly low-pitch\", \"moderate pitch\", \"slightly high-pitch\", \"high-pitch\", \"very high-pitch\"]\n}"
  },
  {
    "path": "main.py",
    "content": "from datasets import load_dataset, Audio\nfrom multiprocess import set_start_method\nfrom dataspeech import rate_apply, pitch_apply, snr_apply, squim_apply\nimport torch\nimport argparse\n\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Path or name of the dataset. See: https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/loading_methods#datasets.load_dataset.path\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use, if necessary.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dataset on disk with this path.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the dataset to the hub.\")\n    parser.add_argument(\"--audio_column_name\", default=\"audio\", type=str, help=\"Column name of the audio column to be enriched.\")\n    parser.add_argument(\"--text_column_name\", default=\"text\", type=str, help=\"Text column name.\")\n    parser.add_argument(\"--rename_column\", action=\"store_true\", help=\"If activated, rename audio and text column names to 'audio' and 'text'. Useful if you want to merge datasets afterwards.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers for transformations that don't use GPUs or if no GPU are available.\")\n    parser.add_argument(\"--cpu_writer_batch_size\", default=1000, type=int, help=\"writer_batch_size for transformations that don't use GPUs. See: https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Dataset.map.writer_batch_size\")\n    parser.add_argument(\"--batch_size\", default=2, type=int, help=\"This parameters specify how many samples are passed by workers for operations that are using GPUs.\")\n    parser.add_argument(\"--penn_batch_size\", default=4096, type=int, help=\"Pitch estimation chunks audio into smaller pieces and processes them in batch. This specify the batch size. If you are using a gpu, pick a batch size that doesn't cause memory errors.\")\n    parser.add_argument(\"--num_workers_per_gpu_for_pitch\", default=1, type=int, help=\"Number of workers per GPU for the pitch estimation if GPUs are available. Defaults to 1 if some are avaiable. Useful if you want multiple processes per GPUs to maximise GPU usage.\")\n    parser.add_argument(\"--num_workers_per_gpu_for_snr\", default=1, type=int, help=\"Number of workers per GPU for the SNR and reverberation estimation if GPUs are available. Defaults to 1 if some are avaiable. Useful if you want multiple processes per GPUs to maximise GPU usage.\")\n    parser.add_argument(\"--apply_squim_quality_estimation\", action=\"store_true\", help=\"If set, will also use torchaudio-squim estimation (SI-SNR, STOI and PESQ).\")\n    parser.add_argument(\"--num_workers_per_gpu_for_squim\", default=1, type=int, help=\"Number of workers per GPU for the SI-SNR, STOI and PESQ estimation if GPUs are available. Defaults to 1 if some are avaiable. Useful if you want multiple processes per GPUs to maximise GPU usage.\")\n\n\n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration, num_proc=args.cpu_num_workers,)\n    else:\n        dataset = load_dataset(args.dataset_name, num_proc=args.cpu_num_workers,)\n        \n    audio_column_name = \"audio\" if args.rename_column else args.audio_column_name\n    text_column_name = \"text\" if args.rename_column else args.text_column_name\n    if args.rename_column:\n        dataset = dataset.rename_columns({args.audio_column_name: \"audio\", args.text_column_name: \"text\"})\n        \n\n    if args.apply_squim_quality_estimation:\n        print(\"Compute SI-SDR, PESQ, STOI\")\n        squim_dataset = dataset.map(\n            squim_apply,\n            batched=True,\n            batch_size=args.batch_size,\n            with_rank=True if torch.cuda.device_count()>0 else False,\n            num_proc=torch.cuda.device_count()*args.num_workers_per_gpu_for_squim if torch.cuda.device_count()>0 else args.cpu_num_workers,\n            remove_columns=[audio_column_name], # tricks to avoid rewritting audio\n            fn_kwargs={\"audio_column_name\": audio_column_name,},\n        )\n\n    print(\"Compute pitch\")\n    pitch_dataset = dataset.cast_column(audio_column_name, Audio(sampling_rate=16_000)).map(\n        pitch_apply,\n        batched=True,\n        batch_size=args.batch_size,\n        with_rank=True if torch.cuda.device_count()>0 else False,\n        num_proc=torch.cuda.device_count()*args.num_workers_per_gpu_for_pitch if torch.cuda.device_count()>0 else args.cpu_num_workers,\n        remove_columns=[audio_column_name], # tricks to avoid rewritting audio\n        fn_kwargs={\"audio_column_name\": audio_column_name, \"penn_batch_size\": args.penn_batch_size},\n    )\n\n    print(\"Compute snr and reverb\")\n    snr_dataset = dataset.map(\n        snr_apply,\n        batched=True,\n        batch_size=args.batch_size,\n        with_rank=True if torch.cuda.device_count()>0 else False,\n        num_proc=torch.cuda.device_count()*args.num_workers_per_gpu_for_snr if torch.cuda.device_count()>0 else args.cpu_num_workers,\n        remove_columns=[audio_column_name], # tricks to avoid rewritting audio\n        fn_kwargs={\"audio_column_name\": audio_column_name},\n    )\n    \n    print(\"Compute speaking rate\")\n    if \"speech_duration\" in snr_dataset[next(iter(snr_dataset.keys()))].features:    \n        rate_dataset = snr_dataset.map(\n            rate_apply,\n            with_rank=False,\n            num_proc=args.cpu_num_workers,\n            writer_batch_size= args.cpu_writer_batch_size,\n            fn_kwargs={\"audio_column_name\": audio_column_name, \"text_column_name\": text_column_name},\n        )\n    else:\n        rate_dataset = dataset.map(\n            rate_apply,\n            with_rank=False,\n            num_proc=args.cpu_num_workers,\n            writer_batch_size= args.cpu_writer_batch_size,\n            remove_columns=[audio_column_name], # tricks to avoid rewritting audio\n            fn_kwargs={\"audio_column_name\": audio_column_name, \"text_column_name\": text_column_name},\n        )\n    \n    for split in dataset.keys():\n        dataset[split] = pitch_dataset[split].add_column(\"snr\", snr_dataset[split][\"snr\"]).add_column(\"c50\", snr_dataset[split][\"c50\"])\n        if \"speech_duration\" in snr_dataset[split]:\n            dataset[split] = dataset[split].add_column(\"speech_duration\", snr_dataset[split][\"speech_duration\"])\n        dataset[split] = dataset[split].add_column(\"speaking_rate\", rate_dataset[split][\"speaking_rate\"]).add_column(\"phonemes\", rate_dataset[split][\"phonemes\"])\n        if args.apply_squim_quality_estimation:\n            dataset[split] = dataset[split].add_column(\"stoi\", squim_dataset[split][\"stoi\"]).add_column(\"si-sdr\", squim_dataset[split][\"sdr\"]).add_column(\"pesq\", squim_dataset[split][\"pesq\"])\n    \n    if args.output_dir:\n        print(\"Saving to disk...\")\n        dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        print(\"Pushing to the hub...\")\n        if args.configuration:\n            dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            dataset.push_to_hub(args.repo_id)\n    \n"
  },
  {
    "path": "requirements.txt",
    "content": "datasets[audio]\nhttps://github.com/marianne-m/brouhaha-vad/archive/main.zip\npenn\ng2p\ndemucs\ntransformers\naccelerate\nbitsandbytes"
  },
  {
    "path": "scripts/filter_audio_separation.py",
    "content": "from demucs import pretrained\nfrom demucs.apply import apply_model\nfrom demucs.audio import convert_audio\nfrom datasets import load_dataset\nfrom multiprocess import set_start_method\nimport torch\nimport argparse\nfrom datasets import Audio\n\n\n\ndemucs = pretrained.get_model('htdemucs')\nsource = demucs.sources\n\ndef wrap_audio(audio, sr):\n    return {\n        \"array\": audio.cpu().numpy(),\n        \"sampling_rate\": sr\n    }\n\n\n# TODO(YL): make compatible with other naming and stems\ndef filter_stems(batch, rank=None):\n    if rank is not None:\n        # move the model to the right GPU if not there already\n        device = f\"cuda:{(rank or 0)% torch.cuda.device_count()}\"\n        # move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway\n        demucs.to(device)\n\n    if isinstance(batch[\"audio\"], list):  \n        wavs = [convert_audio(\n                    torch.tensor(audio[\"array\"][None], device=device).to(torch.float32), audio[\"sampling_rate\"], demucs.samplerate, demucs.audio_channels).T for audio in batch[\"audio\"]]\n        wavs_length = [audio.shape[0] for audio in wavs]\n        \n        wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0).transpose(1,2)\n        stems = apply_model(demucs, wavs)\n        \n        batch[\"vocals\"] = [wrap_audio(s[-1,:,:length].mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]\n        batch[\"others\"] = [wrap_audio(s[:-1, :,:length].sum(0).mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]\n        \n    else:\n        audio = torch.tensor(batch[\"audio\"][\"array\"].squeeze(), device=device).to(torch.float32)\n        sample_rate = batch[\"audio\"][\"sampling_rate\"]\n        audio = convert_audio(\n                audio, sample_rate, demucs.samplerate, demucs.audio_channels)\n        stems = apply_model(demucs, audio[None])\n        \n        batch[\"vocals\"] = wrap_audio(stems[0,-1].mean(0), demucs.samplerate)\n        batch[\"others\"] = wrap_audio(stems[0, :-1].sum(0).mean(0), demucs.samplerate)\n\n    return batch\n    \nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset_name\", type=str, help=\"Path or name of the dataset. See: https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/loading_methods#datasets.load_dataset.path\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use, if necessary.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dataset on disk with this path.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the model to the hub.\")\n    parser.add_argument(\"--audio_column_name\", default=\"audio\", type=str, help=\"Column name of the audio column to be separated.\")\n    parser.add_argument(\"--batch_size\", default=8, type=int, help=\"Batch size. Speeds up operations on GPU.\")\n    parser.add_argument(\"--num_workers_per_gpu\", default=1, type=int, help=\"Number of workers per GPU for transformations that uses GPUs if GPUs are available. Defaults to 1 if some are avaiable. Useful if you want multiple processes per GPUs to maximise GPU usage.\")\n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration)\n    else:\n        dataset = load_dataset(args.dataset_name)    \n\n\n    num_proc = torch.cuda.device_count()*args.num_workers_per_gpu if torch.cuda.device_count() >= 1 else None\n\n    updated_dataset = dataset.map(\n        filter_stems,\n        batched=True,\n        batch_size=args.batch_size,\n        with_rank=True,\n        num_proc=num_proc,\n    )\n    \n    updated_dataset = updated_dataset.cast_column(\"vocals\", Audio())\n    updated_dataset = updated_dataset.cast_column(\"others\", Audio())\n    \n    if args.output_dir:\n        print(\"Saving to disk...\")\n        updated_dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        print(\"Pushing to the hub...\")\n        if args.configuration:\n            updated_dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            updated_dataset.push_to_hub(args.repo_id)\n    \n\n"
  },
  {
    "path": "scripts/merge_audio_to_metadata.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom datasets import load_dataset, concatenate_datasets\nfrom multiprocess import set_start_method\nimport argparse\n\n\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Repo id.\")\n    parser.add_argument(\"metadata_dataset_name\", type=str, help=\"Repo id.\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dasaset on disk.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the model to the hub.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers.\")\n    parser.add_argument(\"--strategy\", default=\"concatenate\", type=str, help=\"For now only concatenate.\")\n    parser.add_argument(\"--id_column_name\", default=\"id\", type=str, help=\"For now only concatenate.\") # TODO\n    parser.add_argument(\"--columns_to_drop\", default=None, type=str, help=\"Column names to drop in the metadataset. If some columns are duplicates. Separated by '+'. \")\n    \n\n    args = parser.parse_args()\n    \n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration)\n    else:\n        dataset = load_dataset(args.dataset_name)\n        \n    if args.configuration:\n        metadata_dataset = load_dataset(args.metadata_dataset_name, args.configuration)\n    else:\n        metadata_dataset = load_dataset(args.metadata_dataset_name)\n\n    columns_to_drop = None\n    if args.columns_to_drop is not None:\n        columns_to_drop = args.columns_to_drop.split(\"+\")\n        metadata_dataset = metadata_dataset.remove_columns(columns_to_drop)\n    \n    # TODO: for now suppose that they've kept the same ordering\n    for split in dataset:\n        if split in metadata_dataset:\n            dataset[split] = concatenate_datasets([dataset[split], metadata_dataset[split].rename_column(args.id_column_name, f\"metadata_{args.id_column_name}\")], axis=1)\n        else:\n            raise ValueError(f\"Metadataset don't have the same split {split} than dataset\")\n        \n        if len(dataset[split].filter(lambda id1, id2: id1!=id2, input_columns=[args.id_column_name, f\"metadata_{args.id_column_name}\"])) != 0:\n            raise ValueError(f\"Concatenate didn't work. Some ids don't correspond on split {split}\")\n    \n\n    if args.output_dir:\n        dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        if args.configuration:\n            dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            dataset.push_to_hub(args.repo_id)\n    "
  },
  {
    "path": "scripts/metadata_to_text.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom datasets import load_dataset, DatasetDict\nfrom multiprocess import set_start_method\nimport argparse\nfrom pathlib import Path\nimport os\nimport matplotlib.pyplot as plt\nimport json\n\nSPEAKER_RATE_BINS = [\"very slowly\", \"quite slowly\", \"slightly slowly\", \"moderate speed\", \"slightly fast\", \"quite fast\", \"very fast\"]\nSNR_BINS = [\"very noisy\", \"quite noisy\", \"slightly noisy\", \"moderate ambient sound\", \"slightly clear\", \"quite clear\", \"very clear\"]\nREVERBERATION_BINS = [\"very roomy sounding\", \"quite roomy sounding\", \"slightly roomy sounding\", \"moderate reverberation\", \"slightly confined sounding\", \"quite confined sounding\", \"very confined sounding\"]\nUTTERANCE_LEVEL_STD = [\"very monotone\", \"quite monotone\", \"slightly monotone\", \"moderate intonation\", \"slightly expressive\", \"quite expressive\", \"very expressive\"]\nSI_SDR_BINS = [\"extremely noisy\", \"very noisy\", \"noisy\", \"slightly noisy\", \"almost no noise\", \"very clear\"]\nPESQ_BINS = [\"very bad speech quality\", \"bad speech quality\", \"slightly bad speech quality\", \"moderate speech quality\", \"great speech quality\", \"wonderful speech quality\"]\n\n# this one is supposed to be apply to speaker-level mean pitch, and relative to gender\nSPEAKER_LEVEL_PITCH_BINS = [\"very low pitch\", \"quite low pitch\", \"slightly low pitch\", \"moderate pitch\", \"slightly high pitch\", \"quite high pitch\", \"very high pitch\"]\n\n\ndef visualize_bins_to_text(values_1, values_2, name_1, name_2, text_bins, save_dir, output_column_name, default_bins=100, lower_range=None):\n    # Save both histograms into a single figure\n    fig, axs = plt.subplots(2, figsize=(8,6), sharex=True)\n    \n    # Plot histogram and vertical lines for subplot 1\n    axs[0].hist(values_1, bins=default_bins, color='blue', alpha=0.7)\n    _, bin_edges1 = np.histogram(values_1, bins=len(text_bins), range=(lower_range, values_1.max()) if lower_range else None)\n    for edge in bin_edges1:\n        axs[0].axvline(x=edge, color='red', linestyle='--', linewidth=1)\n\n\n    # Plot histogram and vertical lines for subplot 2\n    axs[1].hist(values_2, bins=default_bins, color='green', alpha=0.7)\n    _, bin_edges2 = np.histogram(values_2, bins=len(text_bins), range=(lower_range, values_2.max()) if lower_range else None)\n    for edge in bin_edges2:\n        axs[1].axvline(x=edge, color='red', linestyle='--', linewidth=1)\n\n    # Add labels and title\n    axs[0].set_title(name_1)\n    axs[1].set_title(name_2)\n    axs[0].set_yscale('log')\n    axs[1].set_yscale('log')\n    axs[0].set_ylabel('Frequency')\n    axs[1].set_ylabel('Frequency')\n    axs[1].set_xlabel(f'{output_column_name}')\n\n    # Adjust layout\n    plt.tight_layout()\n\n    filename = f\"{output_column_name}.png\"\n    filepath = os.path.join(save_dir, filename)\n    plt.savefig(filepath)\n    print(f\"Plots saved at '{filename}'!\")\n\ndef bins_to_text(dataset, text_bins, column_name, output_column_name, leading_split_for_bins=\"train\", batch_size = 4, num_workers = 1, std_tolerance=5, save_dir=None, only_save_plot=False, lower_range=None, bin_edges=None):\n    '''\n    Compute bins of `column_name` from the splits `leading_split_for_bins` and apply text bins to every split.\n    `leading_split_for_bins` can be a string or a list.\n    '''\n    if bin_edges is None:\n        values = []\n        for df in dataset:\n            for split in df:\n                if leading_split_for_bins is None or leading_split_for_bins in split:\n                    values.extend(df[split][column_name])\n        \n        # filter out outliers\n        values = np.array(values)\n        values = values[~np.isnan(values)]\n        filtered_values = values\n        if std_tolerance is not None:\n            filtered_values = values[np.abs(values - np.mean(values)) < std_tolerance * np.std(values)]\n\n        if save_dir is not None:\n            visualize_bins_to_text(values, filtered_values, \"Before filtering\", \"After filtering\", text_bins, save_dir, output_column_name, lower_range=lower_range)\n            \n        # speaking_rate can easily have outliers\n        if save_dir is not None and output_column_name==\"speaking_rate\":\n            visualize_bins_to_text(filtered_values, filtered_values, \"After filtering\", \"After filtering\", text_bins, save_dir, f\"{output_column_name}_after_filtering\", lower_range=lower_range)\n        \n        values = filtered_values\n        hist, bin_edges = np.histogram(values, bins = len(text_bins), range=(lower_range, values.max()) if lower_range else None)\n        \n        if only_save_plot:\n            return dataset, bin_edges\n    else:\n        print(f\"Already computed bin edges have been passed for {output_column_name}. Will use: {bin_edges}.\")\n\n    def batch_association(batch):\n        index_bins = np.searchsorted(bin_edges, batch, side=\"left\")\n        # do min(max(...)) when values are outside of the main bins\n        # it happens when value = min or max or have been filtered out from bins computation\n        batch_bins = [text_bins[min(max(i-1, 0), len(text_bins)-1)] for i in index_bins]\n        return {\n            output_column_name: batch_bins\n        }\n    \n    dataset = [df.map(batch_association, batched=True, batch_size=batch_size, input_columns=[column_name], num_proc=num_workers) for df in dataset]\n    return dataset, bin_edges\n\ndef speaker_level_relative_to_gender(dataset, text_bins, speaker_column_name, gender_column_name, column_name, output_column_name, batch_size = 4, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=None):\n    '''\n    Computes mean values on a speaker level and computes bins on top relative to the gender column name.\n    Then associate a text bin to the column.\n    This time, doesn't use leading_split_for_bins, computes it for all. Could probably be optimized\n    '''\n    list_data = []\n    for df in dataset:\n        for split in df:\n            panda_data = df[split].remove_columns([col for col in df[split].column_names if col not in {speaker_column_name, column_name, gender_column_name}]).to_pandas()\n            list_data.append(panda_data)\n        \n    dataframe = pd.concat(list_data, ignore_index=True)\n    dataframe = dataframe.groupby(speaker_column_name).agg({column_name: \"mean\", gender_column_name: \"first\"})\n    if bin_edges is None:\n        bin_edges = {}\n        if save_dir is not None:\n            save_dict = {}\n            save_dict_afer_filtering = {}\n        for category in [\"male\", \"female\"]:\n            values = dataframe[dataframe[gender_column_name] == category][column_name]\n            values = np.array(values)\n            if save_dir is not None:\n                save_dict[category] = values\n            if std_tolerance is not None:\n                # filter out outliers\n                values = values[np.abs(values - np.mean(values)) < std_tolerance * np.std(values)]\n                if save_dir is not None:\n                    save_dict_afer_filtering[category] = values\n            bin_edges[category] = np.histogram(values, len(text_bins))[1]\n        \n        if save_dir is not None:\n            visualize_bins_to_text(save_dict[\"male\"], save_dict[\"female\"], \"Male distribution\", \"Female distribution\", text_bins, save_dir, output_column_name)\n            if std_tolerance is not None:\n                visualize_bins_to_text(save_dict_afer_filtering[\"male\"], save_dict_afer_filtering[\"female\"], \"Male distribution\", \"Female distribution\", text_bins, save_dir, f\"{output_column_name}_after_filtering\")\n\n        if only_save_plot:\n            return dataset, bin_edges\n    else:\n        print(f\"Already computed bin edges have been passed for {output_column_name}. Will use: {bin_edges}.\")\n     \n    speaker_id_to_bins = dataframe.apply(lambda x: np.searchsorted(bin_edges[x[gender_column_name]], x[column_name]), axis=1).to_dict()\n        \n    def batch_association(batch):\n        index_bins = [speaker_id_to_bins[speaker] for speaker in batch]\n        # do min(max(...)) when values are outside of the main bins\n        # it happens when value = min or max or have been filtered out from bins computation\n        batch_bins = [text_bins[min(max(i-1, 0), len(text_bins)-1)] for i in index_bins]\n        return {\n            output_column_name: batch_bins\n        }\n        \n    \n    dataset = [df.map(batch_association, batched=True, input_columns=[speaker_column_name], batch_size=batch_size, num_proc=num_workers) for df in dataset]\n    return dataset, bin_edges\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Path or name of the dataset(s). If multiple datasets, names have to be separated by `+`.\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration(s) to use (or configuration separated by +).\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dataset(s) on disk. If multiple datasets, paths have to be separated by `+`.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the dataset(s) to the hub. If multiple datasets, names have to be separated by `+`.\")\n    parser.add_argument(\"--path_to_text_bins\", default=None, type=str, help=\"If specified, points to a JSON file which contains the text bins that will be associated to each bins. Will use default bins.\")\n    parser.add_argument(\"--path_to_bin_edges\", default=None, type=str, help=\"If specified, points to a JSON file which contains the bin edges. Useful if you want to apply already computed bins to new datasets. If not specified, will recompute bin edges from scratch.\")\n    parser.add_argument(\"--save_bin_edges\", default=None, type=str, help=\"If specified, it's the name of the JSON file which will contains the edge bins that have been computed. Useful if you want to reuse those bin eges on new datasets. By default, it won't save those edges..\")\n    parser.add_argument(\"--avoid_pitch_computation\", default=False, action=\"store_true\", help=\"If `True`, will not compute `pitch`. Note that `pitch` is computed on a speaker-level, relative to gender, so you don't need it in a mono-speaker setting.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers.\")\n    parser.add_argument(\"--batch_size\", default=16, type=int, help=\"Batch size in `Dataset.map` operations. https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Dataset.map\")\n    parser.add_argument(\"--speaker_id_column_name\", default=\"speaker_id\", type=str, help=\"Speaker id column name. Only used if `avoid_pitch_computation=False`\")\n    parser.add_argument(\"--gender_column_name\", default=\"gender\", type=str, help=\"Gender column name. .Only used if `avoid_pitch_computation=False`\")\n    parser.add_argument(\"--pitch_std_tolerance\", default=2., type=float, help=\"Standard deviation tolerance for pitch estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `avoid_pitch_computation=False`.\")\n    parser.add_argument(\"--speaking_rate_std_tolerance\", default=4., type=float, help=\"Standard deviation tolerance for speaking rate estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--snr_std_tolerance\", default=3.5, type=float, help=\"Standard deviation tolerance for SNR estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--reverberation_std_tolerance\", default=4, type=float, help=\"Standard deviation tolerance for reverberation estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--speech_monotony_std_tolerance\", default=4, type=float, help=\"Standard deviation tolerance for speech monotony estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--leading_split_for_bins\", default=None, type=str, help=\"If specified, will use every split that contains this string to compute statistics. If not specified, will use every split. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--plot_directory\", default=None, type=str, help=\"If specified, will save visualizing plots to this directory. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--only_save_plot\", default=False, action=\"store_true\", help=\"If `True` and `--plot_directory` is specified, will only compute plot. Only used if `path_to_bin_edges=False`.\")\n    parser.add_argument(\"--snr_lower_range\", default=None, type=float, help=\"The lower range of the SNR bins\")\n    parser.add_argument(\"--speaking_rate_lower_range\", default=None, type=float, help=\"The lower range of the speaking rate bins\")\n    parser.add_argument(\"--apply_squim_quality_estimation\", action=\"store_true\", help=\"If set, will also compute bins for torchaudio-squim estimation (SI-SNR, PESQ).\")\n    parser.add_argument(\"--pesq_std_tolerance\", default=None, type=float, help=\"Used if `apply_squim_quality_estimation=True`. Standard deviation tolerance for PESQ estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `avoid_pitch_computation=False`.\")\n    parser.add_argument(\"--sdr_std_tolerance\", default=None, type=float, help=\"Used if `apply_squim_quality_estimation=True`. Standard deviation tolerance for SI-SDR estimation. Any value that is outside mean ± std * tolerance is discared. Only used if `path_to_bin_edges=False`.\")\n\n    args = parser.parse_args()\n    \n    if args.plot_directory is None and args.only_save_plot:\n        raise ValueError(\"`only_save_plot=true` but `plot_directory` is not specified. Please give a path to the directory where you want the plot to be saved.\")\n    if args.only_save_plot and args.path_to_bin_edges:\n        raise ValueError(\"`only_save_plot=true` but `path_to_bin_edges` is specified. Since the latter is specified, we won't redo computations that would have been used for plotting. Chose one ar another. Note that if you use this script to label a new dataset for fine-tuning, I'd recommend avoiding plotting and set `only_save_plot=false`\")\n        \n    text_bins_dict = {}\n    if args.path_to_text_bins:\n        with open(args.path_to_text_bins) as json_file:\n            text_bins_dict = json.load(json_file)\n            \n    bin_edges_dict = {}\n    if args.path_to_bin_edges:\n        with open(args.path_to_bin_edges) as json_file:\n            bin_edges_dict = json.load(json_file)\n\n    speaker_level_pitch_bins = text_bins_dict.get(\"speaker_level_pitch_bins\", SPEAKER_LEVEL_PITCH_BINS)\n    speaker_rate_bins = text_bins_dict.get(\"speaker_rate_bins\", SPEAKER_RATE_BINS)\n    snr_bins = text_bins_dict.get(\"snr_bins\", SNR_BINS)\n    reverberation_bins = text_bins_dict.get(\"reverberation_bins\", REVERBERATION_BINS)\n    utterance_level_std = text_bins_dict.get(\"utterance_level_std\", UTTERANCE_LEVEL_STD)\n    \n    if args.apply_squim_quality_estimation:\n        sdr_bins = text_bins_dict.get(\"sdr_bins\", SI_SDR_BINS)\n        pesq_std = text_bins_dict.get(\"pesq_bins\", PESQ_BINS)\n\n    output_dirs = [args.output_dir] if args.output_dir is not None else None\n    repo_ids = [args.repo_id] if args.repo_id is not None else None\n    if args.configuration:\n        if \"+\" in args.dataset_name:\n            dataset_names = args.dataset_name.split(\"+\")\n            dataset_configs = args.configuration.split(\"+\")\n            if len(dataset_names) != len(dataset_configs):\n                raise ValueError(f\"There are {len(dataset_names)} datasets spotted but {len(dataset_configs)} configuration spotted\")\n            \n            if args.repo_id is not None:\n                repo_ids = args.repo_id.split(\"+\")\n                if len(dataset_names) != len(repo_ids):\n                    raise ValueError(f\"There are {len(dataset_names)} datasets spotted but {len(repo_ids)} repository ids spotted\")\n\n            if args.output_dir is not None:\n                output_dirs = args.output_dir.split(\"+\")\n                if len(dataset_names) != len(output_dirs):\n                    raise ValueError(f\"There are {len(dataset_names)} datasets spotted but {len(output_dirs)} local paths on which to save the datasets spotted\")\n            \n            dataset = []\n            for dataset_name, dataset_config in zip(dataset_names, dataset_configs):\n                tmp_dataset = load_dataset(dataset_name, dataset_config, num_proc=args.cpu_num_workers)\n                dataset.append(tmp_dataset)\n        else:\n            dataset = [load_dataset(args.dataset_name, args.configuration, num_proc=args.cpu_num_workers)]\n            dataset_configs = [args.configuration]\n    else:\n        if \"+\" in args.dataset_name:\n            dataset_names = args.dataset_name.split(\"+\")\n            if args.repo_id is not None:\n                repo_ids = args.repo_id.split(\"+\")\n                if len(dataset_names) != len(repo_ids):\n                    raise ValueError(f\"There are {len(dataset_names)} datasets spotted but {len(repo_ids)} repository ids spotted\")\n\n            if args.output_dir is not None:\n                output_dirs = args.output_dir.split(\"+\")\n                if len(dataset_names) != len(output_dirs):\n                    raise ValueError(f\"There are {len(dataset_names)} datasets spotted but {len(output_dirs)} local paths on which to save the datasets spotted\")\n            \n            dataset = []\n            for dataset_name, dataset_config in zip(dataset_names):\n                tmp_dataset = load_dataset(dataset_name, num_proc=args.cpu_num_workers)\n                dataset.append(tmp_dataset)\n\n        else:\n            dataset = [load_dataset(args.dataset_name, num_proc=args.cpu_num_workers)]\n\n    if args.plot_directory:\n        Path(args.plot_directory).mkdir(parents=True, exist_ok=True)\n    \n    if not args.avoid_pitch_computation:\n        bin_edges = None\n        if \"pitch_bins_male\" in bin_edges_dict and \"pitch_bins_female\" in bin_edges_dict:\n            bin_edges = {\"male\": bin_edges_dict[\"pitch_bins_male\"], \"female\": bin_edges_dict[\"pitch_bins_female\"]}\n\n        dataset, pitch_bin_edges = speaker_level_relative_to_gender(dataset, speaker_level_pitch_bins, args.speaker_id_column_name, args.gender_column_name, \"utterance_pitch_mean\", \"pitch\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, std_tolerance=args.pitch_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges)\n\n    dataset, speaking_rate_bin_edges = bins_to_text(dataset, speaker_rate_bins, \"speaking_rate\", \"speaking_rate\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.speaking_rate_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"speaking_rate\",None), lower_range=args.speaking_rate_lower_range)\n    dataset, noise_bin_edges = bins_to_text(dataset, snr_bins, \"snr\", \"noise\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.snr_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"noise\",None), lower_range=args.snr_lower_range)\n    dataset, reverberation_bin_edges = bins_to_text(dataset, reverberation_bins, \"c50\", \"reverberation\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.reverberation_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"reverberation\",None))\n    dataset, speech_monotony_bin_edges = bins_to_text(dataset, utterance_level_std, \"utterance_pitch_std\", \"speech_monotony\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.speech_monotony_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"speech_monotony\",None))\n\n    if args.apply_squim_quality_estimation:\n        dataset, sdr_bin_edges = bins_to_text(dataset, sdr_bins, \"si-sdr\", \"sdr_noise\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.sdr_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"si-sdr\",None))\n        dataset, pesq_bin_edges = bins_to_text(dataset, pesq_std, \"pesq\", \"pesq_speech_quality\", batch_size=args.batch_size, num_workers=args.cpu_num_workers, leading_split_for_bins=args.leading_split_for_bins, std_tolerance=args.pesq_std_tolerance, save_dir=args.plot_directory, only_save_plot=args.only_save_plot, bin_edges=bin_edges_dict.get(\"pesq\",None))\n\n    if args.save_bin_edges:\n        bin_edges = {\n            \"speaking_rate\": speaking_rate_bin_edges.tolist(),\n            \"noise\": noise_bin_edges.tolist(),\n            \"reverberation\": reverberation_bin_edges.tolist(),\n            \"speech_monotony\": speech_monotony_bin_edges.tolist(),\n        }\n        if not args.avoid_pitch_computation:\n            bin_edges[\"pitch_bins_male\"] = pitch_bin_edges[\"male\"].tolist()\n            bin_edges[\"pitch_bins_female\"] = pitch_bin_edges[\"female\"].tolist()\n        if args.apply_squim_quality_estimation:\n            bin_edges[\"si-sdr\"] = sdr_bin_edges.tolist()\n            bin_edges[\"pesq\"] = pesq_bin_edges.tolist()\n        \n        with open(args.save_bin_edges, \"w\") as outfile: \n            json.dump(bin_edges, outfile)\n        \n    if not args.only_save_plot:\n        if args.output_dir:\n            for output_dir, df in zip(output_dirs, dataset):\n                df.save_to_disk(output_dir)\n        if args.repo_id:\n            for i, (repo_id, df) in enumerate(zip(repo_ids, dataset)):\n                if args.configuration:\n                    df.push_to_hub(repo_id, dataset_configs[i])\n                else:\n                    df.push_to_hub(repo_id)\n"
  },
  {
    "path": "scripts/per_dataset_script/add_gender_to_MLS.py",
    "content": "from datasets import load_dataset\nfrom multiprocess import set_start_method\nimport pandas as pd\nimport argparse\n\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Repo id or local path.\")\n    parser.add_argument(\"tsv_path\", default=None, type=str, help=\"Text column name.\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dasaset on disk.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the model to the hub.\")\n    parser.add_argument(\"--speaker_id_column_name\", default=\"speaker_id\", type=str, help=\"Audio column name.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers for transformations that don't use GPUs or if no GPU are available.\")\n\n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration)\n    else:\n        dataset = load_dataset(args.dataset_name)\n        \n    speaker_id_column_name = args.speaker_id_column_name\n\n    speaker_dataset = pd.read_csv(args.tsv_path, sep=\"|\", on_bad_lines='skip')\n    speaker_column = ' SPEAKER   ' \n    gender_column = '   GENDER   '\n    speaker_dataset = speaker_dataset.set_index(speaker_column)[gender_column]\n    speaker_dataset = speaker_dataset.to_dict()\n    \n    def map_gender(speaker_ids):\n        genders = [speaker_dataset[int(speaker)].strip() for speaker in speaker_ids]\n        return {\"gender\": [\"male\" if g==\"M\" else \"female\" for g in genders]}\n    \n    dataset = dataset.map(map_gender, batched=True, batch_size=128, input_columns=speaker_id_column_name, num_proc=args.cpu_num_workers)\n\n    \n    if args.output_dir:\n        dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        if args.configuration:\n            dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            dataset.push_to_hub(args.repo_id)\n    \n"
  },
  {
    "path": "scripts/per_dataset_script/add_gender_to_libritts_r.py",
    "content": "from datasets import load_dataset\nfrom multiprocess import set_start_method\nimport pandas as pd\nimport argparse\n\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Repo id or local path.\")\n    parser.add_argument(\"tsv_path\", default=None, type=str, help=\"Text column name.\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dasaset on disk.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the model to the hub.\")\n    parser.add_argument(\"--speaker_id_column_name\", default=\"speaker_id\", type=str, help=\"Audio column name.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers for transformations that don't use GPUs or if no GPU are available.\")\n\n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration)\n    else:\n        dataset = load_dataset(args.dataset_name)\n        \n    speaker_id_column_name = args.speaker_id_column_name\n\n    speaker_dataset = pd.read_csv(args.tsv_path, sep=\"\\t\").to_dict()\n    \n    def map_gender(speaker_ids):\n        genders = [speaker_dataset[\"READER\"][int(speaker)] for speaker in speaker_ids]\n        return {\"gender\": [\"male\" if g==\"M\" else \"female\" for g in genders]}\n    \n    dataset = dataset.map(map_gender, batched=True, batch_size=128, input_columns=speaker_id_column_name, num_proc=args.cpu_num_workers)\n\n    \n    if args.output_dir:\n        dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        if args.configuration:\n            dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            dataset.push_to_hub(args.repo_id)\n    \n"
  },
  {
    "path": "scripts/per_dataset_script/clean_libritts_r.py",
    "content": "from datasets import load_dataset\nfrom multiprocess import set_start_method\nimport pandas as pd\nimport argparse\nfrom os import listdir\nimport os\n\n\nif __name__ == \"__main__\":\n    set_start_method(\"spawn\")\n    parser = argparse.ArgumentParser()\n    \n    \n    parser.add_argument(\"dataset_name\", type=str, help=\"Repo id or local path.\")\n    parser.add_argument(\"bad_samples_folder\", default=None, type=str, help=\"Path to LibriTTS-R bad folder samples.\")\n    parser.add_argument(\"--configuration\", default=None, type=str, help=\"Dataset configuration to use.\")\n    parser.add_argument(\"--output_dir\", default=None, type=str, help=\"If specified, save the dasaset on disk.\")\n    parser.add_argument(\"--repo_id\", default=None, type=str, help=\"If specified, push the model to the hub.\")\n    parser.add_argument(\"--speaker_id_column_name\", default=\"speaker_id\", type=str, help=\"Speaker id column name.\")\n    parser.add_argument(\"--cpu_num_workers\", default=1, type=int, help=\"Number of CPU workers for transformations that don't use GPUs or if no GPU are available.\")\n\n    args = parser.parse_args()\n    \n    if args.configuration:\n        dataset = load_dataset(args.dataset_name, args.configuration)\n    else:\n        dataset = load_dataset(args.dataset_name)\n        \n    speaker_id_column_name = args.speaker_id_column_name\n    \n    # speakers to exclude because of mixed gender detection\n    # cf: https://github.com/line/LibriTTS-P/blob/main/data/excluded_spk_list.txt\n    speakers_to_remove = {2074, 4455, 6032, 3546, 2262, 8097, 1734, 3793, 8295}\n    \n    def filter_speakers(speaker, speakers_to_remove):\n        return int(speaker) not in speakers_to_remove \n\n    print(dataset)\n    dataset = dataset.filter(filter_speakers, input_columns=speaker_id_column_name, num_proc=args.cpu_num_workers, fn_kwargs={\"speakers_to_remove\": speakers_to_remove})\n    print(dataset)\n    \n    bad_samples_txt_files = [os.path.join(args.bad_samples_folder, f) for f in listdir(args.bad_samples_folder) if \"bad_sample\" in f] \n\n    samples_to_filter = set()\n    for txt_file in bad_samples_txt_files:\n        with open(txt_file, 'r') as file:\n            for line in file:\n                line = line.strip().split(\"/\")[-1].split(\".\")[0]\n\n                samples_to_filter.add(line)\n\n    print(len(samples_to_filter))\n    def filter_samples(id, samples_to_filter):\n        return id not in samples_to_filter \n    dataset = dataset.filter(filter_samples, input_columns=\"id\", num_proc=args.cpu_num_workers, fn_kwargs={\"samples_to_filter\": samples_to_filter})\n\n    print(dataset)\n    if args.output_dir:\n        dataset.save_to_disk(args.output_dir)\n    if args.repo_id:\n        if args.configuration:\n            dataset.push_to_hub(args.repo_id, args.configuration)\n        else:\n            dataset.push_to_hub(args.repo_id)\n    \n"
  },
  {
    "path": "scripts/run_prompt_creation.py",
    "content": "import json\nimport logging\nimport os\nimport re\nimport shutil\nimport sys\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom accelerate import Accelerator, skip_first_batches\nfrom accelerate.logging import get_logger\nfrom datasets import DatasetDict, load_dataset\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    BitsAndBytesConfig,\n    HfArgumentParser,\n)\nfrom datetime import timedelta\nfrom accelerate import InitProcessGroupKwargs\n\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    model_name_or_path: str = field(\n        metadata={\"help\": \"The name of the model to use (via the transformers library) for the prompt annotation.\"},\n    )\n    per_device_eval_batch_size: int = field(\n        metadata={\"help\": \"The per-device batch size to use for inference.\"},\n    )\n    model_variant: str = field(\n        default=None,\n        metadata={\"help\": \"If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. \"},\n    )\n    model_revision: str = field(\n        default=\"main\",\n        metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n    )\n    cache_dir: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Where to store the pretrained models downloaded from huggingface.co\"},\n    )\n    torch_dtype: Optional[str] = field(\n        default=\"float16\",\n        metadata={\n            \"help\": (\n                \"Floating-point format in which the model weights should be initialized\"\n                \" and the computations run. Choose one of `[float32, float16, bfloat16]`.\"\n            )\n        },\n    )\n    attn_implementation: Optional[str] = field(\n        default=\"sdpa\",\n        metadata={\"help\": \"Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']\"},\n    )\n    load_in_8bit: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use 8-bit precision for inference.\"}\n    )\n    load_in_4bit: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use 4-bit precision for inference.\"}\n    )\n    bnb_4bit_quant_type: Optional[str] = field(\n        default=\"nf4\", metadata={\"help\": \"precise the quantization type (fp4 or nf4)\"}\n    )\n    use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={\"help\": \"use nested quantization\"})\n    trust_remote_code: Optional[bool] = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option \"\n                \"should only be set to `True` for repositories you trust and in which you have read the code, as it will \"\n                \"execute code present on the Hub on your local machine.\"\n            )\n        },\n    )\n    use_fast_tokenizer: Optional[bool] = field(\n        default=True, metadata={\"help\": \"Use fast tokenizer for encoding/decoding input ids\"}\n    )\n    token: Optional[bool] = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub\"\n        },\n    )\n    do_sample: Optional[bool] = field(default=True, metadata={\"help\": \"Whether to use sampling mode for generation\"})\n    temperature: Optional[float] = field(default=0.6, metadata={\"help\": \"Temperature for sampling-based generation\"})\n    max_new_tokens: Optional[int] = field(\n        default=256, metadata={\"help\": \"Maximum number of new tokens during generation\"}\n    )\n    torch_compile: Optional[bool] = field(\n        default=False,\n        metadata={\n            \"help\": \"Whether to compile the forward pass (not sampling) in generate. Only compatible with Gemma and LlaMA.\"\n        },\n    )\n\n\n@dataclass\nclass DataArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    output_dir: str = field(\n        metadata={\n            \"help\": \"Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the \"\n            \"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'.\"\n        },\n    )\n    dataset_name: str = field(\n        default=None,\n        metadata={\"help\": \"The name of the dataset to use (via the datasets library)\"},\n    )\n    dataset_config_name: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"},\n    )\n    dataset_split_name: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"The split name of the dataset to use (via the datasets library).\"},\n    )\n    dataset_cache_dir: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Path to cache directory for saving and loading datasets\"},\n    )\n    max_eval_samples: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Maximum number of samples for generation - use for debugging purposes.\"},\n    )\n    overwrite_cache: bool = field(\n        default=False,\n        metadata={\"help\": \"Overwrite the cached training and evaluation sets\"},\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n    )\n    dataloader_num_workers: Optional[int] = field(\n        default=0,\n        metadata={\"help\": \"The number of processes to use for the dataloader.\"},\n    )\n    push_to_hub: Optional[bool] = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to push the processed dataset to the Hub.\"},\n    )\n    hub_dataset_id: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Repository namespace if pushing to the Hugging Face Hub.\"},\n    )\n    overwrite_output_dir: Optional[bool] = field(\n        default=False,\n        metadata={\"help\": \"Overwrite the content of the output directory each time the script is run.\"},\n    )\n    save_steps: Optional[int] = field(\n        default=500,\n        metadata={\"help\": \"Save the generated prompts every save_steps.\"},\n    )\n    save_total_limit: Optional[int] = field(\n        default=1, metadata={\"help\": (\"If a value is passed, will limit the total number of saved checkpoints\")}\n    )\n    speaker_name: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"If `is_single_speaker`, it specified the speaker name that you want to give to the mono-speaker of your dataset.\"},\n    )\n    is_single_speaker: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use a single speaker prompt, with a single name, specified by `speaker_name`.\"}\n    )\n    is_new_speaker_prompt: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use the newest speaker prompt, which will be used for the next Parler-TTS.\"}\n    )\n    speaker_id_column: Optional[str] = field(\n        default=None, metadata={\"help\": \"Speaker id column name. Only used if creating a dataset with multiple speaker names (i.e if `speaker_ids_to_name_json` is specified)\"}\n    )\n    speaker_ids_to_name_json: Optional[str] = field(\n        default=None, metadata={\"help\": \"Path to a JSON file which map some speaker ids to some names. Only used if `speaker_id_column` is specified.\"}\n    )\n    accent_column: Optional[str] = field(\n        default=None, metadata={\"help\": \"Accent column name, if any.\"}\n    )\n\n\n    def __post_init__(self):\n        if self.push_to_hub and self.hub_dataset_id is None:\n            raise ValueError(\"You must specify the `hub_dataset_id` when setting `--push_to_hub=True`\")\n\n\ndef get_quantization_config(model_args: ModelArguments) -> Union[BitsAndBytesConfig, None]:\n    if model_args.load_in_4bit:\n        compute_dtype = torch.float16\n        if model_args.torch_dtype not in {\"auto\", None}:\n            compute_dtype = getattr(torch, model_args.torch_dtype)\n\n        quantization_config = BitsAndBytesConfig(\n            load_in_4bit=True,\n            bnb_4bit_compute_dtype=compute_dtype,\n            bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,\n            bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,\n        )\n    elif model_args.load_in_8bit:\n        quantization_config = BitsAndBytesConfig(\n            load_in_8bit=True,\n        )\n    else:\n        quantization_config = None\n\n    return quantization_config\n\n\ndef get_current_device() -> int:\n    \"\"\"Get the current device. For GPU we return the local process index to enable multiple GPU training.\"\"\"\n    return Accelerator().local_process_index if torch.cuda.is_available() else \"cpu\"\n\n\ndef get_kbit_device_map() -> Union[Dict[str, int], None]:\n    \"\"\"Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`\"\"\"\n    return {\"\": get_current_device()} if torch.cuda.is_available() else None\n\n\nCHECKPOINT_PREFIX = \"checkpoint\"\n_RE_CHECKPOINT = re.compile(r\"^checkpoint-(\\d+).json$\")\n\n\ndef save_checkpoint(output_dir, all_generated_ids, step):\n    checkpoint_path = f\"{CHECKPOINT_PREFIX}-{step}.json\"\n    output_path = os.path.join(output_dir, checkpoint_path)\n    all_generated_ids = [ids.tolist() for ids in all_generated_ids]\n    with open(output_path, \"w\") as file:\n        json.dump(all_generated_ids, file)\n\n\ndef load_checkpoint(checkpoint_path):\n    with open(checkpoint_path, \"r\") as file:\n        all_generated_ids = json.load(file)\n    logger.info(f\"Json file {checkpoint_path} loaded.\")\n    all_generated_ids = [np.array(lst) for lst in all_generated_ids]\n    return all_generated_ids\n\n\ndef sorted_checkpoints(output_dir=None) -> List[str]:\n    \"\"\"Helper function to sort saved checkpoints from oldest to newest.\"\"\"\n    ordering_and_checkpoint_path = []\n\n    glob_checkpoints = [str(x) for x in Path(output_dir).glob(f\"{CHECKPOINT_PREFIX}-*\")]\n\n    for path in glob_checkpoints:\n        regex_match = re.match(f\".*{CHECKPOINT_PREFIX}-([0-9]+)\", path)\n        if regex_match is not None and regex_match.groups() is not None:\n            ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n\n    checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n    return checkpoints_sorted\n\n\ndef rotate_checkpoints(save_total_limit=None, output_dir=None) -> None:\n    \"\"\"Helper function to delete old checkpoints.\"\"\"\n    if save_total_limit is None or save_total_limit <= 0:\n        return\n    # Check if we should delete older checkpoint(s)\n    checkpoints_sorted = sorted_checkpoints(output_dir=output_dir)\n    if len(checkpoints_sorted) <= save_total_limit:\n        return\n\n    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)\n    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n    for checkpoint in checkpoints_to_be_deleted:\n        logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n        os.remove(checkpoint)\n\n\ndef get_last_checkpoint(folder, return_list=False) -> Tuple[List, int]:\n    if not os.path.exists(folder) or not os.path.isdir(folder):\n        os.makedirs(folder, exist_ok=True)\n        return [], 0\n    content = os.listdir(folder)\n    checkpoints = [path for path in content if _RE_CHECKPOINT.search(path) is not None]\n    if len(checkpoints) == 0:\n        return [], 0\n    last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))\n    # Find num steps saved state string pattern\n    pattern = r\"checkpoint-(\\d+).json\"\n    match = re.search(pattern, last_checkpoint)\n    cur_step = int(match.group(1))\n    if return_list:\n        # load corresponding generated ids\n        all_generated_ids = load_checkpoint(last_checkpoint)\n        return all_generated_ids, cur_step\n    else:\n        return [], cur_step\n\n\n@dataclass\nclass DataCollatorWithPadding:\n    \"\"\"\n    Data collator that will dynamically pad the inputs received to the longest sequence in the batch.\n    \"\"\"\n\n    tokenizer: Any\n\n    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n        # split inputs and labels since they have to be of different lengths and need\n        # different padding methods\n        input_ids = {\"input_ids\": [feature[\"input_ids\"] for feature in features]}\n        batch = self.tokenizer.pad(input_ids, return_tensors=\"pt\", padding=\"longest\", return_attention_mask=True)\n        return batch\n\n\nPROMPT = \"\"\"You will be given six descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (e.g., male, female)\n2. The level of reverberation (e.g., very roomy sounding, quite roomy sounding, slightly roomy sounding, moderate reverberation, slightly confined sounding, quite confined sounding, very confined sounding)\n3. The amount of noise the sample (e.g., very noisy, quite noisy, slightly noisy, moderate ambient sound, slightly clear, quite clear, very clear)\n4. The tone of the speaker's voice (e.g., very monotone, quite monotone, slightly monotone, moderate intonation, slightly expressive, quite expressive, very expressive)\n5. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)\n6. The pitch of the speaker's voice (e.g., very low pitch, quite low pitch, slightly low pitch, moderate pitch, slightly high pitch, quite high pitch, very high pitch)\nYour task is to create a text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.\nFor example, given the following keywords: 'female', 'slightly roomy sounding', 'slightly noisy', 'very expressive', 'slightly low pitch', 'very slowly', a valid description would be: 'a woman with a deep voice speaks slowly but has an animated delivery in an echoey room with some background noise'.\nFor the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:\"\n\"\"\"\n\nNEW_PROMPT = \"\"\"You will be given six descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (male, female)\n2. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n4. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n5. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n6. The pitch of the speaker's voice (very low-pitch, low-pitch, slightly low-pitch, moderate pitch, slightly high-pitch, high-pitch, very high-pitch)\n\nYour task is to create a text description using these keywords that accurately describes the speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'female', 'slightly distant-sounding', 'noisy', 'very expressive and animated', 'very slowly', 'moderate pitch', a valid description would be: 'A woman speaks very slowly but has a very animated delivery. The recording is noisy and there is some roominess.'\nAnother valid description would be: 'In a noisy room, a female speaker delivers a very animated and expressive speech, at a very slow pace.'\nAnother valid description would be: 'A woman enunciates a very expressive speech. Her voice is slightly distant-sounding, with some background noise present. She speaks very slowly with a moderate pitch but a very expressive tone.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: '[gender]', '[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', '[pitch]', the corresponding description is:\n\"\"\"\n\nNEW_PROMPT_WITH_ACCENT = \"\"\"You will be given 7 descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (male, female)\n2. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n4. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n5. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n6. The pitch of the speaker's voice (very low-pitch, low-pitch, slightly low-pitch, moderate pitch, slightly high-pitch, high-pitch, very high-pitch)\n7. The accent of the speaker.\n\nYour task is to create a text description using these keywords that accurately describes the speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'female', 'slightly distant-sounding', 'noisy', 'very expressive and animated', 'very slowly', 'moderate pitch', 'Chinese', a valid description would be: 'A woman with a Chinese accent speaks very slowly but has a very animated delivery. The recording is noisy and there is some roominess.'\nAnother valid description would be: 'In a noisy room, a female speaker with a Chinese accent delivers a very animated and expressive speech, at a very slow pace.'\nAnother valid description would be: 'A woman with a Chinese accent enunciates a very expressive speech. Her voice is slightly distant-sounding, with some background noise present. She speaks very slowly with a moderate pitch but a very expressive tone.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: '[gender]', '[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', '[pitch]', '[accent]', the corresponding description is:\n\"\"\"\n\n\nNEW_SINGLE_SPEAKER_PROMPT = \"\"\"You will be given four descriptive keywords related to an audio sample of [speaker_name]'s speech. These keywords include:\n1. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n3. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n4. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n\nYour task is to create a text description using these keywords that accurately describes [speaker_name]'s speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'slightly distant-sounding', 'clear', 'very expressive and animated', 'slightly fast', a valid description would be: '[speaker_name] speaks slightly fast but has a very animated delivery in a room with slight echo but no background noise.'\nAnother valid description would be: `In a very animated voice, [speaker_name] delivers words slightly quickly. The room is quite, but there's a bit of echo.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: ''[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', the corresponding description is:\n\"\"\"\n\nSINGLE_SPEAKER_PROMPT = \"\"\"You will be given four descriptive keywords related to an audio sample of [speaker_name]'s speech. These keywords include:\n1. The level of reverberation (e.g., very roomy sounding, quite roomy sounding, slightly roomy sounding, moderate reverberation, slightly confined sounding, quite confined sounding, very confined sounding)\n2. The amount of noise the sample (e.g., very noisy, quite noisy, slightly noisy, moderate ambient sound, slightly clear, quite clear, very clear)\n3. The tone of the speaker's voice (e.g., very monotone, quite monotone, slightly monotone, moderate intonation, slightly expressive, quite expressive, very expressive)\n4. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)\n\nYour task is to create a single and only short text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', you must include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', you must include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.\n\nFor example, given the following keywords: 'slightly roomy sounding', 'quite noisy', 'very expressive', 'very slowly', a valid description would be: '[speaker_name] speaks very slowly but has an animated delivery in an echoey room with background noise.'.\nFeel free to change the order of keywords, and to use synonyms, for example, with the previous keywords: `In a very expressive voice, [speaker_name] pronounces her words incredibly slowly. There's some background noise in this room with a bit of echo.'.\n\nFor the keywords: ''[reverberation]', '[noise]', '[speech_monotony]', '[speaking_rate]', the corresponding description is:\n\"\"\"\n\ndef main():\n    # 1. Parse input arguments\n    parser = HfArgumentParser((ModelArguments, DataArguments))\n    if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n        # If we pass only one argument to the script and it's the path to a json file,\n        # let's parse it to get our arguments.\n        model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))\n    else:\n        model_args, data_args = parser.parse_args_into_dataclasses()\n\n    # 2. Setup logging\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        handlers=[logging.StreamHandler(sys.stdout)],\n    )\n    \n    if data_args.is_single_speaker and data_args.speaker_name is None:\n        raise ValueError(\"`is_single_speaker=True` but `speaker_name` is not specified. Specify it or remove `is_single_speaker`.\")\n\n    if not data_args.is_single_speaker and data_args.speaker_name:\n        raise ValueError(f\"`is_single_speaker=False` but `speaker_name=data_args.speaker_name` is not specified. Add `--is_single_speaker` or remove `speaker_name`.\")\n\n\n    # Create the custom configuration\n    process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=3600*3))\n    accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])\n\n    if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):\n        logger.info(\"Cleaning output dir from previous run...\")\n        shutil.rmtree(data_args.output_dir)\n\n    # 3. Load annotated dataset\n    logger.info(\"*** Load annotated dataset ***\")\n    if data_args.dataset_split_name is not None:\n        raw_datasets = DatasetDict()\n        data_splits = data_args.dataset_split_name.split(\"+\")\n        # load on a split-wise basis\n        for split in data_splits:\n            with accelerator.local_main_process_first():\n                raw_datasets[split] = load_dataset(\n                    data_args.dataset_name,\n                    data_args.dataset_config_name,\n                    split=split,\n                    cache_dir=model_args.cache_dir,\n                    token=model_args.token,\n                    num_proc=data_args.preprocessing_num_workers,\n                )\n    else:\n        with accelerator.local_main_process_first():\n            # load all splits for annotation\n            raw_datasets = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                cache_dir=model_args.cache_dir,\n                token=model_args.token,\n                num_proc=data_args.preprocessing_num_workers,\n            )\n\n    raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys())\n\n    if data_args.max_eval_samples is not None:\n        for split in raw_datasets:\n            raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))\n\n    EXPECTED_COLUMNS = {\"gender\", \"pitch\", \"noise\", \"reverberation\", \"speech_monotony\", \"speaking_rate\"}\n    if data_args.is_single_speaker:\n        EXPECTED_COLUMNS = {\"noise\", \"reverberation\", \"speech_monotony\", \"speaking_rate\"}\n        \n    if data_args.is_new_speaker_prompt:\n        EXPECTED_COLUMNS.remove(\"noise\")\n        EXPECTED_COLUMNS.add(\"sdr_noise\")\n        \n    speaker_ids_to_name = {}\n    speaker_id_column = data_args.speaker_id_column\n    if data_args.speaker_id_column and data_args.speaker_ids_to_name_json:\n        import json\n        if data_args.is_single_speaker:\n            raise ValueError(f\"`is_single_speaker=True` but `speaker_ids_to_name_json={data_args.speaker_ids_to_name_json}`. Specify one or another.\")\n        \n        EXPECTED_COLUMNS.add(data_args.speaker_id_column)\n        with open(data_args.speaker_ids_to_name_json, \"r\") as read_file:\n            speaker_ids_to_name = json.load(read_file)\n\n    if not EXPECTED_COLUMNS.issubset(raw_datasets_features):\n        missing_columns = EXPECTED_COLUMNS - raw_datasets_features\n        raise ValueError(\n            f\"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}\"\n        )\n\n    # 4. Load pre-trained model\n    logger.info(\"*** Load pretrained model ***\")\n    torch_dtype = (\n        model_args.torch_dtype if model_args.torch_dtype in [\"auto\", None] else getattr(torch, model_args.torch_dtype)\n    )\n    quantization_config = get_quantization_config(model_args)\n\n    model = AutoModelForCausalLM.from_pretrained(\n        model_args.model_name_or_path,\n        revision=model_args.model_revision,\n        variant=model_args.model_variant,\n        trust_remote_code=model_args.trust_remote_code,\n        attn_implementation=model_args.attn_implementation,\n        torch_dtype=torch_dtype,\n        device_map=get_kbit_device_map() if quantization_config is not None else None,\n        quantization_config=quantization_config,\n        low_cpu_mem_usage=True,\n        token=model_args.token,\n    ).eval()\n\n    if model_args.torch_compile:\n        # torch compile only compatible with gemma and llama\n        if not callable(getattr(model, \"_setup_cache\", None)):\n            raise ValueError(\n                f\"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False\"\n                \"for dynamic k/v cache\"\n            )\n        model.generation_config.cache_implementation = \"static\"\n        # compile the forward pass (but not the top-{p,k} sampling)\n        model = torch.compile(model, mode=\"reduce-overhead\", fullgraph=True)\n\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_args.model_name_or_path,\n        revision=model_args.model_revision,\n        trust_remote_code=model_args.trust_remote_code,\n        use_fast=model_args.use_fast_tokenizer,\n        padding_side=\"left\",\n    )\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.bos_token_id\n        model.generation_config.pad_token_id = model.generation_config.eos_token_id\n\n    speaker_name = data_args.speaker_name\n    is_single_speaker = data_args.is_single_speaker\n    is_new_speaker_prompt = data_args.is_new_speaker_prompt\n    accent_column_name = data_args.accent_column\n\n    def prepare_dataset(sample):\n        sample_prompt = PROMPT\n        if is_single_speaker:\n            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT\n            sample_prompt = sample_prompt.replace(f\"[speaker_name]\", speaker_name)\n        elif (speaker_id_column and speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)):\n            name =  speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)\n            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT\n            sample_prompt = sample_prompt.replace(f\"[speaker_name]\", name)\n        elif is_new_speaker_prompt and accent_column_name is not None:\n            sample_prompt = NEW_PROMPT if sample.get(accent_column_name, \"Unindentified\") == \"Unindentified\" else NEW_PROMPT_WITH_ACCENT\n        elif is_new_speaker_prompt:\n            sample_prompt = NEW_PROMPT\n        for key in EXPECTED_COLUMNS:\n            sample_prompt = sample_prompt.replace(f\"[{key}]\", sample[key])\n        if accent_column_name is not None and sample.get(accent_column_name, \"Unindentified\") != \"Unindentified\":\n            sample_prompt = sample_prompt.replace(\"[accent]\", sample[\"accent\"])\n            \n        sample_prompt = [{\"role\": \"user\", \"content\": sample_prompt}]\n        token_ids = tokenizer.apply_chat_template(sample_prompt)\n        sample[\"input_ids\"] = token_ids\n        return sample\n\n    with accelerator.local_main_process_first():\n        vectorized_datasets = raw_datasets.map(\n            prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc=\"Preparing prompts\"\n        )\n\n    # Prepare everything with our `accelerator`\n    model = accelerator.prepare(model)\n    data_collator = DataCollatorWithPadding(tokenizer)\n\n    def generate_step(batch):\n        output_ids = accelerator.unwrap_model(model).generate(\n            batch[\"input_ids\"],\n            attention_mask=batch[\"attention_mask\"],\n            do_sample=model_args.do_sample,\n            temperature=model_args.temperature,\n            max_new_tokens=model_args.max_new_tokens,\n        )\n        output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)\n        return output_ids\n\n    def postprocess_dataset(batch):\n        prompt_texts = tokenizer.batch_decode(batch[\"input_ids\"], skip_special_tokens=True)\n        generated_texts = tokenizer.batch_decode(batch[\"generated_ids\"], skip_special_tokens=True)\n        \n        batch[\"text_description\"] = [generated_text[len(prompt_text) :] for (prompt_text, generated_text) in zip(prompt_texts, generated_texts)]\n        return batch\n\n    for split in vectorized_datasets:\n        data_loader = DataLoader(\n            vectorized_datasets[split],\n            batch_size=model_args.per_device_eval_batch_size,\n            collate_fn=data_collator,\n            num_workers=data_args.dataloader_num_workers,\n            pin_memory=True,\n        )\n        data_loader = accelerator.prepare(data_loader)\n        total_inference_steps = len(data_loader)\n        progress_bar = tqdm(\n            range(total_inference_steps), desc=\" ... \", position=0, disable=not accelerator.is_local_main_process\n        )\n\n        split_output_dir = os.path.join(data_args.output_dir, split)\n        all_generated_ids, cur_step = get_last_checkpoint(split_output_dir, accelerator.is_local_main_process)\n        accelerator.wait_for_everyone()\n\n        if cur_step > 0:\n            logger.info(f\"Resuming {split} from step {cur_step}\")\n            # efficiently skip the first n batches\n            data_loader = skip_first_batches(data_loader, cur_step)\n            progress_bar.update(cur_step)\n\n        while cur_step < total_inference_steps:\n            for batch in data_loader:\n                generated_ids = generate_step(batch)\n                generated_ids = accelerator.gather_for_metrics(generated_ids)\n                if accelerator.is_local_main_process:\n                    all_generated_ids.extend(generated_ids.cpu().numpy())\n\n                cur_step += 1\n                progress_bar.update(1)\n\n                if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps):\n                    if accelerator.is_main_process:\n                        save_checkpoint(split_output_dir, all_generated_ids, cur_step)\n                        rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir)\n                    accelerator.wait_for_everyone()\n\n        if accelerator.is_local_main_process:\n            vectorized_datasets[split] = vectorized_datasets[split].add_column(\"generated_ids\", all_generated_ids)\n\n        if accelerator.is_main_process:\n            vectorized_datasets[split] = vectorized_datasets[split].map(\n                postprocess_dataset,\n                batched=True,\n                num_proc=data_args.preprocessing_num_workers,\n                desc=\"Postprocessing dataset\",\n                remove_columns=[\"input_ids\", \"generated_ids\"],\n            )\n        accelerator.wait_for_everyone()\n\n    if accelerator.is_main_process:\n        vectorized_datasets.save_to_disk(data_args.output_dir)\n        if data_args.push_to_hub:\n            vectorized_datasets.push_to_hub(\n                data_args.hub_dataset_id,\n                config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else \"default\",\n                token=model_args.token,\n            )\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "scripts/run_prompt_creation_llm_swarm.py",
    "content": "import json\nimport os\nimport re\nimport shutil\nimport sys\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Tuple, List\n\nimport logging\n\nimport math\nfrom datasets import DatasetDict, load_dataset\nfrom tqdm import tqdm\nfrom transformers import (\n    AutoTokenizer,\n    HfArgumentParser,\n)\nimport asyncio\nfrom llm_swarm import LLMSwarm, LLMSwarmConfig\nfrom huggingface_hub import AsyncInferenceClient\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    model_name_or_path: str = field(\n        metadata={\n            \"help\": \"The name of the model to use (via the transformers library) for the prompt annotation.\"\n        },\n    )\n    num_instances: int = field(\n        default=1,\n        metadata={\"help\": \"Number of TGI instances.\"},\n    )\n    per_instance_max_parallel_requests: int = field(\n        default=500,\n        metadata={\"help\": \"Maximum number of parallel requests per instance.\"},\n    )\n    checkpoint_interval: Optional[int] = field(\n        default=1000,\n        metadata={\n            \"help\": \"Interval for streaming chunks of generation.\"\n        },\n    )\n    model_revision: str = field(\n        default=\"main\",\n        metadata={\n            \"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"\n        },\n    )\n    cache_dir: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Where to store the pretrained models downloaded from huggingface.co\"\n        },\n    )\n    do_sample: Optional[bool] = field(\n        default=True, metadata={\"help\": \"Whether to use sampling mode for generation\"}\n    )\n    temperature: Optional[float] = field(\n        default=0.6, metadata={\"help\": \"Temperature for sampling-based generation\"}\n    )\n    max_new_tokens: Optional[int] = field(\n        default=256, metadata={\"help\": \"Maximum number of new tokens during generation\"}\n    )\n    token: Optional[bool] = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub\"\n        },\n    )\n    debug_endpoint: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Endpoint to use for debugging (e.g. http://localhost:13120).\"},\n    )\n    max_retries: Optional[int] = field(\n        default=5,\n        metadata={\"help\": \"Maximum number of retries per sample.\"},\n    )\n    retry_delay_in_s: Optional[float] = field(\n        default=5.0,\n        metadata={\"help\": \"Time to wait between successive retries in seconds.\"},\n    )\n\n\n@dataclass\nclass DataArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    output_dir: str = field(\n        metadata={\n            \"help\": \"Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the \"\n            \"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'.\"\n        },\n    )\n    dataset_name: str = field(\n        default=None,\n        metadata={\"help\": \"The name of the dataset to use (via the datasets library)\"},\n    )\n    dataset_config_name: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"The configuration name of the dataset to use (via the datasets library).\"\n        },\n    )\n    dataset_split_name: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"The split name of the dataset to use (via the datasets library).\"\n        },\n    )\n    dataset_cache_dir: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Path to cache directory for saving and loading datasets\"},\n    )\n    max_eval_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"Maximum number of samples for generation - use for debugging purposes.\"\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False,\n        metadata={\"help\": \"Overwrite the cached training and evaluation sets\"},\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n    )\n    push_to_hub: Optional[bool] = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to push the processed dataset to the Hub.\"},\n    )\n    hub_dataset_id: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Repository namespace if pushing to the Hugging Face Hub.\"},\n    )\n    overwrite_output_dir: Optional[bool] = field(\n        default=False,\n        metadata={\n            \"help\": \"Overwrite the content of the output directory each time the script is run.\"\n        },\n    )\n    save_steps: Optional[int] = field(\n        default=100,\n        metadata={\"help\": \"Save the generated prompts every save_steps.\"},\n    )\n    save_total_limit: Optional[int] = field(\n        default=1, metadata={\"help\": (\"If a value is passed, will limit the total number of saved checkpoints\")}\n    )\n    speaker_name: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"If `is_single_speaker`, it specified the speaker name that you want to give to the mono-speaker of your dataset.\"},\n    )\n    is_single_speaker: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use a single speaker prompt, with a single name, specified by `speaker_name`.\"}\n    )\n    is_new_speaker_prompt: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to use the newest speaker prompt, which will be used for the next Parler-TTS.\"}\n    )\n    speaker_id_column: Optional[str] = field(\n        default=None, metadata={\"help\": \"Speaker id column name. Only used if creating a dataset with multiple speaker names (i.e if `speaker_ids_to_name_json` is specified)\"}\n    )\n    speaker_ids_to_name_json: Optional[str] = field(\n        default=None, metadata={\"help\": \"Path to a JSON file which map some speaker ids to some names. Only used if `speaker_id_column` is specified.\"}\n    )\n    accent_column: Optional[str] = field(\n        default=None, metadata={\"help\": \"Accent column name, if any.\"}\n    )\n\n    def __post_init__(self):\n        if self.push_to_hub and self.hub_dataset_id is None:\n            raise ValueError(\n                \"You must specify the `hub_dataset_id` when setting `--push_to_hub=True`\"\n            )\n\nCHECKPOINT_PREFIX = \"checkpoint\"\n_RE_CHECKPOINT = re.compile(r\"^checkpoint-(\\d+).json$\")\n\n\ndef save_checkpoint(output_dir, all_generated_ids, step):\n    checkpoint_path = f\"{CHECKPOINT_PREFIX}-{step}.json\"\n    output_path = os.path.join(output_dir, checkpoint_path)\n    with open(output_path, \"w\") as file:\n        json.dump(all_generated_ids, file)\n\n\ndef load_checkpoint(checkpoint_path):\n    with open(checkpoint_path, \"r\") as file:\n        all_generated_ids = json.load(file)\n    return all_generated_ids\n\n\ndef sorted_checkpoints(output_dir=None) -> List[str]:\n    \"\"\"Helper function to sort saved checkpoints from oldest to newest.\"\"\"\n    ordering_and_checkpoint_path = []\n\n    glob_checkpoints = [str(x) for x in Path(output_dir).glob(f\"{CHECKPOINT_PREFIX}-*\")]\n\n    for path in glob_checkpoints:\n        regex_match = re.match(f\".*{CHECKPOINT_PREFIX}-([0-9]+)\", path)\n        if regex_match is not None and regex_match.groups() is not None:\n            ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n\n    checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n    return checkpoints_sorted\n\n\ndef rotate_checkpoints(save_total_limit=None, output_dir=None) -> None:\n    \"\"\"Helper function to delete old checkpoints.\"\"\"\n    if save_total_limit is None or save_total_limit <= 0:\n        return\n    # Check if we should delete older checkpoint(s)\n    checkpoints_sorted = sorted_checkpoints(output_dir=output_dir)\n    if len(checkpoints_sorted) <= save_total_limit:\n        return\n\n    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)\n    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n    for checkpoint in checkpoints_to_be_deleted:\n        logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n        os.remove(checkpoint)\n\n\ndef get_last_checkpoint(folder) -> Tuple[List, int]:\n    if not os.path.exists(folder) or not os.path.isdir(folder):\n        os.makedirs(folder, exist_ok=True)\n        return [], 0\n    content = os.listdir(folder)\n    checkpoints = [path for path in content if _RE_CHECKPOINT.search(path) is not None]\n    if len(checkpoints) == 0:\n        return [], 0\n    last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))\n    # Find num steps saved state string pattern\n    pattern = r\"checkpoint-(\\d+).json\"\n    match = re.search(pattern, last_checkpoint)\n    cur_step = int(match.group(1))\n    # load corresponding generated ids\n    all_generated_ids = load_checkpoint(last_checkpoint)\n    return all_generated_ids, cur_step\n\n\n\nPROMPT = \"\"\"You will be given six descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (e.g., male, female)\n2. The level of reverberation (e.g., very roomy sounding, quite roomy sounding, slightly roomy sounding, moderate reverberation, slightly confined sounding, quite confined sounding, very confined sounding)\n3. The amount of noise the sample (e.g., very noisy, quite noisy, slightly noisy, moderate ambient sound, slightly clear, quite clear, very clear)\n4. The tone of the speaker's voice (e.g., very monotone, quite monotone, slightly monotone, moderate intonation, slightly expressive, quite expressive, very expressive)\n5. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)\n6. The pitch of the speaker's voice (e.g., very low pitch, quite low pitch, slightly low pitch, moderate pitch, slightly high pitch, quite high pitch, very high pitch)\nYour task is to create a text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.\nFor example, given the following keywords: 'female', 'slightly roomy sounding', 'slightly noisy', 'very expressive', 'slightly low pitch', 'very slowly', a valid description would be: 'a woman with a deep voice speaks slowly but has an animated delivery in an echoey room with some background noise'.\nFor the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:\"\n\"\"\"\n\nNEW_PROMPT = \"\"\"You will be given six descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (male, female)\n2. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n4. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n5. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n6. The pitch of the speaker's voice (very low-pitch, low-pitch, slightly low-pitch, moderate pitch, slightly high-pitch, high-pitch, very high-pitch)\n\nYour task is to create a text description using these keywords that accurately describes the speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'female', 'slightly distant-sounding', 'noisy', 'very expressive and animated', 'very slowly', 'moderate pitch', a valid description would be: 'A woman speaks very slowly but has a very animated delivery. The recording is noisy and there is some roominess.'\nAnother valid description would be: 'In a noisy room, a female speaker delivers a very animated and expressive speech, at a very slow pace.'\nAnother valid description would be: 'A woman enunciates a very expressive speech. Her voice is slightly distant-sounding, with some background noise present. She speaks very slowly with a moderate pitch but a very expressive tone.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: '[gender]', '[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', '[pitch]', the corresponding description is:\n\"\"\"\n\nNEW_PROMPT_WITH_ACCENT = \"\"\"You will be given 7 descriptive keywords related to an audio sample of a person's speech. These keywords include:\n1. The gender (male, female)\n2. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n4. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n5. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n6. The pitch of the speaker's voice (very low-pitch, low-pitch, slightly low-pitch, moderate pitch, slightly high-pitch, high-pitch, very high-pitch)\n7. The accent of the speaker.\n\nYour task is to create a text description using these keywords that accurately describes the speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'female', 'slightly distant-sounding', 'noisy', 'very expressive and animated', 'very slowly', 'moderate pitch', 'Chinese', a valid description would be: 'A woman with a Chinese accent speaks very slowly but has a very animated delivery. The recording is noisy and there is some roominess.'\nAnother valid description would be: 'In a noisy room, a female speaker with a Chinese accent delivers a very animated and expressive speech, at a very slow pace.'\nAnother valid description would be: 'A woman with a Chinese accent enunciates a very expressive speech. Her voice is slightly distant-sounding, with some background noise present. She speaks very slowly with a moderate pitch but a very expressive tone.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: '[gender]', '[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', '[pitch]', '[accent]', the corresponding description is:\n\"\"\"\n\n\nNEW_SINGLE_SPEAKER_PROMPT = \"\"\"You will be given four descriptive keywords related to an audio sample of [speaker_name]'s speech. These keywords include:\n1. The level of reverberation (very distant-sounding, distant-sounding, slightly distant-sounding, slightly close-sounding, very close-sounding)\n3. The amount of noise in the sample (extremely noisy, very noisy, noisy, slightly noisy, almost no noise, very clear)\n3. The tone of the speaker's voice (very monotone, monotone, slightly expressive and animated, expressive and animated, very expressive and animated)\n4. The pace of the speaker's delivery (very slowly, slowly, slightly slowly, moderate speed, slightly fast, fast, very fast)\n\nYour task is to create a text description using these keywords that accurately describes [speaker_name]'s speech sample.\nIf the amount of noise is 'very noisy' and the level of reverberation is 'very distant-sounding', you must include terms such as 'very poor recording' or `very bad recording` in the description. \nLikewise, if the amount of noise is 'very clear' and the level of reverberation is 'very close-sounding', you must include terms like 'very good recording' or `excellent recording` in the description. \nYou can randomly omit the following terms, as they are default terms: 'moderate speed' and 'moderate pitch'.\nDo not add extra details beyond what has been provided above. You can change the order of keywords, and replace synonymous terms.\n\nFor example, given the following keywords: 'slightly distant-sounding', 'clear', 'very expressive and animated', 'slightly fast', a valid description would be: '[speaker_name] speaks slightly fast but has a very animated delivery in a room with slight echo but no background noise.'\nAnother valid description would be: `In a very animated voice, [speaker_name] delivers words slightly quickly. The room is quite, but there's a bit of echo.'\n\nEnsure that the generated description is grammatically correct, easy to understand, and concise. Only return one and only one description.\n\nFor the keywords: ''[reverberation]', '[sdr_noise]', '[speech_monotony]', '[speaking_rate]', the corresponding description is:\n\"\"\"\n\nSINGLE_SPEAKER_PROMPT = \"\"\"You will be given four descriptive keywords related to an audio sample of [speaker_name]'s speech. These keywords include:\n1. The level of reverberation (e.g., very roomy sounding, quite roomy sounding, slightly roomy sounding, moderate reverberation, slightly confined sounding, quite confined sounding, very confined sounding)\n2. The amount of noise the sample (e.g., very noisy, quite noisy, slightly noisy, moderate ambient sound, slightly clear, quite clear, very clear)\n3. The tone of the speaker's voice (e.g., very monotone, quite monotone, slightly monotone, moderate intonation, slightly expressive, quite expressive, very expressive)\n4. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)\n\nYour task is to create a single and only short text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', you must include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', you must include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.\n\nFor example, given the following keywords: 'slightly roomy sounding', 'quite noisy', 'very expressive', 'very slowly', a valid description would be: '[speaker_name] speaks very slowly but has an animated delivery in an echoey room with background noise.'.\nFeel free to change the order of keywords, and to use synonyms, for example, with the previous keywords: `In a very expressive voice, [speaker_name] pronounces her words incredibly slowly. There's some background noise in this room with a bit of echo.'.\n\nFor the keywords: ''[reverberation]', '[noise]', '[speech_monotony]', '[speaking_rate]', the corresponding description is:\n\"\"\"\n\n# 1. Parse input arguments\nparser = HfArgumentParser((ModelArguments, DataArguments))\nif len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n    # If we pass only one argument to the script and it's the path to a json file,\n    # let's parse it to get our arguments.\n    model_args, data_args = parser.parse_json_file(\n        json_file=os.path.abspath(sys.argv[1])\n    )\nelse:\n    model_args, data_args = parser.parse_args_into_dataclasses()\n\nif data_args.is_single_speaker and data_args.speaker_name is None:\n    raise ValueError(\"`is_single_speaker=True` but `speaker_name` is not specified. Specify it or remove `is_single_speaker`.\")\n\nif not data_args.is_single_speaker and data_args.speaker_name:\n    raise ValueError(f\"`is_single_speaker=False` but `speaker_name=data_args.speaker_name` is not specified. Add `--is_single_speaker` or remove `speaker_name`.\")\n\nEXPECTED_COLUMNS = {\"gender\", \"pitch\", \"noise\", \"reverberation\", \"speech_monotony\", \"speaking_rate\"}\nif data_args.is_single_speaker:\n    EXPECTED_COLUMNS = {\"noise\", \"reverberation\", \"speech_monotony\", \"speaking_rate\"}\n    \nif data_args.is_new_speaker_prompt:\n    EXPECTED_COLUMNS.remove(\"noise\")\n    EXPECTED_COLUMNS.add(\"sdr_noise\")\n\nspeaker_ids_to_name = {}\nspeaker_id_column = data_args.speaker_id_column\nif data_args.speaker_id_column and data_args.speaker_ids_to_name_json:\n    import json\n    if data_args.is_single_speaker:\n        raise ValueError(f\"`is_single_speaker=True` but `speaker_ids_to_name_json={data_args.speaker_ids_to_name_json}`. Specify one or another.\")\n    \n    EXPECTED_COLUMNS.add(data_args.speaker_id_column)\n    with open(data_args.speaker_ids_to_name_json, \"r\") as read_file:\n        speaker_ids_to_name = json.load(read_file)\n\nspeaker_name = data_args.speaker_name\nis_single_speaker = data_args.is_single_speaker\nis_new_speaker_prompt = data_args.is_new_speaker_prompt\naccent_column_name = data_args.accent_column\n    \nwith LLMSwarm(\n    LLMSwarmConfig(\n        instances=model_args.num_instances,\n        inference_engine=\"tgi\",\n        slurm_template_path=\"./tgi_h100.template.slurm\",\n        load_balancer_template_path=\"./nginx.template.conf\",\n        model=model_args.model_name_or_path,\n        revision=model_args.model_revision,\n        per_instance_max_parallel_requests=model_args.per_instance_max_parallel_requests,\n        debug_endpoint=model_args.debug_endpoint,\n    )\n) as llm_swarm:\n    semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests)\n    client = AsyncInferenceClient(model=llm_swarm.endpoint)\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_args.model_name_or_path,\n        revision=model_args.model_revision,\n    )\n\n    async def process_text(sample):\n        sample_prompt = PROMPT\n        if is_single_speaker:\n            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT\n            sample_prompt = sample_prompt.replace(f\"[speaker_name]\", speaker_name)\n        elif (speaker_id_column and speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)):\n            name =  speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)\n            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT\n            sample_prompt = sample_prompt.replace(f\"[speaker_name]\", name)\n        elif is_new_speaker_prompt and accent_column_name is not None:\n            sample_prompt = NEW_PROMPT if sample.get(accent_column_name, \"Unindentified\") == \"Unindentified\" else NEW_PROMPT_WITH_ACCENT\n        elif is_new_speaker_prompt:\n            sample_prompt = NEW_PROMPT\n        for key in EXPECTED_COLUMNS:\n            sample_prompt = sample_prompt.replace(f\"[{key}]\", sample[key])\n        if accent_column_name is not None and sample.get(accent_column_name, \"Unindentified\") != \"Unindentified\":\n            sample_prompt = sample_prompt.replace(\"[accent]\", sample[\"accent\"])\n\n        sample_prompt = [{\"role\": \"user\", \"content\": sample_prompt}]\n        sample_prompt = tokenizer.apply_chat_template(sample_prompt, tokenize=False)\n        attempt = 0\n        while attempt < model_args.max_retries:\n            try:\n                async with semaphore:\n                    return await client.text_generation(\n                        prompt=sample_prompt,\n                        max_new_tokens=model_args.max_new_tokens,\n                        temperature=model_args.temperature,\n                        do_sample=model_args.do_sample,\n                    )\n            except Exception as e:\n                attempt += 1\n                if attempt < model_args.max_retries:\n                    print(\n                        f\"Request failed due to {e}\\nRetrying in {model_args.retry_delay_in_s} seconds... (Attempt {attempt}/{model_args.max_retries})\"\n                    )\n                    await asyncio.sleep(model_args.retry_delay_in_s)\n                else:\n                    raise ValueError(\n                        f\"Max retries reached. Failed with error: {e}.\"\n                    )\n\n    async def main():\n        # 2. Setup logging\n        logger.setLevel(logging.INFO)\n        logging.basicConfig(\n            format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n            datefmt=\"%m/%d/%Y %H:%M:%S\",\n            handlers=[logging.StreamHandler(sys.stdout)],\n        )\n\n        if (\n            data_args.overwrite_output_dir\n            and os.path.exists(data_args.output_dir)\n            and os.path.isdir(data_args.output_dir)\n        ):\n            logger.info(\"Cleaning output dir from previous run...\")\n            shutil.rmtree(data_args.output_dir)\n\n        # 3. Load annotated dataset\n        logger.info(\"*** Load annotated dataset ***\")\n        if data_args.dataset_split_name is not None:\n            raw_datasets = DatasetDict()\n            data_splits = data_args.dataset_split_name.split(\"+\")\n            # load on a split-wise basis\n            for split in data_splits:\n                raw_datasets[split] = load_dataset(\n                    data_args.dataset_name,\n                    data_args.dataset_config_name,\n                    split=split,\n                    cache_dir=model_args.cache_dir,\n                    token=model_args.token,\n                    num_proc=data_args.preprocessing_num_workers,\n                )\n        else:\n            # load all splits for annotation\n            raw_datasets = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                cache_dir=model_args.cache_dir,\n                token=model_args.token,\n                num_proc=data_args.preprocessing_num_workers,\n            )\n\n        raw_datasets_features = set(\n            raw_datasets[next(iter(raw_datasets))].features.keys()\n        )\n\n        if data_args.max_eval_samples is not None:\n            for split in raw_datasets:\n                raw_datasets[split] = raw_datasets[split].select(\n                    range(data_args.max_eval_samples)\n                )\n\n        if not EXPECTED_COLUMNS.issubset(raw_datasets_features):\n            missing_columns = EXPECTED_COLUMNS - raw_datasets_features\n            raise ValueError(\n                f\"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}\"\n            )\n\n        for split in raw_datasets:\n            total_samples = len(raw_datasets[split])\n            total_inference_steps = math.ceil(total_samples / model_args.checkpoint_interval)\n\n            split_output_dir = os.path.join(data_args.output_dir, split)\n            progress_bar = tqdm(range(total_inference_steps), desc=f\"{split}\", position=0)\n\n            all_generated_ids, inference_step = get_last_checkpoint(split_output_dir)\n            if inference_step > 0:\n                logger.info(f\"Resuming {split} from step {inference_step}\")\n                progress_bar.update(inference_step)\n\n            while inference_step < total_inference_steps:\n                start_index = inference_step * model_args.checkpoint_interval\n                end_index = min((inference_step + 1) * model_args.checkpoint_interval, total_samples)\n                inference_chunk = raw_datasets[split].select(range(start_index, end_index))\n                results = await asyncio.gather(\n                    *(process_text(sample) for sample in inference_chunk)\n                )\n                inference_step += 1\n                progress_bar.update(1)\n                all_generated_ids.extend(results)\n\n                if (inference_step % data_args.save_steps == 0) or (inference_step == total_inference_steps):\n                    logger.info(f\"Saving generations of step {inference_step}\")\n                    save_checkpoint(split_output_dir, all_generated_ids, inference_step)\n                    rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir)\n\n            raw_datasets[split] = raw_datasets[split].add_column(\n                \"text_description\", all_generated_ids\n            )\n\n        raw_datasets.save_to_disk(data_args.output_dir)\n        if data_args.push_to_hub:\n            raw_datasets.push_to_hub(\n                data_args.hub_dataset_id,\n                config_name=(\n                    data_args.dataset_config_name\n                    if data_args.dataset_config_name is not None\n                    else \"default\"\n                ),\n                token=model_args.token,\n            )\n\n    asyncio.run(main())\n"
  }
]