Skip to content
Snippets Groups Projects
Commit b725bdae authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test for batchnorm to affine

parent 529771f5
No related branches found
No related tags found
No related merge requests found
import os
import shutil
from functools import reduce
from operator import mul
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
import finn.core.onnx_exec as oxe
import finn.transformation.general as tx
FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
export_onnx_path = "test_output_lfc.onnx"
transformed_onnx_path = "test_output_lfc_transformed.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
class LFC(Module):
def __init__(
self,
num_classes=10,
weight_bit_width=None,
act_bit_width=None,
in_bit_width=None,
in_ch=1,
in_features=(28, 28),
):
super(LFC, self).__init__()
weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
stats_op = get_stats_op(weight_quant_type)
self.features = ModuleList()
self.features.append(get_act_quant(in_bit_width, in_quant_type))
self.features.append(Dropout(p=IN_DROPOUT))
in_features = reduce(mul, in_features)
for out_features in FC_OUT_FEATURES:
self.features.append(
get_quant_linear(
in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
)
in_features = out_features
self.features.append(BatchNorm1d(num_features=in_features))
self.features.append(get_act_quant(act_bit_width, act_quant_type))
self.features.append(Dropout(p=HIDDEN_DROPOUT))
self.fc = get_quant_linear(
in_features=in_features,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
def forward(self, x):
x = x.view(1, 784)
# removing the torch.tensor here creates a float64 op for some reason..
# so explicitly wrapped with torch.tensor to make a float32 one instead
x = 2.0 * x - torch.tensor([1.0])
for mod in self.features:
x = mod(x)
out = self.fc(x)
return out
def test_batchnorm_to_affine():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
new_model = tx.replace_batchnorm_with_affine(model)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
input_dict = {"0": nph.to_array(input_tensor)}
output_original = oxe.execute_onnx(model, input_dict)["53"]
output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
assert np.isclose(output_transformed, output_original, atol=1e-3).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment