diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 5a9355bcfefc57f9608490a156906e65f7672271..68e013a6d5bd2191c9b54c84646022f3b491ec60 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -55,7 +55,8 @@ class AbsorbAddIntoMultiThreshold(Transformation):
                     start_name = n.input[0]
                     # we can only absorb 0d or 1d adds
                     is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
-                    is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1]
+                    actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
+                    is_1d = actual_ndims == 1
                     if is_scalar or is_1d:
                         Tnew = T - A.reshape(-1, 1)
                         # Tnew = T - A.reshape(-1, T.shape[1])
@@ -85,7 +86,8 @@ class AbsorbMulIntoMultiThreshold(Transformation):
                 assert A is not None, "Initializer for mul weights is not set."
                 is_signed = (A < 0).any()
                 is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
-                is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1]
+                actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
+                is_1d = actual_ndims == 1
                 consumer = model.find_consumer(n.output[0])
                 if consumer is not None and consumer.op_type == "MultiThreshold":
                     if not is_signed and (is_1d or is_scalar):
@@ -122,7 +124,8 @@ class FactorOutMulSignMagnitude(Transformation):
                 A = model.get_initializer(mul_weight_name)
                 assert A is not None, "Initializer for mul weights is not set."
                 is_scalar = np.prod(A.shape) == 1
-                is_1d = len(A.shape) == 2 and A.shape[0] == 1
+                actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
+                is_1d = actual_ndims == 1
                 is_not_bipolar = (
                     model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
                 )