From ec09c481262e864cc19e9b1c582245774137a17a Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Tue, 26 Nov 2019 17:42:33 +0000 Subject: [PATCH] [Test] Added checking for shape and right FINN data type and added binary case --- src/finn/core/utils.py | 4 +++- tests/test_gen_FINN_dt_tensor.py | 27 ++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index c69bd1b16..2f416c4af 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -180,7 +180,9 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape): if FINN_dt == DataType.BIPOLAR: tensor_values = np.random.randint(2, size=tensor_shape) tensor_values = 2 * tensor_values - 1 + elif FINN_dt == DataType.BINARY: + tensor_values = np.random.randint(2, size=tensor_shape) else: raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt)) - return[tensor_values] + return tensor_values diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py index 5f4c89ea7..4321498b5 100644 --- a/tests/test_gen_FINN_dt_tensor.py +++ b/tests/test_gen_FINN_dt_tensor.py @@ -1,3 +1,4 @@ +import numpy as np import finn.core.utils as util from finn.core.datatype import DataType @@ -6,8 +7,28 @@ def test_FINN_tensor_generator(): shape_bp = [2,2] dt_bp = DataType.BIPOLAR tensor_bp = util.gen_FINN_dt_tensor(dt_bp, shape_bp) - import pdb; pdb.set_trace() - print(tensor_bp) - # test shape + for i in range(len(shape_bp)): + assert shape_bp[i] == tensor_bp.shape[i], """Shape of generated tensor + does not match the desired shape""" + # test if elements are FINN datatype + for value in tensor_bp.flatten(): + assert dt_bp.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + + # binary + shape_b = [4,2,3] + dt_b = DataType.BINARY + tensor_b = util.gen_FINN_dt_tensor(dt_b, shape_b) + # test shape + for i in range(len(shape_b)): + assert shape_b[i] == tensor_b.shape[i], """Shape of generated tensor + does not match the desired shape""" # test if elements are FINN datatype + for value in tensor_b.flatten(): + assert dt_b.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + + + #import pdb; pdb.set_trace() + -- GitLab