SYMBOL INDEX (447 symbols across 25 files) FILE: inf_cl/flash.py function _prob_fwd_kernel (line 12) | def _prob_fwd_kernel( function _dq_prob_bwd_kernel (line 70) | def _dq_prob_bwd_kernel( function _dk_prob_bwd_kernel (line 140) | def _dk_prob_bwd_kernel( function _flash_prob_forward (line 204) | def _flash_prob_forward(q, k): function _flash_prob_backward (line 242) | def _flash_prob_backward(q, k, lse, dlse): class FlashProb (line 306) | class FlashProb(torch.autograd.Function): method forward (line 309) | def forward(ctx, q, k): method backward (line 316) | def backward(ctx, dlse): function _cal_flash_loss (line 323) | def _cal_flash_loss(q, k, labels, head_dim=256): function cal_flash_loss (line 337) | def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256): FILE: inf_cl/ring.py class RingComm (line 17) | class RingComm: method __init__ (line 19) | def __init__(self, process_group: dist.ProcessGroup): method send_recv (line 33) | def send_recv(self, to_send, recv_tensor = None): method commit (line 45) | def commit(self): method wait (line 50) | def wait(self): class GradientGather (line 59) | class GradientGather(torch.autograd.Function): method forward (line 62) | def forward(ctx, x): method backward (line 67) | def backward(ctx, dx): class RingProb (line 72) | class RingProb(torch.autograd.Function): method forward (line 75) | def forward(ctx, q, k, group): method backward (line 109) | def backward(ctx, dlse): class InfProb (line 154) | class InfProb(torch.autograd.Function): method forward (line 157) | def forward(ctx, q, k, group): method backward (line 190) | def backward(ctx, dlse): function set_seed (line 231) | def set_seed(rank, seed=42): function _cal_ring_loss (line 239) | def _cal_ring_loss(q, k, labels, head_dim=256): function _cal_inf_loss (line 252) | def _cal_inf_loss(q, k, labels, head_dim=256): function cal_ring_loss (line 265) | def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256): function cal_inf_loss (line 289) | def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256): FILE: inf_clip/factory.py function _natural_key (line 32) | def _natural_key(string_): function _rescan_model_configs (line 36) | def _rescan_model_configs(): function list_models (line 60) | def list_models(): function add_model_config (line 65) | def add_model_config(path): function get_model_config (line 73) | def get_model_config(model_name): function _get_hf_config (line 80) | def _get_hf_config(model_id, cache_dir=None): function get_tokenizer (line 87) | def get_tokenizer( function load_state_dict (line 131) | def load_state_dict(checkpoint_path: str, map_location='cpu'): function load_checkpoint (line 146) | def load_checkpoint( function create_model (line 183) | def create_model( function create_loss (line 349) | def create_loss(args): function create_model_and_transforms (line 410) | def create_model_and_transforms( function create_model_from_pretrained (line 467) | def create_model_from_pretrained( FILE: inf_clip/models/clip_arch.py class CLIPVisionCfg (line 27) | class CLIPVisionCfg: class CLIPTextCfg (line 58) | class CLIPTextCfg: function get_cast_dtype (line 86) | def get_cast_dtype(precision: str): function get_input_dtype (line 95) | def get_input_dtype(precision: str): function _build_vision_tower (line 104) | def _build_vision_tower( function _build_text_tower (line 173) | def _build_text_tower( class CLIP (line 220) | class CLIP(nn.Module): method __init__ (line 224) | def __init__( method lock_image_tower (line 257) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): method set_grad_checkpointing (line 262) | def set_grad_checkpointing(self, enable=True): method encode_image (line 266) | def encode_image(self, image, normalize: bool = False): method encode_text (line 270) | def encode_text(self, text, normalize: bool = False): method get_logits (line 287) | def get_logits(self, image, text): method forward (line 296) | def forward( class CustomTextCLIP (line 319) | class CustomTextCLIP(nn.Module): method __init__ (line 323) | def __init__( method lock_image_tower (line 346) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): method lock_text_tower (line 350) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:... method set_grad_checkpointing (line 354) | def set_grad_checkpointing(self, enable=True): method encode_image (line 358) | def encode_image(self, image, normalize: bool = False): method encode_text (line 362) | def encode_text(self, text, normalize: bool = False): method get_logits (line 366) | def get_logits(self, image, text): method forward (line 375) | def forward( function convert_weights_to_lp (line 398) | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): function convert_to_custom_text_state_dict (line 432) | def convert_to_custom_text_state_dict(state_dict: dict): function build_model_from_openai_state_dict (line 450) | def build_model_from_openai_state_dict( function trace_model (line 509) | def trace_model(model, batch_size=256, device=torch.device('cpu')): function resize_pos_embed (line 525) | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', ... function resize_text_pos_embed (line 559) | def resize_text_pos_embed(state_dict, model, interpolation: str = 'linea... function get_model_preprocess_cfg (line 591) | def get_model_preprocess_cfg(model): function set_model_preprocess_cfg (line 608) | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): function get_model_tokenize_cfg (line 615) | def get_model_tokenize_cfg(model): FILE: inf_clip/models/coca_arch.py class MultimodalCfg (line 47) | class MultimodalCfg(CLIPTextCfg): function _build_text_decoder_tower (line 55) | def _build_text_decoder_tower( function _token_to_tensor (line 81) | def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: class CoCa (line 89) | class CoCa(nn.Module): method __init__ (line 92) | def __init__( method set_grad_checkpointing (line 146) | def set_grad_checkpointing(self, enable: bool = True): method _encode_image (line 151) | def _encode_image(self, images, normalize: bool = True): method _encode_text (line 156) | def _encode_text(self, text, normalize: bool = True): method encode_image (line 161) | def encode_image(self, images, normalize: bool = True): method encode_text (line 165) | def encode_text(self, text, normalize: bool = True): method forward (line 169) | def forward( method generate (line 204) | def generate( method _generate_beamsearch (line 331) | def _generate_beamsearch( function prepare_inputs_for_generation (line 481) | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **... FILE: inf_clip/models/hf_model.py class BaseModelOutput (line 22) | class BaseModelOutput: class PretrainedConfig (line 26) | class PretrainedConfig: function _camel2snake (line 33) | def _camel2snake(s): function register_pooler (line 41) | def register_pooler(cls): class MeanPooler (line 48) | class MeanPooler(nn.Module): method forward (line 51) | def forward(self, x: BaseModelOutput, attention_mask: TensorType): class MaxPooler (line 57) | class MaxPooler(nn.Module): method forward (line 60) | def forward(self, x: BaseModelOutput, attention_mask: TensorType): class ClsPooler (line 66) | class ClsPooler(nn.Module): method __init__ (line 69) | def __init__(self, use_pooler_output=True): method forward (line 74) | def forward(self, x: BaseModelOutput, attention_mask: TensorType): class ClsLastHiddenStatePooler (line 85) | class ClsLastHiddenStatePooler(nn.Module): method __init__ (line 90) | def __init__(self): method forward (line 94) | def forward(self, x: BaseModelOutput, attention_mask: TensorType): class HFTextEncoder (line 98) | class HFTextEncoder(nn.Module): method __init__ (line 102) | def __init__( method forward (line 166) | def forward(self, x: TensorType): method lock (line 183) | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): method set_grad_checkpointing (line 201) | def set_grad_checkpointing(self, enable=True): method init_parameters (line 204) | def init_parameters(self): FILE: inf_clip/models/lit_arch.py class LiTVisionCfg (line 14) | class LiTVisionCfg: class LiTTextCfg (line 45) | class LiTTextCfg: class LiT (line 73) | class LiT(nn.Module): method __init__ (line 77) | def __init__( method get_embed_dim (line 103) | def get_embed_dim(self): method lock_image_tower (line 106) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): method lock_text_tower (line 110) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:... method set_grad_checkpointing (line 114) | def set_grad_checkpointing(self, enable=True): method encode_image (line 118) | def encode_image(self, image, normalize: bool = False): method encode_trunk_image (line 122) | def encode_trunk_image(self, image, normalize: bool = False): method project_image (line 127) | def project_image(self, trunk_features, normalize: bool = False): method encode_text (line 131) | def encode_text(self, text, normalize: bool = False): method get_logits (line 135) | def get_logits(self, image, text): method forward (line 144) | def forward( FILE: inf_clip/models/loss.py function gather_features (line 21) | def gather_features( function all_reduce (line 70) | def all_reduce(tensor): class ClipLoss (line 79) | class ClipLoss(nn.Module): method __init__ (line 81) | def __init__( method get_ground_truth (line 102) | def get_ground_truth(self, device, num_logits) -> torch.Tensor: method get_logits (line 115) | def get_logits(self, image_features, text_features, logit_scale): method forward (line 133) | def forward(self, image_features, text_features, logit_scale): class DiscoClipLoss (line 152) | class DiscoClipLoss(nn.Module): method __init__ (line 154) | def __init__( method get_ground_truth (line 169) | def get_ground_truth(self, device, num_logits) -> torch.Tensor: method forward (line 180) | def forward(self, image_features, text_features, logit_scale): class FlashClipLoss (line 204) | class FlashClipLoss(nn.Module): method __init__ (line 206) | def __init__( method get_ground_truth (line 221) | def get_ground_truth(self, device, num_logits) -> torch.Tensor: method forward (line 231) | def forward(self, image_features, text_features, logit_scale): class RingClipLoss (line 251) | class RingClipLoss(nn.Module): method __init__ (line 253) | def __init__( method forward (line 267) | def forward(self, image_features, text_features, logit_scale): class InfClipLoss (line 291) | class InfClipLoss(nn.Module): method __init__ (line 293) | def __init__( method forward (line 307) | def forward(self, image_features, text_features, logit_scale): class CoCaLoss (line 384) | class CoCaLoss(ClipLoss): method __init__ (line 385) | def __init__( method forward (line 410) | def forward(self, image_features, text_features, logits, labels, logit... class DistillClipLoss (line 430) | class DistillClipLoss(ClipLoss): method dist_loss (line 432) | def dist_loss(self, teacher_logits, student_logits): method forward (line 435) | def forward( function neighbour_exchange (line 469) | def neighbour_exchange(from_rank, to_rank, tensor, group=None): function neighbour_exchange_bidir (line 489) | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tens... class NeighbourExchange (line 522) | class NeighbourExchange(torch.autograd.Function): method forward (line 524) | def forward(ctx, from_rank, to_rank, group, tensor): method backward (line 531) | def backward(ctx, grad_output): function neighbour_exchange_with_grad (line 535) | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): class NeighbourExchangeBidir (line 539) | class NeighbourExchangeBidir(torch.autograd.Function): method forward (line 541) | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_... method backward (line 548) | def backward(ctx, *grad_outputs): function neighbour_exchange_bidir_with_grad (line 553) | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_... class SigLipLoss (line 557) | class SigLipLoss(nn.Module): method __init__ (line 567) | def __init__( method get_ground_truth (line 587) | def get_ground_truth(self, device, dtype, num_logits, negative_only=Fa... method get_logits (line 593) | def get_logits(self, image_features, text_features, logit_scale, logit... method _loss (line 599) | def _loss(self, image_features, text_features, logit_scale, logit_bias... method forward (line 610) | def forward(self, image_features, text_features, logit_scale, logit_bi... FILE: inf_clip/models/modified_resnet.py class Bottleneck (line 10) | class Bottleneck(nn.Module): method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1): method forward (line 42) | def forward(self, x: torch.Tensor): class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module): method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o... method forward (line 68) | def forward(self, x): class ModifiedResNet (line 95) | class ModifiedResNet(nn.Module): method __init__ (line 103) | def __init__(self, layers, output_dim, heads, image_size=224, width=64): method _make_layer (line 132) | def _make_layer(self, planes, blocks, stride=1): method init_parameters (line 141) | def init_parameters(self): method lock (line 154) | def lock(self, unlocked_groups=0, freeze_bn_stats=False): method set_grad_checkpointing (line 162) | def set_grad_checkpointing(self, enable=True): method stem (line 166) | def stem(self, x): method forward (line 173) | def forward(self, x): FILE: inf_clip/models/pos_embed.py function get_2d_sincos_pos_embed (line 20) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): function get_2d_sincos_pos_embed_from_grid (line 38) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): function interpolate_pos_embed (line 75) | def interpolate_pos_embed(model, checkpoint_model): FILE: inf_clip/models/timm_model.py class TimmModel (line 28) | class TimmModel(nn.Module): method __init__ (line 32) | def __init__( method lock (line 110) | def lock(self, unlocked_groups=0, freeze_bn_stats=False): method set_grad_checkpointing (line 143) | def set_grad_checkpointing(self, enable=True): method forward_trunk (line 149) | def forward_trunk(self, x): method forward_head (line 152) | def forward_head(self, x): method forward (line 155) | def forward(self, x): FILE: inf_clip/models/tokenizer.py function default_bpe (line 27) | def default_bpe(): function bytes_to_unicode (line 32) | def bytes_to_unicode(): function get_pairs (line 54) | def get_pairs(word): function basic_clean (line 66) | def basic_clean(text): function whitespace_clean (line 72) | def whitespace_clean(text): function _clean_canonicalize (line 78) | def _clean_canonicalize(x): function _clean_lower (line 83) | def _clean_lower(x): function _clean_whitespace (line 88) | def _clean_whitespace(x): function get_clean_fn (line 93) | def get_clean_fn(type: str): function canonicalize_text (line 104) | def canonicalize_text( class SimpleTokenizer (line 133) | class SimpleTokenizer(object): method __init__ (line 134) | def __init__( method bpe (line 172) | def bpe(self, token): method encode (line 213) | def encode(self, text): method decode (line 221) | def decode(self, tokens): method __call__ (line 226) | def __call__(self, texts: Union[str, List[str]], context_length: Optio... function decode (line 271) | def decode(output_ids: torch.Tensor): function tokenize (line 276) | def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT... function random_mask_tokenize (line 280) | def random_mask_tokenize( function simple_mask_tokenize (line 309) | def simple_mask_tokenize( function syntax_mask_tokenize (line 331) | def syntax_mask_tokenize( function get_reduction_mask_fn (line 390) | def get_reduction_mask_fn(type: str): class HFTokenizer (line 403) | class HFTokenizer: method __init__ (line 406) | def __init__( method save_pretrained (line 426) | def save_pretrained(self, dest): method __call__ (line 429) | def __call__(self, texts: Union[str, List[str]], context_length: Optio... method set_language (line 456) | def set_language(self, src_lang): class SigLipTokenizer (line 463) | class SigLipTokenizer: method __init__ (line 473) | def __init__( method save_pretrained (line 497) | def save_pretrained(self, dest): method __call__ (line 500) | def __call__(self, texts: Union[str, List[str]], context_length: Optio... FILE: inf_clip/models/transform.py class PreprocessCfg (line 17) | class PreprocessCfg: method __post_init__ (line 26) | def __post_init__(self): method num_channels (line 30) | def num_channels(self): method input_size (line 34) | def input_size(self): function merge_preprocess_dict (line 40) | def merge_preprocess_dict( function merge_preprocess_kwargs (line 57) | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): class AugmentationCfg (line 62) | class AugmentationCfg: function _setup_size (line 75) | def _setup_size(size, error_msg): class ResizeKeepRatio (line 88) | class ResizeKeepRatio: method __init__ (line 94) | def __init__( method get_params (line 116) | def get_params( method __call__ (line 144) | def __call__(self, img): method __repr__ (line 160) | def __repr__(self): function center_crop_or_pad (line 167) | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0... class CenterCropOrPad (line 207) | class CenterCropOrPad(torch.nn.Module): method __init__ (line 219) | def __init__(self, size, fill=0): method forward (line 224) | def forward(self, img): method __repr__ (line 234) | def __repr__(self) -> str: function _convert_to_rgb (line 238) | def _convert_to_rgb(image): class color_jitter (line 242) | class color_jitter(object): method __init__ (line 246) | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., ... method __call__ (line 251) | def __call__(self, img): class gray_scale (line 258) | class gray_scale(object): method __init__ (line 262) | def __init__(self, p=0.2): method __call__ (line 267) | def __call__(self, img): function image_transform (line 274) | def image_transform( function image_transform_v2 (line 393) | def image_transform_v2( FILE: inf_clip/models/transformer.py class LayerNormFp32 (line 15) | class LayerNormFp32(nn.LayerNorm): method forward (line 18) | def forward(self, x: torch.Tensor): class LayerNorm (line 24) | class LayerNorm(nn.LayerNorm): method forward (line 27) | def forward(self, x: torch.Tensor): class QuickGELU (line 33) | class QuickGELU(nn.Module): method forward (line 35) | def forward(self, x: torch.Tensor): class LayerScale (line 39) | class LayerScale(nn.Module): method __init__ (line 40) | def __init__(self, dim, init_values=1e-5, inplace=False): method forward (line 45) | def forward(self, x): class PatchDropout (line 49) | class PatchDropout(nn.Module): method __init__ (line 54) | def __init__(self, prob, exclude_first_token=True): method forward (line 60) | def forward(self, x): class Attention (line 89) | class Attention(nn.Module): method __init__ (line 90) | def __init__( method forward (line 132) | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): class AttentionalPooler (line 187) | class AttentionalPooler(nn.Module): method __init__ (line 188) | def __init__( method forward (line 202) | def forward(self, x: torch.Tensor): class ResidualAttentionBlock (line 210) | class ResidualAttentionBlock(nn.Module): method __init__ (line 211) | def __init__( method attention (line 239) | def attention( method forward (line 254) | def forward( class CustomResidualAttentionBlock (line 268) | class CustomResidualAttentionBlock(nn.Module): method __init__ (line 269) | def __init__( method get_reference_weight (line 306) | def get_reference_weight(self): method forward (line 309) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =... function _expand_token (line 315) | def _expand_token(token, batch_size: int): class Transformer (line 319) | class Transformer(nn.Module): method __init__ (line 320) | def __init__( method get_cast_dtype (line 350) | def get_cast_dtype(self) -> torch.dtype: method forward (line 355) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =... class CustomTransformer (line 369) | class CustomTransformer(nn.Module): method __init__ (line 371) | def __init__( method get_cast_dtype (line 412) | def get_cast_dtype(self) -> torch.dtype: method forward (line 418) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =... class VisionTransformer (line 434) | class VisionTransformer(nn.Module): method __init__ (line 437) | def __init__( method lock (line 541) | def lock(self, unlocked_groups=0, freeze_bn_stats=False): method init_parameters (line 574) | def init_parameters(self): method set_grad_checkpointing (line 595) | def set_grad_checkpointing(self, enable=True): method _global_pool (line 598) | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.T... method forward (line 608) | def forward(self, x: torch.Tensor): function text_global_pool (line 653) | def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: ... class TextTransformer (line 668) | class TextTransformer(nn.Module): method __init__ (line 671) | def __init__( method init_parameters (line 731) | def init_parameters(self): method set_grad_checkpointing (line 755) | def set_grad_checkpointing(self, enable=True): method build_causal_mask (line 758) | def build_causal_mask(self): method build_cls_mask (line 766) | def build_cls_mask(self, text, cast_dtype: torch.dtype): method forward (line 775) | def forward(self, text): class MultimodalTransformer (line 812) | class MultimodalTransformer(Transformer): method __init__ (line 813) | def __init__( method init_parameters (line 856) | def init_parameters(self): method build_attention_mask (line 874) | def build_attention_mask(self): method forward (line 882) | def forward(self, image_embs, text_embs): method set_grad_checkpointing (line 907) | def set_grad_checkpointing(self, enable=True): FILE: inf_clip/openai.py function list_openai_models (line 20) | def list_openai_models() -> List[str]: function load_openai_model (line 25) | def load_openai_model( FILE: inf_clip/pretrained.py function _pcfg (line 34) | def _pcfg(url='', hf_hub='', **kwargs): function _slpcfg (line 47) | def _slpcfg(url='', hf_hub='', **kwargs): function _apcfg (line 60) | def _apcfg(url='', hf_hub='', **kwargs): function _mccfg (line 73) | def _mccfg(url='', hf_hub='', **kwargs): function _clean_tag (line 519) | def _clean_tag(tag: str): function list_pretrained (line 524) | def list_pretrained(as_str: bool = False): function list_pretrained_models_by_tag (line 531) | def list_pretrained_models_by_tag(tag: str): function list_pretrained_tags_by_model (line 541) | def list_pretrained_tags_by_model(model: str): function is_pretrained_cfg (line 549) | def is_pretrained_cfg(model: str, tag: str): function get_pretrained_cfg (line 555) | def get_pretrained_cfg(model: str, tag: str): function get_pretrained_url (line 562) | def get_pretrained_url(model: str, tag: str): function download_pretrained_from_url (line 567) | def download_pretrained_from_url( function has_hf_hub (line 613) | def has_hf_hub(necessary=False): function download_pretrained_from_hf (line 621) | def download_pretrained_from_hf( function download_pretrained (line 632) | def download_pretrained( function load_big_vision_weights (line 664) | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): function convert_mobile_clip_state_dict (line 793) | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fa... function convert_state_dict (line 834) | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): FILE: inf_clip/train/data.py class CsvDataset (line 24) | class CsvDataset(Dataset): method __init__ (line 25) | def __init__(self, input_filename, transforms, img_key, caption_key, s... method __len__ (line 36) | def __len__(self): method __getitem__ (line 39) | def __getitem__(self, idx): class SharedEpoch (line 45) | class SharedEpoch: method __init__ (line 46) | def __init__(self, epoch: int = 0): method set_value (line 49) | def set_value(self, epoch): method get_value (line 52) | def get_value(self): class DataInfo (line 57) | class DataInfo: method set_epoch (line 62) | def set_epoch(self, epoch): function expand_urls (line 69) | def expand_urls(urls, weights=None): function get_dataset_size (line 91) | def get_dataset_size(shards): function get_imagenet (line 113) | def get_imagenet(args, preprocess_fns, split): function count_samples (line 160) | def count_samples(dataloader): function filter_no_caption_or_no_image (line 170) | def filter_no_caption_or_no_image(sample): function log_and_continue (line 176) | def log_and_continue(exn): function group_by_keys_nothrow (line 182) | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes... function tarfile_to_samples_nothrow (line 214) | def tarfile_to_samples_nothrow(src, handler=log_and_continue): function pytorch_worker_seed (line 222) | def pytorch_worker_seed(increment=0): function json_fetch (line 236) | def json_fetch(data, key='caption'): class detshuffle2 (line 262) | class detshuffle2(wds.PipelineStage): method __init__ (line 263) | def __init__( method run (line 275) | def run(self, src): class ResampledShards2 (line 294) | class ResampledShards2(IterableDataset): method __init__ (line 297) | def __init__( method __iter__ (line 324) | def __iter__(self): function get_wds_dataset (line 348) | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False... function get_csv_dataset (line 468) | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=No... class SyntheticDataset (line 498) | class SyntheticDataset(Dataset): method __init__ (line 500) | def __init__( method __len__ (line 516) | def __len__(self): method __getitem__ (line 519) | def __getitem__(self, idx): function get_synthetic_dataset (line 525) | def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokeni... function get_dataset_fn (line 548) | def get_dataset_fn(data_path, dataset_type): function get_data (line 568) | def get_data(args, preprocess_fns, epoch=0, tokenizer=None): FILE: inf_clip/train/engine.py function accuracy (line 28) | def accuracy(output, target, topk=(1,)): function get_clip_metrics (line 34) | def get_clip_metrics(image_features, text_features, logit_scale): function maybe_compute_generative_loss (line 54) | def maybe_compute_generative_loss(model_out): function get_memory (line 61) | def get_memory(): function seconds_to_hms (line 70) | def seconds_to_hms(seconds): function cal_grad_norm (line 77) | def cal_grad_norm(model): function assign_learning_rate (line 87) | def assign_learning_rate(optimizer, new_lr): function _warmup_lr (line 92) | def _warmup_lr(base_lr, warmup_length, step): function const_lr (line 96) | def const_lr(optimizer, base_lr, warmup_length, steps): function const_lr_cooldown (line 107) | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown... function cosine_lr (line 126) | def cosine_lr(optimizer, base_lr, warmup_length, steps): function postprocess_clip_output (line 139) | def postprocess_clip_output(model_out): function unwrap_model (line 147) | def unwrap_model(model): function backward (line 154) | def backward(total_loss, scaler): class AverageMeter (line 161) | class AverageMeter(object): method __init__ (line 164) | def __init__(self): method reset (line 167) | def reset(self): method update (line 173) | def update(self, val, n=1): class GradientAccum (line 180) | class GradientAccum: method __init__ (line 182) | def __init__(self, model, loss, scaler, autocast, input_dtype, device): method clear (line 203) | def clear(self): method clear_state (line 208) | def clear_state(self): method accum_inference (line 216) | def accum_inference(self, images, texts): method accum_forward_backward (line 246) | def accum_forward_backward(self): class GradientCache (line 292) | class GradientCache: method __init__ (line 294) | def __init__(self, model, loss, scaler, autocast, input_dtype, device): method clear (line 315) | def clear(self): method clear_state (line 320) | def clear_state(self): method forward_backward (line 327) | def forward_backward(self, images, texts): method accum_inference (line 345) | def accum_inference(self, images, texts): method accum_forward_backward (line 376) | def accum_forward_backward(self): function train_one_epoch (line 432) | def train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer... function evaluate (line 573) | def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): function zero_shot_run (line 682) | def zero_shot_run(model, classifier, dataloader, args): function zero_shot_eval (line 709) | def zero_shot_eval(model, data, epoch, args, tokenizer=None): FILE: inf_clip/train/main.py function random_seed (line 42) | def random_seed(seed=42, rank=0): function natural_key (line 50) | def natural_key(string_): function copy_codebase (line 55) | def copy_codebase(args): function prepare_logging (line 72) | def prepare_logging(args): function get_latest_checkpoint (line 127) | def get_latest_checkpoint(path: str, remote : bool): function prepare_resuming (line 143) | def prepare_resuming(args): function prepare_remote_sync (line 178) | def prepare_remote_sync(args): function prepare_model (line 205) | def prepare_model(args, device): function prepare_optimizer_scaler (line 308) | def prepare_optimizer_scaler(args, model): function prepare_scheduler (line 360) | def prepare_scheduler(args, optimizer, num_batches): function main (line 383) | def main(args): FILE: inf_clip/train/optims.py class ScalingViTAdafactor (line 8) | class ScalingViTAdafactor(Optimizer): method __init__ (line 18) | def __init__( method _get_lr (line 52) | def _get_lr(param_group, param_state): method _get_options (line 63) | def _get_options(param_group, param_shape): method _rms (line 69) | def _rms(tensor): method _approx_sq_grad (line 73) | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): method step (line 81) | def step(self, closure=None): class Lion (line 179) | class Lion(Optimizer): method __init__ (line 184) | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): method step (line 206) | def step(self, closure=None): FILE: inf_clip/train/params.py function get_default_params (line 9) | def get_default_params(model_name): class ParseKwargs (line 18) | class ParseKwargs(argparse.Action): method __call__ (line 19) | def __call__(self, parser, namespace, values, option_string=None): function parse_args (line 30) | def parse_args(args): function create_deepspeed_config (line 507) | def create_deepspeed_config(args): FILE: inf_clip/train/utils.py function setup_logging (line 19) | def setup_logging(log_file, level, include_host=False): function remote_sync_s3 (line 43) | def remote_sync_s3(local_dir, remote_dir): function remote_sync_fsspec (line 54) | def remote_sync_fsspec(local_dir, remote_dir): function remote_sync (line 79) | def remote_sync(local_dir, remote_dir, protocol): function keep_running_remote_sync (line 90) | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): function start_sync_process (line 96) | def start_sync_process(sync_every, local_dir, remote_dir, protocol): function pt_save (line 102) | def pt_save(pt_obj, file_path): function pt_load (line 108) | def pt_load(file_path, map_location=None): function check_exists (line 117) | def check_exists(file_path): function get_autocast (line 126) | def get_autocast(precision): function is_global_master (line 140) | def is_global_master(args): function is_local_master (line 144) | def is_local_master(args): function is_master (line 148) | def is_master(args, local=False): function is_using_horovod (line 152) | def is_using_horovod(): function is_using_distributed (line 163) | def is_using_distributed(): function world_info_from_env (line 171) | def world_info_from_env(): function init_distributed_device (line 191) | def init_distributed_device(args): function broadcast_object (line 254) | def broadcast_object(args, obj, src=0): function all_gather_object (line 267) | def all_gather_object(args, obj, dst=0): FILE: inf_clip/utils.py function freeze_batch_norm_2d (line 9) | def freeze_batch_norm_2d(module, module_match={}, name=''): function _ntuple (line 49) | def _ntuple(n): function replace_linear (line 65) | def replace_linear(model, linear_replacement, include_modules=['c_fc', '... function convert_int8_model_to_inference_mode (line 84) | def convert_int8_model_to_inference_mode(model): FILE: inf_clip/zero_shot_classifier.py function batched (line 9) | def batched(iterable, n): function build_zero_shot_classifier (line 21) | def build_zero_shot_classifier( function build_zero_shot_classifier_legacy (line 71) | def build_zero_shot_classifier_legacy( FILE: tests/example.py function create_cl_tensors (line 9) | def create_cl_tensors(rank, world_size):