Skip to content
Snippets Groups Projects
Commit 3dd4957c authored by auphelia's avatar auphelia
Browse files

[Tensor generator] Added option(+test) for ternary data type

parent ec09c481
No related branches found
No related tags found
No related merge requests found
......@@ -182,6 +182,8 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape):
tensor_values = 2 * tensor_values - 1
elif FINN_dt == DataType.BINARY:
tensor_values = np.random.randint(2, size=tensor_shape)
elif FINN_dt == DataType.TERNARY:
tensor_values = np.random.randint(-1, high=1, size=tensor_shape)
else:
raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt))
......
......@@ -30,5 +30,19 @@ def test_FINN_tensor_generator():
does not match the desired Data type"""
# ternary
shape_t = [7,1,3,1]
dt_t = DataType.TERNARY
tensor_t = util.gen_FINN_dt_tensor(dt_t, shape_t)
# test shape
for i in range(len(shape_t)):
assert shape_t[i] == tensor_t.shape[i], """Shape of generated tensor
does not match the desired shape"""
# test if elements are FINN datatype
for value in tensor_t.flatten():
assert dt_t.allowed(value), """Data type of generated tensor
does not match the desired Data type"""
#import pdb; pdb.set_trace()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment