From 505a0a93e75a2d7e4cfc9290f857d4d50c081165 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 26 Mar 2020 19:16:17 +0000 Subject: [PATCH] [StreamingFC] add shape and type inference to StreamingFC --- .../fpgadataflow/streamingfclayer_batch.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index 22421abc8..606c02778 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -32,7 +32,7 @@ from shutil import copy import numpy as np from pyverilator import PyVerilator - +from onnx import TensorProto, helper from finn.core.datatype import DataType from finn.custom_op.fpgadataflow import HLSCustomOp from finn.util.basic import interleave_matrix_outer_dim_from_partitions @@ -113,10 +113,26 @@ class StreamingFCLayer_Batch(HLSCustomOp): return mh // pe def make_shape_compatible_op(self, model): - pass + oshape = self.get_normal_output_shape() + # implement tensor with correct shape + values = np.random.randn(*oshape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + ) def infer_node_datatype(self, model): - pass + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) def verify_node(self): info_messages = [] -- GitLab