Skip to content
Snippets Groups Projects
Commit 0fda042a authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Updated finn-base commit and refactored changed function names.

parent 4aa96eba
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="831073dc94d8788197e0e9a28e989bfcb9248045" ARG FINN_BASE_COMMIT="686639de94f96f02fef794a693a295591d3f25c6"
ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0" ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="5b22551871d25bf5e26917fe5900fcaa49406faf" ARG BREVITAS_COMMIT="5b22551871d25bf5e26917fe5900fcaa49406faf"
......
...@@ -40,7 +40,7 @@ from finn.transformation.qonnx.quant_act_to_multithreshold import ( ...@@ -40,7 +40,7 @@ from finn.transformation.qonnx.quant_act_to_multithreshold import (
ConvertQuantActToMultiThreshold, ConvertQuantActToMultiThreshold,
default_filter_function_generator, default_filter_function_generator,
) )
from finn.transformation.remove import RemoveEmptyPadding from finn.transformation.remove import RemoveIdentityOps
class ConvertQONNXtoFINN(Transformation): class ConvertQONNXtoFINN(Transformation):
...@@ -94,6 +94,6 @@ class ConvertQONNXtoFINN(Transformation): ...@@ -94,6 +94,6 @@ class ConvertQONNXtoFINN(Transformation):
# Convert AvgPool -> Mul -> Trunc structure to QuantAvgPool2d # Convert AvgPool -> Mul -> Trunc structure to QuantAvgPool2d
model = model.transform(AvgPoolAndTruncToQuantAvgPool()) model = model.transform(AvgPoolAndTruncToQuantAvgPool())
# Remove empty padding if it exists # Remove empty padding if it exists
model = model.transform(RemoveEmptyPadding()) model = model.transform(RemoveIdentityOps())
return model, False return model, False
...@@ -123,7 +123,7 @@ class FoldQuantWeights(Transformation): ...@@ -123,7 +123,7 @@ class FoldQuantWeights(Transformation):
new_initializer = np.round(new_initializer) new_initializer = np.round(new_initializer)
model.set_initializer(node_out, new_initializer) model.set_initializer(node_out, new_initializer)
q_inst = getCustomOp(n) q_inst = getCustomOp(n)
new_dtype = q_inst.get_internal_dtype(model) new_dtype = q_inst.get_integer_datatype(model)
model.set_tensor_datatype(node_out, new_dtype) model.set_tensor_datatype(node_out, new_dtype)
# Reshape scale for Conv if required # Reshape scale for Conv if required
......
...@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC): ...@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC):
def _extract_output_datatype(self): def _extract_output_datatype(self):
"""Get the output datatype for the MultiThreshold node.""" """Get the output datatype for the MultiThreshold node."""
q_inst = getCustomOp(self._q_node) q_inst = getCustomOp(self._q_node)
dtype = q_inst.get_internal_dtype(self._model) dtype = q_inst.get_integer_datatype(self._model)
dtype = dtype.name dtype = dtype.name
return dtype return dtype
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment