Skip to content
Snippets Groups Projects
Commit d6bfc7c2 authored by Felix Jentzsch's avatar Felix Jentzsch
Browse files

Streamlining modifications to support MaxPool before Flatten

parent 23ca60b8
No related branches found
No related tags found
No related merge requests found
...@@ -46,10 +46,8 @@ from finn.transformation.streamline import Streamline ...@@ -46,10 +46,8 @@ from finn.transformation.streamline import Streamline
from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.streamline.reorder import ( from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
MakeMaxPoolNHWC,
MoveScalarLinearPastInvariants,
)
from shutil import copy, copytree from shutil import copy, copytree
from finn.transformation.fpgadataflow.insert_dwc import InsertDWC from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
...@@ -159,13 +157,13 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig): ...@@ -159,13 +157,13 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig):
""" """
model = model.transform(absorb.AbsorbSignBiasIntoMultiThreshold()) model = model.transform(absorb.AbsorbSignBiasIntoMultiThreshold())
model = model.transform(MoveScalarLinearPastInvariants())
model = model.transform(Streamline()) model = model.transform(Streamline())
need_lowering = len(model.get_nodes_by_op_type("Conv")) > 0 need_lowering = len(model.get_nodes_by_op_type("Conv")) > 0
if need_lowering: if need_lowering:
model = model.transform(LowerConvsToMatMul()) model = model.transform(LowerConvsToMatMul())
model = model.transform(MakeMaxPoolNHWC()) model = model.transform(MakeMaxPoolNHWC())
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold()) model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(MakeMaxPoolNHWC())
model = model.transform(ConvertBipolarMatMulToXnorPopcount()) model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(Streamline()) model = model.transform(Streamline())
# absorb final add-mul nodes into TopK # absorb final add-mul nodes into TopK
...@@ -182,7 +180,7 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig): ...@@ -182,7 +180,7 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig):
def step_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig): def step_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig):
"""Convert eligible nodes to `HLSCustomOp` subclasses that represent HLS """Convert eligible nodes to `HLSCustomOp` subclasses that represent HLS
layers. Which nodes and particular configurations can be converted to HLS layers. Which nodes and particular configurations can be converted to HLS
is limited, see the source code of the `convert_to_hls` module for more. """ is limited, see the source code of the `convert_to_hls` module for more."""
mem_mode = cfg.default_mem_mode.value mem_mode = cfg.default_mem_mode.value
if cfg.standalone_thresholds: if cfg.standalone_thresholds:
......
...@@ -60,6 +60,7 @@ from finn.transformation.streamline.reorder import ( ...@@ -60,6 +60,7 @@ from finn.transformation.streamline.reorder import (
MoveAddPastConv, MoveAddPastConv,
MoveScalarMulPastConv, MoveScalarMulPastConv,
MoveMulPastMaxPool, MoveMulPastMaxPool,
MoveScalarLinearPastInvariants,
) )
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
...@@ -78,6 +79,7 @@ class Streamline(Transformation): ...@@ -78,6 +79,7 @@ class Streamline(Transformation):
BatchNormToAffine(), BatchNormToAffine(),
ConvertSignToThres(), ConvertSignToThres(),
MoveMulPastMaxPool(), MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(), AbsorbSignBiasIntoMultiThreshold(),
MoveAddPastMul(), MoveAddPastMul(),
MoveScalarAddPastMatMul(), MoveScalarAddPastMatMul(),
......
...@@ -308,9 +308,8 @@ class Absorb1BitMulIntoConv(Transformation): ...@@ -308,9 +308,8 @@ class Absorb1BitMulIntoConv(Transformation):
class AbsorbTransposeIntoMultiThreshold(Transformation): class AbsorbTransposeIntoMultiThreshold(Transformation):
"""Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold) """Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold)
with NHWC mode. For (NHWCTranpose -> MultiThreshold -> Flatten), move Transpose with NHWC mode. For (NCHWTranspose -> MultiThreshold), move Transpose past MT."""
past MultiThreshold to prepare for the RemoveCNVtoFCFlatten() transformation."""
def apply(self, model): def apply(self, model):
graph = model.graph graph = model.graph
...@@ -339,35 +338,26 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): ...@@ -339,35 +338,26 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
graph.node.remove(n) graph.node.remove(n)
graph.node.remove(final_t_cand) graph.node.remove(final_t_cand)
graph_modified = True graph_modified = True
# also support implicit flatten via reshape, e.g. reshape(1,-1) else:
elif ( mt = getCustomOp(mt_cand)
final_t_cand.op_type == "Flatten" mt.set_nodeattr("data_layout", "NHWC")
or final_t_cand.op_type == "Reshape" # get rid of first tranpose node
): mt_cand.input[0] = n.input[0]
ishape = model.get_tensor_shape(final_t_cand.input[0]) graph.node.remove(n)
oshape = model.get_tensor_shape(final_t_cand.output[0]) # fix output shape for MultiThreshold
if len(oshape) == 2 and ishape[0] == oshape[0]: mt_ishape = model.get_tensor_shape(mt_cand.input[0])
# transition to FC part, can still use NHWC model.set_tensor_shape(mt_cand.output[0], mt_ishape)
mt = getCustomOp(mt_cand) # re-insert Transpose behind MultiThreshold
mt.set_nodeattr("data_layout", "NHWC") transpose_output = model.make_new_valueinfo_name()
# get rid of first tranpose node new_transpose = oh.make_node(
mt_cand.input[0] = n.input[0] "Transpose",
graph.node.remove(n) [mt_cand.output[0]],
# fix output shape for MultiThreshold [transpose_output],
mt_ishape = model.get_tensor_shape(mt_cand.input[0]) perm=[0, 3, 1, 2],
(b, h, w, c) = mt_ishape )
model.set_tensor_shape(mt_cand.output[0], mt_ishape) graph.node.insert(node_ind + 1, new_transpose)
# re-insert Transpose behind MultiThreshold final_t_cand.input[0] = transpose_output
transpose_output = model.make_new_valueinfo_name() graph_modified = True
new_transpose = oh.make_node(
"Transpose",
[mt_cand.output[0]],
[transpose_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
final_t_cand.input[0] = transpose_output
graph_modified = True
if graph_modified: if graph_modified:
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
return (model, graph_modified) return (model, graph_modified)
......
...@@ -645,7 +645,8 @@ class MoveScalarLinearPastInvariants(Transformation): ...@@ -645,7 +645,8 @@ class MoveScalarLinearPastInvariants(Transformation):
class MakeMaxPoolNHWC(Transformation): class MakeMaxPoolNHWC(Transformation):
"""Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC).""" """Convert (MaxPool, NHWCTranspose) into (NHWCTranspose, MaxPoolNHWC)
and (NCHWTranspose, MaxPool) into (MaxPoolNHWC, NCHWTranspose)."""
def apply(self, model): def apply(self, model):
graph = model.graph graph = model.graph
...@@ -655,6 +656,7 @@ class MakeMaxPoolNHWC(Transformation): ...@@ -655,6 +656,7 @@ class MakeMaxPoolNHWC(Transformation):
node_ind += 1 node_ind += 1
if n.op_type == "MaxPool": if n.op_type == "MaxPool":
consumer = model.find_consumer(n.output[0]) consumer = model.find_consumer(n.output[0])
producer = model.find_producer(n.input[0])
if consumer is not None and consumer.op_type == "Transpose": if consumer is not None and consumer.op_type == "Transpose":
perms = list(get_by_name(consumer.attribute, "perm").ints) perms = list(get_by_name(consumer.attribute, "perm").ints)
if perms == [0, 2, 3, 1]: if perms == [0, 2, 3, 1]:
...@@ -674,6 +676,25 @@ class MakeMaxPoolNHWC(Transformation): ...@@ -674,6 +676,25 @@ class MakeMaxPoolNHWC(Transformation):
graph.node.remove(consumer) graph.node.remove(consumer)
graph.node.insert(node_ind - 1, consumer) graph.node.insert(node_ind - 1, consumer)
graph_modified = True graph_modified = True
elif producer is not None and producer.op_type == "Transpose":
perms = list(get_by_name(producer.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
n.op_type = "MaxPoolNHWC"
n.domain = "finn.custom_op.general"
start_name = producer.input[0]
mid_name = n.input[0]
end_name = n.output[0]
(b, hi, wi, c) = model.get_tensor_shape(start_name)
(b, c, ho, wo) = model.get_tensor_shape(end_name)
producer.input[0] = mid_name
producer.output[0] = end_name
n.input[0] = start_name
n.output[0] = mid_name
model.set_tensor_shape(mid_name, (b, ho, wo, c))
model.set_tensor_shape(end_name, (b, c, ho, wo))
graph.node.remove(producer)
graph.node.insert(node_ind, producer)
graph_modified = True
return (model, graph_modified) return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment