From 9c2a7eeac4618c12c3f0e82ee06afb4b37e6aa2f Mon Sep 17 00:00:00 2001 From: mmrahorovic <mmrahorovic@hotmail.com> Date: Sun, 20 Feb 2022 14:30:08 +0000 Subject: [PATCH] [transform]: support for ceil mode MaxPool --- src/finn/transformation/streamline/reorder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 0cdd6651d..746972d70 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -670,6 +670,7 @@ class MakeMaxPoolNHWC(Transformation): if consumer is not None and consumer.op_type == "Transpose": perms = list(get_by_name(consumer.attribute, "perm").ints) if perms == [0, 2, 3, 1]: + ceil_mode = get_by_name(n.attribute, "ceil_mode").i n.op_type = "MaxPoolNHWC" n.domain = "finn.custom_op.general" start_name = n.input[0] @@ -683,12 +684,14 @@ class MakeMaxPoolNHWC(Transformation): n.output[0] = end_name model.set_tensor_shape(mid_name, (b, hi, wi, c)) model.set_tensor_shape(end_name, (b, ho, wo, c)) + getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode) graph.node.remove(consumer) graph.node.insert(node_ind - 1, consumer) graph_modified = True elif producer is not None and producer.op_type == "Transpose": perms = list(get_by_name(producer.attribute, "perm").ints) if perms == [0, 3, 1, 2]: + ceil_mode = get_by_name(n.attribute, "ceil_mode").i n.op_type = "MaxPoolNHWC" n.domain = "finn.custom_op.general" start_name = producer.input[0] @@ -702,6 +705,7 @@ class MakeMaxPoolNHWC(Transformation): n.output[0] = mid_name model.set_tensor_shape(mid_name, (b, ho, wo, c)) model.set_tensor_shape(end_name, (b, c, ho, wo)) + getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode) graph.node.remove(producer) graph.node.insert(node_ind, producer) graph_modified = True -- GitLab