Skip to content
Snippets Groups Projects
Commit ac73f06d authored by auphelia's avatar auphelia
Browse files

[Test] Add first draft of brevitas relu export for multichannel thresholds

parent 4062647c
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@ export_onnx_path = "test_act.onnx"
def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
min_val = -1.0
ishape = (1, 15)
b_act = QuantReLU(
bit_width=abits,
max_val=max_val,
......@@ -67,3 +68,66 @@ scaling_impl.learned_value": torch.tensor(
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
@pytest.mark.parametrize("abits", [4])
@pytest.mark.parametrize("out_channels", [32])
def test_brevitas_act_export_relu_imagenet(abits, out_channels):
ishape = (32, 15)
b_act = QuantReLU(
bit_width=abits,
quant_type=QuantType.INT,
scaling_impl_type=ScalingImplType.PARAMETER,
scaling_per_channel=True,
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=False,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
)
checkpoint = {
"act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value": torch.tensor(
[
[
[[0.8520]],
[[0.7784]],
[[0.8643]],
[[1.0945]],
[[0.8520]],
[[0.7239]],
[[0.8520]],
[[1.0609]],
[[1.0299]],
[[0.0121]],
[[0.8520]],
[[0.8489]],
[[0.1376]],
[[0.8520]],
[[0.7096]],
[[0.8520]],
[[0.8520]],
[[0.8608]],
[[0.9484]],
[[0.8520]],
[[0.7237]],
[[0.5425]],
[[0.5774]],
[[0.5975]],
[[0.6685]],
[[0.4472]],
[[0.8520]],
[[0.7879]],
[[0.8520]],
[[1.0322]],
[[0.4550]],
[[1.0612]],
]
]
).type(
torch.FloatTensor
)
}
b_act.load_state_dict(checkpoint)
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment