Skip to content
Snippets Groups Projects
Commit 3bce6698 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] fix container problems in infer_datatypes

parent 68c7f4e0
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,8 @@ from finn.core.datatype import DataType
def _infer_node_datatype(model, node):
"""Infer output datatype(s) for a particular node. Returns True if any
changes were made."""
idtypes = map(lambda x: model.get_tensor_datatype(x), node.input)
odtypes = map(lambda x: model.get_tensor_datatype(x), node.output)
idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
if node.op_type == "MultiThreshold":
# number of thresholds decides # output buts, use get_smallest_possible
n_thres = model.get_tensor_shape(node.input[1])[1]
......@@ -15,13 +15,13 @@ def _infer_node_datatype(model, node):
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
elif node.op_type == "MatMul":
if len(filter(lambda x: x == DataType.FLOAT32, idtypes)) != 0:
if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
# node has at least one float input, output is also float
model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
else:
# TODO compute minimum / maximum result to minimize bitwidth
# use (u)int32 accumulators for now
has_signed_inp = len(filter(lambda x: x.signed(), idtypes)) != 0
has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
if has_signed_inp:
odtype = DataType.INT32
else:
......@@ -32,8 +32,8 @@ def _infer_node_datatype(model, node):
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
# compare old and new output dtypes to see if anything changed
new_odtypes = map(lambda x: model.get_tensor_datatype(x), node.output)
graph_modified = new_odtypes == odtypes
new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
graph_modified = new_odtypes != odtypes
return graph_modified
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment