Skip to content
Snippets Groups Projects
Commit ec4dfbe7 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Merge branch 'feature/fix_duplicate_batch_infer_shape' into feature/labelSelect_addStream_infer

parents ec7f824c 3bd3b1ce
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ import numpy as np
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from onnx import helper
from onnx import helper, TensorProto
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
......@@ -82,13 +82,27 @@ class DuplicateStreams_Batch(HLSCustomOp):
exp_ishape = self.get_normal_input_shape()
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpected input shape."
return helper.make_node(
"Dropout",
inputs=[self.onnx_node.input[0]],
oshape = self.get_normal_output_shape()
values = np.zeros(oshape).astype(np.float32)
split_input = np.concatenate((values, values), axis=0)
split_in = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, oshape
)
model.graph.value_info.append(split_in) # requires clean up
model.set_initializer(split_in.name, split_input)
shape_comp_node = helper.make_node(
"Split",
inputs=[split_in.name],
outputs=[self.onnx_node.output[0], self.onnx_node.output[1]],
axis=0,
)
return shape_comp_node
def infer_node_datatype(self, model):
odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment