SYMBOL INDEX (136 symbols across 17 files) FILE: sae_training/activations_store.py class ActivationsStore (line 11) | class ActivationsStore: method __init__ (line 16) | def __init__( method get_batch_tokens (line 62) | def get_batch_tokens(self): method get_activations (line 139) | def get_activations(self, batch_tokens, get_loss=False): method get_buffer (line 172) | def get_buffer(self, n_batches_in_buffer): method get_data_loader (line 278) | def get_data_loader(self,) -> DataLoader: method next_batch (line 339) | def next_batch(self): FILE: sae_training/config.py class RunnerConfig (line 11) | class RunnerConfig(ABC): method __post_init__ (line 54) | def __post_init__(self): class LanguageModelSAERunnerConfig (line 63) | class LanguageModelSAERunnerConfig(RunnerConfig): method __post_init__ (line 113) | def __post_init__(self): class CacheActivationsRunnerConfig (line 187) | class CacheActivationsRunnerConfig(RunnerConfig): method __post_init__ (line 198) | def __post_init__(self): FILE: sae_training/geom_median/src/geom_median/numpy/main.py function compute_geometric_median (line 7) | def compute_geometric_median( FILE: sae_training/geom_median/src/geom_median/numpy/utils.py function check_list_of_array_format (line 4) | def check_list_of_array_format(points): function check_list_of_list_of_array_format (line 7) | def check_list_of_list_of_array_format(points): function check_shapes_compatibility (line 13) | def check_shapes_compatibility(list_of_arrays, i): FILE: sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py function geometric_median_array (line 4) | def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=... function geometric_median_per_component (line 42) | def geometric_median_per_component(points, weights, eps=1e-6, maxiter=10... function weighted_average (line 67) | def weighted_average(points, weights): function geometric_median_objective (line 77) | def geometric_median_objective(median, points, weights): FILE: sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py function geometric_median_list_of_array (line 4) | def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=10... function weighted_average (line 42) | def weighted_average(points, weights): function geometric_median_objective (line 45) | def geometric_median_objective(median, points, weights): function l2distance (line 49) | def l2distance(p1, p2): function subtract (line 52) | def subtract(p1, p2): FILE: sae_training/geom_median/src/geom_median/torch/main.py function compute_geometric_median (line 7) | def compute_geometric_median( FILE: sae_training/geom_median/src/geom_median/torch/utils.py function check_list_of_array_format (line 4) | def check_list_of_array_format(points): function check_list_of_list_of_array_format (line 7) | def check_list_of_list_of_array_format(points): function check_shapes_compatibility (line 13) | def check_shapes_compatibility(list_of_arrays, i): FILE: sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py function geometric_median_array (line 8) | def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=... function geometric_median_per_component (line 53) | def geometric_median_per_component(points, weights, eps=1e-6, maxiter=10... function weighted_average (line 81) | def weighted_average(points, weights): function geometric_median_objective (line 89) | def geometric_median_objective(median, points, weights): FILE: sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py function geometric_median_list_of_array (line 5) | def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=10... function weighted_average_component (line 49) | def weighted_average_component(points, weights): function weighted_average (line 55) | def weighted_average(points, weights): function geometric_median_objective (line 60) | def geometric_median_objective(median, points, weights): function l2distance (line 64) | def l2distance(p1, p2): FILE: sae_training/optim.py function get_scheduler (line 15) | def get_scheduler( FILE: sae_training/sparse_autoencoder.py class SparseAutoencoder (line 23) | class SparseAutoencoder(HookedRootModule): method __init__ (line 27) | def __init__( method forward (line 107) | def forward(self, x, dead_neuron_mask = None, mse_target=None): method get_sparse_connection_loss (line 183) | def get_sparse_connection_loss(self): method initialize_b_dec (line 190) | def initialize_b_dec(self, activation_store): method initialize_b_dec_with_geometric_median (line 202) | def initialize_b_dec_with_geometric_median(self, activation_store): method initialize_b_dec_with_mean (line 245) | def initialize_b_dec_with_mean(self, activation_store): method resample_neurons_l2 (line 278) | def resample_neurons_l2( method resample_neurons_anthropic (line 367) | def resample_neurons_anthropic( method collect_anthropic_resampling_losses (line 459) | def collect_anthropic_resampling_losses(self, model, activation_store): method get_test_loss (line 512) | def get_test_loss(self, batch_tokens, model): method set_decoder_norm_to_unit_norm (line 559) | def set_decoder_norm_to_unit_norm(self): method remove_gradient_parallel_to_decoder_directions (line 563) | def remove_gradient_parallel_to_decoder_directions(self): method save_model (line 581) | def save_model(self, path: str): method load_from_pretrained (line 607) | def load_from_pretrained(cls, path: str): method get_name (line 653) | def get_name(self): FILE: sae_training/train_sae_on_language_model.py function train_sae_on_language_model (line 16) | def train_sae_on_language_model( function run_evals (line 293) | def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: A... function get_recons_loss (line 410) | def get_recons_loss(sparse_autoencoder, model, activation_store, batch_t... function mean_ablate_hook (line 441) | def mean_ablate_hook(mlp_post, hook): function zero_ablate_hook (line 446) | def zero_ablate_hook(mlp_post, hook): function kl_divergence_attention (line 451) | def kl_divergence_attention(y_true, y_pred): FILE: sae_training/utils.py class LMSparseAutoencoderSessionloader (line 11) | class LMSparseAutoencoderSessionloader(): method __init__ (line 19) | def __init__(self, cfg: LanguageModelSAERunnerConfig): method load_session (line 23) | def load_session(self) -> Tuple[HookedTransformer, SparseAutoencoder, ... method load_session_from_pretrained (line 36) | def load_session_from_pretrained(cls, path: str) -> Tuple[HookedTransf... method get_model (line 53) | def get_model(self, model_name: str): method initialize_sparse_autoencoder (line 64) | def initialize_sparse_autoencoder(self, cfg: LanguageModelSAERunnerCon... method get_activations_loader (line 73) | def get_activations_loader(self, cfg: LanguageModelSAERunnerConfig, mo... function shuffle_activations_pairwise (line 84) | def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[... FILE: transcoder_circuits/circuit_analysis.py function get_attn_head_contribs (line 13) | def get_attn_head_contribs(model, cache, layer, range_normal): function get_transcoder_ixg (line 37) | def get_transcoder_ixg(transcoder, cache, range_normal, input_layer, inp... function get_mean_ixg (line 57) | def get_mean_ixg(model, tokens_arr, range_transcoder, range_feature_idx,... function get_ln_constant (line 113) | def get_ln_constant(model, cache, vector, layer, token, is_ln2=False, re... class ComponentType (line 134) | class ComponentType(enum.Enum): class FeatureType (line 144) | class FeatureType(enum.Enum): class ContribType (line 149) | class ContribType(enum.Enum): class Component (line 155) | class Component: method __str__ (line 166) | def __str__(self, show_token=True): method __repr__ (line 184) | def __repr__(self): class FeatureVector (line 190) | class FeatureVector: method __post_init__ (line 208) | def __post_init__(self): method __str__ (line 214) | def __str__(self, show_full=True, show_contrib=True, show_last_token=T... method __repr__ (line 227) | def __repr__(self): function make_sae_feature_vector (line 232) | def make_sae_feature_vector(sae, feature_idx, use_encoder=True, token=-1): function get_top_transcoder_features (line 273) | def get_top_transcoder_features(model, transcoder, cache, feature_vector... function get_top_contribs (line 317) | def get_top_contribs(model, transcoders, cache, feature_vector, k=5, ign... function greedy_get_top_paths (line 399) | def greedy_get_top_paths(model, transcoders, cache, feature_vector, num_... function print_all_paths (line 436) | def print_all_paths(paths): function get_raw_top_features_among_paths (line 453) | def get_raw_top_features_among_paths(all_paths, use_tokens=True, top_k=5... class FilterType (line 494) | class FilterType(enum.Enum): class FeatureFilter (line 503) | class FeatureFilter: method match (line 522) | def match(self, feature): function flatten_nested_list (line 554) | def flatten_nested_list(x): function get_paths_via_filter (line 557) | def get_paths_via_filter(all_paths, infix_path=None, not_infix_path=None... function path_to_str (line 622) | def path_to_str(path, show_contrib=False, show_last_token=False): function paths_to_graph (line 628) | def paths_to_graph(all_paths): function add_error_nodes_to_graph (line 685) | def add_error_nodes_to_graph(model, cache, transcoders, edges, nodes, do... function sum_over_tokens (line 794) | def sum_over_tokens(edges, nodes): function layer_to_float (line 817) | def layer_to_float(feature): function nodes_to_coords (line 824) | def nodes_to_coords(nodes, y_jitter=0.3, y_mult=1.0): function get_contribs_in_graph (line 844) | def get_contribs_in_graph(edges, nodes): function plot_graph (line 859) | def plot_graph(edges, nodes, y_mult=1.0, width=800, height=600, arrow_wi... FILE: transcoder_circuits/feature_dashboards.py function get_feature_scores (line 12) | def get_feature_scores(model, encoder, tokens_arr, feature_idx, batch_si... function sample_percentiles (line 39) | def sample_percentiles(arr, num_samples): function sample_uniform (line 63) | def sample_uniform(arr, num_samples, unique=True, use_tqdm=False, only_m... function make_sequence_html (line 103) | def make_sequence_html(token_strs, scores, function get_uniform_band_examples (line 174) | def get_uniform_band_examples(scores, uniform_vals, uniform_idxs, num_ba... function display_activating_examples_dash (line 193) | def display_activating_examples_dash(model, all_tokens, scores, function get_logits_for_feature (line 221) | def get_logits_for_feature(model, sae, feature_idx, k=7): function batch_color_interpolate (line 238) | def batch_color_interpolate(scores, max_color, zero_color, scores_min=No... function display_logits_for_feature (line 250) | def display_logits_for_feature(model, sae, feature_idx, k=7): function plot_pulledback_feature (line 301) | def plot_pulledback_feature(model, feature_vector, transcoder, size=None... function get_transcoder_pullback_features (line 343) | def get_transcoder_pullback_features(model, feature_vector, transcoder, ... function display_transcoder_pullback_features (line 361) | def display_transcoder_pullback_features(model, feature_vector, transcod... function get_ov_norms_for_transcoder_feature (line 408) | def get_ov_norms_for_transcoder_feature(model, transcoder, feature_idx, ... function get_deembeddings_for_transcoder_feature (line 427) | def get_deembeddings_for_transcoder_feature(model, transcoder, feature_i... function get_deembeddings_for_feature_vector (line 449) | def get_deembeddings_for_feature_vector(model, feature_vector, k=7): function plot_deembedding_for_transcoder_feature (line 468) | def plot_deembedding_for_transcoder_feature(model, transcoder, feature_i... function display_deembeddings_for_transcoder_feature (line 483) | def display_deembeddings_for_transcoder_feature(model, transcoder, featu... function display_deembeddings_for_feature_vector (line 528) | def display_deembeddings_for_feature_vector(model, feature_vector, k=7): function display_analysis_for_transcoder_feature (line 573) | def display_analysis_for_transcoder_feature(model, transcoder, feature_i... FILE: transcoder_circuits/replacement_ctx.py class TranscoderWrapper (line 4) | class TranscoderWrapper(torch.nn.Module): method __init__ (line 5) | def __init__(self, transcoder): method forward (line 8) | def forward(self, x): class TranscoderReplacementContext (line 11) | class TranscoderReplacementContext: method __init__ (line 12) | def __init__(self, model, transcoders): method __enter__ (line 20) | def __enter__(self): method __exit__ (line 24) | def __exit__(self, exc_type, exc_value, exc_tb): class ZeroAblationWrapper (line 28) | class ZeroAblationWrapper(torch.nn.Module): method __init__ (line 29) | def __init__(self): method forward (line 31) | def forward(self, x): class ZeroAblationContext (line 34) | class ZeroAblationContext: method __init__ (line 35) | def __init__(self, model, layers): method __enter__ (line 41) | def __enter__(self): method __exit__ (line 45) | def __exit__(self, exc_type, exc_value, exc_tb):