diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 65e67721c8b6b44af060ba59eddfcd72dadc5fc7..8398a277443530e84632d26fbfca6d90ea4b0b9e 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -317,11 +317,13 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Transpose": + if n.op_type == "Transpose" and not model.is_fork_node(n): perms = list(get_by_name(n.attribute, "perm").ints) if perms == [0, 3, 1, 2]: mt_cand = model.find_consumer(n.output[0]) - if mt_cand.op_type == "MultiThreshold": + if mt_cand.op_type == "MultiThreshold" and not model.is_fork_node( + mt_cand + ): final_t_cand = model.find_consumer(mt_cand.output[0]) if final_t_cand.op_type == "Transpose": perms = list(