Skip to content
Snippets Groups Projects
Commit 82a1a1e0 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Conv] add assertions for square k/img/stride and dilation of 1

parent 862e8587
No related branches found
No related tags found
No related merge requests found
...@@ -67,7 +67,8 @@ class ConvolutionInputGenerator(HLSCustomOp): ...@@ -67,7 +67,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
"OFMDim": ("ints", True, []), # [H, W] = [Y, X] "OFMDim": ("ints", True, []), # [H, W] = [Y, X]
"SIMD": ("i", True, 0), "SIMD": ("i", True, 0),
"Stride": ("ints", True, [1, 1]), # [H, W] = [Y, X] "Stride": ("ints", True, [1, 1]), # [H, W] = [Y, X]
"Dilation": ("ints", True, [1, 1]), # [H, W] = [Y, X] # note: only dilation=1 supported for now
"Dilation": ("ints", True, [1, 1], {[1, 1]}), # [H, W] = [Y, X]
# FINN DataTypes for inputs, weights, outputs # FINN DataTypes for inputs, weights, outputs
"inputDataType": ("s", True, ""), "inputDataType": ("s", True, ""),
"outputDataType": ("s", True, ""), "outputDataType": ("s", True, ""),
...@@ -87,6 +88,17 @@ class ConvolutionInputGenerator(HLSCustomOp): ...@@ -87,6 +88,17 @@ class ConvolutionInputGenerator(HLSCustomOp):
my_attrs.update(super().get_nodeattr_types()) my_attrs.update(super().get_nodeattr_types())
return my_attrs return my_attrs
def get_nodeattr(self, name):
# overriding get_nodeattr to check for square kernel/img.. requirement
# since this can't be done with the attribute restriction in nodeattr_types
# TODO non-square can be enabled in theory but needs testing
ret = super().get_nodeattr(name)
props_to_check = ["ConvKernelDim", "IFMDim", "OFMDim", "Stride", "Dilation"]
if name in props_to_check:
is_square = ret[0] == ret[1]
assert is_square, "Only square %s supported" % name
return ret
def get_normal_input_shape(self): def get_normal_input_shape(self):
ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels") ifm_ch = self.get_nodeattr("IFMChannels")
...@@ -384,14 +396,7 @@ class ConvolutionInputGenerator(HLSCustomOp): ...@@ -384,14 +396,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
#define Input_precision1 {}\n #define IFMDim1 {}\n #define Input_precision1 {}\n #define IFMDim1 {}\n
#define OFMDim1 {}\n #define SIMD1 {}\n #define OFMDim1 {}\n #define SIMD1 {}\n
#define Stride1 {}\n #define numReps {}""".format( #define Stride1 {}\n #define numReps {}""".format(
k, k, ifm_ch, ifm_precision, ifm_dim, ofm_dim, simd, stride, numReps
ifm_ch,
ifm_precision,
ifm_dim,
ofm_dim,
simd,
stride,
numReps,
) )
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment