diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py
index 12c6984c6e66e1917d2a1e0a74c8620ccb6afabc..0a36b8bbe5c05a8226ae647e0061c1551f3b1cbf 100644
--- a/src/finn/transformation/streamline/remove.py
+++ b/src/finn/transformation/streamline/remove.py
@@ -32,6 +32,23 @@ from finn.transformation.infer_shapes import InferShapes
 import numpy as np
 
 
+def _remove_node_and_rewire(model, node):
+    producer = model.find_producer(node.input[0])
+    if producer is not None:
+        # wire output tensor to
+        # output of producer node
+        producer.output[0] = node.output[0]
+    else:
+        # node is first in graph
+        consumer = model.find_consumer(node.output[0])
+        assert consumer is not None, "Whole graph is identity"
+        assert consumer.input[0] == node.output[0]
+        # rewire consumer's input directly to graph input
+        consumer.input[0] = node.input[0]
+    # remove node
+    model.graph.node.remove(node)
+
+
 class RemoveIdentityOps(Transformation):
     """Remove identity ops like Add/Sub with zero or Mul/Div with one"""
 
@@ -48,11 +65,7 @@ class RemoveIdentityOps(Transformation):
             ):
                 A = model.get_initializer(n.input[1])
                 if A is not None and (A == np.zeros_like(A)).all():
-                    producer = model.find_producer(n.input[0])
-                    # remove node and wire output tensor to
-                    # output of producer node
-                    producer.output[0] = n.output[0]
-                    graph.node.remove(n)
+                    _remove_node_and_rewire(model, n)
 
             elif (
                 n.op_type in ["Mul", "Div"]
@@ -61,10 +74,6 @@ class RemoveIdentityOps(Transformation):
             ):
                 A = model.get_initializer(n.input[1])
                 if A is not None and (A == np.ones_like(A)).all():
-                    producer = model.find_producer(n.input[0])
-                    # remove node and wire output tensor to
-                    # output of producer node
-                    producer.output[0] = n.output[0]
-                    graph.node.remove(n)
+                    _remove_node_and_rewire(model, n)
         model = model.transform(InferShapes())
         return (model, graph_modified)
diff --git a/tests/transformation/streamline/test_remove_identity_ops.py b/tests/transformation/streamline/test_remove_identity_ops.py
index 536c1ab0b48fa44388da23f45b528da3c5f3b2f2..98430fad0e0f4c17d77ddbf44afeeccd44372047 100644
--- a/tests/transformation/streamline/test_remove_identity_ops.py
+++ b/tests/transformation/streamline/test_remove_identity_ops.py
@@ -11,7 +11,7 @@ from finn.transformation.streamline.remove import RemoveIdentityOps
 from finn.util.basic import gen_finn_dt_tensor
 
 
-def insert_identity_op(model, op):
+def insert_identity_op(model, op, as_first_node):
     if op in ["Add", "Sub"]:
         val = np.asarray([0.0], dtype=np.float32)
     elif op in ["Mul", "Div"]:
@@ -19,10 +19,15 @@ def insert_identity_op(model, op):
     else:
         return
 
-    identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
     graph = model.graph
-    graph.node.insert(3, identity_node)
-    graph.node[-1].input[0] = "ident_out"
+    if as_first_node:
+        identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
+        graph.node.insert(0, identity_node)
+        graph.node[1].input[0] = "ident_out"
+    else:
+        identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
+        graph.node.insert(3, identity_node)
+        graph.node[-1].input[0] = "ident_out"
     model.set_initializer("value", val)
 
     return model
@@ -30,7 +35,8 @@ def insert_identity_op(model, op):
 
 # identity operations to be inserted
 @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
-def test_remove_identity_ops(op):
+@pytest.mark.parametrize("as_first_node", [False, True])
+def test_remove_identity_ops(op, as_first_node):
 
     # set up onnx model
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
@@ -64,7 +70,7 @@ def test_remove_identity_ops(op):
     model.set_initializer("shape", shape_values)
     model.set_initializer("div", div_values)
     model.set_initializer("matmul", matmul_values)
-    insert_identity_op(model, op)
+    insert_identity_op(model, op, as_first_node)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
     idict = {"inp": inp_values}