Skip to content
Snippets Groups Projects
Commit 68c957bf authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[FPGADataFlow] InferConvInpGen inserts padding node when required

parent 054d4dd6
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from onnx import helper from onnx import helper, TensorProto
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.transformation import Transformation from finn.transformation import Transformation
...@@ -59,27 +59,62 @@ class InferConvInpGen(Transformation): ...@@ -59,27 +59,62 @@ class InferConvInpGen(Transformation):
ifm_ch = i2c_in_shape[-1] ifm_ch = i2c_in_shape[-1]
ifm_dim = i2c_in_shape[1] ifm_dim = i2c_in_shape[1]
ofm_dim = i2c_out_shape[1] ofm_dim = i2c_out_shape[1]
# if padding enabled, ensure pad_val supported by DataType
# default params for ConvolutionInputGenerator
ConvInpGen_node_idx = node_ind
ConvInpGen_input = i2c_input
ConvInpGen_idim = ifm_dim
if pad > 0: if pad > 0:
# if padding enabled, ensure pad_val supported by DataType
assert dt.allowed(pad_val), "Im2Col DataType must support pad_val" assert dt.allowed(pad_val), "Im2Col DataType must support pad_val"
odim_padding = ifm_dim + 2 * pad
padding_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(1, odim_padding, odim_padding, ifm_ch),
)
graph.value_info.append(padding_out)
padding_out = padding_out.name
model.set_tensor_datatype(padding_out, dt)
ConvInpGen_node_idx += 1
ConvInpGen_input = padding_out
ConvInpGen_idim = odim_padding
padding_node = helper.make_node(
"FMPadding_Batch",
[i2c_input],
[padding_out],
domain="finn",
backend="fpgadataflow",
ImgDim=ifm_dim,
OutputDim=odim_padding,
Padding=2 * pad,
NumChannels=ifm_ch,
inputDataType=dt.name,
)
graph.node.insert(node_ind, padding_node)
# create equivalent ConvolutionInputGenerator node # create equivalent ConvolutionInputGenerator node
# TODO support padding ConvInpGen_node = helper.make_node(
new_node = helper.make_node(
"ConvolutionInputGenerator", "ConvolutionInputGenerator",
[i2c_input], [ConvInpGen_input],
[i2c_output], [i2c_output],
domain="finn", domain="finn",
backend="fpgadataflow", backend="fpgadataflow",
ConvKernelDim=k, ConvKernelDim=k,
IFMChannels=ifm_ch, IFMChannels=ifm_ch,
IFMDim=ifm_dim, IFMDim=ConvInpGen_idim,
OFMDim=ofm_dim, OFMDim=ofm_dim,
SIMD=ifm_ch, SIMD=ifm_ch,
Stride=stride, Stride=stride,
inputDataType=dt.name, inputDataType=dt.name,
outputDataType=dt.name, outputDataType=dt.name,
) )
graph.node.insert(node_ind, new_node) graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node)
# remove old nodes # remove old nodes
graph.node.remove(n) graph.node.remove(n)
graph_modified = True graph_modified = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment