Skip to content

Commit 8989eb7

Browse files
committed
disable patches
1 parent caab6da commit 8989eb7

File tree

4 files changed

+11
-192
lines changed

4 files changed

+11
-192
lines changed

_unittests/ut_ci_models/test_ci_export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self):
2020
pretrained=False,
2121
part="",
2222
output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"),
23+
opset=24,
2324
)
2425
self.clean_dump()
2526

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
996996
with self.subTest(case="case5"):
997997
if not has_transformers("4.57"):
998998
raise unittest.SkipTest("transformers 4.57+.")
999+
if has_transformers("5.2.99"):
1000+
raise unittest.SkipTest("transformers 5.2+.")
9991001
with self.assertRaises((AttributeError, TypeError)):
10001002
model_inputs = model.prepare_inputs_for_generation(
10011003
input_ids, past_key_values=dynamic_cache

onnx_diagnostic/ci_models/export_qwen25_vl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import sys
6161
import time
6262
import warnings
63-
from typing import Any, Dict, List, Tuple
63+
from typing import Any, Dict, List, Optional, Tuple
6464
from .ci_helpers import (
6565
check_for_discrepancies_and_log_everything_into_a_json_file,
6666
compute_expected_outputs,
@@ -199,6 +199,7 @@ def main(
199199
atol: float = 0.01,
200200
mismatch01: float = 0.1,
201201
profile_exporter: bool = False,
202+
opset: Optional[int] = None,
202203
):
203204
"""
204205
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
@@ -221,6 +222,8 @@ def main(
221222
:param atol: raises an exception if tolerance is above that threshold
222223
:param mismatch01: raises an exception if the ratio of mismatches
223224
is above that threshold
225+
:param opset: opset, if not specified, a value is chosen based on the
226+
proposed rewriting
224227
:param profile_exporter: profiles the exporter
225228
"""
226229
prefix = simplify_model_id_for_a_filename(model_id)
@@ -243,6 +246,7 @@ def main(
243246
print(f"-- make_zip={make_zip}")
244247
print(f"-- output_folder={output_folder}")
245248
print(f"-- atol={atol}")
249+
print(f"-- opset={opset}")
246250
print(f"-- mismatch01={mismatch01}")
247251
print(f"-- profile_exporter={profile_exporter}")
248252
print("------------------------------------------------------------------")
@@ -473,15 +477,15 @@ def process_image(inputs_embeds, image_features):
473477

474478
begin = time.perf_counter()
475479

476-
target_opset = 22
480+
target_opset = opset or 22
477481
if (
478482
exporter == "onnx-dynamo"
479483
and device == "cuda"
480484
and "QWEN25ATTENTION" not in os.environ
481485
):
482486
os.environ["QWEN25ATTENTION"] = "PACKED"
483487
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
484-
target_opset = 23
488+
target_opset = opset or 23
485489

486490
with torch_export_patches(
487491
patch_torch=False,

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py

Lines changed: 1 addition & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class patched_GenerationMixin:
1919
(
2020
None
2121
if pv.Version(transformers.__version__) >= pv.Version("4.56")
22+
and pv.Version(transformers.__version__) < pv.Version("5.2.99")
2223
else "prepare_inputs_for_generation"
2324
),
2425
# (
@@ -297,192 +298,3 @@ def prepare_inputs_for_generation( # pragma: no cover
297298
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
298299
model_inputs.pop("labels", None)
299300
return model_inputs
300-
301-
'''
302-
# drops a patch since it is for a very specific version.
303-
def _sample(
304-
self,
305-
input_ids: torch.LongTensor,
306-
logits_processor: "LogitsProcessorList", # noqa: F821
307-
stopping_criteria: "StoppingCriteriaList", # noqa: F821
308-
generation_config: "GenerationConfig", # noqa: F821
309-
synced_gpus: bool = False,
310-
streamer: Optional["BaseStreamer"] = None, # noqa: F821
311-
**model_kwargs,
312-
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
313-
"""
314-
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
315-
"""
316-
# init values
317-
pad_token_id = generation_config._pad_token_tensor
318-
output_attentions = generation_config.output_attentions
319-
output_hidden_states = generation_config.output_hidden_states
320-
output_scores = generation_config.output_scores
321-
output_logits = generation_config.output_logits
322-
return_dict_in_generate = generation_config.return_dict_in_generate
323-
has_eos_stopping_criteria = any(
324-
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
325-
)
326-
do_sample = generation_config.do_sample
327-
328-
# init attention / hidden states / scores tuples
329-
scores = () if (return_dict_in_generate and output_scores) else None
330-
raw_logits = () if (return_dict_in_generate and output_logits) else None
331-
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
332-
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
333-
decoder_hidden_states = (
334-
() if (return_dict_in_generate and output_hidden_states) else None
335-
)
336-
337-
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
338-
if return_dict_in_generate and self.config.is_encoder_decoder:
339-
encoder_attentions = (
340-
model_kwargs["encoder_outputs"].get("attentions")
341-
if output_attentions
342-
else None
343-
)
344-
encoder_hidden_states = (
345-
model_kwargs["encoder_outputs"].get("hidden_states")
346-
if output_hidden_states
347-
else None
348-
)
349-
350-
# keep track of which sequences are already finished
351-
batch_size, cur_len = input_ids.shape[:2]
352-
this_peer_finished = False
353-
unfinished_sequences = torch.ones(
354-
batch_size, dtype=torch.long, device=input_ids.device
355-
)
356-
model_kwargs = self._get_initial_cache_position(
357-
cur_len, input_ids.device, model_kwargs
358-
)
359-
360-
model_forward = self.__call__
361-
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
362-
if compile_forward:
363-
os.environ["TOKENIZERS_PARALLELISM"] = "0"
364-
# If we use FA2 and a static cache, we cannot compile with fullgraph
365-
if self.config._attn_implementation == "flash_attention_2":
366-
# only raise warning if the user passed an explicit compile-config
367-
if (
368-
generation_config.compile_config is not None
369-
and generation_config.compile_config.fullgraph
370-
):
371-
generation_config.compile_config.fullgraph = False
372-
model_forward = self.get_compiled_call(generation_config.compile_config)
373-
374-
if generation_config.prefill_chunk_size is not None:
375-
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
376-
is_prefill = False
377-
else:
378-
is_prefill = True
379-
380-
while self._has_unfinished_sequences(
381-
this_peer_finished, synced_gpus, device=input_ids.device
382-
):
383-
# prepare model inputs
384-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
385-
386-
if is_prefill:
387-
outputs = self(**model_inputs, return_dict=True)
388-
is_prefill = False
389-
else:
390-
outputs = model_forward(**model_inputs, return_dict=True)
391-
392-
model_kwargs = self._update_model_kwargs_for_generation(
393-
outputs,
394-
model_kwargs,
395-
is_encoder_decoder=self.config.is_encoder_decoder,
396-
)
397-
if synced_gpus and this_peer_finished:
398-
continue
399-
400-
next_token_logits = outputs.logits[:, -1, :].to(
401-
copy=True, dtype=torch.float32, device=input_ids.device
402-
)
403-
404-
# pre-process distribution
405-
next_token_scores = logits_processor(input_ids, next_token_logits)
406-
407-
# Store scores, attentions and hidden_states when required
408-
if return_dict_in_generate:
409-
if output_scores:
410-
scores += (next_token_scores,)
411-
if output_logits:
412-
raw_logits += (next_token_logits,)
413-
if output_attentions:
414-
decoder_attentions += (
415-
(outputs.decoder_attentions,)
416-
if self.config.is_encoder_decoder
417-
else (outputs.attentions,)
418-
)
419-
if self.config.is_encoder_decoder:
420-
cross_attentions += (outputs.cross_attentions,)
421-
422-
if output_hidden_states:
423-
decoder_hidden_states += (
424-
(outputs.decoder_hidden_states,)
425-
if self.config.is_encoder_decoder
426-
else (outputs.hidden_states,)
427-
)
428-
429-
# token selection
430-
if do_sample:
431-
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
432-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
433-
else:
434-
next_tokens = torch.argmax(next_token_scores, dim=-1)
435-
436-
# finished sentences should have their next token be a padding token
437-
if has_eos_stopping_criteria:
438-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
439-
1 - unfinished_sequences
440-
)
441-
442-
# update generated ids, model inputs, and length for next step
443-
# PATCHED: the two following lines, next_tokens can 2D already for this model
444-
next_tokens_2d = (
445-
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
446-
)
447-
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
448-
if streamer is not None:
449-
streamer.put(next_tokens.cpu())
450-
451-
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
452-
this_peer_finished = unfinished_sequences.max() == 0
453-
cur_len += 1
454-
455-
# This is needed to properly delete outputs.logits which may be very large
456-
# for first iteration
457-
# Otherwise a reference to outputs is kept which keeps
458-
# the logits alive in the next iteration
459-
del outputs
460-
461-
if streamer is not None:
462-
streamer.end()
463-
464-
if return_dict_in_generate:
465-
if self.config.is_encoder_decoder:
466-
return transformers.generation.utils.GenerateEncoderDecoderOutput(
467-
sequences=input_ids,
468-
scores=scores,
469-
logits=raw_logits,
470-
encoder_attentions=encoder_attentions,
471-
encoder_hidden_states=encoder_hidden_states,
472-
decoder_attentions=decoder_attentions,
473-
cross_attentions=cross_attentions,
474-
decoder_hidden_states=decoder_hidden_states,
475-
past_key_values=model_kwargs.get("past_key_values"),
476-
)
477-
else:
478-
return transformers.generation.utils.GenerateDecoderOnlyOutput(
479-
sequences=input_ids,
480-
scores=scores,
481-
logits=raw_logits,
482-
attentions=decoder_attentions,
483-
hidden_states=decoder_hidden_states,
484-
past_key_values=model_kwargs.get("past_key_values"),
485-
)
486-
else:
487-
return input_ids
488-
'''

0 commit comments

Comments
 (0)