diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index dd3e3ac69de5516cb9f0e6ffe4161ef893d7dea4..19140a57ad2baa8290367ce23fdc0a55bbfbee17 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -32,6 +32,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.core.datatype import DataType from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp @@ -361,29 +362,36 @@ class MoveMulPastDWConv(Transformation): ): mul_weight_name = n.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + if A is None: + continue conv_node = consumer mul_node = n start_name = mul_node.input[0] conv_in_name = conv_node.input[0] conv_in_shape = model.get_tensor_shape(conv_in_name) ifm_ch = conv_in_shape[1] - group_attribute = get_by_name(consumer.attribute, "group").i + group_attribute = get_by_name(consumer.attribute, "group") + if group_attribute is None: + continue + group_attribute = group_attribute.i conv_out_name = conv_node.output[0] conv_out_shape = model.get_tensor_shape(conv_out_name) - if np.prod(A.shape) == ifm_ch == group_attribute: + if A.shape == (1, ifm_ch, 1, 1) and ifm_ch == group_attribute: # if the mul is channelwise and conv is depthwise, # we can simply swap the order of ops # rewire mul input to be conv input conv_node.input[0] = start_name model.set_tensor_shape(start_name, conv_in_shape) + model.set_tensor_datatype(start_name, DataType.FLOAT32) # use old conv input tensor as conv output conv_node.output[0] = conv_in_name model.set_tensor_shape(conv_in_name, conv_out_shape) + model.set_tensor_datatype(conv_in_name, DataType.FLOAT32) # use new conv output as new mul node input mul_node.input[0] = conv_in_name # use old conv output as new mul node output mul_node.output[0] = conv_out_name + model.set_tensor_datatype(conv_out_name, DataType.FLOAT32) # move mul node past conv node graph.node.remove(mul_node) graph.node.insert(node_ind, mul_node)