Skip to content
Snippets Groups Projects
Unverified Commit 72bc4752 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

[Transform] skip and not fail when const params not found

parent 403311d2
No related branches found
No related tags found
No related merge requests found
...@@ -67,8 +67,9 @@ class MoveAddPastMul(Transformation): ...@@ -67,8 +67,9 @@ class MoveAddPastMul(Transformation):
add_weight_name = n.input[1] add_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
B = model.get_initializer(add_weight_name) B = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for mul weights is not set." if (A is None) or (B is None):
assert B is not None, "Initializer for add weights is not set." warnings.warn("Mul or add does not have constant params, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -123,8 +124,9 @@ class MoveScalarMulPastMatMul(Transformation): ...@@ -123,8 +124,9 @@ class MoveScalarMulPastMatMul(Transformation):
matmul_weight_name = consumer.input[1] matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
W = model.get_initializer(matmul_weight_name) W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for mul weights is not set." if (A is None) or (W is None):
assert W is not None, "Initializer for matmul weights is not set." warnings.warn("MatMul or Mul params are not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -180,8 +182,9 @@ class MoveScalarAddPastMatMul(Transformation): ...@@ -180,8 +182,9 @@ class MoveScalarAddPastMatMul(Transformation):
matmul_weight_name = consumer.input[1] matmul_weight_name = consumer.input[1]
A = model.get_initializer(add_weight_name) A = model.get_initializer(add_weight_name)
W = model.get_initializer(matmul_weight_name) W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for add weights is not set." if (A is None) or (W is None):
assert W is not None, "Initializer for matmul weights is not set." warnings.warn("MatMul or Add params are not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -242,7 +245,9 @@ class MoveScalarAddPastConv(Transformation): ...@@ -242,7 +245,9 @@ class MoveScalarAddPastConv(Transformation):
conv_in_name = consumer.input[0] conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name) conv_in_shape = model.get_tensor_shape(conv_in_name)
A = model.get_initializer(add_weight_name) A = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for add weights is not set." if A is None:
warnings.warn("Add param is not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
end_name = consumer.output[0] end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name) conv_out_shape = model.get_tensor_shape(end_name)
...@@ -310,7 +315,9 @@ class MoveScalarMulPastConv(Transformation): ...@@ -310,7 +315,9 @@ class MoveScalarMulPastConv(Transformation):
): ):
mul_weight_name = n.input[1] mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set." if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
conv_node = consumer conv_node = consumer
mul_node = n mul_node = n
start_name = mul_node.input[0] start_name = mul_node.input[0]
...@@ -621,7 +628,9 @@ class MoveTransposePastScalarMul(Transformation): ...@@ -621,7 +628,9 @@ class MoveTransposePastScalarMul(Transformation):
): ):
mul_weight_name = consumer.input[1] mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set." if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
transp_node = n transp_node = n
mul_node = consumer mul_node = consumer
start_name = transp_node.input[0] start_name = transp_node.input[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment