From b2eec8207d899dd95ab759412007a7798dcd1d96 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Tue, 22 Oct 2019 14:03:41 +0100 Subject: [PATCH] [Transform] use new interface for batchnorm_to_affine --- .../transformation/batchnorm_to_affine.py | 38 +++++++++---------- tests/test_batchnorm_to_affine.py | 11 +++--- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index dfdad3ba2..4f3984cd7 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -1,29 +1,26 @@ -import copy - import numpy as np import onnx.shape_inference as si from onnx import TensorProto from onnx import helper as oh -import finn.transformation.general as tg - def batchnorm_to_affine(model): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" - new_model = copy.deepcopy(model) - graph = new_model.graph + graph = model.graph nodes_to_remove = [] node_ind = 0 + graph_modified = False for n in graph.node: node_ind += 1 if n.op_type == "BatchNormalization": + graph_modified = True bn_input = n.input[0] bn_output = n.output[0] # extract batchnorm parameters as numpy arrays - scale = tg.get_initializer(new_model, n.input[1]) - bias = tg.get_initializer(new_model, n.input[2]) - mean = tg.get_initializer(new_model, n.input[3]) - variance = tg.get_initializer(new_model, n.input[4]) + scale = model.get_initializer(n.input[1]) + bias = model.get_initializer(n.input[2]) + mean = model.get_initializer(n.input[3]) + variance = model.get_initializer(n.input[4]) epsilon = 1e-5 # find A and B to compute batchnorm as affine transpose Ax+B # TODO is a division by moving avg factor needed for variance? @@ -31,32 +28,32 @@ def batchnorm_to_affine(model): B = bias - (A * mean) nodes_to_remove += [n] # see if we have surrounding Unsqueeze/Squeeze nodes we can remove - producer = tg.find_producer(new_model, bn_input) + producer = model.find_producer(bn_input) if producer is not None: if producer.op_type == "Unsqueeze": bn_input = producer.input[0] nodes_to_remove += [producer] - consumer = tg.find_consumer(new_model, bn_output) + consumer = model.find_consumer(bn_output) if consumer is not None: if consumer.op_type == "Squeeze": bn_output = consumer.output[0] nodes_to_remove += [consumer] - data_shape = tg.get_tensor_shape(new_model, bn_input) + data_shape = model.get_tensor_shape(bn_input) # create value_info and initializers for Mul and Add constants mul_const = oh.make_tensor_value_info( - tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape + model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape ) graph.value_info.append(mul_const) - tg.set_initializer(new_model, mul_const.name, A) + model.set_initializer(mul_const.name, A) mul_output = oh.make_tensor_value_info( - tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape + model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape ) graph.value_info.append(mul_output) add_const = oh.make_tensor_value_info( - tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape + model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape ) graph.value_info.append(add_const) - tg.set_initializer(new_model, add_const.name, B) + model.set_initializer(add_const.name, B) # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node( "Mul", [bn_input, mul_const.name], [mul_output.name] @@ -70,5 +67,6 @@ def batchnorm_to_affine(model): # delete marked nodes (batchnorm and (un)squeezing) for n in nodes_to_remove: graph.node.remove(n) - new_model = si.infer_shapes(new_model) - return new_model + graph_modified = True + model.model = si.infer_shapes(model.model) + return (model, graph_modified) diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 34abb9a19..7c7efc494 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -15,6 +15,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList import finn.core.onnx_exec as oxe import finn.transformation.batchnorm_to_affine as tx +from finn.core.modelwrapper import ModelWrapper FC_OUT_FEATURES = [1024, 1024, 1024] INTERMEDIATE_FC_PER_OUT_CH_SCALING = True @@ -94,9 +95,9 @@ def test_batchnorm_to_affine(): checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) - model = onnx.load(export_onnx_path) - model = si.infer_shapes(model) - new_model = tx.batchnorm_to_affine(model) + model = ModelWrapper(export_onnx_path) + model.model = si.infer_shapes(model.model) + new_model = model.transform_single(tx.batchnorm_to_affine) try: os.remove("/tmp/" + mnist_onnx_filename) except OSError: @@ -108,8 +109,8 @@ def test_batchnorm_to_affine(): with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: input_tensor.ParseFromString(f.read()) input_dict = {"0": nph.to_array(input_tensor)} - output_original = oxe.execute_onnx(model, input_dict)["53"] - output_transformed = oxe.execute_onnx(new_model, input_dict)["53"] + output_original = oxe.execute_onnx(model.model, input_dict)["53"] + output_transformed = oxe.execute_onnx(new_model.model, input_dict)["53"] assert np.isclose(output_transformed, output_original, atol=1e-3).all() # remove the downloaded model and extracted files os.remove(dl_ret) -- GitLab