Skip to content
Snippets Groups Projects
Commit 7aed59b6 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[CustomOp] add shape and type inference to ConvInpGen

parent 346a5140
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,8 @@ from pyverilator import PyVerilator
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from finn.custom_op.im2col import compute_conv_output_dim
from onnx import TensorProto, helper
# ONNX i/o tensor shape assumptions for ConvolutionInputGenerator:
# input 0 is the input tensor, shape NHWC = (1, IFMDim, IFMDim, IFMChannels)
......@@ -63,10 +65,34 @@ class ConvolutionInputGenerator(HLSCustomOp):
return my_attrs
def make_shape_compatible_op(self, model):
pass
k = self.get_nodeattr("ConvKernelDim")
ifm_dim = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels")
stride = self.get_nodeattr("Stride")
pad = 0
exp_ishape = (1, ifm_dim, ifm_dim, ifm_ch)
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpect input shape for ConvInpGen."
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
# implement tensor with correct shape
values = np.random.randn(1, ofm_dim, ofm_dim, k * k * ifm_ch).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
outputs=[self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
)
def infer_node_datatype(self, model):
pass
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])
model.set_tensor_datatype(node.output[0], dtype)
def verify_node(self):
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment