Skip to content
Snippets Groups Projects
Commit cc648647 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[SetFolding] handle ConvolutionInputGenerator1D as well

parent ae5cb636
No related branches found
No related tags found
No related merge requests found
...@@ -104,7 +104,12 @@ class SetFolding(Transformation): ...@@ -104,7 +104,12 @@ class SetFolding(Transformation):
] ]
# these ops use SIMD parallelism, up to a max value of NumChannels # these ops use SIMD parallelism, up to a max value of NumChannels
# ConvolutionInputGenerator has a special case when depthwise=1 # ConvolutionInputGenerator has a special case when depthwise=1
simd_ops = ["DownSampler", "FMPadding_Batch", "ConvolutionInputGenerator"] simd_ops = [
"DownSampler",
"FMPadding_Batch",
"ConvolutionInputGenerator",
"ConvolutionInputGenerator1D",
]
# these ops are preceded by depthwise SWG and have special behavior, # these ops are preceded by depthwise SWG and have special behavior,
# as explained in the SetFolding docstring # as explained in the SetFolding docstring
depthwise_op_exceptions = ["Vector_Vector_Activate_Batch", "Pool_Batch"] depthwise_op_exceptions = ["Vector_Vector_Activate_Batch", "Pool_Batch"]
...@@ -166,7 +171,10 @@ class SetFolding(Transformation): ...@@ -166,7 +171,10 @@ class SetFolding(Transformation):
"Expected SWU on DW op input, found " + swu_node.op_type "Expected SWU on DW op input, found " + swu_node.op_type
) )
elif op_type in simd_ops: elif op_type in simd_ops:
if op_type == "ConvolutionInputGenerator": if op_type in [
"ConvolutionInputGenerator",
"ConvolutionInputGenerator1D",
]:
depthwise = node_inst.get_nodeattr("depthwise") depthwise = node_inst.get_nodeattr("depthwise")
if depthwise == 0: if depthwise == 0:
max_simd = node_inst.get_nodeattr("IFMChannels") max_simd = node_inst.get_nodeattr("IFMChannels")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment