diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py index aa059747b602bc6b659bc8b53b1f18988bba1ef0..67824ad4f633983b93e3178d03118927a1ddd85b 100644 --- a/src/finn/transformation/streamline/collapse_repeated.py +++ b/src/finn/transformation/streamline/collapse_repeated.py @@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == self.op_name: + if ( + n.op_type == self.op_name + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == self.op_name: + if ( + consumer is not None + and consumer.op_type == self.op_name + and not model.is_join_node(consumer) + ): op0_param_name = n.input[1] op1_param_name = consumer.input[1] op0_param = model.get_initializer(op0_param_name)