diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 089d0ff81e6e7d1593b421cd2224c32346f9eaf7..19d947b57d045d4d7f2523f0f392adaba5bb367a 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -4,8 +4,8 @@ from finn.core.datatype import DataType
 def _infer_node_datatype(model, node):
     """Infer output datatype(s) for a particular node. Returns True if any
     changes were made."""
-    idtypes = map(lambda x: model.get_tensor_datatype(x), node.input)
-    odtypes = map(lambda x: model.get_tensor_datatype(x), node.output)
+    idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
+    odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     if node.op_type == "MultiThreshold":
         # number of thresholds decides # output buts, use get_smallest_possible
         n_thres = model.get_tensor_shape(node.input[1])[1]
@@ -15,13 +15,13 @@ def _infer_node_datatype(model, node):
         # always produces bipolar outputs
         model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
     elif node.op_type == "MatMul":
-        if len(filter(lambda x: x == DataType.FLOAT32, idtypes)) != 0:
+        if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
             # node has at least one float input, output is also float
             model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
         else:
             # TODO compute minimum / maximum result to minimize bitwidth
             # use (u)int32 accumulators for now
-            has_signed_inp = len(filter(lambda x: x.signed(), idtypes)) != 0
+            has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
             if has_signed_inp:
                 odtype = DataType.INT32
             else:
@@ -32,8 +32,8 @@ def _infer_node_datatype(model, node):
         for o in node.output:
             model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
-    new_odtypes = map(lambda x: model.get_tensor_datatype(x), node.output)
-    graph_modified = new_odtypes == odtypes
+    new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
+    graph_modified = new_odtypes != odtypes
     return graph_modified