Skip to content
Snippets Groups Projects
Commit ec09c481 authored by auphelia's avatar auphelia
Browse files

[Test] Added checking for shape and right FINN data type and added binary case

parent ac844480
No related branches found
No related tags found
No related merge requests found
......@@ -180,7 +180,9 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape):
if FINN_dt == DataType.BIPOLAR:
tensor_values = np.random.randint(2, size=tensor_shape)
tensor_values = 2 * tensor_values - 1
elif FINN_dt == DataType.BINARY:
tensor_values = np.random.randint(2, size=tensor_shape)
else:
raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt))
return[tensor_values]
return tensor_values
import numpy as np
import finn.core.utils as util
from finn.core.datatype import DataType
......@@ -6,8 +7,28 @@ def test_FINN_tensor_generator():
shape_bp = [2,2]
dt_bp = DataType.BIPOLAR
tensor_bp = util.gen_FINN_dt_tensor(dt_bp, shape_bp)
import pdb; pdb.set_trace()
print(tensor_bp)
# test shape
for i in range(len(shape_bp)):
assert shape_bp[i] == tensor_bp.shape[i], """Shape of generated tensor
does not match the desired shape"""
# test if elements are FINN datatype
for value in tensor_bp.flatten():
assert dt_bp.allowed(value), """Data type of generated tensor
does not match the desired Data type"""
# binary
shape_b = [4,2,3]
dt_b = DataType.BINARY
tensor_b = util.gen_FINN_dt_tensor(dt_b, shape_b)
# test shape
for i in range(len(shape_b)):
assert shape_b[i] == tensor_b.shape[i], """Shape of generated tensor
does not match the desired shape"""
# test if elements are FINN datatype
for value in tensor_b.flatten():
assert dt_b.allowed(value), """Data type of generated tensor
does not match the desired Data type"""
#import pdb; pdb.set_trace()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment