diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe3e8a131fbade8f6058193f68fe8c28fe28eec --- /dev/null +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -0,0 +1,75 @@ +import copy + +import numpy as np +import onnx.shape_inference as si +from onnx import TensorProto +from onnx import helper as oh + +import finn.transformation.general as tg + + +def batchnorm_to_affine(model): + """Replaces any test-time BatchNorm layers with Mul-Add layers.""" + new_model = copy.deepcopy(model) + new_model = si.infer_shapes(new_model) + graph = new_model.graph + nodes_to_remove = [] + node_ind = 0 + for n in graph.node: + node_ind += 1 + if n.op_type == "BatchNormalization": + bn_input = n.input[0] + bn_output = n.output[0] + # extract batchnorm parameters as numpy arrays + scale = tg.get_initializer(new_model, n.input[1]) + bias = tg.get_initializer(new_model, n.input[2]) + mean = tg.get_initializer(new_model, n.input[3]) + variance = tg.get_initializer(new_model, n.input[4]) + epsilon = 1e-5 + # find A and B to compute batchnorm as affine transpose Ax+B + # TODO is a division by moving avg factor needed for variance? + A = scale / np.sqrt(epsilon + variance) + B = bias - (A * mean) + nodes_to_remove += [n] + # see if we have surrounding Unsqueeze/Squeeze nodes we can remove + producer = tg.find_producer(new_model, bn_input) + if producer is not None: + if producer.op_type == "Unsqueeze": + bn_input = producer.input[0] + nodes_to_remove += [producer] + consumer = tg.find_consumer(new_model, bn_output) + if consumer is not None: + if consumer.op_type == "Squeeze": + bn_output = consumer.output[0] + nodes_to_remove += [consumer] + data_shape = tg.get_tensor_shape(new_model, bn_input) + # create value_info and initializers for Mul and Add constants + mul_const = oh.make_tensor_value_info( + tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape + ) + graph.value_info.append(mul_const) + tg.set_initializer(new_model, mul_const.name, A) + mul_output = oh.make_tensor_value_info( + tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape + ) + graph.value_info.append(mul_output) + add_const = oh.make_tensor_value_info( + tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape + ) + graph.value_info.append(add_const) + tg.set_initializer(new_model, add_const.name, B) + # create Mul and Add nodes to replace the batchnorm + mul_node = oh.make_node( + "Mul", [bn_input, mul_const.name], [mul_output.name] + ) + add_node = oh.make_node( + "Add", [mul_output.name, add_const.name], [bn_output] + ) + # insert where the batchnorm is to preserve topological ordering + graph.node.insert(node_ind, mul_node) + graph.node.insert(node_ind + 1, add_node) + # delete marked nodes (batchnorm and (un)squeezing) + for n in nodes_to_remove: + graph.node.remove(n) + new_model = si.infer_shapes(new_model) + return new_model diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index db01b00914db84cf31f5f6420c05e0032141669b..9086823113ec1cfe90431ce8ee904105f6698999 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -1,10 +1,6 @@ import copy -import numpy as np -import onnx.shape_inference as si -from onnx import TensorProto -from onnx import helper as oh -from onnx import numpy_helper as np_helper +import onnx.numpy_helper as np_helper def give_unique_names(model): @@ -93,70 +89,3 @@ def make_new_valueinfo_name(model): while candidate in names: candidate = str(int(candidate) + 1) return candidate - - -def replace_batchnorm_with_affine(model): - """Replaces any test-time BatchNorm layers with Mul-Add layers.""" - new_model = copy.deepcopy(model) - new_model = si.infer_shapes(new_model) - graph = new_model.graph - nodes_to_remove = [] - node_ind = 0 - for n in graph.node: - node_ind += 1 - if n.op_type == "BatchNormalization": - bn_input = n.input[0] - bn_output = n.output[0] - # extract batchnorm parameters as numpy arrays - scale = get_initializer(new_model, n.input[1]) - bias = get_initializer(new_model, n.input[2]) - mean = get_initializer(new_model, n.input[3]) - variance = get_initializer(new_model, n.input[4]) - epsilon = 1e-5 - # find A and B to compute batchnorm as affine transpose Ax+B - # TODO is a division by moving avg factor needed for variance? - A = scale / np.sqrt(epsilon + variance) - B = bias - (A * mean) - nodes_to_remove += [n] - # see if we have surrounding Unsqueeze/Squeeze nodes we can remove - producer = find_producer(new_model, bn_input) - if producer is not None: - if producer.op_type == "Unsqueeze": - bn_input = producer.input[0] - nodes_to_remove += [producer] - consumer = find_consumer(new_model, bn_output) - if consumer is not None: - if consumer.op_type == "Squeeze": - bn_output = consumer.output[0] - nodes_to_remove += [consumer] - data_shape = get_tensor_shape(new_model, bn_input) - # create value_info and initializers for Mul and Add constants - mul_const = oh.make_tensor_value_info( - make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape - ) - graph.value_info.append(mul_const) - set_initializer(new_model, mul_const.name, A) - mul_output = oh.make_tensor_value_info( - make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape - ) - graph.value_info.append(mul_output) - add_const = oh.make_tensor_value_info( - make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape - ) - graph.value_info.append(add_const) - set_initializer(new_model, add_const.name, B) - # create Mul and Add nodes to replace the batchnorm - mul_node = oh.make_node( - "Mul", [bn_input, mul_const.name], [mul_output.name] - ) - add_node = oh.make_node( - "Add", [mul_output.name, add_const.name], [bn_output] - ) - # insert where the batchnorm is to preserve topological ordering - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # delete marked nodes (batchnorm and (un)squeezing) - for n in nodes_to_remove: - graph.node.remove(n) - new_model = si.infer_shapes(new_model) - return new_model diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index babe1c2be373620fa3966b2ed923fe13bc2ecd89..79f6357bbc2127468b57e86b76f6eb28e089514e 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -13,7 +13,7 @@ from models.common import get_act_quant, get_quant_linear, get_quant_type, get_s from torch.nn import BatchNorm1d, Dropout, Module, ModuleList import finn.core.onnx_exec as oxe -import finn.transformation.general as tx +import finn.transformation.batchnorm_to_affine as tx FC_OUT_FEATURES = [1024, 1024, 1024] INTERMEDIATE_FC_PER_OUT_CH_SCALING = True @@ -94,7 +94,7 @@ def test_batchnorm_to_affine(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) - new_model = tx.replace_batchnorm_with_affine(model) + new_model = tx.batchnorm_to_affine(model) try: os.remove("/tmp/" + mnist_onnx_filename) except OSError: