Skip to content
Snippets Groups Projects
Commit e0f7e034 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] set wdt/odt as float32 for CollapseRepeated

parent 399c57ba
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.core.datatype import DataType
class CollapseRepeatedOp(Transformation):
......@@ -83,6 +84,9 @@ class CollapseRepeatedOp(Transformation):
graph.node.insert(node_ind, new_node)
# replace parameter value
model.set_initializer(new_node_param_name, new_param)
# be conservative with param/output DataTypes
model.set_tensor_datatype(new_node_param_name, DataType.FLOAT32)
model.set_tensor_datatype(end_name, DataType.FLOAT32)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment