Skip to content
Snippets Groups Projects
Unverified Commit 7ec36b01 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #194 from quetric/feature/fix_duplicate_batch_infer_shape

[HLSCustomOp] Fix inferShapes for InferDuplicateStreamsLayer
parents d701634a 8e1eea34
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ import numpy as np
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from onnx import TensorProto, helper
from onnx import helper, TensorProto
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
......@@ -80,24 +80,33 @@ class DuplicateStreams_Batch(HLSCustomOp):
def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape()
oshape = self.get_normal_output_shape()
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpected input shape."
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
oshape = self.get_normal_output_shape()
values = np.zeros(oshape).astype(np.float32)
split_input = np.concatenate((values, values), axis=0)
return helper.make_node(
split_in = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, oshape
)
model.graph.value_info.append(split_in) # requires clean up
model.set_initializer(split_in.name, split_input)
shape_comp_node = helper.make_node(
"Split",
inputs=[split_input],
outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor", data_type=TensorProto.FLOAT, axis=0
),
inputs=[split_in.name],
outputs=[self.onnx_node.output[0], self.onnx_node.output[1]],
axis=0,
)
return shape_comp_node
def infer_node_datatype(self, model):
odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt)
model.set_tensor_datatype(self.onnx_node.output[1], odt)
def verify_node(self):
info_messages = []
......
......@@ -33,6 +33,8 @@ from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
......@@ -72,6 +74,9 @@ def make_dupstreams_modelwrapper(ch, pe, idim, idt):
model.set_tensor_datatype("inp", idt)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment