Skip to content
Snippets Groups Projects
Commit cc648647 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[SetFolding] handle ConvolutionInputGenerator1D as well

parent ae5cb636
No related branches found
No related tags found
No related merge requests found
......@@ -104,7 +104,12 @@ class SetFolding(Transformation):
]
# these ops use SIMD parallelism, up to a max value of NumChannels
# ConvolutionInputGenerator has a special case when depthwise=1
simd_ops = ["DownSampler", "FMPadding_Batch", "ConvolutionInputGenerator"]
simd_ops = [
"DownSampler",
"FMPadding_Batch",
"ConvolutionInputGenerator",
"ConvolutionInputGenerator1D",
]
# these ops are preceded by depthwise SWG and have special behavior,
# as explained in the SetFolding docstring
depthwise_op_exceptions = ["Vector_Vector_Activate_Batch", "Pool_Batch"]
......@@ -166,7 +171,10 @@ class SetFolding(Transformation):
"Expected SWU on DW op input, found " + swu_node.op_type
)
elif op_type in simd_ops:
if op_type == "ConvolutionInputGenerator":
if op_type in [
"ConvolutionInputGenerator",
"ConvolutionInputGenerator1D",
]:
depthwise = node_inst.get_nodeattr("depthwise")
if depthwise == 0:
max_simd = node_inst.get_nodeattr("IFMChannels")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment