diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index 1f2aa8510a6c5cb9d70aa19bb315a0224b69e1b2..c69bd1b16f1d0c556ff92d23dd3be8a777ad6c6d 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -174,3 +174,13 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False): ret = np.pad(ndarray, pad_amt, mode="constant", constant_values=val) assert (np.asarray(ret.shape, dtype=np.int32) == desired).all() return ret + +def gen_FINN_dt_tensor(FINN_dt, tensor_shape): + # generates random tensor in given shape and with given FINN data type + if FINN_dt == DataType.BIPOLAR: + tensor_values = np.random.randint(2, size=tensor_shape) + tensor_values = 2 * tensor_values - 1 + else: + raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt)) + + return[tensor_values] diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4c89ea752567825d720c9a220f76f786dde2d7 --- /dev/null +++ b/tests/test_gen_FINN_dt_tensor.py @@ -0,0 +1,13 @@ +import finn.core.utils as util +from finn.core.datatype import DataType + +def test_FINN_tensor_generator(): + # bipolar + shape_bp = [2,2] + dt_bp = DataType.BIPOLAR + tensor_bp = util.gen_FINN_dt_tensor(dt_bp, shape_bp) + import pdb; pdb.set_trace() + print(tensor_bp) + + # test shape + # test if elements are FINN datatype