diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 3ff86cab48d365c10e69bc2c764e8083c6a36880..6f45d498cc7d6fd663cd6b507cfd780b69d51dd8 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from onnx import helper +from onnx import helper, TensorProto from finn.core.datatype import DataType from finn.transformation import Transformation @@ -59,27 +59,62 @@ class InferConvInpGen(Transformation): ifm_ch = i2c_in_shape[-1] ifm_dim = i2c_in_shape[1] ofm_dim = i2c_out_shape[1] - # if padding enabled, ensure pad_val supported by DataType + + # default params for ConvolutionInputGenerator + ConvInpGen_node_idx = node_ind + ConvInpGen_input = i2c_input + ConvInpGen_idim = ifm_dim + if pad > 0: + # if padding enabled, ensure pad_val supported by DataType assert dt.allowed(pad_val), "Im2Col DataType must support pad_val" + + odim_padding = ifm_dim + 2 * pad + + padding_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, odim_padding, odim_padding, ifm_ch), + ) + graph.value_info.append(padding_out) + padding_out = padding_out.name + model.set_tensor_datatype(padding_out, dt) + + ConvInpGen_node_idx += 1 + ConvInpGen_input = padding_out + ConvInpGen_idim = odim_padding + + padding_node = helper.make_node( + "FMPadding_Batch", + [i2c_input], + [padding_out], + domain="finn", + backend="fpgadataflow", + ImgDim=ifm_dim, + OutputDim=odim_padding, + Padding=2 * pad, + NumChannels=ifm_ch, + inputDataType=dt.name, + ) + graph.node.insert(node_ind, padding_node) + # create equivalent ConvolutionInputGenerator node - # TODO support padding - new_node = helper.make_node( + ConvInpGen_node = helper.make_node( "ConvolutionInputGenerator", - [i2c_input], + [ConvInpGen_input], [i2c_output], domain="finn", backend="fpgadataflow", ConvKernelDim=k, IFMChannels=ifm_ch, - IFMDim=ifm_dim, + IFMDim=ConvInpGen_idim, OFMDim=ofm_dim, SIMD=ifm_ch, Stride=stride, inputDataType=dt.name, outputDataType=dt.name, ) - graph.node.insert(node_ind, new_node) + graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) # remove old nodes graph.node.remove(n) graph_modified = True