diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 652136c82303d8b9ca6772f228312fe96efd33ea..46a97f6fab25029727fd7c106d978ff60130dab6 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -349,15 +349,15 @@ class InferStreamingMaxPool(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + for node in graph.node: node_ind += 1 - if n.op_type == "MaxPoolNHWC": - mp_input = n.input[0] - mp_output = n.output[0] + if node.op_type == "MaxPoolNHWC": + mp_input = node.input[0] + mp_output = node.output[0] mp_in_shape = model.get_tensor_shape(mp_input) # mp_out_shape = model.get_tensor_shape(mp_output) dt = model.get_tensor_datatype(mp_input) - mp_inst = getCustomOp(n) + mp_inst = getCustomOp(node) k_h, k_w = mp_inst.get_nodeattr("kernel_shape") ifm_ch = mp_in_shape[-1] ifm_dim_h = mp_in_shape[1] @@ -365,8 +365,11 @@ class InferStreamingMaxPool(Transformation): pe = 1 ceil_mode = mp_inst.get_nodeattr("ceil_mode") is_1d = (ifm_dim_h == 1 and k_h == 1) or (ifm_dim_w == 1 and k_w == 1) - is_divisable = ifm_dim_h % k_h == 0 or ifm_dim_w % k_w == 0 - if is_1d or is_divisable: + is_divisable = (ifm_dim_h % k_h == 0) or (ifm_dim_w % k_w == 0) + is_bipolar = dt == DataType["BIPOLAR"] + pass_1d = is_1d and (not is_bipolar) + pass_2d = (not is_1d) and is_divisable + if pass_1d or pass_2d: # create equivalent StreamingMaxPool_Batch node new_node = helper.make_node( "StreamingMaxPool_Batch", @@ -380,12 +383,14 @@ class InferStreamingMaxPool(Transformation): dataType=dt.name, PE=pe, CeilMode=ceil_mode, - name="StreamingMaxPool_Batch_" + n.name, + name="StreamingMaxPool_Batch_" + node.name, ) graph.node.insert(node_ind, new_node) # remove old nodes - graph.node.remove(n) + graph.node.remove(node) graph_modified = True + else: + warnings.warn(node.name + ": could not convert to HLS") if graph_modified: model = model.transform(InferShapes()) model = model.transform(InferDataTypes())