diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index c69bd1b16f1d0c556ff92d23dd3be8a777ad6c6d..2f416c4af8d9c9024455e6c2672037be6fe21c2a 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -180,7 +180,9 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape): if FINN_dt == DataType.BIPOLAR: tensor_values = np.random.randint(2, size=tensor_shape) tensor_values = 2 * tensor_values - 1 + elif FINN_dt == DataType.BINARY: + tensor_values = np.random.randint(2, size=tensor_shape) else: raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt)) - return[tensor_values] + return tensor_values diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py index 5f4c89ea752567825d720c9a220f76f786dde2d7..4321498b5cc32dd783cfdd5e7c5e9cf8e765ed8c 100644 --- a/tests/test_gen_FINN_dt_tensor.py +++ b/tests/test_gen_FINN_dt_tensor.py @@ -1,3 +1,4 @@ +import numpy as np import finn.core.utils as util from finn.core.datatype import DataType @@ -6,8 +7,28 @@ def test_FINN_tensor_generator(): shape_bp = [2,2] dt_bp = DataType.BIPOLAR tensor_bp = util.gen_FINN_dt_tensor(dt_bp, shape_bp) - import pdb; pdb.set_trace() - print(tensor_bp) - # test shape + for i in range(len(shape_bp)): + assert shape_bp[i] == tensor_bp.shape[i], """Shape of generated tensor + does not match the desired shape""" + # test if elements are FINN datatype + for value in tensor_bp.flatten(): + assert dt_bp.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + + # binary + shape_b = [4,2,3] + dt_b = DataType.BINARY + tensor_b = util.gen_FINN_dt_tensor(dt_b, shape_b) + # test shape + for i in range(len(shape_b)): + assert shape_b[i] == tensor_b.shape[i], """Shape of generated tensor + does not match the desired shape""" # test if elements are FINN datatype + for value in tensor_b.flatten(): + assert dt_b.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + + + #import pdb; pdb.set_trace() +