From 3dd4957cf358d374b3f82a3e3f4d41ead8be933c Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Tue, 26 Nov 2019 17:56:47 +0000 Subject: [PATCH] [Tensor generator] Added option(+test) for ternary data type --- src/finn/core/utils.py | 2 ++ tests/test_gen_FINN_dt_tensor.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index 2f416c4af..6f18fe149 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -182,6 +182,8 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape): tensor_values = 2 * tensor_values - 1 elif FINN_dt == DataType.BINARY: tensor_values = np.random.randint(2, size=tensor_shape) + elif FINN_dt == DataType.TERNARY: + tensor_values = np.random.randint(-1, high=1, size=tensor_shape) else: raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt)) diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py index 4321498b5..0cb31a169 100644 --- a/tests/test_gen_FINN_dt_tensor.py +++ b/tests/test_gen_FINN_dt_tensor.py @@ -30,5 +30,19 @@ def test_FINN_tensor_generator(): does not match the desired Data type""" + # ternary + shape_t = [7,1,3,1] + dt_t = DataType.TERNARY + tensor_t = util.gen_FINN_dt_tensor(dt_t, shape_t) + # test shape + for i in range(len(shape_t)): + assert shape_t[i] == tensor_t.shape[i], """Shape of generated tensor + does not match the desired shape""" + # test if elements are FINN datatype + for value in tensor_t.flatten(): + assert dt_t.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + + #import pdb; pdb.set_trace() -- GitLab