diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 089d0ff81e6e7d1593b421cd2224c32346f9eaf7..19d947b57d045d4d7f2523f0f392adaba5bb367a 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -4,8 +4,8 @@ from finn.core.datatype import DataType def _infer_node_datatype(model, node): """Infer output datatype(s) for a particular node. Returns True if any changes were made.""" - idtypes = map(lambda x: model.get_tensor_datatype(x), node.input) - odtypes = map(lambda x: model.get_tensor_datatype(x), node.output) + idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) + odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) if node.op_type == "MultiThreshold": # number of thresholds decides # output buts, use get_smallest_possible n_thres = model.get_tensor_shape(node.input[1])[1] @@ -15,13 +15,13 @@ def _infer_node_datatype(model, node): # always produces bipolar outputs model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) elif node.op_type == "MatMul": - if len(filter(lambda x: x == DataType.FLOAT32, idtypes)) != 0: + if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: # node has at least one float input, output is also float model.set_tensor_datatype(node.output[0], DataType.FLOAT32) else: # TODO compute minimum / maximum result to minimize bitwidth # use (u)int32 accumulators for now - has_signed_inp = len(filter(lambda x: x.signed(), idtypes)) != 0 + has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 if has_signed_inp: odtype = DataType.INT32 else: @@ -32,8 +32,8 @@ def _infer_node_datatype(model, node): for o in node.output: model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed - new_odtypes = map(lambda x: model.get_tensor_datatype(x), node.output) - graph_modified = new_odtypes == odtypes + new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) + graph_modified = new_odtypes != odtypes return graph_modified