diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index f089275c221f769daace3e9628a00bf87b4e5457..f04c6ba9a79457ff23bd84fbb92a37756be86fe8 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -317,11 +317,13 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Transpose": + if n.op_type == "Transpose" and not model.is_fork_node(n): perms = list(get_by_name(n.attribute, "perm").ints) if perms == [0, 3, 1, 2]: mt_cand = model.find_consumer(n.output[0]) - if mt_cand.op_type == "MultiThreshold": + if mt_cand.op_type == "MultiThreshold" and not model.is_fork_node( + mt_cand + ): final_t_cand = model.find_consumer(mt_cand.output[0]) if final_t_cand.op_type == "Transpose": perms = list(