diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 4c9b076b6acb1e725421d5067ac74ad42c3dbafb..50cd9ed4ff663044433c5d8997f2fcd9337a0dd6 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -46,10 +46,8 @@ from finn.transformation.streamline import Streamline from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.move_reshape import RemoveCNVtoFCFlatten from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul -from finn.transformation.streamline.reorder import ( - MakeMaxPoolNHWC, - MoveScalarLinearPastInvariants, -) +from finn.transformation.streamline.reorder import MakeMaxPoolNHWC + from shutil import copy, copytree from finn.transformation.fpgadataflow.insert_dwc import InsertDWC from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO @@ -159,13 +157,13 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig): """ model = model.transform(absorb.AbsorbSignBiasIntoMultiThreshold()) - model = model.transform(MoveScalarLinearPastInvariants()) model = model.transform(Streamline()) need_lowering = len(model.get_nodes_by_op_type("Conv")) > 0 if need_lowering: model = model.transform(LowerConvsToMatMul()) model = model.transform(MakeMaxPoolNHWC()) model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold()) + model = model.transform(MakeMaxPoolNHWC()) model = model.transform(ConvertBipolarMatMulToXnorPopcount()) model = model.transform(Streamline()) # absorb final add-mul nodes into TopK @@ -182,7 +180,7 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig): def step_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig): """Convert eligible nodes to `HLSCustomOp` subclasses that represent HLS layers. Which nodes and particular configurations can be converted to HLS - is limited, see the source code of the `convert_to_hls` module for more. """ + is limited, see the source code of the `convert_to_hls` module for more.""" mem_mode = cfg.default_mem_mode.value if cfg.standalone_thresholds: diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index 876f8892dbc9c42189ee8dc06ff5eb407f7a0946..97cd957ce166255c1442a544f6e865b42c33a1df 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -60,6 +60,7 @@ from finn.transformation.streamline.reorder import ( MoveAddPastConv, MoveScalarMulPastConv, MoveMulPastMaxPool, + MoveScalarLinearPastInvariants, ) from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds @@ -78,6 +79,7 @@ class Streamline(Transformation): BatchNormToAffine(), ConvertSignToThres(), MoveMulPastMaxPool(), + MoveScalarLinearPastInvariants(), AbsorbSignBiasIntoMultiThreshold(), MoveAddPastMul(), MoveScalarAddPastMatMul(), diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 9b842162f7f751c60b18bbd288ff96ef28d3aa88..9c6472d4f7fcebb6680b5f78b08bb4538d36df8f 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -308,9 +308,8 @@ class Absorb1BitMulIntoConv(Transformation): class AbsorbTransposeIntoMultiThreshold(Transformation): - """Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold) - with NHWC mode. For (NHWCTranpose -> MultiThreshold -> Flatten), move Transpose - past MultiThreshold to prepare for the RemoveCNVtoFCFlatten() transformation.""" + """Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold) + with NHWC mode. For (NCHWTranspose -> MultiThreshold), move Transpose past MT.""" def apply(self, model): graph = model.graph @@ -339,35 +338,26 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): graph.node.remove(n) graph.node.remove(final_t_cand) graph_modified = True - # also support implicit flatten via reshape, e.g. reshape(1,-1) - elif ( - final_t_cand.op_type == "Flatten" - or final_t_cand.op_type == "Reshape" - ): - ishape = model.get_tensor_shape(final_t_cand.input[0]) - oshape = model.get_tensor_shape(final_t_cand.output[0]) - if len(oshape) == 2 and ishape[0] == oshape[0]: - # transition to FC part, can still use NHWC - mt = getCustomOp(mt_cand) - mt.set_nodeattr("data_layout", "NHWC") - # get rid of first tranpose node - mt_cand.input[0] = n.input[0] - graph.node.remove(n) - # fix output shape for MultiThreshold - mt_ishape = model.get_tensor_shape(mt_cand.input[0]) - (b, h, w, c) = mt_ishape - model.set_tensor_shape(mt_cand.output[0], mt_ishape) - # re-insert Transpose behind MultiThreshold - transpose_output = model.make_new_valueinfo_name() - new_transpose = oh.make_node( - "Transpose", - [mt_cand.output[0]], - [transpose_output], - perm=[0, 3, 1, 2], - ) - graph.node.insert(node_ind + 1, new_transpose) - final_t_cand.input[0] = transpose_output - graph_modified = True + else: + mt = getCustomOp(mt_cand) + mt.set_nodeattr("data_layout", "NHWC") + # get rid of first tranpose node + mt_cand.input[0] = n.input[0] + graph.node.remove(n) + # fix output shape for MultiThreshold + mt_ishape = model.get_tensor_shape(mt_cand.input[0]) + model.set_tensor_shape(mt_cand.output[0], mt_ishape) + # re-insert Transpose behind MultiThreshold + transpose_output = model.make_new_valueinfo_name() + new_transpose = oh.make_node( + "Transpose", + [mt_cand.output[0]], + [transpose_output], + perm=[0, 3, 1, 2], + ) + graph.node.insert(node_ind + 1, new_transpose) + final_t_cand.input[0] = transpose_output + graph_modified = True if graph_modified: model = model.transform(InferDataTypes()) return (model, graph_modified) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 7163a95c4dbbe5c8bcee4ebeea87c5e9611c179e..4049d7bc8b555dce6f3ab6f9475f6aebf41fbc2c 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -645,7 +645,8 @@ class MoveScalarLinearPastInvariants(Transformation): class MakeMaxPoolNHWC(Transformation): - """Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC).""" + """Convert (MaxPool, NHWCTranspose) into (NHWCTranspose, MaxPoolNHWC) + and (NCHWTranspose, MaxPool) into (MaxPoolNHWC, NCHWTranspose).""" def apply(self, model): graph = model.graph @@ -655,6 +656,7 @@ class MakeMaxPoolNHWC(Transformation): node_ind += 1 if n.op_type == "MaxPool": consumer = model.find_consumer(n.output[0]) + producer = model.find_producer(n.input[0]) if consumer is not None and consumer.op_type == "Transpose": perms = list(get_by_name(consumer.attribute, "perm").ints) if perms == [0, 2, 3, 1]: @@ -674,6 +676,25 @@ class MakeMaxPoolNHWC(Transformation): graph.node.remove(consumer) graph.node.insert(node_ind - 1, consumer) graph_modified = True + elif producer is not None and producer.op_type == "Transpose": + perms = list(get_by_name(producer.attribute, "perm").ints) + if perms == [0, 3, 1, 2]: + n.op_type = "MaxPoolNHWC" + n.domain = "finn.custom_op.general" + start_name = producer.input[0] + mid_name = n.input[0] + end_name = n.output[0] + (b, hi, wi, c) = model.get_tensor_shape(start_name) + (b, c, ho, wo) = model.get_tensor_shape(end_name) + producer.input[0] = mid_name + producer.output[0] = end_name + n.input[0] = start_name + n.output[0] = mid_name + model.set_tensor_shape(mid_name, (b, ho, wo, c)) + model.set_tensor_shape(end_name, (b, c, ho, wo)) + graph.node.remove(producer) + graph.node.insert(node_ind, producer) + graph_modified = True return (model, graph_modified)