Skip to content
Snippets Groups Projects
Commit 529771f5 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] fix batchnorm to affine problems, add get_tensor_shape

parent cc0ccaf3
No related branches found
No related tags found
No related merge requests found
import copy
import numpy as np
import onnx.shape_inference as si
from onnx import TensorProto
from onnx import helper as oh
from onnx import numpy_helper as np_helper
......@@ -16,6 +17,21 @@ def give_unique_names(model):
return new_model
def get_tensor_shape(model, tensor_name):
"""Returns the shape of tensor with given name, if it has ValueInfoProto."""
graph = model.graph
vi_names = [(x.name, x) for x in graph.input]
vi_names += [(x.name, x) for x in graph.output]
vi_names += [(x.name, x) for x in graph.value_info]
try:
vi_ind = [x[0] for x in vi_names].index(tensor_name)
vi = vi_names[vi_ind][1]
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return dims
except ValueError:
return None
def set_initializer(model, tensor_name, tensor_value):
"""Set the initializer value for tensor with given name."""
graph = model.graph
......@@ -34,12 +50,9 @@ def set_initializer(model, tensor_name, tensor_value):
graph.initializer.append(tensor_init_proto)
def get_initializer(model, tensor_name, tensor_value):
def get_initializer(model, tensor_name):
"""Get the initializer value for tensor with given name, if any."""
graph = model.graph
# convert tensor_value (numpy array) into TensorProto w/ correct name
tensor_init_proto = np_helper.from_array(tensor_value)
tensor_init_proto.name = tensor_name
init_names = [x.name for x in graph.initializer]
try:
init_ind = init_names.index(tensor_name)
......@@ -51,7 +64,7 @@ def get_initializer(model, tensor_name, tensor_value):
def find_producer(model, tensor_name):
"""Find and return the node that produces the tensor with given name.
Currently only works for linear graphs."""
all_outputs = [x.output[0].name for x in model.graph.node]
all_outputs = [x.output[0] for x in model.graph.node]
try:
producer_ind = all_outputs.index(tensor_name)
return model.graph.node[producer_ind]
......@@ -62,7 +75,7 @@ def find_producer(model, tensor_name):
def find_consumer(model, tensor_name):
"""Find and return the node that consumes the tensor with given name.
Currently only works for linear graphs."""
all_inputs = [x.input[0].name for x in model.graph.node]
all_inputs = [x.input[0] for x in model.graph.node]
try:
consumer_ind = all_inputs.index(tensor_name)
return model.graph.node[consumer_ind]
......@@ -85,17 +98,20 @@ def make_new_valueinfo_name(model):
def replace_batchnorm_with_affine(model):
"""Replaces any test-time BatchNorm layers with Mul-Add layers."""
new_model = copy.deepcopy(model)
new_model = si.infer_shapes(new_model)
graph = new_model.graph
nodes_to_remove = []
node_ind = 0
for n in graph.node:
node_ind += 1
if n.op_type == "BatchNormalization":
bn_input = n.input[0]
bn_output = n.output[0]
# extract batchnorm parameters as numpy arrays
scale = get_initializer(n.input[1])
bias = get_initializer(n.input[2])
mean = get_initializer(n.input[3])
variance = get_initializer(n.input[4])
scale = get_initializer(new_model, n.input[1])
bias = get_initializer(new_model, n.input[2])
mean = get_initializer(new_model, n.input[3])
variance = get_initializer(new_model, n.input[4])
epsilon = 1e-5
# find A and B to compute batchnorm as affine transpose Ax+B
# TODO is a division by moving avg factor needed for variance?
......@@ -103,16 +119,17 @@ def replace_batchnorm_with_affine(model):
B = bias - (A * mean)
nodes_to_remove += [n]
# see if we have surrounding Unsqueeze/Squeeze nodes we can remove
producer = find_producer(n)
producer = find_producer(new_model, bn_input)
if producer is not None:
if producer.op_type == "Unsqueeze":
bn_input = producer.input[0]
nodes_to_remove += [producer]
consumer = find_consumer(n)
consumer = find_consumer(new_model, bn_output)
if consumer is not None:
if consumer.op_type == "Squeeze":
bn_output = consumer.output[0]
nodes_to_remove += [consumer]
data_shape = get_tensor_shape(new_model, bn_input)
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(
make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
......@@ -120,7 +137,7 @@ def replace_batchnorm_with_affine(model):
graph.value_info.append(mul_const)
set_initializer(new_model, mul_const.name, A)
mul_output = oh.make_tensor_value_info(
make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
)
graph.value_info.append(mul_output)
add_const = oh.make_tensor_value_info(
......@@ -130,17 +147,16 @@ def replace_batchnorm_with_affine(model):
set_initializer(new_model, add_const.name, B)
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node(
"Mul", [bn_input.name, mul_const.name], [mul_output.name]
"Mul", [bn_input, mul_const.name], [mul_output.name]
)
add_node = oh.make_node(
"Add", [mul_output.name, add_const.name], [bn_output.name]
"Add", [mul_output.name, add_const.name], [bn_output]
)
graph.node.append(mul_node)
graph.node.append(add_node)
# delete marked nodes
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
# delete marked nodes (batchnorm and (un)squeezing)
for n in nodes_to_remove:
graph.node.remove(n)
# TODO topologically sort nodes
# TODO give new names, maybe run shape inference?
new_model = si.infer_shapes(new_model)
return new_model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment