diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 0cdd6651d982426b1d81d7313346dcd899294bf7..746972d70e75c50595a548590dc4f8d2adf87992 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -670,6 +670,7 @@ class MakeMaxPoolNHWC(Transformation): if consumer is not None and consumer.op_type == "Transpose": perms = list(get_by_name(consumer.attribute, "perm").ints) if perms == [0, 2, 3, 1]: + ceil_mode = get_by_name(n.attribute, "ceil_mode").i n.op_type = "MaxPoolNHWC" n.domain = "finn.custom_op.general" start_name = n.input[0] @@ -683,12 +684,14 @@ class MakeMaxPoolNHWC(Transformation): n.output[0] = end_name model.set_tensor_shape(mid_name, (b, hi, wi, c)) model.set_tensor_shape(end_name, (b, ho, wo, c)) + getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode) graph.node.remove(consumer) graph.node.insert(node_ind - 1, consumer) graph_modified = True elif producer is not None and producer.op_type == "Transpose": perms = list(get_by_name(producer.attribute, "perm").ints) if perms == [0, 3, 1, 2]: + ceil_mode = get_by_name(n.attribute, "ceil_mode").i n.op_type = "MaxPoolNHWC" n.domain = "finn.custom_op.general" start_name = producer.input[0] @@ -702,6 +705,7 @@ class MakeMaxPoolNHWC(Transformation): n.output[0] = mid_name model.set_tensor_shape(mid_name, (b, ho, wo, c)) model.set_tensor_shape(end_name, (b, c, ho, wo)) + getCustomOp(n).set_nodeattr("ceil_mode", ceil_mode) graph.node.remove(producer) graph.node.insert(node_ind, producer) graph_modified = True