Skip to content
Snippets Groups Projects
Commit fc3ec9c7 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

fix issues from pre-commit

parent 6dd32dd1
No related branches found
No related tags found
No related merge requests found
...@@ -24,88 +24,95 @@ ...@@ -24,88 +24,95 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
from enum import Enum from enum import Enum
import numpy as np
class DataType(Enum): class DataType(Enum):
FLOAT32 = 0 FLOAT32 = 0
BINARY = 1 BINARY = 1
BIPOLAR = 2 BIPOLAR = 2
UINT2 = 3 UINT2 = 3
UINT3 = 4 UINT3 = 4
UINT4 = 5 UINT4 = 5
UINT8 = 6 UINT8 = 6
UINT16 = 7 UINT16 = 7
UINT32 = 8 UINT32 = 8
INT2 = 9 INT2 = 9
INT3 = 10 INT3 = 10
INT4 = 11 INT4 = 11
INT8 = 12 INT8 = 12
INT16 = 13 INT16 = 13
INT32 = 14 INT32 = 14
def bitwidth(self): def bitwidth(self):
"""Returns the number of bits required for this DataType.""" """Returns the number of bits required for this DataType."""
if self.name.startswith("UINT"): if self.name.startswith("UINT"):
return int(self.name.strip("UINT")) return int(self.name.strip("UINT"))
elif self.name.startswith("INT"): elif self.name.startswith("INT"):
return int(self.name.strip("INT")) return int(self.name.strip("INT"))
elif "FLOAT" in self.name: elif "FLOAT" in self.name:
return int(self.name.strip("FLOAT")) return int(self.name.strip("FLOAT"))
elif self.name in ["BINARY", "BIPOLAR"]: elif self.name in ["BINARY", "BIPOLAR"]:
return 1 return 1
else: else:
raise Exception("Unrecognized data type: %s" % self.name) raise Exception("Unrecognized data type: %s" % self.name)
def min(self): def min(self):
"""Returns the smallest possible value allowed by this DataType.""" """Returns the smallest possible value allowed by this DataType."""
if self.name.startswith("UINT") or self.name == "BINARY": if self.name.startswith("UINT") or self.name == "BINARY":
return 0 return 0
elif self.name.startswith("INT"): elif self.name.startswith("INT"):
return -(2 ** (self.bitwidth() - 1)) return -(2 ** (self.bitwidth() - 1))
elif self.name == "FLOAT32": elif self.name == "FLOAT32":
return np.finfo(np.float32).min return np.finfo(np.float32).min
elif self.name == "BIPOLAR": elif self.name == "BIPOLAR":
return -1 return -1
else: else:
raise Exception("Unrecognized data type: %s" % self.name) raise Exception("Unrecognized data type: %s" % self.name)
def max(self): def max(self):
"""Returns the largest possible value allowed by this DataType.""" """Returns the largest possible value allowed by this DataType."""
if self.name.startswith("UINT"): if self.name.startswith("UINT"):
return (2 ** (self.bitwidth())) - 1 return (2 ** (self.bitwidth())) - 1
elif self.name == "BINARY": elif self.name == "BINARY":
return +1 return +1
elif self.name.startswith("INT"): elif self.name.startswith("INT"):
return (2 ** (self.bitwidth() - 1)) - 1 return (2 ** (self.bitwidth() - 1)) - 1
elif self.name == "FLOAT32": elif self.name == "FLOAT32":
return np.finfo(np.float32).max return np.finfo(np.float32).max
elif self.name == "BIPOLAR": elif self.name == "BIPOLAR":
return +1 return +1
else: else:
raise Exception("Unrecognized data type: %s" % self.name) raise Exception("Unrecognized data type: %s" % self.name)
def allowed(self, value): def allowed(self, value):
"""Check whether given value is allowed for this DataType. """Check whether given value is allowed for this DataType.
value (float32): value to be checked""" value (float32): value to be checked"""
if "FLOAT" in self.name: if "FLOAT" in self.name:
return True return True
elif "INT" in self.name: elif "INT" in self.name:
return (self.min() <= value) and (value <= self.max()) and float(value).is_integer() return (
elif self.name == "BINARY": (self.min() <= value)
return value in [0, 1] and (value <= self.max())
elif self.name == "BIPOLAR": and float(value).is_integer()
return value in [-1, +1] )
else: elif self.name == "BINARY":
raise Exception("Unrecognized data type: %s" % self.name) return value in [0, 1]
elif self.name == "BIPOLAR":
return value in [-1, +1]
else:
raise Exception("Unrecognized data type: %s" % self.name)
class Tensor(object): class Tensor(object):
"""A multidimensional array of numbers of given datatype. """A multidimensional array of numbers of given datatype.
Attributes: Attributes:
dtype (DataType): Element data type for this Tensor dtype (DataType): Element data type for this Tensor
...@@ -114,7 +121,7 @@ class Tensor(object): ...@@ -114,7 +121,7 @@ class Tensor(object):
["N", "C", "H", "W"] ["N", "C", "H", "W"]
""" """
def __init__(self, dtype, data, dim_names = []): def __init__(self, dtype, data, dim_names=[]):
self.dtype = dtype self.dtype = dtype
self.data = data self.data = data
self.dim_names = dim_names self.dim_names = dim_names
...@@ -24,52 +24,55 @@ ...@@ -24,52 +24,55 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import onnx import onnx
import onnx.helper as helper import onnx.helper as helper
import numpy as np
from functools import reduce
from onnx import numpy_helper as np_helper
import onnxruntime as rt import onnxruntime as rt
from onnx import numpy_helper as np_helper
model = onnx.load_model("model.onnx") model = onnx.load_model("model.onnx")
graph = model.graph graph = model.graph
def valueinfo_to_tensor(vi): def valueinfo_to_tensor(vi):
"""Creates an all-zeroes numpy tensor from a ValueInfoProto.""" """Creates an all-zeroes numpy tensor from a ValueInfoProto."""
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return np.zeros(
dims, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type]
)
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type])
def execute_node(node, context, graph): def execute_node(node, context, graph):
"""Call onnxruntime to execute a single node. Input/output provided via context.""" """Call onnxruntime to execute a single node. Input/output provided via context."""
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
node_graph = helper.make_graph(
nodes=[node], name="single-node-exec", inputs=node_inputs, outputs=node_outputs
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
print("Input shape for %s: %s" % (inp, context[inp].shape))
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception(
"Output shapes disagree after node execution: found %s vs expected %s"
% (str(output_list[output_ind].shape.shape), str(context[outp].shape))
)
context[outp] = output_list[output_ind]
print("Output shape for %s: %s" % (outp, context[outp].shape))
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += (list(filter(lambda x: x.name in node.output, graph.value_info)))
node_graph = helper.make_graph(
nodes = [node],
name = "single-node-exec",
inputs = node_inputs,
outputs = node_outputs
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
print("Input shape for %s: %s" % (inp, context[inp].shape))
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_list[output_ind].shape.shape), str(context[outp].shape)))
context[outp] = output_list[output_ind]
print("Output shape for %s: %s" % (outp, context[outp].shape))
# first, we need to make sure that every variable required by the graph has # first, we need to make sure that every variable required by the graph has
# some buffer associated with it. this includes graph inputs (which includes # some buffer associated with it. this includes graph inputs (which includes
...@@ -79,32 +82,32 @@ def execute_node(node, context, graph): ...@@ -79,32 +82,32 @@ def execute_node(node, context, graph):
execution_context = dict() execution_context = dict()
# make empty tensors for all the graph inputs and outputs # make empty tensors for all the graph inputs and outputs
for vi in graph.input: for vi in graph.input:
new_tensor = valueinfo_to_tensor(vi) new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor execution_context[vi.name] = new_tensor
for vi in graph.output: for vi in graph.output:
new_tensor = valueinfo_to_tensor(vi) new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor execution_context[vi.name] = new_tensor
# make empty tensors for all intermediate buffers # make empty tensors for all intermediate buffers
# TODO are we guaranteed to have the .value_info filled? # TODO are we guaranteed to have the .value_info filled?
# do we need to call ONNX shape inference first? # do we need to call ONNX shape inference first?
for vi in graph.value_info: for vi in graph.value_info:
new_tensor = valueinfo_to_tensor(vi) new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor execution_context[vi.name] = new_tensor
# fill in the constants provided by the initializers (TensorProto to npy) # fill in the constants provided by the initializers (TensorProto to npy)
for t in graph.initializer: for t in graph.initializer:
execution_context[t.name] = np_helper.to_array(t) execution_context[t.name] = np_helper.to_array(t)
# now call each node in the graph nodes list # now call each node in the graph nodes list
# we can simply walk down the list since the ONNX spec guarantees that it is # we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted # topologically sorted
all_used_ops = set() all_used_ops = set()
for node in graph.node: for node in graph.node:
print("Node name: %s Type: %s" % (node.name, node.op_type)) print("Node name: %s Type: %s" % (node.name, node.op_type))
all_used_ops.add(node.op_type) all_used_ops.add(node.op_type)
print("Input(s): " + str(node.input)) print("Input(s): " + str(node.input))
print("Output(s): " + str(node.output)) print("Output(s): " + str(node.output))
print("Attribute(s): " + str(node.attribute)) print("Attribute(s): " + str(node.attribute))
execute_node(node, execution_context, graph) execute_node(node, execution_context, graph)
print("Final output(s): ") print("Final output(s): ")
print(execution_context[graph.output[0].name]) print(execution_context[graph.output[0].name])
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pytest
import finn.core.tensor as ten import finn.core.tensor as ten
def test_datatypes(): def test_datatypes():
assert ten.DataType.BIPOLAR.allowed(-1) == True assert ten.DataType.BIPOLAR.allowed(-1) is True
assert ten.DataType.BIPOLAR.allowed(0) == False assert ten.DataType.BIPOLAR.allowed(0) is False
assert ten.DataType.BINARY.allowed(-1) == False assert ten.DataType.BINARY.allowed(-1) is False
assert ten.DataType.BINARY.allowed(-1) == True assert ten.DataType.BINARY.allowed(-1) is True
assert ten.DataType.UINT2.allowed(2) == True assert ten.DataType.UINT2.allowed(2) is True
assert ten.DataType.UINT2.allowed(10) == False assert ten.DataType.UINT2.allowed(10) is False
assert ten.DataType.UINT3.allowed(5) == True assert ten.DataType.UINT3.allowed(5) is True
assert ten.DataType.UINT3.allowed(-7) == False assert ten.DataType.UINT3.allowed(-7) is False
assert ten.DataType.UINT4.allowed(15) == True assert ten.DataType.UINT4.allowed(15) is True
assert ten.DataType.UINT4.allowed(150) == False assert ten.DataType.UINT4.allowed(150) is False
assert ten.DataType.UINT8.allowed(150) == True assert ten.DataType.UINT8.allowed(150) is True
assert ten.DataType.UINT8.allowed(777) == False assert ten.DataType.UINT8.allowed(777) is False
assert ten.DataType.UINT16.allowed(14500) == True assert ten.DataType.UINT16.allowed(14500) is True
assert ten.DataType.UINT16.allowed(-1) == False assert ten.DataType.UINT16.allowed(-1) is False
assert ten.DataType.UINT32.allowed(2 ** 10) == True assert ten.DataType.UINT32.allowed(2 ** 10) is True
assert ten.DataType.UINT32.allowed(-1) == False assert ten.DataType.UINT32.allowed(-1) is False
assert ten.DataType.INT2.allowed(2) == True assert ten.DataType.INT2.allowed(2) is True
assert ten.DataType.INT2.allowed(-10) == False assert ten.DataType.INT2.allowed(-10) is False
assert ten.DataType.INT3.allowed(5) == False assert ten.DataType.INT3.allowed(5) is False
assert ten.DataType.INT3.allowed(-2) == True assert ten.DataType.INT3.allowed(-2) is True
assert ten.DataType.INT4.allowed(15) == False assert ten.DataType.INT4.allowed(15) is False
assert ten.DataType.INT4.allowed(-5) == True assert ten.DataType.INT4.allowed(-5) is True
assert ten.DataType.INT8.allowed(150) == False assert ten.DataType.INT8.allowed(150) is False
assert ten.DataType.INT8.allowed(-127) == True assert ten.DataType.INT8.allowed(-127) is True
assert ten.DataType.INT16.allowed(-1.04) == False assert ten.DataType.INT16.allowed(-1.04) is False
assert ten.DataType.INT16.allowed(-7777) == True assert ten.DataType.INT16.allowed(-7777) is True
assert ten.DataType.INT32.allowed(7.77) == False assert ten.DataType.INT32.allowed(7.77) is False
assert ten.DataType.INT32.allowed(-5) == True assert ten.DataType.INT32.allowed(-5) is True
assert ten.DataType.INT32.allowed(5) == True assert ten.DataType.INT32.allowed(5) is True
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment