From 505a0a93e75a2d7e4cfc9290f857d4d50c081165 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 26 Mar 2020 19:16:17 +0000
Subject: [PATCH] [StreamingFC] add shape and type inference to StreamingFC

---
 .../fpgadataflow/streamingfclayer_batch.py    | 22 ++++++++++++++++---
 1 file changed, 19 insertions(+), 3 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 22421abc8..606c02778 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -32,7 +32,7 @@ from shutil import copy
 
 import numpy as np
 from pyverilator import PyVerilator
-
+from onnx import TensorProto, helper
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow import HLSCustomOp
 from finn.util.basic import interleave_matrix_outer_dim_from_partitions
@@ -113,10 +113,26 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             return mh // pe
 
     def make_shape_compatible_op(self, model):
-        pass
+        oshape = self.get_normal_output_shape()
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
 
     def infer_node_datatype(self, model):
-        pass
+        node = self.onnx_node
+        # data type stays the same
+        dtype = model.get_tensor_datatype(node.input[0])
+        model.set_tensor_datatype(node.output[0], dtype)
 
     def verify_node(self):
         info_messages = []
-- 
GitLab