From fb612a1aadf94df9855bdc29107291f851c61097 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Tue, 3 Sep 2019 15:50:51 +0100 Subject: [PATCH] remove Tensor, rename file to datatype.py let's try to stick with ONNX Tensor defs and see what happens --- src/finn/core/{tensor.py => datatype.py} | 16 ------- tests/test_datatypes.py | 60 ++++++++++++------------ 2 files changed, 30 insertions(+), 46 deletions(-) rename src/finn/core/{tensor.py => datatype.py} (89%) diff --git a/src/finn/core/tensor.py b/src/finn/core/datatype.py similarity index 89% rename from src/finn/core/tensor.py rename to src/finn/core/datatype.py index 257078ad1..9323d0c42 100644 --- a/src/finn/core/tensor.py +++ b/src/finn/core/datatype.py @@ -109,19 +109,3 @@ class DataType(Enum): return value in [-1, +1] else: raise Exception("Unrecognized data type: %s" % self.name) - - -class Tensor(object): - """A multidimensional array of numbers of given datatype. - - Attributes: - dtype (DataType): Element data type for this Tensor - data (numpy ndarray of float32): Numpy container for data - dim_names (list of str): names associated with each dimension, e.g. - ["N", "C", "H", "W"] - """ - - def __init__(self, dtype, data, dim_names=[]): - self.dtype = dtype - self.data = data - self.dim_names = dim_names diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 98d5d387d..c2929fa15 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,35 +1,35 @@ # -*- coding: utf-8 -*- -import finn.core.tensor as ten +import finn.core.datatype as dt def test_datatypes(): - assert ten.DataType.BIPOLAR.allowed(-1) - assert ten.DataType.BIPOLAR.allowed(0) is False - assert ten.DataType.BINARY.allowed(-1) is False - assert ten.DataType.BINARY.allowed(1) - assert ten.DataType.UINT2.allowed(2) - assert ten.DataType.UINT2.allowed(10) is False - assert ten.DataType.UINT3.allowed(5) - assert ten.DataType.UINT3.allowed(-7) is False - assert ten.DataType.UINT4.allowed(15) - assert ten.DataType.UINT4.allowed(150) is False - assert ten.DataType.UINT8.allowed(150) - assert ten.DataType.UINT8.allowed(777) is False - assert ten.DataType.UINT16.allowed(14500) - assert ten.DataType.UINT16.allowed(-1) is False - assert ten.DataType.UINT32.allowed(2 ** 10) - assert ten.DataType.UINT32.allowed(-1) is False - assert ten.DataType.INT2.allowed(-1) - assert ten.DataType.INT2.allowed(-10) is False - assert ten.DataType.INT3.allowed(5) is False - assert ten.DataType.INT3.allowed(-2) - assert ten.DataType.INT4.allowed(15) is False - assert ten.DataType.INT4.allowed(-5) - assert ten.DataType.INT8.allowed(150) is False - assert ten.DataType.INT8.allowed(-127) - assert ten.DataType.INT16.allowed(-1.04) is False - assert ten.DataType.INT16.allowed(-7777) - assert ten.DataType.INT32.allowed(7.77) is False - assert ten.DataType.INT32.allowed(-5) - assert ten.DataType.INT32.allowed(5) + assert dt.DataType.BIPOLAR.allowed(-1) + assert dt.DataType.BIPOLAR.allowed(0) is False + assert dt.DataType.BINARY.allowed(-1) is False + assert dt.DataType.BINARY.allowed(1) + assert dt.DataType.UINT2.allowed(2) + assert dt.DataType.UINT2.allowed(10) is False + assert dt.DataType.UINT3.allowed(5) + assert dt.DataType.UINT3.allowed(-7) is False + assert dt.DataType.UINT4.allowed(15) + assert dt.DataType.UINT4.allowed(150) is False + assert dt.DataType.UINT8.allowed(150) + assert dt.DataType.UINT8.allowed(777) is False + assert dt.DataType.UINT16.allowed(14500) + assert dt.DataType.UINT16.allowed(-1) is False + assert dt.DataType.UINT32.allowed(2 ** 10) + assert dt.DataType.UINT32.allowed(-1) is False + assert dt.DataType.INT2.allowed(-1) + assert dt.DataType.INT2.allowed(-10) is False + assert dt.DataType.INT3.allowed(5) is False + assert dt.DataType.INT3.allowed(-2) + assert dt.DataType.INT4.allowed(15) is False + assert dt.DataType.INT4.allowed(-5) + assert dt.DataType.INT8.allowed(150) is False + assert dt.DataType.INT8.allowed(-127) + assert dt.DataType.INT16.allowed(-1.04) is False + assert dt.DataType.INT16.allowed(-7777) + assert dt.DataType.INT32.allowed(7.77) is False + assert dt.DataType.INT32.allowed(-5) + assert dt.DataType.INT32.allowed(5) -- GitLab