From 634371914f14bdcbb9dd7c3974e7caff551f46f7 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 30 Jan 2025 13:41:21 -0800 Subject: [PATCH 1/4] feat: streaming response for HookedTransformer.generate --- transformer_lens/HookedTransformer.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a34a5c4a0..48f1a7cb6 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -11,6 +11,7 @@ import logging import os +from collections.abc import Generator from typing import ( Dict, List, @@ -2044,7 +2045,12 @@ def generate( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, return_type: Optional[str] = "input", verbose: bool = True, - ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]: + stream_output: bool = False, + ) -> Union[ + Int[torch.Tensor, "batch pos_plus_new_tokens"], + str, + Generator[Union[Int[torch.Tensor, "batch"], str], None, None], + ]: """Sample Tokens from the Model. Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. @@ -2213,7 +2219,18 @@ def generate( ) ) - tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) + new_tokens = sampled_tokens.unsqueeze(-1) + if stream_output: + if return_type == "str": + tokens_to_return = self.tokenizer.decode(new_tokens) + if self.cfg.default_prepend_bos and index == 0: + yield tokens_to_return[1:] + else: + yield tokens_to_return + else: + yield new_tokens + + tokens = torch.cat([tokens, new_tokens], dim=-1) if stop_at_eos and finished_sequences.all(): break From 2c1d49266138d01de177b3e701c84e5584179f57 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 30 Jan 2025 16:23:14 -0800 Subject: [PATCH 2/4] fix: hacky duplicate generate_stream --- transformer_lens/HookedTransformer.py | 226 ++++++++++++++++++++++++-- 1 file changed, 211 insertions(+), 15 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 48f1a7cb6..90d5c9bbd 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -11,7 +11,6 @@ import logging import os -from collections.abc import Generator from typing import ( Dict, List, @@ -24,6 +23,7 @@ cast, overload, ) +from collections.abc import Generator import einops import numpy as np @@ -2045,12 +2045,208 @@ def generate( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, return_type: Optional[str] = "input", verbose: bool = True, - stream_output: bool = False, - ) -> Union[ - Int[torch.Tensor, "batch pos_plus_new_tokens"], - str, - Generator[Union[Int[torch.Tensor, "batch"], str], None, None], - ]: + ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]: + """Sample Tokens from the Model. + + Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. + + To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish + (by producing an EOT token), we keep running the model on the entire batch, but throw away + the output for a finished sequence and just keep adding EOTs to pad. + + This supports entering a single string, but not a list of strings - if the strings don't + tokenize to exactly the same length, this gets messy. If that functionality is needed, + convert them to a batch of tokens and input that instead. + + Args: + input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, + pos]) or a text string (this will be converted to a batch of tokens with batch size + 1). + max_new_tokens (int): Maximum number of tokens to generate. + stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. + eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end + of sentence. If None, use the tokenizer's eos_token_id - required if using + stop_at_eos. It's also possible to provide a list of token IDs (not just the + eos_token_id), in which case the generation will stop when any of them are output + (useful e.g. for stable_lm). + do_sample (bool): If True, sample from the model's output distribution. Otherwise, use + greedy search (take the max logit each time). + top_k (int): Number of tokens to sample from. If None, sample from all tokens. + top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, + we take the top tokens with cumulative probability >= top_p. + temperature (float): Temperature for sampling. Higher values will make the model more + random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is + sampling from a uniform distribution). + freq_penalty (float): Frequency penalty for sampling - how much to penalise previous + tokens. Higher values will make the model more random. + use_past_kv_cache (bool): If True, create and use cache to speed up generation. + prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend + the BOS token to the input (applicable when input is a string). Defaults to None, + implying usage of self.cfg.default_prepend_bos (default is True unless specified + otherwise). Pass True or False to override the default. + padding_side (Union[Literal["left", "right"], None], optional): Overrides + self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple + strings of different lengths. + return_type (Optional[str]): The type of the output to return - either a string (str), + a tensor of tokens (tensor) or whatever the format of the input was (input). + verbose (bool): If True, show tqdm progress bars for generation. + + Returns: + outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens + (by default returns same type as input). + """ + + with utils.LocallyOverridenDefaults( + self, prepend_bos=prepend_bos, padding_side=padding_side + ): + if type(input) == str: + # If text, convert to tokens (batch_size=1) + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if passing a string to the model" + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + else: + tokens = input + + if return_type == "input": + if type(input) == str: + return_type = "str" + else: + return_type = "tensor" + + assert isinstance(tokens, torch.Tensor) + batch_size, ctx_length = tokens.shape + device = devices.get_device_for_block_index(0, self.cfg) + tokens = tokens.to(device) + if use_past_kv_cache: + past_kv_cache = HookedTransformerKeyValueCache.init_cache( + self.cfg, self.cfg.device, batch_size + ) + else: + past_kv_cache = None + + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + assert self.tokenizer is not None + if stop_at_eos: + tokenizer_has_eos_token = ( + self.tokenizer is not None and self.tokenizer.eos_token_id is not None + ) + if eos_token_id is None: + assert ( + tokenizer_has_eos_token + ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" + + eos_token_id = self.tokenizer.eos_token_id + + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + # eos_token_id is a Sequence (e.g. list or tuple) + stop_tokens = eos_token_id + eos_token_for_padding = ( + self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] + ) + + # An array to track which sequences in the batch have finished. + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + # Currently nothing in HookedTransformer changes with eval, but this is here in case + # that changes in the future. + self.eval() + for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): + # While generating, we keep generating logits, throw away all but the final logits, + # and then use those logits to sample from the distribution We keep adding the + # sampled tokens to the end of tokens. + if use_past_kv_cache: + # We just take the final tokens, as a [batch, 1] tensor + if index > 0: + logits = self.forward( + tokens[:, -1:], + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + # We input the entire sequence, as a [batch, pos] tensor, since we aren't using + # the cache. + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + ) + final_logits = logits[:, -1, :] + + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + tokens=tokens, + ).to(devices.get_device_for_block_index(0, self.cfg)) + else: + sampled_tokens = final_logits.argmax(-1).to( + devices.get_device_for_block_index(0, self.cfg) + ) + + if stop_at_eos: + # For all unfinished sequences, add on the next token. If a sequence was + # finished, throw away the generated token and add eos_token_for_padding + # instead. + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) + + if stop_at_eos and finished_sequences.all(): + break + + if return_type == "str": + if self.cfg.default_prepend_bos: + # If we prepended a BOS token, remove it when returning output. + return self.tokenizer.decode(tokens[0, 1:]) + else: + return self.tokenizer.decode(tokens[0]) + + else: + return tokens + + @torch.inference_mode() + def generate_stream( + self, + input: Union[str, Float[torch.Tensor, "batch pos"]] = "", + max_new_tokens: int = 10, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, + padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: """Sample Tokens from the Model. Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. @@ -2220,15 +2416,15 @@ def generate( ) new_tokens = sampled_tokens.unsqueeze(-1) - if stream_output: - if return_type == "str": - tokens_to_return = self.tokenizer.decode(new_tokens) - if self.cfg.default_prepend_bos and index == 0: - yield tokens_to_return[1:] - else: - yield tokens_to_return + + if return_type == "str": + tokens_to_return = self.tokenizer.decode(new_tokens) + if self.cfg.default_prepend_bos and index == 0: + yield tokens_to_return[1:] else: - yield new_tokens + yield tokens_to_return + else: + yield new_tokens tokens = torch.cat([tokens, new_tokens], dim=-1) From ee54d7595a5c7f389826b94223fe4c26f79ee748 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 31 Jan 2025 00:00:57 -0800 Subject: [PATCH 3/4] fix: simplify the yield --- transformer_lens/HookedTransformer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 90d5c9bbd..447c1f886 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2417,12 +2417,8 @@ def generate_stream( new_tokens = sampled_tokens.unsqueeze(-1) - if return_type == "str": - tokens_to_return = self.tokenizer.decode(new_tokens) - if self.cfg.default_prepend_bos and index == 0: - yield tokens_to_return[1:] - else: - yield tokens_to_return + if index == 0: + yield torch.cat([tokens, new_tokens], dim=-1) else: yield new_tokens From 94d38613c501349e9d9dfc3ad3a6f504f91e3f2c Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 4 Feb 2025 00:06:11 +0100 Subject: [PATCH 4/4] ran format --- transformer_lens/HookedTransformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 447c1f886..d54e07920 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -11,6 +11,7 @@ import logging import os +from collections.abc import Generator from typing import ( Dict, List, @@ -23,7 +24,6 @@ cast, overload, ) -from collections.abc import Generator import einops import numpy as np