Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,38 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
assert output_tf == output_hf_str


def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)

hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

final_output = ""
for result in tf_model.generate_stream(
text,
do_sample=False,
use_past_kv_cache=True,
verbose=False,
max_new_tokens=10,
max_tokens_per_yield=10,
):
final_output += tf_model.to_string(result[0])

hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids
output_hf_tokens = hf_model.generate(
hf_input_ids,
do_sample=False,
max_new_tokens=10,
)
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)

assert (
final_output == output_hf_str
), f"\nStreaming output: {final_output}\nHF output: {output_hf_str}"


def check_norm_folding(
model_name,
hf_model=None,
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedEncoderDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ def generate(
else:
return decoder_input

@overload
@overload # type: ignore[overload-overlap]
def run_with_cache(
self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
...

@overload
@overload # type: ignore[overload-overlap]
def run_with_cache(
self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
Expand Down
227 changes: 227 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import logging
import os
from collections.abc import Generator
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -2353,6 +2354,232 @@ def generate(
else:
return embeds

@torch.inference_mode()
def generate_stream(
self,
input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
max_new_tokens: int = 10,
max_tokens_per_yield: int = 25,
stop_at_eos: bool = True,
eos_token_id: Optional[int] = None,
do_sample: bool = True,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0,
freq_penalty: float = 0.0,
use_past_kv_cache: bool = True,
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
return_type: Optional[str] = "input",
verbose: bool = True,
) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
"""Stream tokens from the Model as they are generated.

Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
yielding batches of tokens progressively during generation rather than waiting for the entire
sequence to be generated.

To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
(by producing an EOT token), we keep running the model on the entire batch, but throw away
the output for a finished sequence and just keep adding EOTs to pad.

This supports entering a single string, but not a list of strings - if the strings don't
tokenize to exactly the same length, this gets messy. If that functionality is needed,
convert them to a batch of tokens and input that instead.

Args:
input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
pos]) or a text string (this will be converted to a batch of tokens with batch size
1).
max_new_tokens (int): Maximum number of tokens to generate.
max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
Controls how frequently the function yields tokens during generation.
stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
of sentence. If None, use the tokenizer's eos_token_id - required if using
stop_at_eos. It's also possible to provide a list of token IDs (not just the
eos_token_id), in which case the generation will stop when any of them are output
(useful e.g. for stable_lm).
do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
greedy search (take the max logit each time).
top_k (int): Number of tokens to sample from. If None, sample from all tokens.
top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
we take the top tokens with cumulative probability >= top_p.
temperature (float): Temperature for sampling. Higher values will make the model more
random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
sampling from a uniform distribution).
freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
tokens. Higher values will make the model more random.
use_past_kv_cache (bool): If True, create and use cache to speed up generation.
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (applicable when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos (default is True unless specified
otherwise). Pass True or False to override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
return_type (Optional[str]): The type of the output to return - either a string (str),
a tensor of tokens (tensor) or whatever the format of the input was (input).
verbose (bool): If True, show tqdm progress bars for generation.

Yields:
outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
progressively during generation. Each yield contains accumulated tokens since the last
yield, up to max_tokens_per_yield.
"""

with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
if type(input) == str:
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string"
tokens = input

if return_type == "input":
if type(input) == str:
return_type = "str"
else:
return_type = "tensor"

assert isinstance(tokens, torch.Tensor)
batch_size, ctx_length = tokens.shape
device = devices.get_device_for_block_index(0, self.cfg)
tokens = tokens.to(device)
if use_past_kv_cache:
past_kv_cache = HookedTransformerKeyValueCache.init_cache(
self.cfg, self.cfg.device, batch_size
)
else:
past_kv_cache = None

stop_tokens: List[int] = []
eos_token_for_padding = 0
assert self.tokenizer is not None
if stop_at_eos:
tokenizer_has_eos_token = (
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"

eos_token_id = self.tokenizer.eos_token_id

if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
eos_token_for_padding = eos_token_id
else:
# eos_token_id is a Sequence (e.g. list or tuple)
stop_tokens = eos_token_id
eos_token_for_padding = (
self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
)

# An array to track which sequences in the batch have finished.
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)

accumulated_tokens: Optional[torch.Tensor] = None
tokens_since_last_yield = 0

# Currently nothing in HookedTransformer changes with eval, but this is here in case
# that changes in the future.
self.eval()
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.
if use_past_kv_cache:
# We just take the final tokens, as a [batch, 1] tensor
if index > 0:
logits = self.forward(
tokens[:, -1:],
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
logits = self.forward(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
# the cache.
logits = self.forward(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
)
final_logits = logits[:, -1, :]

if do_sample:
sampled_tokens = utils.sample_logits(
final_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
freq_penalty=freq_penalty,
tokens=tokens,
).to(devices.get_device_for_block_index(0, self.cfg))
else:
sampled_tokens = final_logits.argmax(-1).to(
devices.get_device_for_block_index(0, self.cfg)
)

if stop_at_eos:
# For all unfinished sequences, add on the next token. If a sequence was
# finished, throw away the generated token and add eos_token_for_padding
# instead.
sampled_tokens[finished_sequences] = eos_token_for_padding
finished_sequences.logical_or_(
torch.isin(
sampled_tokens.to(self.cfg.device),
torch.tensor(stop_tokens).to(self.cfg.device),
)
)

new_tokens = sampled_tokens.unsqueeze(-1)

# Accumulate tokens until we hit max_tokens_per_yield
if index == 0:
accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
tokens_since_last_yield = accumulated_tokens.shape[1]
else:
if accumulated_tokens is None:
accumulated_tokens = new_tokens
else:
accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
tokens_since_last_yield += 1

if tokens_since_last_yield >= max_tokens_per_yield:
yield accumulated_tokens
tokens_since_last_yield = 0
accumulated_tokens = None

tokens = torch.cat([tokens, new_tokens], dim=-1)

if stop_at_eos and finished_sequences.all():
# Yield any remaining accumulated tokens before breaking
if accumulated_tokens is not None:
yield accumulated_tokens
break

# Only yield remaining tokens if we didn't already yield them in the break case
if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
yield accumulated_tokens

# Give access to all weights as properties.
@property
def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
Expand Down
Loading