diff --git a/src/finn/transformation/lower_convs_to_matmul.py b/src/finn/transformation/lower_convs_to_matmul.py
index cb71f22ad51f966d6641095f7f820f7dc4f16beb..ed8fd21d812dea494e135a7967d8bb7f9aae30b5 100644
--- a/src/finn/transformation/lower_convs_to_matmul.py
+++ b/src/finn/transformation/lower_convs_to_matmul.py
@@ -48,6 +48,8 @@ class LowerConvsToMatMul(Transformation):
                 graph_modified = True
                 cnv_input = n.input[0]
                 cnv_output = n.output[0]
+                idt = model.get_tensor_datatype(cnv_input)
+                odt = model.get_tensor_datatype(cnv_output)
                 # extract conv parameters
                 k = get_by_name(n.attribute, "kernel_shape").ints[-1]
                 pad = get_by_name(n.attribute, "pads").ints[-1]
@@ -69,6 +71,7 @@ class LowerConvsToMatMul(Transformation):
                 )
                 graph.value_info.append(inp_trans_out)
                 inp_trans_out = inp_trans_out.name
+                model.set_tensor_datatype(inp_trans_out, idt)
 
                 im2col_out = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
@@ -77,6 +80,7 @@ class LowerConvsToMatMul(Transformation):
                 )
                 graph.value_info.append(im2col_out)
                 im2col_out = im2col_out.name
+                model.set_tensor_datatype(im2col_out, idt)
 
                 matmul_out = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
@@ -85,6 +89,7 @@ class LowerConvsToMatMul(Transformation):
                 )
                 graph.value_info.append(matmul_out)
                 matmul_out = matmul_out.name
+                model.set_tensor_datatype(matmul_out, odt)
 
                 # create new nodes
                 # NCHW -> NHWC