Skip to content
Snippets Groups Projects
Commit 9a3a2eea authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] refactor test_batchnorm_to_affine to use models.LFC

parent 5a6c8a48
No related branches found
No related tags found
No related merge requests found
import os
from functools import reduce
from operator import mul
from pkgutil import get_data
import brevitas.onnx as bo
......@@ -8,20 +6,14 @@ import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
from models.LFC import LFC
import finn.core.onnx_exec as oxe
import finn.transformation.batchnorm_to_affine as tx
import finn.transformation.fold_constants as fc
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
export_onnx_path = "test_output_lfc.onnx"
transformed_onnx_path = "test_output_lfc_transformed.onnx"
# TODO get from config instead, hardcoded to Docker path for now
......@@ -30,62 +22,6 @@ trained_lfc_checkpoint = (
)
class LFC(Module):
def __init__(
self,
num_classes=10,
weight_bit_width=None,
act_bit_width=None,
in_bit_width=None,
in_ch=1,
in_features=(28, 28),
):
super(LFC, self).__init__()
weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
stats_op = get_stats_op(weight_quant_type)
self.features = ModuleList()
self.features.append(get_act_quant(in_bit_width, in_quant_type))
self.features.append(Dropout(p=IN_DROPOUT))
in_features = reduce(mul, in_features)
for out_features in FC_OUT_FEATURES:
self.features.append(
get_quant_linear(
in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
)
in_features = out_features
self.features.append(BatchNorm1d(num_features=in_features))
self.features.append(get_act_quant(act_bit_width, act_quant_type))
self.features.append(Dropout(p=HIDDEN_DROPOUT))
self.fc = get_quant_linear(
in_features=in_features,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
def forward(self, x):
x = x.view(1, 784)
# removing the torch.tensor here creates a float64 op for some reason..
# so explicitly wrapped with torch.tensor to make a float32 one instead
x = 2.0 * x - torch.tensor([1.0])
for mod in self.features:
x = mod(x)
out = self.fc(x)
return out
def test_batchnorm_to_affine():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
......@@ -93,12 +29,15 @@ def test_batchnorm_to_affine():
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
model = model.transform_repeated(fc.fold_constants)
new_model = model.transform_single(tx.batchnorm_to_affine)
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
out_old = model.graph.output[0].name
out_new = new_model.graph.output[0].name
input_dict = {"0": nph.to_array(input_tensor)}
output_original = oxe.execute_onnx(model, input_dict)["53"]
output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
output_original = oxe.execute_onnx(model, input_dict)[out_old]
output_transformed = oxe.execute_onnx(new_model, input_dict)[out_new]
assert np.isclose(output_transformed, output_original, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment