SYMBOL INDEX (1239 symbols across 112 files) FILE: app.py function generate_random_string (line 66) | def generate_random_string(length): function resize_video (line 71) | def resize_video(input_path, output_path, target_width, target_height): function _call_nostderr (line 83) | def _call_nostderr(*args, **kwargs): function interrupt (line 96) | def interrupt(): class FileCleaner (line 101) | class FileCleaner: method __init__ (line 102) | def __init__(self, file_lifetime: float = 3600): method add (line 106) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 110) | def _cleanup(self): function make_waveform (line 124) | def make_waveform(*args, **kwargs): function load_model (line 146) | def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=No... function get_audio_info (line 180) | def get_audio_info(audio_path): function info_to_params (line 250) | def info_to_params(audio_path): function info_to_params_a (line 348) | def info_to_params_a(audio_path): function make_pseudo_stereo (line 435) | def make_pseudo_stereo (filename, sr_select, pan, delay): function normalize_audio (line 453) | def normalize_audio(audio_data): function load_diffusion (line 460) | def load_diffusion(): function unload_diffusion (line 467) | def unload_diffusion(): function _do_predictions (line 474) | def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_... function predict_batched (line 638) | def predict_batched(texts, melodies): function add_tags (line 646) | def add_tags(filename, tags): function save_outputs (line 686) | def save_outputs(mp4, wav_tmp, tags, gen_type): function clear_cash (line 735) | def clear_cash(): function s2t (line 766) | def s2t(seconds, seconds2): function calc_time (line 777) | def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6... function predict_full (line 816) | def predict_full(gen_type, model, decoder, custom_model, prompt_amount, ... function get_available_folders (line 935) | def get_available_folders(): function toggle_audio_src (line 941) | def toggle_audio_src(choice): function ui_full (line 948) | def ui_full(launch_kwargs): function ui_batched (line 1695) | def ui_batched(launch_kwargs): FILE: audiocraft/adversarial/discriminators/base.py class MultiDiscriminator (line 19) | class MultiDiscriminator(ABC, nn.Module): method __init__ (line 22) | def __init__(self): method forward (line 26) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: method num_discriminators (line 31) | def num_discriminators(self) -> int: FILE: audiocraft/adversarial/discriminators/mpd.py function get_padding (line 17) | def get_padding(kernel_size: int, dilation: int = 1) -> int: class PeriodDiscriminator (line 21) | class PeriodDiscriminator(nn.Module): method __init__ (line 38) | def __init__(self, period: int, in_channels: int = 1, out_channels: in... method forward (line 58) | def forward(self, x: torch.Tensor): class MultiPeriodDiscriminator (line 79) | class MultiPeriodDiscriminator(MultiDiscriminator): method __init__ (line 88) | def __init__(self, in_channels: int = 1, out_channels: int = 1, method num_discriminators (line 96) | def num_discriminators(self): method forward (line 99) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/discriminators/msd.py class ScaleDiscriminator (line 17) | class ScaleDiscriminator(nn.Module): method __init__ (line 37) | def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Seq... method forward (line 83) | def forward(self, x: torch.Tensor): class MultiScaleDiscriminator (line 95) | class MultiScaleDiscriminator(MultiDiscriminator): method __init__ (line 105) | def __init__(self, in_channels: int = 1, out_channels: int = 1, downsa... method num_discriminators (line 114) | def num_discriminators(self): method forward (line 117) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/discriminators/msstftd.py function get_2d_padding (line 18) | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[i... class DiscriminatorSTFT (line 22) | class DiscriminatorSTFT(nn.Module): method __init__ (line 41) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i... method forward (line 81) | def forward(self, x: torch.Tensor): class MultiScaleSTFTDiscriminator (line 94) | class MultiScaleSTFTDiscriminator(MultiDiscriminator): method __init__ (line 107) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i... method num_discriminators (line 120) | def num_discriminators(self): method _separate_channels (line 123) | def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: method forward (line 127) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/losses.py class AdversarialLoss (line 26) | class AdversarialLoss(nn.Module): method __init__ (line 49) | def __init__(self, method _save_to_state_dict (line 67) | def _save_to_state_dict(self, destination, prefix, keep_vars): method _load_from_state_dict (line 73) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): method get_adversary_pred (line 78) | def get_adversary_pred(self, x): method train_adv (line 89) | def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.T... method forward (line 115) | def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[... function get_adv_criterion (line 138) | def get_adv_criterion(loss_type: str) -> tp.Callable: function get_fake_criterion (line 149) | def get_fake_criterion(loss_type: str) -> tp.Callable: function get_real_criterion (line 158) | def get_real_criterion(loss_type: str) -> tp.Callable: function mse_real_loss (line 167) | def mse_real_loss(x: torch.Tensor) -> torch.Tensor: function mse_fake_loss (line 171) | def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: function hinge_real_loss (line 175) | def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: function hinge_fake_loss (line 179) | def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: function mse_loss (line 183) | def mse_loss(x: torch.Tensor) -> torch.Tensor: function hinge_loss (line 189) | def hinge_loss(x: torch.Tensor) -> torch.Tensor: function hinge2_loss (line 195) | def hinge2_loss(x: torch.Tensor) -> torch.Tensor: class FeatureMatchingLoss (line 201) | class FeatureMatchingLoss(nn.Module): method __init__ (line 209) | def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: boo... method forward (line 214) | def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List... FILE: audiocraft/data/audio.py function _init_av (line 31) | def _init_av(): class AudioFileInfo (line 41) | class AudioFileInfo: function _av_info (line 47) | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function _soundfile_info (line 57) | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function audio_info (line 62) | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function _av_read (line 72) | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, durati... function audio_read (line 116) | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., function audio_write (line 153) | def audio_write(stem_name: tp.Union[str, Path], FILE: audiocraft/data/audio_dataset.py class BaseInfo (line 39) | class BaseInfo: method _dict2fields (line 42) | def _dict2fields(cls, dictionary: dict): method from_dict (line 49) | def from_dict(cls, dictionary: dict): method to_dict (line 53) | def to_dict(self): class AudioMeta (line 61) | class AudioMeta(BaseInfo): method from_dict (line 71) | def from_dict(cls, dictionary: dict): method to_dict (line 77) | def to_dict(self): class SegmentInfo (line 85) | class SegmentInfo(BaseInfo): function _get_audio_meta (line 101) | def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: function _resolve_audio_meta (line 118) | def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: function find_audio_files (line 145) | def find_audio_files(path: tp.Union[Path, str], function load_audio_meta (line 204) | def load_audio_meta(path: tp.Union[str, Path], function save_audio_meta (line 228) | def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): class AudioDataset (line 244) | class AudioDataset: method __init__ (line 295) | def __init__(self, method start_epoch (line 350) | def start_epoch(self, epoch: int): method __len__ (line 353) | def __len__(self): method _get_sampling_probabilities (line 356) | def _get_sampling_probabilities(self, normalized: bool = True): method _get_file_permutation (line 373) | def _get_file_permutation(num_files: int, permutation_index: int, base... method sample_file (line 380) | def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: method _audio_read (line 404) | def _audio_read(self, path: str, seek_time: float = 0, duration: float... method __getitem__ (line 413) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t... method collater (line 462) | def collater(self, samples): method _filter_duration (line 502) | def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioM... method from_meta (line 524) | def from_meta(cls, root: tp.Union[str, Path], **kwargs): method from_path (line 544) | def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, function main (line 562) | def main(): FILE: audiocraft/data/audio_utils.py function convert_audio_channels (line 16) | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torc... function convert_audio (line 49) | def convert_audio(wav: torch.Tensor, from_rate: float, function normalize_loudness (line 57) | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_hea... function _clip_wav (line 86) | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: ... function normalize_audio (line 97) | def normalize_audio(wav: torch.Tensor, normalize: bool = True, function f32_pcm (line 149) | def f32_pcm(wav: torch.Tensor) -> torch.Tensor: function i16_pcm (line 161) | def i16_pcm(wav: torch.Tensor) -> torch.Tensor: FILE: audiocraft/data/info_audio_dataset.py function _clusterify_meta (line 25) | def _clusterify_meta(meta: AudioMeta) -> AudioMeta: function clusterify_all_meta (line 33) | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: class AudioInfo (line 39) | class AudioInfo(SegmentWithAttributes): method to_condition_attributes (line 50) | def to_condition_attributes(self) -> ConditioningAttributes: class InfoAudioDataset (line 54) | class InfoAudioDataset(AudioDataset): method __init__ (line 59) | def __init__(self, meta: tp.List[AudioMeta], **kwargs): method __getitem__ (line 62) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t... function get_keyword_or_keyword_list (line 71) | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.... function get_string (line 79) | def get_string(value: tp.Optional[str]) -> tp.Optional[str]: function get_keyword (line 87) | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: function get_keyword_list (line 95) | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional... FILE: audiocraft/data/music_dataset.py class MusicInfo (line 37) | class MusicInfo(AudioInfo): method has_music_meta (line 57) | def has_music_meta(self) -> bool: method to_condition_attributes (line 60) | def to_condition_attributes(self) -> ConditioningAttributes: method attribute_getter (line 76) | def attribute_getter(attribute): method from_dict (line 92) | def from_dict(cls, dictionary: dict, fields_required: bool = False): function augment_music_info_description (line 115) | def augment_music_info_description(music_info: MusicInfo, merge_text_p: ... class Paraphraser (line 167) | class Paraphraser: method __init__ (line 168) | def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_... method sample_paraphrase (line 175) | def sample_paraphrase(self, audio_path: str, description: str): class MusicDataset (line 187) | class MusicDataset(InfoAudioDataset): method __init__ (line 204) | def __init__(self, *args, info_fields_required: bool = True, method __getitem__ (line 220) | def __getitem__(self, index): function get_musical_key (line 252) | def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: function get_bpm (line 263) | def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: FILE: audiocraft/data/sound_dataset.py class SoundInfo (line 35) | class SoundInfo(SegmentWithAttributes): method has_sound_meta (line 42) | def has_sound_meta(self) -> bool: method to_condition_attributes (line 45) | def to_condition_attributes(self) -> ConditioningAttributes: method attribute_getter (line 57) | def attribute_getter(attribute): method from_dict (line 65) | def from_dict(cls, dictionary: dict, fields_required: bool = False): class SoundDataset (line 87) | class SoundDataset(InfoAudioDataset): method __init__ (line 104) | def __init__( method _get_info_path (line 129) | def _get_info_path(self, path: tp.Union[str, Path]) -> Path: method __getitem__ (line 142) | def __getitem__(self, index): method collater (line 163) | def collater(self, samples): function rms_f (line 173) | def rms_f(x: torch.Tensor) -> torch.Tensor: function normalize (line 177) | def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Ten... function is_clipped (line 185) | def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) ->... function mix_pair (line 189) | def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -... function snr_mixer (line 199) | def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_ov... function snr_mix (line 252) | def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high... function mix_text (line 261) | def mix_text(src_text: str, dst_text: str): function mix_samples (line 268) | def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: fl... FILE: audiocraft/data/zip.py class PathInZip (line 22) | class PathInZip: method __init__ (line 36) | def __init__(self, path: str) -> None: method from_paths (line 42) | def from_paths(cls, zip_path: str, file_path: str): method __str__ (line 45) | def __str__(self) -> str: function _open_zip (line 49) | def _open_zip(path: str, mode: MODE = 'r'): function set_zip_cache_size (line 56) | def set_zip_cache_size(max_size: int): function open_file_in_zip (line 66) | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: FILE: audiocraft/environment.py class AudioCraftEnvironment (line 25) | class AudioCraftEnvironment: method __init__ (line 49) | def __init__(self) -> None: method _get_cluster_config (line 74) | def _get_cluster_config(self) -> omegaconf.DictConfig: method instance (line 79) | def instance(cls): method reset (line 85) | def reset(cls): method get_team (line 90) | def get_team(cls) -> str: method get_cluster (line 97) | def get_cluster(cls) -> str: method get_dora_dir (line 104) | def get_dora_dir(cls) -> Path: method get_reference_dir (line 114) | def get_reference_dir(cls) -> Path: method get_slurm_exclude (line 122) | def get_slurm_exclude(cls) -> tp.Optional[str]: method get_slurm_partitions (line 128) | def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str... method resolve_reference_path (line 146) | def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: method apply_dataset_mappers (line 167) | def apply_dataset_mappers(cls, path: str) -> str: FILE: audiocraft/grids/_base_explorers.py function get_sheep_ping (line 14) | def get_sheep_ping(sheep) -> tp.Optional[str]: class BaseExplorer (line 31) | class BaseExplorer(ABC, Explorer): method stages (line 40) | def stages(self): method get_grid_meta (line 43) | def get_grid_meta(self): method get_grid_metrics (line 55) | def get_grid_metrics(self): method process_sheep (line 60) | def process_sheep(self, sheep, history): FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32): function explorer (line 49) | def explorer(launcher): FILE: audiocraft/grids/compression/_explorers.py class CompressionExplorer (line 12) | class CompressionExplorer(BaseExplorer): method stages (line 15) | def stages(self): method get_grid_meta (line 18) | def get_grid_meta(self): method get_grid_metrics (line 28) | def get_grid_metrics(self): FILE: audiocraft/grids/compression/debug.py function explorer (line 22) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_base_24khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py function explorer (line 17) | def explorer(launcher): FILE: audiocraft/grids/diffusion/_explorers.py class DiffusionExplorer (line 12) | class DiffusionExplorer(BaseExplorer): method stages (line 15) | def stages(self): method get_grid_meta (line 18) | def get_grid_meta(self): method get_grid_metrics (line 28) | def get_grid_metrics(self): FILE: audiocraft/grids/musicgen/_explorers.py class LMExplorer (line 14) | class LMExplorer(BaseExplorer): method stages (line 17) | def stages(self) -> tp.List[str]: method get_grid_metrics (line 20) | def get_grid_metrics(self): method process_sheep (line 45) | def process_sheep(self, sheep, history): class GenerationEvalExplorer (line 69) | class GenerationEvalExplorer(BaseExplorer): method stages (line 72) | def stages(self) -> tp.List[str]: method get_grid_metrics (line 75) | def get_grid_metrics(self): FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32, eval_melody: bool = False): function explorer (line 63) | def explorer(launcher): FILE: audiocraft/losses/balancer.py class Balancer (line 14) | class Balancer: method __init__ (line 61) | def __init__(self, weights: tp.Dict[str, float], balance_grads: bool =... method metrics (line 74) | def metrics(self): method backward (line 77) | def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Te... FILE: audiocraft/losses/sisnr.py function _unfold (line 15) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten... function _center (line 31) | def _center(x: torch.Tensor) -> torch.Tensor: function _norm2 (line 35) | def _norm2(x: torch.Tensor) -> torch.Tensor: class SISNR (line 39) | class SISNR(nn.Module): method __init__ (line 51) | def __init__( method forward (line 64) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor... FILE: audiocraft/losses/specloss.py class MelSpectrogramWrapper (line 18) | class MelSpectrogramWrapper(nn.Module): method __init__ (line 35) | def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_lengt... method forward (line 48) | def forward(self, x): class MelSpectrogramL1Loss (line 65) | class MelSpectrogramL1Loss(torch.nn.Module): method __init__ (line 80) | def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: in... method forward (line 89) | def forward(self, x, y): class MultiScaleMelSpectrogramLoss (line 96) | class MultiScaleMelSpectrogramLoss(nn.Module): method __init__ (line 110) | def __init__(self, sample_rate: int, range_start: int = 6, range_end: ... method forward (line 137) | def forward(self, x, y): FILE: audiocraft/losses/stftloss.py function _stft (line 17) | def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int, class SpectralConvergenceLoss (line 45) | class SpectralConvergenceLoss(nn.Module): method __init__ (line 48) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): method forward (line 52) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): class LogSTFTMagnitudeLoss (line 64) | class LogSTFTMagnitudeLoss(nn.Module): method __init__ (line 70) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): method forward (line 74) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): class STFTLosses (line 86) | class STFTLosses(nn.Module): method __init__ (line 97) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt... method forward (line 109) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.... class STFTLoss (line 129) | class STFTLoss(nn.Module): method __init__ (line 142) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt... method forward (line 151) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.... class MRSTFTLoss (line 164) | class MRSTFTLoss(nn.Module): method __init__ (line 177) | def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_l... method forward (line 189) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: FILE: audiocraft/metrics/chroma_cosinesim.py class ChromaCosineSimilarityMetric (line 14) | class ChromaCosineSimilarityMetric(torchmetrics.Metric): method __init__ (line 28) | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, a... method update (line 38) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method compute (line 69) | def compute(self) -> float: FILE: audiocraft/metrics/clap_consistency.py class TextConsistencyMetric (line 24) | class TextConsistencyMetric(torchmetrics.Metric): method update (line 27) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch... method compute (line 30) | def compute(self): class CLAPTextConsistencyMetric (line 34) | class CLAPTextConsistencyMetric(TextConsistencyMetric): method __init__ (line 47) | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = ... method _initialize_model (line 55) | def _initialize_model(self, model_path: tp.Union[str, Path], model_arc... method _tokenizer (line 63) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: method update (line 67) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch... method compute (line 81) | def compute(self): FILE: audiocraft/metrics/fad.py class FrechetAudioDistanceMetric (line 29) | class FrechetAudioDistanceMetric(torchmetrics.Metric): method __init__ (line 145) | def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path... method reset (line 167) | def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): method update (line 182) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method _get_samples_name (line 222) | def _get_samples_name(self, is_background: bool): method _create_embedding_beams (line 225) | def _create_embedding_beams(self, is_background: bool, gpu_index: tp.O... method _compute_fad_score (line 259) | def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): method _log_process_result (line 283) | def _log_process_result(self, returncode: int, log_file: tp.Union[Path... method _parallel_create_embedding_beams (line 293) | def _parallel_create_embedding_beams(self, num_of_gpus: int): method _sequential_create_embedding_beams (line 303) | def _sequential_create_embedding_beams(self): method _local_compute_frechet_audio_distance (line 313) | def _local_compute_frechet_audio_distance(self): method compute (line 323) | def compute(self) -> float: FILE: audiocraft/metrics/kld.py class _patch_passt_stft (line 22) | class _patch_passt_stft: method __init__ (line 24) | def __init__(self): method __enter__ (line 27) | def __enter__(self): method __exit__ (line 32) | def __exit__(self, *exc): function kl_divergence (line 36) | def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, ... class KLDivergenceMetric (line 53) | class KLDivergenceMetric(torchmetrics.Metric): method __init__ (line 62) | def __init__(self): method _get_label_distribution (line 69) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, method update (line 82) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method compute (line 105) | def compute(self) -> dict: class PasstKLDivergenceMetric (line 116) | class PasstKLDivergenceMetric(KLDivergenceMetric): method __init__ (line 131) | def __init__(self, pretrained_length: tp.Optional[float] = None): method _initialize_model (line 135) | def _initialize_model(self, pretrained_length: tp.Optional[float] = No... method _load_base_model (line 145) | def _load_base_model(self, pretrained_length: tp.Optional[float]): method _process_audio (line 172) | def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len:... method _get_model_preds (line 187) | def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: method _get_label_distribution (line 198) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, FILE: audiocraft/metrics/rvm.py function db_to_scale (line 13) | def db_to_scale(volume: tp.Union[float, torch.Tensor]): function scale_to_db (line 17) | def scale_to_db(scale: torch.Tensor, min_volume: float = -120): class RelativeVolumeMel (line 22) | class RelativeVolumeMel(nn.Module): method __init__ (line 69) | def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: ... method forward (line 84) | def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) ... FILE: audiocraft/metrics/visqol.py class ViSQOL (line 22) | class ViSQOL: method __init__ (line 56) | def __init__(self, bin: tp.Union[Path, str], mode: str = "audio", method _get_target_sr (line 67) | def _get_target_sr(self, mode: str) -> int: method _prepare_files (line 75) | def _prepare_files( method _flush_files (line 132) | def _flush_files(self, tmp_dir: tp.Union[Path, str]): method _collect_moslqo_score (line 136) | def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str])... method _collect_debug_data (line 146) | def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) ->... method visqol_model (line 153) | def visqol_model(self): method _run_visqol (line 156) | def _run_visqol( method __call__ (line 181) | def __call__( FILE: audiocraft/models/audiogen.py class AudioGen (line 25) | class AudioGen: method __init__ (line 36) | def __init__(self, name: str, compression_model: CompressionModel, lm:... method frame_rate (line 59) | def frame_rate(self) -> float: method sample_rate (line 64) | def sample_rate(self) -> int: method audio_channels (line 69) | def audio_channels(self) -> int: method get_pretrained (line 74) | def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): method set_generation_params (line 97) | def set_generation_params(self, use_sampling: bool = True, top_k: int ... method set_custom_progress_callback (line 129) | def set_custom_progress_callback(self, progress_callback: tp.Optional[... method generate (line 133) | def generate(self, descriptions: tp.List[str], progress: bool = False)... method generate_continuation (line 144) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra... method _prepare_tokens_and_attributes (line 168) | def _prepare_tokens_and_attributes( method _generate_tokens (line 193) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], method to (line 273) | def to(self, device: str): FILE: audiocraft/models/builders.py function get_quantizer (line 43) | def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: ... function get_encodec_autoencoder (line 54) | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): function get_compression_model (line 68) | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: function get_lm_model (line 86) | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: function get_conditioner_provider (line 122) | def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig)... function get_condition_fuser (line 159) | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: function get_codebooks_pattern_provider (line 169) | def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) ... function get_debug_compression_model (line 184) | def get_debug_compression_model(device='cpu', sample_rate: int = 32000): function get_diffusion_model (line 211) | def get_diffusion_model(cfg: omegaconf.DictConfig): function get_processor (line 219) | def get_processor(cfg, sample_rate: int = 24000): function get_debug_lm_model (line 230) | def get_debug_lm_model(device='cpu'): function get_wrapped_compression_model (line 248) | def get_wrapped_compression_model( FILE: audiocraft/models/encodec.py class CompressionModel (line 27) | class CompressionModel(ABC, nn.Module): method forward (line 33) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 37) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 42) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 47) | def decode_latent(self, codes: torch.Tensor): method channels (line 53) | def channels(self) -> int: method frame_rate (line 58) | def frame_rate(self) -> float: method sample_rate (line 63) | def sample_rate(self) -> int: method cardinality (line 68) | def cardinality(self) -> int: method num_codebooks (line 73) | def num_codebooks(self) -> int: method total_codebooks (line 78) | def total_codebooks(self) -> int: method set_num_codebooks (line 82) | def set_num_codebooks(self, n: int): method get_pretrained (line 87) | def get_pretrained( class EncodecModel (line 124) | class EncodecModel(CompressionModel): method __init__ (line 143) | def __init__(self, method total_codebooks (line 167) | def total_codebooks(self): method num_codebooks (line 172) | def num_codebooks(self): method set_num_codebooks (line 176) | def set_num_codebooks(self, n: int): method cardinality (line 181) | def cardinality(self): method preprocess (line 185) | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Opt... method postprocess (line 197) | def postprocess(self, method forward (line 205) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 222) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 239) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 256) | def decode_latent(self, codes: torch.Tensor): class DAC (line 261) | class DAC(CompressionModel): method __init__ (line 262) | def __init__(self, model_type: str = "44khz"): method forward (line 273) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 277) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 281) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 286) | def decode_latent(self, codes: torch.Tensor): method channels (line 291) | def channels(self) -> int: method frame_rate (line 295) | def frame_rate(self) -> float: method sample_rate (line 299) | def sample_rate(self) -> int: method cardinality (line 303) | def cardinality(self) -> int: method num_codebooks (line 307) | def num_codebooks(self) -> int: method total_codebooks (line 311) | def total_codebooks(self) -> int: method set_num_codebooks (line 314) | def set_num_codebooks(self, n: int): class HFEncodecCompressionModel (line 322) | class HFEncodecCompressionModel(CompressionModel): method __init__ (line 325) | def __init__(self, model: HFEncodecModel): method forward (line 339) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 343) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 351) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 359) | def decode_latent(self, codes: torch.Tensor): method channels (line 364) | def channels(self) -> int: method frame_rate (line 368) | def frame_rate(self) -> float: method sample_rate (line 373) | def sample_rate(self) -> int: method cardinality (line 377) | def cardinality(self) -> int: method num_codebooks (line 381) | def num_codebooks(self) -> int: method total_codebooks (line 385) | def total_codebooks(self) -> int: method set_num_codebooks (line 388) | def set_num_codebooks(self, n: int): FILE: audiocraft/models/lm.py function get_init_fn (line 36) | def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int... function init_layer (line 64) | def init_layer(m: nn.Module, class ScaledEmbedding (line 97) | class ScaledEmbedding(nn.Embedding): method __init__ (line 100) | def __init__(self, *args, lr=None, **kwargs): method make_optim_group (line 104) | def make_optim_group(self): class LMOutput (line 112) | class LMOutput: class LMModel (line 119) | class LMModel(StreamingModule): method __init__ (line 144) | def __init__(self, pattern_provider: CodebooksPatternProvider, conditi... method _init_weights (line 178) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:... method special_token_id (line 213) | def special_token_id(self) -> int: method num_codebooks (line 217) | def num_codebooks(self) -> int: method forward (line 220) | def forward(self, sequence: torch.Tensor, method compute_predictions (line 264) | def compute_predictions( method _sample_next_token (line 309) | def _sample_next_token(self, method generate (line 381) | def generate(self, FILE: audiocraft/models/loaders.py function get_audiocraft_cache_dir (line 34) | def get_audiocraft_cache_dir() -> tp.Optional[str]: function _get_state_dict (line 38) | def _get_state_dict( function load_compression_model_ckpt (line 67) | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], ... function load_compression_model (line 71) | def load_compression_model(file_or_url_or_id: tp.Union[Path, str], devic... function load_lm_model_ckpt (line 83) | def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir... function _delete_param (line 87) | def _delete_param(cfg: DictConfig, full_name: str): function load_lm_model (line 100) | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', ... function load_mbd_ckpt (line 118) | def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.... function load_diffusion_models (line 122) | def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device... FILE: audiocraft/models/multibanddiffusion.py class DiffusionProcess (line 25) | class DiffusionProcess: method __init__ (line 32) | def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule... method generate (line 38) | def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, class MultiBandDiffusion (line 50) | class MultiBandDiffusion: method __init__ (line 57) | def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: Compre... method sample_rate (line 63) | def sample_rate(self) -> int: method get_mbd_musicgen (line 67) | def get_mbd_musicgen(device=None): method get_mbd_24khz (line 82) | def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, method get_condition (line 116) | def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.... method get_emb (line 129) | def get_emb(self, codes: torch.Tensor): method generate (line 136) | def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = ... method re_eq (line 154) | def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 3... method regenerate (line 170) | def regenerate(self, wav: torch.Tensor, sample_rate: int): method tokens_to_wav (line 185) | def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): FILE: audiocraft/models/musicgen.py class MusicGen (line 39) | class MusicGen: method __init__ (line 50) | def __init__(self, name: str, compression_model: CompressionModel, lm:... method frame_rate (line 73) | def frame_rate(self) -> float: method sample_rate (line 78) | def sample_rate(self) -> int: method audio_channels (line 83) | def audio_channels(self) -> int: method get_pretrained (line 88) | def get_pretrained(name: str = 'GrandaddyShmax/musicgen-melody', devic... method set_generation_params (line 118) | def set_generation_params(self, use_sampling: bool = True, top_k: int ... method set_custom_progress_callback (line 150) | def set_custom_progress_callback(self, progress_callback: tp.Optional[... method generate_unconditional (line 154) | def generate_unconditional(self, num_samples: int, progress: bool = Fa... method generate (line 168) | def generate(self, descriptions: tp.List[str], progress: bool = False,... method generate_with_chroma (line 183) | def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs... method generate_continuation (line 218) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra... method _prepare_tokens_and_attributes (line 246) | def _prepare_tokens_and_attributes( method _generate_tokens (line 303) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], method generate_audio (line 399) | def generate_audio(self, gen_tokens: torch.Tensor): method to (line 406) | def to(self, device: str): FILE: audiocraft/models/unet.py class Output (line 21) | class Output: function get_model (line 25) | def get_model(cfg, channels: int, side: int, num_steps: int): class ResBlock (line 33) | class ResBlock(nn.Module): method __init__ (line 34) | def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, method forward (line 52) | def forward(self, x): class DecoderLayer (line 58) | class DecoderLayer(nn.Module): method __init__ (line 59) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int... method forward (line 72) | def forward(self, x: torch.Tensor) -> torch.Tensor: class EncoderLayer (line 80) | class EncoderLayer(nn.Module): method __init__ (line 81) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int... method forward (line 94) | def forward(self, x: torch.Tensor) -> torch.Tensor: class BLSTM (line 107) | class BLSTM(nn.Module): method __init__ (line 110) | def __init__(self, dim, layers=2): method forward (line 115) | def forward(self, x): class DiffusionUnet (line 123) | class DiffusionUnet(nn.Module): method __init__ (line 124) | def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, gr... method forward (line 163) | def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], ... FILE: audiocraft/modules/activations.py class CustomGLU (line 13) | class CustomGLU(nn.Module): method __init__ (line 33) | def __init__(self, activation: nn.Module, dim: int = -1): method forward (line 38) | def forward(self, x: Tensor): class SwiGLU (line 44) | class SwiGLU(CustomGLU): method __init__ (line 52) | def __init__(self, dim: int = -1): class GeGLU (line 56) | class GeGLU(CustomGLU): method __init__ (line 64) | def __init__(self, dim: int = -1): class ReGLU (line 68) | class ReGLU(CustomGLU): method __init__ (line 76) | def __init__(self, dim: int = -1): function get_activation_fn (line 80) | def get_activation_fn( FILE: audiocraft/modules/chroma.py class ChromaExtractor (line 16) | class ChromaExtractor(nn.Module): method __init__ (line 29) | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: i... method forward (line 46) | def forward(self, wav: torch.Tensor) -> torch.Tensor: FILE: audiocraft/modules/codebooks_patterns.py class Pattern (line 22) | class Pattern: method __post_init__ (line 50) | def __post_init__(self): method _validate_layout (line 58) | def _validate_layout(self): method num_sequence_steps (line 80) | def num_sequence_steps(self): method max_delay (line 84) | def max_delay(self): method valid_layout (line 92) | def valid_layout(self): method get_sequence_coords_with_timestep (line 96) | def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int... method get_steps_with_timestep (line 111) | def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) ... method get_first_step_with_timesteps (line 114) | def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = ... method _build_pattern_sequence_scatter_indexes (line 118) | def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q:... method build_pattern_sequence (line 152) | def build_pattern_sequence(self, z: torch.Tensor, special_token: int, ... method _build_reverted_sequence_scatter_indexes (line 179) | def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int... method revert_pattern_sequence (line 223) | def revert_pattern_sequence(self, s: torch.Tensor, special_token: int,... method revert_pattern_logits (line 248) | def revert_pattern_logits(self, logits: torch.Tensor, special_token: f... class CodebooksPatternProvider (line 270) | class CodebooksPatternProvider(ABC): method __init__ (line 288) | def __init__(self, n_q: int, cached: bool = True): method get_pattern (line 294) | def get_pattern(self, timesteps: int) -> Pattern: class DelayedPatternProvider (line 303) | class DelayedPatternProvider(CodebooksPatternProvider): method __init__ (line 326) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, method get_pattern (line 337) | def get_pattern(self, timesteps: int) -> Pattern: class ParallelPatternProvider (line 356) | class ParallelPatternProvider(DelayedPatternProvider): method __init__ (line 364) | def __init__(self, n_q: int): class UnrolledPatternProvider (line 368) | class UnrolledPatternProvider(CodebooksPatternProvider): method __init__ (line 419) | def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = N... method _build_flattened_codebooks (line 433) | def _build_flattened_codebooks(self, delays: tp.List[int], flattening:... method _num_inner_steps (line 453) | def _num_inner_steps(self): method num_virtual_steps (line 458) | def num_virtual_steps(self, timesteps: int) -> int: method get_pattern (line 461) | def get_pattern(self, timesteps: int) -> Pattern: class VALLEPattern (line 489) | class VALLEPattern(CodebooksPatternProvider): method __init__ (line 498) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): method get_pattern (line 506) | def get_pattern(self, timesteps: int) -> Pattern: class MusicLMPattern (line 521) | class MusicLMPattern(CodebooksPatternProvider): method __init__ (line 529) | def __init__(self, n_q: int, group_by: int = 2): method get_pattern (line 533) | def get_pattern(self, timesteps: int) -> Pattern: FILE: audiocraft/modules/conditioners.py class WavCondition (line 46) | class WavCondition(tp.NamedTuple): class JointEmbedCondition (line 54) | class JointEmbedCondition(tp.NamedTuple): class ConditioningAttributes (line 64) | class ConditioningAttributes: method __getitem__ (line 69) | def __getitem__(self, item): method text_attributes (line 73) | def text_attributes(self): method wav_attributes (line 77) | def wav_attributes(self): method joint_embed_attributes (line 81) | def joint_embed_attributes(self): method attributes (line 85) | def attributes(self): method to_flat_dict (line 92) | def to_flat_dict(self): method from_flat_dict (line 100) | def from_flat_dict(cls, x): class SegmentWithAttributes (line 108) | class SegmentWithAttributes(SegmentInfo): method to_condition_attributes (line 113) | def to_condition_attributes(self) -> ConditioningAttributes: function nullify_condition (line 117) | def nullify_condition(condition: ConditionType, dim: int = 1): function nullify_wav (line 144) | def nullify_wav(cond: WavCondition) -> WavCondition: function nullify_joint_embed (line 163) | def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: class Tokenizer (line 180) | class Tokenizer: method __call__ (line 184) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch... class WhiteSpaceTokenizer (line 188) | class WhiteSpaceTokenizer(Tokenizer): method __init__ (line 197) | def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_... method __call__ (line 210) | def __call__(self, texts: tp.List[tp.Optional[str]], class NoopTokenizer (line 256) | class NoopTokenizer(Tokenizer): method __init__ (line 266) | def __init__(self, n_bins: int, pad_idx: int = 0): method __call__ (line 270) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch... class BaseConditioner (line 286) | class BaseConditioner(nn.Module): method __init__ (line 296) | def __init__(self, dim: int, output_dim: int): method tokenize (line 302) | def tokenize(self, *args, **kwargs) -> tp.Any: method forward (line 310) | def forward(self, inputs: tp.Any) -> ConditionType: class TextConditioner (line 323) | class TextConditioner(BaseConditioner): class LUTConditioner (line 327) | class LUTConditioner(TextConditioner): method __init__ (line 337) | def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: ... method tokenize (line 348) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Ten... method forward (line 354) | def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> Con... class T5Conditioner (line 362) | class T5Conditioner(TextConditioner): method __init__ (line 390) | def __init__(self, name: str, output_dim: int, finetune: bool, device:... method tokenize (line 430) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch... method forward (line 449) | def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: class WaveformConditioner (line 458) | class WaveformConditioner(BaseConditioner): method __init__ (line 469) | def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.d... method tokenize (line 473) | def tokenize(self, x: WavCondition) -> WavCondition: method _get_wav_embedding (line 478) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method _downsampling_factor (line 482) | def _downsampling_factor(self): method forward (line 486) | def forward(self, x: WavCondition) -> ConditionType: class ChromaStemConditioner (line 509) | class ChromaStemConditioner(WaveformConditioner): method __init__ (line 531) | def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, r... method _downsampling_factor (line 554) | def _downsampling_factor(self) -> int: method _load_eval_wavs (line 557) | def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) ->... method reset_eval_wavs (line 578) | def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: method has_eval_wavs (line 581) | def has_eval_wavs(self) -> bool: method _sample_eval_wavs (line 584) | def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: method _get_chroma_len (line 593) | def _get_chroma_len(self) -> int: method _get_stemmed_wav (line 600) | def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> tor... method _extract_chroma (line 614) | def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: method _compute_wav_embedding (line 620) | def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) ... method _get_full_chroma_for_cache (line 630) | def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: Wav... method _extract_chroma_chunk (line 638) | def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondi... method _get_wav_embedding (line 654) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method tokenize (line 688) | def tokenize(self, x: WavCondition) -> WavCondition: class JointEmbeddingConditioner (line 698) | class JointEmbeddingConditioner(BaseConditioner): method __init__ (line 712) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ... method _get_embed (line 731) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,... method forward (line 740) | def forward(self, x: JointEmbedCondition) -> ConditionType: method tokenize (line 755) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: class CLAPEmbeddingConditioner (line 759) | class CLAPEmbeddingConditioner(JointEmbeddingConditioner): method __init__ (line 786) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ... method _tokenizer (line 825) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: method _compute_text_embedding (line 829) | def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: method _get_text_embedding_for_cache (line 841) | def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], method _preprocess_wav (line 848) | def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sam... method _compute_wav_embedding (line 869) | def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, method _get_wav_embedding_for_cache (line 904) | def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], method _extract_wav_embedding_chunk (line 920) | def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: Jo... method _get_text_embedding (line 941) | def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: method _get_wav_embedding (line 955) | def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: method tokenize (line 968) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: method _get_embed (line 981) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,... function dropout_condition (line 994) | def dropout_condition(sample: ConditioningAttributes, condition_type: st... class DropoutModule (line 1025) | class DropoutModule(nn.Module): method __init__ (line 1027) | def __init__(self, seed: int = 1234): class AttributeDropout (line 1033) | class AttributeDropout(DropoutModule): method __init__ (line 1050) | def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eva... method forward (line 1058) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List... method __repr__ (line 1076) | def __repr__(self): class ClassifierFreeGuidanceDropout (line 1080) | class ClassifierFreeGuidanceDropout(DropoutModule): method __init__ (line 1088) | def __init__(self, p: float, seed: int = 1234): method forward (line 1092) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List... method __repr__ (line 1115) | def __repr__(self): class ConditioningProvider (line 1119) | class ConditioningProvider(nn.Module): method __init__ (line 1126) | def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device... method joint_embed_conditions (line 1132) | def joint_embed_conditions(self): method has_joint_embed_conditions (line 1136) | def has_joint_embed_conditions(self): method text_conditions (line 1140) | def text_conditions(self): method wav_conditions (line 1144) | def wav_conditions(self): method has_wav_condition (line 1148) | def has_wav_condition(self): method tokenize (line 1151) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict... method forward (line 1179) | def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, Con... method _collate_text (line 1197) | def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> t... method _collate_wavs (line 1224) | def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> t... method _collate_joint_embeds (line 1268) | def _collate_joint_embeds(self, samples: tp.List[ConditioningAttribute... class ConditionFuser (line 1322) | class ConditionFuser(StreamingModule): method __init__ (line 1339) | def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attent... method forward (line 1353) | def forward( FILE: audiocraft/modules/conv.py function apply_parametrization_norm (line 21) | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): function get_norm_module (line 33) | def get_norm_module(module: nn.Module, causal: bool = False, norm: str =... function get_extra_padding_for_conv1d (line 47) | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stri... function pad_for_conv1d (line 56) | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, paddi... function pad1d (line 71) | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'co... function unpad1d (line 91) | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): class NormConv1d (line 100) | class NormConv1d(nn.Module): method __init__ (line 104) | def __init__(self, *args, causal: bool = False, norm: str = 'none', method forward (line 111) | def forward(self, x): class NormConv2d (line 117) | class NormConv2d(nn.Module): method __init__ (line 121) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str... method forward (line 127) | def forward(self, x): class NormConvTranspose1d (line 133) | class NormConvTranspose1d(nn.Module): method __init__ (line 137) | def __init__(self, *args, causal: bool = False, norm: str = 'none', method forward (line 144) | def forward(self, x): class NormConvTranspose2d (line 150) | class NormConvTranspose2d(nn.Module): method __init__ (line 154) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str... method forward (line 159) | def forward(self, x): class StreamableConv1d (line 165) | class StreamableConv1d(nn.Module): method __init__ (line 169) | def __init__(self, in_channels: int, out_channels: int, method forward (line 185) | def forward(self, x): class StreamableConvTranspose1d (line 204) | class StreamableConvTranspose1d(nn.Module): method __init__ (line 208) | def __init__(self, in_channels: int, out_channels: int, method forward (line 221) | def forward(self, x): FILE: audiocraft/modules/diffusion_schedule.py function betas_from_alpha_bar (line 20) | def betas_from_alpha_bar(alpha_bar): class SampleProcessor (line 25) | class SampleProcessor(torch.nn.Module): method project_sample (line 26) | def project_sample(self, x: torch.Tensor): method return_sample (line 30) | def return_sample(self, z: torch.Tensor): class MultiBandProcessor (line 35) | class MultiBandProcessor(SampleProcessor): method __init__ (line 57) | def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, method mean (line 77) | def mean(self): method std (line 82) | def std(self): method target_std (line 87) | def target_std(self): method project_sample (line 91) | def project_sample(self, x: torch.Tensor): method return_sample (line 104) | def return_sample(self, x: torch.Tensor): class NoiseSchedule (line 112) | class NoiseSchedule: method __init__ (line 127) | def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_s... method get_beta (line 149) | def get_beta(self, step: tp.Union[int, torch.Tensor]): method get_initial_noise (line 155) | def get_initial_noise(self, x: torch.Tensor): method get_alpha_bar (line 160) | def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]]... method get_training_item (line 169) | def get_training_item(self, x: torch.Tensor, tensor_step: bool = False... method generate (line 192) | def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.... method generate_subsampled (line 238) | def generate_subsampled(self, model: torch.nn.Module, initial: torch.T... FILE: audiocraft/modules/lstm.py class StreamableLSTM (line 10) | class StreamableLSTM(nn.Module): method __init__ (line 14) | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = T... method forward (line 19) | def forward(self, x): FILE: audiocraft/modules/rope.py class XPos (line 13) | class XPos(nn.Module): method __init__ (line 24) | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int =... method get_decay (line 38) | def get_decay(self, start: int, end: int): class RotaryEmbedding (line 49) | class RotaryEmbedding(nn.Module): method __init__ (line 60) | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool =... method get_rotation (line 75) | def get_rotation(self, start: int, end: int): method rotate (line 84) | def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool =... method rotate_qk (line 103) | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int... FILE: audiocraft/modules/seanet.py class SEANetResnetBlock (line 16) | class SEANetResnetBlock(nn.Module): method __init__ (line 33) | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dila... method forward (line 59) | def forward(self, x): class SEANetEncoder (line 63) | class SEANetEncoder(nn.Module): method __init__ (line 91) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:... method forward (line 152) | def forward(self, x): class SEANetDecoder (line 156) | class SEANetDecoder(nn.Module): method __init__ (line 186) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:... method forward (line 256) | def forward(self, z): FILE: audiocraft/modules/streaming.py class StreamingModule (line 20) | class StreamingModule(nn.Module): method __init__ (line 43) | def __init__(self) -> None: method _apply_named_streaming (line 48) | def _apply_named_streaming(self, fn: tp.Any): method _set_streaming (line 53) | def _set_streaming(self, streaming: bool): method streaming (line 59) | def streaming(self): method reset_streaming (line 68) | def reset_streaming(self): method get_streaming_state (line 75) | def get_streaming_state(self) -> State: method set_streaming_state (line 88) | def set_streaming_state(self, state: State): method flush (line 107) | def flush(self, x: tp.Optional[torch.Tensor] = None): class StreamingSequential (line 122) | class StreamingSequential(StreamingModule, nn.Sequential): method flush (line 125) | def flush(self, x: tp.Optional[torch.Tensor] = None): FILE: audiocraft/modules/transformer.py function set_efficient_attention_backend (line 31) | def set_efficient_attention_backend(backend: str = 'torch'): function _get_attention_time_dimension (line 38) | def _get_attention_time_dimension() -> int: function _is_profiled (line 45) | def _is_profiled() -> bool: function create_norm_fn (line 54) | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: function create_sin_embedding (line 70) | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: ... function expand_repeated_kv (line 92) | def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class LayerScale (line 112) | class LayerScale(nn.Module): method __init__ (line 123) | def __init__(self, channels: int, init: float = 1e-4, channel_last: bo... method forward (line 131) | def forward(self, x: torch.Tensor): class StreamingMultiheadAttention (line 138) | class StreamingMultiheadAttention(StreamingModule): method __init__ (line 164) | def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.... method _load_from_state_dict (line 224) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): method _get_mask (line 233) | def _get_mask(self, current_steps: int, device: torch.device, dtype: t... method _complete_kv (line 266) | def _complete_kv(self, k, v): method _apply_rope (line 300) | def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): method forward (line 316) | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch... class StreamingTransformerLayer (line 445) | class StreamingTransformerLayer(nn.TransformerEncoderLayer): method __init__ (line 479) | def __init__(self, d_model: int, num_heads: int, dim_feedforward: int ... method _cross_attention_block (line 533) | def _cross_attention_block(self, src: torch.Tensor, method forward (line 541) | def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tenso... class StreamingTransformer (line 568) | class StreamingTransformer(StreamingModule): method __init__ (line 605) | def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_... method _apply_layer (line 654) | def _apply_layer(self, layer, *args, **kwargs): method forward (line 685) | def forward(self, x: torch.Tensor, *args, **kwargs): method make_optim_group (line 707) | def make_optim_group(self): function _verify_xformers_memory_efficient_compat (line 718) | def _verify_xformers_memory_efficient_compat(): function _verify_xformers_internal_compat (line 732) | def _verify_xformers_internal_compat(): function _is_custom (line 746) | def _is_custom(custom: bool, memory_efficient: bool): FILE: audiocraft/optim/cosine_lr_scheduler.py class CosineLRScheduler (line 13) | class CosineLRScheduler(_LRScheduler): method __init__ (line 23) | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_step... method _get_sched_lr (line 33) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 47) | def get_lr(self): FILE: audiocraft/optim/dadam.py function to_real (line 23) | def to_real(x): class DAdaptAdam (line 30) | class DAdaptAdam(torch.optim.Optimizer): method __init__ (line 62) | def __init__(self, params, lr=1.0, method supports_memory_efficient_fp16 (line 99) | def supports_memory_efficient_fp16(self): method supports_flat_params (line 103) | def supports_flat_params(self): method step (line 106) | def step(self, closure=None): FILE: audiocraft/optim/ema.py function _get_all_non_persistent_buffers_set (line 17) | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "... function _get_named_tensors (line 32) | def _get_named_tensors(module: nn.Module): class ModuleDictEMA (line 40) | class ModuleDictEMA: method __init__ (line 45) | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, method _init (line 55) | def _init(self): method step (line 64) | def step(self): method state_dict (line 78) | def state_dict(self): method load_state_dict (line 81) | def load_state_dict(self, state): FILE: audiocraft/optim/fsdp.py function is_fsdp_used (line 22) | def is_fsdp_used() -> bool: function is_sharded_tensor (line 32) | def is_sharded_tensor(x: tp.Any) -> bool: function switch_to_full_state_dict (line 37) | def switch_to_full_state_dict(models: tp.List[FSDP]): function wrap_with_fsdp (line 51) | def wrap_with_fsdp(cfg, model: torch.nn.Module, function purge_fsdp (line 120) | def purge_fsdp(model: FSDP): class _FSDPFixStateDict (line 138) | class _FSDPFixStateDict(FSDP): method _name_without_fsdp_prefix (line 140) | def _name_without_fsdp_prefix(name: str) -> str: method state_dict (line 146) | def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore method load_state_dict (line 153) | def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore function _fix_post_backward_hook (line 175) | def _fix_post_backward_hook(): FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py class InverseSquareRootLRScheduler (line 13) | class InverseSquareRootLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini... method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 37) | def get_lr(self): FILE: audiocraft/optim/linear_warmup_lr_scheduler.py class LinearWarmupLRScheduler (line 13) | class LinearWarmupLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini... method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 34) | def get_lr(self): FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py class PolynomialDecayLRScheduler (line 11) | class PolynomialDecayLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_step... method _get_sched_lr (line 31) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 46) | def get_lr(self): FILE: audiocraft/quantization/base.py class QuantizedResult (line 19) | class QuantizedResult: class BaseQuantizer (line 27) | class BaseQuantizer(nn.Module): method forward (line 31) | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: method encode (line 40) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 44) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 49) | def total_codebooks(self): method num_codebooks (line 54) | def num_codebooks(self): method set_num_codebooks (line 58) | def set_num_codebooks(self, n: int): class DummyQuantizer (line 63) | class DummyQuantizer(BaseQuantizer): method __init__ (line 66) | def __init__(self): method forward (line 69) | def forward(self, x: torch.Tensor, frame_rate: int): method encode (line 73) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 80) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 88) | def total_codebooks(self): method num_codebooks (line 93) | def num_codebooks(self): method set_num_codebooks (line 97) | def set_num_codebooks(self, n: int): FILE: audiocraft/quantization/core_vq.py function exists (line 16) | def exists(val: tp.Optional[tp.Any]) -> bool: function default (line 20) | def default(val: tp.Any, d: tp.Any) -> tp.Any: function l2norm (line 24) | def l2norm(t): function ema_inplace (line 28) | def ema_inplace(moving_avg, new, decay: float): function laplace_smoothing (line 32) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): function uniform_init (line 36) | def uniform_init(*shape: int): function sample_vectors (line 42) | def sample_vectors(samples, num: int): function kmeans (line 53) | def kmeans(samples, num_clusters: int, num_iters: int = 10): function orthogonal_loss_fn (line 78) | def orthogonal_loss_fn(t): class EuclideanCodebook (line 87) | class EuclideanCodebook(nn.Module): method __init__ (line 103) | def __init__( method init_embed_ (line 130) | def init_embed_(self, data): method replace_ (line 142) | def replace_(self, samples, mask): method expire_codes_ (line 148) | def expire_codes_(self, batch_samples): method preprocess (line 160) | def preprocess(self, x): method quantize (line 164) | def quantize(self, x): method postprocess_emb (line 174) | def postprocess_emb(self, embed_ind, shape): method dequantize (line 177) | def dequantize(self, embed_ind): method encode (line 181) | def encode(self, x): method decode (line 191) | def decode(self, embed_ind): method forward (line 195) | def forward(self, x): class VectorQuantization (line 222) | class VectorQuantization(nn.Module): method __init__ (line 245) | def __init__( method codebook (line 284) | def codebook(self): method inited (line 288) | def inited(self): method _preprocess (line 291) | def _preprocess(self, x): method _postprocess (line 296) | def _postprocess(self, quantize): method encode (line 301) | def encode(self, x): method decode (line 307) | def decode(self, embed_ind): method forward (line 313) | def forward(self, x): class ResidualVectorQuantization (line 352) | class ResidualVectorQuantization(nn.Module): method __init__ (line 357) | def __init__(self, *, num_quantizers, **kwargs): method forward (line 363) | def forward(self, x, n_q: tp.Optional[int] = None): method encode (line 382) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor... method decode (line 394) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor: FILE: audiocraft/quantization/vq.py class ResidualVectorQuantizer (line 16) | class ResidualVectorQuantizer(BaseQuantizer): method __init__ (line 35) | def __init__( method forward (line 76) | def forward(self, x: torch.Tensor, frame_rate: int): method encode (line 87) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 98) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 106) | def total_codebooks(self): method num_codebooks (line 110) | def num_codebooks(self): method set_num_codebooks (line 113) | def set_num_codebooks(self, n: int): FILE: audiocraft/solvers/audiogen.py class AudioGenSolver (line 10) | class AudioGenSolver(musicgen.MusicGenSolver): FILE: audiocraft/solvers/base.py class StandardSolver (line 27) | class StandardSolver(ABC, flashy.BaseSolver): method __init__ (line 38) | def __init__(self, cfg: omegaconf.DictConfig): method autocast (line 98) | def autocast(self): method _get_state_source (line 102) | def _get_state_source(self, name) -> flashy.state.StateDictSource: method best_metric_name (line 107) | def best_metric_name(self) -> tp.Optional[str]: method register_best_state (line 114) | def register_best_state(self, *args: str): method register_ema (line 127) | def register_ema(self, *args: str): method wrap_with_fsdp (line 141) | def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): method update_best_state_from_stage (line 147) | def update_best_state_from_stage(self, stage_name: str = 'valid'): method _load_new_state_dict (line 189) | def _load_new_state_dict(self, state_dict: dict) -> dict: method swap_best_state (line 198) | def swap_best_state(self): method swap_ema_state (line 210) | def swap_ema_state(self): method is_training (line 226) | def is_training(self): method log_model_summary (line 229) | def log_model_summary(self, model: nn.Module): method build_model (line 236) | def build_model(self): method initialize_ema (line 240) | def initialize_ema(self): method build_dataloaders (line 256) | def build_dataloaders(self): method show (line 261) | def show(self): method log_updates (line 266) | def log_updates(self): method checkpoint_path (line 270) | def checkpoint_path(self, **kwargs): method epoch_checkpoint_path (line 274) | def epoch_checkpoint_path(self, epoch: int, **kwargs): method checkpoint_path_with_name (line 278) | def checkpoint_path_with_name(self, name: str, **kwargs): method save_checkpoints (line 282) | def save_checkpoints(self): method load_from_pretrained (line 311) | def load_from_pretrained(self, name: str) -> dict: method load_checkpoints (line 314) | def load_checkpoints(self, load_best: bool = False, ignore_state_keys:... method restore (line 432) | def restore(self, load_best: bool = False, replay_metrics: bool = False, method commit (line 456) | def commit(self, save_checkpoints: bool = True): method run_epoch (line 466) | def run_epoch(self): method run (line 489) | def run(self): method should_stop_training (line 501) | def should_stop_training(self) -> bool: method should_run_stage (line 505) | def should_run_stage(self, stage_name) -> bool: method run_step (line 513) | def run_step(self, idx: int, batch: tp.Any, metrics: dict): method common_train_valid (line 517) | def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): method train (line 559) | def train(self): method valid (line 563) | def valid(self): method evaluate (line 568) | def evaluate(self): method generate (line 573) | def generate(self): method run_one_stage (line 577) | def run_one_stage(self, stage_name: str): method get_eval_solver_from_sig (line 597) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, FILE: audiocraft/solvers/builders.py class DatasetType (line 36) | class DatasetType(Enum): function get_solver (line 42) | def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: function get_optim_parameter_groups (line 59) | def get_optim_parameter_groups(model: nn.Module): function get_optimizer (line 86) | def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]]... function get_lr_scheduler (line 115) | def get_lr_scheduler(optimizer: torch.optim.Optimizer, function get_ema (line 159) | def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp... function get_loss (line 180) | def get_loss(loss_name: str, cfg: omegaconf.DictConfig): function get_balancer (line 194) | def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictC... function get_adversary (line 200) | def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: function get_adversarial_losses (line 211) | def get_adversarial_losses(cfg) -> nn.ModuleDict: function get_visqol (line 244) | def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: function get_fad (line 250) | def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMe... function get_kldiv (line 258) | def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: function get_text_consistency (line 268) | def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsi... function get_chroma_cosine_similarity (line 278) | def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.C... function get_audio_datasets (line 285) | def get_audio_datasets(cfg: omegaconf.DictConfig, FILE: audiocraft/solvers/compression.py class CompressionSolver (line 27) | class CompressionSolver(base.StandardSolver): method __init__ (line 34) | def __init__(self, cfg: omegaconf.DictConfig): method best_metric_name (line 55) | def best_metric_name(self) -> tp.Optional[str]: method build_model (line 59) | def build_model(self): method build_dataloaders (line 68) | def build_dataloaders(self): method show (line 72) | def show(self): method run_step (line 83) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): method run_epoch (line 176) | def run_epoch(self): method evaluate (line 183) | def evaluate(self): method generate (line 213) | def generate(self): method load_from_pretrained (line 236) | def load_from_pretrained(self, name: str) -> dict: method model_from_checkpoint (line 269) | def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], method wrapped_model_from_checkpoint (line 304) | def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, function evaluate_audio_reconstruction (line 320) | def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor,... FILE: audiocraft/solvers/diffusion.py class PerStageMetrics (line 25) | class PerStageMetrics: method __init__ (line 30) | def __init__(self, num_steps: int, num_stages: int = 4): method __call__ (line 34) | def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): class DataProcess (line 53) | class DataProcess: method __init__ (line 67) | def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, us... method process_data (line 95) | def process_data(self, x, metric=False): method inverse_process (line 107) | def inverse_process(self, x): class DiffusionSolver (line 114) | class DiffusionSolver(base.StandardSolver): method __init__ (line 122) | def __init__(self, cfg: omegaconf.DictConfig): method best_metric_name (line 155) | def best_metric_name(self) -> tp.Optional[str]: method get_condition (line 162) | def get_condition(self, wav: torch.Tensor) -> torch.Tensor: method build_model (line 168) | def build_model(self): method build_dataloaders (line 178) | def build_dataloaders(self): method show (line 182) | def show(self): method run_step (line 186) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): method run_epoch (line 215) | def run_epoch(self): method evaluate (line 223) | def evaluate(self): method regenerate (line 253) | def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] =... method generate (line 262) | def generate(self): FILE: audiocraft/solvers/musicgen.py class MusicGenSolver (line 30) | class MusicGenSolver(base.StandardSolver): method __init__ (line 37) | def __init__(self, cfg: omegaconf.DictConfig): method get_eval_solver_from_sig (line 64) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, method get_formatter (line 100) | def get_formatter(self, stage_name: str) -> flashy.Formatter: method best_metric_name (line 109) | def best_metric_name(self) -> tp.Optional[str]: method build_model (line 112) | def build_model(self) -> None: method build_dataloaders (line 163) | def build_dataloaders(self) -> None: method show (line 167) | def show(self) -> None: method load_state_dict (line 174) | def load_state_dict(self, state: dict) -> None: method load_from_pretrained (line 185) | def load_from_pretrained(self, name: str): method _compute_cross_entropy (line 195) | def _compute_cross_entropy( method _prepare_tokens_and_attributes (line 230) | def _prepare_tokens_and_attributes( method run_step (line 330) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg... method run_generate_step (line 404) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm... method generate_audio (line 471) | def generate_audio(self) -> dict: method generate (line 561) | def generate(self) -> dict: method run_epoch (line 567) | def run_epoch(self): method train (line 573) | def train(self): method evaluate_audio_generation (line 586) | def evaluate_audio_generation(self) -> dict: method evaluate (line 691) | def evaluate(self) -> dict: FILE: audiocraft/train.py function resolve_config_dset_paths (line 29) | def resolve_config_dset_paths(cfg): function get_solver (line 37) | def get_solver(cfg): function get_solver_from_xp (line 51) | def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, ... function get_solver_from_sig (line 96) | def get_solver_from_sig(sig: str, *args, **kwargs): function init_seed_and_system (line 104) | def init_seed_and_system(cfg): function main (line 125) | def main(cfg): FILE: audiocraft/utils/autocast.py class TorchAutocast (line 10) | class TorchAutocast: method __init__ (line 21) | def __init__(self, enabled: bool, *args, **kwargs): method __enter__ (line 24) | def __enter__(self): method __exit__ (line 37) | def __exit__(self, *args, **kwargs): FILE: audiocraft/utils/best_state.py class BestStateDictManager (line 21) | class BestStateDictManager(flashy.state.StateDictSource): method __init__ (line 36) | def __init__(self, device: tp.Union[torch.device, str] = 'cpu', method _get_parameter_ids (line 43) | def _get_parameter_ids(self, state_dict): method _validate_no_parameter_ids_overlap (line 46) | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): method update (line 53) | def update(self, name: str, source: flashy.state.StateDictSource): method register (line 58) | def register(self, name: str, source: flashy.state.StateDictSource): method state_dict (line 75) | def state_dict(self) -> flashy.state.StateDict: method load_state_dict (line 78) | def load_state_dict(self, state: flashy.state.StateDict): FILE: audiocraft/utils/cache.py function get_full_embed (line 24) | def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device... class EmbeddingCache (line 39) | class EmbeddingCache: method __init__ (line 60) | def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, t... method _get_cache_path (line 79) | def _get_cache_path(self, path: tp.Union[Path, str]): method _get_full_embed_from_cache (line 85) | def _get_full_embed_from_cache(cache: Path): method get_embed_from_cache (line 94) | def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> tor... method populate_embed_cache (line 124) | def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: class CachedBatchWriter (line 161) | class CachedBatchWriter: method __init__ (line 180) | def __init__(self, cache_folder: Path): method start_epoch (line 185) | def start_epoch(self, epoch: int): method _get_zip_path (line 193) | def _get_zip_path(cache_folder: Path, epoch: int, index: int): method _zip_path (line 197) | def _zip_path(self): method save (line 201) | def save(self, *content): class CachedBatchLoader (line 224) | class CachedBatchLoader: method __init__ (line 237) | def __init__(self, cache_folder: Path, batch_size: int, method __len__ (line 246) | def __len__(self): method start_epoch (line 250) | def start_epoch(self, epoch: int): method _zip_path (line 255) | def _zip_path(self, index: int): method _load_one (line 259) | def _load_one(self, index: int): method __iter__ (line 296) | def __iter__(self): FILE: audiocraft/utils/checkpoint.py class CheckpointSource (line 22) | class CheckpointSource(Enum): function checkpoint_name (line 28) | def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int... function is_sharded_checkpoint (line 51) | def is_sharded_checkpoint(path: Path) -> bool: function resolve_checkpoint_path (line 56) | def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.O... function load_checkpoint (line 87) | def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> ... function save_checkpoint (line 98) | def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bo... function flush_stale_checkpoints (line 104) | def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optiona... function check_sharded_checkpoint (line 125) | def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_pat... function _safe_save_checkpoint (line 142) | def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_shard... FILE: audiocraft/utils/cluster.py class ClusterType (line 19) | class ClusterType(Enum): function _guess_cluster_type (line 27) | def _guess_cluster_type() -> ClusterType: function get_cluster_type (line 45) | def get_cluster_type( function get_slurm_parameters (line 54) | def get_slurm_parameters( FILE: audiocraft/utils/deadlock.py class DeadlockDetect (line 18) | class DeadlockDetect: method __init__ (line 19) | def __init__(self, use: bool = False, timeout: float = 120.): method update (line 24) | def update(self, stage: str): method __enter__ (line 28) | def __enter__(self): method __exit__ (line 33) | def __exit__(self, exc_type, exc_val, exc_tb): method _detector_thread (line 38) | def _detector_thread(self): FILE: audiocraft/utils/export.py function export_encodec (line 20) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un... function export_pretrained_compression_model (line 36) | def export_pretrained_compression_model(pretrained_encodec: str, out_fil... function export_lm (line 61) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P... FILE: audiocraft/utils/export_legacy.py function _clean_lm_cfg (line 18) | def _clean_lm_cfg(cfg: DictConfig): function export_encodec (line 33) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.... function export_lm (line 46) | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union... FILE: audiocraft/utils/notebook.py function display_audio (line 17) | def display_audio(samples: torch.Tensor, sample_rate: int): FILE: audiocraft/utils/profiler.py class Profiler (line 17) | class Profiler: method __init__ (line 20) | def __init__(self, module: torch.nn.Module, enabled: bool = False): method step (line 28) | def step(self): method __enter__ (line 32) | def __enter__(self): method __exit__ (line 36) | def __exit__(self, exc_type, exc_value, exc_tb): FILE: audiocraft/utils/samples/manager.py class ReferenceSample (line 42) | class ReferenceSample: class Sample (line 49) | class Sample: method __hash__ (line 59) | def __hash__(self): method audio (line 62) | def audio(self) -> tp.Tuple[torch.Tensor, int]: method audio_prompt (line 65) | def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: method audio_reference (line 68) | def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: class SampleManager (line 72) | class SampleManager: method __init__ (line 89) | def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = Fal... method latest_epoch (line 98) | def latest_epoch(self): method _load_samples (line 102) | def _load_samples(self): method _load_sample (line 110) | def _load_sample(json_file: Path) -> Sample: method _init_hash (line 126) | def _init_hash(self): method _get_tensor_id (line 129) | def _get_tensor_id(self, tensor: torch.Tensor) -> str: method _get_sample_id (line 134) | def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Ten... method _store_audio (line 173) | def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: ... method add_sample (line 196) | def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int ... method add_samples (line 238) | def add_samples(self, samples_wavs: torch.Tensor, epoch: int, method get_samples (line 269) | def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_pr... function slugify (line 305) | def slugify(value: tp.Any, allow_unicode: bool = False): function _match_stable_samples (line 328) | def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp... function _match_unstable_samples (line 343) | def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> ... function get_samples_for_xps (line 358) | def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str,... FILE: audiocraft/utils/ui.py class ToolButton (line 8) | class ToolButton(gr.Button, gr.components.IOComponent): method __init__ (line 11) | def __init__(self, **kwargs): method get_block_name (line 14) | def get_block_name(self): function create_refresh_button (line 18) | def create_refresh_button(refresh_component, refresh_method, refreshed_a... FILE: audiocraft/utils/utils.py function model_hash (line 26) | def model_hash(model: torch.nn.Module) -> str: function dict_from_config (line 36) | def dict_from_config(cfg: omegaconf.DictConfig) -> dict: function random_subset (line 49) | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.ut... function get_loader (line 58) | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, function get_dataset_from_loader (line 81) | def get_dataset_from_loader(dataloader): function multinomial (line 89) | def multinomial(input: torch.Tensor, num_samples: int, replacement=False... function sample_top_k (line 109) | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: function sample_top_p (line 126) | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: class DummyPoolExecutor (line 145) | class DummyPoolExecutor: class DummyResult (line 149) | class DummyResult: method __init__ (line 150) | def __init__(self, func, *args, **kwargs): method result (line 155) | def result(self): method __init__ (line 158) | def __init__(self, workers, mp_context=None): method submit (line 161) | def submit(self, func, *args, **kwargs): method __enter__ (line 164) | def __enter__(self): method __exit__ (line 167) | def __exit__(self, exc_type, exc_value, exc_tb): function get_pool_executor (line 171) | def get_pool_executor(num_workers: int, mp_context=None): function length_to_mask (line 175) | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = No... function hash_trick (line 191) | def hash_trick(word: str, vocab_size: int) -> int: function with_rank_rng (line 204) | def with_rank_rng(base_seed: int = 1234): function collate (line 227) | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[to... function copy_state (line 251) | def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', function swap_state (line 265) | def swap_state(model, state, **kwargs): function warn_once (line 275) | def warn_once(logger, msg): function is_jsonable (line 280) | def is_jsonable(x: tp.Any): function load_clap_state_dict (line 289) | def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): FILE: demos/musicgen_app.py function _call_nostderr (line 39) | def _call_nostderr(*args, **kwargs): function interrupt (line 52) | def interrupt(): class FileCleaner (line 57) | class FileCleaner: method __init__ (line 58) | def __init__(self, file_lifetime: float = 3600): method add (line 62) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 66) | def _cleanup(self): function make_waveform (line 80) | def make_waveform(*args, **kwargs): function load_model (line 90) | def load_model(version='facebook/musicgen-melody'): function load_diffusion (line 97) | def load_diffusion(): function _do_predictions (line 104) | def _do_predictions(texts, melodies, duration, progress=False, **gen_kwa... function predict_batched (line 154) | def predict_batched(texts, melodies): function predict_full (line 162) | def predict_full(model, decoder, text, melody, duration, topk, topp, tem... function toggle_audio_src (line 195) | def toggle_audio_src(choice): function toggle_diffusion (line 202) | def toggle_diffusion(choice): function ui_full (line 209) | def ui_full(launch_kwargs): function ui_batched (line 338) | def ui_batched(launch_kwargs): FILE: scripts/mos.py function normalize_path (line 43) | def normalize_path(path: Path): function get_full_path (line 51) | def get_full_path(normalized_path: Path): function get_signature (line 57) | def get_signature(xps: tp.List[str]): function ensure_logged (line 63) | def ensure_logged(func): function login (line 76) | def login(): function index (line 98) | def index(): function survey (line 135) | def survey(signature): function audio (line 236) | def audio(path: str): function mean (line 242) | def mean(x): function std (line 246) | def std(x): function results (line 253) | def results(signature): FILE: scripts/resample_dataset.py function read_txt_files (line 22) | def read_txt_files(path: tp.Union[str, Path]): function read_egs_files (line 31) | def read_egs_files(path: tp.Union[str, Path]): function process_dataset (line 45) | def process_dataset(args, n_shards: int, node_index: int, task_index: tp... FILE: tests/adversarial/test_discriminators.py class TestMultiPeriodDiscriminator (line 18) | class TestMultiPeriodDiscriminator: method test_mpd_discriminator (line 20) | def test_mpd_discriminator(self): class TestMultiScaleDiscriminator (line 33) | class TestMultiScaleDiscriminator: method test_msd_discriminator (line 35) | def test_msd_discriminator(self): class TestMultiScaleStftDiscriminator (line 49) | class TestMultiScaleStftDiscriminator: method test_msstftd_discriminator (line 51) | def test_msstftd_discriminator(self): FILE: tests/adversarial/test_losses.py class TestAdversarialLoss (line 22) | class TestAdversarialLoss: method test_adversarial_single_multidiscriminator (line 24) | def test_adversarial_single_multidiscriminator(self): method test_adversarial_feat_loss (line 45) | def test_adversarial_feat_loss(self): class TestGeneratorAdversarialLoss (line 65) | class TestGeneratorAdversarialLoss: method test_hinge_generator_adv_loss (line 67) | def test_hinge_generator_adv_loss(self): method test_mse_generator_adv_loss (line 76) | def test_mse_generator_adv_loss(self): class TestDiscriminatorAdversarialLoss (line 88) | class TestDiscriminatorAdversarialLoss: method _disc_loss (line 90) | def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.T... method test_hinge_discriminator_adv_loss (line 97) | def test_hinge_discriminator_adv_loss(self): method test_mse_discriminator_adv_loss (line 105) | def test_mse_discriminator_adv_loss(self): class TestFeatureMatchingLoss (line 115) | class TestFeatureMatchingLoss: method test_features_matching_loss_base (line 117) | def test_features_matching_loss_base(self): method test_features_matching_loss_raises_exception (line 126) | def test_features_matching_loss_raises_exception(self): method test_features_matching_loss_output (line 141) | def test_features_matching_loss_output(self): FILE: tests/common_utils/temp_utils.py class TempDirMixin (line 11) | class TempDirMixin: method get_base_temp_dir (line 18) | def get_base_temp_dir(cls): method tearDownClass (line 29) | def tearDownClass(cls): method id (line 43) | def id(self): method get_temp_path (line 46) | def get_temp_path(self, *paths): method get_temp_dir (line 52) | def get_temp_dir(self, *paths): FILE: tests/common_utils/wav_utils.py function get_white_noise (line 14) | def get_white_noise(chs: int = 1, num_frames: int = 1): function get_batch_white_noise (line 19) | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): function save_wav (line 24) | def save_wav(path: str, wav: torch.Tensor, sample_rate: int): FILE: tests/data/test_audio.py class TestInfo (line 19) | class TestInfo(TempDirMixin): method test_info_mp3 (line 21) | def test_info_mp3(self): method _test_info_format (line 34) | def _test_info_format(self, ext: str): method test_info_wav (line 48) | def test_info_wav(self): method test_info_flac (line 51) | def test_info_flac(self): method test_info_ogg (line 54) | def test_info_ogg(self): method test_info_m4a (line 57) | def test_info_m4a(self): class TestRead (line 63) | class TestRead(TempDirMixin): method test_read_full_wav (line 65) | def test_read_full_wav(self): method test_read_partial_wav (line 80) | def test_read_partial_wav(self): method test_read_seek_time_wav (line 97) | def test_read_seek_time_wav(self): method test_read_seek_time_wav_padded (line 116) | def test_read_seek_time_wav_padded(self): class TestAvRead (line 139) | class TestAvRead(TempDirMixin): method test_avread_seek_base (line 141) | def test_avread_seek_base(self): method test_avread_seek_partial (line 159) | def test_avread_seek_partial(self): method test_avread_seek_outofbound (line 178) | def test_avread_seek_outofbound(self): method test_avread_seek_edge (line 193) | def test_avread_seek_edge(self): class TestAudioWrite (line 212) | class TestAudioWrite(TempDirMixin): method test_audio_write_wav (line 214) | def test_audio_write_wav(self): FILE: tests/data/test_audio_dataset.py class TestAudioMeta (line 31) | class TestAudioMeta(TempDirMixin): method test_get_audio_meta (line 33) | def test_get_audio_meta(self): method test_save_audio_meta (line 49) | def test_save_audio_meta(self): method test_load_audio_meta (line 65) | def test_load_audio_meta(self): class TestAudioDataset (line 90) | class TestAudioDataset(TempDirMixin): method _create_audio_files (line 92) | def _create_audio_files(self, method _create_audio_dataset (line 114) | def _create_audio_dataset(self, method test_dataset_full (line 135) | def test_dataset_full(self): method test_dataset_segment (line 152) | def test_dataset_segment(self): method test_dataset_equal_audio_and_segment_durations (line 170) | def test_dataset_equal_audio_and_segment_durations(self): method test_dataset_samples (line 192) | def test_dataset_samples(self): method test_dataset_return_info (line 218) | def test_dataset_return_info(self): method test_dataset_return_info_no_segment_duration (line 240) | def test_dataset_return_info_no_segment_duration(self): method test_dataset_collate_fn (line 260) | def test_dataset_collate_fn(self): method test_dataset_with_meta_collate_fn (line 280) | def test_dataset_with_meta_collate_fn(self, segment_duration): method test_sample_with_weight (line 308) | def test_sample_with_weight(self, segment_duration, sample_on_weight, ... method test_meta_duration_filter_all (line 333) | def test_meta_duration_filter_all(self): method test_meta_duration_filter_long (line 345) | def test_meta_duration_filter_long(self): FILE: tests/data/test_audio_utils.py class TestConvertAudioChannels (line 20) | class TestConvertAudioChannels: method test_convert_audio_channels_downmix (line 22) | def test_convert_audio_channels_downmix(self): method test_convert_audio_channels_nochange (line 28) | def test_convert_audio_channels_nochange(self): method test_convert_audio_channels_upmix (line 34) | def test_convert_audio_channels_upmix(self): method test_convert_audio_channels_upmix_error (line 40) | def test_convert_audio_channels_upmix_error(self): class TestConvertAudio (line 47) | class TestConvertAudio: method test_convert_audio_channels_downmix (line 49) | def test_convert_audio_channels_downmix(self): method test_convert_audio_channels_upmix (line 56) | def test_convert_audio_channels_upmix(self): method test_convert_audio_upsample (line 63) | def test_convert_audio_upsample(self): method test_convert_audio_resample (line 72) | def test_convert_audio_resample(self): class TestNormalizeAudio (line 82) | class TestNormalizeAudio: method test_clip_wav (line 84) | def test_clip_wav(self): method test_normalize_audio_clip (line 91) | def test_normalize_audio_clip(self): method test_normalize_audio_rms (line 98) | def test_normalize_audio_rms(self): method test_normalize_audio_peak (line 105) | def test_normalize_audio_peak(self): FILE: tests/losses/test_losses.py function test_mel_l1_loss (line 20) | def test_mel_l1_loss(): function test_msspec_loss (line 34) | def test_msspec_loss(): function test_mrstft_loss (line 48) | def test_mrstft_loss(): function test_sisnr_loss (line 59) | def test_sisnr_loss(): function test_stft_loss (line 70) | def test_stft_loss(): FILE: tests/models/test_audiogen.py class TestAudioGenModel (line 13) | class TestAudioGenModel: method get_audiogen (line 14) | def get_audiogen(self): method test_base (line 19) | def test_base(self): method test_generate_continuation (line 25) | def test_generate_continuation(self): method test_generate (line 41) | def test_generate(self): method test_generate_long (line 47) | def test_generate_long(self): FILE: tests/models/test_encodec_model.py class TestEncodecModel (line 17) | class TestEncodecModel: method _create_encodec_model (line 19) | def _create_encodec_model(self, method test_model (line 37) | def test_model(self): method test_model_renorm (line 48) | def test_model_renorm(self): FILE: tests/models/test_multibanddiffusion.py class TestMBD (line 18) | class TestMBD: method _create_mbd (line 20) | def _create_mbd(self, method test_model (line 43) | def test_model(self): FILE: tests/models/test_musicgen.py class TestMusicGenModel (line 13) | class TestMusicGenModel: method get_musicgen (line 14) | def get_musicgen(self): method test_base (line 19) | def test_base(self): method test_generate_unconditional (line 25) | def test_generate_unconditional(self): method test_generate_continuation (line 30) | def test_generate_continuation(self): method test_generate (line 46) | def test_generate(self): method test_generate_long (line 52) | def test_generate_long(self): FILE: tests/modules/test_activations.py class TestActivations (line 13) | class TestActivations: method test_custom_glu_calculation (line 14) | def test_custom_glu_calculation(self): FILE: tests/modules/test_codebooks_patterns.py class TestParallelPatternProvider (line 18) | class TestParallelPatternProvider: method test_get_pattern (line 22) | def test_get_pattern(self, n_q: int, timesteps: int): method test_pattern_content (line 30) | def test_pattern_content(self, n_q: int, timesteps: int): method test_pattern_max_delay (line 40) | def test_pattern_max_delay(self, n_q: int, timesteps: int): class TestDelayedPatternProvider (line 47) | class TestDelayedPatternProvider: method test_get_pattern (line 51) | def test_get_pattern(self, n_q: int, timesteps: int): method test_pattern_content (line 65) | def test_pattern_content(self, n_q: int, timesteps: int): method test_pattern_max_delay (line 75) | def test_pattern_max_delay(self, timesteps: int, delay: list): class TestUnrolledPatternProvider (line 82) | class TestUnrolledPatternProvider: method test_get_pattern (line 87) | def test_get_pattern(self, timesteps: int, flattening: list, delays: l... method test_pattern_max_delay (line 97) | def test_pattern_max_delay(self, timesteps: int, flattening: list, del... class TestPattern (line 105) | class TestPattern: method ref_build_pattern_sequence (line 107) | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern... method ref_revert_pattern_sequence (line 121) | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Patter... method ref_revert_pattern_logits (line 134) | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern,... method _get_pattern_providers (line 149) | def _get_pattern_providers(self, n_q: int): method test_build_pattern_sequence (line 173) | def test_build_pattern_sequence(self, n_q: int, timesteps: int): method test_revert_pattern_sequence (line 205) | def test_revert_pattern_sequence(self, n_q: int, timesteps: int): method test_revert_pattern_logits (line 228) | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: i... FILE: tests/modules/test_conv.py function test_get_extra_padding_for_conv1d (line 25) | def test_get_extra_padding_for_conv1d(): function test_pad1d_zeros (line 30) | def test_pad1d_zeros(): function test_pad1d_reflect (line 52) | def test_pad1d_reflect(): function test_unpad1d (line 74) | def test_unpad1d(): class TestNormConv1d (line 96) | class TestNormConv1d: method test_norm_conv1d_modules (line 98) | def test_norm_conv1d_modules(self): class TestNormConvTranspose1d (line 123) | class TestNormConvTranspose1d: method test_normalizations (line 125) | def test_normalizations(self): class TestStreamableConv1d (line 151) | class TestStreamableConv1d: method get_streamable_conv1d_output_length (line 153) | def get_streamable_conv1d_output_length(self, length, kernel_size, str... method test_streamable_conv1d (line 160) | def test_streamable_conv1d(self): class TestStreamableConvTranspose1d (line 176) | class TestStreamableConvTranspose1d: method get_streamable_convtr1d_output_length (line 178) | def get_streamable_convtr1d_output_length(self, length, kernel_size, s... method test_streamable_convtr1d (line 182) | def test_streamable_convtr1d(self): FILE: tests/modules/test_lstm.py class TestStreamableLSTM (line 13) | class TestStreamableLSTM: method test_lstm (line 15) | def test_lstm(self): method test_lstm_skip (line 25) | def test_lstm_skip(self): FILE: tests/modules/test_rope.py function test_rope (line 13) | def test_rope(): function test_rope_io_dtypes (line 26) | def test_rope_io_dtypes(): function test_transformer_with_rope (line 50) | def test_transformer_with_rope(): function test_rope_streaming (line 66) | def test_rope_streaming(): function test_rope_streaming_past_context (line 94) | def test_rope_streaming_past_context(): function test_rope_memory_efficient (line 124) | def test_rope_memory_efficient(): function test_rope_with_xpos (line 145) | def test_rope_with_xpos(): function test_positional_scale (line 158) | def test_positional_scale(): FILE: tests/modules/test_seanet.py class TestSEANetModel (line 16) | class TestSEANetModel: method test_base (line 18) | def test_base(self): method test_causal (line 28) | def test_causal(self): method test_conv_skip_connection (line 38) | def test_conv_skip_connection(self): method test_seanet_encoder_decoder_final_act (line 48) | def test_seanet_encoder_decoder_final_act(self): method _check_encoder_blocks_norm (line 58) | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable... method test_encoder_disable_norm (line 70) | def test_encoder_disable_norm(self): method _check_decoder_blocks_norm (line 79) | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable... method test_decoder_disable_norm (line 94) | def test_decoder_disable_norm(self): method test_disable_norm_raises_exception (line 103) | def test_disable_norm_raises_exception(self): FILE: tests/modules/test_transformer.py function test_transformer_causal_streaming (line 16) | def test_transformer_causal_streaming(): function test_transformer_vs_pytorch (line 52) | def test_transformer_vs_pytorch(): function test_streaming_api (line 71) | def test_streaming_api(): function test_memory_efficient (line 88) | def test_memory_efficient(): function test_attention_as_float32 (line 108) | def test_attention_as_float32(): function test_streaming_memory_efficient (line 134) | def test_streaming_memory_efficient(): function test_cross_attention (line 164) | def test_cross_attention(): function test_cross_attention_compat (line 192) | def test_cross_attention_compat(): function test_repeat_kv (line 224) | def test_repeat_kv(): function test_qk_layer_norm (line 241) | def test_qk_layer_norm(): FILE: tests/quantization/test_vq.py class TestResidualVectorQuantizer (line 12) | class TestResidualVectorQuantizer: method test_rvq (line 14) | def test_rvq(self):