Skip to content
Snippets Groups Projects
Unverified Commit e4cd5b67 authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #777 from Xilinx/feature/relu_act_export_test

Update QuantReLU export test
parents fc8e7d9f 3a2d5e3f
No related branches found
No related tags found
No related merge requests found
......@@ -32,15 +32,12 @@ import numpy as np
import onnx # noqa
import os
import torch
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.export import export_finn_onnx, export_qonnx
from brevitas.nn import QuantReLU
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn
import finn.core.onnx_exec as oxe
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
......@@ -50,30 +47,17 @@ export_onnx_path = "test_brevitas_relu_act_export.onnx"
@pytest.mark.brevitas_export
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize(
"scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
)
@pytest.mark.parametrize("ishape", [(1, 15), (1, 32, 1, 1)])
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_act_export_relu(abits, scaling_impl_type, QONNX_export):
ishape = (1, 15)
def test_brevitas_act_export_relu(
abits,
ishape,
QONNX_export,
):
b_act = QuantReLU(
bit_width=abits,
max_val=6.0,
scaling_impl_type=scaling_impl_type,
restrict_scaling_type=RestrictValueType.LOG_FP,
quant_type=QuantType.INT,
)
if scaling_impl_type == ScalingImplType.PARAMETER:
checkpoint = {
"act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value": torch.tensor(
0.49
).type(
torch.FloatTensor
)
}
b_act.load_state_dict(checkpoint)
if QONNX_export:
m_path = export_onnx_path
export_qonnx(b_act, torch.randn(ishape), m_path)
......@@ -92,17 +76,6 @@ scaling_impl.learned_value": torch.tensor(
inp_tensor = torch.from_numpy(inp_tensor).float()
b_act.eval()
expected = b_act.forward(inp_tensor).detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits, scaling_impl_type)
print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
if abits < 5:
print(
"thres:",
", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
)
print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
......@@ -110,35 +83,22 @@ scaling_impl.learned_value": torch.tensor(
@pytest.mark.brevitas_export
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize("scaling_per_output_channel", [True, False])
@pytest.mark.parametrize("ishape", [(1, 15, 4, 4), (1, 32, 1, 1)])
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_act_export_relu_imagenet(
abits, scaling_per_output_channel, QONNX_export
def test_brevitas_act_export_relu_channel(
abits,
ishape,
QONNX_export,
):
out_channels = 32
ishape = (1, out_channels, 1, 1)
ch = ishape[1]
b_act = QuantReLU(
bit_width=abits,
quant_type=QuantType.INT,
scaling_impl_type=ScalingImplType.PARAMETER,
scaling_per_output_channel=scaling_per_output_channel,
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=False,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
scaling_impl_type=ScalingImplType.CONST,
scaling_per_output_channel=True,
per_channel_broadcastable_shape=(1, ch, 1, 1),
)
if scaling_per_output_channel is True:
rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
else:
rand_tensor = torch.tensor(1.2398)
checkpoint = {
"act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value": rand_tensor.type(
torch.FloatTensor
)
}
b_act.load_state_dict(checkpoint)
if QONNX_export:
m_path = export_onnx_path
export_qonnx(b_act, torch.randn(ishape), m_path)
......@@ -157,93 +117,6 @@ scaling_impl.learned_value": rand_tensor.type(
inp_tensor = torch.from_numpy(inp_tensor).float()
b_act.eval()
expected = b_act.forward(inp_tensor).detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits)
print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
if abits < 5:
print(
"thres:",
", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
)
print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
class PyTorchTestModel(nn.Module):
def __init__(self, abits):
super(PyTorchTestModel, self).__init__()
out_channels = 32
self.b_act = QuantReLU(
bit_width=abits,
quant_type=QuantType.INT,
scaling_impl_type=ScalingImplType.PARAMETER,
scaling_per_output_channel=True,
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=False,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
)
def forward(self, x):
act_out = self.b_act(x)
y0 = act_out * 2.0
y1 = act_out * -1.0
y = y0 + y1
return y
@pytest.mark.brevitas_export
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize("scaling_per_output_channel", [True])
@pytest.mark.parametrize("QONNX_export", [True])
def test_brevitas_act_export_relu_forking(
abits, scaling_per_output_channel, QONNX_export
):
out_channels = 32
ishape = (1, out_channels, 1, 1)
model_pyt = PyTorchTestModel(abits)
rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
checkpoint = {
"b_act.act_quant_proxy.fused_activation_quant_proxy."
"tensor_quant.scaling_impl.learned_value": rand_tensor.type(torch.FloatTensor)
}
model_pyt.load_state_dict(checkpoint)
if QONNX_export:
m_path = export_onnx_path
export_qonnx(model_pyt, torch.randn(ishape), m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32)
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
inp_tensor = torch.from_numpy(inp_tensor).float()
model_pyt.eval()
expected = model_pyt.forward(inp_tensor).detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits)
print("scale: ", model_pyt.quant_act_scale().type(torch.FloatTensor).detach())
if abits < 5:
print(
"thres:",
", ".join(["{:8.4f}".format(x) for x in model_pyt.export_thres[0]]),
)
print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment