diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 3c06e7913682d6865059d6fdeb71d8455122f1ac..cf712b38054c78c6e414ad914ab67378daec5d12 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -563,55 +563,71 @@ class AbsorbTransposeIntoResize(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + for node in graph.node: node_ind += 1 - if n.op_type == "Transpose" and not model.is_fork_node(n): - perms = list(get_by_name(n.attribute, "perm").ints) + if node.op_type == "Transpose" and not model.is_fork_node(node): + perms = list(get_by_name(node.attribute, "perm").ints) if perms == [0, 3, 1, 2]: - mt_cand = model.find_consumer(n.output[0]) + mt_cand = model.find_consumer(node.output[0]) if mt_cand is not None and mt_cand.op_type == "Resize": - final_t_cands = model.find_consumers(mt_cand.output[0]) mode = get_by_name(mt_cand.attribute, "mode").s.decode("ascii") # skip if mode is not nearest if mode != "nearest": continue # if sizes specified, turn into scales - sizes = model.get_initializer(mt_cand.input[3]) + if len(mt_cand.input) > 3: + sizes = model.get_initializer(mt_cand.input[3]) + else: + sizes = None if sizes is not None: ishape = model.get_tensor_shape(mt_cand.input[0]) ns, cs, hs, ws = sizes / np.asarray(ishape) model.set_initializer( - mt_cand.input[2], np.asarray([ns, hs, ws, cs]) + mt_cand.input[2], np.asarray([ns, cs, hs, ws]) ) mt_cand.input.remove(mt_cand.input[3]) - # get rid of first tranpose node - mt_cand.input[0] = n.input[0] - graph.node.remove(n) - # fix output shape for Resize - mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0]) - mt_ishape = model.get_tensor_shape(mt_cand.input[0]) - model.set_tensor_shape(mt_cand.output[0], mt_ishape) - # re-insert Transpose behind Resize - transpose_output = model.make_new_valueinfo_name() - model.set_tensor_shape(transpose_output, mt_orig_oshape) - new_transpose = oh.make_node( - "Transpose", - [mt_cand.output[0]], - [transpose_output], - perm=[0, 3, 1, 2], - ) - graph.node.insert(node_ind + 1, new_transpose) - if final_t_cands is not None: - # rewire next nodes' inputs - for final_t_cand in final_t_cands: - final_t_cand.input[0] = transpose_output - else: - # replace graph top-level output - get_by_name( - model.graph.output, mt_cand.output[0] - ).name = transpose_output - model.set_tensor_shape(mt_cand.output[0], mt_ishape) - graph_modified = True + # scales already specified, transpose indices to NHWC + scales = model.get_initializer(mt_cand.input[2]) + assert scales is not None + ns, cs, hs, ws = scales + model.set_initializer( + mt_cand.input[2], np.asarray([ns, hs, ws, cs]) + ) + # get rid of first tranpose node + mt_cand.input[0] = node.input[0] + graph.node.remove(node) + is_last_node = mt_cand.output[0] in [ + x.name for x in model.graph.output + ] + + new_tensor_name = model.make_new_valueinfo_name() + if is_last_node: + trans_input = new_tensor_name + trans_output = mt_cand.output[0] + else: + trans_input = mt_cand.output[0] + trans_output = new_tensor_name + # fix tensor shapes for Resize and Transpose + # n, c, h, w = model.get_tensor_shape(mt_cand.input[0]) + n, c, hx, wx = model.get_tensor_shape(mt_cand.output[0]) + model.set_tensor_shape(trans_input, (n, hx, wx, c)) + model.set_tensor_shape(trans_output, (n, c, hx, wx)) + # re-insert Transpose behind Resize + new_transpose = oh.make_node( + "Transpose", + [trans_input], + [trans_output], + perm=[0, 3, 1, 2], + ) + graph.node.insert(node_ind + 1, new_transpose) + # rewire nodes + final_t_cands = model.find_consumers(mt_cand.output[0]) + if final_t_cands is not None: + # rewire next nodes' inputs + for final_t_cand in final_t_cands: + final_t_cand.input[0] = trans_output + mt_cand.output[0] = trans_input + graph_modified = True if graph_modified: model = model.transform(InferDataTypes()) return (model, graph_modified)