diff --git a/notebooks/brevitas-network-import.ipynb b/notebooks/brevitas-network-import.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..cbe494e3165983ddc170989196c5c133e84b168b --- /dev/null +++ b/notebooks/brevitas-network-import.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Importing Brevitas networks into FINN" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import inspect\n", + "\n", + "def showSrc(what):\n", + " print(\"\".join(inspect.getsourcelines(what)[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class LFC(Module):\n", + "\n", + " def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None,\n", + " in_bit_width=None, in_ch=1, in_features=(28, 28)):\n", + " super(LFC, self).__init__()\n", + "\n", + " weight_quant_type = get_quant_type(weight_bit_width)\n", + " act_quant_type = get_quant_type(act_bit_width)\n", + " in_quant_type = get_quant_type(in_bit_width)\n", + " stats_op = get_stats_op(weight_quant_type)\n", + "\n", + " self.features = ModuleList()\n", + " self.features.append(get_act_quant(in_bit_width, in_quant_type))\n", + " self.features.append(Dropout(p=IN_DROPOUT))\n", + " in_features = reduce(mul, in_features)\n", + " for out_features in FC_OUT_FEATURES:\n", + " self.features.append(get_quant_linear(in_features=in_features,\n", + " out_features=out_features,\n", + " per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,\n", + " bit_width=weight_bit_width,\n", + " quant_type=weight_quant_type,\n", + " stats_op=stats_op))\n", + " in_features = out_features\n", + " self.features.append(BatchNorm1d(num_features=in_features))\n", + " self.features.append(get_act_quant(act_bit_width, act_quant_type))\n", + " self.features.append(Dropout(p=HIDDEN_DROPOUT))\n", + " self.fc = get_quant_linear(in_features=in_features,\n", + " out_features=num_classes,\n", + " per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,\n", + " bit_width=weight_bit_width,\n", + " quant_type=weight_quant_type,\n", + " stats_op=stats_op)\n", + "\n", + " def forward(self, x):\n", + " x = x.view(x.shape[0], -1)\n", + " x = 2.0 * x - torch.tensor([1.0])\n", + " for mod in self.features:\n", + " x = mod(x)\n", + " out = self.fc(x)\n", + " return out\n", + "\n" + ] + } + ], + "source": [ + "from models.LFC import LFC\n", + "showSrc(LFC)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: netron in /opt/conda/lib/python3.6/site-packages (3.5.9)\r\n" + ] + } + ], + "source": [ + "!pip install --user netron" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serving 'LFCW1A1.onnx' at http://0.0.0.0:8081\n" + ] + } + ], + "source": [ + "import netron\n", + "netron.start('LFCW1A1.onnx', port=8081, host=\"0.0.0.0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>\n" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%html\n", + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}