Skip to content
Snippets Groups Projects
Commit 011d9a68 authored by auphelia's avatar auphelia
Browse files

[Transform] Update change datalayout trafo for quantavgpool

parent b291eb06
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "QuantAvgPool2d":
if n.op_type == "QuantAvgPool2d" and (
get_by_name(n.attribute, "data_layout") is None
or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW"
):
graph_modified = True
node_input = n.input[0]
node_output = n.output[0]
......@@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
"Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
)
quantavg_node = helper.make_node(
"QuantAvgPool2dNHWC",
"QuantAvgPool2d",
[inp_trans_out],
[quantavg_out],
domain="finn",
......@@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
ibits=ibits,
obits=obits,
signed=signed,
data_layout="NHWC",
)
# NHWC -> NCHW
out_trans_node = helper.make_node(
......@@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
# remove old nodes
graph.node.remove(n)
# set shapes
model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels))
model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels))
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment