diff --git a/backends/nxp/aten_passes/split_group_convolution.py b/backends/nxp/aten_passes/split_group_convolution.py index 58c87730c84..cad897071e4 100644 --- a/backends/nxp/aten_passes/split_group_convolution.py +++ b/backends/nxp/aten_passes/split_group_convolution.py @@ -111,15 +111,27 @@ def _create_convolution_node(self, conv_target, args: tuple) -> Node: # Compute the output shapes for the `convolution`, and assign the `val` meta. with FakeTensorMode() as mode: input_shapes = [ - input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape + ( + input_.meta["val"].shape + if hasattr(input_, "meta") + else input_.shape if input_ is not None else None + ) for input_ in args[:3] ] input_dtypes = [ - input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype + ( + input_.meta["val"].dtype + if hasattr(input_, "meta") + else input_.dtype if input_ is not None else None + ) for input_ in args[:3] ] fake_inputs = [ - FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) + ( + FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) + if shape is not None and dtype is not None + else None + ) for shape, dtype in zip(input_shapes, input_dtypes) ] output = conv_target(*fake_inputs, *args[3:]) @@ -211,7 +223,9 @@ def _is_conv(node_: Node): w_data = self._get_tensor_constant_from_node(w) b_data = self._get_tensor_constant_from_node(b) - if w_data is None or b_data is None: + + with_bias = b is not None + if w_data is None or (b_data is None and with_bias): continue # Only the standard case with static weights and bias is supported. # Create a `split` node to split the main input. @@ -227,10 +241,9 @@ def _is_conv(node_: Node): for i in range(groups) ] - # Split the weights and bias, across dimension `0`, slices of size `weight_split_size`. + # Split the weights across dimension `0`, slices of size `weight_split_size`. weight_split_size = w.meta["val"].shape[0] // groups split_weights_data = torch.split(w_data, weight_split_size, 0) - split_bias_data = torch.split(b_data, weight_split_size, 0) # Turn the weights and biases into parameter nodes containing the data. # Use a different name for every parameter. The function internally ensures the name's uniqueness, but @@ -241,12 +254,17 @@ def _is_conv(node_: Node): ) for i, weight_data in enumerate(split_weights_data) ] - split_bias_nodes = [ - self._create_parameter_node_for_data( - bias_data, b.name + f"_{i}_", split_node - ) - for i, bias_data in enumerate(split_bias_data) - ] + + if with_bias: + split_bias_data = torch.split(b_data, weight_split_size, 0) + split_bias_nodes = [ + self._create_parameter_node_for_data( + bias_data, b.name + f"_{i}_", split_node + ) + for i, bias_data in enumerate(split_bias_data) + ] + else: + split_bias_nodes = [None] * len(split_weight_nodes) # Create the `conv` nodes. with self.module.graph.inserting_after( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 5fa994be7ae..09d7c673733 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -48,7 +48,98 @@ from torch.nn import Parameter +# The arguments of the conv are: +# [x, w, b, stride, padding, dilation, transposed, output padding, groups] +# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291 +Stride = Padding = Dilation = OutPadding = list[int] +Transposed = bool +Groups = int +ConvolutionArgs = tuple[ + Node, Node, Node | None, Stride, Padding, Dilation, Transposed, OutPadding, Groups +] + + class ConvolutionConverter(NodeConverter): + @staticmethod + def _is_supported_on_target_new_flow( + node: Node, + parameters_mapping: dict[str, Parameter], + ) -> bool: + ( + inp_node, + w_node, + b_node, + stride, + padding, + dilation, + transposed, + _, + groups, + ) = ConvolutionConverter._get_convolution_arguments(node) + + # Input must be INT8/UINT8 + # Output must be INT8/UINT8 + inp_out_supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, inp_out_supported_types, [0], [0] + ): + return False + + # Weights must be INT8 + w_supported_types = [torch.int8] + if not NodeConverter.uses_quantization_type_for_io( + node, w_supported_types, [1], [] + ): + return False + + # Bias must be INT32 + if b_node is not None: + b_supported_types = [torch.int32] + if not NodeConverter.uses_quantization_type_for_io( + node, b_supported_types, [2], [] + ): + return False + + # Weights must be constant + if not node_is_effectively_static_tensor(w_node, parameters_mapping): + return False + + # Bias must be constant (if present) + if b_node is not None and not node_is_effectively_static_tensor( + b_node, parameters_mapping + ): + return False + + # kernelH <= 4096, kernelW <= 4096 + # strideH <= 4096, strideW <= 4096 + # dilationH <= 4096, dilationW <= 4096 + w_node_shape = ( + w_node.meta["val"].shape if hasattr(w_node, "meta") else w_node.shape + ) + + kernel_h = w_node_shape[2] + kernel_w = w_node_shape[3] + stride_h = stride[0] + stride_w = stride[1] + dilation_h = dilation[0] + dilation_w = dilation[1] + + dim_sizes = [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w] + + if any(dim > 4096 for dim in dim_sizes): + return False + + # kernelH * kernelW * inpC <= 65535 + inp_node_shape = ( + inp_node.meta["val"].shape if hasattr(inp_node, "meta") else inp_node.shape + ) + inp_channels = inp_node_shape[1] + + if kernel_h * kernel_w * inp_channels > 65535: + return False + + return True + @staticmethod def _is_supported_on_target( node: Node, @@ -56,62 +147,66 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if custom_delegation_options.use_new_flow_neutron_c: + return ConvolutionConverter._is_supported_on_target_new_flow( + node, parameters_mapping + ) + num_macs = neutron_target_spec.get_num_macs() node_t_params = get_node_tensor_params(node) - weights = node.args[1] - conv_params = ConvParameters( - *ConvolutionConverter._get_convolution_arguments(node) + _, w_node, _, stride, padding, dilation, transposed, _, groups = ( + ConvolutionConverter._get_convolution_arguments(node) ) if node_t_params["batch_size"] != 1: # Only batch size 1 is supported on neutron. return False - if conv_params.transposed: + if transposed: # TransposeConv2d with groups > 1 is not supported # TODO: split into multiple convs with groups = 1 - if conv_params.groups > 1: + if groups > 1: return False - if not node_is_effectively_static_tensor(weights, parameters_mapping): + if not node_is_effectively_static_tensor(w_node, parameters_mapping): # Only supported if the weights are static, because TFLite `TransposeConv` uses permuted # weights. In case the weights are dynamic, a Transpose operator would have to be added, which # is not supported on Neutron. return False # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind if ( - conv_params.dilation != [1, 1] - or conv_params.padding[0] != 0 - or conv_params.padding[1] >= node_t_params["kernel_width"] + dilation != [1, 1] + or padding[0] != 0 + or padding[1] >= node_t_params["kernel_width"] or ( - conv_params.padding[1] != 0 and node_t_params["inp_height"] != 1 + padding[1] != 0 and node_t_params["inp_height"] != 1 ) # Slice added by explicit padding - or conv_params.stride[0] != 1 + or stride[0] != 1 or ( ( - conv_params.stride[1] != node_t_params["kernel_width"] / 2 + stride[1] != node_t_params["kernel_width"] / 2 or node_t_params["out_height"] != 1 ) - and conv_params.stride[1] != node_t_params["kernel_width"] + and stride[1] != node_t_params["kernel_width"] ) - or conv_params.stride[1] % 2 != 0 + or stride[1] % 2 != 0 or node_t_params["inp_channels"] % num_macs != 0 or node_t_params["out_channels"] % num_macs != 0 or node_t_params["kernel_width"] % 2 != 0 or node_t_params["kernel_height"] != 1 ): return False - elif conv_params.groups == 1: # Regular convolution. + elif groups == 1: # Regular convolution. pass elif conv_utils.group_conv_convertible_as_depthwise( - node, conv_params.groups + node, groups ): # Depthwise convolution. # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted # weights. In case the weights are dynamic, a Transpose operator would have to be added, which # is not supported on Neutron. - if not node_is_effectively_static_tensor(weights, parameters_mapping): + if not node_is_effectively_static_tensor(w_node, parameters_mapping): return False elif conv_utils.group_conv_convertible_into_multiple_convolutions( - node, conv_params.groups + node, groups ): # Separable conv. # Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron. return False @@ -149,10 +244,6 @@ def _is_supported_in_IR( return True - Stride = Padding = Dilation = OutPadding = list[int] - Transposed = bool - Groups = int - def _compute_slicing_params( self, output_shape, explicit_padding ) -> tuple[list[int], list[int]]: @@ -170,14 +261,14 @@ def _compute_slicing_params( @staticmethod def _get_convolution_arguments( conv_node: Node, - ) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups): - # The arguments of the conv are: - # [x, w, b, stride, padding, dilation, transposed, output padding, groups] - # https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291 - _, _, _, stride, padding, dilation, transposed, out_padding, groups = ( + ) -> ConvolutionArgs: + x, w, b, stride, padding, dilation, transposed, out_padding, groups = ( conv_node.args ) return ( + x, + w, + b, list(stride), list(padding), list(dilation), @@ -380,16 +471,8 @@ def _convert_2d_conv( elif conv_utils.group_conv_convertible_into_multiple_convolutions( t_op, conv_params.groups - ): # Convert to separated `Conv2D`. - t_op.builtin_options = conv_2d_options.Conv2D() - - return conv_utils.create_separated_convolutions_based_on_group( - t_op, - conv_params, - self.builder, - self._convert_unpadded_2D, - conv_utils.conv_op_factory, - ) + ): + raise RuntimeError("Group convolution was not decomposed.") else: # Convert to regular `Conv2D`. @@ -419,7 +502,7 @@ def _convert_2d_conv( def convert(self, node: Node): self.assert_convertible(node) - stride, padding, dilation, transposed, out_padding, groups = ( + _, _, _, stride, padding, dilation, transposed, out_padding, groups = ( self._get_convolution_arguments(node) ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index 5580d0ca729..39fba3b5ea7 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -6,7 +6,6 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) @@ -17,6 +16,8 @@ from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) + +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, @@ -27,7 +28,17 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.backends.nxp.tests.nsys_testing import ( + AllCloseOutputComparator, + lower_run_compare, +) +from executorch.backends.nxp.tests.ops_aliases import ( + Convolution, + ExecutorchDelegateCall, + ViewCopy, +) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -477,3 +488,745 @@ def test_conv2d_conversion__depthwise__delegates_without_post_quantization_state nodes = list(edge_program.graph.nodes) assert any(n.target == "lowered_module_0" for n in nodes) + + +class TestConvNewNeutronFlow: + @staticmethod + def _conv_id(ins, oc, ks=3, s=2, d=1, p=0, b=True, g=1): + return ( + f"input_shape={ins}, " + f"out_channels={oc}, " + f"kernel_size={ks}, " + f"stride={s}, " + f"dilation={d}, " + f"padding={p}, " + f"bias={b}, " + f"group={g}" + ) + + @staticmethod + def assert_delegated_and_correct(model, input_shape, mocker, use_qat): + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset = RandomDatasetCreator(low=-1, high=1) + comparator = AllCloseOutputComparator(atol=1) + + try: + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset, + comparator, + use_new_flow_neutron_c=True, + use_qat=use_qat, + ) + + except AssertionError as e: + if "NPU output doesn't match reference" in str(e): + pytest.xfail( + "AIR-14660: The `conv` node was delegated, but does not compute correctly." + ) + + else: + raise + + except RuntimeError as e: + if ( + "Model converted with neutron-converter does not contain a NeutronGraph node." + in str(e) + ): + pytest.xfail( + "AIR-14661: The `conv` node should have been delegated, but is not." + ) + + else: + raise + + @staticmethod + def assert_not_delegated(model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `convolution` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [Convolution]) + + @pytest.mark.parametrize( + "input_shape, out_channels", + [ + pytest.param( + ins := (1, 8, 16, 24), + oc := 8, + id="basic inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (8, 16, 8, 32), + oc := 16, + id="basic inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (16, 8, 32, 64), + oc := 32, + id="basic inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (1, 8, 32, 64), + oc := 16, + id="basic inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (1, 32, 48, 8), + oc := 24, + id="basic inference: " + _conv_id(ins, oc), + ), + ], + ) + def test__basic_nsys_inference(self, input_shape, out_channels, use_qat, mocker): + in_channels = input_shape[1] + model = Conv2dModule(in_channels=in_channels, out_channels=out_channels) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape", + [ + pytest.param( + ins := (1, 8, 16, 24), + id="basic inference, depthwise: " + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (8, 16, 8, 32), + id="basic inference, depthwise: " + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (16, 8, 32, 64), + id="basic inference, depthwise: " + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (1, 16, 32, 64), + id="basic inference, depthwise: " + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (1, 32, 48, 8), + id="basic inference, depthwise: " + _conv_id(ins, ins[1], g=ins[1]), + ), + ], + ) + def test__basic_nsys_inference_depthwise(self, input_shape, use_qat, mocker): + out_channels = input_shape[1] + group = input_shape[1] + model = Conv2dModule( + in_channels=input_shape[1], out_channels=out_channels, group=group + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, out_channels", + [ + pytest.param( + ins := (1, 3, 7, 14), + oc := 3, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (2, 3, 13, 27), + oc := 7, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (3, 7, 3, 14), + oc := 4, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (1, 7, 7, 21), + oc := 1, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (7, 7, 7, 7), + oc := 10, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + pytest.param( + ins := (4, 21, 13, 17), + oc := 27, + id="unusual shape inference: " + _conv_id(ins, oc), + ), + ], + ) + def test__basic_nsys_inference__unusual_shapes( + self, input_shape, out_channels, use_qat, mocker + ): + model = Conv2dModule(in_channels=input_shape[1], out_channels=out_channels) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape", + [ + pytest.param( + ins := (1, 3, 7, 14), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (2, 3, 13, 27), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (3, 7, 3, 14), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (1, 7, 7, 21), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (7, 7, 7, 7), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + pytest.param( + ins := (4, 21, 13, 17), + id="unusual shape inference, depthwise: " + + _conv_id(ins, ins[1], g=ins[1]), + ), + ], + ) + def test__basic_nsys_inference_depthwise__unusual_shapes( + self, input_shape, use_qat, mocker + ): + out_channels = input_shape[1] + group = input_shape[1] + + model = Conv2dModule( + in_channels=input_shape[1], out_channels=out_channels, group=group + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, out_channels", + [ + pytest.param( + ins := (21, 4, 7), + oc := 45, + id="`conv2d` implicit batch: " + _conv_id(ins, oc), + ), + ], + ) + def test__basic_nsys_inference__implicit_batch( + self, input_shape, out_channels, use_qat, mocker + ): + in_channels = input_shape[0] + + model = Conv2dModule(in_channels=in_channels, out_channels=out_channels) + + # `view_copy` is inserted to convert to explicit batch + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={Convolution: 1}, + expected_non_delegated_ops={ViewCopy: 2}, + ) + dataset = RandomDatasetCreator(low=-1, high=1) + comparator = AllCloseOutputComparator(atol=1) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset, + comparator, + use_new_flow_neutron_c=True, + use_qat=use_qat, + ) + + @pytest.mark.parametrize( + "input_shape, out_channels, kernel_size, stride, dilation", + [ + pytest.param( + ins := (2, 3, 1, 4100), + oc := 7, + ks := (1, 4096), + s := 1, + d := 1, + id=f"bounds of kernel width: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 3, 4100, 1), + oc := 9, + ks := (4096, 1), + s := 1, + d := 1, + id=f"bounds of kernel height: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (4, 3, 3, 8500), + oc := 5, + ks := 3, + s := (1, 4096), + d := 1, + id=f"bounds of stride width: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (2, 3, 8500, 3), + oc := 11, + ks := 3, + s := (4096, 1), + d := 1, + id=f"bounds of stride height: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 3, 3, 8500), + oc := 9, + ks := 3, + s := 1, + d := (1, 4096), + id=f"bounds of dilation width: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (4, 3, 8500, 3), + oc := 7, + ks := 3, + s := 1, + d := (4096, 1), + id=f"bounds of dilation height: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (2, 15, 91, 91), + oc := 13, + ks := (61, 71), + s := 1, + d := 1, + id=f"bounds of kernel_h * kernel_w * input_channels: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + ], + ) + def test__basic_nsys_inference__big( + self, input_shape, out_channels, kernel_size, stride, dilation, use_qat, mocker + ): + model = Conv2dModule( + in_channels=input_shape[1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, kernel_size, stride, dilation", + [ + pytest.param( + ins := (2, 3, 1, 4100), + ks := (1, 4096), + s := 1, + d := 1, + id=f"bounds of kernel width: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 3, 4100, 1), + ks := (4096, 1), + s := 1, + d := 1, + id=f"bounds of kernel height: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (2, 3, 3, 8500), + ks := 3, + s := (1, 4096), + d := 1, + id=f"bounds of stride width: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (4, 3, 8500, 3), + ks := 3, + s := (4096, 1), + d := 1, + id=f"bounds of stride height: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (4, 3, 3, 8500), + ks := 3, + s := 1, + d := (1, 4096), + id=f"bounds of dilation width: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 3, 8500, 3), + ks := 3, + s := 1, + d := (4096, 1), + id=f"bounds of dilation height: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (2, 15, 91, 91), + ks := (61, 71), + s := 1, + d := 1, + id=f"bounds of kernel_h * kernel_w * input_channels: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + ], + ) + def test__basic_nsys_inference_depthwise__big( + self, input_shape, kernel_size, stride, dilation, use_qat, mocker + ): + out_channels = input_shape[1] + group = input_shape[1] + + model = Conv2dModule( + in_channels=input_shape[1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + group=group, + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, out_channels, kernel_size, stride, dilation, padding, bias", + [ + pytest.param( + ins := (1, 8, 32, 32), + oc := 7, + ks := (5, 3), + s := (2, 1), + d := (1, 2), + p := (2, 1), + b := True, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (2, 7, 31, 17), + oc := 9, + ks := (7, 7), + s := (3, 2), + d := (2, 1), + p := (3, 3), + b := False, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (2, 12, 28, 28), + oc := 11, + ks := (3, 5), + s := (2, 2), + d := (2, 2), + p := (1, 2), + b := True, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (3, 2, 40, 20), + oc := 13, + ks := (1, 5), + s := (1, 2), + d := (3, 1), + p := (0, 2), + b := False, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (4, 6, 30, 30), + oc := 5, + ks := (3, 3), + s := (2, 2), + d := (1, 1), + p := (2, 2), + b := True, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (3, 12, 7, 7), + oc := 7, + ks := (5, 5), + s := (1, 3), + d := (1, 2), + p := (2, 4), + b := False, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + pytest.param( + ins := (1, 4, 15, 15), + oc := 9, + ks := (2, 2), + s := (2, 2), + d := (2, 2), + p := (1, 1), + b := True, + id=f"some params not default: {_conv_id(ins, oc, ks=ks, s=s, d=d, p=p, b=b)}", + ), + ], + ) + def test__nsys_inference__non_default_params( + self, + input_shape, + out_channels, + kernel_size, + stride, + dilation, + padding, + bias, + use_qat, + mocker, + ): + model = Conv2dModule( + in_channels=input_shape[1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, kernel_size, stride, dilation, padding, bias", + [ + pytest.param( + ins := (1, 8, 32, 32), + ks := (5, 3), + s := (2, 1), + d := (1, 2), + p := (2, 1), + b := True, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 31, 17), + ks := (7, 7), + s := (3, 2), + d := (2, 1), + p := (3, 3), + b := False, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (2, 12, 28, 28), + ks := (3, 5), + s := (2, 2), + d := (2, 2), + p := (1, 2), + b := True, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (3, 2, 40, 20), + ks := (1, 5), + s := (1, 2), + d := (3, 1), + p := (0, 2), + b := False, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (4, 6, 30, 30), + ks := (3, 3), + s := (2, 2), + d := (1, 1), + p := (2, 2), + b := True, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (3, 12, 7, 7), + ks := (5, 5), + s := (1, 3), + d := (1, 2), + p := (2, 4), + b := False, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + pytest.param( + ins := (1, 4, 15, 15), + ks := (2, 2), + s := (2, 2), + d := (2, 2), + p := (1, 1), + b := True, + id=f"some params not default: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, p=p, b=b, g=ins[1])}", + ), + ], + ) + def test__nsys_inference_depthwise__non_default_params( + self, input_shape, kernel_size, stride, dilation, padding, bias, use_qat, mocker + ): + out_channels = input_shape[1] + group = input_shape[1] + + model = Conv2dModule( + in_channels=input_shape[1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + group=group, + ) + + self.assert_delegated_and_correct(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape, out_channels, kernel_size, stride, dilation", + [ + pytest.param( + ins := (3, 7, 5000, 11), + oc := 7, + ks := (4097, 1), + s := 1, + d := 1, + id=f"kernel height too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 7, 13, 5000), + oc := 9, + ks := (1, 4097), + s := 1, + d := 1, + id=f"kernel width too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 7, 5000, 11), + oc := 11, + ks := 3, + s := (4097, 1), + d := 1, + id=f"stride height too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 7, 13, 5000), + oc := 5, + ks := 3, + s := (1, 4097), + d := 1, + id=f"stride width too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 7, 8500, 11), + oc := 7, + ks := 3, + s := 1, + d := (4097, 1), + id=f"dilation height too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 7, 13, 8500), + oc := 9, + ks := 3, + s := 1, + d := (1, 4097), + id=f"dilation width too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + pytest.param( + ins := (3, 113, 123, 133), + oc := 11, + ks := (41, 15), + s := 1, + d := 1, + id=f"kernel_h * kernel_w * input_channels too big: {_conv_id(ins, oc, ks=ks, s=s, d=d)}", + ), + ], + ) + def test__non_delegation( + self, input_shape, out_channels, kernel_size, stride, dilation + ): + in_channels = input_shape[1] + + model = Conv2dModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, kernel_size, stride, dilation", + [ + pytest.param( + ins := (3, 7, 5000, 11), + ks := (4097, 1), + s := 1, + d := 1, + id=f"kernel height too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 13, 5000), + ks := (1, 4097), + s := 1, + d := 1, + id=f"kernel width too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 5000, 11), + ks := 3, + s := (4097, 1), + d := 1, + id=f"stride height too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 13, 5000), + ks := 3, + s := (1, 4097), + d := 1, + id=f"stride width too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 8500, 11), + ks := 3, + s := 1, + d := (4097, 1), + id=f"dilation height too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 7, 13, 8500), + ks := 3, + s := 1, + d := (1, 4097), + id=f"dilation width too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + pytest.param( + ins := (3, 113, 123, 133), + ks := (41, 15), + s := 1, + d := 1, + id=f"kernel_h * kernel_w * input_channels too big, depthwise: {_conv_id(ins, ins[1], ks=ks, s=s, d=d, g=ins[1])}", + ), + ], + ) + def test__non_delegation_depthwise( + self, input_shape, kernel_size, stride, dilation + ): + out_channels = input_shape[1] + group = input_shape[1] + + model = Conv2dModule( + in_channels=input_shape[1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + group=group, + ) + + self.assert_not_delegated(model, input_shape) diff --git a/backends/nxp/tests/model_output_comparator.py b/backends/nxp/tests/model_output_comparator.py index f0dd7cd2d60..1092bae32f2 100644 --- a/backends/nxp/tests/model_output_comparator.py +++ b/backends/nxp/tests/model_output_comparator.py @@ -97,7 +97,9 @@ def compare_sample(self, sample_dir, cpu_output_tensors, npu_output_tensors): print( f"NPU output doesn't match reference. Maximum absolute difference: {max_diff}" ) - assert all_close + assert ( + all_close + ), f"NPU output doesn't match reference. Maximum absolute difference: {max_diff}" def _default_postprocess_fn(outputs: np.ndarray, _: str):