Skip to content
Snippets Groups Projects
Commit f5fef066 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/roundup_thresholds' into dev

parents e00cf5cf cb61db2f
No related branches found
No related tags found
No related merge requests found
......@@ -125,3 +125,8 @@ class DataType(Enum):
def signed(self):
"""Return whether this DataType can represent negative numbers."""
return self.min() < 0
def is_integer(self):
"""Return whether this DataType represents integer values only."""
# only FLOAT32 is noninteger for now
return self != DataType.FLOAT32
......@@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model):
graph.node.remove(consumer)
graph_modified = True
return (model, graph_modified)
def round_thresholds(model):
"""For MultiThreshold nodes operating on integer inputs, round up
thresholds values to the nearest integer."""
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "MultiThreshold":
idtype = model.get_tensor_datatype(n.input[0])
T = model.get_initializer(n.input[1])
Tnew = np.ceil(T)
if idtype.is_integer() and (T != Tnew).any():
# round up the thresholds to nearest integer
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
return (model, graph_modified)
import numpy as np
from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
import finn.transformation.streamline as sl
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
def test_round_thresholds():
v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4])
thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4])
node_def = helper.make_node(
"MultiThreshold", ["v", "thresholds"], ["out"], domain="finn"
)
graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
model_def = helper.make_model(graph_def)
model = ModelWrapper(model_def)
threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32)
model.set_initializer("thresholds", threshold_val)
model.set_tensor_datatype("v", DataType.INT8)
inp_dict_f = {"v": np.floor(threshold_val).T}
inp_dict_n = {"v": np.round(threshold_val).T}
inp_dict_c = {"v": np.ceil(threshold_val).T}
orig_f = oxe.execute_onnx(model, inp_dict_f)["out"]
orig_n = oxe.execute_onnx(model, inp_dict_n)["out"]
orig_c = oxe.execute_onnx(model, inp_dict_c)["out"]
assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32
new_model = model.transform_repeated(sl.round_thresholds)
# rounded up thresholds should have same dtype as input
assert new_model.get_tensor_datatype("thresholds") == DataType.INT8
new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"]
new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"]
new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"]
assert np.isclose(orig_f, new_f, atol=1e-3).all()
assert np.isclose(orig_n, new_n, atol=1e-3).all()
assert np.isclose(orig_c, new_c, atol=1e-3).all()
......@@ -12,6 +12,7 @@ import finn.core.onnx_exec as oxe
import finn.transformation.batchnorm_to_affine as ba
import finn.transformation.fold_constants as fc
import finn.transformation.general as tg
import finn.transformation.infer_datatypes as di
import finn.transformation.infer_shapes as si
import finn.transformation.streamline as sl
from finn.core.modelwrapper import ModelWrapper
......@@ -53,12 +54,14 @@ def test_streamline_lfc_w1a1():
sl.factor_out_mul_sign_magnitude,
sl.absorb_mul_into_multi_threshold,
sl.absorb_1bit_mul_into_matmul,
sl.round_thresholds,
]
trn_ind = 0
for trn in transforms:
model = model.transform_repeated(trn)
model = model.transform_single(tg.give_unique_node_names)
model = model.transform_single(tg.give_readable_tensor_names)
model = model.transform_repeated(di.infer_datatypes)
produced_ctx = oxe.execute_onnx(model, input_dict, True)
produced = produced_ctx[model.graph.output[0].name]
# model.save("%d-%s.onnx" % (trn_ind, trn.__name__))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment