From d77dd7fe0de8fe473cd04d76ddde1877967a8c14 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Tue, 10 Mar 2020 18:17:29 +0000 Subject: [PATCH] [Streamline] better 1d checking for absorbption --- src/finn/transformation/streamline/absorb.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 5a9355bcf..68e013a6d 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -55,7 +55,8 @@ class AbsorbAddIntoMultiThreshold(Transformation): start_name = n.input[0] # we can only absorb 0d or 1d adds is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) - is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 if is_scalar or is_1d: Tnew = T - A.reshape(-1, 1) # Tnew = T - A.reshape(-1, T.shape[1]) @@ -85,7 +86,8 @@ class AbsorbMulIntoMultiThreshold(Transformation): assert A is not None, "Initializer for mul weights is not set." is_signed = (A < 0).any() is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) - is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": if not is_signed and (is_1d or is_scalar): @@ -122,7 +124,8 @@ class FactorOutMulSignMagnitude(Transformation): A = model.get_initializer(mul_weight_name) assert A is not None, "Initializer for mul weights is not set." is_scalar = np.prod(A.shape) == 1 - is_1d = len(A.shape) == 2 and A.shape[0] == 1 + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 is_not_bipolar = ( model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR ) -- GitLab