Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"more-itertools>=10.7.0",
"json-repair>=0.44.1",
]
requires-python = "==3.12.*"
requires-python = ">=3.11,<3.13"
readme = "README.md"

[[project.authors]]
Expand Down
60 changes: 21 additions & 39 deletions src/lm_saes/backend/attribution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
Sequence,
TypeVar,
cast,
)

Expand All @@ -16,7 +17,7 @@

from lm_saes.backend.hooks import replace_biases_with_leaves
from lm_saes.backend.indexed_tensor import Dimension, NodeIndexedMatrix, NodeIndexedVector, NodeInfo
from lm_saes.utils.distributed import DimMap
from lm_saes.utils.distributed import DimMap, full_tensor
from lm_saes.utils.distributed.ops import maybe_local_map, nonzero, searchsorted
from lm_saes.utils.misc import ensure_tokenized
from lm_saes.utils.timer import timer
Expand All @@ -36,14 +37,17 @@ class NodeInfoRef(NodeInfo):
ref: torch.Tensor


class NodeInfoQueue[T: NodeInfo]:
def __init__(self, node_infos: Sequence[T] = []):
NodeInfoT = TypeVar("NodeInfoT", bound=NodeInfo)


class NodeInfoQueue(Generic[NodeInfoT]):
def __init__(self, node_infos: Sequence[NodeInfoT] = ()):
self.queue = list(node_infos)

def enqueue(self, node_info: Sequence[T]):
def enqueue(self, node_info: Sequence[NodeInfoT]):
self.queue.extend(node_info)

def dequeue(self, batch_size: int) -> Sequence[T]:
def dequeue(self, batch_size: int) -> Sequence[NodeInfoT]:
accumulated = 0
results = []
while accumulated < batch_size and len(self.queue) > 0:
Expand All @@ -56,33 +60,11 @@ def dequeue(self, batch_size: int) -> Sequence[T]:
accumulated += len(results[-1])
return results

def iter(self, batch_size: int) -> Iterator[Sequence[T]]:
def iter(self, batch_size: int) -> Iterator[Sequence[NodeInfoT]]:
while len(self.queue) > 0:
yield self.dequeue(batch_size)


class NodeInfoSet[T: NodeInfo]:
def __init__(self, node_infos: Sequence[T] = []):
self.node_dict: dict[Any, T] = {}
self.extend(node_infos)

def extend(self, node_infos: Sequence[T]):
for node_info in node_infos:
if node_info.key not in self.node_dict:
self.node_dict[node_info.key] = replace(node_info)
else:
self.node_dict[node_info.key].indices = torch.cat(
[self.node_dict[node_info.key].indices, node_info.indices],
dim=0,
)

def __len__(self) -> int:
return sum(len(node_info) for node_info in self.node_dict.values())

def to_list(self) -> list[T]:
return list(self.node_dict.values())


@dataclass
class AttributionResult:
activations: NodeIndexedVector
Expand Down Expand Up @@ -206,8 +188,8 @@ def greedily_collect_attribution(
attribution[Dimension.from_node_infos(target_batch), None] = torch.cat(
[
einops.einsum(
value[: root.shape[0]],
grad[: root.shape[0]],
value.detach()[: root.shape[0]],
grad.detach()[: root.shape[0]],
"batch n_elements ..., batch n_elements ... -> batch n_elements",
)
for value, grad in zip(values(all_sources), grads(all_sources))
Expand Down Expand Up @@ -246,8 +228,8 @@ def greedily_collect_attribution(
torch.cat(
[
einops.einsum(
value[: root.shape[0]],
grad[: root.shape[0]],
value.detach()[: root.shape[0]],
grad.detach()[: root.shape[0]],
"batch n_elements ..., batch n_elements ... -> batch n_elements",
)
for value, grad in zip(values(all_sources), grads(all_sources))
Expand Down Expand Up @@ -565,9 +547,9 @@ def attribute(
sources_dimension = Dimension.from_node_infos(sources)
attribution = attribution[None, sources_dimension + collected_intermediates]

intermediate_ref_map = {node_info.key: node_info.ref for node_info, _ in intermediates}
intermediate_ref_map = {node_info.key: node_info.ref.detach() for node_info, _ in intermediates}
activations = torch.cat(
[node_info.ref[0, *node_info.indices.unbind(dim=1)] for node_info in targets]
[node_info.ref.detach()[0, *node_info.indices.unbind(dim=1)] for node_info in targets]
+ [
intermediate_ref_map[node_info.key][0, *node_info.indices.unbind(dim=1)]
for node_info in collected_intermediates
Expand All @@ -581,13 +563,13 @@ def attribute(
dimensions=(Dimension.from_node_infos(targets) + collected_intermediates + Dimension.from_node_infos(sources),),
)

prompt_token_ids = tokens.detach().cpu().tolist()
logit_token_ids = top_idx.detach().cpu().tolist()
prompt_token_ids = full_tensor(tokens).detach().cpu().tolist()
logit_token_ids = full_tensor(top_idx).detach().cpu().tolist()

return AttributionResult(
activations=activations_vec,
attribution=attribution,
logits=batch_logits[:, -1, top_idx],
logits=batch_logits[:, -1, top_idx].detach(),
probs=top_p,
prompt_token_ids=prompt_token_ids,
prompt_tokens=[model.tokenizer.decode([token_id]) for token_id in prompt_token_ids],
Expand Down
15 changes: 13 additions & 2 deletions src/lm_saes/backend/indexed_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def __add__(self, other: "Dimension") -> Self:
for key in all_keys
}

ret = self.__class__._from_node_mappings(node_mappings=node_mappings, device=self.device, mapper=self.mapper)
ret = self.__class__._from_node_mappings(
node_mappings=node_mappings, device=self.device, mapper=self.mapper, device_mesh=self.device_mesh
)
assert len(ret) == len(self) + len(other), "Dimension length mismatch"
return ret

Expand Down Expand Up @@ -619,7 +621,16 @@ def topk(self, k: int, ignore_dimension: Dimension | None = None):
data = self.data
if ignore_indices is not None:
data = data.clone()
data[ignore_indices] = float("-inf")
if not isinstance(data, DTensor):
data[ignore_indices] = float("-inf")
else:
data = DimMap({}).distribute(
full_tensor(data).index_put(
(full_tensor(ignore_indices),),
torch.tensor(float("-inf"), device=data.device),
),
device_mesh=data.device_mesh,
)
topk_values, topk_indices = torch.topk(data, k=k, dim=0)
return topk_values, self.dimensions[0].offsets_to_nodes(topk_indices)

Expand Down
Loading
Loading