We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9af6f4f commit caab6daCopy full SHA for caab6da
onnx_diagnostic/tasks/image_text_to_text.py
@@ -156,7 +156,7 @@ def _get_inputs_gemma3(
156
},
157
"position_ids": {0: batch, 1: seq_length},
158
"cache_position": {0: seq_length},
159
- "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
+ "past_key_values": [{0: batch, 2: seq_length} for _ in range(num_hidden_layers * 2)],
160
"pixel_values": {0: batch},
161
"use_cache": None,
162
}
@@ -280,7 +280,7 @@ def get_inputs_default(
280
"past_key_values": list(
281
itertools.chain.from_iterable(
282
zip(
283
- [{0: batch} for _ in range(num_hidden_layers)],
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
284
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
285
)
286
0 commit comments