SYMBOL INDEX (1517 symbols across 145 files) FILE: audiocraft/adversarial/discriminators/base.py class MultiDiscriminator (line 19) | class MultiDiscriminator(ABC, nn.Module): method __init__ (line 22) | def __init__(self): method forward (line 26) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: method num_discriminators (line 31) | def num_discriminators(self) -> int: FILE: audiocraft/adversarial/discriminators/mpd.py function get_padding (line 17) | def get_padding(kernel_size: int, dilation: int = 1) -> int: class PeriodDiscriminator (line 21) | class PeriodDiscriminator(nn.Module): method __init__ (line 38) | def __init__(self, period: int, in_channels: int = 1, out_channels: in... method forward (line 58) | def forward(self, x: torch.Tensor): class MultiPeriodDiscriminator (line 79) | class MultiPeriodDiscriminator(MultiDiscriminator): method __init__ (line 88) | def __init__(self, in_channels: int = 1, out_channels: int = 1, method num_discriminators (line 96) | def num_discriminators(self): method forward (line 99) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/discriminators/msd.py class ScaleDiscriminator (line 17) | class ScaleDiscriminator(nn.Module): method __init__ (line 37) | def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Seq... method forward (line 83) | def forward(self, x: torch.Tensor): class MultiScaleDiscriminator (line 95) | class MultiScaleDiscriminator(MultiDiscriminator): method __init__ (line 105) | def __init__(self, in_channels: int = 1, out_channels: int = 1, downsa... method num_discriminators (line 114) | def num_discriminators(self): method forward (line 117) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/discriminators/msstftd.py function get_2d_padding (line 18) | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[i... class DiscriminatorSTFT (line 22) | class DiscriminatorSTFT(nn.Module): method __init__ (line 41) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i... method forward (line 81) | def forward(self, x: torch.Tensor): class MultiScaleSTFTDiscriminator (line 94) | class MultiScaleSTFTDiscriminator(MultiDiscriminator): method __init__ (line 107) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i... method num_discriminators (line 120) | def num_discriminators(self): method _separate_channels (line 123) | def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: method forward (line 127) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: FILE: audiocraft/adversarial/losses.py class AdversarialLoss (line 26) | class AdversarialLoss(nn.Module): method __init__ (line 49) | def __init__(self, method _save_to_state_dict (line 67) | def _save_to_state_dict(self, destination, prefix, keep_vars): method _load_from_state_dict (line 73) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): method get_adversary_pred (line 78) | def get_adversary_pred(self, x): method train_adv (line 89) | def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.T... method forward (line 115) | def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[... function get_adv_criterion (line 138) | def get_adv_criterion(loss_type: str) -> tp.Callable: function get_fake_criterion (line 149) | def get_fake_criterion(loss_type: str) -> tp.Callable: function get_real_criterion (line 158) | def get_real_criterion(loss_type: str) -> tp.Callable: function mse_real_loss (line 167) | def mse_real_loss(x: torch.Tensor) -> torch.Tensor: function mse_fake_loss (line 171) | def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: function hinge_real_loss (line 175) | def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: function hinge_fake_loss (line 179) | def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: function mse_loss (line 183) | def mse_loss(x: torch.Tensor) -> torch.Tensor: function hinge_loss (line 189) | def hinge_loss(x: torch.Tensor) -> torch.Tensor: function hinge2_loss (line 195) | def hinge2_loss(x: torch.Tensor) -> torch.Tensor: class FeatureMatchingLoss (line 201) | class FeatureMatchingLoss(nn.Module): method __init__ (line 209) | def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: boo... method forward (line 214) | def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List... FILE: audiocraft/data/audio.py function _init_av (line 31) | def _init_av(): class AudioFileInfo (line 41) | class AudioFileInfo: function _av_info (line 47) | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function _soundfile_info (line 57) | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function audio_info (line 62) | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: function _av_read (line 72) | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, durati... function audio_read (line 116) | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., function _piping_to_ffmpeg (line 147) | def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, ... function audio_write (line 159) | def audio_write(stem_name: tp.Union[str, Path], function get_spec (line 234) | def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray: function save_spectrograms (line 256) | def save_spectrograms( FILE: audiocraft/data/audio_dataset.py class BaseInfo (line 39) | class BaseInfo: method _dict2fields (line 42) | def _dict2fields(cls, dictionary: dict): method from_dict (line 49) | def from_dict(cls, dictionary: dict): method to_dict (line 53) | def to_dict(self): class AudioMeta (line 61) | class AudioMeta(BaseInfo): method from_dict (line 71) | def from_dict(cls, dictionary: dict): method to_dict (line 77) | def to_dict(self): class SegmentInfo (line 85) | class SegmentInfo(BaseInfo): function _get_audio_meta (line 101) | def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: function _resolve_audio_meta (line 118) | def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: function find_audio_files (line 145) | def find_audio_files(path: tp.Union[Path, str], function load_audio_meta (line 204) | def load_audio_meta(path: tp.Union[str, Path], function save_audio_meta (line 228) | def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): class AudioDataset (line 244) | class AudioDataset: method __init__ (line 295) | def __init__(self, method start_epoch (line 350) | def start_epoch(self, epoch: int): method __len__ (line 353) | def __len__(self): method _get_sampling_probabilities (line 356) | def _get_sampling_probabilities(self, normalized: bool = True): method _get_file_permutation (line 373) | def _get_file_permutation(num_files: int, permutation_index: int, base... method sample_file (line 380) | def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: method _audio_read (line 404) | def _audio_read(self, path: str, seek_time: float = 0, duration: float... method __getitem__ (line 413) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t... method collater (line 462) | def collater(self, samples): method _filter_duration (line 502) | def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioM... method from_meta (line 524) | def from_meta(cls, root: tp.Union[str, Path], **kwargs): method from_path (line 544) | def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, function main (line 562) | def main(): FILE: audiocraft/data/audio_utils.py function convert_audio_channels (line 21) | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torc... function convert_audio (line 54) | def convert_audio(wav: torch.Tensor, from_rate: float, function normalize_loudness (line 62) | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_hea... function _clip_wav (line 91) | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: ... function normalize_audio (line 103) | def normalize_audio(wav: torch.Tensor, normalize: bool = True, function f32_pcm (line 155) | def f32_pcm(wav: torch.Tensor) -> torch.Tensor: function i16_pcm (line 172) | def i16_pcm(wav: torch.Tensor) -> torch.Tensor: function compress (line 195) | def compress(wav: torch.Tensor, sr: int, function get_mp3 (line 233) | def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") ->... function get_aac (line 274) | def get_aac( FILE: audiocraft/data/info_audio_dataset.py function _clusterify_meta (line 25) | def _clusterify_meta(meta: AudioMeta) -> AudioMeta: function clusterify_all_meta (line 33) | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: class AudioInfo (line 39) | class AudioInfo(SegmentWithAttributes): method to_condition_attributes (line 50) | def to_condition_attributes(self) -> ConditioningAttributes: class InfoAudioDataset (line 54) | class InfoAudioDataset(AudioDataset): method __init__ (line 59) | def __init__(self, meta: tp.List[AudioMeta], **kwargs): method __getitem__ (line 62) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t... function get_keyword_or_keyword_list (line 71) | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.... function get_string (line 79) | def get_string(value: tp.Optional[str]) -> tp.Optional[str]: function get_keyword (line 87) | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: function get_keyword_list (line 95) | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional... FILE: audiocraft/data/jasco_dataset.py class JascoInfo (line 23) | class JascoInfo(MusicInfo): method to_condition_attributes (line 32) | def to_condition_attributes(self) -> ConditioningAttributes: class MelodyData (line 50) | class MelodyData: method __init__ (line 55) | def __init__(self, method load_saliency_from_saliency_dict (line 112) | def load_saliency_from_saliency_dict(self, method get_null_salience (line 150) | def get_null_salience(self) -> torch.Tensor: method __call__ (line 153) | def __call__(self, x: MusicInfo) -> torch.Tensor: class JascoDataset (line 173) | class JascoDataset(MusicDataset): method from_meta (line 183) | def from_meta(cls, root: tp.Union[str, Path], **kwargs): method __init__ (line 210) | def __init__(self, *args, method _get_relevant_sublist (line 239) | def _get_relevant_sublist(self, chords, timestamp): method _get_chords (line 269) | def _get_chords(self, music_info: MusicInfo, effective_segment_dur: fl... method __getitem__ (line 296) | def __getitem__(self, index): FILE: audiocraft/data/music_dataset.py class MusicInfo (line 37) | class MusicInfo(AudioInfo): method has_music_meta (line 57) | def has_music_meta(self) -> bool: method to_condition_attributes (line 60) | def to_condition_attributes(self) -> ConditioningAttributes: method attribute_getter (line 76) | def attribute_getter(attribute): method from_dict (line 92) | def from_dict(cls, dictionary: dict, fields_required: bool = False): function augment_music_info_description (line 115) | def augment_music_info_description(music_info: MusicInfo, merge_text_p: ... class Paraphraser (line 167) | class Paraphraser: method __init__ (line 168) | def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_... method sample_paraphrase (line 175) | def sample_paraphrase(self, audio_path: str, description: str): class MusicDataset (line 187) | class MusicDataset(InfoAudioDataset): method __init__ (line 204) | def __init__(self, *args, info_fields_required: bool = True, method __getitem__ (line 220) | def __getitem__(self, index): function get_musical_key (line 252) | def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: function get_bpm (line 263) | def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: FILE: audiocraft/data/sound_dataset.py class SoundInfo (line 35) | class SoundInfo(SegmentWithAttributes): method has_sound_meta (line 42) | def has_sound_meta(self) -> bool: method to_condition_attributes (line 45) | def to_condition_attributes(self) -> ConditioningAttributes: method attribute_getter (line 57) | def attribute_getter(attribute): method from_dict (line 65) | def from_dict(cls, dictionary: dict, fields_required: bool = False): class SoundDataset (line 87) | class SoundDataset(InfoAudioDataset): method __init__ (line 104) | def __init__( method _get_info_path (line 129) | def _get_info_path(self, path: tp.Union[str, Path]) -> Path: method __getitem__ (line 142) | def __getitem__(self, index): method collater (line 163) | def collater(self, samples): function rms_f (line 173) | def rms_f(x: torch.Tensor) -> torch.Tensor: function normalize (line 177) | def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Ten... function is_clipped (line 185) | def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) ->... function mix_pair (line 189) | def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -... function snr_mixer (line 199) | def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_ov... function snr_mix (line 252) | def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high... function mix_text (line 261) | def mix_text(src_text: str, dst_text: str): function mix_samples (line 268) | def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: fl... FILE: audiocraft/data/zip.py class PathInZip (line 22) | class PathInZip: method __init__ (line 36) | def __init__(self, path: str) -> None: method from_paths (line 42) | def from_paths(cls, zip_path: str, file_path: str): method __str__ (line 45) | def __str__(self) -> str: function _open_zip (line 49) | def _open_zip(path: str, mode: MODE = 'r'): function set_zip_cache_size (line 56) | def set_zip_cache_size(max_size: int): function open_file_in_zip (line 66) | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: FILE: audiocraft/environment.py class AudioCraftEnvironment (line 25) | class AudioCraftEnvironment: method __init__ (line 49) | def __init__(self) -> None: method _get_cluster_config (line 74) | def _get_cluster_config(self) -> omegaconf.DictConfig: method instance (line 79) | def instance(cls): method reset (line 85) | def reset(cls): method get_team (line 90) | def get_team(cls) -> str: method get_cluster (line 97) | def get_cluster(cls) -> str: method get_dora_dir (line 104) | def get_dora_dir(cls) -> Path: method get_reference_dir (line 114) | def get_reference_dir(cls) -> Path: method get_slurm_exclude (line 122) | def get_slurm_exclude(cls) -> tp.Optional[str]: method get_slurm_partitions (line 128) | def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str... method resolve_reference_path (line 146) | def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: method apply_dataset_mappers (line 167) | def apply_dataset_mappers(cls, path: str) -> str: FILE: audiocraft/grids/_base_explorers.py function get_sheep_ping (line 14) | def get_sheep_ping(sheep) -> tp.Optional[str]: class BaseExplorer (line 31) | class BaseExplorer(ABC, Explorer): method stages (line 40) | def stages(self): method get_grid_meta (line 43) | def get_grid_meta(self): method get_grid_metrics (line 55) | def get_grid_metrics(self): method process_sheep (line 60) | def process_sheep(self, sheep, history): FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32): function explorer (line 49) | def explorer(launcher): FILE: audiocraft/grids/compression/_explorers.py class CompressionExplorer (line 12) | class CompressionExplorer(BaseExplorer): method stages (line 15) | def stages(self): method get_grid_meta (line 18) | def get_grid_meta(self): method get_grid_metrics (line 28) | def get_grid_metrics(self): FILE: audiocraft/grids/compression/debug.py function explorer (line 22) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_base_24khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py function explorer (line 20) | def explorer(launcher): FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py function explorer (line 17) | def explorer(launcher): FILE: audiocraft/grids/diffusion/_explorers.py class DiffusionExplorer (line 12) | class DiffusionExplorer(BaseExplorer): method stages (line 15) | def stages(self): method get_grid_meta (line 18) | def get_grid_meta(self): method get_grid_metrics (line 28) | def get_grid_metrics(self): FILE: audiocraft/grids/magnet/audio_magnet_16khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32): function explorer (line 47) | def explorer(launcher): FILE: audiocraft/grids/magnet/magnet_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32): function explorer (line 47) | def explorer(launcher): FILE: audiocraft/grids/musicgen/_explorers.py class LMExplorer (line 14) | class LMExplorer(BaseExplorer): method stages (line 17) | def stages(self) -> tp.List[str]: method get_grid_metrics (line 20) | def get_grid_metrics(self): method process_sheep (line 45) | def process_sheep(self, sheep, history): class GenerationEvalExplorer (line 69) | class GenerationEvalExplorer(BaseExplorer): method stages (line 72) | def stages(self) -> tp.List[str]: method get_grid_metrics (line 75) | def get_grid_metrics(self): FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py function explorer (line 12) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py function eval (line 26) | def eval(launcher, batch_size: int = 32, eval_melody: bool = False): function explorer (line 63) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py function explorer (line 13) | def explorer(launcher): FILE: audiocraft/grids/musicgen/musicgen_style_32khz.py function explorer (line 11) | def explorer(launcher): FILE: audiocraft/grids/watermarking/_explorers.py class WatermarkingMbExplorer (line 12) | class WatermarkingMbExplorer(BaseExplorer): method stages (line 15) | def stages(self): method get_grid_meta (line 18) | def get_grid_meta(self): method get_grid_metrics (line 27) | def get_grid_metrics(self): class WatermarkingExplorer (line 66) | class WatermarkingExplorer(BaseExplorer): method stages (line 69) | def stages(self): method get_grid_meta (line 72) | def get_grid_meta(self): method get_grid_metrics (line 81) | def get_grid_metrics(self): FILE: audiocraft/grids/watermarking/audioseal.py function explorer (line 15) | def explorer(launcher): FILE: audiocraft/grids/watermarking/kbits.py function explorer (line 16) | def explorer(launcher): FILE: audiocraft/losses/balancer.py class Balancer (line 14) | class Balancer: method __init__ (line 61) | def __init__(self, weights: tp.Dict[str, float], balance_grads: bool =... method metrics (line 74) | def metrics(self): method backward (line 77) | def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Te... FILE: audiocraft/losses/loudnessloss.py function basic_loudness (line 18) | def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Te... function _unfold (line 53) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten... class FLoudnessRatio (line 69) | class FLoudnessRatio(nn.Module): method __init__ (line 82) | def __init__( method forward (line 101) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor... class TLoudnessRatio (line 114) | class TLoudnessRatio(nn.Module): method __init__ (line 125) | def __init__( method forward (line 137) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor... class TFLoudnessRatio (line 153) | class TFLoudnessRatio(nn.Module): method __init__ (line 166) | def __init__( method forward (line 187) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor... FILE: audiocraft/losses/sisnr.py function _unfold (line 15) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten... function _center (line 31) | def _center(x: torch.Tensor) -> torch.Tensor: function _norm2 (line 35) | def _norm2(x: torch.Tensor) -> torch.Tensor: class SISNR (line 39) | class SISNR(nn.Module): method __init__ (line 56) | def __init__( method forward (line 69) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor... FILE: audiocraft/losses/specloss.py class MelSpectrogramWrapper (line 18) | class MelSpectrogramWrapper(nn.Module): method __init__ (line 35) | def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_lengt... method forward (line 48) | def forward(self, x): class MelSpectrogramL1Loss (line 65) | class MelSpectrogramL1Loss(torch.nn.Module): method __init__ (line 80) | def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: in... method forward (line 89) | def forward(self, x, y): class MultiScaleMelSpectrogramLoss (line 96) | class MultiScaleMelSpectrogramLoss(nn.Module): method __init__ (line 110) | def __init__(self, sample_rate: int, range_start: int = 6, range_end: ... method forward (line 137) | def forward(self, x, y): FILE: audiocraft/losses/stftloss.py function _stft (line 17) | def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int, class SpectralConvergenceLoss (line 45) | class SpectralConvergenceLoss(nn.Module): method __init__ (line 48) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): method forward (line 52) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): class LogSTFTMagnitudeLoss (line 64) | class LogSTFTMagnitudeLoss(nn.Module): method __init__ (line 70) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): method forward (line 74) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): class STFTLosses (line 86) | class STFTLosses(nn.Module): method __init__ (line 97) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt... method forward (line 109) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.... class STFTLoss (line 129) | class STFTLoss(nn.Module): method __init__ (line 142) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt... method forward (line 151) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.... class MRSTFTLoss (line 164) | class MRSTFTLoss(nn.Module): method __init__ (line 177) | def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_l... method forward (line 189) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: FILE: audiocraft/losses/wmloss.py class WMDetectionLoss (line 13) | class WMDetectionLoss(nn.Module): method __init__ (line 15) | def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None: method forward (line 21) | def forward(self, positive, negative, mask, message=None): class WMMbLoss (line 55) | class WMMbLoss(nn.Module): method __init__ (line 56) | def __init__(self, temperature: float, loss_type: Literal["bce", "mse"... method forward (line 73) | def forward(self, positive, negative, mask, message): FILE: audiocraft/metrics/chroma_cosinesim.py class ChromaCosineSimilarityMetric (line 14) | class ChromaCosineSimilarityMetric(torchmetrics.Metric): method __init__ (line 28) | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, a... method update (line 38) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method compute (line 69) | def compute(self) -> float: FILE: audiocraft/metrics/clap_consistency.py class TextConsistencyMetric (line 24) | class TextConsistencyMetric(torchmetrics.Metric): method update (line 27) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch... method compute (line 30) | def compute(self): class CLAPTextConsistencyMetric (line 34) | class CLAPTextConsistencyMetric(TextConsistencyMetric): method __init__ (line 47) | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = ... method _initialize_model (line 55) | def _initialize_model(self, model_path: tp.Union[str, Path], model_arc... method _tokenizer (line 63) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: method update (line 67) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch... method compute (line 81) | def compute(self): FILE: audiocraft/metrics/fad.py class FrechetAudioDistanceMetric (line 29) | class FrechetAudioDistanceMetric(torchmetrics.Metric): method __init__ (line 145) | def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path... method reset (line 167) | def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): method update (line 182) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method _get_samples_name (line 222) | def _get_samples_name(self, is_background: bool): method _create_embedding_beams (line 225) | def _create_embedding_beams(self, is_background: bool, gpu_index: tp.O... method _compute_fad_score (line 259) | def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): method _log_process_result (line 283) | def _log_process_result(self, returncode: int, log_file: tp.Union[Path... method _parallel_create_embedding_beams (line 293) | def _parallel_create_embedding_beams(self, num_of_gpus: int): method _sequential_create_embedding_beams (line 303) | def _sequential_create_embedding_beams(self): method _local_compute_frechet_audio_distance (line 313) | def _local_compute_frechet_audio_distance(self): method compute (line 323) | def compute(self) -> float: FILE: audiocraft/metrics/kld.py class _patch_passt_stft (line 22) | class _patch_passt_stft: method __init__ (line 24) | def __init__(self): method __enter__ (line 27) | def __enter__(self): method __exit__ (line 32) | def __exit__(self, *exc): function kl_divergence (line 36) | def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, ... class KLDivergenceMetric (line 53) | class KLDivergenceMetric(torchmetrics.Metric): method __init__ (line 62) | def __init__(self): method _get_label_distribution (line 69) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, method update (line 82) | def update(self, preds: torch.Tensor, targets: torch.Tensor, method compute (line 105) | def compute(self) -> dict: class PasstKLDivergenceMetric (line 116) | class PasstKLDivergenceMetric(KLDivergenceMetric): method __init__ (line 131) | def __init__(self, pretrained_length: tp.Optional[float] = None): method _initialize_model (line 135) | def _initialize_model(self, pretrained_length: tp.Optional[float] = No... method _load_base_model (line 145) | def _load_base_model(self, pretrained_length: tp.Optional[float]): method _process_audio (line 172) | def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len:... method _get_model_preds (line 187) | def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: method _get_label_distribution (line 198) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, FILE: audiocraft/metrics/miou.py function calculate_miou (line 10) | def calculate_miou(y_pred: torch.Tensor, y_true: torch.Tensor) -> float: FILE: audiocraft/metrics/pesq.py class PesqMetric (line 14) | class PesqMetric(torchmetrics.Metric): method __init__ (line 23) | def __init__(self, sample_rate: int): method update (line 30) | def update(self, preds: torch.Tensor, targets: torch.Tensor): method compute (line 45) | def compute(self) -> torch.Tensor: FILE: audiocraft/metrics/rvm.py function db_to_scale (line 13) | def db_to_scale(volume: tp.Union[float, torch.Tensor]): function scale_to_db (line 17) | def scale_to_db(scale: torch.Tensor, min_volume: float = -120): class RelativeVolumeMel (line 22) | class RelativeVolumeMel(nn.Module): method __init__ (line 69) | def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: ... method forward (line 84) | def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) ... FILE: audiocraft/metrics/visqol.py class ViSQOL (line 22) | class ViSQOL: method __init__ (line 56) | def __init__(self, bin: tp.Union[Path, str], mode: str = "audio", method _get_target_sr (line 67) | def _get_target_sr(self, mode: str) -> int: method _prepare_files (line 75) | def _prepare_files( method _flush_files (line 132) | def _flush_files(self, tmp_dir: tp.Union[Path, str]): method _collect_moslqo_score (line 136) | def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str])... method _collect_debug_data (line 146) | def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) ->... method visqol_model (line 153) | def visqol_model(self): method _run_visqol (line 156) | def _run_visqol( method __call__ (line 181) | def __call__( FILE: audiocraft/models/audiogen.py class AudioGen (line 23) | class AudioGen(BaseGenModel): method __init__ (line 34) | def __init__(self, name: str, compression_model: CompressionModel, lm:... method get_pretrained (line 40) | def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): method set_generation_params (line 63) | def set_generation_params(self, use_sampling: bool = True, top_k: int ... FILE: audiocraft/models/builders.py function get_quantizer (line 44) | def get_quantizer( function get_encodec_autoencoder (line 56) | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): function get_compression_model (line 70) | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: function get_jasco_model (line 94) | def get_jasco_model(cfg: omegaconf.DictConfig, function get_lm_model (line 136) | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: function get_conditioner_provider (line 178) | def get_conditioner_provider( function get_condition_fuser (line 230) | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: function get_codebooks_pattern_provider (line 240) | def get_codebooks_pattern_provider( function get_debug_compression_model (line 257) | def get_debug_compression_model(device="cpu", sample_rate: int = 32000): function get_diffusion_model (line 291) | def get_diffusion_model(cfg: omegaconf.DictConfig): function get_processor (line 298) | def get_processor(cfg, sample_rate: int = 24000): function get_debug_lm_model (line 309) | def get_debug_lm_model(device="cpu"): function get_wrapped_compression_model (line 338) | def get_wrapped_compression_model( function get_watermark_model (line 354) | def get_watermark_model(cfg: omegaconf.DictConfig) -> WMModel: FILE: audiocraft/models/encodec.py class CompressionModel (line 28) | class CompressionModel(ABC, nn.Module): method forward (line 34) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 38) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 43) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 48) | def decode_latent(self, codes: torch.Tensor): method channels (line 54) | def channels(self) -> int: method frame_rate (line 59) | def frame_rate(self) -> float: method sample_rate (line 64) | def sample_rate(self) -> int: method cardinality (line 69) | def cardinality(self) -> int: method num_codebooks (line 74) | def num_codebooks(self) -> int: method total_codebooks (line 79) | def total_codebooks(self) -> int: method set_num_codebooks (line 83) | def set_num_codebooks(self, n: int): method get_pretrained (line 88) | def get_pretrained( class EncodecModel (line 125) | class EncodecModel(CompressionModel): method __init__ (line 144) | def __init__(self, method total_codebooks (line 168) | def total_codebooks(self): method num_codebooks (line 173) | def num_codebooks(self): method set_num_codebooks (line 177) | def set_num_codebooks(self, n: int): method cardinality (line 182) | def cardinality(self): method preprocess (line 186) | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Opt... method postprocess (line 198) | def postprocess(self, method forward (line 206) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 223) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 240) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 257) | def decode_latent(self, codes: torch.Tensor): class DAC (line 262) | class DAC(CompressionModel): method __init__ (line 263) | def __init__(self, model_type: str = "44khz"): method forward (line 274) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 278) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 282) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 287) | def decode_latent(self, codes: torch.Tensor): method channels (line 292) | def channels(self) -> int: method frame_rate (line 296) | def frame_rate(self) -> float: method sample_rate (line 300) | def sample_rate(self) -> int: method cardinality (line 304) | def cardinality(self) -> int: method num_codebooks (line 308) | def num_codebooks(self) -> int: method total_codebooks (line 312) | def total_codebooks(self) -> int: method set_num_codebooks (line 315) | def set_num_codebooks(self, n: int): class HFEncodecCompressionModel (line 323) | class HFEncodecCompressionModel(CompressionModel): method __init__ (line 326) | def __init__(self, model: HFEncodecModel): method forward (line 340) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 344) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method decode (line 352) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 360) | def decode_latent(self, codes: torch.Tensor): method channels (line 365) | def channels(self) -> int: method frame_rate (line 369) | def frame_rate(self) -> float: method sample_rate (line 374) | def sample_rate(self) -> int: method cardinality (line 378) | def cardinality(self) -> int: method num_codebooks (line 382) | def num_codebooks(self) -> int: method total_codebooks (line 386) | def total_codebooks(self) -> int: method set_num_codebooks (line 389) | def set_num_codebooks(self, n: int): class InterleaveStereoCompressionModel (line 397) | class InterleaveStereoCompressionModel(CompressionModel): method __init__ (line 409) | def __init__(self, model: CompressionModel, per_timestep: bool = False): method total_codebooks (line 416) | def total_codebooks(self): method num_codebooks (line 420) | def num_codebooks(self): method set_num_codebooks (line 428) | def set_num_codebooks(self, n: int): method num_virtual_steps (line 436) | def num_virtual_steps(self) -> float: method frame_rate (line 443) | def frame_rate(self) -> float: method sample_rate (line 447) | def sample_rate(self) -> int: method channels (line 451) | def channels(self) -> int: method cardinality (line 455) | def cardinality(self): method forward (line 460) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: method encode (line 463) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona... method get_left_right_codes (line 481) | def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.... method decode (line 488) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]... method decode_latent (line 504) | def decode_latent(self, codes: torch.Tensor): FILE: audiocraft/models/flow_matching.py class FMOutput (line 35) | class FMOutput: class CFGTerm (line 40) | class CFGTerm: method __init__ (line 48) | def __init__(self, conditions, weight): method drop_irrelevant_conds (line 52) | def drop_irrelevant_conds(self, conditions): class AllCFGTerm (line 63) | class AllCFGTerm(CFGTerm): method __init__ (line 67) | def __init__(self, conditions, weight): method drop_irrelevant_conds (line 71) | def drop_irrelevant_conds(self): class NullCFGTerm (line 75) | class NullCFGTerm(CFGTerm): method __init__ (line 79) | def __init__(self, conditions, weight): method drop_irrelevant_conds (line 83) | def drop_irrelevant_conds(self): class TextCFGTerm (line 92) | class TextCFGTerm(CFGTerm): method __init__ (line 97) | def __init__(self, conditions, weight, model_att_dropout): method drop_irrelevant_conds (line 116) | def drop_irrelevant_conds(self): class FlowMatchingModel (line 121) | class FlowMatchingModel(StreamingModule): method __init__ (line 150) | def __init__(self, condition_provider: JascoConditioningProvider, method _get_timestep_embedding (line 209) | def _get_timestep_embedding(self, timesteps, embedding_dim): method _embed_time_parameter (line 232) | def _embed_time_parameter(self, t: torch.Tensor): method _init_weights (line 244) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:... method _align_seq_length (line 276) | def _align_seq_length(self, method forward (line 289) | def forward(self, method _multi_source_cfg_preprocess (line 345) | def _multi_source_cfg_preprocess(self, method estimated_vector_field (line 386) | def estimated_vector_field(self, z, t, condition_tensors=None, cfg_ter... method _multi_source_cfg_postprocess (line 403) | def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms): method generate (line 419) | def generate(self, FILE: audiocraft/models/genmodel.py class BaseGenModel (line 28) | class BaseGenModel(ABC): method __init__ (line 39) | def __init__(self, name: str, compression_model: CompressionModel, lm:... method frame_rate (line 81) | def frame_rate(self) -> float: method sample_rate (line 86) | def sample_rate(self) -> int: method audio_channels (line 91) | def audio_channels(self) -> int: method set_custom_progress_callback (line 95) | def set_custom_progress_callback(self, progress_callback: tp.Optional[... method set_generation_params (line 100) | def set_generation_params(self, *args, **kwargs): method get_pretrained (line 106) | def get_pretrained(name: str, device=None): method _prepare_tokens_and_attributes (line 110) | def _prepare_tokens_and_attributes( method generate_unconditional (line 135) | def generate_unconditional(self, num_samples: int, progress: bool = Fa... method generate (line 151) | def generate(self, descriptions: tp.List[str], progress: bool = False,... method generate_continuation (line 166) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra... method _generate_tokens (line 193) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], method generate_audio (line 262) | def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: FILE: audiocraft/models/jasco.py class JASCO (line 24) | class JASCO(BaseGenModel): method __init__ (line 30) | def __init__(self, chords_mapping_path='assets/chord_to_index_mapping.... method get_pretrained (line 43) | def get_pretrained(name: str = 'facebook/jasco-chords-drums-400M', dev... method set_generation_params (line 66) | def set_generation_params(self, method _unnormalized_latents (line 85) | def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor: method generate_audio (line 91) | def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor: method _generate_tokens (line 99) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], method _prepare_chord_conditions (line 137) | def _prepare_chord_conditions( method _prepare_drums_conditions (line 176) | def _prepare_drums_conditions(self, method _prepare_melody_conditions (line 214) | def _prepare_melody_conditions( method _prepare_temporal_conditions (line 240) | def _prepare_temporal_conditions( method generate_music (line 269) | def generate_music( method generate (line 318) | def generate(self, descriptions: tp.List[str], progress: bool = False,... FILE: audiocraft/models/lm.py function get_init_fn (line 37) | def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int... function init_layer (line 65) | def init_layer(m: nn.Module, class ScaledEmbedding (line 98) | class ScaledEmbedding(nn.Embedding): method __init__ (line 101) | def __init__(self, *args, lr=None, **kwargs): method make_optim_group (line 105) | def make_optim_group(self): class LMOutput (line 113) | class LMOutput: class LMModel (line 120) | class LMModel(StreamingModule): method __init__ (line 145) | def __init__(self, pattern_provider: CodebooksPatternProvider, conditi... method _init_weights (line 179) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:... method special_token_id (line 214) | def special_token_id(self) -> int: method num_codebooks (line 218) | def num_codebooks(self) -> int: method forward (line 221) | def forward(self, sequence: torch.Tensor, method compute_predictions (line 270) | def compute_predictions( method _sample_next_token (line 323) | def _sample_next_token(self, method generate (line 421) | def generate(self, FILE: audiocraft/models/lm_magnet.py class MagnetLMModel (line 26) | class MagnetLMModel(LMModel): method __init__ (line 37) | def __init__(self, subcodes_context: int = 5, compression_model_framer... method restricted_context_attn_mask (line 48) | def restricted_context_attn_mask(self, seq_len: int, device: torch.dev... method _stage_attn_mask (line 69) | def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int, method _build_attn_masks (line 102) | def _build_attn_masks(self, compression_model_framerate: int, segment_... method generate (line 118) | def generate(self, method _generate_magnet (line 152) | def _generate_magnet(self, method _generate_stage (line 265) | def _generate_stage(self, method _construct_spans_mask (line 442) | def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, dev... method _least_probable_span_masking (line 461) | def _least_probable_span_masking(self, scores: torch.Tensor, num_maske... FILE: audiocraft/models/loaders.py function get_audiocraft_cache_dir (line 36) | def get_audiocraft_cache_dir() -> tp.Optional[str]: function _get_state_dict (line 40) | def _get_state_dict( function load_compression_model_ckpt (line 74) | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], ... function load_compression_model (line 78) | def load_compression_model( function load_lm_model_ckpt (line 94) | def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir... function _delete_param (line 98) | def _delete_param(cfg: DictConfig, full_name: str): function load_lm_model (line 111) | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', ... function load_lm_model_magnet (line 129) | def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compres... function load_jasco_model (line 158) | def load_jasco_model(file_or_url_or_id: tp.Union[Path, str], function load_mbd_ckpt (line 175) | def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], function load_diffusion_models (line 181) | def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], function load_audioseal_models (line 206) | def load_audioseal_models( FILE: audiocraft/models/magnet.py class MAGNeT (line 18) | class MAGNeT(BaseGenModel): method __init__ (line 23) | def __init__(self, **kwargs): method get_pretrained (line 30) | def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=... method set_generation_params (line 60) | def set_generation_params(self, use_sampling: bool = True, top_k: int ... FILE: audiocraft/models/multibanddiffusion.py class DiffusionProcess (line 25) | class DiffusionProcess: method __init__ (line 32) | def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule... method generate (line 36) | def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, class MultiBandDiffusion (line 48) | class MultiBandDiffusion: method __init__ (line 55) | def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: Compre... method sample_rate (line 61) | def sample_rate(self) -> int: method get_mbd_musicgen (line 65) | def get_mbd_musicgen(device=None): method get_mbd_24khz (line 81) | def get_mbd_24khz(bw: float = 3.0, method get_condition (line 113) | def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.... method get_emb (line 126) | def get_emb(self, codes: torch.Tensor): method generate (line 133) | def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = ... method re_eq (line 151) | def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 3... method regenerate (line 167) | def regenerate(self, wav: torch.Tensor, sample_rate: int): method tokens_to_wav (line 182) | def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): FILE: audiocraft/models/musicgen.py class MusicGen (line 40) | class MusicGen(BaseGenModel): method __init__ (line 51) | def __init__(self, name: str, compression_model: CompressionModel, lm:... method get_pretrained (line 57) | def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): method set_generation_params (line 96) | def set_generation_params(self, use_sampling: bool = True, top_k: int ... method set_style_conditioner_params (line 134) | def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length... method generate_with_chroma (line 155) | def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs... method _prepare_tokens_and_attributes (line 194) | def _prepare_tokens_and_attributes( method _generate_tokens (line 251) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], FILE: audiocraft/models/unet.py class Output (line 21) | class Output: function get_model (line 25) | def get_model(cfg, channels: int, side: int, num_steps: int): class ResBlock (line 33) | class ResBlock(nn.Module): method __init__ (line 34) | def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, method forward (line 52) | def forward(self, x): class DecoderLayer (line 58) | class DecoderLayer(nn.Module): method __init__ (line 59) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int... method forward (line 72) | def forward(self, x: torch.Tensor) -> torch.Tensor: class EncoderLayer (line 80) | class EncoderLayer(nn.Module): method __init__ (line 81) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int... method forward (line 94) | def forward(self, x: torch.Tensor) -> torch.Tensor: class BLSTM (line 107) | class BLSTM(nn.Module): method __init__ (line 110) | def __init__(self, dim, layers=2): method forward (line 115) | def forward(self, x): class DiffusionUnet (line 123) | class DiffusionUnet(nn.Module): method __init__ (line 124) | def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, gr... method forward (line 163) | def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], ... FILE: audiocraft/models/watermark.py class WMModel (line 17) | class WMModel(ABC, nn.Module): method get_watermark (line 24) | def get_watermark( method detect_watermark (line 36) | def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: class AudioSeal (line 49) | class AudioSeal(WMModel): method __init__ (line 54) | def __init__( method get_watermark (line 67) | def get_watermark( method detect_watermark (line 75) | def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: method forward (line 93) | def forward( # generator method get_pretrained (line 105) | def get_pretrained(name="base", device=None) -> WMModel: FILE: audiocraft/modules/activations.py class CustomGLU (line 13) | class CustomGLU(nn.Module): method __init__ (line 33) | def __init__(self, activation: nn.Module, dim: int = -1): method forward (line 38) | def forward(self, x: Tensor): class SwiGLU (line 44) | class SwiGLU(CustomGLU): method __init__ (line 52) | def __init__(self, dim: int = -1): class GeGLU (line 56) | class GeGLU(CustomGLU): method __init__ (line 64) | def __init__(self, dim: int = -1): class ReGLU (line 68) | class ReGLU(CustomGLU): method __init__ (line 76) | def __init__(self, dim: int = -1): function get_activation_fn (line 80) | def get_activation_fn( FILE: audiocraft/modules/chroma.py class ChromaExtractor (line 16) | class ChromaExtractor(nn.Module): method __init__ (line 29) | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: i... method forward (line 46) | def forward(self, wav: torch.Tensor) -> torch.Tensor: FILE: audiocraft/modules/codebooks_patterns.py class Pattern (line 22) | class Pattern: method __post_init__ (line 50) | def __post_init__(self): method _validate_layout (line 57) | def _validate_layout(self): method num_sequence_steps (line 79) | def num_sequence_steps(self): method max_delay (line 83) | def max_delay(self): method valid_layout (line 91) | def valid_layout(self): method starts_with_special_token (line 95) | def starts_with_special_token(self): method get_sequence_coords_with_timestep (line 98) | def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int... method get_steps_with_timestep (line 113) | def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) ... method get_first_step_with_timesteps (line 116) | def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = ... method _build_pattern_sequence_scatter_indexes (line 120) | def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q:... method build_pattern_sequence (line 154) | def build_pattern_sequence(self, z: torch.Tensor, special_token: int, ... method _build_reverted_sequence_scatter_indexes (line 181) | def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int... method revert_pattern_sequence (line 225) | def revert_pattern_sequence(self, s: torch.Tensor, special_token: int,... method revert_pattern_logits (line 250) | def revert_pattern_logits(self, logits: torch.Tensor, special_token: f... class CodebooksPatternProvider (line 272) | class CodebooksPatternProvider(ABC): method __init__ (line 290) | def __init__(self, n_q: int, cached: bool = True): method get_pattern (line 296) | def get_pattern(self, timesteps: int) -> Pattern: class DelayedPatternProvider (line 305) | class DelayedPatternProvider(CodebooksPatternProvider): method __init__ (line 328) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, method get_pattern (line 339) | def get_pattern(self, timesteps: int) -> Pattern: class ParallelPatternProvider (line 359) | class ParallelPatternProvider(DelayedPatternProvider): method __init__ (line 368) | def __init__(self, n_q: int, empty_initial: int = 0): class UnrolledPatternProvider (line 372) | class UnrolledPatternProvider(CodebooksPatternProvider): method __init__ (line 423) | def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = N... method _build_flattened_codebooks (line 437) | def _build_flattened_codebooks(self, delays: tp.List[int], flattening:... method _num_inner_steps (line 457) | def _num_inner_steps(self): method num_virtual_steps (line 462) | def num_virtual_steps(self, timesteps: int) -> int: method get_pattern (line 465) | def get_pattern(self, timesteps: int) -> Pattern: class CoarseFirstPattern (line 493) | class CoarseFirstPattern(CodebooksPatternProvider): method __init__ (line 507) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): method get_pattern (line 515) | def get_pattern(self, timesteps: int) -> Pattern: class MusicLMPattern (line 530) | class MusicLMPattern(CodebooksPatternProvider): method __init__ (line 538) | def __init__(self, n_q: int, group_by: int = 2): method get_pattern (line 542) | def get_pattern(self, timesteps: int) -> Pattern: FILE: audiocraft/modules/conditioners.py class JascoCondConst (line 46) | class JascoCondConst(Enum): class WavCondition (line 55) | class WavCondition(tp.NamedTuple): class JointEmbedCondition (line 63) | class JointEmbedCondition(tp.NamedTuple): class SymbolicCondition (line 72) | class SymbolicCondition(tp.NamedTuple): class ConditioningAttributes (line 78) | class ConditioningAttributes: method __getitem__ (line 84) | def __getitem__(self, item): method text_attributes (line 88) | def text_attributes(self): method wav_attributes (line 92) | def wav_attributes(self): method joint_embed_attributes (line 96) | def joint_embed_attributes(self): method symbolic_attributes (line 100) | def symbolic_attributes(self): method attributes (line 104) | def attributes(self): method to_flat_dict (line 112) | def to_flat_dict(self): method from_flat_dict (line 121) | def from_flat_dict(cls, x): class SegmentWithAttributes (line 129) | class SegmentWithAttributes(SegmentInfo): method to_condition_attributes (line 134) | def to_condition_attributes(self) -> ConditioningAttributes: function nullify_condition (line 138) | def nullify_condition(condition: ConditionType, dim: int = 1): function nullify_wav (line 165) | def nullify_wav(cond: WavCondition) -> WavCondition: function nullify_joint_embed (line 184) | def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: function nullify_chords (line 201) | def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 19... function nullify_melody (line 212) | def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition: function _drop_description_condition (line 223) | def _drop_description_condition(conditions: tp.List[ConditioningAttribut... class Tokenizer (line 239) | class Tokenizer: method __call__ (line 243) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch... class WhiteSpaceTokenizer (line 247) | class WhiteSpaceTokenizer(Tokenizer): method __init__ (line 256) | def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_... method __call__ (line 269) | def __call__(self, texts: tp.List[tp.Optional[str]], class NoopTokenizer (line 315) | class NoopTokenizer(Tokenizer): method __init__ (line 325) | def __init__(self, n_bins: int, pad_idx: int = 0): method __call__ (line 329) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch... class BaseConditioner (line 345) | class BaseConditioner(nn.Module): method __init__ (line 355) | def __init__(self, dim: int, output_dim: int): method tokenize (line 362) | def tokenize(self, *args, **kwargs) -> tp.Any: method forward (line 370) | def forward(self, inputs: tp.Any) -> ConditionType: class TextConditioner (line 383) | class TextConditioner(BaseConditioner): class LUTConditioner (line 387) | class LUTConditioner(TextConditioner): method __init__ (line 397) | def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: ... method tokenize (line 408) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Ten... method forward (line 414) | def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> Con... class T5Conditioner (line 422) | class T5Conditioner(TextConditioner): method __init__ (line 450) | def __init__(self, name: str, output_dim: int, finetune: bool, device:... method tokenize (line 490) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch... method forward (line 509) | def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: class WaveformConditioner (line 518) | class WaveformConditioner(BaseConditioner): method __init__ (line 529) | def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.d... method tokenize (line 535) | def tokenize(self, x: WavCondition) -> WavCondition: method _get_wav_embedding (line 540) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method _downsampling_factor (line 544) | def _downsampling_factor(self): method forward (line 548) | def forward(self, x: WavCondition) -> ConditionType: class ChromaStemConditioner (line 571) | class ChromaStemConditioner(WaveformConditioner): method __init__ (line 593) | def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, r... method _downsampling_factor (line 618) | def _downsampling_factor(self) -> int: method _load_eval_wavs (line 621) | def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) ->... method reset_eval_wavs (line 642) | def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: method has_eval_wavs (line 645) | def has_eval_wavs(self) -> bool: method _sample_eval_wavs (line 648) | def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: method _get_chroma_len (line 657) | def _get_chroma_len(self) -> int: method _get_stemmed_wav (line 664) | def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> tor... method _extract_chroma (line 678) | def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: method _compute_wav_embedding (line 684) | def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) ... method _get_full_chroma_for_cache (line 694) | def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: Wav... method _extract_chroma_chunk (line 702) | def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondi... method _get_wav_embedding (line 718) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method tokenize (line 752) | def tokenize(self, x: WavCondition) -> WavCondition: class FeatureExtractor (line 762) | class FeatureExtractor(WaveformConditioner): method __init__ (line 790) | def __init__( method _get_wav_embedding (line 827) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method _downsampling_factor (line 854) | def _downsampling_factor(self): method _get_mask_wav (line 860) | def _get_mask_wav(self, x: WavCondition, start: int) -> tp.Union[torch... class StyleConditioner (line 872) | class StyleConditioner(FeatureExtractor): method __init__ (line 897) | def __init__(self, transformer_scale: str = 'default', ds_factor: int ... method _get_wav_embedding (line 937) | def _get_wav_embedding(self, wav: WavCondition) -> torch.Tensor: method set_params (line 970) | def set_params(self, eval_q: int = 3, method _downsampling_factor (line 987) | def _downsampling_factor(self): method forward (line 991) | def forward(self, x: WavCondition) -> ConditionType: class JointEmbeddingConditioner (line 1006) | class JointEmbeddingConditioner(BaseConditioner): method __init__ (line 1020) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ... method _get_embed (line 1039) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,... method forward (line 1048) | def forward(self, x: JointEmbedCondition) -> ConditionType: method tokenize (line 1063) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: class CLAPEmbeddingConditioner (line 1067) | class CLAPEmbeddingConditioner(JointEmbeddingConditioner): method __init__ (line 1094) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ... method _tokenizer (line 1135) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: method _compute_text_embedding (line 1139) | def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: method _get_text_embedding_for_cache (line 1151) | def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], method _preprocess_wav (line 1158) | def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sam... method _compute_wav_embedding (line 1179) | def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, method _get_wav_embedding_for_cache (line 1214) | def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], method _extract_wav_embedding_chunk (line 1230) | def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: Jo... method _get_text_embedding (line 1251) | def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: method _get_wav_embedding (line 1265) | def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: method tokenize (line 1278) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: method _get_embed (line 1291) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,... function dropout_symbolic_conditions (line 1304) | def dropout_symbolic_conditions(sample: ConditioningAttributes, function dropout_condition (line 1337) | def dropout_condition(sample: ConditioningAttributes, class DropoutModule (line 1372) | class DropoutModule(nn.Module): method __init__ (line 1374) | def __init__(self, seed: int = 1234): class AttributeDropout (line 1380) | class AttributeDropout(DropoutModule): method __init__ (line 1397) | def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eva... method forward (line 1405) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List... method __repr__ (line 1423) | def __repr__(self): class ClassifierFreeGuidanceDropout (line 1427) | class ClassifierFreeGuidanceDropout(DropoutModule): method __init__ (line 1435) | def __init__(self, p: float, seed: int = 1234): method forward (line 1439) | def forward(self, samples: tp.List[ConditioningAttributes], method __repr__ (line 1465) | def __repr__(self): class ConditioningProvider (line 1469) | class ConditioningProvider(nn.Module): method __init__ (line 1476) | def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device... method joint_embed_conditions (line 1482) | def joint_embed_conditions(self): method has_joint_embed_conditions (line 1486) | def has_joint_embed_conditions(self): method text_conditions (line 1490) | def text_conditions(self): method wav_conditions (line 1494) | def wav_conditions(self): method has_wav_condition (line 1498) | def has_wav_condition(self): method tokenize (line 1501) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict... method forward (line 1529) | def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, Con... method _collate_text (line 1547) | def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> t... method _collate_wavs (line 1574) | def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> t... method _collate_joint_embeds (line 1618) | def _collate_joint_embeds(self, samples: tp.List[ConditioningAttribute... class ConditionFuser (line 1672) | class ConditionFuser(StreamingModule): method __init__ (line 1689) | def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attent... method forward (line 1703) | def forward( FILE: audiocraft/modules/conv.py function apply_parametrization_norm (line 21) | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): function get_norm_module (line 33) | def get_norm_module(module: nn.Module, causal: bool = False, norm: str =... function get_extra_padding_for_conv1d (line 47) | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stri... function pad_for_conv1d (line 56) | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, paddi... function pad1d (line 71) | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'co... function unpad1d (line 91) | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): class NormConv1d (line 100) | class NormConv1d(nn.Module): method __init__ (line 104) | def __init__(self, *args, causal: bool = False, norm: str = 'none', method forward (line 111) | def forward(self, x): class NormConv2d (line 117) | class NormConv2d(nn.Module): method __init__ (line 121) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str... method forward (line 127) | def forward(self, x): class NormConvTranspose1d (line 133) | class NormConvTranspose1d(nn.Module): method __init__ (line 137) | def __init__(self, *args, causal: bool = False, norm: str = 'none', method forward (line 144) | def forward(self, x): class NormConvTranspose2d (line 150) | class NormConvTranspose2d(nn.Module): method __init__ (line 154) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str... method forward (line 159) | def forward(self, x): class StreamableConv1d (line 165) | class StreamableConv1d(nn.Module): method __init__ (line 169) | def __init__(self, in_channels: int, out_channels: int, method forward (line 185) | def forward(self, x): class StreamableConvTranspose1d (line 204) | class StreamableConvTranspose1d(nn.Module): method __init__ (line 208) | def __init__(self, in_channels: int, out_channels: int, method forward (line 221) | def forward(self, x): FILE: audiocraft/modules/diffusion_schedule.py function betas_from_alpha_bar (line 20) | def betas_from_alpha_bar(alpha_bar): class SampleProcessor (line 25) | class SampleProcessor(torch.nn.Module): method project_sample (line 26) | def project_sample(self, x: torch.Tensor): method return_sample (line 30) | def return_sample(self, z: torch.Tensor): class MultiBandProcessor (line 35) | class MultiBandProcessor(SampleProcessor): method __init__ (line 57) | def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, method mean (line 77) | def mean(self): method std (line 82) | def std(self): method target_std (line 87) | def target_std(self): method project_sample (line 91) | def project_sample(self, x: torch.Tensor): method return_sample (line 104) | def return_sample(self, x: torch.Tensor): class NoiseSchedule (line 112) | class NoiseSchedule: method __init__ (line 127) | def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_s... method get_beta (line 149) | def get_beta(self, step: tp.Union[int, torch.Tensor]): method get_initial_noise (line 155) | def get_initial_noise(self, x: torch.Tensor): method get_alpha_bar (line 160) | def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]]... method get_training_item (line 169) | def get_training_item(self, x: torch.Tensor, tensor_step: bool = False... method generate (line 192) | def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.... method generate_subsampled (line 238) | def generate_subsampled(self, model: torch.nn.Module, initial: torch.T... FILE: audiocraft/modules/jasco_conditioners.py class MelodyConditioner (line 15) | class MelodyConditioner(BaseConditioner): method __init__ (line 23) | def __init__(self, card: int, out_dim: int, device: tp.Union[torch.dev... method tokenize (line 27) | def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: method forward (line 30) | def forward(self, x: SymbolicCondition) -> ConditionType: class ChordsEmbConditioner (line 36) | class ChordsEmbConditioner(BaseConditioner): method __init__ (line 44) | def __init__(self, card: int, out_dim: int, device: tp.Union[torch.dev... method tokenize (line 50) | def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: method forward (line 53) | def forward(self, x: SymbolicCondition) -> ConditionType: class DrumsConditioner (line 59) | class DrumsConditioner(WaveformConditioner): method __init__ (line 60) | def __init__(self, out_dim: int, sample_rate: int, blurring_factor: in... method create_embedding_cache (line 93) | def create_embedding_cache(self, cache_path): method _get_drums_stem (line 100) | def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torc... method _temporal_blur (line 111) | def _temporal_blur(self, z: torch.Tensor): method _extract_coarse_drum_codes (line 125) | def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: i... method _calc_coarse_drum_codes_for_cache (line 140) | def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path], method _load_drum_codes_chunk (line 161) | def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor,... method _get_wav_embedding (line 179) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: method tokenize (line 206) | def tokenize(self, x: WavCondition) -> WavCondition: class JascoConditioningProvider (line 216) | class JascoConditioningProvider(ConditioningProvider): method __init__ (line 224) | def __init__(self, *args, method tokenize (line 233) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict... method _collate_symbolic (line 262) | def _collate_symbolic(self, samples: tp.List[ConditioningAttributes], FILE: audiocraft/modules/lstm.py class StreamableLSTM (line 10) | class StreamableLSTM(nn.Module): method __init__ (line 14) | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = T... method forward (line 19) | def forward(self, x): FILE: audiocraft/modules/rope.py class XPos (line 13) | class XPos(nn.Module): method __init__ (line 24) | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int =... method get_decay (line 38) | def get_decay(self, start: int, end: int): class RotaryEmbedding (line 49) | class RotaryEmbedding(nn.Module): method __init__ (line 60) | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool =... method get_rotation (line 75) | def get_rotation(self, start: int, end: int): method rotate (line 84) | def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, i... method rotate_qk (line 106) | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int... FILE: audiocraft/modules/seanet.py class SEANetResnetBlock (line 16) | class SEANetResnetBlock(nn.Module): method __init__ (line 33) | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dila... method forward (line 59) | def forward(self, x): class SEANetEncoder (line 63) | class SEANetEncoder(nn.Module): method __init__ (line 91) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:... method forward (line 152) | def forward(self, x): class SEANetDecoder (line 156) | class SEANetDecoder(nn.Module): method __init__ (line 186) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:... method forward (line 256) | def forward(self, z): FILE: audiocraft/modules/streaming.py class StreamingModule (line 20) | class StreamingModule(nn.Module): method __init__ (line 43) | def __init__(self) -> None: method _apply_named_streaming (line 48) | def _apply_named_streaming(self, fn: tp.Any): method _set_streaming (line 53) | def _set_streaming(self, streaming: bool): method streaming (line 59) | def streaming(self): method reset_streaming (line 68) | def reset_streaming(self): method get_streaming_state (line 75) | def get_streaming_state(self) -> State: method set_streaming_state (line 88) | def set_streaming_state(self, state: State): method flush (line 107) | def flush(self, x: tp.Optional[torch.Tensor] = None): class StreamingSequential (line 122) | class StreamingSequential(StreamingModule, nn.Sequential): method flush (line 125) | def flush(self, x: tp.Optional[torch.Tensor] = None): FILE: audiocraft/modules/transformer.py function set_efficient_attention_backend (line 31) | def set_efficient_attention_backend(backend: str = 'torch'): function _get_attention_time_dimension (line 38) | def _get_attention_time_dimension(memory_efficient: bool) -> int: function _is_profiled (line 45) | def _is_profiled() -> bool: function create_norm_fn (line 54) | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: function create_sin_embedding (line 70) | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: ... function expand_repeated_kv (line 92) | def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bo... class LayerScale (line 112) | class LayerScale(nn.Module): method __init__ (line 123) | def __init__(self, channels: int, init: float = 1e-4, channel_last: bo... method forward (line 131) | def forward(self, x: torch.Tensor): class StreamingMultiheadAttention (line 138) | class StreamingMultiheadAttention(StreamingModule): method __init__ (line 164) | def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.... method _load_from_state_dict (line 224) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): method _get_mask (line 233) | def _get_mask(self, current_steps: int, device: torch.device, dtype: t... method _complete_kv (line 266) | def _complete_kv(self, k, v): method _apply_rope (line 300) | def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): method forward (line 315) | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch... class StreamingTransformerLayer (line 454) | class StreamingTransformerLayer(nn.TransformerEncoderLayer): method __init__ (line 488) | def __init__(self, d_model: int, num_heads: int, dim_feedforward: int ... method _cross_attention_block (line 542) | def _cross_attention_block(self, src: torch.Tensor, method forward (line 550) | def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tenso... class StreamingTransformer (line 577) | class StreamingTransformer(StreamingModule): method __init__ (line 614) | def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_... method _apply_layer (line 662) | def _apply_layer(self, layer, *args, **kwargs): method forward (line 693) | def forward(self, x: torch.Tensor, *args, **kwargs): method make_optim_group (line 715) | def make_optim_group(self): function _verify_xformers_memory_efficient_compat (line 726) | def _verify_xformers_memory_efficient_compat(): function _verify_xformers_internal_compat (line 740) | def _verify_xformers_internal_compat(): function _is_custom (line 754) | def _is_custom(custom: bool, memory_efficient: bool): FILE: audiocraft/modules/unet_transformer.py class UnetTransformer (line 6) | class UnetTransformer(StreamingTransformer): method __init__ (line 20) | def __init__(self, d_model: int, num_layers: int, skip_connections: bo... method forward (line 32) | def forward(self, x: torch.Tensor, *args, **kwargs): FILE: audiocraft/modules/watermark.py function pad (line 13) | def pad( function mix (line 42) | def mix( FILE: audiocraft/optim/cosine_lr_scheduler.py class CosineLRScheduler (line 13) | class CosineLRScheduler(_LRScheduler): method __init__ (line 23) | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_step... method _get_sched_lr (line 33) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 47) | def get_lr(self): FILE: audiocraft/optim/dadam.py function to_real (line 19) | def to_real(x): class DAdaptAdam (line 26) | class DAdaptAdam(torch.optim.Optimizer): method __init__ (line 58) | def __init__(self, params, lr=1.0, method supports_memory_efficient_fp16 (line 95) | def supports_memory_efficient_fp16(self): method supports_flat_params (line 99) | def supports_flat_params(self): method step (line 102) | def step(self, closure=None): FILE: audiocraft/optim/ema.py function _get_all_non_persistent_buffers_set (line 17) | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "... function _get_named_tensors (line 32) | def _get_named_tensors(module: nn.Module): class ModuleDictEMA (line 40) | class ModuleDictEMA: method __init__ (line 45) | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, method _init (line 55) | def _init(self): method step (line 64) | def step(self): method state_dict (line 78) | def state_dict(self): method load_state_dict (line 81) | def load_state_dict(self, state): FILE: audiocraft/optim/fsdp.py function is_fsdp_used (line 22) | def is_fsdp_used() -> bool: function is_sharded_tensor (line 32) | def is_sharded_tensor(x: tp.Any) -> bool: function switch_to_full_state_dict (line 37) | def switch_to_full_state_dict(models: tp.List[FSDP]): function wrap_with_fsdp (line 51) | def wrap_with_fsdp(cfg, model: torch.nn.Module, function purge_fsdp (line 120) | def purge_fsdp(model: FSDP): class _FSDPFixStateDict (line 149) | class _FSDPFixStateDict(FSDP): method _name_without_fsdp_prefix (line 151) | def _name_without_fsdp_prefix(name: str) -> str: method state_dict (line 157) | def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type... method load_state_dict (line 164) | def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore function _fix_post_backward_hook (line 186) | def _fix_post_backward_hook(): FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py class InverseSquareRootLRScheduler (line 13) | class InverseSquareRootLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini... method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 37) | def get_lr(self): FILE: audiocraft/optim/linear_warmup_lr_scheduler.py class LinearWarmupLRScheduler (line 13) | class LinearWarmupLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini... method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 34) | def get_lr(self): FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py class PolynomialDecayLRScheduler (line 11) | class PolynomialDecayLRScheduler(_LRScheduler): method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_step... method _get_sched_lr (line 31) | def _get_sched_lr(self, lr: float, step: int): method get_lr (line 46) | def get_lr(self): FILE: audiocraft/quantization/base.py class QuantizedResult (line 19) | class QuantizedResult: class BaseQuantizer (line 27) | class BaseQuantizer(nn.Module): method forward (line 31) | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: method encode (line 40) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 44) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 49) | def total_codebooks(self): method num_codebooks (line 54) | def num_codebooks(self): method set_num_codebooks (line 58) | def set_num_codebooks(self, n: int): class DummyQuantizer (line 63) | class DummyQuantizer(BaseQuantizer): method __init__ (line 66) | def __init__(self): method forward (line 69) | def forward(self, x: torch.Tensor, frame_rate: int): method encode (line 73) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 80) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 88) | def total_codebooks(self): method num_codebooks (line 93) | def num_codebooks(self): method set_num_codebooks (line 97) | def set_num_codebooks(self, n: int): FILE: audiocraft/quantization/core_vq.py function exists (line 16) | def exists(val: tp.Optional[tp.Any]) -> bool: function default (line 20) | def default(val: tp.Any, d: tp.Any) -> tp.Any: function l2norm (line 24) | def l2norm(t): function ema_inplace (line 28) | def ema_inplace(moving_avg, new, decay: float): function laplace_smoothing (line 32) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): function uniform_init (line 36) | def uniform_init(*shape: int): function sample_vectors (line 42) | def sample_vectors(samples, num: int): function kmeans (line 53) | def kmeans(samples, num_clusters: int, num_iters: int = 10): function orthogonal_loss_fn (line 78) | def orthogonal_loss_fn(t): class EuclideanCodebook (line 87) | class EuclideanCodebook(nn.Module): method __init__ (line 103) | def __init__( method init_embed_ (line 130) | def init_embed_(self, data): method replace_ (line 142) | def replace_(self, samples, mask): method expire_codes_ (line 148) | def expire_codes_(self, batch_samples): method preprocess (line 160) | def preprocess(self, x): method quantize (line 164) | def quantize(self, x): method postprocess_emb (line 174) | def postprocess_emb(self, embed_ind, shape): method dequantize (line 177) | def dequantize(self, embed_ind): method encode (line 181) | def encode(self, x): method decode (line 191) | def decode(self, embed_ind): method forward (line 195) | def forward(self, x): class VectorQuantization (line 222) | class VectorQuantization(nn.Module): method __init__ (line 244) | def __init__( method codebook (line 283) | def codebook(self): method inited (line 287) | def inited(self): method _preprocess (line 290) | def _preprocess(self, x): method _postprocess (line 295) | def _postprocess(self, quantize): method encode (line 300) | def encode(self, x): method decode (line 306) | def decode(self, embed_ind): method forward (line 312) | def forward(self, x): class ResidualVectorQuantization (line 351) | class ResidualVectorQuantization(nn.Module): method __init__ (line 356) | def __init__(self, *, num_quantizers, **kwargs): method forward (line 362) | def forward(self, x, n_q: tp.Optional[int] = None): method encode (line 386) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor... method decode (line 398) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor: FILE: audiocraft/quantization/vq.py class ResidualVectorQuantizer (line 16) | class ResidualVectorQuantizer(BaseQuantizer): method __init__ (line 35) | def __init__( method forward (line 76) | def forward(self, x: torch.Tensor, frame_rate: int): method encode (line 87) | def encode(self, x: torch.Tensor) -> torch.Tensor: method decode (line 98) | def decode(self, codes: torch.Tensor) -> torch.Tensor: method total_codebooks (line 106) | def total_codebooks(self): method num_codebooks (line 110) | def num_codebooks(self): method set_num_codebooks (line 113) | def set_num_codebooks(self, n: int): FILE: audiocraft/solvers/audiogen.py class AudioGenSolver (line 10) | class AudioGenSolver(musicgen.MusicGenSolver): FILE: audiocraft/solvers/base.py class StandardSolver (line 27) | class StandardSolver(ABC, flashy.BaseSolver): method __init__ (line 38) | def __init__(self, cfg: omegaconf.DictConfig): method autocast (line 98) | def autocast(self): method _get_state_source (line 102) | def _get_state_source(self, name) -> flashy.state.StateDictSource: method best_metric_name (line 107) | def best_metric_name(self) -> tp.Optional[str]: method register_best_state (line 114) | def register_best_state(self, *args: str): method register_ema (line 127) | def register_ema(self, *args: str): method wrap_with_fsdp (line 141) | def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): method update_best_state_from_stage (line 147) | def update_best_state_from_stage(self, stage_name: str = 'valid'): method _load_new_state_dict (line 189) | def _load_new_state_dict(self, state_dict: dict) -> dict: method swap_best_state (line 198) | def swap_best_state(self): method swap_ema_state (line 210) | def swap_ema_state(self): method is_training (line 226) | def is_training(self): method log_model_summary (line 229) | def log_model_summary(self, model: nn.Module): method build_model (line 236) | def build_model(self): method initialize_ema (line 240) | def initialize_ema(self): method build_dataloaders (line 256) | def build_dataloaders(self): method show (line 261) | def show(self): method log_updates (line 266) | def log_updates(self): method checkpoint_path (line 270) | def checkpoint_path(self, **kwargs): method epoch_checkpoint_path (line 274) | def epoch_checkpoint_path(self, epoch: int, **kwargs): method checkpoint_path_with_name (line 278) | def checkpoint_path_with_name(self, name: str, **kwargs): method save_checkpoints (line 282) | def save_checkpoints(self): method load_from_pretrained (line 311) | def load_from_pretrained(self, name: str) -> dict: method load_checkpoints (line 314) | def load_checkpoints(self, load_best: bool = False, ignore_state_keys:... method restore (line 432) | def restore(self, load_best: bool = False, replay_metrics: bool = False, method commit (line 456) | def commit(self, save_checkpoints: bool = True): method run_epoch (line 466) | def run_epoch(self): method run (line 489) | def run(self): method should_stop_training (line 501) | def should_stop_training(self) -> bool: method should_run_stage (line 505) | def should_run_stage(self, stage_name) -> bool: method run_step (line 513) | def run_step(self, idx: int, batch: tp.Any, metrics: dict): method common_train_valid (line 517) | def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): method train (line 559) | def train(self): method valid (line 563) | def valid(self): method evaluate (line 568) | def evaluate(self): method generate (line 573) | def generate(self): method run_one_stage (line 577) | def run_one_stage(self, stage_name: str): method get_eval_solver_from_sig (line 597) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, FILE: audiocraft/solvers/builders.py class DatasetType (line 37) | class DatasetType(Enum): function get_solver (line 44) | def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: function get_optim_parameter_groups (line 68) | def get_optim_parameter_groups(model: nn.Module): function get_optimizer (line 95) | def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]]... function get_lr_scheduler (line 124) | def get_lr_scheduler(optimizer: torch.optim.Optimizer, function get_ema (line 168) | def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp... function get_loss (line 189) | def get_loss(loss_name: str, cfg: omegaconf.DictConfig): function get_balancer (line 206) | def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictC... function get_adversary (line 212) | def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: function get_adversarial_losses (line 223) | def get_adversarial_losses(cfg) -> nn.ModuleDict: function get_visqol (line 256) | def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: function get_fad (line 262) | def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMe... function get_kldiv (line 270) | def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: function get_text_consistency (line 280) | def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsi... function get_chroma_cosine_similarity (line 290) | def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.C... function get_audio_datasets (line 297) | def get_audio_datasets(cfg: omegaconf.DictConfig, FILE: audiocraft/solvers/compression.py class CompressionSolver (line 27) | class CompressionSolver(base.StandardSolver): method __init__ (line 34) | def __init__(self, cfg: omegaconf.DictConfig): method best_metric_name (line 55) | def best_metric_name(self) -> tp.Optional[str]: method build_model (line 59) | def build_model(self): method build_dataloaders (line 68) | def build_dataloaders(self): method show (line 72) | def show(self): method run_step (line 83) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): method run_epoch (line 176) | def run_epoch(self): method evaluate (line 183) | def evaluate(self): method generate (line 213) | def generate(self): method load_from_pretrained (line 236) | def load_from_pretrained(self, name: str) -> dict: method model_from_checkpoint (line 269) | def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], method wrapped_model_from_checkpoint (line 304) | def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, function evaluate_audio_reconstruction (line 320) | def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor,... FILE: audiocraft/solvers/diffusion.py class PerStageMetrics (line 25) | class PerStageMetrics: method __init__ (line 30) | def __init__(self, num_steps: int, num_stages: int = 4): method __call__ (line 34) | def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): class DataProcess (line 53) | class DataProcess: method __init__ (line 67) | def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, us... method process_data (line 95) | def process_data(self, x, metric=False): method inverse_process (line 107) | def inverse_process(self, x): class DiffusionSolver (line 114) | class DiffusionSolver(base.StandardSolver): method __init__ (line 122) | def __init__(self, cfg: omegaconf.DictConfig): method best_metric_name (line 155) | def best_metric_name(self) -> tp.Optional[str]: method get_condition (line 162) | def get_condition(self, wav: torch.Tensor) -> torch.Tensor: method build_model (line 168) | def build_model(self): method build_dataloaders (line 178) | def build_dataloaders(self): method show (line 182) | def show(self): method run_step (line 186) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): method run_epoch (line 215) | def run_epoch(self): method evaluate (line 223) | def evaluate(self): method regenerate (line 253) | def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] =... method generate (line 262) | def generate(self): FILE: audiocraft/solvers/jasco.py class JascoSolver (line 19) | class JascoSolver(musicgen.MusicGenSolver): method __init__ (line 25) | def __init__(self, cfg: DictConfig): method build_model (line 39) | def build_model(self) -> None: method _get_latents (line 55) | def _get_latents(self, audio): method _prepare_latents_and_attributes (line 60) | def _prepare_latents_and_attributes( method _normalized_latents (line 104) | def _normalized_latents(self, latents: torch.Tensor) -> torch.Tensor: method _unnormalized_latents (line 108) | def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor: method _z (line 112) | def _z(self, z_0: torch.Tensor, z_1: torch.Tensor, t: torch.Tensor, si... method _vector_field (line 116) | def _vector_field(self, z_0: torch.Tensor, z_1: torch.Tensor, sigma_mi... method _compute_loss (line 121) | def _compute_loss(self, t: torch.Tensor, v_theta: torch.Tensor, v: tor... method run_step (line 134) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg... method _decode_latents (line 216) | def _decode_latents(self, latents): method run_generate_step (line 220) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm... FILE: audiocraft/solvers/magnet.py class MagnetSolver (line 21) | class MagnetSolver(musicgen.MusicGenSolver): method __init__ (line 25) | def __init__(self, cfg: DictConfig): method build_model (line 47) | def build_model(self) -> None: method _calc_mean_maskrate_to_u_LUT (line 53) | def _calc_mean_maskrate_to_u_LUT(self, T: int): method _non_spans_mask (line 87) | def _non_spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, de... method _spans_mask (line 102) | def _spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device... method _get_mask (line 127) | def _get_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: ... method _compute_cross_entropy_magnet (line 143) | def _compute_cross_entropy_magnet(self, logits: torch.Tensor, method run_step (line 172) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg... class AudioMagnetSolver (line 271) | class AudioMagnetSolver(MagnetSolver): FILE: audiocraft/solvers/musicgen.py class MusicGenSolver (line 32) | class MusicGenSolver(base.StandardSolver): method __init__ (line 39) | def __init__(self, cfg: omegaconf.DictConfig): method get_eval_solver_from_sig (line 66) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, method get_formatter (line 102) | def get_formatter(self, stage_name: str) -> flashy.Formatter: method best_metric_name (line 111) | def best_metric_name(self) -> tp.Optional[str]: method initialize_optimization (line 114) | def initialize_optimization(self) -> None: method build_model (line 140) | def build_model(self) -> None: method build_dataloaders (line 171) | def build_dataloaders(self) -> None: method show (line 175) | def show(self) -> None: method load_state_dict (line 182) | def load_state_dict(self, state: dict) -> None: method load_from_pretrained (line 209) | def load_from_pretrained(self, name: str): method _compute_cross_entropy (line 219) | def _compute_cross_entropy( method _get_audio_tokens (line 253) | def _get_audio_tokens(self, audio: torch.Tensor): method _prepare_tokens_and_attributes (line 259) | def _prepare_tokens_and_attributes( method run_step (line 363) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg... method run_generate_step (line 445) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm... method generate_audio (line 511) | def generate_audio(self) -> dict: method generate (line 611) | def generate(self) -> dict: method run_epoch (line 617) | def run_epoch(self): method train (line 623) | def train(self): method evaluate_audio_generation (line 636) | def evaluate_audio_generation(self) -> dict: method evaluate (line 741) | def evaluate(self) -> dict: FILE: audiocraft/solvers/watermark.py function get_encodec_audio_effect (line 45) | def get_encodec_audio_effect(encodec_cfg: DictConfig, sr: int) -> tp.Dict: function random_message (line 69) | def random_message(nbits: int, batch_size: int) -> torch.Tensor: class WatermarkSolver (line 76) | class WatermarkSolver(base.StandardSolver): method __init__ (line 79) | def __init__(self, cfg: DictConfig): method _init_losses (line 93) | def _init_losses(self): method _init_augmentations (line 133) | def _init_augmentations(self): method best_metric_name (line 162) | def best_metric_name(self) -> tp.Optional[str]: method build_model (line 166) | def build_model(self): method build_dataloaders (line 176) | def build_dataloaders(self): method show (line 180) | def show(self): method crop (line 185) | def crop( method run_step (line 251) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): method run_epoch (line 393) | def run_epoch(self): method evaluate (line 400) | def evaluate(self) -> dict: method generate (line 533) | def generate(self): method load_from_pretrained (line 576) | def load_from_pretrained(self, name: str) -> dict: method model_from_checkpoint (line 580) | def model_from_checkpoint( function evaluate_localizations (line 617) | def evaluate_localizations(predictions, true_predictions, name): function evaluate_augmentations (line 633) | def evaluate_augmentations( function evaluate_audio_watermark (line 654) | def evaluate_audio_watermark( function tensor_pesq (line 672) | def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): function compute_accuracy (line 677) | def compute_accuracy(positive, negative): function compute_FPR (line 685) | def compute_FPR(negative): function compute_FNR (line 691) | def compute_FNR(positive): function _bit_acc (line 697) | def _bit_acc(decoded, original): function compute_bit_acc (line 702) | def compute_bit_acc(positive, original, mask=None): FILE: audiocraft/train.py function resolve_config_dset_paths (line 30) | def resolve_config_dset_paths(cfg): function get_solver (line 38) | def get_solver(cfg): function get_solver_from_xp (line 52) | def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, ... function get_solver_from_sig (line 97) | def get_solver_from_sig(sig: str, *args, **kwargs): function init_seed_and_system (line 105) | def init_seed_and_system(cfg): function main (line 131) | def main(cfg): FILE: audiocraft/utils/audio_effects.py function select_audio_effects (line 24) | def select_audio_effects( function get_audio_effects (line 84) | def get_audio_effects(cfg: omegaconf.DictConfig): function audio_effect_return (line 99) | def audio_effect_return( function generate_pink_noise (line 109) | def generate_pink_noise(length: int) -> torch.Tensor: function compress_with_encodec (line 121) | def compress_with_encodec( function apply_compression_skip_grad (line 146) | def apply_compression_skip_grad(tensor: torch.Tensor, compression_fn, **... class AudioEffects (line 177) | class AudioEffects: method speed (line 179) | def speed( method updownresample (line 206) | def updownresample( method echo (line 223) | def echo( method random_noise (line 278) | def random_noise( method pink_noise (line 289) | def pink_noise( method lowpass_filter (line 302) | def lowpass_filter( method highpass_filter (line 315) | def highpass_filter( method bandpass_filter (line 328) | def bandpass_filter( method smooth (line 358) | def smooth( method boost_audio (line 390) | def boost_audio( method duck_audio (line 399) | def duck_audio( method identity (line 408) | def identity( method mp3_compression (line 414) | def mp3_compression( method aac_compression (line 436) | def aac_compression( FILE: audiocraft/utils/autocast.py class TorchAutocast (line 10) | class TorchAutocast: method __init__ (line 21) | def __init__(self, enabled: bool, *args, **kwargs): method __enter__ (line 24) | def __enter__(self): method __exit__ (line 37) | def __exit__(self, *args, **kwargs): FILE: audiocraft/utils/best_state.py class BestStateDictManager (line 21) | class BestStateDictManager(flashy.state.StateDictSource): method __init__ (line 36) | def __init__(self, device: tp.Union[torch.device, str] = 'cpu', method _get_parameter_ids (line 43) | def _get_parameter_ids(self, state_dict): method _validate_no_parameter_ids_overlap (line 46) | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): method update (line 53) | def update(self, name: str, source: flashy.state.StateDictSource): method register (line 58) | def register(self, name: str, source: flashy.state.StateDictSource): method state_dict (line 75) | def state_dict(self) -> flashy.state.StateDict: method load_state_dict (line 78) | def load_state_dict(self, state: flashy.state.StateDict): FILE: audiocraft/utils/cache.py function get_full_embed (line 24) | def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device... class EmbeddingCache (line 39) | class EmbeddingCache: method __init__ (line 60) | def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[s... method _get_cache_path (line 79) | def _get_cache_path(self, path: tp.Union[Path, str]): method _get_full_embed_from_cache (line 85) | def _get_full_embed_from_cache(cache: Path): method get_embed_from_cache (line 94) | def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> tor... method populate_embed_cache (line 124) | def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: class CachedBatchWriter (line 161) | class CachedBatchWriter: method __init__ (line 180) | def __init__(self, cache_folder: Path): method start_epoch (line 185) | def start_epoch(self, epoch: int): method _get_zip_path (line 193) | def _get_zip_path(cache_folder: Path, epoch: int, index: int): method _zip_path (line 197) | def _zip_path(self): method save (line 201) | def save(self, *content): class CachedBatchLoader (line 224) | class CachedBatchLoader: method __init__ (line 237) | def __init__(self, cache_folder: Path, batch_size: int, method __len__ (line 246) | def __len__(self): method start_epoch (line 250) | def start_epoch(self, epoch: int): method _zip_path (line 255) | def _zip_path(self, index: int): method _load_one (line 259) | def _load_one(self, index: int): method __iter__ (line 297) | def __iter__(self): FILE: audiocraft/utils/checkpoint.py class CheckpointSource (line 22) | class CheckpointSource(Enum): function checkpoint_name (line 28) | def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int... function is_sharded_checkpoint (line 51) | def is_sharded_checkpoint(path: Path) -> bool: function resolve_checkpoint_path (line 56) | def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.O... function load_checkpoint (line 87) | def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> ... function save_checkpoint (line 98) | def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bo... function flush_stale_checkpoints (line 104) | def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optiona... function check_sharded_checkpoint (line 125) | def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_pat... function _safe_save_checkpoint (line 142) | def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_shard... FILE: audiocraft/utils/cluster.py class ClusterType (line 19) | class ClusterType(Enum): function _guess_cluster_type (line 27) | def _guess_cluster_type() -> ClusterType: function get_cluster_type (line 45) | def get_cluster_type( function get_slurm_parameters (line 54) | def get_slurm_parameters( FILE: audiocraft/utils/deadlock.py class DeadlockDetect (line 18) | class DeadlockDetect: method __init__ (line 19) | def __init__(self, use: bool = False, timeout: float = 120.): method update (line 24) | def update(self, stage: str): method __enter__ (line 28) | def __enter__(self): method __exit__ (line 33) | def __exit__(self, exc_type, exc_val, exc_tb): method _detector_thread (line 38) | def _detector_thread(self): FILE: audiocraft/utils/export.py function export_encodec (line 20) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un... function export_pretrained_compression_model (line 36) | def export_pretrained_compression_model(pretrained_encodec: str, out_fil... function export_lm (line 61) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P... FILE: audiocraft/utils/export_legacy.py function _clean_lm_cfg (line 20) | def _clean_lm_cfg(cfg: DictConfig): function export_encodec (line 41) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un... function export_lm (line 55) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P... FILE: audiocraft/utils/notebook.py function display_audio (line 17) | def display_audio(samples: torch.Tensor, sample_rate: int): FILE: audiocraft/utils/profiler.py class Profiler (line 17) | class Profiler: method __init__ (line 20) | def __init__(self, module: torch.nn.Module, enabled: bool = False): method step (line 28) | def step(self): method __enter__ (line 32) | def __enter__(self): method __exit__ (line 36) | def __exit__(self, exc_type, exc_value, exc_tb): FILE: audiocraft/utils/samples/manager.py class ReferenceSample (line 42) | class ReferenceSample: class Sample (line 49) | class Sample: method __hash__ (line 59) | def __hash__(self): method audio (line 62) | def audio(self) -> tp.Tuple[torch.Tensor, int]: method audio_prompt (line 65) | def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: method audio_reference (line 68) | def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: class SampleManager (line 72) | class SampleManager: method __init__ (line 89) | def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = Fal... method latest_epoch (line 98) | def latest_epoch(self): method _load_samples (line 102) | def _load_samples(self): method _load_sample (line 110) | def _load_sample(json_file: Path) -> Sample: method _init_hash (line 126) | def _init_hash(self): method _get_tensor_id (line 129) | def _get_tensor_id(self, tensor: torch.Tensor) -> str: method _get_sample_id (line 134) | def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Ten... method _store_audio (line 173) | def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: ... method add_sample (line 196) | def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int ... method add_samples (line 238) | def add_samples(self, samples_wavs: torch.Tensor, epoch: int, method get_samples (line 269) | def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_pr... function slugify (line 305) | def slugify(value: tp.Any, allow_unicode: bool = False): function _match_stable_samples (line 328) | def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp... function _match_unstable_samples (line 343) | def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> ... function get_samples_for_xps (line 358) | def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str,... FILE: audiocraft/utils/utils.py function model_hash (line 25) | def model_hash(model: torch.nn.Module) -> str: function dict_from_config (line 35) | def dict_from_config(cfg: omegaconf.DictConfig) -> dict: function random_subset (line 48) | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.ut... function get_loader (line 57) | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, function get_dataset_from_loader (line 80) | def get_dataset_from_loader(dataloader): function multinomial (line 88) | def multinomial(input: torch.Tensor, num_samples: int, replacement=False... function sample_top_k (line 108) | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: function sample_top_p (line 125) | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: class DummyPoolExecutor (line 144) | class DummyPoolExecutor: class DummyResult (line 148) | class DummyResult: method __init__ (line 149) | def __init__(self, func, *args, **kwargs): method result (line 154) | def result(self): method __init__ (line 157) | def __init__(self, workers, mp_context=None): method submit (line 160) | def submit(self, func, *args, **kwargs): method __enter__ (line 163) | def __enter__(self): method __exit__ (line 166) | def __exit__(self, exc_type, exc_value, exc_tb): function get_pool_executor (line 170) | def get_pool_executor(num_workers: int, mp_context=None): function length_to_mask (line 174) | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = No... function hash_trick (line 190) | def hash_trick(word: str, vocab_size: int) -> int: function with_rank_rng (line 203) | def with_rank_rng(base_seed: int = 1234): function collate (line 226) | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[to... function copy_state (line 250) | def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', function swap_state (line 264) | def swap_state(model, state, **kwargs): function warn_once (line 274) | def warn_once(logger, msg): function is_jsonable (line 279) | def is_jsonable(x: tp.Any): function load_clap_state_dict (line 288) | def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): function construct_frame_chords (line 300) | def construct_frame_chords( FILE: demos/jasco_app.py function _call_nostderr (line 32) | def _call_nostderr(*args, **kwargs): function interrupt (line 45) | def interrupt(): class FileCleaner (line 50) | class FileCleaner: method __init__ (line 51) | def __init__(self, file_lifetime: float = 3600): method add (line 55) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 59) | def _cleanup(self): function chords_string_to_list (line 73) | def chords_string_to_list(chords: str): function load_model (line 85) | def load_model(version='facebook/jasco-chords-drums-400M'): function _do_predictions (line 93) | def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=... function predict_full (line 140) | def predict_full(model, function ui_full (line 180) | def ui_full(launch_kwargs): FILE: demos/magnet_app.py function _call_nostderr (line 37) | def _call_nostderr(*args, **kwargs): function interrupt (line 50) | def interrupt(): class FileCleaner (line 55) | class FileCleaner: method __init__ (line 56) | def __init__(self, file_lifetime: float = 3600): method add (line 60) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 64) | def _cleanup(self): function make_waveform (line 78) | def make_waveform(*args, **kwargs): function load_model (line 88) | def load_model(version='facebook/magnet-small-10secs'): function _do_predictions (line 96) | def _do_predictions(texts, progress=False, gradio_progress=None, **gen_k... function predict_batched (line 125) | def predict_batched(texts, melodies): function predict_full (line 133) | def predict_full(model, model_path, text, temperature, topp, function ui_full (line 175) | def ui_full(launch_kwargs): FILE: demos/musicgen_app.py function _call_nostderr (line 44) | def _call_nostderr(*args, **kwargs): function interrupt (line 57) | def interrupt(): class FileCleaner (line 62) | class FileCleaner: method __init__ (line 63) | def __init__(self, file_lifetime: float = 3600): method add (line 67) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 71) | def _cleanup(self): function make_waveform (line 84) | def make_waveform(*args, **kwargs): function load_model (line 94) | def load_model(version='facebook/musicgen-melody'): function load_diffusion (line 105) | def load_diffusion(): function _do_predictions (line 112) | def _do_predictions(texts, melodies, duration, progress=False, gradio_pr... function predict_batched (line 174) | def predict_batched(texts, melodies): function predict_full (line 182) | def predict_full(model, model_path, decoder, text, melody, duration, top... function toggle_audio_src (line 230) | def toggle_audio_src(choice): function toggle_diffusion (line 237) | def toggle_diffusion(choice): function ui_full (line 244) | def ui_full(launch_kwargs): function ui_batched (line 387) | def ui_batched(launch_kwargs): FILE: demos/musicgen_style_app.py function _call_nostderr (line 39) | def _call_nostderr(*args, **kwargs): function interrupt (line 52) | def interrupt(): class FileCleaner (line 57) | class FileCleaner: method __init__ (line 58) | def __init__(self, file_lifetime: float = 3600): method add (line 62) | def add(self, path: tp.Union[str, Path]): method _cleanup (line 66) | def _cleanup(self): function make_waveform (line 79) | def make_waveform(*args, **kwargs): function load_model (line 89) | def load_model(version='facebook/musicgen-style'): function load_diffusion (line 100) | def load_diffusion(): function _do_predictions (line 107) | def _do_predictions(texts, melodies, duration, top_k, top_p, temperature... function predict_full (line 164) | def predict_full(model, model_path, decoder, text, melody, duration, top... function toggle_audio_src (line 220) | def toggle_audio_src(choice): function toggle_diffusion (line 227) | def toggle_diffusion(choice): function ui_full (line 234) | def ui_full(launch_kwargs): FILE: scripts/chords/build_chord_maps.py function parse_args (line 12) | def parse_args(): function get_chord_dict (line 25) | def get_chord_dict(chord_folder: str): function get_predefined_chord_to_index_map (line 50) | def get_predefined_chord_to_index_map(path_to_chords_to_index_map: str): FILE: scripts/chords/extract_chords.py function parse_args (line 11) | def parse_args(): function save_to_db_cb (line 22) | def save_to_db_cb(tgt_dir: str): FILE: scripts/mos.py function normalize_path (line 43) | def normalize_path(path: Path): function get_full_path (line 51) | def get_full_path(normalized_path: Path): function get_signature (line 57) | def get_signature(xps: tp.List[str]): function ensure_logged (line 63) | def ensure_logged(func): function login (line 76) | def login(): function index (line 98) | def index(): function survey (line 135) | def survey(signature): function audio (line 236) | def audio(path: str): function mean (line 242) | def mean(x): function std (line 246) | def std(x): function results (line 253) | def results(signature): FILE: scripts/resample_dataset.py function read_txt_files (line 22) | def read_txt_files(path: tp.Union[str, Path]): function read_egs_files (line 31) | def read_egs_files(path: tp.Union[str, Path]): function process_dataset (line 45) | def process_dataset(args, n_shards: int, node_index: int, task_index: tp... FILE: tests/adversarial/test_discriminators.py class TestMultiPeriodDiscriminator (line 18) | class TestMultiPeriodDiscriminator: method test_mpd_discriminator (line 20) | def test_mpd_discriminator(self): class TestMultiScaleDiscriminator (line 33) | class TestMultiScaleDiscriminator: method test_msd_discriminator (line 35) | def test_msd_discriminator(self): class TestMultiScaleStftDiscriminator (line 49) | class TestMultiScaleStftDiscriminator: method test_msstftd_discriminator (line 51) | def test_msstftd_discriminator(self): FILE: tests/adversarial/test_losses.py class TestAdversarialLoss (line 22) | class TestAdversarialLoss: method test_adversarial_single_multidiscriminator (line 24) | def test_adversarial_single_multidiscriminator(self): method test_adversarial_feat_loss (line 45) | def test_adversarial_feat_loss(self): class TestGeneratorAdversarialLoss (line 65) | class TestGeneratorAdversarialLoss: method test_hinge_generator_adv_loss (line 67) | def test_hinge_generator_adv_loss(self): method test_mse_generator_adv_loss (line 76) | def test_mse_generator_adv_loss(self): class TestDiscriminatorAdversarialLoss (line 88) | class TestDiscriminatorAdversarialLoss: method _disc_loss (line 90) | def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.T... method test_hinge_discriminator_adv_loss (line 97) | def test_hinge_discriminator_adv_loss(self): method test_mse_discriminator_adv_loss (line 105) | def test_mse_discriminator_adv_loss(self): class TestFeatureMatchingLoss (line 115) | class TestFeatureMatchingLoss: method test_features_matching_loss_base (line 117) | def test_features_matching_loss_base(self): method test_features_matching_loss_raises_exception (line 126) | def test_features_matching_loss_raises_exception(self): method test_features_matching_loss_output (line 141) | def test_features_matching_loss_output(self): FILE: tests/common_utils/temp_utils.py class TempDirMixin (line 11) | class TempDirMixin: method get_base_temp_dir (line 18) | def get_base_temp_dir(cls): method tearDownClass (line 29) | def tearDownClass(cls): method id (line 43) | def id(self): method get_temp_path (line 46) | def get_temp_path(self, *paths): method get_temp_dir (line 52) | def get_temp_dir(self, *paths): FILE: tests/common_utils/wav_utils.py function get_white_noise (line 14) | def get_white_noise(chs: int = 1, num_frames: int = 1): function get_batch_white_noise (line 19) | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): function save_wav (line 24) | def save_wav(path: str, wav: torch.Tensor, sample_rate: int): FILE: tests/data/test_audio.py class TestInfo (line 19) | class TestInfo(TempDirMixin): method test_info_mp3 (line 21) | def test_info_mp3(self): method _test_info_format (line 34) | def _test_info_format(self, ext: str): method test_info_wav (line 48) | def test_info_wav(self): method test_info_flac (line 51) | def test_info_flac(self): method test_info_ogg (line 54) | def test_info_ogg(self): method test_info_m4a (line 57) | def test_info_m4a(self): class TestRead (line 63) | class TestRead(TempDirMixin): method test_read_full_wav (line 65) | def test_read_full_wav(self): method test_read_partial_wav (line 80) | def test_read_partial_wav(self): method test_read_seek_time_wav (line 97) | def test_read_seek_time_wav(self): method test_read_seek_time_wav_padded (line 116) | def test_read_seek_time_wav_padded(self): class TestAvRead (line 139) | class TestAvRead(TempDirMixin): method test_avread_seek_base (line 141) | def test_avread_seek_base(self): method test_avread_seek_partial (line 159) | def test_avread_seek_partial(self): method test_avread_seek_outofbound (line 178) | def test_avread_seek_outofbound(self): method test_avread_seek_edge (line 193) | def test_avread_seek_edge(self): class TestAudioWrite (line 212) | class TestAudioWrite(TempDirMixin): method test_audio_write_wav (line 214) | def test_audio_write_wav(self): FILE: tests/data/test_audio_dataset.py class TestAudioMeta (line 31) | class TestAudioMeta(TempDirMixin): method test_get_audio_meta (line 33) | def test_get_audio_meta(self): method test_save_audio_meta (line 49) | def test_save_audio_meta(self): method test_load_audio_meta (line 65) | def test_load_audio_meta(self): class TestAudioDataset (line 90) | class TestAudioDataset(TempDirMixin): method _create_audio_files (line 92) | def _create_audio_files(self, method _create_audio_dataset (line 114) | def _create_audio_dataset(self, method test_dataset_full (line 135) | def test_dataset_full(self): method test_dataset_segment (line 152) | def test_dataset_segment(self): method test_dataset_equal_audio_and_segment_durations (line 170) | def test_dataset_equal_audio_and_segment_durations(self): method test_dataset_samples (line 192) | def test_dataset_samples(self): method test_dataset_return_info (line 218) | def test_dataset_return_info(self): method test_dataset_return_info_no_segment_duration (line 240) | def test_dataset_return_info_no_segment_duration(self): method test_dataset_collate_fn (line 260) | def test_dataset_collate_fn(self): method test_dataset_with_meta_collate_fn (line 280) | def test_dataset_with_meta_collate_fn(self, segment_duration): method test_sample_with_weight (line 308) | def test_sample_with_weight(self, segment_duration, sample_on_weight, ... method test_meta_duration_filter_all (line 333) | def test_meta_duration_filter_all(self): method test_meta_duration_filter_long (line 345) | def test_meta_duration_filter_long(self): FILE: tests/data/test_audio_utils.py class TestConvertAudioChannels (line 22) | class TestConvertAudioChannels: method test_convert_audio_channels_downmix (line 24) | def test_convert_audio_channels_downmix(self): method test_convert_audio_channels_nochange (line 30) | def test_convert_audio_channels_nochange(self): method test_convert_audio_channels_upmix (line 36) | def test_convert_audio_channels_upmix(self): method test_convert_audio_channels_upmix_error (line 42) | def test_convert_audio_channels_upmix_error(self): class TestConvertAudio (line 49) | class TestConvertAudio: method test_convert_audio_channels_downmix (line 51) | def test_convert_audio_channels_downmix(self): method test_convert_audio_channels_upmix (line 58) | def test_convert_audio_channels_upmix(self): method test_convert_audio_upsample (line 65) | def test_convert_audio_upsample(self): method test_convert_audio_resample (line 74) | def test_convert_audio_resample(self): method test_convert_pcm (line 83) | def test_convert_pcm(self): class TestNormalizeAudio (line 92) | class TestNormalizeAudio: method test_clip_wav (line 94) | def test_clip_wav(self): method test_normalize_audio_clip (line 101) | def test_normalize_audio_clip(self): method test_normalize_audio_rms (line 108) | def test_normalize_audio_rms(self): method test_normalize_audio_peak (line 115) | def test_normalize_audio_peak(self): FILE: tests/losses/test_losses.py function test_mel_l1_loss (line 23) | def test_mel_l1_loss(): function test_msspec_loss (line 37) | def test_msspec_loss(): function test_mrstft_loss (line 51) | def test_mrstft_loss(): function test_sisnr_loss (line 62) | def test_sisnr_loss(): function test_stft_loss (line 73) | def test_stft_loss(): function test_wm_loss (line 84) | def test_wm_loss(): function test_loudness_loss (line 96) | def test_loudness_loss(): FILE: tests/metrics/test_pesq.py function tensor_pesq (line 14) | def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): class TestPesq (line 30) | class TestPesq(TempDirMixin): method test (line 32) | def test(self): FILE: tests/models/test_audiogen.py class TestAudioGenModel (line 13) | class TestAudioGenModel: method get_audiogen (line 14) | def get_audiogen(self): method test_base (line 19) | def test_base(self): method test_generate_continuation (line 25) | def test_generate_continuation(self): method test_generate (line 41) | def test_generate(self): method test_generate_long (line 47) | def test_generate_long(self): FILE: tests/models/test_encodec_model.py class TestEncodecModel (line 17) | class TestEncodecModel: method _create_encodec_model (line 19) | def _create_encodec_model(self, method test_model (line 37) | def test_model(self): method test_model_renorm (line 48) | def test_model_renorm(self): FILE: tests/models/test_multibanddiffusion.py class TestMBD (line 18) | class TestMBD: method _create_mbd (line 20) | def _create_mbd(self, method test_model (line 43) | def test_model(self): FILE: tests/models/test_musicgen.py class TestMusicGenModel (line 13) | class TestMusicGenModel: method get_musicgen (line 14) | def get_musicgen(self): method test_base (line 19) | def test_base(self): method test_generate_unconditional (line 25) | def test_generate_unconditional(self): method test_generate_continuation (line 30) | def test_generate_continuation(self): method test_generate (line 46) | def test_generate(self): method test_generate_long (line 52) | def test_generate_long(self): method test_generate_two_step_cfg (line 60) | def test_generate_two_step_cfg(self): FILE: tests/models/test_watermark.py class TestWatermarkModel (line 13) | class TestWatermarkModel: method test_base (line 15) | def test_base(self): FILE: tests/modules/test_activations.py class TestActivations (line 13) | class TestActivations: method test_custom_glu_calculation (line 14) | def test_custom_glu_calculation(self): FILE: tests/modules/test_codebooks_patterns.py class TestParallelPatternProvider (line 18) | class TestParallelPatternProvider: method test_get_pattern (line 22) | def test_get_pattern(self, n_q: int, timesteps: int): method test_pattern_content (line 30) | def test_pattern_content(self, n_q: int, timesteps: int): method test_pattern_max_delay (line 40) | def test_pattern_max_delay(self, n_q: int, timesteps: int): class TestDelayedPatternProvider (line 47) | class TestDelayedPatternProvider: method test_get_pattern (line 51) | def test_get_pattern(self, n_q: int, timesteps: int): method test_pattern_content (line 65) | def test_pattern_content(self, n_q: int, timesteps: int): method test_pattern_max_delay (line 75) | def test_pattern_max_delay(self, timesteps: int, delay: list): class TestUnrolledPatternProvider (line 82) | class TestUnrolledPatternProvider: method test_get_pattern (line 87) | def test_get_pattern(self, timesteps: int, flattening: list, delays: l... method test_pattern_max_delay (line 97) | def test_pattern_max_delay(self, timesteps: int, flattening: list, del... class TestPattern (line 105) | class TestPattern: method ref_build_pattern_sequence (line 107) | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern... method ref_revert_pattern_sequence (line 121) | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Patter... method ref_revert_pattern_logits (line 134) | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern,... method _get_pattern_providers (line 149) | def _get_pattern_providers(self, n_q: int): method test_build_pattern_sequence (line 173) | def test_build_pattern_sequence(self, n_q: int, timesteps: int): method test_revert_pattern_sequence (line 205) | def test_revert_pattern_sequence(self, n_q: int, timesteps: int): method test_revert_pattern_logits (line 228) | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: i... FILE: tests/modules/test_conv.py function test_get_extra_padding_for_conv1d (line 25) | def test_get_extra_padding_for_conv1d(): function test_pad1d_zeros (line 30) | def test_pad1d_zeros(): function test_pad1d_reflect (line 52) | def test_pad1d_reflect(): function test_unpad1d (line 74) | def test_unpad1d(): class TestNormConv1d (line 96) | class TestNormConv1d: method test_norm_conv1d_modules (line 98) | def test_norm_conv1d_modules(self): class TestNormConvTranspose1d (line 123) | class TestNormConvTranspose1d: method test_normalizations (line 125) | def test_normalizations(self): class TestStreamableConv1d (line 151) | class TestStreamableConv1d: method get_streamable_conv1d_output_length (line 153) | def get_streamable_conv1d_output_length(self, length, kernel_size, str... method test_streamable_conv1d (line 160) | def test_streamable_conv1d(self): class TestStreamableConvTranspose1d (line 176) | class TestStreamableConvTranspose1d: method get_streamable_convtr1d_output_length (line 178) | def get_streamable_convtr1d_output_length(self, length, kernel_size, s... method test_streamable_convtr1d (line 182) | def test_streamable_convtr1d(self): FILE: tests/modules/test_lstm.py class TestStreamableLSTM (line 13) | class TestStreamableLSTM: method test_lstm (line 15) | def test_lstm(self): method test_lstm_skip (line 25) | def test_lstm_skip(self): FILE: tests/modules/test_rope.py function test_rope (line 13) | def test_rope(): function test_rope_io_dtypes (line 26) | def test_rope_io_dtypes(): function test_transformer_with_rope (line 50) | def test_transformer_with_rope(): function test_rope_streaming (line 66) | def test_rope_streaming(): function test_rope_streaming_past_context (line 94) | def test_rope_streaming_past_context(): function test_rope_memory_efficient (line 124) | def test_rope_memory_efficient(): function test_rope_with_xpos (line 145) | def test_rope_with_xpos(): function test_positional_scale (line 158) | def test_positional_scale(): FILE: tests/modules/test_seanet.py class TestSEANetModel (line 16) | class TestSEANetModel: method test_base (line 18) | def test_base(self): method test_causal (line 28) | def test_causal(self): method test_conv_skip_connection (line 38) | def test_conv_skip_connection(self): method test_seanet_encoder_decoder_final_act (line 48) | def test_seanet_encoder_decoder_final_act(self): method _check_encoder_blocks_norm (line 58) | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable... method test_encoder_disable_norm (line 70) | def test_encoder_disable_norm(self): method _check_decoder_blocks_norm (line 79) | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable... method test_decoder_disable_norm (line 94) | def test_decoder_disable_norm(self): method test_disable_norm_raises_exception (line 103) | def test_disable_norm_raises_exception(self): FILE: tests/modules/test_transformer.py function test_transformer_causal_streaming (line 16) | def test_transformer_causal_streaming(): function test_transformer_vs_pytorch (line 52) | def test_transformer_vs_pytorch(): function test_streaming_api (line 71) | def test_streaming_api(): function test_memory_efficient (line 88) | def test_memory_efficient(): function test_attention_as_float32 (line 108) | def test_attention_as_float32(): function test_streaming_memory_efficient (line 134) | def test_streaming_memory_efficient(): function test_cross_attention (line 164) | def test_cross_attention(): function test_cross_attention_compat (line 192) | def test_cross_attention_compat(): function test_repeat_kv (line 224) | def test_repeat_kv(): function test_qk_layer_norm (line 241) | def test_qk_layer_norm(): FILE: tests/quantization/test_vq.py class TestResidualVectorQuantizer (line 12) | class TestResidualVectorQuantizer: method test_rvq (line 14) | def test_rvq(self): FILE: tests/utils/test_audio_effects.py class TestAudioEffect (line 15) | class TestAudioEffect: method audio_effects (line 19) | def audio_effects(self): method test_select_empty_effects (line 86) | def test_select_empty_effects(self): method test_select_wrong_strategy (line 90) | def test_select_wrong_strategy(self): method test_selection (line 97) | def test_selection(self, audio_effects):