Skip to content
Snippets Groups Projects
Unverified Commit a13faced authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #646 from Xilinx/fix/qonnx_quant_conv_bias

Improvements to QONNX -> FINN-ONNX conversion
parents 1aca8432 6495cf87
No related branches found
No related tags found
No related merge requests found
...@@ -126,10 +126,20 @@ class FoldQuantWeights(Transformation): ...@@ -126,10 +126,20 @@ class FoldQuantWeights(Transformation):
model.set_tensor_datatype(node_out, new_dtype) model.set_tensor_datatype(node_out, new_dtype)
# Reshape scale for Conv if required # Reshape scale for Conv if required
target_output_shape = model.get_tensor_shape(
target_node.output[0]
)
if target_node.op_type == "Conv" and len(scale.shape) > 0: if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape) conv_out_shape = [1] * len(target_output_shape)
bias_shape[1] = -1 # only support per-output channel scaling
scale = scale.reshape(bias_shape) # (i.e. all scale shape elems besides 0th must be 1s)
if len(scale.shape) > 1:
assert (
np.prod(scale.shape[1:]) == 1
), "Can't fold scale beyond per-out-channel granularity"
# collect all scaling in channels dim (since we constrain)
conv_out_shape[1] = -1
scale = scale.reshape(conv_out_shape)
if scale.shape == (1,): if scale.shape == (1,):
scale = scale[0] scale = scale[0]
......
...@@ -110,11 +110,6 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -110,11 +110,6 @@ class ConvertQuantActToMultiThreshold(Transformation):
predecessor_op_type = predecessor[0].op_type predecessor_op_type = predecessor[0].op_type
else: else:
predecessor_op_type = predecessor predecessor_op_type = predecessor
if model.is_fork_node(n):
raise ValueError(
"Forking Quant/BipolarQuant nodes are currently "
"not supported by FINN."
)
if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0: if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0:
raise ValueError( raise ValueError(
"Only Quant nodes with zero-point == 0 are currently supported." "Only Quant nodes with zero-point == 0 are currently supported."
......
...@@ -41,6 +41,7 @@ from brevitas.nn import QuantReLU ...@@ -41,6 +41,7 @@ from brevitas.nn import QuantReLU
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.cleanup import cleanup as qonnx_cleanup from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
...@@ -179,3 +180,83 @@ scaling_impl.learned_value": rand_tensor.type( ...@@ -179,3 +180,83 @@ scaling_impl.learned_value": rand_tensor.type(
assert np.isclose(produced, expected, atol=1e-3).all() assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path) os.remove(export_onnx_path)
class PyTorchTestModel(nn.Module):
def __init__(self, abits):
super(PyTorchTestModel, self).__init__()
out_channels = 32
self.b_act = QuantReLU(
bit_width=abits,
quant_type=QuantType.INT,
scaling_impl_type=ScalingImplType.PARAMETER,
scaling_per_channel=True,
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=False,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
)
def forward(self, x):
act_out = self.b_act(x)
y0 = act_out * 2.0
y1 = act_out * -1.0
y = y0 + y1
return y
@pytest.mark.brevitas_export
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
@pytest.mark.parametrize("scaling_per_channel", [True])
@pytest.mark.parametrize("QONNX_export", [True])
def test_brevitas_act_export_relu_forking(
abits, max_val, scaling_per_channel, QONNX_export
):
out_channels = 32
ishape = (1, out_channels, 1, 1)
min_val = -1.0
model_pyt = PyTorchTestModel(abits)
rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
checkpoint = {
"b_act.act_quant_proxy.fused_activation_quant_proxy."
"tensor_quant.scaling_impl.learned_value": rand_tensor.type(torch.FloatTensor)
}
model_pyt.load_state_dict(checkpoint)
if QONNX_export:
m_path = export_onnx_path
BrevitasONNXManager.export(model_pyt, ishape, m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
np.float32
)
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
inp_tensor = torch.from_numpy(inp_tensor).float()
model_pyt.eval()
expected = model_pyt.forward(inp_tensor).detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits, max_val)
print("scale: ", model_pyt.quant_act_scale().type(torch.FloatTensor).detach())
if abits < 5:
print(
"thres:",
", ".join(["{:8.4f}".format(x) for x in model_pyt.export_thres[0]]),
)
print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment