diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index c464fc21fc30b8b5b0a7a8df91f413e947c9fc08..d7e1c5e732e29be300e17620d3ca5ea792c5c477 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -1,12 +1,12 @@ +import os from functools import reduce from operator import mul import onnx +import onnx.numpy_helper as nph import torch import torch.onnx from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op -from torch import nn -from torch.autograd import Function from torch.nn import BatchNorm1d, Dropout, Module, ModuleList @@ -70,90 +70,33 @@ def test_brevitas_to_onnx_export(): out = self.fc(x) return out - class objdict(dict): - def __getattr__(self, name): - if name in self: - return self[name] - else: - raise AttributeError("No such attribute: " + name) - - def __setattr__(self, name, value): - self[name] = value - - def __delattr__(self, name): - if name in self: - del self[name] - else: - raise AttributeError("No such attribute: " + name) - - # TODO: <all this needs to mvoe into Brevitas> - quantization_annotation = dict() - - class QuantizedLinearPlaceholderFunction(Function): - @staticmethod - def symbolic(g, W, x, bw, out_features): - # import pdb; pdb.set_trace() - quantization_annotation[W.uniqueName()] = str(bw) - return g.op("MatMul", W, x, domain_s="finn") - - @staticmethod - def forward(ctx, W, x, bw, out_features): - return torch.empty(1, out_features, dtype=torch.float) - - class QuantizedLinearPlaceholder(nn.Module): - def __init__(self, quantized_linear): - super(QuantizedLinearPlaceholder, self).__init__() - self.in_features = quantized_linear.in_features - self.out_features = quantized_linear.out_features - # compute the quantized weights - W, s, bitwidth = quantized_linear.weight_quant(quantized_linear.weight) - W = W.detach().numpy().reshape(self.out_features, self.in_features) - s = s.detach().numpy() - s = s.reshape(s.size, 1) - W = W / s - self.W = torch.from_numpy(W) - self.bitwidth = bitwidth.item() - - def forward(self, x): - # return linear(self.W, x) - return QuantizedLinearPlaceholderFunction.apply( - self.W, x, self.bitwidth, self.out_features - ) - - class QuantizedHardTanhPlaceholderFunction(Function): - @staticmethod - def symbolic(g, input): - ret = g.op("QuantizedHardTanh", input, domain_s="finn") - # insert quantization annotation for the resulting tensor, TODO fix bitwidth - quantization_annotation[ret.uniqueName()] = "1" - return ret - - @staticmethod - def forward(ctx, input): - return input.clamp(0) - - class QuantizedHardTanhPlaceholder(nn.Module): - def __init__(self): - super(QuantizedHardTanhPlaceholder, self).__init__() - - def forward(self, x): - return QuantizedHardTanhPlaceholderFunction.apply(x) - - # TODO: </all this needs to mvoe into Brevitas> export_onnx_path = "test_output_lfc.onnx" - lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - for i in range(len(lfc.features)): - L = lfc.features[i] - if type(L).__name__ == "QuantLinear": - lfc.features[i] = QuantizedLinearPlaceholder(L) - elif type(L).__name__ == "QuantHardTanh": - lfc.features[i] = QuantizedHardTanhPlaceholder() - lfc.fc = QuantizedLinearPlaceholder(lfc.fc) - torch.onnx.export(lfc, torch.empty(784, dtype=torch.float), export_onnx_path) - model = onnx.load(export_onnx_path) - assert len(model.graph.input) == 16 - assert len(model.graph.node) == 33 - assert len(model.graph.output) == 1 - assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 - assert model.graph.node[13].op_type == "Constant" - assert model.graph.node[12].op_type == "QuantizedHardTanh" + with torch.no_grad(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + import brevitas.onnx as bo + + bo.prepare_for_onnx_export(lfc, True) + torch.onnx.export( + lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True + ) + model = onnx.load(export_onnx_path) + # TODO the following way of testing is highly sensitive to small changes + # in PyTorch ONNX export: the order, names, count... of nodes could + # easily change between different versions, and break this test. + assert len(model.graph.input) == 32 + assert len(model.graph.node) == 33 + assert len(model.graph.output) == 1 + assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 + assert model.graph.node[12].op_type == "QuantizedHardTanh" + assert model.graph.node[13].op_type == "Constant" + assert model.graph.node[14].op_type == "MatMul" + assert model.graph.node[12].output[0] == model.graph.node[14].input[1] + assert model.graph.node[13].output[0] == model.graph.node[14].input[0] + int_weights_pytorch = lfc.features[2].int_weight.detach().numpy() + int_weights_onnx = nph.to_array(model.graph.node[13].attribute[0].t) + assert (int_weights_onnx == int_weights_pytorch).all() + assert model.graph.node[12].attribute[0].name == "activation_qnt" + assert model.graph.node[12].attribute[0].s.decode("utf-8") == "1" + assert model.graph.node[14].attribute[1].name == "weight_qnt" + assert model.graph.node[14].attribute[1].s.decode("utf-8") == "1" + os.remove(export_onnx_path)