Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@
from rich import print as rprint
from transformers import AutoTokenizer

from transformer_lens import FactoredMatrix

CACHE_DIR = transformers.TRANSFORMERS_CACHE
USE_DEFAULT_VALUE = None


def select_compatible_kwargs(
kwargs_dict: Dict[str, Any], callable: Callable
) -> Dict[str, Any]:
Expand Down Expand Up @@ -97,8 +94,14 @@ def get_corner(tensor, n=3):
# Prints the top left corner of the tensor
if isinstance(tensor, torch.Tensor):
return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
elif isinstance(tensor, FactoredMatrix):
return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB
else:
# pylint: disable=wrong-import-position
# isort: off
from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies
# isort: on
# pylint: enable=wrong-import-position
if isinstance(tensor, FactoredMatrix):
return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB


def to_numpy(tensor):
Expand Down