Skip to content
Snippets Groups Projects
Unverified Commit 72bc4752 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

[Transform] skip and not fail when const params not found

parent 403311d2
No related branches found
No related tags found
No related merge requests found
......@@ -67,8 +67,9 @@ class MoveAddPastMul(Transformation):
add_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
B = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for mul weights is not set."
assert B is not None, "Initializer for add weights is not set."
if (A is None) or (B is None):
warnings.warn("Mul or add does not have constant params, skipping")
continue
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
......@@ -123,8 +124,9 @@ class MoveScalarMulPastMatMul(Transformation):
matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
assert W is not None, "Initializer for matmul weights is not set."
if (A is None) or (W is None):
warnings.warn("MatMul or Mul params are not constant, skipping")
continue
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
......@@ -180,8 +182,9 @@ class MoveScalarAddPastMatMul(Transformation):
matmul_weight_name = consumer.input[1]
A = model.get_initializer(add_weight_name)
W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for add weights is not set."
assert W is not None, "Initializer for matmul weights is not set."
if (A is None) or (W is None):
warnings.warn("MatMul or Add params are not constant, skipping")
continue
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
......@@ -242,7 +245,9 @@ class MoveScalarAddPastConv(Transformation):
conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
A = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for add weights is not set."
if A is None:
warnings.warn("Add param is not constant, skipping")
continue
start_name = n.input[0]
end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name)
......@@ -310,7 +315,9 @@ class MoveScalarMulPastConv(Transformation):
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
conv_node = consumer
mul_node = n
start_name = mul_node.input[0]
......@@ -621,7 +628,9 @@ class MoveTransposePastScalarMul(Transformation):
):
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
transp_node = n
mul_node = consumer
start_name = transp_node.input[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment