diff --git a/src/finn/core/tensor.py b/src/finn/core/datatype.py similarity index 89% rename from src/finn/core/tensor.py rename to src/finn/core/datatype.py index 257078ad183dc9df561971c6e3d70677e4b5282e..9323d0c42fc37c7dc95e21cf6305dc40dd605b5c 100644 --- a/src/finn/core/tensor.py +++ b/src/finn/core/datatype.py @@ -109,19 +109,3 @@ class DataType(Enum): return value in [-1, +1] else: raise Exception("Unrecognized data type: %s" % self.name) - - -class Tensor(object): - """A multidimensional array of numbers of given datatype. - - Attributes: - dtype (DataType): Element data type for this Tensor - data (numpy ndarray of float32): Numpy container for data - dim_names (list of str): names associated with each dimension, e.g. - ["N", "C", "H", "W"] - """ - - def __init__(self, dtype, data, dim_names=[]): - self.dtype = dtype - self.data = data - self.dim_names = dim_names diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 98d5d387dd580ad6894daabe71075ce5e7d992f6..c2929fa156a593eb7d44f63cafc55da05b7c3dae 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,35 +1,35 @@ # -*- coding: utf-8 -*- -import finn.core.tensor as ten +import finn.core.datatype as dt def test_datatypes(): - assert ten.DataType.BIPOLAR.allowed(-1) - assert ten.DataType.BIPOLAR.allowed(0) is False - assert ten.DataType.BINARY.allowed(-1) is False - assert ten.DataType.BINARY.allowed(1) - assert ten.DataType.UINT2.allowed(2) - assert ten.DataType.UINT2.allowed(10) is False - assert ten.DataType.UINT3.allowed(5) - assert ten.DataType.UINT3.allowed(-7) is False - assert ten.DataType.UINT4.allowed(15) - assert ten.DataType.UINT4.allowed(150) is False - assert ten.DataType.UINT8.allowed(150) - assert ten.DataType.UINT8.allowed(777) is False - assert ten.DataType.UINT16.allowed(14500) - assert ten.DataType.UINT16.allowed(-1) is False - assert ten.DataType.UINT32.allowed(2 ** 10) - assert ten.DataType.UINT32.allowed(-1) is False - assert ten.DataType.INT2.allowed(-1) - assert ten.DataType.INT2.allowed(-10) is False - assert ten.DataType.INT3.allowed(5) is False - assert ten.DataType.INT3.allowed(-2) - assert ten.DataType.INT4.allowed(15) is False - assert ten.DataType.INT4.allowed(-5) - assert ten.DataType.INT8.allowed(150) is False - assert ten.DataType.INT8.allowed(-127) - assert ten.DataType.INT16.allowed(-1.04) is False - assert ten.DataType.INT16.allowed(-7777) - assert ten.DataType.INT32.allowed(7.77) is False - assert ten.DataType.INT32.allowed(-5) - assert ten.DataType.INT32.allowed(5) + assert dt.DataType.BIPOLAR.allowed(-1) + assert dt.DataType.BIPOLAR.allowed(0) is False + assert dt.DataType.BINARY.allowed(-1) is False + assert dt.DataType.BINARY.allowed(1) + assert dt.DataType.UINT2.allowed(2) + assert dt.DataType.UINT2.allowed(10) is False + assert dt.DataType.UINT3.allowed(5) + assert dt.DataType.UINT3.allowed(-7) is False + assert dt.DataType.UINT4.allowed(15) + assert dt.DataType.UINT4.allowed(150) is False + assert dt.DataType.UINT8.allowed(150) + assert dt.DataType.UINT8.allowed(777) is False + assert dt.DataType.UINT16.allowed(14500) + assert dt.DataType.UINT16.allowed(-1) is False + assert dt.DataType.UINT32.allowed(2 ** 10) + assert dt.DataType.UINT32.allowed(-1) is False + assert dt.DataType.INT2.allowed(-1) + assert dt.DataType.INT2.allowed(-10) is False + assert dt.DataType.INT3.allowed(5) is False + assert dt.DataType.INT3.allowed(-2) + assert dt.DataType.INT4.allowed(15) is False + assert dt.DataType.INT4.allowed(-5) + assert dt.DataType.INT8.allowed(150) is False + assert dt.DataType.INT8.allowed(-127) + assert dt.DataType.INT16.allowed(-1.04) is False + assert dt.DataType.INT16.allowed(-7777) + assert dt.DataType.INT32.allowed(7.77) is False + assert dt.DataType.INT32.allowed(-5) + assert dt.DataType.INT32.allowed(5)