Skip to content
Snippets Groups Projects
Commit 011d9a68 authored by auphelia's avatar auphelia
Browse files

[Transform] Update change datalayout trafo for quantavgpool

parent b291eb06
No related branches found
No related tags found
No related merge requests found
...@@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ...@@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
graph_modified = False graph_modified = False
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "QuantAvgPool2d": if n.op_type == "QuantAvgPool2d" and (
get_by_name(n.attribute, "data_layout") is None
or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW"
):
graph_modified = True graph_modified = True
node_input = n.input[0] node_input = n.input[0]
node_output = n.output[0] node_output = n.output[0]
...@@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ...@@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
"Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1] "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
) )
quantavg_node = helper.make_node( quantavg_node = helper.make_node(
"QuantAvgPool2dNHWC", "QuantAvgPool2d",
[inp_trans_out], [inp_trans_out],
[quantavg_out], [quantavg_out],
domain="finn", domain="finn",
...@@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ...@@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
ibits=ibits, ibits=ibits,
obits=obits, obits=obits,
signed=signed, signed=signed,
data_layout="NHWC",
) )
# NHWC -> NCHW # NHWC -> NCHW
out_trans_node = helper.make_node( out_trans_node = helper.make_node(
...@@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation): ...@@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
# remove old nodes # remove old nodes
graph.node.remove(n) graph.node.remove(n)
# set shapes
model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels))
model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels))
model = model.transform(InferShapes()) model = model.transform(InferShapes())
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment