Skip to content
Snippets Groups Projects
Commit ca07070a authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] call InferDataTypes after bipolar2xnor

parent 40e224c0
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ from onnx import helper as oh
from finn.core.datatype import DataType
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.util.basic import get_by_name
......@@ -68,6 +69,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
"""Could not find upstream bipolar
MultiThreshold"""
)
graph_modified = True
mt = mt_chain[-1]
bin_dt_attr = "BINARY".encode("utf-8")
get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
......@@ -83,7 +85,6 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
K = Wbin.shape[0]
model.set_initializer(mm_weight, Wbin)
model.set_tensor_datatype(mm_weight, DataType.BINARY)
graph_modified = True
# make new output node with correct shape
mm_out_shape = model.get_tensor_shape(mm_output)
xnorpcout = oh.make_tensor_value_info(
......@@ -121,5 +122,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
model = model.transform(InferShapes())
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment