Skip to content
Snippets Groups Projects
Commit 3e7c1de6 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for constant folding Quant weight nodes with per-channel scaling for convolutions.

parent d8495f34
No related branches found
No related tags found
No related merge requests found
......@@ -109,6 +109,21 @@ class FoldQuantWeights(Transformation):
new_dtype = DataType[dtype.name.replace("SCALED", "")]
model.set_tensor_datatype(node_out, new_dtype)
# Reshape scale for Conv if required
target_node = model.find_direct_successors(n)
if target_node is None:
raise RuntimeError(
"Weights quantized with the Quant node must have "
"a successor node."
)
else:
target_node = target_node[0]
if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape)
bias_shape[1] = -1
scale = scale.reshape(bias_shape)
if scale.shape == (1,):
scale = scale[0]
mul_shape = tuple()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment