Skip to content
Snippets Groups Projects
Commit 05d0ccb1 authored by Felix Jentzsch's avatar Felix Jentzsch
Browse files

Fix merge

parent fd5a21a0
No related branches found
No related tags found
No related merge requests found
......@@ -51,7 +51,6 @@ class InferConvInpGen(Transformation):
def __init__(self, use_rtl_variant=False):
super().__init__()
self.use_rtl_variant = use_rtl_variant
self.use_rtl_variant = True #testing
def apply(self, model):
graph = model.graph
......@@ -225,15 +224,15 @@ class InferConvInpGen(Transformation):
depthwise=depthwise,
name="ConvolutionInputGenerator_" + n.name,
)
else: # non-square images and/or kernels
else: # 1D images and/or kernels
assert is_1d_convolution, (
"%s: ConvolutionInputGenerator1D works only for 1D convs"
% n.name
)
if dilation_h > 1 or dilation_w > 1:
assert stride_h == 1 and stride_w == 1, (
"""%s: Stride value of greater than 1 is not supported for convolutions
with dilation value greater than 1"""
assert depthwise == 1, (
"""%s: Dilation value > 1 is only supported for
1D depthwise separable convolutions"""
% n.name
)
ConvInpGen_node = helper.make_node(
......@@ -1689,4 +1688,4 @@ class InferConcatLayer(Transformation):
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
return (model, graph_modified)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment