Skip to content
Snippets Groups Projects
Commit b2eec820 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] use new interface for batchnorm_to_affine

parent dde783bc
No related branches found
No related tags found
No related merge requests found
import copy
import numpy as np
import onnx.shape_inference as si
from onnx import TensorProto
from onnx import helper as oh
import finn.transformation.general as tg
def batchnorm_to_affine(model):
"""Replaces any test-time BatchNorm layers with Mul-Add layers."""
new_model = copy.deepcopy(model)
graph = new_model.graph
graph = model.graph
nodes_to_remove = []
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "BatchNormalization":
graph_modified = True
bn_input = n.input[0]
bn_output = n.output[0]
# extract batchnorm parameters as numpy arrays
scale = tg.get_initializer(new_model, n.input[1])
bias = tg.get_initializer(new_model, n.input[2])
mean = tg.get_initializer(new_model, n.input[3])
variance = tg.get_initializer(new_model, n.input[4])
scale = model.get_initializer(n.input[1])
bias = model.get_initializer(n.input[2])
mean = model.get_initializer(n.input[3])
variance = model.get_initializer(n.input[4])
epsilon = 1e-5
# find A and B to compute batchnorm as affine transpose Ax+B
# TODO is a division by moving avg factor needed for variance?
......@@ -31,32 +28,32 @@ def batchnorm_to_affine(model):
B = bias - (A * mean)
nodes_to_remove += [n]
# see if we have surrounding Unsqueeze/Squeeze nodes we can remove
producer = tg.find_producer(new_model, bn_input)
producer = model.find_producer(bn_input)
if producer is not None:
if producer.op_type == "Unsqueeze":
bn_input = producer.input[0]
nodes_to_remove += [producer]
consumer = tg.find_consumer(new_model, bn_output)
consumer = model.find_consumer(bn_output)
if consumer is not None:
if consumer.op_type == "Squeeze":
bn_output = consumer.output[0]
nodes_to_remove += [consumer]
data_shape = tg.get_tensor_shape(new_model, bn_input)
data_shape = model.get_tensor_shape(bn_input)
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(
tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
)
graph.value_info.append(mul_const)
tg.set_initializer(new_model, mul_const.name, A)
model.set_initializer(mul_const.name, A)
mul_output = oh.make_tensor_value_info(
tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
)
graph.value_info.append(mul_output)
add_const = oh.make_tensor_value_info(
tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
)
graph.value_info.append(add_const)
tg.set_initializer(new_model, add_const.name, B)
model.set_initializer(add_const.name, B)
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node(
"Mul", [bn_input, mul_const.name], [mul_output.name]
......@@ -70,5 +67,6 @@ def batchnorm_to_affine(model):
# delete marked nodes (batchnorm and (un)squeezing)
for n in nodes_to_remove:
graph.node.remove(n)
new_model = si.infer_shapes(new_model)
return new_model
graph_modified = True
model.model = si.infer_shapes(model.model)
return (model, graph_modified)
......@@ -15,6 +15,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
import finn.core.onnx_exec as oxe
import finn.transformation.batchnorm_to_affine as tx
from finn.core.modelwrapper import ModelWrapper
FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
......@@ -94,9 +95,9 @@ def test_batchnorm_to_affine():
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
model = si.infer_shapes(model)
new_model = tx.batchnorm_to_affine(model)
model = ModelWrapper(export_onnx_path)
model.model = si.infer_shapes(model.model)
new_model = model.transform_single(tx.batchnorm_to_affine)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
......@@ -108,8 +109,8 @@ def test_batchnorm_to_affine():
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
input_dict = {"0": nph.to_array(input_tensor)}
output_original = oxe.execute_onnx(model, input_dict)["53"]
output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
output_original = oxe.execute_onnx(model.model, input_dict)["53"]
output_transformed = oxe.execute_onnx(new_model.model, input_dict)["53"]
assert np.isclose(output_transformed, output_original, atol=1e-3).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment