diff --git a/src/finn/transformation/change_datalayout.py b/src/finn/transformation/change_datalayout.py index 5b1e70052ccc28551ba6609e4dda72bf9450c2b0..d5b393a25e57122b059a44f70904a6dbe5bbaa3f 100644 --- a/src/finn/transformation/change_datalayout.py +++ b/src/finn/transformation/change_datalayout.py @@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "QuantAvgPool2d": + if n.op_type == "QuantAvgPool2d" and ( + get_by_name(n.attribute, "data_layout") is None + or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW" + ): graph_modified = True node_input = n.input[0] node_output = n.output[0] @@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1] ) quantavg_node = helper.make_node( - "QuantAvgPool2dNHWC", + "QuantAvgPool2d", [inp_trans_out], [quantavg_out], domain="finn", @@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ibits=ibits, obits=obits, signed=signed, + data_layout="NHWC", ) # NHWC -> NCHW out_trans_node = helper.make_node( @@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): # remove old nodes graph.node.remove(n) + # set shapes + model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels)) + model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels)) model = model.transform(InferShapes()) return (model, graph_modified)