From 011d9a686b262123596dd820de32bccd5c45ee8f Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 24 Jun 2020 12:04:08 +0100 Subject: [PATCH] [Transform] Update change datalayout trafo for quantavgpool --- src/finn/transformation/change_datalayout.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/finn/transformation/change_datalayout.py b/src/finn/transformation/change_datalayout.py index 5b1e70052..d5b393a25 100644 --- a/src/finn/transformation/change_datalayout.py +++ b/src/finn/transformation/change_datalayout.py @@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "QuantAvgPool2d": + if n.op_type == "QuantAvgPool2d" and ( + get_by_name(n.attribute, "data_layout") is None + or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW" + ): graph_modified = True node_input = n.input[0] node_output = n.output[0] @@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1] ) quantavg_node = helper.make_node( - "QuantAvgPool2dNHWC", + "QuantAvgPool2d", [inp_trans_out], [quantavg_out], domain="finn", @@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ibits=ibits, obits=obits, signed=signed, + data_layout="NHWC", ) # NHWC -> NCHW out_trans_node = helper.make_node( @@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): # remove old nodes graph.node.remove(n) + # set shapes + model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels)) + model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels)) model = model.transform(InferShapes()) return (model, graph_modified) -- GitLab