SYMBOL INDEX (2093 symbols across 170 files) FILE: .github/compute_wheel_version.py function get_tagged_version (line 15) | def get_tagged_version() -> Optional[str]: function get_dev_version (line 33) | def get_dev_version() -> str: FILE: .github/gpu_benchmark_diff.py class NamedObject (line 13) | class NamedObject: method __init__ (line 14) | def __init__(self, name) -> None: function git_file_at (line 18) | def git_file_at(filename: str, ref: str) -> str: FILE: .github/run-clang-format.py class ExitStatus (line 45) | class ExitStatus: function list_files (line 51) | def list_files(files, recursive=False, extensions=None, exclude=None): function make_diff (line 81) | def make_diff(file, original, reformatted): class DiffError (line 93) | class DiffError(Exception): method __init__ (line 94) | def __init__(self, message, errs=None): class UnexpectedError (line 99) | class UnexpectedError(Exception): method __init__ (line 100) | def __init__(self, message, exc=None): function run_clang_format_diff_wrapper (line 106) | def run_clang_format_diff_wrapper(args, file): function run_clang_format_diff (line 116) | def run_clang_format_diff(args, file): function bold_red (line 172) | def bold_red(s): function colorize (line 176) | def colorize(diff_lines): function print_diff (line 202) | def print_diff(diff_lines, use_color): function print_trouble (line 208) | def print_trouble(prog, message, use_colors): function main (line 215) | def main(): FILE: .github/selective_ci/selective_ci.py class ComponentInfo (line 16) | class ComponentInfo: function list_files_in_commit (line 86) | def list_files_in_commit(commit: git.Commit): function check_patterns_are_valid (line 100) | def check_patterns_are_valid(patterns): FILE: docs/source/conf.py function setup (line 136) | def setup(app): FILE: examples/llama_inference/generate.py class GenArgs (line 29) | class GenArgs: class FastGen (line 37) | class FastGen: method build (line 42) | def build( method __init__ (line 87) | def __init__( method generate_all (line 100) | def generate_all( function get_prompts (line 207) | def get_prompts(interactive: bool) -> Iterable[list[str]]: function main (line 224) | def main(ckpt_dir: str, interactive: bool, add_instruction_tags: bool): FILE: examples/llama_inference/model.py class ModelArgs (line 21) | class ModelArgs: class Attention (line 51) | class Attention(nn.Module): method __init__ (line 52) | def __init__( method load_hook (line 84) | def load_hook( method forward (line 100) | def forward( class FeedForward (line 154) | class FeedForward(nn.Module): method __init__ (line 155) | def __init__( method load_hook (line 186) | def load_hook( method forward (line 201) | def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock (line 209) | class TransformerBlock(nn.Module): method __init__ (line 210) | def __init__(self, args: ModelArgs, layer_index: int): method forward (line 243) | def forward( class Transformer (line 266) | class Transformer(nn.Module): method __init__ (line 267) | def __init__(self, args: ModelArgs): method forward_with_attn_bias (line 292) | def forward_with_attn_bias( method forward (line 308) | def forward( function make_cache (line 324) | def make_cache( function cache_prefix (line 371) | def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: FILE: examples/llama_inference/mp_utils.py function initialize (line 18) | def initialize( function get_world_size (line 83) | def get_world_size() -> int: function get_rank (line 90) | def get_rank() -> int: function all_gather (line 96) | def all_gather(x: torch.Tensor) -> torch.Tensor: function all_reduce (line 110) | def all_reduce(x: torch.Tensor): FILE: examples/llama_inference/sample_utils.py function top_p (line 9) | def top_p(probs: torch.Tensor, p: float) -> torch.Tensor: FILE: examples/llama_inference/stats.py class PhaseStats (line 12) | class PhaseStats: method show (line 17) | def show(self) -> str: class Stats (line 27) | class Stats: method __init__ (line 32) | def __init__(self): method end_phase (line 36) | def end_phase(self, tokens: int, now: Optional[float] = None): method phase (line 50) | def phase(self, name: str, tokens: int = 0): FILE: examples/llama_inference/tokenizer.py class Tokenizer (line 11) | class Tokenizer: method __init__ (line 14) | def __init__(self, model_path: str): method encode (line 36) | def encode(self, s: str, bos: bool = True, eos: bool = False) -> list[... method decode (line 56) | def decode(self, t: list[int]) -> str: FILE: setup.py function get_extra_nvcc_flags_for_build_type (line 54) | def get_extra_nvcc_flags_for_build_type(cuda_version: int) -> List[str]: function fetch_requirements (line 72) | def fetch_requirements(): function get_local_version_suffix (line 78) | def get_local_version_suffix() -> str: function generate_version_py (line 89) | def generate_version_py(version: str) -> str: function get_cuda_version (line 98) | def get_cuda_version(cuda_dir) -> int: function get_hip_version (line 111) | def get_hip_version(rocm_dir) -> Optional[str]: function rename_cpp_cu (line 128) | def rename_cpp_cu(cpp_files): function get_extensions (line 133) | def get_extensions(): class clean (line 362) | class clean(distutils.command.clean.clean): # type: ignore method run (line 363) | def run(self): class bdist_wheel_abi_none (line 378) | class bdist_wheel_abi_none(_bdist_wheel if _bdist_wheel else object): #... method get_tag (line 386) | def get_tag(self): class BuildExtensionWithExtraFiles (line 399) | class BuildExtensionWithExtraFiles(BuildExtension): method __init__ (line 400) | def __init__(self, *args, **kwargs) -> None: method get_export_symbols (line 405) | def get_export_symbols(self, ext): method build_extensions (line 411) | def build_extensions(self) -> None: method copy_extensions_to_source (line 420) | def copy_extensions_to_source(self) -> None: method get_ext_filename (line 434) | def get_ext_filename(self, ext_name): FILE: stubs/fvcore/nn.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/matplotlib/pyplot.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/numpy/__init__.pyi class _ArrayOrScalarCommon (line 37) | class _ArrayOrScalarCommon( class float (line 45) | class float: ... class ndarray (line 47) | class ndarray(_ArrayOrScalarCommon[DType, Unpack[Ts]], Iterable, Sized, ... method __init__ (line 48) | def __init__( method __getitem__ (line 58) | def __getitem__( method __getitem__ (line 62) | def __getitem__( method __setitem__ (line 65) | def __setitem__(self, key, value): ... method shape (line 67) | def shape(self) -> Tuple[Unpack[Ts]]: ... method reshape (line 69) | def reshape(self, shape: Tuple[Unpack[Ts2]]) -> ndarray[DType, Unpack[... method reshape (line 71) | def reshape(self, *shape: Unpack[Ts2]) -> ndarray[DType, Unpack[Ts2]]:... method __add__ (line 72) | def __add__(self, other) -> ndarray[DType, Unpack[Ts]]: ... method __div__ (line 73) | def __div__(self, other) -> ndarray[DType, Unpack[Ts]]: ... method __truediv__ (line 74) | def __truediv__(self, other) -> ndarray[DType, Unpack[Ts]]: ... method astype (line 77) | def astype(self, dtype: Type[NewDType]) -> ndarray[NewDType, Unpack[Ts... method astype (line 79) | def astype(self, dtype: Literal["int64"]) -> ndarray[int64, Unpack[Ts]... method astype (line 81) | def astype(self, dtype: Literal["float32"]) -> ndarray[float32, Unpack... method astype (line 83) | def astype(self, dtype: Literal["float64"]) -> ndarray[float64, Unpack... function empty (line 89) | def empty( function empty (line 95) | def empty( function empty (line 101) | def empty(shape: N, dtype: Type[DType]) -> ndarray[DType, N]: ... function array (line 104) | def array( function sin (line 111) | def sin(x: ndarray[DType, Unpack[Ts]]) -> ndarray[DType, Unpack[Ts]]: ... class int64 (line 113) | class int64: method __init__ (line 114) | def __init__(self, value=...): ... class float32 (line 116) | class float32: method __init__ (line 117) | def __init__(self, value=...): ... class float64 (line 119) | class float64: method __init__ (line 120) | def __init__(self, value=...): ... FILE: stubs/pandas.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/recommonmark/transform.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/seaborn.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/sklearn/model_selection.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/submitit.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/tensorflow.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/torch/__init__.pyi class complex64 (line 66) | class complex64: ... class complex128 (line 67) | class complex128: ... class float16 (line 68) | class float16: ... class float32 (line 69) | class float32: ... class float64 (line 70) | class float64: ... class int64 (line 71) | class int64: ... class int32 (line 72) | class int32: ... class bool (line 73) | class bool: ... class memory_format (line 74) | class memory_format: ... class long (line 80) | class long: ... class layout (line 81) | class layout: ... class MaxNamedTuple (line 87) | class MaxNamedTuple(Generic[DType, Unpack[Ts]]): method __getitem__ (line 91) | def __getitem__(self, key: L[0]) -> Tensor[DType, Unpack[Ts]]: ... method __getitem__ (line 93) | def __getitem__(self, key: L[1]) -> Tensor[int64, Unpack[Ts]]: ... class device (line 95) | class device: method __init__ (line 96) | def __init__(self, device_str: str): ... class Size (line 101) | class Size(Tuple[builtins.int, ...]): method __getitem__ (line 103) | def __getitem__(self: Size, key: builtins.int) -> builtins.int: ... method __getitem__ (line 105) | def __getitem__(self: Size, key: slice) -> Size: ... method numel (line 106) | def numel(self: Size) -> builtins.int: ... class Generator (line 108) | class Generator(object): method __init__ (line 110) | def __init__(self, device: Union[_device, str, None] = None) -> None: ... method get_state (line 111) | def get_state(self) -> Tensor: ... method set_state (line 112) | def set_state(self, _new_state: Tensor) -> Generator: ... method manual_seed (line 113) | def manual_seed(self, seed: builtins.int) -> Generator: ... method seed (line 114) | def seed(self) -> builtins.int: ... method initial_seed (line 115) | def initial_seed(self) -> builtins.int: ... class Storage (line 119) | class Storage(object): method __deepcopy__ (line 121) | def __deepcopy__(self, memo) -> "Storage": ... method _new_shared (line 122) | def _new_shared(self, int) -> "Storage": ... method _write_file (line 123) | def _write_file( method element_size (line 126) | def element_size(self) -> int: ... method is_shared (line 127) | def is_shared(self) -> bool: ... method share_memory_ (line 128) | def share_memory_(self) -> "Storage": ... method size (line 129) | def size(self) -> int: ... class Tensor (line 131) | class Tensor(Generic[DType, Unpack[Ts]]): method __init__ (line 147) | def __init__(self, other: Tensor[DType, Unpack[Ts]]) -> None: ... method __init__ (line 149) | def __init__( method __init__ (line 153) | def __init__(self, storage: Storage) -> None: ... method __init__ (line 155) | def __init__( method device (line 159) | def device(self) -> _device: ... method dtype (line 161) | def dtype(self) -> Type[DType]: ... method long (line 162) | def long(self) -> "LongTensor[DType, Unpack[Ts]]": ... method size (line 171) | def size(self: Tensor[DType, N1, Unpack[Rs]], axis: L[0]) -> N1: ... method size (line 173) | def size(self: Tensor[DType, N1, N2, Unpack[Rs]], axis: L[1]) -> N2: ... method size (line 175) | def size(self: Tensor[DType, Unpack[Rs], N1], axis: L[-1]) -> N1: ... method size (line 177) | def size(self: Tensor[DType, Unpack[Rs], N1, N2], axis: L[-2]) -> N1: ... method size (line 179) | def size(self: Tensor[DType, Unpack[Rs]]) -> Tuple[Unpack[Rs]]: ... method split (line 181) | def split( method split (line 185) | def split( method item (line 191) | def item(self: Tensor[DType, L[1]]) -> DType: ... method item (line 193) | def item(self: Tensor[DType]) -> DType: ... method numel (line 194) | def numel(self) -> builtins.int: ... method backward (line 195) | def backward(self) -> None: ... method __getitem__ (line 197) | def __getitem__( method __getitem__ (line 201) | def __getitem__( method __getitem__ (line 205) | def __getitem__( method __getitem__ (line 209) | def __getitem__(self, item: Any) -> Any: ... method expand (line 211) | def expand( method expand (line 215) | def expand( method detach (line 218) | def detach(self: T) -> T: ... method numpy (line 220) | def numpy(self) -> ndarray[DType, Unpack[Ts]]: ... method to (line 224) | def to( method to (line 228) | def to( method __add__ (line 233) | def __add__( method __add__ (line 237) | def __add__( method __iadd__ (line 242) | def __iadd__( method __iadd__ (line 246) | def __iadd__( method __radd__ (line 251) | def __radd__( method __radd__ (line 255) | def __radd__( method __sub__ (line 260) | def __sub__( method __sub__ (line 264) | def __sub__( method __isub__ (line 269) | def __isub__( method __isub__ (line 273) | def __isub__( method __rsub__ (line 278) | def __rsub__( method __rsub__ (line 282) | def __rsub__( method __mul__ (line 287) | def __mul__( method __mul__ (line 292) | def __mul__( method __imul__ (line 297) | def __imul__( method __imul__ (line 302) | def __imul__( method __rmul__ (line 307) | def __rmul__( method __rmul__ (line 312) | def __rmul__( method __pow__ (line 317) | def __pow__( method __pow__ (line 322) | def __pow__( method __truediv__ (line 327) | def __truediv__( method __truediv__ (line 332) | def __truediv__( method __itruediv__ (line 337) | def __itruediv__( method __itruediv__ (line 342) | def __itruediv__( method __rtruediv__ (line 347) | def __rtruediv__( method __floordiv__ (line 352) | def __floordiv__( method __floordiv__ (line 357) | def __floordiv__( method __ifloordiv__ (line 362) | def __ifloordiv__( method __ifloordiv__ (line 367) | def __ifloordiv__( method __rfloordiv__ (line 372) | def __rfloordiv__( method __invert__ (line 376) | def __invert__(self) -> Tensor[DType, Unpack[Ts]]: ... method __neg__ (line 377) | def __neg__(self) -> Tensor[DType, Unpack[Ts]]: ... method __iand__ (line 378) | def __iand__( method __and__ (line 382) | def __and__( method __matmul__ (line 387) | def __matmul__( method __matmul__ (line 392) | def __matmul__( method __ne__ (line 398) | def __ne__( method abs (line 401) | def abs(self) -> Tensor[DType, Unpack[Ts]]: ... method all (line 403) | def all( method all (line 407) | def all( method all (line 412) | def all( method argmax (line 417) | def argmax( method argmax (line 423) | def argmax( method argmax (line 429) | def argmax( method argmax (line 435) | def argmax( method argmax (line 441) | def argmax( method argmax (line 447) | def argmax( method argmax (line 453) | def argmax( method argmax (line 459) | def argmax( method argmax (line 465) | def argmax( method argmin (line 471) | def argmin( method argmin (line 477) | def argmin( method argmin (line 483) | def argmin( method argmin (line 489) | def argmin( method argmin (line 495) | def argmin( method argmin (line 501) | def argmin( method argmin (line 507) | def argmin( method argmin (line 513) | def argmin( method argmin (line 519) | def argmin( method chunk (line 528) | def chunk(self: Tensor[DType, Unpack[Rs], N], chunks: L[2], dim: L[-1]... method chunk (line 533) | def chunk( method clone (line 539) | def clone( method count_nonzero (line 543) | def count_nonzero( method count_nonzero (line 548) | def count_nonzero( method count_nonzero (line 553) | def count_nonzero( method count_nonzero (line 558) | def count_nonzero( method count_nonzero (line 563) | def count_nonzero( method dim (line 569) | def dim(self: Tensor[DType]) -> L[0]: ... method dim (line 571) | def dim(self: Tensor[DType, builtins.int]) -> L[1]: ... method dim (line 573) | def dim(self: Tensor[DType, builtins.int, builtins.int]) -> L[2]: ... method dim (line 575) | def dim(self: Tensor[DType, builtins.int, builtins.int, builtins.int])... method half (line 576) | def half( method is_contiguous (line 579) | def is_contiguous( method indices (line 582) | def indices(self) -> Tensor: ... method masked_select (line 584) | def masked_select(self, mask: Tensor, *, out: Optional[Tensor] = ...) ... method max (line 586) | def max( method max (line 592) | def max( method max (line 598) | def max( method max (line 604) | def max( method max (line 610) | def max( method max (line 616) | def max( method max (line 622) | def max( method max (line 628) | def max( method max (line 634) | def max( method max (line 640) | def max( method max (line 646) | def max( method mean (line 650) | def mean( method mean (line 656) | def mean( method mean (line 662) | def mean( method mean (line 668) | def mean( method mean (line 674) | def mean( method mean (line 680) | def mean( method mean (line 686) | def mean( method mean (line 692) | def mean( method mean (line 698) | def mean( method bitwise_not (line 703) | def bitwise_not(self) -> Tensor[DType, Unpack[Ts]]: ... method bitwise_not_ (line 704) | def bitwise_not_(self) -> Tensor[DType, Unpack[Ts]]: ... method diff (line 706) | def diff( method diff (line 711) | def diff( method diff (line 716) | def diff( method diff (line 721) | def diff( method is_sparse (line 725) | def is_sparse(self) -> builtins.bool: ... method coalesce (line 726) | def coalesce(self: Tensor[DType, Unpack[Rs]]) -> Tensor[DType, Unpack[... method values (line 727) | def values(self: Tensor[DType, Unpack[Rs]]) -> Tensor[DType, Unpack[Rs... method to_sparse (line 728) | def to_sparse(self: Tensor[DType, Unpack[Ts]]) -> Tensor[DType, Unpack... method __eq__ (line 733) | def __eq__( method __eq__ (line 738) | def __eq__( method argsort (line 742) | def argsort( method bmm (line 745) | def bmm( method diag_embed (line 748) | def diag_embed( method matmul (line 752) | def matmul( method matmul (line 757) | def matmul( method multinomial (line 763) | def multinomial( method new_ones (line 771) | def new_ones( method new_ones (line 779) | def new_ones( method unsqueeze (line 787) | def unsqueeze( method unsqueeze (line 791) | def unsqueeze( method unsqueeze (line 795) | def unsqueeze( method unsqueeze (line 799) | def unsqueeze( method unsqueeze_ (line 803) | def unsqueeze_( method unsqueeze_ (line 807) | def unsqueeze_( method unsqueeze_ (line 811) | def unsqueeze_( method unsqueeze_ (line 815) | def unsqueeze_( method real (line 819) | def real(self: Tensor[complex64, Unpack[Rs]]) -> Tensor[float32, Unpac... method repeat (line 821) | def repeat( method repeat (line 825) | def repeat( method repeat (line 829) | def repeat( method repeat_interleave (line 833) | def repeat_interleave( method repeat_interleave (line 837) | def repeat_interleave( method repeat_interleave (line 841) | def repeat_interleave( method repeat_interleave (line 845) | def repeat_interleave( method repeat_interleave (line 850) | def repeat_interleave( method __setitem__ (line 853) | def __setitem__(self, item: object, other: object) -> None: ... method scatter (line 855) | def scatter( method scatter_ (line 863) | def scatter_( method softmax (line 871) | def softmax(self, dim: builtins.int) -> Tensor[DType, Unpack[Ts]]: ... method softmax (line 873) | def softmax( method stride (line 877) | def stride( method stride (line 881) | def stride( method stride (line 885) | def stride( method stride (line 889) | def stride(self) -> Tuple[Unpack[Ts]]: ... method squeeze (line 891) | def squeeze( method squeeze (line 895) | def squeeze( method squeeze (line 899) | def squeeze( method squeeze (line 906) | def squeeze( method squeeze (line 913) | def squeeze( method type_as (line 916) | def type_as( method squeeze_ (line 920) | def squeeze_( method squeeze_ (line 924) | def squeeze_( method squeeze_ (line 928) | def squeeze_( method squeeze_ (line 935) | def squeeze_( method squeeze_ (line 942) | def squeeze_( method view (line 946) | def view( method view (line 952) | def view( method view (line 958) | def view( method view (line 964) | def view(self, *shape: Unpack[Rs]) -> Tensor[DType, Unpack[Rs]]: ... method transpose (line 966) | def transpose( method transpose (line 970) | def transpose( method transpose (line 974) | def transpose( method transpose (line 978) | def transpose( method transpose (line 982) | def transpose( method flatten (line 986) | def flatten( method flatten (line 992) | def flatten( method flatten (line 998) | def flatten( method flatten (line 1004) | def flatten( method flatten (line 1010) | def flatten( method __lt__ (line 1016) | def __lt__( method __lt__ (line 1020) | def __lt__( method __gt__ (line 1024) | def __gt__( method __gt__ (line 1028) | def __gt__( method logical_and (line 1031) | def logical_and( method logical_and_ (line 1037) | def logical_and_( method reshape (line 1044) | def reshape( method reshape (line 1050) | def reshape( method reshape (line 1056) | def reshape(self, *shape: Unpack[Rs]) -> Tensor[DType, Unpack[Rs]]: ... method unbind (line 1058) | def unbind( method unbind (line 1062) | def unbind( method unbind (line 1066) | def unbind( method sign (line 1069) | def sign(self, *, out: Optional[Tensor] = ...) -> Tensor[DType, Unpack... method sum (line 1071) | def sum( method sum (line 1078) | def sum( method sum (line 1085) | def sum( method sum (line 1092) | def sum( method sum (line 1099) | def sum( method cumsum (line 1105) | def cumsum( method contiguous (line 1110) | def contiguous(input: Tensor[DType, Unpack[Rs]]) -> Tensor[DType, Unpa... class LongTensor (line 1112) | class LongTensor(Tensor[DType, Unpack[Ts]], Generic[DType, Unpack[Ts]]): method __getitem__ (line 1114) | def __getitem__( method __getitem__ (line 1118) | def __getitem__( method __getitem__ (line 1122) | def __getitem__( method __eq__ (line 1125) | def __eq__( function allclose (line 1133) | def allclose( function bitwise_not (line 1140) | def bitwise_not( function einsum (line 1143) | def einsum( function eye (line 1148) | def eye( function eye (line 1159) | def eye( function eye (line 1171) | def eye( function eye (line 1182) | def eye( function zeros (line 1194) | def zeros( function zeros (line 1205) | def zeros( function zeros (line 1217) | def zeros( function zeros (line 1227) | def zeros( function ones (line 1238) | def ones(*size: Unpack[Ts]) -> Tensor[float, Unpack[Ts]]: ... function ones (line 1240) | def ones( function ones_like (line 1244) | def ones_like( function ones_like (line 1255) | def ones_like( function tril (line 1265) | def tril( function arange (line 1269) | def arange( function arange (line 1279) | def arange( function arange (line 1290) | def arange( function arange (line 1303) | def arange( function arange (line 1313) | def arange( function arange (line 1324) | def arange( function arange (line 1335) | def arange( function argmax (line 1346) | def argmax( function argmax (line 1352) | def argmax( function argmax (line 1358) | def argmax( function argmax (line 1364) | def argmax( function argmax (line 1370) | def argmax( function argmax (line 1376) | def argmax( function argmax (line 1382) | def argmax( function argmax (line 1388) | def argmax( function argmax (line 1394) | def argmax( function argmin (line 1400) | def argmin( function argmin (line 1406) | def argmin( function argmin (line 1412) | def argmin( function argmin (line 1418) | def argmin( function argmin (line 1424) | def argmin( function argmin (line 1430) | def argmin( function argmin (line 1436) | def argmin( function argmin (line 1442) | def argmin( function argmin (line 1448) | def argmin( function bmm (line 1453) | def bmm( function chunk (line 1457) | def chunk(input: Tensor[DType, Unpack[Ts], N], chunks: L[2], dim: L[-1])... function chunk (line 1462) | def chunk(input: Tensor[DType, N, Unpack[Ts]], chunks: L[2], dim: L[0] =... function diag (line 1466) | def diag( function diagonal (line 1472) | def diagonal( function diag_embed (line 1478) | def diag_embed( function empty_like (line 1485) | def empty_like( function empty_like (line 1497) | def empty_like( function logical_and (line 1508) | def logical_and( function log_softmax (line 1515) | def log_softmax( function log_softmax (line 1521) | def log_softmax( function masked_select (line 1527) | def masked_select( function max (line 1531) | def max( function max (line 1537) | def max( function max (line 1543) | def max( function max (line 1549) | def max( function max (line 1555) | def max( function max (line 1561) | def max( function max (line 1567) | def max( function max (line 1573) | def max( function max (line 1579) | def max( function max (line 1585) | def max( function max (line 1591) | def max( function mean (line 1595) | def mean( function mean (line 1601) | def mean( function mean (line 1607) | def mean( function mean (line 1613) | def mean( function mean (line 1619) | def mean( function mean (line 1625) | def mean( function mean (line 1631) | def mean( function mean (line 1637) | def mean( function mean (line 1643) | def mean( function meshgrid (line 1649) | def meshgrid(tensor1: Tensor[DType, N1]) -> Tuple[Tensor[DType, N1]]: ... function meshgrid (line 1651) | def meshgrid( function meshgrid (line 1656) | def meshgrid( function meshgrid (line 1664) | def meshgrid(*tensors: Tensor) -> Tuple[Tensor, ...]: ... function norm (line 1666) | def norm( function norm (line 1675) | def norm( function norm (line 1684) | def norm( function norm (line 1693) | def norm( function norm (line 1702) | def norm( function norm (line 1711) | def norm( function norm (line 1720) | def norm( function norm (line 1729) | def norm( function norm (line 1738) | def norm( function normal (line 1747) | def normal( function rand (line 1756) | def rand( function rand (line 1767) | def rand( function rand (line 1779) | def rand( function rand (line 1789) | def rand( function randint (line 1800) | def randint( function randint (line 1813) | def randint( function randint (line 1826) | def randint( function randint (line 1839) | def randint( function rand_like (line 1851) | def rand_like( function nonzero (line 1854) | def nonzero( function repeat_interleave (line 1858) | def repeat_interleave( function repeat_interleave (line 1862) | def repeat_interleave( function repeat_interleave (line 1866) | def repeat_interleave( function repeat_interleave (line 1870) | def repeat_interleave( function repeat_interleave (line 1876) | def repeat_interleave( function stack (line 1880) | def stack( function stack (line 1887) | def stack( function stack (line 1894) | def stack( function stack (line 1905) | def stack( function stack (line 1916) | def stack( function cdist (line 1922) | def cdist( function clone (line 1928) | def clone( function count_nonzero (line 1932) | def count_nonzero( function count_nonzero (line 1937) | def count_nonzero( function count_nonzero (line 1942) | def count_nonzero( function count_nonzero (line 1947) | def count_nonzero( function count_nonzero (line 1952) | def count_nonzero( function sum (line 1958) | def sum( function sum (line 1962) | def sum( function sum (line 1969) | def sum( function sum (line 1973) | def sum( function sum (line 1980) | def sum( function sin (line 1987) | def sin( function cos (line 1990) | def cos( function exp (line 1993) | def exp( function matmul (line 1997) | def matmul( function matmul (line 2004) | def matmul( function multinomial (line 2010) | def multinomial( function unbind (line 2018) | def unbind( function unbind (line 2022) | def unbind( function unbind (line 2026) | def unbind( function unsqueeze (line 2030) | def unsqueeze( function unsqueeze (line 2034) | def unsqueeze( function unsqueeze (line 2038) | def unsqueeze( function unsqueeze (line 2042) | def unsqueeze( function real (line 2046) | def real(input: Tensor[complex64, Unpack[Ts]]) -> Tensor[float32, Unpack... function real (line 2048) | def real(input: Tensor[complex128, Unpack[Ts]]) -> Tensor[float64, Unpac... function zeros_like (line 2049) | def zeros_like( function randn (line 2053) | def randn( function randn (line 2063) | def randn( function randn (line 2072) | def randn( function randn (line 2082) | def randn( function all (line 2091) | def all( function all (line 2095) | def all( function all (line 2100) | def all( function randperm (line 2105) | def randperm( function randperm (line 2117) | def randperm( function sqrt (line 2128) | def sqrt( function where (line 2132) | def where( function where (line 2146) | def where(condition: Tensor[DType, Unpack[Ts]]) -> Any: ... function diff (line 2148) | def diff( function diff (line 2153) | def diff( function diff (line 2158) | def diff( function diff (line 2163) | def diff( function argsort (line 2167) | def argsort( function cat (line 2175) | def cat( function cat (line 2182) | def cat( function cat (line 2189) | def cat( function cat (line 2198) | def cat( function cat (line 2209) | def cat( function cat (line 2220) | def cat( function cat (line 2235) | def cat( function sign (line 2248) | def sign( function sparse_coo_tensor (line 2252) | def sparse_coo_tensor( function sparse_coo_tensor (line 2262) | def sparse_coo_tensor( function sparse_coo_tensor (line 2272) | def sparse_coo_tensor( function softmax (line 2282) | def softmax( function softmax (line 2286) | def softmax( function transpose (line 2290) | def transpose( function transpose (line 2294) | def transpose( function transpose (line 2298) | def transpose( function transpose (line 2302) | def transpose( function transpose (line 2306) | def transpose( function empty (line 2310) | def empty( function empty (line 2320) | def empty( function empty (line 2331) | def empty( function empty (line 2343) | def empty( function flatten (line 2354) | def flatten( function flatten (line 2360) | def flatten( function flatten (line 2366) | def flatten( function flatten (line 2372) | def flatten( function flatten (line 2378) | def flatten( function reshape (line 2384) | def reshape( function reshape (line 2388) | def reshape( function reshape (line 2394) | def reshape( FILE: stubs/torch/autograd/__init__.pyi class Function (line 12) | class Function: method apply (line 14) | def apply(cls, *args: object) -> Any: ... class enable_grad (line 16) | class enable_grad: method __enter__ (line 17) | def __enter__(self) -> None: ... method __exit__ (line 18) | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> N... function backward (line 20) | def backward( FILE: stubs/torch/autograd/profiler.pyi class record_function (line 9) | class record_function(contextlib.ContextDecorator): method __init__ (line 10) | def __init__(self, name: str) -> None: ... method __enter__ (line 11) | def __enter__(self) -> Any: ... method __exit__ (line 12) | def __exit__(self, *exctype: Any) -> None: ... FILE: stubs/torch/cuda/__init__.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/torch/fft/__init__.pyi function fft (line 16) | def fft( function fft2 (line 25) | def fft2( FILE: stubs/torch/linalg/__init__.pyi function pinv (line 20) | def pinv( function pinv (line 28) | def pinv( function qr (line 36) | def qr( function qr (line 48) | def qr( function norm (line 57) | def norm( FILE: stubs/torch/nn/__init__.pyi class Module (line 38) | class Module: method __call__ (line 39) | def __call__(self, *args: Any, **kwargs: Any) -> Any: ... method parameters (line 40) | def parameters(self) -> Iterator[Any]: ... method double (line 41) | def double(self: T) -> T: ... method to (line 42) | def to(self, dtype: Type[T], device: torch._device = ...) -> Module: ... method eval (line 43) | def eval(self) -> Module: ... method train (line 44) | def train(self, mode: bool) -> Module: ... method register_parameter (line 45) | def register_parameter(self, name: str, param: Optional[Parameter]) ->... class LSTMCell (line 49) | class LSTMCell(Module, Generic[InputSize, HiddenSize]): method __init__ (line 50) | def __init__( method __call__ (line 53) | def __call__( class Linear (line 61) | class Linear(Module, Generic[InputSize, OutputSize]): method __init__ (line 62) | def __init__( method __call__ (line 65) | def __call__( class _Loss (line 70) | class _Loss(Module): ... class MSELoss (line 72) | class MSELoss(_Loss): method __init__ (line 73) | def __init__( method __call__ (line 79) | def __call__( class Conv2d (line 98) | class Conv2d( method __init__ (line 102) | def __init__( method __call__ (line 110) | def __call__( class ReflectionPad2d (line 123) | class ReflectionPad2d(Module, Generic[Padding]): method __init__ (line 124) | def __init__( method __call__ (line 128) | def __call__( class InstanceNorm2d (line 139) | class InstanceNorm2d(Generic[Channels]): method __init__ (line 140) | def __init__(self, num_features: Channels, affine: bool = False) -> No... method __call__ (line 141) | def __call__( class LeakyReLU (line 145) | class LeakyReLU(Module): method __init__ (line 146) | def __init__(self, negative_slope: float = ..., inplace: bool = ...) -... method __call__ (line 147) | def __call__( class ReLU (line 151) | class ReLU(Module): method __call__ (line 152) | def __call__( class GELU (line 156) | class GELU(Module): method __call__ (line 157) | def __call__( class Dropout (line 161) | class Dropout(Module): method __init__ (line 162) | def __init__(self, p: float, inplace: bool = ...) -> None: ... method __call__ (line 163) | def __call__( class Embedding (line 167) | class Embedding(Module, Generic[N, EmbeddingDimension]): method __init__ (line 168) | def __init__( method padding_idx (line 180) | def padding_idx(self) -> int: ... method max_norm (line 182) | def max_norm(self) -> float: ... method norm_type (line 184) | def norm_type(self) -> float: ... method scale_grad_by_freq (line 186) | def scale_grad_by_freq(self) -> bool: ... method sparse (line 188) | def sparse(self) -> bool: ... method weight (line 190) | def weight(self) -> Tensor[torch.float32, N, EmbeddingDimension]: ... method from_pretrained (line 192) | def from_pretrained( method forward (line 202) | def forward( method __call__ (line 205) | def __call__( class LayerNorm (line 211) | class LayerNorm(Module): method __init__ (line 212) | def __init__( method forward (line 220) | def forward(self, x: Tensor[DType, Unpack[Ts]]) -> Tensor[DType, Unpac... method __call__ (line 221) | def __call__(self, x: Tensor[DType, Unpack[Ts]]) -> Tensor[DType, Unpa... class AdaptiveAvgPool2d (line 223) | class AdaptiveAvgPool2d(Module, Generic[H, W]): method __new__ (line 225) | def __new__( method __new__ (line 230) | def __new__( method __new__ (line 235) | def __new__( method __new__ (line 240) | def __new__( method forward (line 244) | def forward(self, x: Tensor[DType, Unpack[Ts]]) -> Tensor[DType, Unpac... method __call__ (line 246) | def __call__( method __call__ (line 250) | def __call__( method __call__ (line 254) | def __call__( class ModuleList (line 258) | class ModuleList(Module): method __init__ (line 259) | def __init__(self, modules: Optional[Iterable[Module]] = ...) -> None:... method __iter__ (line 260) | def __iter__(self) -> Iterator[Module]: ... method __len__ (line 261) | def __len__(self) -> int: ... class Parameter (line 263) | class Parameter(Tensor[DType, Unpack[Ts]]): method __init__ (line 264) | def __init__( FILE: stubs/torch/nn/functional.pyi function pad (line 26) | def pad( function pad (line 33) | def pad( FILE: stubs/torch/nn/functional/__init__.pyi function pad (line 26) | def pad( function pad (line 33) | def pad( FILE: stubs/torch/nn/init.pyi function _calculate_fan_in_and_fan_out (line 15) | def _calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[int, int]: ... function constant_ (line 16) | def constant_( function kaiming_uniform_ (line 19) | def kaiming_uniform_( function normal_ (line 22) | def normal_( function uniform_ (line 25) | def uniform_( function _no_grad_uniform_ (line 28) | def _no_grad_uniform_(tensor: Tensor, a, b): ... function xavier_uniform_ (line 29) | def xavier_uniform_(tensor: Tensor, gain: float = ...) -> Tensor: ... FILE: stubs/torch/nn/utils/__init__.pyi function clip_grad_norm_ (line 16) | def clip_grad_norm_( function clip_grad_value_ (line 22) | def clip_grad_value_( FILE: stubs/torch/ops.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/torch/optim/__init__.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/torch/profiler/__init__.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/torch/random/__init__.pyi function initial_seed (line 8) | def initial_seed() -> int: ... function __getattr__ (line 9) | def __getattr__(name) -> Any: ... FILE: stubs/torch/sparse/__init__.pyi function softmax (line 20) | def softmax( function softmax (line 24) | def softmax( FILE: stubs/torch_stub_tests.py function test_sin (line 34) | def test_sin() -> None: function test_unsqueeze (line 54) | def test_unsqueeze() -> None: function test_unsqueeze_ (line 73) | def test_unsqueeze_() -> None: function test_squeeze_ (line 83) | def test_squeeze_() -> None: function test_squeeze (line 100) | def test_squeeze() -> None: function test_repeat (line 117) | def test_repeat() -> None: function test_multiply (line 140) | def test_multiply() -> None: function test_floor_division (line 187) | def test_floor_division() -> None: function test_division (line 220) | def test_division() -> None: function test_setitem (line 257) | def test_setitem() -> None: function test_arange (line 262) | def test_arange(n: N) -> None: function test_embedding (line 284) | def test_embedding() -> None: function test_init_normal (line 307) | def test_init_normal() -> None: function test_view (line 317) | def test_view() -> None: function test_reshape (line 350) | def test_reshape() -> None: function test_transpose (line 364) | def test_transpose() -> None: function test_flatten (line 389) | def test_flatten() -> None: function test_empty (line 424) | def test_empty() -> None: function test_empty_like (line 460) | def test_empty_like() -> None: function test_randn (line 482) | def test_randn() -> None: function test_all (line 512) | def test_all() -> None: function test_where (line 526) | def test_where() -> None: function test_getitem (line 551) | def test_getitem() -> None: function test_expand (line 575) | def test_expand() -> None: function test_to (line 588) | def test_to() -> None: function test_Linear_to (line 606) | def test_Linear_to() -> None: function test_Module_eval (line 613) | def test_Module_eval() -> None: function test_Module_train (line 618) | def test_Module_train() -> None: function test_Linear_bias (line 624) | def test_Linear_bias() -> None: function test_sum (line 630) | def test_sum() -> None: function test_cumsum (line 647) | def test_cumsum() -> None: function test_contiguous (line 663) | def test_contiguous() -> None: function test_diff (line 671) | def test_diff() -> None: function test_argsort (line 684) | def test_argsort() -> None: function test_functional_pad (line 705) | def test_functional_pad() -> None: function test_allclose (line 715) | def test_allclose() -> None: function test_new_ones (line 725) | def test_new_ones() -> None: function test_ones_like (line 736) | def test_ones_like() -> None: function test_sparse_softmax (line 755) | def test_sparse_softmax() -> None: function test_eye (line 769) | def test_eye() -> None: function test_adaptive_average_pool2d (line 779) | def test_adaptive_average_pool2d() -> None: function test_randperm (line 801) | def test_randperm() -> None: function test_sqrt (line 809) | def test_sqrt() -> None: function test_multinomial (line 819) | def test_multinomial() -> None: function test_bmm (line 833) | def test_bmm() -> None: function test_subtract (line 847) | def test_subtract() -> None: function test_add (line 875) | def test_add() -> None: function test_torch_fft (line 899) | def test_torch_fft() -> None: function test_torch_real (line 907) | def test_torch_real() -> None: function test_logical_and (line 920) | def test_logical_and() -> None: function test_and (line 940) | def test_and() -> None: function test_linalg_pinv (line 960) | def test_linalg_pinv() -> None: function test_linalg_qr (line 980) | def test_linalg_qr() -> None: function test_torch_matmul (line 997) | def test_torch_matmul() -> None: function test_torch_optim (line 1016) | def test_torch_optim() -> None: function test_torch_cuda (line 1021) | def test_torch_cuda() -> None: function test_torch_profiler (line 1025) | def test_torch_profiler() -> None: function test_mse_loss (line 1029) | def test_mse_loss() -> None: function test_clip_grad_norm (line 1040) | def test_clip_grad_norm() -> None: function test_clip_grad_value (line 1052) | def test_clip_grad_value() -> None: function test_bitwise_not (line 1057) | def test_bitwise_not() -> None: function test_cdist (line 1071) | def test_cdist() -> None: function test_random_manual_seed (line 1086) | def test_random_manual_seed() -> None: function test_clone (line 1090) | def test_clone() -> None: function test_equal (line 1101) | def test_equal() -> None: function test_diag_embed (line 1117) | def test_diag_embed() -> None: function test_unbind (line 1122) | def test_unbind() -> None: function test_size (line 1137) | def test_size() -> None: function test_stack (line 1151) | def test_stack( function test_repeat_interleave (line 1172) | def test_repeat_interleave() -> None: function test_meshgrid (line 1198) | def test_meshgrid() -> None: function test_argmax (line 1227) | def test_argmax() -> None: function test_argmin (line 1258) | def test_argmin() -> None: function test_mean (line 1289) | def test_mean() -> None: function test_count_nonzero (line 1318) | def test_count_nonzero() -> None: function test_cat (line 1334) | def test_cat() -> None: function test_sign (line 1376) | def test_sign() -> None: function test_diagonal (line 1384) | def test_diagonal() -> None: function test_diag (line 1389) | def test_diag() -> None: function test_module_list (line 1394) | def test_module_list() -> None: function test_sparse_coo_tensor (line 1404) | def test_sparse_coo_tensor() -> None: function test_max (line 1418) | def test_max() -> None: function test_einsum (line 1453) | def test_einsum() -> None: function test_type_as (line 1457) | def test_type_as() -> None: function test_softmax (line 1463) | def test_softmax() -> None: function test_conv2d (line 1474) | def test_conv2d() -> None: function test_nn_Parameter (line 1488) | def test_nn_Parameter() -> None: function test_torch_datatypes (line 1496) | def test_torch_datatypes() -> None: function test_norm (line 1501) | def test_norm() -> None: function test_rand (line 1512) | def test_rand() -> None: function test_randint (line 1534) | def test_randint() -> None: function test_zeros (line 1554) | def test_zeros() -> None: function test_stride (line 1575) | def test_stride() -> None: function test_chunk (line 1586) | def test_chunk() -> None: function test_abs (line 1605) | def test_abs() -> None: function test_enable_grad (line 1613) | def test_enable_grad() -> None: function test_normal (line 1618) | def test_normal() -> None: function test_dim (line 1628) | def test_dim() -> None: function test_is_cuda (line 1642) | def test_is_cuda() -> None: function test_autograd_backward (line 1647) | def test_autograd_backward() -> None: function test_linalg_norm (line 1652) | def test_linalg_norm() -> None: function test_Sized (line 1659) | def test_Sized() -> None: function test_initial_seed (line 1663) | def test_initial_seed() -> None: function test_log_softmax (line 1667) | def test_log_softmax() -> None: function test_masked_select (line 1678) | def test_masked_select() -> None: function test__lt__ (line 1687) | def test__lt__() -> None: function test_pow (line 1693) | def test_pow() -> None: function test_item (line 1701) | def test_item() -> None: function test_uniform_ (line 1711) | def test_uniform_() -> None: function test_kaiming_uniform_ (line 1720) | def test_kaiming_uniform_() -> None: function test_constant_ (line 1729) | def test_constant_() -> None: function test_leaky_relu (line 1736) | def test_leaky_relu() -> None: function test_fft_fft2 (line 1747) | def test_fft_fft2() -> None: function test_real (line 1754) | def test_real() -> None: function test_Tensor_init (line 1767) | def test_Tensor_init() -> None: function test_reflection_pad2d (line 1778) | def test_reflection_pad2d() -> None: function test_half (line 1789) | def test_half() -> None: function test_is_contiguous (line 1797) | def test_is_contiguous() -> None: function test_scatter (line 1802) | def test_scatter() -> None: function test_scatter_ (line 1814) | def test_scatter_() -> None: function test_bool (line 1826) | def test_bool() -> None: FILE: stubs/tqdm.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/triton/__init__.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/triton/language.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: stubs/triton/ops/blocksparse.pyi function __getattr__ (line 8) | def __getattr__(name) -> Any: ... FILE: tests/multiprocessing_utils.py class SafeMpContext (line 17) | class SafeMpContext(multiprocessing.context.BaseContext): method __init__ (line 18) | def __init__(self) -> None: method Process (line 22) | def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnPro... method kill_all_processes (line 28) | def kill_all_processes(self): method log_bad_exit_codes (line 43) | def log_bad_exit_codes(self): method __getattr__ (line 58) | def __getattr__(self, name: str): method __enter__ (line 61) | def __enter__(self): method __exit__ (line 64) | def __exit__(self, exc_type, exc_val, exc_tb): function init_process_group (line 69) | def init_process_group(init_method: str, rank: int, world_size: int): function _launch_subprocesses_fn_wrapper (line 87) | def _launch_subprocesses_fn_wrapper( function get_global_pool_allocator (line 121) | def get_global_pool_allocator( class ProcessPoolExecutorManager (line 142) | class ProcessPoolExecutorManager: method __init__ (line 143) | def __init__(self, world_size: int): method __enter__ (line 146) | def __enter__(self): method submit (line 155) | def submit(self, fn, *args, **kwargs): method __exit__ (line 158) | def __exit__(self, exc_type, exc_val, exc_tb): function launch_subprocesses (line 189) | def launch_subprocesses(world_size: int, fn, *args, **kwargs): FILE: tests/test_attention_patterns.py function _local_1d_pattern (line 15) | def _local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: function _generate_2d_grid (line 29) | def _generate_2d_grid(H, W): function _horizontal_axial_2d_distance (line 36) | def _horizontal_axial_2d_distance(H, W, p=2.0): function _vertical_axial_2d_distance (line 43) | def _vertical_axial_2d_distance(H, W, p=2.0): function _local_2d_distance (line 50) | def _local_2d_distance(H, W, p=2.0): function _local_2d_gaussian_distribution (line 58) | def _local_2d_gaussian_distribution(H, W, sigma=1.0): function test_local_1d_pattern (line 66) | def test_local_1d_pattern(attn_size, window_size): function test_horizontal_axial_2d_distance (line 75) | def test_horizontal_axial_2d_distance(H, W, p): function test_vertical_axial_2d_distance (line 84) | def test_vertical_axial_2d_distance(H, W, p): function test_local_2d_distance (line 93) | def test_local_2d_distance(H, W, p): function test_local_2d_gaussian_distribution (line 102) | def test_local_2d_gaussian_distribution(H, W, sigma): function test_swin_attention_pattern (line 111) | def test_swin_attention_pattern(H, W, window_size): function test_dilated_2d_pattern (line 156) | def test_dilated_2d_pattern(H, W, k): function test_pattern_to_layout (line 174) | def test_pattern_to_layout(): function test_alibi_pattern (line 209) | def test_alibi_pattern(): function test_layout_to_pattern (line 215) | def test_layout_to_pattern(): FILE: tests/test_checkpoint.py function _relu_policy (line 32) | def _relu_policy(ctx, func, *args, **kwargs): function _all_policy (line 36) | def _all_policy(ctx, func, *args, **kwargs): function test_checkpoint (line 44) | def test_checkpoint(policy_fn, input_requires_grad, device, autocast): function test_checkpoint_with_grad (line 80) | def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode): function test_checkpoint_attention (line 122) | def test_checkpoint_attention(policy_fn, input_requires_grad, device, au... function test_list_operators (line 187) | def test_list_operators(): function test_optimize_runtime_with_given_memory (line 222) | def test_optimize_runtime_with_given_memory(max_memory, optimal_soln): function _get_model_blocks (line 253) | def _get_model_blocks(num_layers, dtype, device, inplace, random, first_... class _Model (line 278) | class _Model(torch.nn.Module): method __init__ (line 279) | def __init__(self, blocks, policy_fn): method forward (line 284) | def forward(self, x): function test_optimal_checkpoint_policy (line 296) | def test_optimal_checkpoint_policy( function test_selective_checkpoint_wrapper_compile (line 340) | def test_selective_checkpoint_wrapper_compile( FILE: tests/test_fmha_flop_formula.py function test_flop_formula (line 44) | def test_flop_formula( function test_mask_nonzeros (line 148) | def test_mask_nonzeros() -> None: FILE: tests/test_fmha_merge_attentions.py function get_supported_attn_bias_types (line 38) | def get_supported_attn_bias_types(op): function test_merge_attentions_nobias (line 80) | def test_merge_attentions_nobias( function test_partial_paged (line 155) | def test_partial_paged( function test_merge_attentions_decoding (line 234) | def test_merge_attentions_decoding( function test_merge_attentions_sharedinput (line 377) | def test_merge_attentions_sharedinput( function test_merge_attentions_against_ref (line 490) | def test_merge_attentions_against_ref(bmghk: bool): function _merge_attentions_ref (line 513) | def _merge_attentions_ref(attn_split, lse_split): function test_merge_attention_with_compile (line 542) | def test_merge_attention_with_compile() -> None: function test_merge_training (line 574) | def test_merge_training(): function _pad_seqdim (line 628) | def _pad_seqdim(partial: Partial, left: int, right: int) -> Partial: function _slice (line 633) | def _slice(partial: Partial, a: int, b: int) -> Partial: function test_merge_training_compile (line 638) | def test_merge_training_compile(): function test_merge_training_zilch (line 684) | def test_merge_training_zilch(): function test_merge_training_undilate (line 690) | def test_merge_training_undilate(): FILE: tests/test_fwbw_overlap.py function test_fwbw_overlap (line 19) | def test_fwbw_overlap() -> None: function test_fwbw_nothing_to_overlap (line 116) | def test_fwbw_nothing_to_overlap() -> None: class ExceptionInBW (line 130) | class ExceptionInBW(Exception): class ExceptionInBWOp (line 134) | class ExceptionInBWOp(torch.autograd.Function): method forward (line 136) | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: method backward (line 140) | def backward(ctx: Any, gx: torch.Tensor) -> torch.Tensor: # type: ignore function test_exception_in_bw_pass (line 144) | def test_exception_in_bw_pass() -> None: function test_exception_in_first_bw_pass (line 164) | def test_exception_in_first_bw_pass() -> None: FILE: tests/test_indexing.py function test_scaled_index_add (line 25) | def test_scaled_index_add(out_shape, with_scaling: bool) -> None: function test_index_select_cat (line 80) | def test_index_select_cat(D, batches) -> None: FILE: tests/test_mem_eff_attention.py function _filter_unsupported_ops (line 99) | def _filter_unsupported_ops(ops: Sequence[T]) -> List[T]: function sample_random_supported_fw (line 122) | def sample_random_supported_fw( function generate_test_shapes_B_Mq_Mkv_H_K_Kv (line 143) | def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): function make_id (line 214) | def make_id(op, device, dtype, bias_type, *shape): function get_supported_attn_bias_types (line 226) | def get_supported_attn_bias_types(op): function _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv (line 249) | def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv( function _rand_partition (line 346) | def _rand_partition(r: random.Random, total: int, n: int) -> List[int]: function get_bias_grad (line 355) | def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tens... function create_tensors (line 367) | def create_tensors( function bmhk2bmk (line 483) | def bmhk2bmk(tensor) -> torch.Tensor: function bmk2bmhk (line 491) | def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: function nanify_oob_seqlen (line 497) | def nanify_oob_seqlen(x: torch.Tensor) -> torch.Tensor: function test_forward (line 510) | def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, *... function _block_diag_reshape_lse (line 612) | def _block_diag_reshape_lse( function test_logsumexp (line 624) | def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): function test_logsumexp_mqa (line 706) | def test_logsumexp_mqa(op): function test_backward (line 749) | def test_backward( function _vec_binom_test (line 948) | def _vec_binom_test(x, n, p): function _get_drop_mask (line 975) | def _get_drop_mask(op, batch_size, q_len, kv_len, p, device): function test_dropout_ck (line 994) | def test_dropout_ck(q_len, kv_len, batch_size, k_len, p, seed, attn_bias): function test_dropout_backward_ck (line 1052) | def test_dropout_backward_ck(q_len, kv_len, batch_size, k, p): function test_lowlevel_api_shapes (line 1125) | def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt): function test_cuda_streams (line 1152) | def test_cuda_streams( function test_custom_scale (line 1219) | def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): function apply_attention (line 1283) | def apply_attention(query, key, value, attn_bias, op_fw, proj): function test_grad_checkpointing (line 1296) | def test_grad_checkpointing( function test_unsupported_cpu (line 1370) | def test_unsupported_cpu(op: Type[fmha.AttentionFwOpBase]): function test_unsupported_stride_lastdim (line 1378) | def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]): function test_unsupported_stride_alignment (line 1395) | def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]): function test_unsupported_dropout_combine_flash_cutlass (line 1409) | def test_unsupported_dropout_combine_flash_cutlass() -> None: function test_attn_bias_causal (line 1425) | def test_attn_bias_causal() -> None: function test_attn_bias_torch_tensor (line 1440) | def test_attn_bias_torch_tensor() -> None: function test_attn_bias_blockdiag (line 1450) | def test_attn_bias_blockdiag() -> None: function test_attn_bias_blockdiag_batched (line 1472) | def test_attn_bias_blockdiag_batched() -> None: function test_attn_bias_blockdiag_crossattn_causal (line 1496) | def test_attn_bias_blockdiag_crossattn_causal() -> None: function test_attn_bias_blockdiag_crossattn_causal_with_prefix_qk_cond (line 1539) | def test_attn_bias_blockdiag_crossattn_causal_with_prefix_qk_cond() -> N... function test_attn_bias_blockdiag_crossattn_causal_with_prefix (line 1553) | def test_attn_bias_blockdiag_crossattn_causal_with_prefix() -> None: function test_attn_bias_padded (line 1589) | def test_attn_bias_padded() -> None: function _kv_heads_label (line 1647) | def _kv_heads_label(kv_heads: Optional[int]) -> str: function _test_decoder (line 1655) | def _test_decoder( function test_triton_splitk_decoder (line 1767) | def test_triton_splitk_decoder( function test_ck_splitk_decoder (line 1797) | def test_ck_splitk_decoder( function test_triton_splitk_decoder_manyqueries (line 1832) | def test_triton_splitk_decoder_manyqueries( function test_attn_bias_from_seqlens (line 1854) | def test_attn_bias_from_seqlens() -> None: function test_attn_bias_blockdiag_doc (line 1862) | def test_attn_bias_blockdiag_doc() -> None: class TestAttnBias (line 1893) | class TestAttnBias: method create_tensors (line 1895) | def create_tensors( method pad_bias (line 1912) | def pad_bias(bias: torch.Tensor) -> torch.Tensor: method test_f16_biasf32 (line 1920) | def test_f16_biasf32(self) -> None: method test_f32_biasf16 (line 1929) | def test_f32_biasf16(self) -> None: method test_wrong_alignment (line 1938) | def test_wrong_alignment(self, dtype) -> None: method test_permuted_attn_bias (line 1963) | def test_permuted_attn_bias(self) -> None: function test_window_size_materialize (line 1995) | def test_window_size_materialize() -> None: function test_forward_gqa (line 2041) | def test_forward_gqa(opFW_biasT, Mq: int): function test_backward_gqa (line 2074) | def test_backward_gqa(opBW): function test_forward_gqa_one_group (line 2121) | def test_forward_gqa_one_group(opFW): function test_flash_gqa_wrong_strides (line 2147) | def test_flash_gqa_wrong_strides() -> None: function _dispatches_to_splitK (line 2174) | def _dispatches_to_splitK(q, kv): function _dispatches_to_flash_decoding (line 2181) | def _dispatches_to_flash_decoding(q, kv): function test_dispatch_decoding_bmhk (line 2188) | def test_dispatch_decoding_bmhk() -> None: function test_dispatch_decoding_bmghk (line 2211) | def test_dispatch_decoding_bmghk() -> None: function test_forward_splitk (line 2268) | def test_forward_splitk( function test_mqa_decoding (line 2292) | def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K): function test_empty_tensors_empty_query (line 2315) | def test_empty_tensors_empty_query( function test_empty_tensors_empty_kv (line 2340) | def test_empty_tensors_empty_kv( function test_empty_tensors_empty_b (line 2367) | def test_empty_tensors_empty_b( function test_local_attn_bias (line 2387) | def test_local_attn_bias() -> None: function test_paged_attention (line 2415) | def test_paged_attention( function test_paged_attention_ck (line 2439) | def test_paged_attention_ck(B, MAX_T: int, page_size: int, gappy: bool): function test_paged_attention_flash (line 2458) | def test_paged_attention_flash(B, MAX_T: int, page_size: int): function test_paged_attention_flash3 (line 2479) | def test_paged_attention_flash3( function paged_attention_run_inner (line 2491) | def paged_attention_run_inner( function test_memeff_compile (line 2754) | def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -... function test_triton_splitk_rowwise_fp8 (line 2823) | def test_triton_splitk_rowwise_fp8( function fp8_per_head_quantize (line 2882) | def fp8_per_head_quantize( function test_fp8_attention (line 2908) | def test_fp8_attention(dtype_init, deterministic, causal, B, nheads, seq... function _pack_xformer_input (line 2949) | def _pack_xformer_input( function test_fav3_kvsplit_attn (line 2994) | def test_fav3_kvsplit_attn( function test_nans_in_padding (line 3051) | def test_nans_in_padding(op): FILE: tests/test_multiprocessing_utils.py function inner_test (line 12) | def inner_test(present_parent_keys: List[str] = [], absent_parent_keys: ... function test_env_vars (line 32) | def test_env_vars(): FILE: tests/test_profiler.py function test_profiler_dispatcher_stream_workaround (line 34) | def test_profiler_dispatcher_stream_workaround() -> None: function test_profiler_overhead (line 59) | def test_profiler_overhead(device_bs_mm) -> None: function assert_flops (line 114) | def assert_flops( function test_analyze_prof (line 157) | def test_analyze_prof(dtype) -> None: function test_analyze_prof_sdpa (line 179) | def test_analyze_prof_sdpa(dtype, backend, causal: bool) -> None: function test_analyze_prof_memeff (line 209) | def test_analyze_prof_memeff(op, causal: bool) -> None: FILE: tests/test_rmsnorm.py class RMSNormPytorch (line 26) | class RMSNormPytorch(torch.nn.Module): method __init__ (line 27) | def __init__(self, dim: int, include_weight: bool = True, eps: float =... method _norm (line 35) | def _norm(self, x): method forward (line 38) | def forward(self, x): function test_forward (line 48) | def test_forward(K: int, dtype: str): function test_increment (line 80) | def test_increment(K: int, include_weight: bool, dtype: str): FILE: tests/test_rope_padded.py function apply_scaling (line 26) | def apply_scaling( function _slow_rope (line 53) | def _slow_rope( function _slow_rope2 (line 107) | def _slow_rope2( function test_consistency (line 166) | def test_consistency( function test_rope_prefill (line 269) | def test_rope_prefill(seqlen) -> None: function test_rope_seqpos (line 302) | def test_rope_seqpos() -> None: FILE: tests/test_seqpar.py function reference_leading (line 31) | def reference_leading(input_, w1, w2): function reference_trailing (line 37) | def reference_trailing(hidden, w): function xformers_leading (line 42) | def xformers_leading(input_, w1, w2, *, fuse, group): function xformers_trailing (line 48) | def xformers_trailing(hidden, w, *, fuse, group): function inner_seqpar (line 54) | def inner_seqpar( function test_seqpar (line 266) | def test_seqpar( FILE: tests/test_sequence_parallel_fused_ops.py function compare_fused_and_non_fused_ops (line 29) | def compare_fused_and_non_fused_ops( function inner_sequence_parallel_fused (line 117) | def inner_sequence_parallel_fused( function test_sequence_parallel_fused (line 178) | def test_sequence_parallel_fused( function inner_sequence_parallel_fused_handle_all_dtypes (line 195) | def inner_sequence_parallel_fused_handle_all_dtypes( function test_sequence_parallel_fused_handle_all_dtypes (line 229) | def test_sequence_parallel_fused_handle_all_dtypes( FILE: tests/test_sparse_tensors.py function _create_blocksparse_tensor (line 23) | def _create_blocksparse_tensor( function _create_tensor (line 36) | def _create_tensor(tensor_type, device, dtype, shape, sparsity): function _seed (line 44) | def _seed(): function _get_dtype_atol (line 49) | def _get_dtype_atol(tensor_type, device: str): function test_masked_matmul (line 70) | def test_masked_matmul(tensor_type, device): function test_bmm (line 123) | def test_bmm(tensor_type, device): function test_sparse_softmax (line 173) | def test_sparse_softmax(tensor_type, device): function test_deepcopy (line 218) | def test_deepcopy(tensor_type, device): function test_module_buffer (line 239) | def test_module_buffer(tensor_type, device): FILE: tests/test_sparsity24.py function test_sparse24_largest_mask_2d (line 65) | def test_sparse24_largest_mask_2d() -> None: function test_sparse24_causal1122 (line 82) | def test_sparse24_causal1122(dtype) -> None: function test_sparse24_largest_abs_values_greedy (line 103) | def test_sparse24_largest_abs_values_greedy(dtype, backend) -> None: function test_sparse24_largest_mask_2d_notaligned (line 123) | def test_sparse24_largest_mask_2d_notaligned(dtype) -> None: function test_sparse24_largest_mask_2d_big (line 131) | def test_sparse24_largest_mask_2d_big(dtype) -> None: function create_random_mask (line 136) | def create_random_mask(shape) -> torch.Tensor: function test_detach_requires_grad (line 156) | def test_detach_requires_grad() -> None: function test_detach2 (line 174) | def test_detach2() -> None: function test_meta_pack_and_reorder (line 190) | def test_meta_pack_and_reorder() -> None: function test_pack_tensor_according_to_mask (line 239) | def test_pack_tensor_according_to_mask() -> None: function test_sp24_gemm (line 279) | def test_sp24_gemm(dtype) -> None: function test_pack_meta_shuffle (line 302) | def test_pack_meta_shuffle(transpose: bool) -> None: function test_pack_both_ways_meta_correctness (line 353) | def test_pack_both_ways_meta_correctness(dtype, backend) -> None: function test_pack_both_ways_id (line 382) | def test_pack_both_ways_id(dtype) -> None: function test_pack_both_ways_edge_case1 (line 416) | def test_pack_both_ways_edge_case1(dtype) -> None: function test_sp24_apply (line 444) | def test_sp24_apply(dtype) -> None: function test_sp24_api_different_pattern (line 461) | def test_sp24_api_different_pattern(dtype) -> None: function test_sp24_api_different_pattern_transposed (line 479) | def test_sp24_api_different_pattern_transposed(dtype) -> None: function _gen4x4 (line 496) | def _gen4x4(r: random.Random): function _gen_24_sparsifiable_both_ways (line 512) | def _gen_24_sparsifiable_both_ways( function test_sp24_transpose_invariant (line 532) | def test_sp24_transpose_invariant(dtype, backend) -> None: function test_cusparselt_format (line 558) | def test_cusparselt_format(M: int, N: int) -> None: function test_sp24_matmuls (line 573) | def test_sp24_matmuls(dtype) -> None: function test_sp24_matmuls_mat_vec (line 593) | def test_sp24_matmuls_mat_vec() -> None: function test_sp24_matmuls_bmm (line 604) | def test_sp24_matmuls_bmm() -> None: function sparsify24_dense (line 614) | def sparsify24_dense(tensor: torch.Tensor): function test_sp24_api_mlp_act24_correctness (line 622) | def test_sp24_api_mlp_act24_correctness(dtype, act) -> None: function test_sp24_api_swiglu_correctness (line 670) | def test_sp24_api_swiglu_correctness(dtype) -> None: function test_not_aligned (line 726) | def test_not_aligned(dtype, M): function test_sparsify24_like_dense (line 740) | def test_sparsify24_like_dense(dtype, input_rowmajor, backend): function test_sparsify24_weights (line 756) | def test_sparsify24_weights(dtype, backend): class LinearW24 (line 769) | class LinearW24(torch.nn.Linear): method forward (line 770) | def forward(self, input: torch.Tensor) -> torch.Tensor: function _workaround_cusparselt_internal_error (line 795) | def _workaround_cusparselt_internal_error() -> None: function test_linearw24 (line 808) | def test_linearw24(dtype, bias: bool, aligned: bool, amp: bool) -> None: function test_wrong_alignment_error_message (line 872) | def test_wrong_alignment_error_message() -> None: function test_min_alignment (line 882) | def test_min_alignment() -> None: function test_wrong_dtype_error_message (line 891) | def test_wrong_dtype_error_message() -> None: function test_linear_dispatch_inference_mode (line 902) | def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -... function test_sp24_meta (line 929) | def test_sp24_meta() -> None: function test_sp24_compile (line 939) | def test_sp24_compile(backend) -> None: class _TransformerFFN (line 956) | class _TransformerFFN(nn.Module): method __init__ (line 957) | def __init__( method forward (line 973) | def forward(self, x: torch.Tensor) -> torch.Tensor: function test_linearw24_block_compile (line 982) | def test_linearw24_block_compile() -> None: function test_sp24_ste (line 1018) | def test_sp24_ste(): function test_sparsify24_ste (line 1028) | def test_sparsify24_ste(dtype): class _Sp24X (line 1040) | class _Sp24X(torch.autograd.Function): method forward (line 1042) | def forward(ctx, x): method backward (line 1050) | def backward(ctx, x): function test_compile_unflatten (line 1069) | def test_compile_unflatten(): function _to_fp8_rowwise (line 1077) | def _to_fp8_rowwise(x: torch.Tensor, dtype) -> Tuple[torch.Tensor, torch... function test_sparseNM_dense (line 1087) | def test_sparseNM_dense(M: int, sort_preproc: str) -> None: function test_sparse24_fp8_sm90_cutlass_gemm_eye (line 1116) | def test_sparse24_fp8_sm90_cutlass_gemm_eye( function test_sparse24_fp8_sm90_cutlass_gemm_random_tensor (line 1151) | def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( FILE: tests/test_splitk_reference.py function ref_attention_splitk_bmhk (line 18) | def ref_attention_splitk_bmhk( function ref_attention_splitk (line 42) | def ref_attention_splitk( function _kv_heads_label (line 157) | def _kv_heads_label(kv_heads: Optional[int]) -> str: function test_splitk_reference (line 171) | def test_splitk_reference( FILE: tests/test_tiled_matmul.py function generate_test_shapes (line 31) | def generate_test_shapes(*repeats, num_shapes=5): function ceil_of_ratio (line 51) | def ceil_of_ratio(n, k): function make_operands (line 55) | def make_operands(m, n, k, *, dtype): function test_forward_backward (line 112) | def test_forward_backward( FILE: tests/test_tree_attention.py function test_tree_attention (line 92) | def test_tree_attention( class SplitKAutotune (line 105) | class SplitKAutotune(fmha.triton_splitk.FwOp): function run_tree_attention_inner (line 109) | def run_tree_attention_inner( function ref_tree_attention (line 293) | def ref_tree_attention( function tree_attention_with_sync (line 347) | def tree_attention_with_sync( function test_tree_attention_metadata_full_tree (line 427) | def test_tree_attention_metadata_full_tree(depth: int, branching: int) -... function test_tree_attention_metadata_arbitrary_tree (line 498) | def test_tree_attention_metadata_arbitrary_tree(branching: List[int]) ->... FILE: tests/test_triton_varargs.py function test_triton_varargs_kernel (line 41) | def test_triton_varargs_kernel(): function test_triton_multiple_varargs_kernel (line 71) | def test_triton_multiple_varargs_kernel(conditional: bool): function test_triton_varargs_conditional (line 109) | def test_triton_varargs_conditional(): function test_subscripting_call (line 146) | def test_subscripting_call(): FILE: tests/test_unbind.py function test_unbind (line 17) | def test_unbind(dim: int, contiguous: bool): function test_unbind_get_stack_strides (line 54) | def test_unbind_get_stack_strides(dim: int, contiguous: bool): FILE: tests/utils.py function use_cpu_ref (line 44) | def use_cpu_ref(device: str): function maybe_use_cpu_ref (line 48) | def maybe_use_cpu_ref(fn): function disable_tf32 (line 79) | def disable_tf32(fn): function assert_allclose (line 105) | def assert_allclose( function construct_fp8_attention_inputs (line 134) | def construct_fp8_attention_inputs( function _combine_scale_shift (line 311) | def _combine_scale_shift(scale: torch.Tensor, shift: torch.Tensor) -> to... function quantize_fp8_asymmetric (line 320) | def quantize_fp8_asymmetric( function dequantize_fp8_asymmetric (line 337) | def dequantize_fp8_asymmetric( FILE: xformers/__init__.py function compute_once (line 32) | def compute_once(func): function _is_triton_available (line 45) | def _is_triton_available(): function get_python_lib (line 69) | def get_python_lib(): FILE: xformers/_cpp_lib.py class _BuildInfo (line 23) | class _BuildInfo: method cuda_version (line 27) | def cuda_version(self) -> Optional[int]: method hip_version (line 31) | def hip_version(self) -> Optional[int]: method torch_version (line 35) | def torch_version(self) -> str: method python_version (line 39) | def python_version(self) -> str: method flash_version (line 43) | def flash_version(self) -> str: method use_torch_flash (line 47) | def use_torch_flash(self) -> bool: method build_env (line 51) | def build_env(self) -> Dict[str, Any]: class xFormersWasNotBuiltException (line 55) | class xFormersWasNotBuiltException(Exception): method __str__ (line 56) | def __str__(self) -> str: class xFormersInvalidLibException (line 65) | class xFormersInvalidLibException(Exception): method __init__ (line 66) | def __init__(self, build_info: Optional[_BuildInfo]) -> None: method __str__ (line 69) | def __str__(self) -> str: function _register_extensions (line 85) | def _register_extensions(): FILE: xformers/_deprecation_warning.py function deprecated_function (line 9) | def deprecated_function(self): FILE: xformers/attn_bias_utils.py function _create_aligned_bias (line 16) | def _create_aligned_bias(*shape: int, **kwargs) -> torch.Tensor: function create_attn_bias (line 30) | def create_attn_bias( function _rand_seqlens (line 263) | def _rand_seqlens( function _rand_maxed_partition (line 336) | def _rand_maxed_partition( function _rand_seqlens_padded_k (line 354) | def _rand_seqlens_padded_k( function ref_attention (line 374) | def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=... function ref_attention_bmhk (line 432) | def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor: function pack_kv_cache (line 451) | def pack_kv_cache( FILE: xformers/benchmarks/benchmark_attn_decoding.py function quantize_kv_int4 (line 38) | def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: class AttentionDecodingBase (line 75) | class AttentionDecodingBase: method __init__ (line 78) | def __init__( method get_inputs (line 149) | def get_inputs(self): method fw (line 155) | def fw(self) -> None: class AttentionDecodingCUTLASS (line 164) | class AttentionDecodingCUTLASS(AttentionDecodingBase): class AttentionDecodingCK (line 168) | class AttentionDecodingCK(AttentionDecodingBase): method __init__ (line 171) | def __init__( class AttentionDecodingSplitKV (line 239) | class AttentionDecodingSplitKV(AttentionDecodingBase): class AttentionDecodingCKSplitKV (line 243) | class AttentionDecodingCKSplitKV(AttentionDecodingBase): class AttentionDecodingSplitInt4KV (line 247) | class AttentionDecodingSplitInt4KV(AttentionDecodingBase): method __init__ (line 250) | def __init__( class AttentionDecodingPyTorchRepeat (line 335) | class AttentionDecodingPyTorchRepeat(AttentionDecodingBase): method fw (line 336) | def fw(self) -> None: class AttentionDecodingFlashAttention (line 369) | class AttentionDecodingFlashAttention(AttentionDecodingBase): method fw (line 370) | def fw(self) -> None: function get_benchmark_names (line 405) | def get_benchmark_names(): function test_flash_attention_decoder (line 416) | def test_flash_attention_decoder(name, case): function main (line 453) | def main() -> None: FILE: xformers/benchmarks/benchmark_indexing.py class ScaledIndexAddBenchmark (line 50) | class ScaledIndexAddBenchmark: method __init__ (line 51) | def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None: method fw (line 76) | def fw(self) -> None: method bw (line 85) | def bw(self): class ScaledIndexAddBenchmarkBaseline (line 93) | class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark): method fw (line 94) | def fw(self) -> None: class IndexSelectBenchmark (line 106) | class IndexSelectBenchmark: method __init__ (line 107) | def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None: method fw (line 131) | def fw(self) -> None: method bw (line 134) | def bw(self): class IndexSelectBenchmarkBaseline (line 140) | class IndexSelectBenchmarkBaseline(IndexSelectBenchmark): method fw (line 141) | def fw(self) -> None: FILE: xformers/benchmarks/benchmark_mem_eff_attention.py function product_dict (line 80) | def product_dict(**kwargs): function create_tensors (line 142) | def create_tensors(shape_q, Hkv, dtype, requires_grad=False, packed=True): function mem_eff_attention_fw (line 168) | def mem_eff_attention_fw( function mem_eff_attention_bw (line 266) | def mem_eff_attention_bw( function main (line 343) | def main(): FILE: xformers/benchmarks/benchmark_merge_attentions.py function _merge_attentions_varargs_ref (line 12) | def _merge_attentions_varargs_ref(attn_split, lse_split): function benchmark_merge_attentions_backward (line 34) | def benchmark_merge_attentions_backward(split_k, B, M, G, N_H_L, D_H, dt... function main (line 80) | def main(): FILE: xformers/benchmarks/benchmark_sequence_parallel_fused.py class Scenario (line 23) | class Scenario: class Step (line 33) | class Step(enum.Enum): method __str__ (line 37) | def __str__(self): class Bench (line 42) | class Bench: method __getitem__ (line 46) | def __getitem__(self, step: Step): function round_up_to_nearest_multiple (line 62) | def round_up_to_nearest_multiple(n: int, m: int) -> int: function llama_07B_MHA (line 66) | def llama_07B_MHA(world_size: int) -> Scenario: function llama_07B_FFN (line 76) | def llama_07B_FFN(world_size: int) -> Scenario: function llama_70B_MHA (line 87) | def llama_70B_MHA(world_size: int) -> Scenario: function llama_70B_FFN (line 97) | def llama_70B_FFN(world_size: int) -> Scenario: function run_one_rank (line 120) | def run_one_rank( function main (line 422) | def main(): FILE: xformers/benchmarks/benchmark_sp24.py class Mlp (line 43) | class Mlp(nn.Module): method __init__ (line 46) | def __init__( method fw (line 65) | def fw(self): method bw (line 72) | def bw(self): class MlpDenseMask (line 76) | class MlpDenseMask(Mlp): method fw (line 77) | def fw(self): class MlpAct24 (line 89) | class MlpAct24(Mlp): method fw (line 90) | def fw(self): class LinearW24 (line 101) | class LinearW24(torch.nn.Linear): method forward (line 102) | def forward(self, input: torch.Tensor) -> torch.Tensor: class MlpW24 (line 111) | class MlpW24(Mlp): class MicrobenchmarkBase (line 115) | class MicrobenchmarkBase: method __init__ (line 116) | def __init__( method bw (line 131) | def bw(self) -> None: class MicrobenchmarkSparsify24 (line 135) | class MicrobenchmarkSparsify24(MicrobenchmarkBase): method fw (line 136) | def fw(self) -> torch.Tensor: class MicrobenchmarkSp24ApplyDense (line 141) | class MicrobenchmarkSp24ApplyDense(MicrobenchmarkBase): method fw (line 142) | def fw(self) -> torch.Tensor: class MicrobenchmarkSp24ApplyDenseT (line 147) | class MicrobenchmarkSp24ApplyDenseT(MicrobenchmarkBase): method fw (line 148) | def fw(self) -> torch.Tensor: class MicrobenchmarkInputClone (line 153) | class MicrobenchmarkInputClone(MicrobenchmarkBase): method fw (line 154) | def fw(self) -> torch.Tensor: FILE: xformers/benchmarks/benchmark_tiled_matmul.py function product_dict (line 34) | def product_dict(**kwargs): function matmul_per_tile (line 53) | def matmul_per_tile(a, b): function benchmark_tiled_matmul (line 64) | def benchmark_tiled_matmul(shape_name, dtype): FILE: xformers/benchmarks/utils.py class NotSupportedInputError (line 34) | class NotSupportedInputError(Exception): function get_func_name (line 47) | def get_func_name(fn): function pretty_print (line 53) | def pretty_print(results, title, units) -> None: function pretty_plot (line 79) | def pretty_plot( function bench_functions (line 123) | def bench_functions( function pretty_barplot (line 153) | def pretty_barplot(results, title, units: str, filename=None, dash_key=""): function rmf (line 211) | def rmf(filename: str) -> None: function temp_files_ctx (line 220) | def temp_files_ctx(num: int) -> Generator: function _benchmark_results_from_csv (line 237) | def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, A... function _benchmark_results_to_csv (line 280) | def _benchmark_results_to_csv( function _finalize_results (line 306) | def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List... function _render_bar_plot (line 343) | def _render_bar_plot(results: List[Any], store_results_folder: str) -> N... function create_argparser (line 403) | def create_argparser() -> argparse.ArgumentParser: function benchmark_main_helper (line 438) | def benchmark_main_helper( function benchmark_run_and_compare (line 463) | def benchmark_run_and_compare( function _is_oom_error (line 643) | def _is_oom_error(e): function _fail_if_regressions (line 649) | def _fail_if_regressions( function benchmark_main_helper2 (line 707) | def benchmark_main_helper2( function product_dict (line 756) | def product_dict(**kwargs): FILE: xformers/checkpoint.py class _NotAvailable (line 44) | class _NotAvailable: method __init__ (line 45) | def __init__(self, *args, **kwargs): class ProfileMetadata (line 71) | class ProfileMetadata: function _get_default_policy (line 82) | def _get_default_policy(allow_list=None): class VerboseTorchDispatchMode (line 98) | class VerboseTorchDispatchMode(TorchDispatchMode): method __init__ (line 99) | def __init__(self): method __torch_dispatch__ (line 102) | def __torch_dispatch__(self, func, types, args=(), kwargs=None): function list_operators (line 109) | def list_operators(function, *args, **kwargs): class CachedTorchDispatchMode (line 120) | class CachedTorchDispatchMode(_CachedTorchDispatchMode): method __init__ (line 121) | def __init__(self, policy_fn, storage, allow_cache_entry_mutation): method pop_from_storage (line 129) | def pop_from_storage(self, func, args, kwargs): class NullTorchDispatchMode (line 137) | class NullTorchDispatchMode(TorchDispatchMode): method __torch_dispatch__ (line 138) | def __torch_dispatch__(self, func, types, args=(), kwargs=None): function selective_checkpoint_context_fn (line 144) | def selective_checkpoint_context_fn(policy_fn=None): function checkpoint (line 175) | def checkpoint( class ProfileOperatorsTorchDispatchMode (line 209) | class ProfileOperatorsTorchDispatchMode(TorchDispatchMode): method __init__ (line 210) | def __init__(self, num_runs: int = 10) -> None: method _get_inplace_metadata (line 214) | def _get_inplace_metadata(self, func, out) -> Tuple[int, int, Tuple[in... method __torch_dispatch__ (line 245) | def __torch_dispatch__(self, func, types, args=(), kwargs=None): function _analyze_operators (line 287) | def _analyze_operators(function, *args) -> List[ProfileMetadata]: function get_optimal_checkpoint_policy (line 309) | def get_optimal_checkpoint_policy(function, *args, memory_budget: float)... function _optimize_runtime_with_given_memory (line 386) | def _optimize_runtime_with_given_memory( class _OptimalPolicy (line 461) | class _OptimalPolicy: method __init__ (line 462) | def __init__(self, optim_output: torch.Tensor): method __call__ (line 466) | def __call__(self, ctx, func, *args, **kwargs) -> bool: class SelectiveCheckpointWrapper (line 475) | class SelectiveCheckpointWrapper(ActivationWrapper): method __init__ (line 476) | def __init__(self, mod, memory_budget=None, policy_fn=None): method _get_policy_fn (line 492) | def _get_policy_fn(self, *args, **kwargs): method get_policy_fn (line 515) | def get_policy_fn(self, *args, **kwargs): method forward (line 520) | def forward(self, *args, **kwargs): function selective_checkpoint_wrapper (line 527) | def selective_checkpoint_wrapper( FILE: xformers/components/attention/attention_patterns.py function _generate_nd_grid (line 15) | def _generate_nd_grid(*sizes): function local_nd_distance (line 20) | def local_nd_distance(*sizes, p=2.0, weights=None): function local_nd_gaussian_distribution (line 31) | def local_nd_gaussian_distribution(*sizes, sigma=1): function local_nd_pattern (line 37) | def local_nd_pattern(*sizes, distance, p=2.0): function axial_nd_pattern (line 42) | def axial_nd_pattern(*sizes): function random_pattern_from_probability_matrix (line 48) | def random_pattern_from_probability_matrix(dist_matrix, nnz): function global_token_pattern (line 69) | def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Te... function random_pattern (line 77) | def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor: function local_1d_pattern (line 84) | def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: function causal_1d_pattern (line 92) | def causal_1d_pattern(attn_size: int) -> torch.Tensor: function horizontal_axial_2d_distance (line 98) | def horizontal_axial_2d_distance(H, W, p=2.0): function vertical_axial_2d_distance (line 103) | def vertical_axial_2d_distance(H, W, p=2.0): function local_2d_distance (line 108) | def local_2d_distance(H, W, p=2.0): function local_2d_gausian_distribution (line 112) | def local_2d_gausian_distribution(H, W, sigma=1): function local_2d_pattern (line 116) | def local_2d_pattern(H, W, distance, p=2.0): function axial_2d_pattern (line 120) | def axial_2d_pattern(H, W): function swin_attention_pattern (line 124) | def swin_attention_pattern(H, W, window_size, shift_size=0): function dilated_2d_pattern (line 155) | def dilated_2d_pattern(H, W, k=2): function block_sparsify_tensor (line 168) | def block_sparsify_tensor(x, mask, block_size): function pattern_to_layout (line 186) | def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor: function alibi_pattern (line 214) | def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Ten... function layout_to_pattern (line 263) | def layout_to_pattern(layout: torch.Tensor, block_size: int): FILE: xformers/csrc/attention/attention.cpp function STABLE_TORCH_LIBRARY_FRAGMENT (line 10) | STABLE_TORCH_LIBRARY_FRAGMENT(xformers, m) { FILE: xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp type c10_to_data_t (line 26) | struct c10_to_data_t type c10_to_data_t (line 28) | struct c10_to_data_t { type c10_to_data_t (line 33) | struct c10_to_data_t { type c10_to_data_t (line 38) | struct c10_to_data_t { function instantiate_and_launch_kernels (line 58) | void instantiate_and_launch_kernels( function efficient_attention_forward_decoder_splitk_ck_impl (line 251) | at::Tensor efficient_attention_forward_decoder_splitk_ck_impl( function efficient_attention_forward_decoder_splitk_ck (line 293) | at::Tensor efficient_attention_forward_decoder_splitk_ck( function TORCH_LIBRARY_IMPL (line 312) | TORCH_LIBRARY_IMPL(xformers, CUDA, m) { FILE: xformers/csrc/attention/hip_decoder/ck_tile_attention_forward_decoder_splitk.h function a_u (line 18) | union { function __device__ (line 32) | __device__ __forceinline__ wavefrontReduce(float val, F f) { function load_v (line 41) | void load_v( function store_v (line 49) | void store_v( function namespace (line 58) | namespace ck_tile { FILE: xformers/csrc/attention/hip_decoder/ck_tile_attention_inner_product.h function namespace (line 11) | namespace ck_tile { FILE: xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp function efficient_attention_backward_ck (line 36) | std::tuple function efficient_attention_backward_ck_meta (line 548) | std::tuple function TORCH_LIBRARY_IMPL (line 630) | TORCH_LIBRARY_IMPL(xformers, CUDA, m) { function TORCH_LIBRARY_IMPL (line 636) | TORCH_LIBRARY_IMPL(xformers, Meta, m) { FILE: xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp function rand_uniform_int (line 27) | at::Tensor rand_uniform_int( function TORCH_LIBRARY_IMPL (line 93) | TORCH_LIBRARY_IMPL(xformers, CUDA, m) { FILE: xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp function efficient_attention_forward_ck (line 50) | std::tuple, int64_t, int64_t> function efficient_attention_forward_ck_meta (line 475) | std::tuple, int64_t, int64_t> function TORCH_LIBRARY_IMPL (line 523) | TORCH_LIBRARY_IMPL(xformers, CUDA, m) { function TORCH_LIBRARY_IMPL (line 529) | TORCH_LIBRARY_IMPL(xformers, Meta, m) { FILE: xformers/csrc/attention/hip_fmha/ck_fmha_test.cpp function is_ck_fmha_available (line 14) | bool is_ck_fmha_available(double val) { function TORCH_LIBRARY_FRAGMENT (line 21) | TORCH_LIBRARY_FRAGMENT(xformers, m) { FILE: xformers/csrc/attention/hip_fmha/ck_fmha_util.h function at (line 61) | static inline at::Tensor get_bias_4d_view( function get_number_of_cu (line 95) | static inline int get_number_of_cu() { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h function Run (line 58) | static void Run(BatchedBackwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp function batched_backward_bf16 (line 16) | void batched_backward_bf16(BatchedBackwardParams& param, hipStream_t str... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp function batched_backward_fp16 (line 16) | void batched_backward_fp16(BatchedBackwardParams& param, hipStream_t str... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp function batched_forward_bf16 (line 16) | void batched_forward_bf16(BatchedForwardParams& param, hipStream_t strea... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp function batched_forward_fp16 (line 16) | void batched_forward_fp16(BatchedForwardParams& param, hipStream_t strea... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h function else (line 25) | struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp function batched_infer_bf16 (line 15) | void batched_infer_bf16(BatchedForwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp function batched_infer_fp16 (line 15) | void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h function else (line 25) | struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h function fp16_t (line 17) | struct FmhaBwdTypeConfig { function bf16_t (line 36) | struct FmhaBwdTypeConfig { type FmhaBwdBlockTile (line 58) | struct FmhaBwdBlockTile type FmhaBwdBlockTile (line 66) | struct FmhaBwdBlockTile type FmhaBwdBlockTile (line 74) | struct FmhaBwdBlockTile type FmhaBwdBlockTile (line 82) | struct FmhaBwdBlockTile type FmhaBwdBlockTile (line 91) | struct FmhaBwdBlockTile type FmhaBwdShape (line 107) | struct FmhaBwdShape type FmhaBwdShape (line 121) | struct FmhaBwdShape type FmhaBwdShape (line 135) | struct FmhaBwdShape type FmhaBwdShape (line 149) | struct FmhaBwdShape type FmhaBwdShape (line 163) | struct FmhaBwdShape type FmhaBwdPipelineMaker (line 183) | struct FmhaBwdPipelineMaker type FmhaBwdBlockDropoutMaker (line 203) | struct FmhaBwdBlockDropoutMaker FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h type FmhaFwdBlockTile (line 47) | struct FmhaFwdBlockTile type FmhaFwdBlockTile (line 54) | struct FmhaFwdBlockTile type FmhaFwdShape (line 83) | struct FmhaFwdShape type FmhaFwdShape (line 128) | struct FmhaFwdShape type FmhaFwdShape (line 139) | struct FmhaFwdShape function get_fmha_fwd_mtile (line 177) | static int get_fmha_fwd_mtile( function get_fmha_fwd_least_mtile (line 195) | static int get_fmha_fwd_least_mtile() { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_selector.h function generate_splits_list (line 19) | static int generate_splits_list(int i) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_setting.h type FmhaFwdSplitKVBlockTile (line 47) | struct FmhaFwdSplitKVBlockTile type FmhaFwdSplitKVBlockTile (line 54) | struct FmhaFwdSplitKVBlockTile type FmhaFwdSplitKVShape (line 72) | struct FmhaFwdSplitKVShape type FmhaFwdSplitKVShape (line 117) | struct FmhaFwdSplitKVShape type FmhaFwdSplitKVShape (line 128) | struct FmhaFwdSplitKVShape function fwd_splitkv_get_mtile_size (line 153) | int fwd_splitkv_get_mtile_size() { function get_mtile_size_for_splitkv (line 159) | static int get_mtile_size_for_splitkv(int max_seqlen_q, int max_headdim) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_selector.h function use_splitkv_smallq (line 13) | static bool use_splitkv_smallq(int max_seqlen_q, int max_headdim) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_setting.h type FmhaFwdSplitKVSmallQBlockTile (line 19) | struct FmhaFwdSplitKVSmallQBlockTile type FmhaFwdSplitKVSmallQBlockTile (line 26) | struct FmhaFwdSplitKVSmallQBlockTile type FmhaFwdSplitKVSmallQBlockTile (line 33) | struct FmhaFwdSplitKVSmallQBlockTile type FmhaFwdSplitKVSmallQBlockTile (line 40) | struct FmhaFwdSplitKVSmallQBlockTile type FmhaFwdSplitKVSmallQBlockTile (line 47) | struct FmhaFwdSplitKVSmallQBlockTile type FmhaFwdSplitKVSmallQShape (line 60) | struct FmhaFwdSplitKVSmallQShape type FmhaFwdSplitKVSmallQShape (line 71) | struct FmhaFwdSplitKVSmallQShape type FmhaFwdSplitKVSmallQShape (line 82) | struct FmhaFwdSplitKVSmallQShape type FmhaFwdSplitKVSmallQShape (line 93) | struct FmhaFwdSplitKVSmallQShape type FmhaFwdSplitKVSmallQShape (line 104) | struct FmhaFwdSplitKVSmallQShape function get_mtile_size_for_splitkv_smallq (line 121) | static int get_mtile_size_for_splitkv_smallq(int max_headdim) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_type_config.h function fp16_t (line 15) | struct FmhaFwdTypeConfig { function bf16_t (line 31) | struct FmhaFwdTypeConfig { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h function Run (line 58) | static void Run(GroupedBackwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp function grouped_backward_bf16 (line 16) | void grouped_backward_bf16(GroupedBackwardParams& param, hipStream_t str... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp function grouped_backward_fp16 (line 16) | void grouped_backward_fp16(GroupedBackwardParams& param, hipStream_t str... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp function grouped_forward_bf16 (line 16) | void grouped_forward_bf16(GroupedForwardParams& param, hipStream_t strea... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp function grouped_forward_fp16 (line 16) | void grouped_forward_fp16(GroupedForwardParams& param, hipStream_t strea... FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h function else (line 25) | struct grouped_forward_splitkv_smallq_mask_bias_dropout_dispatch { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp function grouped_infer_bf16 (line 15) | void grouped_infer_bf16(GroupedForwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp function grouped_infer_fp16 (line 15) | void grouped_infer_fp16(GroupedForwardParams& param, hipStream_t stream) { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h function else (line 25) | struct grouped_infer_splitkv_smallq_mask_bias_dropout_dispatch { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h type BatchedInferParams (line 12) | struct BatchedInferParams { function BatchedInferParams (line 42) | struct BatchedForwardParams : public BatchedInferParams { type GroupedInferParams (line 70) | struct GroupedInferParams { function GroupedInferParams (line 114) | struct GroupedForwardParams : public GroupedInferParams { type BatchedBackwardParams (line 144) | struct BatchedBackwardParams { type GroupedBackwardParams (line 202) | struct GroupedBackwardParams { FILE: xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h function __device__ (line 25) | __device__ static constexpr auto GetBlockGemm() { type FmhaRandUniformCommonKargs (line 54) | struct FmhaRandUniformCommonKargs { function FmhaRandUniformCommonKargs (line 72) | struct FmhaRandUniformBatchModeKargs : FmhaRandUniformCommonKargs { function FmhaRandUniformCommonKargs (line 76) | struct FmhaRandUniformGroupModeKargs : FmhaRandUniformCommonKargs { function Kargs (line 99) | Kargs kargs{ function Kargs (line 129) | Kargs kargs{ FILE: xformers/csrc/attention/hip_fmha/generate_instances.py function create_infer_instances (line 127) | def create_infer_instances(instance_dir: Path, headdims: List) -> None: function create_infer_instances_ref (line 165) | def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: function create_forward_instances (line 198) | def create_forward_instances(instance_dir: Path, headdims: List) -> None: function create_forward_instances_ref (line 236) | def create_forward_instances_ref(instance_dir: Path, headdims: List) -> ... function create_backward_instances (line 271) | def create_backward_instances(instance_dir: Path, headdims: List) -> None: function create_backward_instances_ref (line 315) | def create_backward_instances_ref(instance_dir: Path, headdims: List) ->... FILE: xformers/csrc/pt_stable_utils.h function T (line 56) | inline T ceil_div(T a, T b) { function xf_get_layout (line 68) | inline int32_t xf_get_layout(const torch::stable::Tensor& self) { function xf_is_sparse (line 74) | inline bool xf_is_sparse(const torch::stable::Tensor& self) { function torch (line 170) | inline torch::stable::Tensor xf_new_full( function torch (line 198) | inline torch::stable::Tensor xf_resize_( FILE: xformers/csrc/sparse24/compute_sparse_tile.h function namespace (line 13) | namespace xformers { FILE: xformers/csrc/sparse24/sparse24.cpp function STABLE_TORCH_LIBRARY_FRAGMENT (line 3) | STABLE_TORCH_LIBRARY_FRAGMENT(xformers, m) { FILE: xformers/csrc/sparse24/sparse24_metadata.h function namespace (line 12) | namespace xformers { function CUTLASS_HOST_DEVICE (line 105) | CUTLASS_HOST_DEVICE function CUTLASS_HOST_DEVICE (line 127) | CUTLASS_HOST_DEVICE function CUTLASS_HOST_DEVICE (line 136) | CUTLASS_HOST_DEVICE function MetadataCutlassSm80 (line 147) | struct MetadataCutlassSm80 { function CUTLASS_HOST_DEVICE (line 198) | CUTLASS_HOST_DEVICE function CUTLASS_HOST_DEVICE (line 216) | CUTLASS_HOST_DEVICE function CUTLASS_HOST_DEVICE (line 226) | CUTLASS_HOST_DEVICE type MetadataCutlass8bitsSm90 (line 242) | struct MetadataCutlass8bitsSm90 { FILE: xformers/csrc/sparse24/sparse24_pack.h function namespace (line 11) | namespace xformers { FILE: xformers/csrc/sparse24/static_sort.h function CUTLASS_HOST_DEVICE (line 24) | CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) { function CUTLASS_HOST_DEVICE (line 31) | CUTLASS_HOST_DEVICE PB(A& a) { type PB (line 52) | struct PB function CUTLASS_HOST_DEVICE (line 53) | CUTLASS_HOST_DEVICE PB(A& a) { type PB (line 60) | struct PB function CUTLASS_HOST_DEVICE (line 61) | CUTLASS_HOST_DEVICE PB(A& a) { type PS (line 78) | struct PS function CUTLASS_HOST_DEVICE (line 79) | CUTLASS_HOST_DEVICE PS(A& a) {} FILE: xformers/csrc/sparse24/warp_tensor.h function namespace (line 9) | namespace xformers { function TileValueOrdered1d (line 313) | struct TileValueOrdered1d { type Identity (line 398) | struct Identity { FILE: xformers/fwbw_overlap.py class EventHandle (line 24) | class EventHandle: # type: ignore[no-redef] method __init__ (line 25) | def __init__(self) -> None: method current_stream_wait (line 28) | def current_stream_wait(self) -> None: class EventOverlap (line 31) | class EventOverlap: # type: ignore[no-redef] method __init__ (line 32) | def __init__(self, event: Union[EventHandle, None] = None) -> None: method current_stream_wait (line 35) | def current_stream_wait(self) -> None: class EventOverlapHolder (line 43) | class EventOverlapHolder(torch.Tensor): method capture (line 55) | def capture( method __new__ (line 68) | def __new__( method __init__ (line 83) | def __init__( method __tensor_flatten__ (line 94) | def __tensor_flatten__(self): method __repr__ (line 97) | def __repr__(self) -> str: # type: ignore method current_stream_wait (line 100) | def current_stream_wait(self) -> None: method __torch_dispatch__ (line 107) | def __torch_dispatch__( class _ExitCompute (line 137) | class _ExitCompute(torch.autograd.Function): method forward (line 147) | def forward(ctx: torch.autograd.function.FunctionCtx, *tensors: torch.... method backward (line 154) | def backward( # type: ignore class _EnterCompute (line 166) | class _EnterCompute(torch.autograd.Function): method forward (line 176) | def forward( method backward (line 189) | def backward(ctx: torch.autograd.function.FunctionCtx, *gtensors: torc... class _FillGradientForOverlapHolder (line 198) | class _FillGradientForOverlapHolder(torch.autograd.Function): method forward (line 206) | def forward( method backward (line 220) | def backward( # type: ignore function enter_comm (line 242) | def enter_comm( function enter_compute (line 253) | def enter_compute( function enter_compute (line 264) | def enter_compute( function enter_compute (line 273) | def enter_compute( # type: ignore class PhaseBoundary (line 286) | class PhaseBoundary: method __post_init__ (line 293) | def __post_init__(self) -> None: method __str__ (line 297) | def __str__(self) -> str: method __call__ (line 303) | def __call__(self) -> None: class InitialBw (line 327) | class InitialBw: method __init__ (line 328) | def __init__(self, trigger_bw: Callable[[], None]) -> None: method __call__ (line 331) | def __call__(self) -> None: class _GlobalAutogradThread (line 347) | class _GlobalAutogradThread: method run (line 355) | def run(cls) -> None: method cleanup_at_exit (line 379) | def cleanup_at_exit(cls) -> None: function async_bw (line 387) | def async_bw(backward_fn: Callable[[], None]) -> threading.Semaphore: class _WaitInBW (line 397) | class _WaitInBW(torch.autograd.Function): method forward (line 399) | def forward( method backward (line 411) | def backward(ctx: torch.autograd.function.FunctionCtx, *gx: torch.Tens... class _CurrentForwardState (line 432) | class _CurrentForwardState: function before_forward (line 445) | def before_forward(record_fw_chunks: bool) -> None: function enter_phase (line 455) | def enter_phase(enter: str, *tensors: torch.Tensor) -> tuple[torch.Tenso... function flush_single_bw_chunk (line 480) | def flush_single_bw_chunk() -> bool: function flush_pending_bw (line 488) | def flush_pending_bw() -> None: function overlap_fw_bw (line 497) | def overlap_fw_bw( function _overlap_fw_bw (line 508) | def _overlap_fw_bw( FILE: xformers/info.py function get_features_status (line 16) | def get_features_status() -> Dict[str, str]: function print_info (line 25) | def print_info(): FILE: xformers/ops/__init__.py function masked_matmul (line 44) | def masked_matmul(a, b, mask=None): FILE: xformers/ops/_triton/k_index_select_cat.py function index_select_cat_fwd_kernel (line 12) | def index_select_cat_fwd_kernel( function index_select_cat_fwd (line 38) | def index_select_cat_fwd( function index_select_cat_bwd_kernel (line 83) | def index_select_cat_bwd_kernel( function index_select_cat_bwd (line 126) | def index_select_cat_bwd( FILE: xformers/ops/_triton/k_scaled_index_add.py function scaled_index_add_fwd_kernel (line 14) | def scaled_index_add_fwd_kernel( function scaled_index_add_fwd (line 77) | def scaled_index_add_fwd( function scaled_index_add_bwd_kernel (line 176) | def scaled_index_add_bwd_kernel( function scaled_index_add_bwd (line 256) | def scaled_index_add_bwd( FILE: xformers/ops/_triton/matmul_perf_model.py function get_clock_rate_in_khz (line 45) | def get_clock_rate_in_khz(): function get_tensorcore_tflops (line 56) | def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): function get_simd_tflops (line 70) | def get_simd_tflops(device, num_ctas, num_warps, dtype): function get_tflops (line 84) | def get_tflops(device, num_ctas, num_warps, dtype): function estimate_matmul_time (line 91) | def estimate_matmul_time( function early_config_prune (line 173) | def early_config_prune(configs, named_args, **kwargs): FILE: xformers/ops/_triton/rmsnorm_kernels.py function _rms_norm_kernel (line 15) | def _rms_norm_kernel( function _rms_norm_add_kernel (line 51) | def _rms_norm_add_kernel( function _rms_norm_forward (line 94) | def _rms_norm_forward(x, attn_norm_weights, eps): function _rms_norm_add_forward (line 125) | def _rms_norm_add_forward(x, y, attn_norm_weights, eps): FILE: xformers/ops/_triton/rope_padded_kernels.py function _rope_padded_kernel (line 14) | def _rope_padded_kernel( FILE: xformers/ops/_triton/tiled_matmul_kernels.py function init_to_zero (line 20) | def init_to_zero(*names): function gen_config (line 28) | def gen_config( function our_estimate_matmul_time (line 112) | def our_estimate_matmul_time( function our_early_config_prune (line 129) | def our_early_config_prune(config, named_args, **kwargs): function _xformers_tiled_matmul_kernel (line 158) | def _xformers_tiled_matmul_kernel( function _check_row_or_column (line 349) | def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, d... function _get_strides (line 360) | def _get_strides( function _launch_triton_matmul (line 383) | def _launch_triton_matmul( FILE: xformers/ops/common.py function get_operator (line 11) | def get_operator(library: str, name: str): function get_xformers_operator (line 23) | def get_xformers_operator(name: str): class BaseOperator (line 27) | class BaseOperator: method is_available (line 33) | def is_available(cls) -> bool: function register_operator (line 49) | def register_operator(cls: ClsT) -> ClsT: function _get_storage_base (line 63) | def _get_storage_base(x: torch.Tensor) -> int: FILE: xformers/ops/differentiable_collectives.py function all_reduce (line 13) | def all_reduce( function gather_along_first_dim_async (line 25) | def gather_along_first_dim_async( function reduce_scatter_along_first_dim_async (line 43) | def reduce_scatter_along_first_dim_async( function gather_along_first_dim (line 63) | def gather_along_first_dim( function reduce_scatter_along_first_dim (line 72) | def reduce_scatter_along_first_dim( class _CopyToModelParallelRegion (line 83) | class _CopyToModelParallelRegion(torch.autograd.Function): method forward (line 85) | def forward( # type: ignore[override] method backward (line 92) | def backward( # type: ignore[override] function copy_to_model_parallel_region (line 99) | def copy_to_model_parallel_region( class _ReduceFromModelParallelRegion (line 107) | class _ReduceFromModelParallelRegion(torch.autograd.Function): method forward (line 109) | def forward( # type: ignore[override] method backward (line 117) | def backward( # type: ignore[override] function reduce_from_model_parallel_region (line 123) | def reduce_from_model_parallel_region( class _GatherFromSequenceParallelRegion (line 131) | class _GatherFromSequenceParallelRegion(torch.autograd.Function): method forward (line 133) | def forward( # type: ignore[override] method backward (line 140) | def backward( # type: ignore[override] function gather_from_sequence_parallel_region (line 151) | def gather_from_sequence_parallel_region( class _ScatterToSequenceParallelRegion (line 159) | class _ScatterToSequenceParallelRegion(torch.autograd.Function): method forward (line 161) | def forward( # type: ignore[override] method backward (line 168) | def backward( # type: ignore[override] function scatter_to_sequence_parallel_region (line 177) | def scatter_to_sequence_parallel_region( FILE: xformers/ops/fmha/__init__.py function _deserialize_bias (line 55) | def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Te... function _serialize_op (line 71) | def _serialize_op(op): function _unserialize_op (line 77) | def _unserialize_op(op): class _fMHA (line 83) | class _fMHA(torch.autograd.Function): method forward (line 86) | def forward(ctx, op_fw, op_bw, *args: Any) -> Any: method backward (line 168) | def backward(ctx, grad, grad_lse): function memory_efficient_attention (line 199) | def memory_efficient_attention( function memory_efficient_attention_forward_meta (line 332) | def memory_efficient_attention_forward_meta(q, k, v): function memory_efficient_attention_forward_torch_wrapper (line 339) | def memory_efficient_attention_forward_torch_wrapper( function memory_efficient_attention_forward (line 367) | def memory_efficient_attention_forward( function memory_efficient_attention_forward_requires_grad (line 395) | def memory_efficient_attention_forward_requires_grad( function memory_efficient_attention_backward (line 431) | def memory_efficient_attention_backward( function _memory_efficient_attention (line 467) | def _memory_efficient_attention( function _memory_efficient_attention_forward (line 485) | def _memory_efficient_attention_forward( function _memory_efficient_attention_forward_requires_grad (line 499) | def _memory_efficient_attention_forward_requires_grad( function _detect_lse_packed_or_raise (line 513) | def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optio... function _memory_efficient_attention_backward (line 552) | def _memory_efficient_attention_backward( function memory_efficient_attention_partial (line 597) | def memory_efficient_attention_partial( function merge_attentions (line 638) | def merge_attentions( FILE: xformers/ops/fmha/_triton/splitk_kernels.py function _fwd_kernel_splitK (line 31) | def _fwd_kernel_splitK( function gen_config (line 589) | def gen_config( function _get_splitk_kernel (line 607) | def _get_splitk_kernel(num_groups): function early_config_prune (line 631) | def early_config_prune(configs, named_args, **kwargs): function autotune_kernel (line 643) | def autotune_kernel(kernel: Callable): function get_autotuner_cache (line 683) | def get_autotuner_cache( function set_autotuner_cache (line 692) | def set_autotuner_cache( function load_dequantize_k_v_group (line 699) | def load_dequantize_k_v_group( function cast_uint32_to_half2 (line 784) | def cast_uint32_to_half2(scale_shift): function cast_uint32_to_float (line 794) | def cast_uint32_to_float(scale_shift): function dequantize_k_hip (line 804) | def dequantize_k_hip( function dequantize (line 852) | def dequantize( function _splitK_reduce (line 908) | def _splitK_reduce( function _splitK_reduce_varargs (line 1024) | def _splitK_reduce_varargs( function _splitK_reduce_varargs_backward (line 1136) | def _splitK_reduce_varargs_backward( FILE: xformers/ops/fmha/attn_bias.py function _to_device (line 39) | def _to_device(t: torch.Tensor, device: torch.device) -> torch.Tensor: function _to_device_tensor (line 48) | def _to_device_tensor(seq: Sequence[int], dtype: torch.dtype, device: to... class AttentionBias (line 55) | class AttentionBias: method materialize (line 89) | def materialize( function _get_default_bias_device (line 104) | def _get_default_bias_device(device: Optional[torch.device] = None) -> t... function _materialize_causal_mask (line 114) | def _materialize_causal_mask( class LowerTriangularMask (line 142) | class LowerTriangularMask(AttentionBias): method __init__ (line 153) | def __init__(self, device: Union[torch.device, None] = None) -> None: method to (line 156) | def to(self, device: torch.device) -> "LowerTriangularMask": method materialize (line 160) | def materialize( method add_bias (line 168) | def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTens... class LocalAttentionFromBottomRightMask (line 176) | class LocalAttentionFromBottomRightMask(AttentionBias): method to (line 221) | def to(self, device) -> "LocalAttentionFromBottomRightMask": method __post_init__ (line 224) | def __post_init__(self) -> None: method materialize (line 238) | def materialize( class LowerTriangularFromBottomRightMask (line 261) | class LowerTriangularFromBottomRightMask(AttentionBias): method to (line 281) | def to(self, device: torch.device) -> "LowerTriangularFromBottomRightM... method materialize (line 287) | def materialize( method make_local_attention (line 297) | def make_local_attention( class LowerTriangularFromBottomRightLocalAttentionMask (line 309) | class LowerTriangularFromBottomRightLocalAttentionMask( method to (line 331) | def to( method __post_init__ (line 339) | def __post_init__(self) -> None: method materialize (line 345) | def materialize( class LowerTriangularMaskWithTensorBias (line 360) | class LowerTriangularMaskWithTensorBias(LowerTriangularMask): method __init__ (line 363) | def __init__(self, bias: torch.Tensor) -> None: method to (line 366) | def to(self, device: torch.device) -> "LowerTriangularMaskWithTensorBi... method materialize (line 372) | def materialize( class _SeqLenInfo (line 382) | class _SeqLenInfo: method to (line 400) | def to(self, device: torch.device) -> "_SeqLenInfo": method intervals (line 411) | def intervals(self) -> Iterable[Tuple[int, int]]: method _get_seqstart (line 415) | def _get_seqstart( method from_seqlens (line 436) | def from_seqlens( method from_seqlens_inplace (line 454) | def from_seqlens_inplace(self, seqlens: Iterable[int]) -> None: method split (line 475) | def split( class _PaddedSeqLenInfo (line 500) | class _PaddedSeqLenInfo(_SeqLenInfo): method __post_init__ (line 542) | def __post_init__(self) -> None: method to (line 545) | def to(self, device: torch.device) -> "_PaddedSeqLenInfo": method intervals (line 561) | def intervals(self) -> Iterable[Tuple[int, int]]: method from_seqlens (line 566) | def from_seqlens( method from_seqlens_padded (line 574) | def from_seqlens_padded( method from_seqlens_padded_inplace (line 602) | def from_seqlens_padded_inplace(self, seqlens: Sequence[int]) -> None: method split (line 629) | def split( class _GappySeqInfo (line 636) | class _GappySeqInfo(_SeqLenInfo): method to (line 689) | def to(self, device: torch.device) -> "_GappySeqInfo": method intervals (line 704) | def intervals(self) -> Iterable[Tuple[int, int]]: method from_seqlens (line 709) | def from_seqlens( method from_seqlens_gappy (line 715) | def from_seqlens_gappy( method split (line 746) | def split( class BlockDiagonalMask (line 753) | class BlockDiagonalMask(AttentionBias): method to (line 796) | def to(self, device) -> "BlockDiagonalMask": method _create_block_mask (line 804) | def _create_block_mask( method materialize (line 816) | def materialize( method from_seqlens (line 849) | def from_seqlens( method from_tensor_list (line 875) | def from_tensor_list( method from_tensor_lists_qkv (line 908) | def from_tensor_lists_qkv( method split_queries (line 936) | def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: method split_kv (line 939) | def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: method split (line 942) | def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: method make_causal (line 954) | def make_causal(self) -> "BlockDiagonalCausalMask": method make_causal_from_bottomright (line 962) | def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBott... method make_local_attention (line 970) | def make_local_attention( method make_local_attention_from_bottomright (line 981) | def make_local_attention_from_bottomright( class BlockDiagonalCausalMask (line 994) | class BlockDiagonalCausalMask(BlockDiagonalMask): method to (line 1004) | def to(self, device) -> "BlockDiagonalCausalMask": method _create_block_mask (line 1012) | def _create_block_mask( class BlockDiagonalCausalFromBottomRightMask (line 1026) | class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): method to (line 1039) | def to(self, device) -> "BlockDiagonalCausalFromBottomRightMask": method __post_init__ (line 1049) | def __post_init__(self) -> None: method _create_block_mask (line 1064) | def _create_block_mask( class BlockDiagonalPaddedKeysMask (line 1076) | class BlockDiagonalPaddedKeysMask(AttentionBias): method to (line 1096) | def to(self, device) -> "BlockDiagonalPaddedKeysMask": method _create_block_mask (line 1103) | def _create_block_mask( method materialize (line 1111) | def materialize( method from_seqlens (line 1140) | def from_seqlens( method make_paged (line 1171) | def make_paged( method make_local_attention (line 1193) | def make_local_attention( class BlockDiagonalCausalWithOffsetPaddedKeysMask (line 1205) | class BlockDiagonalCausalWithOffsetPaddedKeysMask(BlockDiagonalPaddedKey... method to (line 1226) | def to(self, device) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": method _create_block_mask (line 1235) | def _create_block_mask( method from_seqlens (line 1246) | def from_seqlens( class BlockDiagonalLocalAttentionPaddedKeysMask (line 1279) | class BlockDiagonalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysM... method to (line 1300) | def to(self, device) -> "BlockDiagonalLocalAttentionPaddedKeysMask": method _create_block_mask (line 1311) | def _create_block_mask( method from_seqlens_local (line 1322) | def from_seqlens_local( class BlockDiagonalCausalLocalAttentionPaddedKeysMask (line 1345) | class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPadde... method to (line 1360) | def to(self, device) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMa... method _create_block_mask (line 1370) | def _create_block_mask( method from_seqlens_local (line 1385) | def from_seqlens_local( class PagedBlockDiagonalPaddedKeysMask (line 1402) | class PagedBlockDiagonalPaddedKeysMask(AttentionBias): method to (line 1419) | def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask": method materialize (line 1430) | def materialize( method from_seqlens (line 1470) | def from_seqlens( class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask (line 1508) | class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask( method to (line 1520) | def to( class BlockDiagonalGappyKeysMask (line 1535) | class BlockDiagonalGappyKeysMask(AttentionBias): method to (line 1546) | def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask": method materialize (line 1553) | def materialize( method from_seqlens (line 1576) | def from_seqlens( method make_paged (line 1598) | def make_paged( class BlockDiagonalCausalWithOffsetGappyKeysMask (line 1658) | class BlockDiagonalCausalWithOffsetGappyKeysMask(BlockDiagonalGappyKeysM... method to (line 1668) | def to(self, device: torch.device) -> "BlockDiagonalCausalWithOffsetGa... method materialize (line 1677) | def materialize( class PagedBlockDiagonalGappyKeysMask (line 1708) | class PagedBlockDiagonalGappyKeysMask(AttentionBias): method to (line 1725) | def to(self, device: torch.device) -> "PagedBlockDiagonalGappyKeysMask": method materialize (line 1736) | def materialize( method from_seqlens (line 1787) | def from_seqlens( class PagedBlockDiagonalCausalWithOffsetGappyKeysMask (line 1828) | class PagedBlockDiagonalCausalWithOffsetGappyKeysMask(PagedBlockDiagonal... method to (line 1838) | def to( class BlockDiagonalCausalLocalAttentionMask (line 1853) | class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): method to (line 1865) | def to(self, device) -> "BlockDiagonalCausalLocalAttentionMask": method __post_init__ (line 1876) | def __post_init__(self): method _create_block_mask (line 1901) | def _create_block_mask( class BlockDiagonalCausalLocalAttentionFromBottomRightMask (line 1916) | class BlockDiagonalCausalLocalAttentionFromBottomRightMask( method to (line 1931) | def to(self, device) -> "BlockDiagonalCausalLocalAttentionFromBottomRi... method __post_init__ (line 1942) | def __post_init__(self): method _create_block_mask (line 1949) | def _create_block_mask( FILE: xformers/ops/fmha/ck.py function _minimum_gemm_alignment (line 45) | def _minimum_gemm_alignment(inp: Inputs) -> int: function _get_seqlen_info (line 49) | def _get_seqlen_info( function _get_tensor_bias (line 90) | def _get_tensor_bias( function _check_bias_alignment (line 100) | def _check_bias_alignment( class _CustomMaskType (line 128) | class _CustomMaskType(int, Enum): function _custom_mask_type (line 138) | def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]... class FwOp (line 165) | class FwOp(AttentionFwOpBase): method apply (line 222) | def apply( method apply_bmhk (line 280) | def apply_bmhk( method not_supported_reasons (line 353) | def not_supported_reasons(cls, d: Inputs) -> List[str]: class BwOp (line 363) | class BwOp(AttentionBwOpBase): method not_supported_reasons (line 399) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 432) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... FILE: xformers/ops/fmha/ck_splitk.py class FwOp (line 21) | class FwOp(AttentionFwOpBase): method shape_not_supported_reasons (line 47) | def shape_not_supported_reasons( method not_supported_reasons (line 56) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method get_split_k (line 95) | def get_split_k(cls, B: int, H: int, Mk: int) -> int: method apply (line 107) | def apply( class FwOp_S1 (line 171) | class FwOp_S1(FwOp): class FwOp_S2 (line 176) | class FwOp_S2(FwOp): class FwOp_S4 (line 181) | class FwOp_S4(FwOp): class FwOp_S8 (line 186) | class FwOp_S8(FwOp): class FwOp_S16 (line 191) | class FwOp_S16(FwOp): class FwOp_S32 (line 196) | class FwOp_S32(FwOp): class FwOp_S64 (line 201) | class FwOp_S64(FwOp): class FwOp_S128 (line 206) | class FwOp_S128(FwOp): FILE: xformers/ops/fmha/common.py function _is_bias_type_supported_in_BMK (line 38) | def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool: function _attn_bias_apply (line 47) | def _attn_bias_apply( class ScaledTensor (line 58) | class ScaledTensor(torch.Tensor): method __new__ (line 64) | def __new__( method dequantize (line 95) | def dequantize(self) -> torch.Tensor: method unpack (line 109) | def unpack(self) -> Tuple[torch.Tensor, torch.Tensor]: method __repr__ (line 117) | def __repr__(self): function pack_fp8_tensorwise_per_head (line 124) | def pack_fp8_tensorwise_per_head( class Inputs (line 145) | class Inputs: method device (line 160) | def device(self) -> torch.device: method scale_float (line 164) | def scale_float(self) -> float: method get_qkv_in_bmghk (line 167) | def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.... method normalize_bmhk (line 184) | def normalize_bmhk(self) -> Tuple[int, ...]: method validate_inputs (line 208) | def validate_inputs(self) -> None: method get_output_dtype (line 327) | def get_output_dtype(self) -> torch.dtype: method nbytes (line 335) | def nbytes(self) -> int: class Context (line 345) | class Context: method get_padded_lse (line 354) | def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> ... class Gradients (line 368) | class Gradients: class AttentionOpBase (line 376) | class AttentionOpBase(BaseOperator): method supports (line 416) | def supports(cls, d: Inputs) -> bool: method shape_not_supported_reasons (line 420) | def shape_not_supported_reasons( method not_supported_reasons (line 437) | def not_supported_reasons(cls, d: Inputs) -> List[str]: class AttentionFwOpBase (line 508) | class AttentionFwOpBase(AttentionOpBase): method apply (line 521) | def apply( class AttentionBwOpBase (line 527) | class AttentionBwOpBase(AttentionOpBase): method not_supported_reasons (line 547) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 561) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... function bmk2bmhk (line 570) | def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: function check_lastdim_alignment_stride1 (line 578) | def check_lastdim_alignment_stride1( FILE: xformers/ops/fmha/cutlass.py function _uses_tensorcores (line 40) | def _uses_tensorcores(sm: int, is_half: bool) -> bool: function _minimum_gemm_alignment (line 48) | def _minimum_gemm_alignment(inp: Inputs) -> int: function _get_seqlen_info (line 65) | def _get_seqlen_info( function _get_tensor_bias (line 86) | def _get_tensor_bias( function _check_bias_alignment (line 96) | def _check_bias_alignment( class _CustomMaskType (line 124) | class _CustomMaskType(int, Enum): function _custom_mask_type (line 134) | def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]... class FwOp (line 159) | class FwOp(AttentionFwOpBase): method apply (line 202) | def apply( method apply_bmhk (line 266) | def apply_bmhk( method not_supported_reasons (line 317) | def not_supported_reasons(cls, d: Inputs) -> List[str]: class BwOp (line 327) | class BwOp(AttentionBwOpBase): method not_supported_reasons (line 368) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 400) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... FILE: xformers/ops/fmha/cutlass_blackwell.py function _get_operator (line 34) | def _get_operator(name: str): function _convert_input_format (line 63) | def _convert_input_format( function _is_seqlen_q_le_seqlen_k (line 156) | def _is_seqlen_q_le_seqlen_k( function _is_causal (line 169) | def _is_causal(attn_bias: Union[torch.Tensor, AttentionBias, None]) -> b... function _is_bottom_right (line 187) | def _is_bottom_right(attn_bias: Union[torch.Tensor, AttentionBias, None]... function _window_size (line 203) | def _window_size( class FwOp (line 231) | class FwOp(AttentionFwOpBase): method not_supported_reasons (line 267) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method shape_not_supported_reasons (line 290) | def shape_not_supported_reasons( method apply (line 301) | def apply( class BwOp (line 350) | class BwOp(AttentionBwOpBase): method not_supported_reasons (line 382) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method shape_not_supported_reasons (line 405) | def shape_not_supported_reasons( method apply (line 418) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... FILE: xformers/ops/fmha/dispatch.py function _set_use_fa3 (line 22) | def _set_use_fa3(use_flash_attention3: bool) -> None: function _get_use_fa3 (line 27) | def _get_use_fa3() -> bool: function fa3_available (line 32) | def fa3_available() -> bool: function _format_inputs_description (line 39) | def _format_inputs_description(inp: Inputs) -> str: function _ensure_op_supports_or_raise (line 47) | def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -... function _format_not_supported_reasons (line 56) | def _format_not_supported_reasons(op, reasons: List[str]) -> str: function _run_priority_list (line 60) | def _run_priority_list( function _dispatch_fw_priority_list (line 84) | def _dispatch_fw_priority_list( function _dispatch_fw (line 131) | def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwO... function _dispatch_bw (line 147) | def _dispatch_bw( FILE: xformers/ops/fmha/flash.py function _flash_fwd (line 92) | def _flash_fwd( function _flash_fwd_abstract (line 177) | def _flash_fwd_abstract( function _flash_bwd (line 211) | def _flash_bwd( function _flash_bwd_abstract (line 309) | def _flash_bwd_abstract( function _create_dq_dk_dv (line 320) | def _create_dq_dk_dv( function _convert_input_format (line 336) | def _convert_input_format( function _is_causal (line 438) | def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) ... function _is_paged_attention_supported (line 458) | def _is_paged_attention_supported(attn_bias_type) -> bool: function _window_size (line 470) | def _window_size( function _check_needs_no_topleft (line 497) | def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: function _check_strides_for_bmghk (line 519) | def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[s... function _post_process_lse (line 537) | def _post_process_lse( class FwOp (line 558) | class FwOp(AttentionFwOpBase): method not_supported_reasons (line 603) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 614) | def apply( class BwOp (line 694) | class BwOp(AttentionBwOpBase): method not_supported_reasons (line 727) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 747) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... FILE: xformers/ops/fmha/flash3.py function maybe_contiguous (line 65) | def maybe_contiguous(x: T) -> T: function _flash_attention3_incompatible_reason (line 69) | def _flash_attention3_incompatible_reason() -> Optional[str]: function _heuristic_kvsplit (line 132) | def _heuristic_kvsplit( function mask_non_zeros (line 151) | def mask_non_zeros(s_q: int, s_k: int, window_left: int, window_right: i... function sdpa_flop_count (line 193) | def sdpa_flop_count( function mha_fwd (line 225) | def mha_fwd( function mha_fwd_fake (line 327) | def mha_fwd_fake( function mha_fwd_flops (line 365) | def mha_fwd_flops( function _create_dq_dk_dv (line 430) | def _create_dq_dk_dv( function mha_bwd (line 448) | def mha_bwd( function mha_bwd_fake (line 501) | def mha_bwd_fake( function mha_bwd_flops (line 521) | def mha_bwd_flops( function _check_different_value_headdim_ampere (line 568) | def _check_different_value_headdim_ampere(d: Inputs, reasons: List[str])... function _get_blocktables (line 582) | def _get_blocktables(inp_attn_bias) -> Optional[torch.Tensor]: class FwOp (line 594) | class FwOp(AttentionFwOpBase): method not_supported_reasons (line 648) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 666) | def apply( class BwOp (line 768) | class BwOp(AttentionBwOpBase): method not_supported_reasons (line 802) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method apply (line 812) | def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradi... class FwOp_KVSplit (line 872) | class FwOp_KVSplit(FwOp): method apply (line 901) | def apply( # type: ignore[override] FILE: xformers/ops/fmha/merge_training.py class _PartialFunc (line 39) | class _PartialFunc(torch.autograd.Function): method forward (line 41) | def forward( method backward (line 64) | def backward( # type: ignore[override] class _MergeFunc (line 86) | class _MergeFunc(torch.autograd.Function): method forward (line 88) | def forward( method backward (line 104) | def backward( # type: ignore[override] class Partial (line 111) | class Partial: method __init__ (line 125) | def __init__( method is_bmghk (line 138) | def is_bmghk(self) -> bool: method apply (line 141) | def apply(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Partial": method _tuple (line 160) | def _tuple(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: function memory_efficient_attention_partial_autograd (line 164) | def memory_efficient_attention_partial_autograd( function merge_attentions_autograd (line 184) | def merge_attentions_autograd( FILE: xformers/ops/fmha/torch_attention_compat.py function is_pt_cutlass_compatible (line 10) | def is_pt_cutlass_compatible(force: bool = False) -> bool: function ensure_pt_flash_ok (line 62) | def ensure_pt_flash_ok() -> None: FILE: xformers/ops/fmha/triton_splitk.py function _strides (line 42) | def _strides(x: Optional[torch.Tensor], *stride_names: str): function _is_supported_causal_bias (line 49) | def _is_supported_causal_bias(attn_bias: Any) -> bool: function _is_supported_local_bias (line 62) | def _is_supported_local_bias(attn_bias: Any) -> bool: function _is_supported_gappy_bias (line 72) | def _is_supported_gappy_bias(attn_bias: Any) -> bool: function _is_supported_paged_bias (line 82) | def _is_supported_paged_bias(attn_bias: Any) -> bool: class InputsFp8 (line 93) | class InputsFp8(Inputs): method nbytes (line 105) | def nbytes(self) -> int: function _is_cuda (line 131) | def _is_cuda() -> bool: function _is_cuda_at_least_sm80 (line 135) | def _is_cuda_at_least_sm80(device: torch.device) -> bool: class FwOp (line 143) | class FwOp(AttentionFwOpBase): method shape_not_supported_reasons (line 252) | def shape_not_supported_reasons( method not_supported_reasons (line 261) | def not_supported_reasons(cls, d: Inputs) -> List[str]: method get_split_k (line 337) | def get_split_k( method get_kernel (line 383) | def get_kernel(cls): method get_fp8_scale_shift (line 395) | def get_fp8_scale_shift( method get_extra_args (line 415) | def get_extra_args( method apply (line 606) | def apply( method get_operator (line 1016) | def get_operator( function merge_attentions (line 1047) | def merge_attentions( function merge_attentions_varargs (line 1106) | def merge_attentions_varargs( function merge_attentions_varargs_fake (line 1158) | def merge_attentions_varargs_fake( function _merge_attentions_backward (line 1184) | def _merge_attentions_backward( function merge_attentions_varargs_backward (line 1204) | def merge_attentions_varargs_backward( function merge_attentions_varargs_backward_fake (line 1243) | def merge_attentions_varargs_backward_fake( function _prepare_reduce_kernel_params (line 1256) | def _prepare_reduce_kernel_params( FILE: xformers/ops/indexing.py class ScaledIndexAddFw (line 23) | class ScaledIndexAddFw(BaseOperator): class ScaledIndexAddBw (line 30) | class ScaledIndexAddBw(BaseOperator): class IndexSelect (line 37) | class IndexSelect(BaseOperator): class _ScaledIndexAdd (line 43) | class _ScaledIndexAdd(torch.autograd.Function): method forward (line 46) | def forward( method backward (line 69) | def backward(ctx, grad_output): function scaled_index_add (line 104) | def scaled_index_add( class _IndexSelectCat (line 132) | class _IndexSelectCat(torch.autograd.Function): method forward (line 135) | def forward( method backward (line 180) | def backward(ctx, grad_output): function index_select_cat (line 215) | def index_select_cat( FILE: xformers/ops/modpar_layers.py function _init_2d_weight (line 18) | def _init_2d_weight( class ColumnParallelLinear (line 45) | class ColumnParallelLinear(torch.nn.Module): method __init__ (line 46) | def __init__( method forward (line 94) | def forward(self, input_: torch.Tensor) -> List[torch.Tensor]: class RowParallelLinear (line 108) | class RowParallelLinear(torch.nn.Module): method __init__ (line 109) | def __init__( method forward (line 149) | def forward(self, input_: torch.Tensor) -> torch.Tensor: FILE: xformers/ops/rmsnorm.py function rms_norm (line 13) | def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6): function rms_norm_add (line 42) | def rms_norm_add( class RMSNorm (line 72) | class RMSNorm(torch.nn.Module): method __init__ (line 91) | def __init__(self, dim: int, include_weight: bool = True, eps: float =... method forward (line 99) | def forward(self, x: torch.Tensor): method increment_and_forward_ (line 102) | def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): FILE: xformers/ops/rope_padded.py function rope_padded (line 16) | def rope_padded( FILE: xformers/ops/seqpar.py function sequence_parallel_leading_matmul_fwd (line 32) | def sequence_parallel_leading_matmul_fwd( function sequence_parallel_leading_matmul_fwd_fake (line 56) | def sequence_parallel_leading_matmul_fwd_fake( function sequence_parallel_leading_matmul_bwd (line 74) | def sequence_parallel_leading_matmul_bwd( function sequence_parallel_leading_matmul_bwd_fake (line 170) | def sequence_parallel_leading_matmul_bwd_fake( function sequence_parallel_leading_matmul_setup_context (line 180) | def sequence_parallel_leading_matmul_setup_context(ctx, inputs, output): function sequence_parallel_leading_matmul_bwd_bridge (line 187) | def sequence_parallel_leading_matmul_bwd_bridge(ctx, grad_gathered_outpu... function sequence_parallel_leading_matmul (line 209) | def sequence_parallel_leading_matmul( function sequence_parallel_trailing_matmul_fwd (line 227) | def sequence_parallel_trailing_matmul_fwd( function sequence_parallel_trailing_matmul_fwd_fake (line 248) | def sequence_parallel_trailing_matmul_fwd_fake( function sequence_parallel_trailing_matmul_bwd (line 265) | def sequence_parallel_trailing_matmul_bwd( function sequence_parallel_trailing_matmul_bwd_fake (line 316) | def sequence_parallel_trailing_matmul_bwd_fake( function sequence_parallel_trailing_matmul_setup_context (line 326) | def sequence_parallel_trailing_matmul_setup_context(ctx, inputs, output): function sequence_parallel_trailing_matmul_bwd_bridge (line 333) | def sequence_parallel_trailing_matmul_bwd_bridge(ctx, grad_scattered_out... function sequence_parallel_trailing_matmul (line 355) | def sequence_parallel_trailing_matmul( FILE: xformers/ops/sequence_parallel_fused_ops.py function _is_fp8_dtype (line 20) | def _is_fp8_dtype(dt: torch.dtype): class _FusedSequenceParallel (line 26) | class _FusedSequenceParallel: method __init__ (line 68) | def __init__( method make_stream_factory (line 89) | def make_stream_factory( method allgather_and_linear (line 100) | def allgather_and_linear( method linear_and_reducescatter (line 209) | def linear_and_reducescatter( function _can_ranks_communicate_all_to_all_over_nvlink (line 342) | def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGro... function _lazy_init (line 355) | def _lazy_init( function _default_stream_factory (line 374) | def _default_stream_factory() -> torch.cuda.Stream: function fused_allgather_and_linear (line 379) | def fused_allgather_and_linear( function fused_allgather_and_linear (line 394) | def fused_allgather_and_linear( function fused_allgather_and_linear (line 408) | def fused_allgather_and_linear( function _fused_allgather_and_linear_custom_op (line 522) | def _fused_allgather_and_linear_custom_op( function fused_allgather_and_anything (line 564) | def fused_allgather_and_anything( function fused_linear_and_reducescatter (line 620) | def fused_linear_and_reducescatter( function fused_linear_and_reducescatter (line 635) | def fused_linear_and_reducescatter( function fused_linear_and_reducescatter (line 649) | def fused_linear_and_reducescatter( function _fused_linear_and_reducescatter_custom_op (line 749) | def _fused_linear_and_reducescatter_custom_op( function fused_anything_and_reducescatter (line 791) | def fused_anything_and_reducescatter( FILE: xformers/ops/sp24.py class SparsifyBothWays (line 18) | class SparsifyBothWays(BaseOperator): class SparsifyApply (line 25) | class SparsifyApply(BaseOperator): class SparsifyApplyDenseOutput (line 32) | class SparsifyApplyDenseOutput(BaseOperator): class Sp24Gemm (line 39) | class Sp24Gemm(BaseOperator): function _get_cusparselt_torch_version (line 45) | def _get_cusparselt_torch_version() -> Tuple[int, int, int]: class Sp24GemmCuspltSearch (line 62) | class Sp24GemmCuspltSearch(BaseOperator): class Sp24GemmCusplt (line 69) | class Sp24GemmCusplt(BaseOperator): function _has_cusparseLt (line 75) | def _has_cusparseLt() -> bool: function sparse24_pointwise_op (line 90) | def sparse24_pointwise_op( function sparse24_mm (line 140) | def sparse24_mm(func, types, args=(), kwargs=None) -> torch.Tensor: function sparse24_addmm (line 155) | def sparse24_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: function sparse24_linear (line 175) | def sparse24_linear(func, types, args=(), kwargs=None) -> torch.Tensor: function sparse24_t (line 188) | def sparse24_t(func, types, args=(), kwargs=None) -> torch.Tensor: function sparse24_view (line 203) | def sparse24_view(func, types, args=(), kwargs=None) -> torch.Tensor: function sparse24_detach (line 213) | def sparse24_detach(func, types, args, kwargs) -> torch.Tensor: function no_dispatch (line 228) | def no_dispatch(): function fallback_dispatcher (line 236) | def fallback_dispatcher(func, types, args, kwargs): class Sparse24Tensor (line 289) | class Sparse24Tensor(torch.Tensor): method __new__ (line 300) | def __new__( method __repr__ (line 326) | def __repr__(self): method _sp24_to_dense (line 329) | def _sp24_to_dense(self) -> torch.Tensor: method _mm (line 337) | def _mm( method __tensor_flatten__ (line 348) | def __tensor_flatten__(self): method __tensor_unflatten__ (line 352) | def __tensor_unflatten__( class Sparse24TensorCutlass (line 363) | class Sparse24TensorCutlass(Sparse24Tensor): method _mm (line 364) | def _mm( method __torch_dispatch__ (line 392) | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): function _cusplt_find_alg (line 410) | def _cusplt_find_alg( function _cusplt_mm (line 463) | def _cusplt_mm( function _cusplt_mm_meta (line 483) | def _cusplt_mm_meta( class Sparse24TensorCuSparseLt (line 497) | class Sparse24TensorCuSparseLt(Sparse24Tensor): method _mm (line 498) | def _mm( method __torch_dispatch__ (line 548) | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): function _sparsify24_forward (line 571) | def _sparsify24_forward(x: torch.Tensor, *, algo: str, backend: str) -> ... class _Sparsify24Func (line 600) | class _Sparsify24Func(torch.autograd.Function): method forward (line 602) | def forward(ctx, x: torch.Tensor, algo: str, gradient: str, backend: s... method backward (line 617) | def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] class _Sparsify24STEFunc (line 646) | class _Sparsify24STEFunc(torch.autograd.Function): method forward (line 648) | def forward( method backward (line 663) | def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] class _Sparsify24LikeFunc (line 680) | class _Sparsify24LikeFunc(torch.autograd.Function): method forward (line 682) | def forward(ctx, x: torch.Tensor, pattern: Sparse24Tensor, gradient: s... method backward (line 728) | def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] function allow_in_graph (line 769) | def allow_in_graph(func: F) -> F: function sparsify24 (line 774) | def sparsify24( function sparsify24_ste (line 784) | def sparsify24_ste( function sparsify24_like (line 801) | def sparsify24_like( FILE: xformers/ops/swiglu_op.py class _SwiGLUDecomposedFunc (line 16) | class _SwiGLUDecomposedFunc(torch.autograd.Function): method _silu_backward (line 32) | def _silu_backward(dy, x): method forward (line 39) | def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): method backward (line 51) | def backward(cls, ctx, dx5): class SwiGLUOp (line 72) | class SwiGLUOp: method __init__ (line 75) | def __init__(self, op, packed_weights: bool, name: str, constraints): method supports (line 81) | def supports(self, op: "SwiGLUOpDispatch") -> bool: method __call__ (line 86) | def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor: method __str__ (line 89) | def __str__(self) -> str: class _ForwardToPythonAutogradFunc (line 93) | class _ForwardToPythonAutogradFunc(SwiGLUOp): method supports (line 94) | def supports(self, op: "SwiGLUOpDispatch") -> bool: method __call__ (line 97) | def __call__(self, *args, **kwargs): class _ForwardToFunc (line 101) | class _ForwardToFunc(SwiGLUOp): method __call__ (line 102) | def __call__(self, *args, **kwargs): method info (line 105) | def info(self): function _eager_functional_swiglu (line 111) | def _eager_functional_swiglu( class SwiGLUOpDispatch (line 127) | class SwiGLUOpDispatch: method op (line 139) | def op(self) -> SwiGLUOp: method from_arguments (line 148) | def from_arguments( function _bias_enabled (line 170) | def _bias_enabled(op: SwiGLUOpDispatch) -> bool: function swiglu (line 185) | def swiglu( function swiglu_packed (line 262) | def swiglu_packed( class SwiGLU (line 302) | class SwiGLU(nn.Module): method __init__ (line 308) | def __init__( method forward (line 343) | def forward(self, x: torch.Tensor) -> torch.Tensor: method _ordered_params (line 361) | def _ordered_params( method _packed_ordered_params (line 398) | def _packed_ordered_params( FILE: xformers/ops/tiled_matmul.py function _should_use_triton (line 16) | def _should_use_triton(device: torch.device, dtype: torch.dtype) -> bool: function check_inputs (line 32) | def check_inputs( function check_output (line 95) | def check_output(out: List[List[torch.Tensor]], ms: List[int], ns: List[... function tiled_matmul_out (line 137) | def tiled_matmul_out( function _flatten (line 167) | def _flatten(x: List[List[torch.Tensor]], rows: int, cols: int) -> List[... function _unflatten (line 175) | def _unflatten( function _flattened_transpose (line 188) | def _flattened_transpose( function tiled_matmul_fwd (line 206) | def tiled_matmul_fwd( function tiled_matmul_fwd_fake (line 223) | def tiled_matmul_fwd_fake( function tiled_matmul_setup_context (line 234) | def tiled_matmul_setup_context(ctx, inputs, output): function tiled_matmul_bwd (line 239) | def tiled_matmul_bwd(ctx, flat_grad_c): function tiled_matmul (line 266) | def tiled_matmul( FILE: xformers/ops/tree_attention.py class TreeAttnMetadata (line 34) | class TreeAttnMetadata: method from_tree_choices_cached (line 133) | def from_tree_choices_cached( method from_tree_choices (line 142) | def from_tree_choices( function _get_subtree_size_and_num_children_per_node_at_level (line 219) | def _get_subtree_size_and_num_children_per_node_at_level( function _get_depth_counts (line 239) | def _get_depth_counts(sorted_tree_choices: List[Tuple[int, ...]]) -> Lis... function _get_num_nodes_per_level (line 252) | def _get_num_nodes_per_level( function _prepare_tree_attn_bias (line 259) | def _prepare_tree_attn_bias( function _prepare_tree_indices (line 317) | def _prepare_tree_indices( function _prepare_retrieval_indices (line 348) | def _prepare_retrieval_indices( function _prepare_tree_position_ids (line 383) | def _prepare_tree_position_ids( function _prepare_parent_node_indices (line 404) | def _prepare_parent_node_indices( function _prepare_child_node_indices (line 416) | def _prepare_child_node_indices( function _prepare_candidate_idx (line 445) | def _prepare_candidate_idx( function use_triton_splitk_for_prefix (line 458) | def use_triton_splitk_for_prefix(B: int, G: int, tree_size: int) -> bool: function select_prefix_op (line 469) | def select_prefix_op( function tree_attention (line 513) | def tree_attention( class SplitKAutotune (line 661) | class SplitKAutotune(triton_splitk.FwOp): function construct_full_tree_choices (line 666) | def construct_full_tree_choices( function construct_tree_choices (line 679) | def construct_tree_choices( function get_full_tree_size (line 691) | def get_full_tree_size(tree_depth: int, branching: int) -> int: FILE: xformers/ops/unbind.py function get_stack_strides (line 13) | def get_stack_strides( function _stack_or_none_fw (line 59) | def _stack_or_none_fw( function _stack_fw (line 71) | def _stack_fw( class _Unbind (line 81) | class _Unbind(torch.autograd.Function): method forward (line 88) | def forward(ctx, x: torch.Tensor, dim: int): method backward (line 94) | def backward(cls, ctx, *tensors: torch.Tensor): class _StackOrNone (line 98) | class _StackOrNone(torch.autograd.Function): method forward (line 105) | def forward(ctx, dim: int, *tensors: torch.Tensor): method backward (line 111) | def backward(cls, ctx, grad: torch.Tensor): function unbind (line 115) | def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: function stack_or_none (line 124) | def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Te... FILE: xformers/profiler/api.py function profile (line 30) | def profile( function step (line 87) | def step() -> None: FILE: xformers/profiler/device_limits.py class DeviceLimit (line 14) | class DeviceLimit: function get_device_limits (line 104) | def get_device_limits(device) -> Optional[DeviceLimit]: FILE: xformers/profiler/find_slowest.py function print_json_as_dataframe (line 17) | def print_json_as_dataframe(json_list): function compute_std_dev_of_event_durations_over_ranks (line 44) | def compute_std_dev_of_event_durations_over_ranks(events, top=5): function sort_nccl_events (line 61) | def sort_nccl_events( function read_one_file (line 80) | def read_one_file(profile_trace_path: str) -> pd.DataFrame: function parse_one_file (line 106) | def parse_one_file(profile_trace_path: str) -> tuple[pd.DataFrame, pd.Da... function print_profiling_info (line 116) | def print_profiling_info(cuda_profile_dir: str): FILE: xformers/profiler/profile_analyzer.py class FakeKinetoEvent (line 14) | class FakeKinetoEvent: method __init__ (line 15) | def __init__(self, e: torch._C._autograd._KinetoEvent) -> None: function _attention_flops (line 23) | def _attention_flops(queries, values, causal: bool, fmt: str = "BHMK") -... function _get_arg_idx (line 44) | def _get_arg_idx(op, *arg_names: str) -> int: function _replace_if_needed (line 51) | def _replace_if_needed( class AnalyzedTrace (line 120) | class AnalyzedTrace: method compute_num_ops (line 125) | def compute_num_ops( method compute_hfu (line 135) | def compute_hfu(self, hardware_flops: Dict[torch.dtype, float]) -> float: method compute_mfu (line 141) | def compute_mfu(self, hardware_flops: Dict[torch.dtype, float]) -> float: method _find_all_root_events_with_flops (line 156) | def _find_all_root_events_with_flops( method from_profile (line 192) | def from_profile( FILE: xformers/profiler/profiler.py class NsightProfiler (line 30) | class NsightProfiler: method __init__ (line 38) | def __init__(self, main_profiler: "_Profiler") -> None: method __enter__ (line 42) | def __enter__(self): method __exit__ (line 45) | def __exit__(self, exc_type, exc_val, exc_tb): method step (line 48) | def step(self) -> None: class PyTorchProfiler (line 52) | class PyTorchProfiler: method __init__ (line 62) | def __init__(self, main_profiler: "_Profiler") -> None: method _on_trace (line 74) | def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: method _preprocess_trace (line 98) | def _preprocess_trace( method _analyze_trace (line 113) | def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None: method __enter__ (line 137) | def __enter__(self): method __exit__ (line 141) | def __exit__(self, exc_type, exc_val, exc_tb): method step (line 145) | def step(self) -> None: class PyTorchProfiler_CUDAOnly (line 150) | class PyTorchProfiler_CUDAOnly(PyTorchProfiler): method _analyze_trace (line 155) | def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None: class MemSnapshotsProfiler (line 160) | class MemSnapshotsProfiler: method __init__ (line 165) | def __init__(self, main_profiler: "_Profiler") -> None: method _has_trace_plot (line 170) | def _has_trace_plot(self) -> bool: method __enter__ (line 173) | def __enter__(self): method __exit__ (line 188) | def __exit__(self, exc_type, exc_val, exc_tb): method step (line 211) | def step(self) -> None: class _ProfilerState (line 216) | class _ProfilerState: class _Profiler (line 223) | class _Profiler: method __init__ (line 226) | def __init__( method init_schedule (line 246) | def init_schedule(self, offset: int = 0) -> None: method check_schedule (line 257) | def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> ... method update_profilers_on_step (line 284) | def update_profilers_on_step(self) -> None: method _create_output_filename (line 304) | def _create_output_filename(self, filename: str) -> Path: method start (line 318) | def start(self): method stop (line 321) | def stop(self, exc_type=None, exc_val=None, exc_tb=None): method __enter__ (line 324) | def __enter__(self): method __exit__ (line 332) | def __exit__(self, exc_type, exc_val, exc_tb): method step (line 339) | def step(self) -> None: method format_summary (line 369) | def format_summary(self) -> str: FILE: xformers/profiler/profiler_dcgm.py class DCGMProfiler (line 20) | class DCGMProfiler: # type: ignore method __init__ (line 23) | def __init__( method __enter__ (line 32) | def __enter__(self) -> None: method __exit__ (line 38) | def __exit__(self, exc_type, exc_val, exc_tb) -> None: method step (line 41) | def step(self) -> None: FILE: xformers/profiler/profiler_dcgm_impl.py class DCGMProfiler (line 18) | class DCGMProfiler: method __init__ (line 21) | def __init__( method create_dcgm_group (line 65) | def create_dcgm_group( method get_profilable_fields (line 96) | def get_profilable_fields(self) -> Set[int]: method create_profiling_field_group (line 107) | def create_profiling_field_group( method __enter__ (line 152) | def __enter__(self) -> None: method __exit__ (line 176) | def __exit__(self, exc_type, exc_val, exc_tb) -> None: method step (line 189) | def step(self) -> None: FILE: xformers/sparse/blocksparse_tensor.py function _spmm (line 16) | def _spmm(b, layout, values): function _softmax (line 40) | def _softmax(layout, values): function _sddmm (line 61) | def _sddmm(a, b, layout): class BlockSparseTensor (line 76) | class BlockSparseTensor(torch.Tensor): method __new__ (line 78) | def __new__(cls, values, layout): method __init__ (line 91) | def __init__(self, values, layout): method __repr__ (line 104) | def __repr__(self): method values (line 107) | def values(self): method _raw_wrap (line 111) | def _raw_wrap(cls, values, layout): method _wrap (line 118) | def _wrap(cls, values, bmat): method _bmm (line 125) | def _bmm(cls, arg0, arg1): method _masked_matmul (line 132) | def _masked_matmul(cls, a, b, mask): method _softmax (line 141) | def _softmax(cls, arg0, dim): method _to (line 148) | def _to(cls, arg0, device): method _copy (line 158) | def _copy(cls, arg0, arg1): method _equal (line 169) | def _equal(cls, arg0, arg1): method _to_dense (line 181) | def _to_dense(cls, arg0): method __torch_function__ (line 200) | def __torch_function__(cls, func, types, args=(), kwargs=None): method __torch_dispatch__ (line 277) | def __torch_dispatch__(cls, func, types, args, kwargs): FILE: xformers/sparse/utils.py function _coo_to_csr (line 10) | def _coo_to_csr(m, n, row_indices, column_indices): function _csr_to_coo (line 17) | def _csr_to_coo(m, n, row_offsets, column_indices): function _diffsort (line 25) | def _diffsort(a): function _get_transpose_info (line 29) | def _get_transpose_info(m, n, row_indices, row_offsets, column_indices): function _transpose_with_info (line 48) | def _transpose_with_info(values, _transpose_info): function _transpose (line 54) | def _transpose(m, n, row_indices, values, row_offsets, column_indices): function _nonzero_mask_to_sparse_csr_indices (line 61) | def _nonzero_mask_to_sparse_csr_indices(mask, device): function _dense_to_sparse (line 83) | def _dense_to_sparse(matrix, device): function _round_nnz (line 99) | def _round_nnz(mask, divisible_by=4): function _dense3d_to_sparse (line 108) | def _dense3d_to_sparse(matrix, device): FILE: xformers/triton/importing.py function libdevice_find (line 9) | def libdevice_find(name): FILE: xformers/triton/vararg_kernel.py class _ForLoopUnroller (line 19) | class _ForLoopUnroller(ast.NodeTransformer): method __init__ (line 20) | def __init__(self, target, inline_variables, loop_iter): method visit_Name (line 25) | def visit_Name(self, node): method visit_Subscript (line 30) | def visit_Subscript(self, node): class _VisitorVarargKernel (line 42) | class _VisitorVarargKernel(ast.NodeTransformer): method __init__ (line 43) | def __init__(self, N): method visit_AnnAssign (line 47) | def visit_AnnAssign(self, node): method visit_arguments (line 67) | def visit_arguments(self, node): class _VisitorUnrollKernel (line 90) | class _VisitorUnrollKernel(_VisitorVarargKernel): method visit_For (line 91) | def visit_For(self, node): class _VisitorConditionalKernel (line 119) | class _VisitorConditionalKernel(_VisitorVarargKernel): method __init__ (line 120) | def __init__(self, *args, **kwargs): method visit_Subscript (line 124) | def visit_Subscript(self, node): method visit_Call (line 148) | def visit_Call(self, node): function _monkey_patched_getlines (line 173) | def _monkey_patched_getlines(filename, module_globals=None): class VarargMode (line 180) | class VarargMode(Enum): function unroll_varargs (line 186) | def unroll_varargs(kernel, N: int, mode: VarargMode = VarargMode.UNROLL): FILE: xformers/utils.py function import_all_modules (line 20) | def import_all_modules(root: str, base_module: str) -> List[str]: function get_registry_decorator (line 33) | def get_registry_decorator( function generate_matching_config (line 68) | def generate_matching_config(superset: Dict[str, Any], config_class: Any... function do_bench_cudagraph (line 85) | def do_bench_cudagraph(