diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 3bc328a9f4f6670041d33491d58af6c553bafac9..5a81d8081d13853cb2b660ead980a7de89821d89 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -4,6 +4,7 @@ import onnxruntime as rt from finn.custom_op import CustomOp from finn.core.datatype import DataType +from finn.custom_op.maxpoolnhwc import compute_pool_output_dim class QuantAvgPool2d(CustomOp): @@ -17,18 +18,44 @@ class QuantAvgPool2d(CustomOp): "ibits": ("i", True, 1), "obits": ("i", True, 1), "signed": ("i", True, 0), + "data_layout": ("s", False, "NCHW"), } def make_shape_compatible_op(self, model): node = self.onnx_node + iname = node.input[0] + ishape = model.get_tensor_shape(iname) k = self.get_nodeattr("kernel") s = self.get_nodeattr("stride") + data_layout = self.get_nodeattr("data_layout") + if data_layout == "NCHW": + (n, c, hi, wi) = ishape + ho = compute_pool_output_dim(hi, k, s) + wo = compute_pool_output_dim(wi, k, s) + oshape = (n, c, ho, wo) + elif data_layout == "NHWC": + (n, hi, wi, c) = ishape + ho = compute_pool_output_dim(hi, k, s) + wo = compute_pool_output_dim(wi, k, s) + oshape = (n, ho, wo, c) + else: + raise Exception( + """Datalayout for QuantAvgPool2d is set to an unvalid value. + Has to be set to "NCHW" or "NHWC".""" + ) + + # implement tensor with correct shape + values = np.random.randn(*oshape).astype(np.float32) return helper.make_node( - "AveragePool", - inputs=[node.input[0]], + "Constant", + inputs=[], outputs=[node.output[0]], - kernel_shape=[k, k], - strides=[s, s], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), ) def infer_node_datatype(self, model): @@ -48,8 +75,12 @@ class QuantAvgPool2d(CustomOp): node = self.onnx_node k = self.get_nodeattr("kernel") s = self.get_nodeattr("stride") - ishape = context[node.input[0]].shape + inp_values = context[node.input[0]] oshape = context[node.output[0]].shape + if self.get_nodeattr("data_layout") == "NHWC": + inp_values = inp_values.transpose(0, 3, 1, 2) + oshape = (context[node.output[0]]).transpose(0, 3, 1, 2).shape + ishape = inp_values.shape inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape) outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape) node_avgpool = helper.make_node( @@ -66,7 +97,7 @@ class QuantAvgPool2d(CustomOp): outputs=[outp], ) model_avgpool = helper.make_model(graph_avgpool) - idict = {node.input[0]: context[node.input[0]]} + idict = {node.input[0]: inp_values} sess = rt.InferenceSession(model_avgpool.SerializeToString()) result_temp = sess.run(None, idict) # remove scaling introduced by average @@ -77,7 +108,16 @@ class QuantAvgPool2d(CustomOp): max_bit_width = int(max_value).bit_length() shift_bits = max_bit_width - self.get_nodeattr("obits") result = np.right_shift(result_temp.astype(int), shift_bits) + if self.get_nodeattr("data_layout") == "NHWC": + result = result.transpose(0, 2, 3, 1) context[node.output[0]] = result.astype(np.float32) def verify_node(self): - pass + info_messages = [] + # verify that "domain" is set to "finn" + domain_value = self.onnx_node.domain + if domain_value == "finn": + info_messages.append("Attribute domain is set correctly") + else: + info_messages.append('Attribute domain should be set to "finn"') + return info_messages