Skip to content
Snippets Groups Projects
Commit d77dd7fe authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Streamline] better 1d checking for absorbption

parent 3da413bf
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,8 @@ class AbsorbAddIntoMultiThreshold(Transformation):
start_name = n.input[0]
# we can only absorb 0d or 1d adds
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1]
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
is_1d = actual_ndims == 1
if is_scalar or is_1d:
Tnew = T - A.reshape(-1, 1)
# Tnew = T - A.reshape(-1, T.shape[1])
......@@ -85,7 +86,8 @@ class AbsorbMulIntoMultiThreshold(Transformation):
assert A is not None, "Initializer for mul weights is not set."
is_signed = (A < 0).any()
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1]
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
is_1d = actual_ndims == 1
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
if not is_signed and (is_1d or is_scalar):
......@@ -122,7 +124,8 @@ class FactorOutMulSignMagnitude(Transformation):
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
is_scalar = np.prod(A.shape) == 1
is_1d = len(A.shape) == 2 and A.shape[0] == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
is_1d = actual_ndims == 1
is_not_bipolar = (
model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment