Skip to content
Snippets Groups Projects
Commit ff3d1413 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Reuse bias tensor when extracting the bias from Conv nodes.

parent 20fcaf52
No related branches found
No related tags found
No related merge requests found
......@@ -59,17 +59,12 @@ class ExtractBiasFromConv(Transformation):
# Insert bias as Add node behind the Conv node
out_shape = model.get_tensor_shape(n.output[0])
# Reshape bias tensor
add_shape = [1] * len(out_shape)
# ToDo: this must change to "add_shape[-1] = bias.shape[0]" when
# channels last comes around
add_shape[1] = bias.shape[0]
add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
add_shape,
)
graph.value_info.append(add_tensor)
model.set_initializer(add_tensor.name, bias.reshape(add_shape))
model.set_initializer(n.input[2], bias.reshape(add_shape))
act_add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
......@@ -80,7 +75,7 @@ class ExtractBiasFromConv(Transformation):
add_node = helper.make_node(
"Add",
[act_add_tensor.name, add_tensor.name],
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)
graph.node.insert(node_ind, add_node)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment