Skip to content
Snippets Groups Projects
Commit d7fe1ef4 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] AbsorbMulIntoTopK -> AbsorbMulAddIntoTopK

parent 232810ec
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from onnx import TensorProto, helper
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
from finn.util.basic import roundup_to_integer_multiple
class LabelSelect_Batch(HLSCustomOp):
......@@ -46,6 +47,10 @@ class LabelSelect_Batch(HLSCustomOp):
# If not provided compute min size
labels = self.get_nodeattr("Labels")
odt = DataType.get_smallest_possible(labels - 1)
# ensure a datatype divisible by 8-bits in case this is the last node
bw = roundup_to_integer_multiple(odt.bitwidth(), 8)
new_odt_name = odt.name.replace(str(odt.bitwidth()), str(bw))
odt = DataType[new_odt_name]
odt_name = odt.name
self.set_nodeattr("outputDataType", odt_name)
......
......@@ -424,8 +424,9 @@ class AbsorbTransposeIntoFlatten(Transformation):
return (model, graph_modified)
class AbsorbScalarMulIntoTopK(Transformation):
"""Absorb a mul node into a suceeding topk node if the mul is scalar."""
class AbsorbScalarMulAddIntoTopK(Transformation):
"""Remove mul/add node prior to topk node if the op is scalar. Note that
the TopK output probabilities will change, but the indices won't."""
def apply(self, model):
graph = model.graph
......@@ -435,14 +436,17 @@ class AbsorbScalarMulIntoTopK(Transformation):
node_ind += 1
if n.op_type == "TopK":
prod = model.find_producer(n.input[0])
if prod is not None and prod.op_type == "Mul":
if prod is not None and (prod.op_type in ["Mul", "Add"]):
prod_input = prod.input[0]
param_name = prod.input[1]
A = model.get_initializer(param_name)
if A is None:
warnings.warn("Param is not constant, skipping")
continue
if all(x == 1 for x in A.shape) and A > 0:
is_scalar = all(x == 1 for x in A.shape)
is_scalar_pos_mul = is_scalar and (prod.op_type == "Mul") and A > 0
is_scalar_add = is_scalar and (prod.op_type == "Add")
if is_scalar_pos_mul or is_scalar_add:
# if the mul is scalar and positive, we can just delete the
# mul node and rewire the top k node. Because the top k node
# works with probabilities and their relation to each other
......
......@@ -52,7 +52,7 @@ from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.streamline.absorb import (
AbsorbScalarMulIntoTopK,
AbsorbScalarMulAddIntoTopK,
AbsorbConsecutiveTransposes,
)
from finn.transformation.streamline.collapse_repeated import (
......@@ -192,7 +192,7 @@ def test_convert_to_hls_layers_synthetic(ch, ifmdim, idt):
model = model.transform(to_hls.InferGlobalAccPoolLayer())
model = model.transform(MoveScalarLinearPastInvariants())
model = model.transform(InsertTopK())
model = model.transform(AbsorbScalarMulIntoTopK())
model = model.transform(AbsorbScalarMulAddIntoTopK())
model = model.transform(InferDataTypes())
model = model.transform(to_hls.InferLabelSelectLayer())
model = model.transform(AbsorbConsecutiveTransposes())
......
......@@ -35,7 +35,7 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.transformation.insert_topk import InsertTopK
from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK
from finn.transformation.streamline.absorb import AbsorbScalarMulAddIntoTopK
import finn.core.onnx_exec as oxe
# parameter to indicate if mul parameter is negative or positive
......@@ -49,20 +49,24 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
shape = [1, 1, 1, 1000]
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000])
a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape)
b0 = helper.make_tensor_value_info("b0", TensorProto.FLOAT, [1, 1, 1, 1000])
c0 = helper.make_tensor_value_info("c0", TensorProto.FLOAT, shape)
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000])
mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"])
mul_node = helper.make_node("Mul", ["inp", "a0"], ["b0"])
add_node = helper.make_node("Add", ["b0", "c0"], ["outp"])
mul_graph = helper.make_graph(
nodes=[mul_node],
nodes=[mul_node, add_node],
name="mul-graph",
inputs=[inp],
outputs=[outp],
value_info=[a0],
value_info=[a0, b0, c0],
)
model = helper.make_model(mul_graph, producer_name="mul_model")
model = ModelWrapper(model)
# initialize values
# for mul
if mul_positive is True:
a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype(
np.float32
......@@ -72,12 +76,17 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
np.float32
)
model.set_initializer("a0", a0_values)
# for add
c0_values = np.random.uniform(low=-1, high=-0.1, size=tuple(shape)).astype(
np.float32
)
model.set_initializer("c0", c0_values)
model = model.transform(InsertTopK())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model_transformed = model.transform(AbsorbScalarMulIntoTopK())
model_transformed = model.transform(AbsorbScalarMulAddIntoTopK())
# compare execution results
inp_values = np.random.uniform(low=-10, high=10, size=(1, 1, 1, 1000)).astype(
......@@ -100,9 +109,5 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
# check for new order
assert model.graph != model_transformed.graph
assert len(model.graph.node) - 1 == len(model_transformed.graph.node)
assert len(model.graph.node) - 2 == len(model_transformed.graph.node)
assert model_transformed.graph.node[0].op_type == "TopK"
else:
assert (y_values == y_tr_values).all()
assert model.graph == model_transformed.graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment