Skip to content
Snippets Groups Projects
Unverified Commit 4e67ce80 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #112 from quetric/feature/GiveUniqueParameterTensors

Add new transformation to give unique parameter tensors
parents 4bcd0bc3 f33971c1
No related branches found
No related tags found
No related merge requests found
......@@ -81,6 +81,41 @@ class GiveReadableTensorNames(Transformation):
return (model, False)
class GiveUniqueParameterTensors(Transformation):
"""Make every parameter tensor unique. The aim is to avoid affecting
other nodes apart from the one the system is currently operating on."""
def apply(self, model):
graph = model.graph
graph_modified = False
seen_parameters = []
for n in graph.node:
# copy inputs since they may be modified
node_inputs_list = [x for x in n.input]
for input_idx, node_input in enumerate(node_inputs_list):
# check if it's a parameter
input_init = model.get_initializer(node_input)
if input_init is None:
# dynamic input
continue
# check if repeated
if node_input not in seen_parameters:
# first occurance
seen_parameters += [node_input]
continue
new_param_name = model.make_new_valueinfo_name()
model.set_initializer(new_param_name, input_init)
model.set_tensor_datatype(new_param_name, model.get_tensor_datatype(node_input))
# point node input to new tensor
n.input[input_idx] = new_param_name
return (model, graph_modified)
class ConvertSubToAdd(Transformation):
"""Convert subtract-a-constant nodes to add-a-constant nodes."""
......
......@@ -31,6 +31,12 @@ from pkgutil import get_data
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.general import GiveUniqueNodeNames
import numpy as np
import onnx
import finn.core.onnx_exec as oxe
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.general import GiveUniqueParameterTensors
def test_give_unique_node_names():
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
......@@ -39,3 +45,76 @@ def test_give_unique_node_names():
assert model.graph.node[0].name == "Reshape_0"
assert model.graph.node[1].name == "Conv_0"
assert model.graph.node[11].name == "Add_2"
def test_give_unique_parameter_tensors():
# Create model
input_shape = [4, 4]
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, input_shape)
out1 = onnx.helper.make_tensor_value_info(
"out1", onnx.TensorProto.FLOAT, input_shape
)
graph_nodes = []
graph_nodes += [
onnx.helper.make_node("Add", inputs=["in1", "param1"], outputs=["t1"])
]
graph_nodes += [
onnx.helper.make_node("Sum", inputs=["t1", "param1", "param1"], outputs=["t2"])
]
graph_nodes += [
onnx.helper.make_node("Sum", inputs=["t2", "param2", "param1"], outputs=["t3"])
]
graph_nodes += [
onnx.helper.make_node("Add", inputs=["t3", "param1"], outputs=["out1"])
]
onnx_graph = onnx.helper.make_graph(
nodes=graph_nodes, name="simple_graph", inputs=[in1], outputs=[out1],
)
onnx_model = onnx.helper.make_model(onnx_graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
# Set param values
np.random.seed(0)
param1 = np.random.rand(*input_shape).astype(np.float32)
param2 = np.random.rand(*input_shape).astype(np.float32)
model.set_initializer("param1", param1)
model.set_initializer("param2", param2)
model = model.transform(InferShapes())
# Apply transformation
new_model = model.transform(GiveUniqueParameterTensors())
new_model = new_model.transform(InferShapes())
# Test
# Breaks the model?
input_tensor = np.random.rand(*input_shape).astype(np.float32)
input_dict = {"in1": input_tensor}
# run original
expected_context = oxe.execute_onnx(model, input_dict)
expected_output = expected_context[model.graph.output[0].name]
# run modified
produced_context = oxe.execute_onnx(new_model, input_dict)
produced_output = produced_context[new_model.graph.output[0].name]
assert np.isclose(
expected_output, produced_output, atol=1e-8
).all(), " GiveUniqueParameterTensors() transform breaks the model"
# Does the job?
param_set = set()
param_cnt = 0
for n in new_model.graph.node:
for i in range(1, len(n.input)):
param_set |= {n.input[i]}
param_cnt += 1
assert len(param_set) == param_cnt, " There are still parameters reused"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment