diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 3c06e7913682d6865059d6fdeb71d8455122f1ac..cf712b38054c78c6e414ad914ab67378daec5d12 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -563,55 +563,71 @@ class AbsorbTransposeIntoResize(Transformation):
         graph = model.graph
         node_ind = 0
         graph_modified = False
-        for n in graph.node:
+        for node in graph.node:
             node_ind += 1
-            if n.op_type == "Transpose" and not model.is_fork_node(n):
-                perms = list(get_by_name(n.attribute, "perm").ints)
+            if node.op_type == "Transpose" and not model.is_fork_node(node):
+                perms = list(get_by_name(node.attribute, "perm").ints)
                 if perms == [0, 3, 1, 2]:
-                    mt_cand = model.find_consumer(n.output[0])
+                    mt_cand = model.find_consumer(node.output[0])
                     if mt_cand is not None and mt_cand.op_type == "Resize":
-                        final_t_cands = model.find_consumers(mt_cand.output[0])
                         mode = get_by_name(mt_cand.attribute, "mode").s.decode("ascii")
                         # skip if mode is not nearest
                         if mode != "nearest":
                             continue
                         # if sizes specified, turn into scales
-                        sizes = model.get_initializer(mt_cand.input[3])
+                        if len(mt_cand.input) > 3:
+                            sizes = model.get_initializer(mt_cand.input[3])
+                        else:
+                            sizes = None
                         if sizes is not None:
                             ishape = model.get_tensor_shape(mt_cand.input[0])
                             ns, cs, hs, ws = sizes / np.asarray(ishape)
                             model.set_initializer(
-                                mt_cand.input[2], np.asarray([ns, hs, ws, cs])
+                                mt_cand.input[2], np.asarray([ns, cs, hs, ws])
                             )
                             mt_cand.input.remove(mt_cand.input[3])
-                            # get rid of first tranpose node
-                            mt_cand.input[0] = n.input[0]
-                            graph.node.remove(n)
-                            # fix output shape for Resize
-                            mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0])
-                            mt_ishape = model.get_tensor_shape(mt_cand.input[0])
-                            model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-                            # re-insert Transpose behind Resize
-                            transpose_output = model.make_new_valueinfo_name()
-                            model.set_tensor_shape(transpose_output, mt_orig_oshape)
-                            new_transpose = oh.make_node(
-                                "Transpose",
-                                [mt_cand.output[0]],
-                                [transpose_output],
-                                perm=[0, 3, 1, 2],
-                            )
-                            graph.node.insert(node_ind + 1, new_transpose)
-                            if final_t_cands is not None:
-                                # rewire next nodes' inputs
-                                for final_t_cand in final_t_cands:
-                                    final_t_cand.input[0] = transpose_output
-                            else:
-                                # replace graph top-level output
-                                get_by_name(
-                                    model.graph.output, mt_cand.output[0]
-                                ).name = transpose_output
-                                model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-                            graph_modified = True
+                        # scales already specified, transpose indices to NHWC
+                        scales = model.get_initializer(mt_cand.input[2])
+                        assert scales is not None
+                        ns, cs, hs, ws = scales
+                        model.set_initializer(
+                            mt_cand.input[2], np.asarray([ns, hs, ws, cs])
+                        )
+                        # get rid of first tranpose node
+                        mt_cand.input[0] = node.input[0]
+                        graph.node.remove(node)
+                        is_last_node = mt_cand.output[0] in [
+                            x.name for x in model.graph.output
+                        ]
+
+                        new_tensor_name = model.make_new_valueinfo_name()
+                        if is_last_node:
+                            trans_input = new_tensor_name
+                            trans_output = mt_cand.output[0]
+                        else:
+                            trans_input = mt_cand.output[0]
+                            trans_output = new_tensor_name
+                        # fix tensor shapes for Resize and Transpose
+                        # n, c, h, w = model.get_tensor_shape(mt_cand.input[0])
+                        n, c, hx, wx = model.get_tensor_shape(mt_cand.output[0])
+                        model.set_tensor_shape(trans_input, (n, hx, wx, c))
+                        model.set_tensor_shape(trans_output, (n, c, hx, wx))
+                        # re-insert Transpose behind Resize
+                        new_transpose = oh.make_node(
+                            "Transpose",
+                            [trans_input],
+                            [trans_output],
+                            perm=[0, 3, 1, 2],
+                        )
+                        graph.node.insert(node_ind + 1, new_transpose)
+                        # rewire nodes
+                        final_t_cands = model.find_consumers(mt_cand.output[0])
+                        if final_t_cands is not None:
+                            # rewire next nodes' inputs
+                            for final_t_cand in final_t_cands:
+                                final_t_cand.input[0] = trans_output
+                        mt_cand.output[0] = trans_input
+                        graph_modified = True
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)