diff --git a/src/finn/core/datatype.py b/src/finn/core/datatype.py index 9323d0c42fc37c7dc95e21cf6305dc40dd605b5c..42a366aafcc002a433d0e03c03ef6a9bed6adede 100644 --- a/src/finn/core/datatype.py +++ b/src/finn/core/datatype.py @@ -24,27 +24,29 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from enum import Enum +from enum import Enum, auto import numpy as np class DataType(Enum): - FLOAT32 = 0 - BINARY = 1 - BIPOLAR = 2 - UINT2 = 3 - UINT3 = 4 - UINT4 = 5 - UINT8 = 6 - UINT16 = 7 - UINT32 = 8 - INT2 = 9 - INT3 = 10 - INT4 = 11 - INT8 = 12 - INT16 = 13 - INT32 = 14 + # important to maintain ordering here: unsigned to signed, fewer to more + # bits. The get_smallest_possible() member function is dependent on this. + BINARY = auto() + UINT2 = auto() + UINT3 = auto() + UINT4 = auto() + UINT8 = auto() + UINT16 = auto() + UINT32 = auto() + BIPOLAR = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT8 = auto() + INT16 = auto() + INT32 = auto() + FLOAT32 = auto() def bitwidth(self): """Returns the number of bits required for this DataType.""" @@ -109,3 +111,22 @@ class DataType(Enum): return value in [-1, +1] else: raise Exception("Unrecognized data type: %s" % self.name) + + def get_smallest_possible(value): + """Return smallest (fewest bits) possible DataType that can represent + value. Prefers unsigned integers where possible.""" + if not int(value) == value: + return DataType["FLOAT32"] + for k in DataType.__members__: + dt = DataType[k] + if (dt.min() <= value) and (value <= dt.max()): + return dt + + def signed(self): + """Return whether this DataType can represent negative numbers.""" + return self.min() < 0 + + def is_integer(self): + """Return whether this DataType represents integer values only.""" + # only FLOAT32 is noninteger for now + return self != DataType.FLOAT32 diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..19d947b57d045d4d7f2523f0f392adaba5bb367a --- /dev/null +++ b/src/finn/transformation/infer_datatypes.py @@ -0,0 +1,47 @@ +from finn.core.datatype import DataType + + +def _infer_node_datatype(model, node): + """Infer output datatype(s) for a particular node. Returns True if any + changes were made.""" + idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) + odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) + if node.op_type == "MultiThreshold": + # number of thresholds decides # output buts, use get_smallest_possible + n_thres = model.get_tensor_shape(node.input[1])[1] + odtype = DataType.get_smallest_possible(n_thres) + model.set_tensor_datatype(node.output[0], odtype) + elif node.op_type == "Sign": + # always produces bipolar outputs + model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) + elif node.op_type == "MatMul": + if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: + # node has at least one float input, output is also float + model.set_tensor_datatype(node.output[0], DataType.FLOAT32) + else: + # TODO compute minimum / maximum result to minimize bitwidth + # use (u)int32 accumulators for now + has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 + if has_signed_inp: + odtype = DataType.INT32 + else: + odtype = DataType.UINT32 + model.set_tensor_datatype(node.output[0], odtype) + else: + # unknown, assume node produces float32 outputs + for o in node.output: + model.set_tensor_datatype(o, DataType.FLOAT32) + # compare old and new output dtypes to see if anything changed + new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) + graph_modified = new_odtypes != odtypes + return graph_modified + + +def infer_datatypes(model): + """Infer FINN DataType info for all intermediate/output tensors based on + inputs and node type.""" + graph = model.graph + graph_modified = False + for node in graph.node: + graph_modified |= _infer_node_datatype(model, node) + return (model, graph_modified) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 309177e548be490a7de05aad6006b838b207f3bc..5053b23c3e09bc5f32807cdf16a997c4647dd165 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +def round_thresholds(model): + """For MultiThreshold nodes operating on integer inputs, round up + thresholds values to the nearest integer.""" + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "MultiThreshold": + idtype = model.get_tensor_datatype(n.input[0]) + T = model.get_initializer(n.input[1]) + Tnew = np.ceil(T) + if idtype.is_integer() and (T != Tnew).any(): + # round up the thresholds to nearest integer + model.set_initializer(n.input[1], Tnew) + # use same datatype as inputs for thresholds + model.set_tensor_datatype(n.input[1], idtype) + graph_modified = True + return (model, graph_modified) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index c2929fa156a593eb7d44f63cafc55da05b7c3dae..559314117f03f505d8ddb4afb88ba440f4dcd966 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,35 +1,46 @@ # -*- coding: utf-8 -*- -import finn.core.datatype as dt +from finn.core.datatype import DataType def test_datatypes(): - assert dt.DataType.BIPOLAR.allowed(-1) - assert dt.DataType.BIPOLAR.allowed(0) is False - assert dt.DataType.BINARY.allowed(-1) is False - assert dt.DataType.BINARY.allowed(1) - assert dt.DataType.UINT2.allowed(2) - assert dt.DataType.UINT2.allowed(10) is False - assert dt.DataType.UINT3.allowed(5) - assert dt.DataType.UINT3.allowed(-7) is False - assert dt.DataType.UINT4.allowed(15) - assert dt.DataType.UINT4.allowed(150) is False - assert dt.DataType.UINT8.allowed(150) - assert dt.DataType.UINT8.allowed(777) is False - assert dt.DataType.UINT16.allowed(14500) - assert dt.DataType.UINT16.allowed(-1) is False - assert dt.DataType.UINT32.allowed(2 ** 10) - assert dt.DataType.UINT32.allowed(-1) is False - assert dt.DataType.INT2.allowed(-1) - assert dt.DataType.INT2.allowed(-10) is False - assert dt.DataType.INT3.allowed(5) is False - assert dt.DataType.INT3.allowed(-2) - assert dt.DataType.INT4.allowed(15) is False - assert dt.DataType.INT4.allowed(-5) - assert dt.DataType.INT8.allowed(150) is False - assert dt.DataType.INT8.allowed(-127) - assert dt.DataType.INT16.allowed(-1.04) is False - assert dt.DataType.INT16.allowed(-7777) - assert dt.DataType.INT32.allowed(7.77) is False - assert dt.DataType.INT32.allowed(-5) - assert dt.DataType.INT32.allowed(5) + assert DataType.BIPOLAR.allowed(-1) + assert DataType.BIPOLAR.allowed(0) is False + assert DataType.BINARY.allowed(-1) is False + assert DataType.BINARY.allowed(1) + assert DataType.UINT2.allowed(2) + assert DataType.UINT2.allowed(10) is False + assert DataType.UINT3.allowed(5) + assert DataType.UINT3.allowed(-7) is False + assert DataType.UINT4.allowed(15) + assert DataType.UINT4.allowed(150) is False + assert DataType.UINT8.allowed(150) + assert DataType.UINT8.allowed(777) is False + assert DataType.UINT16.allowed(14500) + assert DataType.UINT16.allowed(-1) is False + assert DataType.UINT32.allowed(2 ** 10) + assert DataType.UINT32.allowed(-1) is False + assert DataType.INT2.allowed(-1) + assert DataType.INT2.allowed(-10) is False + assert DataType.INT3.allowed(5) is False + assert DataType.INT3.allowed(-2) + assert DataType.INT4.allowed(15) is False + assert DataType.INT4.allowed(-5) + assert DataType.INT8.allowed(150) is False + assert DataType.INT8.allowed(-127) + assert DataType.INT16.allowed(-1.04) is False + assert DataType.INT16.allowed(-7777) + assert DataType.INT32.allowed(7.77) is False + assert DataType.INT32.allowed(-5) + assert DataType.INT32.allowed(5) + assert DataType.BINARY.signed() is False + assert DataType.FLOAT32.signed() + assert DataType.BIPOLAR.signed() + + +def test_smallest_possible(): + assert DataType.get_smallest_possible(1) == DataType.BINARY + assert DataType.get_smallest_possible(1.1) == DataType.FLOAT32 + assert DataType.get_smallest_possible(-1) == DataType.BIPOLAR + assert DataType.get_smallest_possible(-3) == DataType.INT3 + assert DataType.get_smallest_possible(-3.2) == DataType.FLOAT32 diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1c99a1da5f5bdb93f0b8b089fddb20e3c078ed --- /dev/null +++ b/tests/test_infer_datatypes.py @@ -0,0 +1,40 @@ +import os + +import brevitas.onnx as bo +import torch +from models.LFC import LFC + +import finn.transformation.fold_constants as fc +import finn.transformation.general as tg +import finn.transformation.infer_datatypes as id +import finn.transformation.infer_shapes as si +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper + +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + + +def test_infer_datatypes(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) + model = model.transform_repeated(fc.fold_constants) + model = model.transform_single(tg.give_unique_node_names) + model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform_repeated(id.infer_datatypes) + assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_3_out0") == DataType.INT32 + assert model.get_tensor_datatype("Sign_0_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_1_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_2_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_3_out0") == DataType.BIPOLAR + os.remove(export_onnx_path) diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py new file mode 100644 index 0000000000000000000000000000000000000000..28add7313201f1e30ecaf98d841f762a3f859835 --- /dev/null +++ b/tests/test_round_thresholds.py @@ -0,0 +1,38 @@ +import numpy as np +from onnx import TensorProto, helper + +import finn.core.onnx_exec as oxe +import finn.transformation.streamline as sl +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper + + +def test_round_thresholds(): + v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4]) + thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) + node_def = helper.make_node( + "MultiThreshold", ["v", "thresholds"], ["out"], domain="finn" + ) + graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) + model_def = helper.make_model(graph_def) + model = ModelWrapper(model_def) + threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32) + model.set_initializer("thresholds", threshold_val) + model.set_tensor_datatype("v", DataType.INT8) + inp_dict_f = {"v": np.floor(threshold_val).T} + inp_dict_n = {"v": np.round(threshold_val).T} + inp_dict_c = {"v": np.ceil(threshold_val).T} + orig_f = oxe.execute_onnx(model, inp_dict_f)["out"] + orig_n = oxe.execute_onnx(model, inp_dict_n)["out"] + orig_c = oxe.execute_onnx(model, inp_dict_c)["out"] + assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32 + new_model = model.transform_repeated(sl.round_thresholds) + # rounded up thresholds should have same dtype as input + assert new_model.get_tensor_datatype("thresholds") == DataType.INT8 + new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"] + new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"] + new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"] + assert np.isclose(orig_f, new_f, atol=1e-3).all() + assert np.isclose(orig_n, new_n, atol=1e-3).all() + assert np.isclose(orig_c, new_c, atol=1e-3).all() diff --git a/tests/test_streamline.py b/tests/test_streamline.py index e3d97a4364ce1c7dd46b9e4ed7e66d7076751fd0..1a93c6ff128b2169da4802d7367608f499e89beb 100644 --- a/tests/test_streamline.py +++ b/tests/test_streamline.py @@ -12,6 +12,7 @@ import finn.core.onnx_exec as oxe import finn.transformation.batchnorm_to_affine as ba import finn.transformation.fold_constants as fc import finn.transformation.general as tg +import finn.transformation.infer_datatypes as di import finn.transformation.infer_shapes as si import finn.transformation.streamline as sl from finn.core.modelwrapper import ModelWrapper @@ -53,12 +54,14 @@ def test_streamline_lfc_w1a1(): sl.factor_out_mul_sign_magnitude, sl.absorb_mul_into_multi_threshold, sl.absorb_1bit_mul_into_matmul, + sl.round_thresholds, ] trn_ind = 0 for trn in transforms: model = model.transform_repeated(trn) model = model.transform_single(tg.give_unique_node_names) model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform_repeated(di.infer_datatypes) produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))