diff --git a/src/finn/core/datatype.py b/src/finn/core/datatype.py index b71922d479b62e5eba07ddd159e2e3243fa77d95..42a366aafcc002a433d0e03c03ef6a9bed6adede 100644 --- a/src/finn/core/datatype.py +++ b/src/finn/core/datatype.py @@ -125,3 +125,8 @@ class DataType(Enum): def signed(self): """Return whether this DataType can represent negative numbers.""" return self.min() < 0 + + def is_integer(self): + """Return whether this DataType represents integer values only.""" + # only FLOAT32 is noninteger for now + return self != DataType.FLOAT32 diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 309177e548be490a7de05aad6006b838b207f3bc..5053b23c3e09bc5f32807cdf16a997c4647dd165 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +def round_thresholds(model): + """For MultiThreshold nodes operating on integer inputs, round up + thresholds values to the nearest integer.""" + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "MultiThreshold": + idtype = model.get_tensor_datatype(n.input[0]) + T = model.get_initializer(n.input[1]) + Tnew = np.ceil(T) + if idtype.is_integer() and (T != Tnew).any(): + # round up the thresholds to nearest integer + model.set_initializer(n.input[1], Tnew) + # use same datatype as inputs for thresholds + model.set_tensor_datatype(n.input[1], idtype) + graph_modified = True + return (model, graph_modified) diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py new file mode 100644 index 0000000000000000000000000000000000000000..28add7313201f1e30ecaf98d841f762a3f859835 --- /dev/null +++ b/tests/test_round_thresholds.py @@ -0,0 +1,38 @@ +import numpy as np +from onnx import TensorProto, helper + +import finn.core.onnx_exec as oxe +import finn.transformation.streamline as sl +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper + + +def test_round_thresholds(): + v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4]) + thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) + node_def = helper.make_node( + "MultiThreshold", ["v", "thresholds"], ["out"], domain="finn" + ) + graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) + model_def = helper.make_model(graph_def) + model = ModelWrapper(model_def) + threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32) + model.set_initializer("thresholds", threshold_val) + model.set_tensor_datatype("v", DataType.INT8) + inp_dict_f = {"v": np.floor(threshold_val).T} + inp_dict_n = {"v": np.round(threshold_val).T} + inp_dict_c = {"v": np.ceil(threshold_val).T} + orig_f = oxe.execute_onnx(model, inp_dict_f)["out"] + orig_n = oxe.execute_onnx(model, inp_dict_n)["out"] + orig_c = oxe.execute_onnx(model, inp_dict_c)["out"] + assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32 + new_model = model.transform_repeated(sl.round_thresholds) + # rounded up thresholds should have same dtype as input + assert new_model.get_tensor_datatype("thresholds") == DataType.INT8 + new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"] + new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"] + new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"] + assert np.isclose(orig_f, new_f, atol=1e-3).all() + assert np.isclose(orig_n, new_n, atol=1e-3).all() + assert np.isclose(orig_c, new_c, atol=1e-3).all() diff --git a/tests/test_streamline.py b/tests/test_streamline.py index e3d97a4364ce1c7dd46b9e4ed7e66d7076751fd0..1a93c6ff128b2169da4802d7367608f499e89beb 100644 --- a/tests/test_streamline.py +++ b/tests/test_streamline.py @@ -12,6 +12,7 @@ import finn.core.onnx_exec as oxe import finn.transformation.batchnorm_to_affine as ba import finn.transformation.fold_constants as fc import finn.transformation.general as tg +import finn.transformation.infer_datatypes as di import finn.transformation.infer_shapes as si import finn.transformation.streamline as sl from finn.core.modelwrapper import ModelWrapper @@ -53,12 +54,14 @@ def test_streamline_lfc_w1a1(): sl.factor_out_mul_sign_magnitude, sl.absorb_mul_into_multi_threshold, sl.absorb_1bit_mul_into_matmul, + sl.round_thresholds, ] trn_ind = 0 for trn in transforms: model = model.transform_repeated(trn) model = model.transform_single(tg.give_unique_node_names) model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform_repeated(di.infer_datatypes) produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))