Skip to content
Snippets Groups Projects
Commit cc0ccaf3 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add a first unfinished version of batchnorm to affine

parent 8be08279
No related branches found
No related tags found
No related merge requests found
import copy
import numpy as np
from onnx import TensorProto
from onnx import helper as oh
from onnx import numpy_helper as np_helper
......@@ -77,3 +80,67 @@ def make_new_valueinfo_name(model):
while candidate in names:
candidate = str(int(candidate) + 1)
return candidate
def replace_batchnorm_with_affine(model):
"""Replaces any test-time BatchNorm layers with Mul-Add layers."""
new_model = copy.deepcopy(model)
graph = new_model.graph
nodes_to_remove = []
for n in graph.node:
if n.op_type == "BatchNormalization":
bn_input = n.input[0]
bn_output = n.output[0]
# extract batchnorm parameters as numpy arrays
scale = get_initializer(n.input[1])
bias = get_initializer(n.input[2])
mean = get_initializer(n.input[3])
variance = get_initializer(n.input[4])
epsilon = 1e-5
# find A and B to compute batchnorm as affine transpose Ax+B
# TODO is a division by moving avg factor needed for variance?
A = scale / np.sqrt(epsilon + variance)
B = bias - (A * mean)
nodes_to_remove += [n]
# see if we have surrounding Unsqueeze/Squeeze nodes we can remove
producer = find_producer(n)
if producer is not None:
if producer.op_type == "Unsqueeze":
bn_input = producer.input[0]
nodes_to_remove += [producer]
consumer = find_consumer(n)
if consumer is not None:
if consumer.op_type == "Squeeze":
bn_output = consumer.output[0]
nodes_to_remove += [consumer]
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(
make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
)
graph.value_info.append(mul_const)
set_initializer(new_model, mul_const.name, A)
mul_output = oh.make_tensor_value_info(
make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
)
graph.value_info.append(mul_output)
add_const = oh.make_tensor_value_info(
make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
)
graph.value_info.append(add_const)
set_initializer(new_model, add_const.name, B)
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node(
"Mul", [bn_input.name, mul_const.name], [mul_output.name]
)
add_node = oh.make_node(
"Add", [mul_output.name, add_const.name], [bn_output.name]
)
graph.node.append(mul_node)
graph.node.append(add_node)
# delete marked nodes
for n in nodes_to_remove:
graph.node.remove(n)
# TODO topologically sort nodes
# TODO give new names, maybe run shape inference?
return new_model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment