Skip to content
Snippets Groups Projects
Commit 0207bc3d authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for FoldQuantWeights into Add nodes.

parent 9b5a7f9f
No related branches found
No related tags found
No related merge requests found
......@@ -103,13 +103,12 @@ class FoldQuantWeights(Transformation):
# Check if the datatype can be directly constant folded
dtype = model.get_tensor_datatype(n.output[0])
if "SCALED" in dtype.name:
# Move the scale factor behind the next operator
scale = model.get_initializer(n.input[1])
model.set_initializer(node_out, q_node_output / scale)
new_dtype = DataType[dtype.name.replace("SCALED", "")]
model.set_tensor_datatype(node_out, new_dtype)
# Reshape scale for Conv if required
if model.is_fork_node(n):
raise RuntimeError(
"Weights quantized with the Quant node are not "
"allowed to be join nodes node."
)
target_node = model.find_direct_successors(n)
if target_node is None:
raise RuntimeError(
......@@ -119,48 +118,94 @@ class FoldQuantWeights(Transformation):
else:
target_node = target_node[0]
if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape)
bias_shape[1] = -1
scale = scale.reshape(bias_shape)
# Check next operator type
# ToDo: CHECK what this is merged into:
# Conv, MatMul and Mul nodes need
# only the extra multiplication behind
# BUT Add needs an extra Div in front!
mul_like_nodes = ["Mul", "Div", "Conv", "MatMul"]
add_like_nodes = ["Add", "Sub"]
if scale.shape == (1,):
scale = scale[0]
mul_shape = tuple()
else:
mul_shape = scale.shape
mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
if (
target_node.op_type in mul_like_nodes
or target_node.op_type in add_like_nodes
):
# Move the scale factor behind the next operator
scale = model.get_initializer(n.input[1])
model.set_initializer(node_out, q_node_output / scale)
new_dtype = DataType[dtype.name.replace("SCALED", "")]
model.set_tensor_datatype(node_out, new_dtype)
successor = model.find_consumers(node_out)
if successor is None:
raise RuntimeError(
"Can only constant fold scaled Quant weights "
"if a successor exists."
if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape)
bias_shape[1] = -1
scale = scale.reshape(bias_shape)
if scale.shape == (1,):
scale = scale[0]
mul_shape = tuple()
else:
mul_shape = scale.shape
mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
successor = successor[0]
mul_output_name = successor.output[0]
output_shape = model.get_tensor_shape(successor.output[0])
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
successor = model.find_consumers(node_out)
if successor is None:
raise RuntimeError(
"Can only constant fold scaled Quant weights "
"if a successor exists."
)
successor = successor[0]
succ_output_name = successor.output[0]
output_shape = model.get_tensor_shape(successor.output[0])
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
mul_node = helper.make_node(
"Mul",
[act_mul_tensor.name, mul_tensor.name],
[succ_output_name],
)
graph.node.insert(node_ind, mul_node)
if target_node.op_type in add_like_nodes:
# Move the scale factor behind also in-front of
# the next operator
div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
graph.value_info.append(div_tensor)
model.set_initializer(div_tensor.name, scale)
succ_input_name = successor.input[0]
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.input[0] = act_mul_tensor.name
div_node = helper.make_node(
"Div",
[succ_input_name, div_tensor.name],
[act_mul_tensor.name],
)
graph.node.insert(node_ind, div_node)
mul_node = helper.make_node(
"Mul",
[act_mul_tensor.name, mul_tensor.name],
[mul_output_name],
)
graph.node.insert(node_ind, mul_node)
else:
# use the execution result as an initializer
model.set_initializer(node_out, q_node_output)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment