Skip to content
Snippets Groups Projects
Commit 9c2a7eea authored by mmrahorovic's avatar mmrahorovic
Browse files

[transform]: support for ceil mode MaxPool

parent 3addec34
No related branches found
No related tags found
No related merge requests found
......@@ -670,6 +670,7 @@ class MakeMaxPoolNHWC(Transformation):
if consumer is not None and consumer.op_type == "Transpose":
perms = list(get_by_name(consumer.attribute, "perm").ints)
if perms == [0, 2, 3, 1]:
ceil_mode = get_by_name(n.attribute, "ceil_mode").i
n.op_type = "MaxPoolNHWC"
n.domain = "finn.custom_op.general"
start_name = n.input[0]
......@@ -683,12 +684,14 @@ class MakeMaxPoolNHWC(Transformation):
n.output[0] = end_name
model.set_tensor_shape(mid_name, (b, hi, wi, c))
model.set_tensor_shape(end_name, (b, ho, wo, c))
getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode)
graph.node.remove(consumer)
graph.node.insert(node_ind - 1, consumer)
graph_modified = True
elif producer is not None and producer.op_type == "Transpose":
perms = list(get_by_name(producer.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
ceil_mode = get_by_name(n.attribute, "ceil_mode").i
n.op_type = "MaxPoolNHWC"
n.domain = "finn.custom_op.general"
start_name = producer.input[0]
......@@ -702,6 +705,7 @@ class MakeMaxPoolNHWC(Transformation):
n.output[0] = mid_name
model.set_tensor_shape(mid_name, (b, ho, wo, c))
model.set_tensor_shape(end_name, (b, c, ho, wo))
getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode)
graph.node.remove(producer)
graph.node.insert(node_ind, producer)
graph_modified = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment