diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 1b22f474abe3f59ac91551efa3661b2612442776..cd44e115eed42d1c0529b86a49e8855ff7c492ce 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -594,11 +594,17 @@ class MoveScalarLinearPastInvariants(Transformation): nodes = [n for n in graph.node] for n in nodes: node_ind += 1 + is_nearest_neighbor_resample = False + if n.op_type == "Upsample" or n.op_type == "Resize": + # Extract mode and scales and input shape + mode = get_by_name(n.attribute, "mode").s.decode("ascii") + is_nearest_neighbor_resample = mode == "nearest" if ( n.op_type == "GlobalAveragePool" or n.op_type == "Reshape" or n.op_type == "Transpose" or n.op_type == "Flatten" + or is_nearest_neighbor_resample ): in0 = n.input[0] if in0 is None: