Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions backends/nxp/aten_passes/split_group_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,27 @@ def _create_convolution_node(self, conv_target, args: tuple) -> Node:
# Compute the output shapes for the `convolution`, and assign the `val` meta.
with FakeTensorMode() as mode:
input_shapes = [
input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape
(
input_.meta["val"].shape
if hasattr(input_, "meta")
else input_.shape if input_ is not None else None
)
for input_ in args[:3]
]
input_dtypes = [
input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype
(
input_.meta["val"].dtype
if hasattr(input_, "meta")
else input_.dtype if input_ is not None else None
)
for input_ in args[:3]
]
fake_inputs = [
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
(
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
if shape is not None and dtype is not None
else None
)
for shape, dtype in zip(input_shapes, input_dtypes)
]
output = conv_target(*fake_inputs, *args[3:])
Expand Down Expand Up @@ -211,8 +223,11 @@ def _is_conv(node_: Node):

w_data = self._get_tensor_constant_from_node(w)
b_data = self._get_tensor_constant_from_node(b)
if w_data is None or b_data is None:
continue # Only the standard case with static weights and bias is supported.

with_bias = b is not None
# Only the standard case with static weights and static bias (or bias=False) is supported.
if w_data is None or (b_data is None and with_bias):
continue
Comment thread
novak-vaclav marked this conversation as resolved.
Comment thread
novak-vaclav marked this conversation as resolved.
Comment on lines 224 to +230

# Create a `split` node to split the main input.
# Split across dimension `1` (channels), `groups` slices of size `input_split_size`.
Expand All @@ -227,10 +242,9 @@ def _is_conv(node_: Node):
for i in range(groups)
]

# Split the weights and bias, across dimension `0`, slices of size `weight_split_size`.
# Split the weights across dimension `0`, slices of size `weight_split_size`.
weight_split_size = w.meta["val"].shape[0] // groups
split_weights_data = torch.split(w_data, weight_split_size, 0)
split_bias_data = torch.split(b_data, weight_split_size, 0)

# Turn the weights and biases into parameter nodes containing the data.
# Use a different name for every parameter. The function internally ensures the name's uniqueness, but
Expand All @@ -241,12 +255,17 @@ def _is_conv(node_: Node):
)
for i, weight_data in enumerate(split_weights_data)
]
split_bias_nodes = [
self._create_parameter_node_for_data(
bias_data, b.name + f"_{i}_", split_node
)
for i, bias_data in enumerate(split_bias_data)
]

if with_bias:
split_bias_data = torch.split(b_data, weight_split_size, 0)
split_bias_nodes = [
self._create_parameter_node_for_data(
bias_data, b.name + f"_{i}_", split_node
)
for i, bias_data in enumerate(split_bias_data)
]
else:
split_bias_nodes = [None] * len(split_weight_nodes)

# Create the `conv` nodes.
with self.module.graph.inserting_after(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,78 +48,171 @@
from torch.nn import Parameter


# The arguments of the conv are:
# convolution(
# Tensor input, Tensor weight, Tensor? bias,
# SymInt[] stride, SymInt[] padding, SymInt[] dilation,
# bool transposed, SymInt[] output_padding, SymInt groups
# ) -> Tensor
Stride = Padding = Dilation = OutPadding = list[int]
Transposed = bool
Groups = int
ConvolutionArgs = tuple[
Node, Node, Node | None, Stride, Padding, Dilation, Transposed, OutPadding, Groups
]


class ConvolutionConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(
def _is_supported_on_target_regular_conv(
node: Node,
parameters_mapping: dict[str, Parameter],
) -> bool:
(
inp_node,
w_node,
b_node,
stride,
_,
dilation,
_,
_,
_,
) = ConvolutionConverter._get_convolution_arguments(node)

# Input must be INT8/UINT8
# Output must be INT8/UINT8
inp_out_supported_types = [torch.int8, torch.uint8]
if not NodeConverter.uses_quantization_type_for_io(
node, inp_out_supported_types, [0], [0]
):
return False

# Weights must be INT8
w_supported_types = [torch.int8]
if not NodeConverter.uses_quantization_type_for_io(
node, w_supported_types, [1], []
):
return False

# Bias must be INT32
if b_node is not None:
b_supported_types = [torch.int32]
if not NodeConverter.uses_quantization_type_for_io(
node, b_supported_types, [2], []
):
return False

# Weights must be constant
if not node_is_effectively_static_tensor(w_node, parameters_mapping):
return False

# Bias must be constant (if present)
if b_node is not None and not node_is_effectively_static_tensor(
b_node, parameters_mapping
):
return False

# kernelH <= 4096, kernelW <= 4096
# strideH <= 4096, strideW <= 4096
# dilationH <= 4096, dilationW <= 4096
w_node_shape = w_node.meta["val"].shape

kernel_h = w_node_shape[2]
kernel_w = w_node_shape[3]
stride_h = stride[0]
stride_w = stride[1]
dilation_h = dilation[0]
dilation_w = dilation[1]
Comment thread
novak-vaclav marked this conversation as resolved.

dim_sizes = [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w]

if any(dim > 4096 for dim in dim_sizes):
return False

# kernelH * kernelW * inpC <= 65535
inp_node_shape = inp_node.meta["val"].shape
inp_channels = (
inp_node_shape[1] if len(inp_node_shape) == 4 else inp_node_shape[0]
)

if kernel_h * kernel_w * inp_channels > 65535:
return False
Comment thread
novak-vaclav marked this conversation as resolved.

return True

@staticmethod
def _is_supported_on_target_transp_conv(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
# TODO: EIEX-894 update the requirements of delegation for new Neutron flow
_, w_node, _, stride, padding, dilation, transposed, _, groups = (
ConvolutionConverter._get_convolution_arguments(node)
)

num_macs = neutron_target_spec.get_num_macs()
node_t_params = get_node_tensor_params(node)
weights = node.args[1]
conv_params = ConvParameters(
*ConvolutionConverter._get_convolution_arguments(node)
)

if node_t_params["batch_size"] != 1:
# Only batch size 1 is supported on neutron.
# Only TransposeConv2d with batch size = 1 is supported on neutron.
return False

if conv_params.transposed:
# TransposeConv2d with groups > 1 is not supported
# TODO: split into multiple convs with groups = 1
if conv_params.groups > 1:
return False
if not node_is_effectively_static_tensor(weights, parameters_mapping):
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
# weights. In case the weights are dynamic, a Transpose operator would have to be added, which
# is not supported on Neutron.
return False
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind
if (
conv_params.dilation != [1, 1]
or conv_params.padding[0] != 0
or conv_params.padding[1] >= node_t_params["kernel_width"]
or (
conv_params.padding[1] != 0 and node_t_params["inp_height"] != 1
) # Slice added by explicit padding
or conv_params.stride[0] != 1
or (
(
conv_params.stride[1] != node_t_params["kernel_width"] / 2
or node_t_params["out_height"] != 1
)
and conv_params.stride[1] != node_t_params["kernel_width"]
)
or conv_params.stride[1] % 2 != 0
or node_t_params["inp_channels"] % num_macs != 0
or node_t_params["out_channels"] % num_macs != 0
or node_t_params["kernel_width"] % 2 != 0
or node_t_params["kernel_height"] != 1
):
return False
elif conv_params.groups == 1: # Regular convolution.
pass
elif conv_utils.group_conv_convertible_as_depthwise(
node, conv_params.groups
): # Depthwise convolution.
# Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted
# TransposeConv2d with groups > 1 is not supported
# TODO: split into multiple convs with groups = 1
if groups > 1:
return False
if not node_is_effectively_static_tensor(w_node, parameters_mapping):
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
# weights. In case the weights are dynamic, a Transpose operator would have to be added, which
# is not supported on Neutron.
if not node_is_effectively_static_tensor(weights, parameters_mapping):
return False
elif conv_utils.group_conv_convertible_into_multiple_convolutions(
node, conv_params.groups
): # Separable conv.
# Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron.
return False
else: # Unexpected case (should never happen).
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind
if (
dilation != [1, 1]
or padding[0] != 0
or padding[1] >= node_t_params["kernel_width"]
or (
padding[1] != 0 and node_t_params["inp_height"] != 1
) # Slice added by explicit padding
or stride[0] != 1
or (
(
stride[1] != node_t_params["kernel_width"] / 2
or node_t_params["out_height"] != 1
)
and stride[1] != node_t_params["kernel_width"]
)
or stride[1] % 2 != 0
or node_t_params["inp_channels"] % num_macs != 0
or node_t_params["out_channels"] % num_macs != 0
or node_t_params["kernel_width"] % 2 != 0
or node_t_params["kernel_height"] != 1
):
return False

return True

@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
is_transposed = (ConvolutionConverter._get_convolution_arguments(node))[6]

if is_transposed:
return ConvolutionConverter._is_supported_on_target_transp_conv(
node, neutron_target_spec, parameters_mapping
)

else:
return ConvolutionConverter._is_supported_on_target_regular_conv(
node, parameters_mapping
)
Comment thread
novak-vaclav marked this conversation as resolved.

@staticmethod
def _is_supported_in_IR(
node: Node,
Expand Down Expand Up @@ -149,10 +242,6 @@ def _is_supported_in_IR(

return True

Stride = Padding = Dilation = OutPadding = list[int]
Transposed = bool
Groups = int

def _compute_slicing_params(
self, output_shape, explicit_padding
) -> tuple[list[int], list[int]]:
Expand All @@ -170,14 +259,14 @@ def _compute_slicing_params(
@staticmethod
def _get_convolution_arguments(
conv_node: Node,
) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups):
# The arguments of the conv are:
# [x, w, b, stride, padding, dilation, transposed, output padding, groups]
# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
) -> ConvolutionArgs:
x, w, b, stride, padding, dilation, transposed, out_padding, groups = (
conv_node.args
)
return (
x,
w,
b,
list(stride),
list(padding),
list(dilation),
Expand Down Expand Up @@ -380,16 +469,8 @@ def _convert_2d_conv(

elif conv_utils.group_conv_convertible_into_multiple_convolutions(
t_op, conv_params.groups
): # Convert to separated `Conv2D`.
t_op.builtin_options = conv_2d_options.Conv2D()

return conv_utils.create_separated_convolutions_based_on_group(
t_op,
conv_params,
self.builder,
self._convert_unpadded_2D,
conv_utils.conv_op_factory,
)
):
raise RuntimeError("NXP backend: Group convolution was not decomposed.")
Comment thread
novak-vaclav marked this conversation as resolved.
Comment thread
novak-vaclav marked this conversation as resolved.
Comment thread
novak-vaclav marked this conversation as resolved.
Comment thread
novak-vaclav marked this conversation as resolved.
Comment on lines 470 to +473

else:
# Convert to regular `Conv2D`.
Expand Down Expand Up @@ -419,7 +500,7 @@ def _convert_2d_conv(
def convert(self, node: Node):
self.assert_convertible(node)

stride, padding, dilation, transposed, out_padding, groups = (
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
self._get_convolution_arguments(node)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool):
module, tuple(input_shape)
).exported_program()

assert len(edge_program.graph.nodes) == 21
assert len(edge_program.graph.nodes) == 7
assert not graph_contains_any_of_ops(edge_program.graph, batch_norm_target_ops)


Expand Down
Loading
Loading