Skip to content
Snippets Groups Projects
Commit c641ed7a authored by auphelia's avatar auphelia
Browse files

[Test] Add comparison of pytorch execution an finn onnx execution of new quant relu node

parent 9ab0cd15
No related branches found
No related tags found
No related merge requests found
......@@ -71,10 +71,11 @@ scaling_impl.learned_value": torch.tensor(
@pytest.mark.parametrize("abits", [4])
@pytest.mark.parametrize("out_channels", [32])
def test_brevitas_act_export_relu_imagenet(abits, out_channels):
ishape = (32, 15)
@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
def test_brevitas_act_export_relu_imagenet(abits, max_val):
out_channels = 32
ishape = (1, out_channels, 1, 1)
min_val = -1.0
b_act = QuantReLU(
bit_width=abits,
quant_type=QuantType.INT,
......@@ -83,7 +84,7 @@ def test_brevitas_act_export_relu_imagenet(abits, out_channels):
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=False,
return_quant_tensor=True,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
)
checkpoint = {
......@@ -131,3 +132,28 @@ scaling_impl.learned_value": torch.tensor(
}
b_act.load_state_dict(checkpoint)
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
np.float32
)
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
inp_tensor = torch.from_numpy(inp_tensor).float()
b_act.eval()
expected = b_act.forward(inp_tensor).tensor.detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits, max_val)
print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
if abits < 5:
print(
"thres:",
", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
)
print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment