From 1ee4b9446c4c782ff200db75d14c1260b8ab3fee Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Tue, 7 Jul 2020 17:14:12 +0100
Subject: [PATCH] [HLSCustomOp] Fix inferShapes for InferDuplicateStreamsLayer

---
 .../fpgadataflow/duplicatestreams_batch.py      | 17 ++++++-----------
 1 file changed, 6 insertions(+), 11 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
index 54051af5e..8143b9c55 100644
--- a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
+++ b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
@@ -32,7 +32,7 @@ import numpy as np
 
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow import HLSCustomOp
-from onnx import TensorProto, helper
+from onnx import helper
 from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
 
 
@@ -80,24 +80,19 @@ class DuplicateStreams_Batch(HLSCustomOp):
 
     def make_shape_compatible_op(self, model):
         exp_ishape = self.get_normal_input_shape()
-        oshape = self.get_normal_output_shape()
         ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
         assert ishape == exp_ishape, "Unexpected input shape."
-        # implement tensor with correct shape
-        values = np.random.randn(*oshape).astype(np.float32)
-        split_input = np.concatenate((values, values), axis=0)
         return helper.make_node(
-            "Split",
-            inputs=[split_input],
-            outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
-            value=helper.make_tensor(
-                name="const_tensor", data_type=TensorProto.FLOAT, axis=0
-            ),
+            "Dropout",
+            inputs=[self.onnx_node.input[0]],
+            outputs=[self.onnx_node.output[0], self.onnx_node.output[1]],
+            axis=0,
         )
 
     def infer_node_datatype(self, model):
         odt = self.get_output_datatype()
         model.set_tensor_datatype(self.onnx_node.output[0], odt)
+        model.set_tensor_datatype(self.onnx_node.output[1], odt)
 
     def verify_node(self):
         info_messages = []
-- 
GitLab