Skip to content
Snippets Groups Projects
Commit 67a44b4b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'quetric-feature/insert_topk' into dev

parents b2dc7f81 c7c7b6c5
No related branches found
No related tags found
No related merge requests found
......@@ -253,14 +253,12 @@ class ModelWrapper:
return None
def find_producer(self, tensor_name):
"""Finds and returns the node that produces the tensor with given name.
Currently only works for linear graphs."""
all_outputs = [x.output[0] for x in self._model_proto.graph.node]
try:
producer_ind = all_outputs.index(tensor_name)
return self._model_proto.graph.node[producer_ind]
except ValueError:
return None
"""Finds and returns the node that produces the tensor with given name."""
ret = None
for x in self._model_proto.graph.node:
if tensor_name in x.output:
ret = x
return ret
def find_upstream(self, tensor_name, finder_fxn):
"""Follow the producer chain upstream, calling finder_fxn on each upstream
......
......@@ -61,6 +61,10 @@ def execute_node(node, context, graph):
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
# note: ensure that the same ValueInfo does not appear both in
# graph.value_info as well as graph.output or graph.input
# nodes with multiple outputs that are a mix of value_info and
# input/outputs may get them reordered below
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(
filter(lambda x: x.name in node.input, graph.value_info)
......@@ -84,17 +88,25 @@ def execute_node(node, context, graph):
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
# get the name of the target buffer from node.output
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
# retrieve the index of that name in node_outputs
for i in range(len(node_outputs)):
if outp == node_outputs[i].name:
list_ind = i
# use that index to index output_list
if output_list[list_ind].shape != context[outp].shape:
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (
str(output_list[output_ind].shape.shape),
str(output_list[list_ind].shape.shape),
str(context[outp].shape),
)
)
context[outp] = output_list[output_ind]
context[outp] = output_list[list_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False):
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
from onnx import TensorProto
from onnx import helper as oh
from finn.transformation import Transformation
from finn.core.datatype import DataType
class InsertTopK(Transformation):
"""Add TopK node at the network output and replace the graph output with
the TopK indices."""
def __init__(self, k=5, axis=-1, largest=1, sorted=1):
super().__init__()
self.k = k
self.axis = axis
self.largest = largest
self.sorted = sorted
def apply(self, model):
# get name of output tensor
graph_out_name = model.graph.output[0].name
# find final node
final_node = model.find_producer(graph_out_name)
# if a top-select op is already present, do nothing
if final_node.op_type == "TopK":
return (model, False)
else:
out_shape = model.get_tensor_shape(graph_out_name)
out_dtype = model.get_tensor_datatype(graph_out_name)
# adjust shape
out_shape[self.axis] = self.k
# make new buffer
k_tensor = np.array([self.k]).astype(np.int64)
k_value = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.INT64, [1]
)
topk_values = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape
)
topk_indices = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.INT64, out_shape
)
model.graph.value_info.append(k_value)
model.set_tensor_datatype(k_value.name, out_dtype) # TODO set to int64
model.graph.value_info.append(topk_values)
model.set_tensor_datatype(topk_values.name, out_dtype)
# create and append topk node
model.set_initializer(k_value.name, k_tensor)
topk_node = oh.make_node(
"TopK",
inputs=[graph_out_name, k_value.name],
outputs=[topk_values.name, topk_indices.name],
axis=self.axis,
largest=self.largest,
sorted=self.sorted,
)
model.graph.node.append(topk_node)
# replace the existing output definition with topk indices
model.graph.output.insert(0, topk_indices)
model.graph.output.pop(1)
# set quantization annotation for indices
# minimal output dtype for TopK indices dependens on num. classes
# assuming UINT32 is large enough for now (FINN has currently no
# DataType.INT64)
model.set_tensor_datatype(topk_indices.name, DataType.UINT32)
return (model, True)
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
import numpy as np
import onnx.numpy_helper as nph
import torch
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.insert_topk import InsertTopK
import finn.core.onnx_exec as oxe
from pkgutil import get_data
import pytest
export_onnx_path = "test_output_lfc.onnx"
@pytest.mark.parametrize("k", [1, 5, 10])
def test_topk_insert(k):
tfc = get_test_model_trained("TFC", 1, 1)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
# do transformations (no topk)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
# verification: generate random input, run through net, streamline,
# run again, check that output is top-k
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_brevitas = torch.from_numpy(nph.to_array(input_tensor)).float()
output_golden = tfc.forward(input_brevitas).detach().numpy()
output_golden_topk = np.flip(output_golden.flatten().argsort())[:k]
output_golden_topk = output_golden_topk.flatten()
input_dict = {"global_in": nph.to_array(input_tensor)}
# insert top-k
model = model.transform(InsertTopK(k))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferShapes())
# verify output of top-k
output_dict_topk = oxe.execute_onnx(model, input_dict)
output_pysim_topk = output_dict_topk[list(output_dict_topk.keys())[0]]
output_pysim_topk = output_pysim_topk.astype(np.int).flatten()
assert np.array_equal(output_golden_topk, output_pysim_topk)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment