Skip to content
Snippets Groups Projects
Commit 828782f0 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Renamed BinaryQuant to BipolarQuant.

parent 74d2b740
No related branches found
No related tags found
No related merge requests found
...@@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="535b27013de83ff36925f2996745b12c9ba64d23" ARG FINN_BASE_COMMIT="451752a9cffed18921e297d24c58d0cbac0c6834"
ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" ARG BREVITAS_COMMIT="efc1217b94a71d616e3b4a37e56bd28a07c55be0"
ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e" ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e"
ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4" ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4"
ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b" ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b"
......
...@@ -48,7 +48,7 @@ class ConvertQONNXtoFINN(Transformation): ...@@ -48,7 +48,7 @@ class ConvertQONNXtoFINN(Transformation):
If incompatibilities are found a ValueError or RuntimeError is raised. If incompatibilities are found a ValueError or RuntimeError is raised.
The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
present a way to control which Quant and BinaryQuant nodes in the activation path present a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes. are converted to MultiThreshold nodes.
The filters which are represented by `max_multithreshold_bit_width` and The filters which are represented by `max_multithreshold_bit_width` and
`filter_lambda` are internally connected by an `AND` operation. A warning `filter_lambda` are internally connected by an `AND` operation. A warning
......
...@@ -47,7 +47,7 @@ class FoldQuantWeights(Transformation): ...@@ -47,7 +47,7 @@ class FoldQuantWeights(Transformation):
execution_context = model.make_empty_exec_context() execution_context = model.make_empty_exec_context()
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Quant" or n.op_type == "BinaryQuant": if n.op_type == "Quant" or n.op_type == "BipolarQuant":
node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
node_out = n.output[0] node_out = n.output[0]
......
...@@ -161,7 +161,7 @@ class QuantActBaseHandler(ABC): ...@@ -161,7 +161,7 @@ class QuantActBaseHandler(ABC):
# behind the MultiTrheshold nodes. # behind the MultiTrheshold nodes.
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and self._q_node.op_type == "BinaryQuant": if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant":
# Get Quant parameters # Get Quant parameters
mul_scale = np.atleast_1d(mul_scale) mul_scale = np.atleast_1d(mul_scale)
# ONNX only accepts 64bit floats as attributes # ONNX only accepts 64bit floats as attributes
...@@ -305,7 +305,7 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -305,7 +305,7 @@ class QuantReluHandler(QuantActBaseHandler):
"Only Quant nodes with zero-point == 0 " "Only Quant nodes with zero-point == 0 "
"are currently supported for ReLu activations." "are currently supported for ReLu activations."
) )
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
return return
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
...@@ -321,11 +321,13 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -321,11 +321,13 @@ class QuantReluHandler(QuantActBaseHandler):
# Gather parameters # Gather parameters
if self._q_node.op_type == "Quant": if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3]) bit_width = self._model.get_initializer(self._q_node.input[3])
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
bit_width = 1.0 bit_width = 1.0
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
quant_scale = self._model.get_initializer(self._q_node.input[1]) quant_scale = self._model.get_initializer(self._q_node.input[1]).astype(
np.float32
)
# q_inst = getCustomOp(self._q_node) # q_inst = getCustomOp(self._q_node)
# narrow = q_inst.get_nodeattr("narrow") # narrow = q_inst.get_nodeattr("narrow")
...@@ -334,11 +336,11 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -334,11 +336,11 @@ class QuantReluHandler(QuantActBaseHandler):
# onnx/finn/handler/act.py#L21 # onnx/finn/handler/act.py#L21
num_distinct_values = 2 ** bit_width num_distinct_values = 2 ** bit_width
num_thresholds = int(num_distinct_values - 1) num_thresholds = int(num_distinct_values - 1)
flat_scale = quant_scale.flatten() flat_scale = quant_scale.flatten().astype(np.float32)
num_scale_channels = flat_scale.shape[0] num_scale_channels = flat_scale.shape[0]
step = np.abs(flat_scale) step = np.abs(flat_scale).astype(np.float32)
min_threshold = step / 2 min_threshold = step / 2
thresholds = np.empty((num_scale_channels, num_thresholds)) thresholds = np.empty((num_scale_channels, num_thresholds)).astype(np.float32)
for c in range(num_scale_channels): for c in range(num_scale_channels):
for t in range(num_thresholds): for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t thresholds[c][t] = min_threshold[c] + step[c] * t
...@@ -349,6 +351,7 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -349,6 +351,7 @@ class QuantReluHandler(QuantActBaseHandler):
if thresholds.shape != final_shape: if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape) thresholds = np.broadcast_to(thresholds, final_shape)
print(f"{thresholds.dtype=}")
return thresholds return thresholds
def _calculate_act_scale(self): def _calculate_act_scale(self):
...@@ -409,7 +412,7 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -409,7 +412,7 @@ class QuantIdentityHandler(QuantActBaseHandler):
"Only Quant nodes with zero-point == 0 " "Only Quant nodes with zero-point == 0 "
"are currently supported for identity activations." "are currently supported for identity activations."
) )
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
quant_scale = self._model.get_initializer(self._q_node.input[1]) quant_scale = self._model.get_initializer(self._q_node.input[1])
if (quant_scale.flatten().shape[0] != 1) or quant_scale.flatten()[0] != 1.0: if (quant_scale.flatten().shape[0] != 1) or quant_scale.flatten()[0] != 1.0:
raise ValueError( raise ValueError(
...@@ -425,7 +428,7 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -425,7 +428,7 @@ class QuantIdentityHandler(QuantActBaseHandler):
if self._q_node.op_type == "Quant": if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3]) bit_width = self._model.get_initializer(self._q_node.input[3])
narrow = q_inst.get_nodeattr("narrow") narrow = q_inst.get_nodeattr("narrow")
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
bit_width = 1.0 bit_width = 1.0
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
...@@ -449,7 +452,7 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -449,7 +452,7 @@ class QuantIdentityHandler(QuantActBaseHandler):
if self._q_node.op_type == "Quant": if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3]) bit_width = self._model.get_initializer(self._q_node.input[3])
narrow = q_inst.get_nodeattr("narrow") narrow = q_inst.get_nodeattr("narrow")
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
bit_width = 1.0 bit_width = 1.0
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
...@@ -496,7 +499,7 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -496,7 +499,7 @@ class QuantIdentityHandler(QuantActBaseHandler):
# Gather parameters # Gather parameters
if self._q_node.op_type == "Quant": if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3]) bit_width = self._model.get_initializer(self._q_node.input[3])
elif self._q_node.op_type == "BinaryQuant": elif self._q_node.op_type == "BipolarQuant":
bit_width = 1.0 bit_width = 1.0
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
......
...@@ -38,7 +38,7 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -38,7 +38,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
Converts Quant nodes in the activation path to MultiThreshold nodes. Converts Quant nodes in the activation path to MultiThreshold nodes.
The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
present a way to control which Quant and BinaryQuant nodes in the activation path present a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes. are converted to MultiThreshold nodes.
The filters which are represented by `max_multithreshold_bit_width` and The filters which are represented by `max_multithreshold_bit_width` and
`filter_lambda` are internally connected by an `AND` operation. A warning `filter_lambda` are internally connected by an `AND` operation. A warning
...@@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Quant" or n.op_type == "BinaryQuant": if n.op_type == "Quant" or n.op_type == "BipolarQuant":
# Check that the node is in the activation path # Check that the node is in the activation path
inp = model.get_initializer(n.input[0]) inp = model.get_initializer(n.input[0])
out = model.get_initializer(n.output[0]) out = model.get_initializer(n.output[0])
...@@ -83,7 +83,7 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -83,7 +83,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
predecessor_op_type = predecessor predecessor_op_type = predecessor
if model.is_fork_node(n): if model.is_fork_node(n):
raise ValueError( raise ValueError(
"Forking Quant/BinaryQuant nodes are currently " "Forking Quant/BipolarQuant nodes are currently "
"not supported by FINN." "not supported by FINN."
) )
if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0: if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0:
...@@ -94,7 +94,7 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -94,7 +94,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
# Check if the bit width is low enough # Check if the bit width is low enough
if n.op_type == "Quant": if n.op_type == "Quant":
bit_width = model.get_initializer(n.input[3]) bit_width = model.get_initializer(n.input[3])
elif n.op_type == "BinaryQuant": elif n.op_type == "BipolarQuant":
bit_width = 1.0 bit_width = 1.0
else: else:
raise RuntimeError("Got an unexpected quantizer node type") raise RuntimeError("Got an unexpected quantizer node type")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment