Skip to content
Snippets Groups Projects
Commit 441a25a6 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] use CNV provided by brevitas_cnv_lfc for test_brevitas_cnv

parent fa32c273
No related branches found
No related tags found
No related merge requests found
......@@ -4,40 +4,13 @@ import pkg_resources as pk
import brevitas.onnx as bo
import numpy as np
import torch
from models.common import (
get_act_quant,
get_quant_conv2d,
get_quant_linear,
get_quant_type,
get_stats_op
)
from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
from models.CNV import CNV
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
# QuantConv2d configuration
CNV_OUT_CH_POOL = [
(0, 64, False),
(1, 64, True),
(2, 128, False),
(3, 128, True),
(4, 256, False),
(5, 256, False),
]
# Intermediate QuantLinear configuration
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
# Last QuantLinear configuration
LAST_FC_IN_FEATURES = 512
LAST_FC_PER_OUT_CH_SCALING = False
# MaxPool2d configuration
POOL_SIZE = 2
export_onnx_path = "test_output_cnv.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_cnv_checkpoint = (
......@@ -45,92 +18,7 @@ trained_cnv_checkpoint = (
)
class CNV(Module):
def __init__(
self,
num_classes=10,
weight_bit_width=None,
act_bit_width=None,
in_bit_width=None,
in_ch=3,
):
super(CNV, self).__init__()
weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
stats_op = get_stats_op(weight_quant_type)
self.conv_features = ModuleList()
self.linear_features = ModuleList()
self.conv_features.append(get_act_quant(in_bit_width, in_quant_type))
for i, out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
self.conv_features.append(
get_quant_conv2d(
in_ch=in_ch,
out_ch=out_ch,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
)
in_ch = out_ch
if is_pool_enabled:
self.conv_features.append(MaxPool2d(kernel_size=2))
if i == 5:
self.conv_features.append(Sequential())
self.conv_features.append(BatchNorm2d(in_ch))
self.conv_features.append(get_act_quant(act_bit_width, act_quant_type))
for in_features, out_features in INTERMEDIATE_FC_FEATURES:
self.linear_features.append(
get_quant_linear(
in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
)
self.linear_features.append(BatchNorm1d(out_features))
self.linear_features.append(get_act_quant(act_bit_width, act_quant_type))
self.fc = get_quant_linear(
in_features=LAST_FC_IN_FEATURES,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op,
)
def forward(self, x):
x = 2.0 * x - torch.tensor([1.0])
for mod in self.conv_features:
x = mod(x)
x = x.view(1, 256)
for mod in self.linear_features:
x = mod(x)
out = self.fc(x)
return out
def test_brevitas_trained_cnv_pytorch():
# load pretrained weights into CNV-w1a1
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"]
input_tensor = torch.from_numpy(input_tensor).float()
assert input_tensor.shape == (1, 3, 32, 32)
# do forward pass in PyTorch/Brevitas
cnv.forward(input_tensor).detach().numpy()
# TODO verify produced answer
def test_brevitas_cnv_export():
def test_brevitas_cnv_w1a1_export():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
......@@ -139,15 +27,17 @@ def test_brevitas_cnv_export():
conv0_wname = model.graph.node[3].input[1]
assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3]
assert model.graph.node[4].op_type == "Mul"
os.remove(export_onnx_path)
def test_brevitas_cnv_export_exec():
def test_brevitas_cnv_w1a1_export_exec():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model.save(export_onnx_path)
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
......@@ -161,3 +51,17 @@ def test_brevitas_cnv_export_exec():
expected = cnv.forward(input_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
def test_brevitas_trained_cnv_w1a1_pytorch():
# load pretrained weights into CNV-w1a1
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"]
input_tensor = torch.from_numpy(input_tensor).float()
assert input_tensor.shape == (1, 3, 32, 32)
# do forward pass in PyTorch/Brevitas
cnv.forward(input_tensor).detach().numpy()
# TODO verify produced answer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment