diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 412fddaed15721b36e6fbc0e778f394c6343255d..12fd14646728d3d629d7dffd87b23f4d4dd256e6 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -1,5 +1,8 @@ import copy +import numpy as np +from onnx import TensorProto +from onnx import helper as oh from onnx import numpy_helper as np_helper @@ -77,3 +80,67 @@ def make_new_valueinfo_name(model): while candidate in names: candidate = str(int(candidate) + 1) return candidate + + +def replace_batchnorm_with_affine(model): + """Replaces any test-time BatchNorm layers with Mul-Add layers.""" + new_model = copy.deepcopy(model) + graph = new_model.graph + nodes_to_remove = [] + for n in graph.node: + if n.op_type == "BatchNormalization": + bn_input = n.input[0] + bn_output = n.output[0] + # extract batchnorm parameters as numpy arrays + scale = get_initializer(n.input[1]) + bias = get_initializer(n.input[2]) + mean = get_initializer(n.input[3]) + variance = get_initializer(n.input[4]) + epsilon = 1e-5 + # find A and B to compute batchnorm as affine transpose Ax+B + # TODO is a division by moving avg factor needed for variance? + A = scale / np.sqrt(epsilon + variance) + B = bias - (A * mean) + nodes_to_remove += [n] + # see if we have surrounding Unsqueeze/Squeeze nodes we can remove + producer = find_producer(n) + if producer is not None: + if producer.op_type == "Unsqueeze": + bn_input = producer.input[0] + nodes_to_remove += [producer] + consumer = find_consumer(n) + if consumer is not None: + if consumer.op_type == "Squeeze": + bn_output = consumer.output[0] + nodes_to_remove += [consumer] + # create value_info and initializers for Mul and Add constants + mul_const = oh.make_tensor_value_info( + make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape + ) + graph.value_info.append(mul_const) + set_initializer(new_model, mul_const.name, A) + mul_output = oh.make_tensor_value_info( + make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape + ) + graph.value_info.append(mul_output) + add_const = oh.make_tensor_value_info( + make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape + ) + graph.value_info.append(add_const) + set_initializer(new_model, add_const.name, B) + # create Mul and Add nodes to replace the batchnorm + mul_node = oh.make_node( + "Mul", [bn_input.name, mul_const.name], [mul_output.name] + ) + add_node = oh.make_node( + "Add", [mul_output.name, add_const.name], [bn_output.name] + ) + graph.node.append(mul_node) + graph.node.append(add_node) + + # delete marked nodes + for n in nodes_to_remove: + graph.node.remove(n) + # TODO topologically sort nodes + # TODO give new names, maybe run shape inference? + return new_model