diff --git a/src/finn/custom_op/fpgadataflow/globalaccpool_batch.py b/src/finn/custom_op/fpgadataflow/globalaccpool_batch.py index 9e6c63dc510aab5f6baff9cb6326a2d0476f67a9..83152dea6cc494b8464c78605399b21b38d48b80 100644 --- a/src/finn/custom_op/fpgadataflow/globalaccpool_batch.py +++ b/src/finn/custom_op/fpgadataflow/globalaccpool_batch.py @@ -75,16 +75,19 @@ class GlobalAccPool_Batch(HLSCustomOp): def get_normal_output_shape(self): ch = self.get_nodeattr("NumChannels") vecs = list(self.get_nodeattr("numInputVectors")) - oshape = tuple([vecs[0]] + [ch]) + if len(vecs) == 1: + oshape = tuple(vecs + [ch]) + elif len(vecs) == 3: + oshape = tuple([vecs[0]] + [1, 1, ch]) return oshape def get_folded_output_shape(self): ch = self.get_nodeattr("NumChannels") pe = self.get_nodeattr("PE") - vecs = list(self.get_nodeattr("numInputVectors")) + unfolded_shape = list(self.get_normal_output_shape()) assert ch % pe == 0, "PE must divide NumChannels" folds = int(ch / pe) - oshape = tuple([vecs[0]] + [folds, pe]) + oshape = tuple(unfolded_shape[:-1] + [folds, pe]) return oshape def make_shape_compatible_op(self, model): diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 538af51de8174712099d73a39519baf9acaca1b4..afb37b1ec0e0fc19a170b19337bceafd34f11a0e 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -785,3 +785,88 @@ class InferChannelwiseLinearLayer(Transformation): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) return (model, graph_modified) + + +class InferGlobalAccPoolLayer(Transformation): + """Convert any GlobalAveragePool into a GlobalAccPool HLS layer and a scalar Mul.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type == "GlobalAveragePool": + in0 = node.input[0] + result = node.output[0] + in0_shape = model.get_tensor_shape(in0) + + idt = model.get_tensor_datatype(in0) + + # skip conversion for layers with float input + if not idt.is_integer(): + continue + + # check layout and convert if necessary + in0_layout = model.get_tensor_layout(in0) + result_layout = model.get_tensor_layout(result) + + if in0_layout == DataLayout.NCHW: + in0 = nchw_to_nhwc(in0, model, node_ind) + node_ind += 1 + in0_shape = model.get_tensor_shape(in0) + + # keep track of where we need to insert the HLS Op + # it has to be ahead of the output transform + insert_point = node_ind + + if result_layout == DataLayout.NCHW: + result = nchw_to_nhwc(result, model, node_ind, reverse=True) + node_ind += 1 + + num_ch = int(in0_shape[-1]) + vecs = in0_shape[:-1] + # create node with no parallelization first + pe = 1 + assert ( + num_ch % pe == 0 + ), "Requirement Labels divisable by PE is violated." + + # create an additional tensor of the same shape and layout as result + out_shape = model.get_tensor_shape(result) + pool_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape + ) + model.graph.value_info.append(pool_out) + pool_out = pool_out.name + model.set_tensor_layout(pool_out, model.get_tensor_layout(result)) + + new_pool = helper.make_node( + "GlobalAccPool_Batch", + [in0], + [pool_out], + domain="finn", + backend="fpgadataflow", + NumChannels=num_ch, + PE=pe, + inputDataType=idt.name, + numInputVectors=vecs, + ) + + mul_value = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, [1] + ) + model.graph.value_info.append(mul_value) + model.set_initializer(mul_value.name, np.array(1 / (vecs[1] * vecs[2]))) + new_mul = helper.make_node("Mul", [pool_out, mul_value.name], [result],) + graph.node.insert(insert_point, new_pool) + graph.node.insert(insert_point + 1, new_mul) + node_ind += 1 + # remove old node + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified)