Skip to content
Snippets Groups Projects
Commit e4bb7c2d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] move onnx-importing util to own file

parent 07b3b2db
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ import onnx.numpy_helper as np_helper
from onnx import TensorProto
import finn.util.basic as util
import finn.util.onnx as onnxutil
from finn.core.datatype import DataType
......@@ -248,14 +249,14 @@ class ModelWrapper:
graph = self._model_proto.graph
# make empty tensors for all the graph inputs and outputs
for vi in graph.input:
new_tensor = util.valueinfo_to_tensor(vi)
new_tensor = onnxutil.valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
for vi in graph.output:
new_tensor = util.valueinfo_to_tensor(vi)
new_tensor = onnxutil.valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# make empty tensors for all intermediate buffers
for vi in graph.value_info:
new_tensor = util.valueinfo_to_tensor(vi)
new_tensor = onnxutil.valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# fill in the constants provided by the initializers (TensorProto to npy)
for t in graph.initializer:
......
......@@ -5,7 +5,6 @@ import subprocess
import tempfile
import numpy as np
import onnx
from finn.core.datatype import DataType
......@@ -38,15 +37,6 @@ def make_build_dir(prefix=""):
)
def valueinfo_to_tensor(vi):
"""Creates an all-zeroes numpy tensor from a ValueInfoProto."""
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return np.zeros(
dims, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type]
)
def get_by_name(container, name, name_field="name"):
"""Return item from container by .name field if it exists, None otherwise"""
names = [getattr(x, name_field) for x in container]
......
import numpy as np
import onnx
def valueinfo_to_tensor(vi):
"""Creates an all-zeroes numpy tensor from a ValueInfoProto."""
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return np.zeros(
dims, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment