Skip to content
Snippets Groups Projects
Commit 3bc81a6e authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] spec signed in pack_quant_tensor

parent 2553e084
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ def test_brevitas_avg_pool_export(
scale = np.ones((1, channels, 1, 1))
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
)
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
......@@ -65,7 +65,7 @@ def test_brevitas_avg_pool_export(
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
......@@ -84,7 +84,7 @@ def test_brevitas_avg_pool_export(
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor, signed=signed
)
# export again to set the scale values correctly
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment