diff --git a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py index 54051af5e0387081a23e1f8fa77ec9e363098830..bb7ff6efa5c257e172bfa03b6a14251208ae4f1c 100644 --- a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py +++ b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py @@ -32,7 +32,7 @@ import numpy as np from finn.core.datatype import DataType from finn.custom_op.fpgadataflow import HLSCustomOp -from onnx import TensorProto, helper +from onnx import helper, TensorProto from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy @@ -80,24 +80,33 @@ class DuplicateStreams_Batch(HLSCustomOp): def make_shape_compatible_op(self, model): exp_ishape = self.get_normal_input_shape() - oshape = self.get_normal_output_shape() ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) assert ishape == exp_ishape, "Unexpected input shape." - # implement tensor with correct shape - values = np.random.randn(*oshape).astype(np.float32) + + oshape = self.get_normal_output_shape() + values = np.zeros(oshape).astype(np.float32) split_input = np.concatenate((values, values), axis=0) - return helper.make_node( + + split_in = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, oshape + ) + + model.graph.value_info.append(split_in) # requires clean up + model.set_initializer(split_in.name, split_input) + + shape_comp_node = helper.make_node( "Split", - inputs=[split_input], - outputs=[self.onnx_node.output[0], self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", data_type=TensorProto.FLOAT, axis=0 - ), + inputs=[split_in.name], + outputs=[self.onnx_node.output[0], self.onnx_node.output[1]], + axis=0, ) + return shape_comp_node + def infer_node_datatype(self, model): odt = self.get_output_datatype() model.set_tensor_datatype(self.onnx_node.output[0], odt) + model.set_tensor_datatype(self.onnx_node.output[1], odt) def verify_node(self): info_messages = [] diff --git a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py index 4fb84be59333ef0e696204c9064fcf77e35b5d9b..59ac1c09f4fe338ef03a8166c63b9d4b29bbc08e 100644 --- a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py +++ b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py @@ -33,6 +33,8 @@ from onnx import TensorProto, helper import finn.core.onnx_exec as oxe from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim @@ -72,6 +74,9 @@ def make_dupstreams_modelwrapper(ch, pe, idim, idt): model.set_tensor_datatype("inp", idt) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return model