Skip to content
Snippets Groups Projects
Commit 99035686 authored by Lucian Petrică's avatar Lucian Petrică
Browse files

Switched to tensor intializer instead of constant node

parent 04281211
No related branches found
No related tags found
No related merge requests found
......@@ -59,10 +59,7 @@ class InsertTopK(Transformation):
#adjust shape
out_shape[self.axis] = self.k
# make new buffer
k_tensor = oh.make_tensor(name='k_tensor',
data_type=TensorProto.INT64,
dims=(1,),
vals=np.array([self.k]).astype(np.int64))
k_tensor = np.array([self.k]).astype(np.int64)
k_value = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.INT64, [1]
)
......@@ -77,12 +74,7 @@ class InsertTopK(Transformation):
model.graph.value_info.append(topk_values)
model.set_tensor_datatype(topk_values.name, out_dtype)
#create and append topk node
k_node = oh.make_node(
'Constant',
inputs=[],
outputs=[k_value.name],
value=k_tensor
)
model.set_initializer(k_value.name, k_tensor)
topk_node = oh.make_node(
'TopK',
inputs=[graph_out_name, k_value.name],
......@@ -91,7 +83,6 @@ class InsertTopK(Transformation):
largest=self.largest,
sorted=self.sorted
)
model.graph.node.append(k_node)
model.graph.node.append(topk_node)
#replace the existing output definition with topk indices
model.graph.output.insert(0,topk_indices)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment