Skip to content
Snippets Groups Projects
Commit 5364fb10 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Merge branch 'feature/fix_duplicate_batch_infer_shape' into feature/labelSelect_addStream_infer

parents 92ddad90 1ee4b944
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ import numpy as np
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from onnx import TensorProto, helper
from onnx import helper
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
......@@ -80,24 +80,19 @@ class DuplicateStreams_Batch(HLSCustomOp):
def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape()
oshape = self.get_normal_output_shape()
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpected input shape."
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
split_input = np.concatenate((values, values), axis=0)
return helper.make_node(
"Split",
inputs=[split_input],
outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor", data_type=TensorProto.FLOAT, axis=0
),
"Dropout",
inputs=[self.onnx_node.input[0]],
outputs=[self.onnx_node.output[0], self.onnx_node.output[1]],
axis=0,
)
def infer_node_datatype(self, model):
odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt)
model.set_tensor_datatype(self.onnx_node.output[1], odt)
def verify_node(self):
info_messages = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment