Skip to content
Snippets Groups Projects
Commit da18bc95 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Streamline] cover NN resampling for MoveScalarLinearPastInvariants

parent 3a786fab
No related branches found
No related tags found
No related merge requests found
......@@ -594,11 +594,17 @@ class MoveScalarLinearPastInvariants(Transformation):
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
is_nearest_neighbor_resample = False
if n.op_type == "Upsample" or n.op_type == "Resize":
# Extract mode and scales and input shape
mode = get_by_name(n.attribute, "mode").s.decode("ascii")
is_nearest_neighbor_resample = mode == "nearest"
if (
n.op_type == "GlobalAveragePool"
or n.op_type == "Reshape"
or n.op_type == "Transpose"
or n.op_type == "Flatten"
or is_nearest_neighbor_resample
):
in0 = n.input[0]
if in0 is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment