Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions backends/nxp/aten_passes/split_group_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,27 @@ def _create_convolution_node(self, conv_target, args: tuple) -> Node:
# Compute the output shapes for the `convolution`, and assign the `val` meta.
with FakeTensorMode() as mode:
input_shapes = [
input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape
(
input_.meta["val"].shape
if hasattr(input_, "meta")
else input_.shape if input_ is not None else None
)
for input_ in args[:3]
]
input_dtypes = [
input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype
(
input_.meta["val"].dtype
if hasattr(input_, "meta")
else input_.dtype if input_ is not None else None
)
for input_ in args[:3]
]
fake_inputs = [
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
(
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
if shape is not None and dtype is not None
else None
)
for shape, dtype in zip(input_shapes, input_dtypes)
]
output = conv_target(*fake_inputs, *args[3:])
Expand Down Expand Up @@ -211,7 +223,9 @@ def _is_conv(node_: Node):

w_data = self._get_tensor_constant_from_node(w)
b_data = self._get_tensor_constant_from_node(b)
if w_data is None or b_data is None:

with_bias = b is not None
if w_data is None or (b_data is None and with_bias):
continue # Only the standard case with static weights and bias is supported.

# Create a `split` node to split the main input.
Expand All @@ -227,10 +241,9 @@ def _is_conv(node_: Node):
for i in range(groups)
]

# Split the weights and bias, across dimension `0`, slices of size `weight_split_size`.
# Split the weights across dimension `0`, slices of size `weight_split_size`.
weight_split_size = w.meta["val"].shape[0] // groups
split_weights_data = torch.split(w_data, weight_split_size, 0)
split_bias_data = torch.split(b_data, weight_split_size, 0)

# Turn the weights and biases into parameter nodes containing the data.
# Use a different name for every parameter. The function internally ensures the name's uniqueness, but
Expand All @@ -241,12 +254,17 @@ def _is_conv(node_: Node):
)
for i, weight_data in enumerate(split_weights_data)
]
split_bias_nodes = [
self._create_parameter_node_for_data(
bias_data, b.name + f"_{i}_", split_node
)
for i, bias_data in enumerate(split_bias_data)
]

if with_bias:
split_bias_data = torch.split(b_data, weight_split_size, 0)
split_bias_nodes = [
self._create_parameter_node_for_data(
bias_data, b.name + f"_{i}_", split_node
)
for i, bias_data in enumerate(split_bias_data)
]
else:
split_bias_nodes = [None] * len(split_weight_nodes)

# Create the `conv` nodes.
with self.module.graph.inserting_after(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,70 +48,165 @@
from torch.nn import Parameter


# The arguments of the conv are:
# [x, w, b, stride, padding, dilation, transposed, output padding, groups]
# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
This link only describes the arguments after x, w and b, so for example it doesn't mention that the bias is optional (as your code suggests).
Perhaps this could be a better source for this information: https://docs.pytorch.org/docs/main/user_guide/torch_compiler/torch.compiler_ir.html (mostly for future PRs)

convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor

Stride = Padding = Dilation = OutPadding = list[int]
Transposed = bool
Groups = int
ConvolutionArgs = tuple[
Node, Node, Node | None, Stride, Padding, Dilation, Transposed, OutPadding, Groups
]


class ConvolutionConverter(NodeConverter):
@staticmethod
def _is_supported_on_target_new_flow(
node: Node,
parameters_mapping: dict[str, Parameter],
) -> bool:
(
inp_node,
w_node,
b_node,
stride,
padding,
dilation,
transposed,
_,
groups,
) = ConvolutionConverter._get_convolution_arguments(node)

# Input must be INT8/UINT8
# Output must be INT8/UINT8
inp_out_supported_types = [torch.int8, torch.uint8]
if not NodeConverter.uses_quantization_type_for_io(
node, inp_out_supported_types, [0], [0]
):
return False

# Weights must be INT8
w_supported_types = [torch.int8]
if not NodeConverter.uses_quantization_type_for_io(
node, w_supported_types, [1], []
):
return False

# Bias must be INT32
if b_node is not None:
b_supported_types = [torch.int32]
if not NodeConverter.uses_quantization_type_for_io(
node, b_supported_types, [2], []
):
return False

# Weights must be constant
if not node_is_effectively_static_tensor(w_node, parameters_mapping):
return False

# Bias must be constant (if present)
if b_node is not None and not node_is_effectively_static_tensor(
b_node, parameters_mapping
):
return False

# kernelH <= 4096, kernelW <= 4096
# strideH <= 4096, strideW <= 4096
# dilationH <= 4096, dilationW <= 4096
w_node_shape = (
w_node.meta["val"].shape if hasattr(w_node, "meta") else w_node.shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you encountered a case where the weights don't have the meta attribute?
I have not seen such a case yet and I believe we rely meta being available in our codebase.

)

kernel_h = w_node_shape[2]
kernel_w = w_node_shape[3]
stride_h = stride[0]
stride_w = stride[1]
dilation_h = dilation[0]
dilation_w = dilation[1]

dim_sizes = [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w]

if any(dim > 4096 for dim in dim_sizes):
return False

# kernelH * kernelW * inpC <= 65535
inp_node_shape = (
inp_node.meta["val"].shape if hasattr(inp_node, "meta") else inp_node.shape
)
inp_channels = inp_node_shape[1]

if kernel_h * kernel_w * inp_channels > 65535:
return False

return True

@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if custom_delegation_options.use_new_flow_neutron_c:
return ConvolutionConverter._is_supported_on_target_new_flow(
node, parameters_mapping
)

num_macs = neutron_target_spec.get_num_macs()
node_t_params = get_node_tensor_params(node)
weights = node.args[1]
conv_params = ConvParameters(
*ConvolutionConverter._get_convolution_arguments(node)
_, w_node, _, stride, padding, dilation, transposed, _, groups = (
ConvolutionConverter._get_convolution_arguments(node)
)

if node_t_params["batch_size"] != 1:
# Only batch size 1 is supported on neutron.
return False

if conv_params.transposed:
if transposed:
# TransposeConv2d with groups > 1 is not supported
# TODO: split into multiple convs with groups = 1
if conv_params.groups > 1:
if groups > 1:
return False
if not node_is_effectively_static_tensor(weights, parameters_mapping):
if not node_is_effectively_static_tensor(w_node, parameters_mapping):
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
# weights. In case the weights are dynamic, a Transpose operator would have to be added, which
# is not supported on Neutron.
return False
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind
if (
conv_params.dilation != [1, 1]
or conv_params.padding[0] != 0
or conv_params.padding[1] >= node_t_params["kernel_width"]
dilation != [1, 1]
or padding[0] != 0
or padding[1] >= node_t_params["kernel_width"]
or (
conv_params.padding[1] != 0 and node_t_params["inp_height"] != 1
padding[1] != 0 and node_t_params["inp_height"] != 1
) # Slice added by explicit padding
or conv_params.stride[0] != 1
or stride[0] != 1
or (
(
conv_params.stride[1] != node_t_params["kernel_width"] / 2
stride[1] != node_t_params["kernel_width"] / 2
or node_t_params["out_height"] != 1
)
and conv_params.stride[1] != node_t_params["kernel_width"]
and stride[1] != node_t_params["kernel_width"]
)
or conv_params.stride[1] % 2 != 0
or stride[1] % 2 != 0
or node_t_params["inp_channels"] % num_macs != 0
or node_t_params["out_channels"] % num_macs != 0
or node_t_params["kernel_width"] % 2 != 0
or node_t_params["kernel_height"] != 1
):
return False
elif conv_params.groups == 1: # Regular convolution.
elif groups == 1: # Regular convolution.
pass
elif conv_utils.group_conv_convertible_as_depthwise(
node, conv_params.groups
node, groups
): # Depthwise convolution.
# Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted
# weights. In case the weights are dynamic, a Transpose operator would have to be added, which
# is not supported on Neutron.
if not node_is_effectively_static_tensor(weights, parameters_mapping):
if not node_is_effectively_static_tensor(w_node, parameters_mapping):
return False
elif conv_utils.group_conv_convertible_into_multiple_convolutions(
node, conv_params.groups
node, groups
): # Separable conv.
# Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron.
return False
Expand Down Expand Up @@ -149,10 +244,6 @@ def _is_supported_in_IR(

return True

Stride = Padding = Dilation = OutPadding = list[int]
Transposed = bool
Groups = int

def _compute_slicing_params(
self, output_shape, explicit_padding
) -> tuple[list[int], list[int]]:
Expand All @@ -170,14 +261,14 @@ def _compute_slicing_params(
@staticmethod
def _get_convolution_arguments(
conv_node: Node,
) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups):
# The arguments of the conv are:
# [x, w, b, stride, padding, dilation, transposed, output padding, groups]
# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
) -> ConvolutionArgs:
x, w, b, stride, padding, dilation, transposed, out_padding, groups = (
conv_node.args
)
return (
x,
w,
b,
list(stride),
list(padding),
list(dilation),
Expand Down Expand Up @@ -380,16 +471,8 @@ def _convert_2d_conv(

elif conv_utils.group_conv_convertible_into_multiple_convolutions(
t_op, conv_params.groups
): # Convert to separated `Conv2D`.
t_op.builtin_options = conv_2d_options.Conv2D()

return conv_utils.create_separated_convolutions_based_on_group(
t_op,
conv_params,
self.builder,
self._convert_unpadded_2D,
conv_utils.conv_op_factory,
)
):
raise RuntimeError("Group convolution was not decomposed.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
Consider mentioning the NXP backend in the message. E.g. raise RuntimeError("NXP backend: Group convolution was not decomposed.").


else:
# Convert to regular `Conv2D`.
Expand Down Expand Up @@ -419,7 +502,7 @@ def _convert_2d_conv(
def convert(self, node: Node):
self.assert_convertible(node)

stride, padding, dilation, transposed, out_padding, groups = (
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
self._get_convolution_arguments(node)
)

Expand Down
Loading
Loading