Skip to content

How to use a SAE for analyzing a dataset

Once you have trained a SAE and saved it as an artifact in your Weights & Biases project, you can use it to analyze a dataset. This section will guide you through the process.

📲 Load the SAE artifact

First, import the necessary libraries and load the SAE artifact.

from sentence_transformers import SentenceTransformer
from transformers import pipeline
from config import get_default_cfg
from sae import JumpReLUSAE
import wandb
import os
import json
import torch
import numpy as np 

sbert = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')
distilbert = pipeline("fill-mask", model="distilbert/distilbert-base-cased")
cfg = get_default_cfg()

run = wandb.init()
artifact = run.use_artifact('path_to_your_wandb_artifact', type='model')
artifact_dir = artifact.download()
config_path = os.path.join(artifact_dir, 'config.json')
with open(config_path, 'r') as f:
    config = json.load(f)

if "dtype" in config and isinstance(config["dtype"], str):
    if config["dtype"] == "torch.float32":
        config["dtype"] = torch.float32
    elif config["dtype"] == "torch.float16":
        config["dtype"] = torch.float16

sae = JumpReLUSAE(config).to(config["device"])
sae.load_state_dict(torch.load(os.path.join(artifact_dir, 'sae.pt')))

The script above will use the configuration stored in the artifact to load the model. This way, you will not need to worry about defining the model architecture again.

📊 Load the dataset

Next, you need to create the dataset you want to analyze. It needs to be a custom class which inherits from the IterableDataset class. The following snippet shows an example of how to do it:

from torch.utils.data import DataLoader, IterableDataset
from datasets import load_dataset

class HFDatasetWrapper(IterableDataset):
    def __init__(self, hf_dataset):
        self.hf_dataset = hf_dataset

    def __iter__(self):
        for item in self.hf_dataset:
            if item is not None and item.get("text"):
                yield item

def collate_fn_skip_none(batch):
    return [item for item in batch if item is not None]

hf_dataset = load_dataset("UniverseTBD/arxiv-astro-abstracts-all", split="train", streaming=True)
dataset = HFDatasetWrapper(hf_dataset)

num_examples = config["num_examples"]
device = config["device"]
dict_size = config["dict_size"]

sbert = sbert.to(device)
sae = sae.to(device)
sae.eval()

dataloader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn_skip_none)

With this code, you are now ready to analyze the dataset.

🔍 Get feature density histogram

One of the indicators of whether the SAE has a good sparsity is the feature density histogram. To get the feature activations, you can run the following snippet:

feature_count = torch.zeros(dict_size, device=device)
processed = 0

for batch in dataloader:
    if processed >= num_examples:
        break

    texts = [item["text"] for item in batch]
    embeddings = sbert.encode(texts, convert_to_tensor=True, device=device)

    with torch.no_grad():
        sae_out = sae(embeddings)

    feature_acts = sae_out["feature_acts"]
    batch_count = (feature_acts > 0).float().sum(dim=0)
    feature_count += batch_count
    processed += len(texts)

    if processed % 20000 == 0:
        print(f"[INFO] {processed} examples processed.")

    temp = get_gpu_temperature()
    if temp is not None and temp >= 74:
        wait_for_gpu_cooling(threshold=74, resume_temp=60, check_interval=10)

feature_fire_rate = 100 * feature_count / num_examples
np.save("fire_rate_astro.npy", feature_fire_rate.cpu().numpy())

the .npy file will contain the number of times each feature was activated. If you want to plot the histogram, you can use the following code:

import matplotlib.pyplot as plt 
log_feature_fire_rate = torch.log10((feature_fire_rate / 100) + 1e-10).cpu().numpy()

plt.style.use('default')
plt.hist(log_feature_fire_rate, bins=50, color='tab:blue')
plt.xlabel("Log10 Feature density")
plt.ylabel("Number of features")
plt.title("Log Feature Fire Rate Distribution (csLG)")
plt.xlim(-10, 0)
plt.show()

Features with a 0% fire rate will be placed in the -10 value in the x-axis, and features with a 100% fire rate will be placed in the 0 value. For further details on how to interpret the histogram, I highly recommend reading this post.

📑 Get the top-10 activating texts

To generate the descriptions for the features, first you need to get examples of the texts which most activated each feature. You can do this by running the following code:

import heapq

num_features = config["dict_size"]
top_activations = [[] for _ in range(num_features)]

for i, example in enumerate(dataset):
    if i >= num_examples: break
    text = example["abstr"]
    embedding = sbert.encode(text, convert_to_tensor=True).squeeze(0).to(config["device"])
    with torch.no_grad():
        sae_out = sae(embedding)
    feature_acts = sae_out["feature_acts"]

    for j in range(num_features):
        activation_value = feature_acts[j].item()
        heap = top_activations[j]
        if len(heap) < 10: # Change this to the number of examples you want 
            heapq.heappush(heap, (activation_value, text))
        else:
            heapq.heappushpop(heap, (activation_value, text))

    if i % 5000 == 0:
        print(f"Processed {i} examples")


top_activations = [sorted(heap, key=lambda x: x[0], reverse=True) for heap in top_activations]
with open("top_activations_astro.json", "w", encoding="utf-8") as f:
    json.dump(top_activations, f, indent=2, ensure_ascii=False)

Once you have the top-k activating texts, you are now redy to generate the descriptions.