diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py new file mode 100644 index 0000000000000000000000000000000000000000..c464fc21fc30b8b5b0a7a8df91f413e947c9fc08 --- /dev/null +++ b/tests/test_brevitas_export.py @@ -0,0 +1,159 @@ +from functools import reduce +from operator import mul + +import onnx +import torch +import torch.onnx +from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op +from torch import nn +from torch.autograd import Function +from torch.nn import BatchNorm1d, Dropout, Module, ModuleList + + +def test_brevitas_to_onnx_export(): + FC_OUT_FEATURES = [1024, 1024, 1024] + INTERMEDIATE_FC_PER_OUT_CH_SCALING = True + LAST_FC_PER_OUT_CH_SCALING = False + IN_DROPOUT = 0.2 + HIDDEN_DROPOUT = 0.2 + + class LFC(Module): + def __init__( + self, + num_classes=10, + weight_bit_width=None, + act_bit_width=None, + in_bit_width=None, + in_ch=1, + in_features=(28, 28), + ): + super(LFC, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + stats_op = get_stats_op(weight_quant_type) + + self.features = ModuleList() + self.features.append(get_act_quant(in_bit_width, in_quant_type)) + self.features.append(Dropout(p=IN_DROPOUT)) + in_features = reduce(mul, in_features) + for out_features in FC_OUT_FEATURES: + self.features.append( + get_quant_linear( + in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type, + stats_op=stats_op, + ) + ) + in_features = out_features + self.features.append(BatchNorm1d(num_features=in_features)) + self.features.append(get_act_quant(act_bit_width, act_quant_type)) + self.features.append(Dropout(p=HIDDEN_DROPOUT)) + self.fc = get_quant_linear( + in_features=in_features, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type, + stats_op=stats_op, + ) + + def forward(self, x): + x = 2.0 * x - 1.0 + x = x.view(x.shape[0], -1) + for mod in self.features: + x = mod(x) + out = self.fc(x) + return out + + class objdict(dict): + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError("No such attribute: " + name) + + def __setattr__(self, name, value): + self[name] = value + + def __delattr__(self, name): + if name in self: + del self[name] + else: + raise AttributeError("No such attribute: " + name) + + # TODO: <all this needs to mvoe into Brevitas> + quantization_annotation = dict() + + class QuantizedLinearPlaceholderFunction(Function): + @staticmethod + def symbolic(g, W, x, bw, out_features): + # import pdb; pdb.set_trace() + quantization_annotation[W.uniqueName()] = str(bw) + return g.op("MatMul", W, x, domain_s="finn") + + @staticmethod + def forward(ctx, W, x, bw, out_features): + return torch.empty(1, out_features, dtype=torch.float) + + class QuantizedLinearPlaceholder(nn.Module): + def __init__(self, quantized_linear): + super(QuantizedLinearPlaceholder, self).__init__() + self.in_features = quantized_linear.in_features + self.out_features = quantized_linear.out_features + # compute the quantized weights + W, s, bitwidth = quantized_linear.weight_quant(quantized_linear.weight) + W = W.detach().numpy().reshape(self.out_features, self.in_features) + s = s.detach().numpy() + s = s.reshape(s.size, 1) + W = W / s + self.W = torch.from_numpy(W) + self.bitwidth = bitwidth.item() + + def forward(self, x): + # return linear(self.W, x) + return QuantizedLinearPlaceholderFunction.apply( + self.W, x, self.bitwidth, self.out_features + ) + + class QuantizedHardTanhPlaceholderFunction(Function): + @staticmethod + def symbolic(g, input): + ret = g.op("QuantizedHardTanh", input, domain_s="finn") + # insert quantization annotation for the resulting tensor, TODO fix bitwidth + quantization_annotation[ret.uniqueName()] = "1" + return ret + + @staticmethod + def forward(ctx, input): + return input.clamp(0) + + class QuantizedHardTanhPlaceholder(nn.Module): + def __init__(self): + super(QuantizedHardTanhPlaceholder, self).__init__() + + def forward(self, x): + return QuantizedHardTanhPlaceholderFunction.apply(x) + + # TODO: </all this needs to mvoe into Brevitas> + export_onnx_path = "test_output_lfc.onnx" + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + for i in range(len(lfc.features)): + L = lfc.features[i] + if type(L).__name__ == "QuantLinear": + lfc.features[i] = QuantizedLinearPlaceholder(L) + elif type(L).__name__ == "QuantHardTanh": + lfc.features[i] = QuantizedHardTanhPlaceholder() + lfc.fc = QuantizedLinearPlaceholder(lfc.fc) + torch.onnx.export(lfc, torch.empty(784, dtype=torch.float), export_onnx_path) + model = onnx.load(export_onnx_path) + assert len(model.graph.input) == 16 + assert len(model.graph.node) == 33 + assert len(model.graph.output) == 1 + assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 + assert model.graph.node[13].op_type == "Constant" + assert model.graph.node[12].op_type == "QuantizedHardTanh"