Skip to content
Snippets Groups Projects
Commit a9d7d896 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

use Brevitas ONNX export functions, add some assertions

parent 458a161e
No related branches found
No related tags found
No related merge requests found
import os
from functools import reduce from functools import reduce
from operator import mul from operator import mul
import onnx import onnx
import onnx.numpy_helper as nph
import torch import torch
import torch.onnx import torch.onnx
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch import nn
from torch.autograd import Function
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
...@@ -70,90 +70,33 @@ def test_brevitas_to_onnx_export(): ...@@ -70,90 +70,33 @@ def test_brevitas_to_onnx_export():
out = self.fc(x) out = self.fc(x)
return out return out
class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)
# TODO: <all this needs to mvoe into Brevitas>
quantization_annotation = dict()
class QuantizedLinearPlaceholderFunction(Function):
@staticmethod
def symbolic(g, W, x, bw, out_features):
# import pdb; pdb.set_trace()
quantization_annotation[W.uniqueName()] = str(bw)
return g.op("MatMul", W, x, domain_s="finn")
@staticmethod
def forward(ctx, W, x, bw, out_features):
return torch.empty(1, out_features, dtype=torch.float)
class QuantizedLinearPlaceholder(nn.Module):
def __init__(self, quantized_linear):
super(QuantizedLinearPlaceholder, self).__init__()
self.in_features = quantized_linear.in_features
self.out_features = quantized_linear.out_features
# compute the quantized weights
W, s, bitwidth = quantized_linear.weight_quant(quantized_linear.weight)
W = W.detach().numpy().reshape(self.out_features, self.in_features)
s = s.detach().numpy()
s = s.reshape(s.size, 1)
W = W / s
self.W = torch.from_numpy(W)
self.bitwidth = bitwidth.item()
def forward(self, x):
# return linear(self.W, x)
return QuantizedLinearPlaceholderFunction.apply(
self.W, x, self.bitwidth, self.out_features
)
class QuantizedHardTanhPlaceholderFunction(Function):
@staticmethod
def symbolic(g, input):
ret = g.op("QuantizedHardTanh", input, domain_s="finn")
# insert quantization annotation for the resulting tensor, TODO fix bitwidth
quantization_annotation[ret.uniqueName()] = "1"
return ret
@staticmethod
def forward(ctx, input):
return input.clamp(0)
class QuantizedHardTanhPlaceholder(nn.Module):
def __init__(self):
super(QuantizedHardTanhPlaceholder, self).__init__()
def forward(self, x):
return QuantizedHardTanhPlaceholderFunction.apply(x)
# TODO: </all this needs to mvoe into Brevitas>
export_onnx_path = "test_output_lfc.onnx" export_onnx_path = "test_output_lfc.onnx"
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) with torch.no_grad():
for i in range(len(lfc.features)): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
L = lfc.features[i] import brevitas.onnx as bo
if type(L).__name__ == "QuantLinear":
lfc.features[i] = QuantizedLinearPlaceholder(L) bo.prepare_for_onnx_export(lfc, True)
elif type(L).__name__ == "QuantHardTanh": torch.onnx.export(
lfc.features[i] = QuantizedHardTanhPlaceholder() lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True
lfc.fc = QuantizedLinearPlaceholder(lfc.fc) )
torch.onnx.export(lfc, torch.empty(784, dtype=torch.float), export_onnx_path) model = onnx.load(export_onnx_path)
model = onnx.load(export_onnx_path) # TODO the following way of testing is highly sensitive to small changes
assert len(model.graph.input) == 16 # in PyTorch ONNX export: the order, names, count... of nodes could
assert len(model.graph.node) == 33 # easily change between different versions, and break this test.
assert len(model.graph.output) == 1 assert len(model.graph.input) == 32
assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 assert len(model.graph.node) == 33
assert model.graph.node[13].op_type == "Constant" assert len(model.graph.output) == 1
assert model.graph.node[12].op_type == "QuantizedHardTanh" assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
assert model.graph.node[12].op_type == "QuantizedHardTanh"
assert model.graph.node[13].op_type == "Constant"
assert model.graph.node[14].op_type == "MatMul"
assert model.graph.node[12].output[0] == model.graph.node[14].input[1]
assert model.graph.node[13].output[0] == model.graph.node[14].input[0]
int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
int_weights_onnx = nph.to_array(model.graph.node[13].attribute[0].t)
assert (int_weights_onnx == int_weights_pytorch).all()
assert model.graph.node[12].attribute[0].name == "activation_qnt"
assert model.graph.node[12].attribute[0].s.decode("utf-8") == "1"
assert model.graph.node[14].attribute[1].name == "weight_qnt"
assert model.graph.node[14].attribute[1].s.decode("utf-8") == "1"
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment