From c903048bf1a073d85c6513ee2b1b2825cd0a0fe8 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Fri, 19 Jun 2020 11:45:31 +0100 Subject: [PATCH] [Transformation] Add check to AbsorbTransposeIntoMultiThreshold Transformation to restrict it to linear segments --- src/finn/transformation/streamline/absorb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index dbcf97361..3dfd4a007 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -250,11 +250,13 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Transpose": + if n.op_type == "Transpose" and not model.is_fork_node(n): perms = list(get_by_name(n.attribute, "perm").ints) if perms == [0, 3, 1, 2]: mt_cand = model.find_consumer(n.output[0]) - if mt_cand.op_type == "MultiThreshold": + if mt_cand.op_type == "MultiThreshold" and not model.is_fork_node( + mt_cand + ): final_t_cand = model.find_consumer(mt_cand.output[0]) if final_t_cand.op_type == "Transpose": perms = list( -- GitLab