Skip to content
Snippets Groups Projects
Commit b5bcd320 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Catch unsupported constant folding for SCALED datatypes.

parent 0207bc3d
No related branches found
No related tags found
No related merge requests found
......@@ -119,92 +119,94 @@ class FoldQuantWeights(Transformation):
target_node = target_node[0]
# Check next operator type
# ToDo: CHECK what this is merged into:
# Conv, MatMul and Mul nodes need
# only the extra multiplication behind
# BUT Add needs an extra Div in front!
mul_like_nodes = ["Mul", "Div", "Conv", "MatMul"]
add_like_nodes = ["Add", "Sub"]
all_supported_ops = mul_like_nodes.copy()
all_supported_ops.extend(add_like_nodes)
if (
target_node.op_type in mul_like_nodes
or target_node.op_type in add_like_nodes
):
# Move the scale factor behind the next operator
scale = model.get_initializer(n.input[1])
model.set_initializer(node_out, q_node_output / scale)
new_dtype = DataType[dtype.name.replace("SCALED", "")]
model.set_tensor_datatype(node_out, new_dtype)
if target_node.op_type not in all_supported_ops:
raise ValueError(
f"Can't constant fold Quant weight node "
f"into node type {target_node.op_type} "
f"at node: {target_node}."
)
# For buth mul and Add:
# Move the scale factor behind the next operator
scale = model.get_initializer(n.input[1])
model.set_initializer(node_out, q_node_output / scale)
new_dtype = DataType[dtype.name.replace("SCALED", "")]
model.set_tensor_datatype(node_out, new_dtype)
if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape)
bias_shape[1] = -1
scale = scale.reshape(bias_shape)
if scale.shape == (1,):
scale = scale[0]
mul_shape = tuple()
else:
mul_shape = scale.shape
mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
successor = model.find_consumers(node_out)
if successor is None:
raise RuntimeError(
"Can only constant fold scaled Quant weights "
"if a successor exists."
)
successor = successor[0]
succ_output_name = successor.output[0]
output_shape = model.get_tensor_shape(successor.output[0])
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
if target_node.op_type == "Conv" and len(scale.shape) > 0:
bias_shape = [1] * len(scale.shape)
bias_shape[1] = -1
scale = scale.reshape(bias_shape)
mul_node = helper.make_node(
"Mul",
[act_mul_tensor.name, mul_tensor.name],
[succ_output_name],
)
graph.node.insert(node_ind, mul_node)
if scale.shape == (1,):
scale = scale[0]
mul_shape = tuple()
else:
mul_shape = scale.shape
mul_tensor = helper.make_tensor_value_info(
if target_node.op_type in add_like_nodes:
# Move the scale factor behind also in-front of
# the next operator
div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
graph.value_info.append(div_tensor)
model.set_initializer(div_tensor.name, scale)
successor = model.find_consumers(node_out)
if successor is None:
raise RuntimeError(
"Can only constant fold scaled Quant weights "
"if a successor exists."
)
successor = successor[0]
succ_output_name = successor.output[0]
output_shape = model.get_tensor_shape(successor.output[0])
succ_input_name = successor.input[0]
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
successor.input[0] = act_mul_tensor.name
mul_node = helper.make_node(
"Mul",
[act_mul_tensor.name, mul_tensor.name],
[succ_output_name],
div_node = helper.make_node(
"Div",
[succ_input_name, div_tensor.name],
[act_mul_tensor.name],
)
graph.node.insert(node_ind, mul_node)
if target_node.op_type in add_like_nodes:
# Move the scale factor behind also in-front of
# the next operator
div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
)
graph.value_info.append(div_tensor)
model.set_initializer(div_tensor.name, scale)
succ_input_name = successor.input[0]
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
successor.input[0] = act_mul_tensor.name
div_node = helper.make_node(
"Div",
[succ_input_name, div_tensor.name],
[act_mul_tensor.name],
)
graph.node.insert(node_ind, div_node)
graph.node.insert(node_ind, div_node)
else:
# use the execution result as an initializer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment