Skip to content
Snippets Groups Projects
Commit ddd4fdbf authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ConvertBipolarMatMulToXnorPopcount] support bipolar -> binary inp

parent b805e0bd
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import warnings
from onnx import TensorProto
from onnx import helper as oh
......@@ -66,26 +67,40 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
mt_chain = model.find_upstream(mm_input, find_prod_mt)
if len(mt_chain) == 0:
raise Exception(
"""Could not find upstream bipolar
MultiThreshold"""
)
graph_modified = True
mt = mt_chain[-1]
mt_inst = getCustomOp(mt)
# ensure old scale/bias were correct for BIPOLAR
scale_ok = mt_inst.get_nodeattr("out_scale") == 2.0
bias_ok = mt_inst.get_nodeattr("out_bias") == -1.0
assert (
scale_ok and bias_ok
), """Unexpected scale/bias
attributes for BIPOLAR MultiThreshold node."""
# start conversion, set MT output to binary
# (this is what XnorPopcountMatMul expects)
mt_inst.set_nodeattr("out_dtype", "BINARY")
mt_inst.set_nodeattr("out_scale", 1.0)
mt_inst.set_nodeattr("out_bias", 0.0)
model.set_tensor_datatype(mm_input, DataType.BINARY)
if mm_input == graph.input[0].name:
# change input datatype to BINARY
model.set_tensor_datatype(mm_input, DataType.BINARY)
graph_modified = True
warnings.warn(
"""IMPORTANT: Changing graph input DataType
to BINARY instead of BIPOLAR. Ensure this is respected
when checking for correctness.
"""
)
else:
raise Exception(
"""Could not find upstream bipolar
MultiThreshold, and the MatMul is not the
first node on graph input. Unable to convert
input tensor from BIPOLAR to BINARY."""
)
else:
graph_modified = True
mt = mt_chain[-1]
mt_inst = getCustomOp(mt)
# ensure old scale/bias were correct for BIPOLAR
scale_ok = mt_inst.get_nodeattr("out_scale") == 2.0
bias_ok = mt_inst.get_nodeattr("out_bias") == -1.0
assert (
scale_ok and bias_ok
), """Unexpected scale/bias
attributes for BIPOLAR MultiThreshold node."""
# start conversion, set MT output to binary
# (this is what XnorPopcountMatMul expects)
mt_inst.set_nodeattr("out_dtype", "BINARY")
mt_inst.set_nodeattr("out_scale", 1.0)
mt_inst.set_nodeattr("out_bias", 0.0)
model.set_tensor_datatype(mm_input, DataType.BINARY)
# change node type and domain
n.op_type = "XnorPopcountMatMul"
n.domain = "finn"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment