Skip to content
Snippets Groups Projects
Commit 6d7026ee authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Removed reliance on ScaledInt data type.

parent 5f3cf652
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa" ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67"
ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
......
...@@ -30,25 +30,11 @@ import numpy as np ...@@ -30,25 +30,11 @@ import numpy as np
from onnx import TensorProto, helper from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
def get_dtype(bit_width: int, signed: bool) -> DataType:
bit_width = int(bit_width)
signed = bool(signed)
if bit_width == 1.0:
finn_dt = DataType["BIPOLAR"]
else:
if signed:
finn_dt = DataType["INT" + str(bit_width)]
else:
finn_dt = DataType["UINT" + str(bit_width)]
return finn_dt
class FoldQuantWeights(Transformation): class FoldQuantWeights(Transformation):
"""Merges Quant nodes, which are used as weights into the initializer """Merges Quant nodes, which are used as weights into the initializer
of the weight tensor. of the weight tensor.
...@@ -123,16 +109,8 @@ class FoldQuantWeights(Transformation): ...@@ -123,16 +109,8 @@ class FoldQuantWeights(Transformation):
# Round, to correct for floating point errors # Round, to correct for floating point errors
new_initializer = np.round(new_initializer) new_initializer = np.round(new_initializer)
model.set_initializer(node_out, new_initializer) model.set_initializer(node_out, new_initializer)
if n.op_type == "Quant": q_inst = getCustomOp(n)
bit_width = model.get_initializer(n.input[3]) new_dtype = q_inst.get_internal_dtype(model)
q_inst = getCustomOp(n)
signed = q_inst.get_nodeattr("signed")
elif n.op_type == "BinaryQuant":
bit_width = 1.0
signed = True
else:
raise RuntimeError("Got an unexpected quantizer node type")
new_dtype = get_dtype(bit_width, signed)
model.set_tensor_datatype(node_out, new_dtype) model.set_tensor_datatype(node_out, new_dtype)
if target_node.op_type == "Conv" and len(scale.shape) > 0: if target_node.op_type == "Conv" and len(scale.shape) > 0:
......
...@@ -95,9 +95,9 @@ class QuantActBaseHandler(ABC): ...@@ -95,9 +95,9 @@ class QuantActBaseHandler(ABC):
def _extract_output_datatype(self): def _extract_output_datatype(self):
"""Get the output datatype for the MultiThreshold node.""" """Get the output datatype for the MultiThreshold node."""
dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name q_inst = getCustomOp(self._q_node)
if dtype is not None: dtype = q_inst.get_internal_dtype(self._model)
dtype = dtype.replace("SCALED", "") dtype = dtype.name
return dtype return dtype
def calculate_node_parameters(self): def calculate_node_parameters(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment