@@ -19,6 +19,7 @@ class patched_GenerationMixin:
1919 (
2020 None
2121 if pv .Version (transformers .__version__ ) >= pv .Version ("4.56" )
22+ and pv .Version (transformers .__version__ ) < pv .Version ("5.2.99" )
2223 else "prepare_inputs_for_generation"
2324 ),
2425 # (
@@ -297,192 +298,3 @@ def prepare_inputs_for_generation( # pragma: no cover
297298 # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
298299 model_inputs .pop ("labels" , None )
299300 return model_inputs
300-
301- '''
302- # drops a patch since it is for a very specific version.
303- def _sample(
304- self,
305- input_ids: torch.LongTensor,
306- logits_processor: "LogitsProcessorList", # noqa: F821
307- stopping_criteria: "StoppingCriteriaList", # noqa: F821
308- generation_config: "GenerationConfig", # noqa: F821
309- synced_gpus: bool = False,
310- streamer: Optional["BaseStreamer"] = None, # noqa: F821
311- **model_kwargs,
312- ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
313- """
314- 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
315- """
316- # init values
317- pad_token_id = generation_config._pad_token_tensor
318- output_attentions = generation_config.output_attentions
319- output_hidden_states = generation_config.output_hidden_states
320- output_scores = generation_config.output_scores
321- output_logits = generation_config.output_logits
322- return_dict_in_generate = generation_config.return_dict_in_generate
323- has_eos_stopping_criteria = any(
324- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
325- )
326- do_sample = generation_config.do_sample
327-
328- # init attention / hidden states / scores tuples
329- scores = () if (return_dict_in_generate and output_scores) else None
330- raw_logits = () if (return_dict_in_generate and output_logits) else None
331- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
332- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
333- decoder_hidden_states = (
334- () if (return_dict_in_generate and output_hidden_states) else None
335- )
336-
337- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
338- if return_dict_in_generate and self.config.is_encoder_decoder:
339- encoder_attentions = (
340- model_kwargs["encoder_outputs"].get("attentions")
341- if output_attentions
342- else None
343- )
344- encoder_hidden_states = (
345- model_kwargs["encoder_outputs"].get("hidden_states")
346- if output_hidden_states
347- else None
348- )
349-
350- # keep track of which sequences are already finished
351- batch_size, cur_len = input_ids.shape[:2]
352- this_peer_finished = False
353- unfinished_sequences = torch.ones(
354- batch_size, dtype=torch.long, device=input_ids.device
355- )
356- model_kwargs = self._get_initial_cache_position(
357- cur_len, input_ids.device, model_kwargs
358- )
359-
360- model_forward = self.__call__
361- compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
362- if compile_forward:
363- os.environ["TOKENIZERS_PARALLELISM"] = "0"
364- # If we use FA2 and a static cache, we cannot compile with fullgraph
365- if self.config._attn_implementation == "flash_attention_2":
366- # only raise warning if the user passed an explicit compile-config
367- if (
368- generation_config.compile_config is not None
369- and generation_config.compile_config.fullgraph
370- ):
371- generation_config.compile_config.fullgraph = False
372- model_forward = self.get_compiled_call(generation_config.compile_config)
373-
374- if generation_config.prefill_chunk_size is not None:
375- model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
376- is_prefill = False
377- else:
378- is_prefill = True
379-
380- while self._has_unfinished_sequences(
381- this_peer_finished, synced_gpus, device=input_ids.device
382- ):
383- # prepare model inputs
384- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
385-
386- if is_prefill:
387- outputs = self(**model_inputs, return_dict=True)
388- is_prefill = False
389- else:
390- outputs = model_forward(**model_inputs, return_dict=True)
391-
392- model_kwargs = self._update_model_kwargs_for_generation(
393- outputs,
394- model_kwargs,
395- is_encoder_decoder=self.config.is_encoder_decoder,
396- )
397- if synced_gpus and this_peer_finished:
398- continue
399-
400- next_token_logits = outputs.logits[:, -1, :].to(
401- copy=True, dtype=torch.float32, device=input_ids.device
402- )
403-
404- # pre-process distribution
405- next_token_scores = logits_processor(input_ids, next_token_logits)
406-
407- # Store scores, attentions and hidden_states when required
408- if return_dict_in_generate:
409- if output_scores:
410- scores += (next_token_scores,)
411- if output_logits:
412- raw_logits += (next_token_logits,)
413- if output_attentions:
414- decoder_attentions += (
415- (outputs.decoder_attentions,)
416- if self.config.is_encoder_decoder
417- else (outputs.attentions,)
418- )
419- if self.config.is_encoder_decoder:
420- cross_attentions += (outputs.cross_attentions,)
421-
422- if output_hidden_states:
423- decoder_hidden_states += (
424- (outputs.decoder_hidden_states,)
425- if self.config.is_encoder_decoder
426- else (outputs.hidden_states,)
427- )
428-
429- # token selection
430- if do_sample:
431- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
432- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
433- else:
434- next_tokens = torch.argmax(next_token_scores, dim=-1)
435-
436- # finished sentences should have their next token be a padding token
437- if has_eos_stopping_criteria:
438- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
439- 1 - unfinished_sequences
440- )
441-
442- # update generated ids, model inputs, and length for next step
443- # PATCHED: the two following lines, next_tokens can 2D already for this model
444- next_tokens_2d = (
445- next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
446- )
447- input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
448- if streamer is not None:
449- streamer.put(next_tokens.cpu())
450-
451- unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
452- this_peer_finished = unfinished_sequences.max() == 0
453- cur_len += 1
454-
455- # This is needed to properly delete outputs.logits which may be very large
456- # for first iteration
457- # Otherwise a reference to outputs is kept which keeps
458- # the logits alive in the next iteration
459- del outputs
460-
461- if streamer is not None:
462- streamer.end()
463-
464- if return_dict_in_generate:
465- if self.config.is_encoder_decoder:
466- return transformers.generation.utils.GenerateEncoderDecoderOutput(
467- sequences=input_ids,
468- scores=scores,
469- logits=raw_logits,
470- encoder_attentions=encoder_attentions,
471- encoder_hidden_states=encoder_hidden_states,
472- decoder_attentions=decoder_attentions,
473- cross_attentions=cross_attentions,
474- decoder_hidden_states=decoder_hidden_states,
475- past_key_values=model_kwargs.get("past_key_values"),
476- )
477- else:
478- return transformers.generation.utils.GenerateDecoderOnlyOutput(
479- sequences=input_ids,
480- scores=scores,
481- logits=raw_logits,
482- attentions=decoder_attentions,
483- hidden_states=decoder_hidden_states,
484- past_key_values=model_kwargs.get("past_key_values"),
485- )
486- else:
487- return input_ids
488- '''
0 commit comments