From fcc457e5a020694329ca8da67e43043c77d43253 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 22 Nov 2019 19:09:52 +0000 Subject: [PATCH] [Transform] reiterated fixes for AbsorbAdd/Mul --- src/finn/transformation/streamline/absorb.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 0a0c8e301..6806137e0 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -26,8 +26,8 @@ class AbsorbAddIntoMultiThreshold(Transformation): assert T is not None start_name = n.input[0] # we can only absorb 0d or 1d adds - is_scalar = all(x == 1 for x in A.shape) - is_1d = np.prod(A.shape) == A.shape[-1] + is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) + is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] if is_scalar or is_1d: Tnew = T - A.reshape(-1, 1) # Tnew = T - A.reshape(-1, T.shape[1]) @@ -56,8 +56,8 @@ class AbsorbMulIntoMultiThreshold(Transformation): A = model.get_initializer(mul_weight_name) assert A is not None is_signed = (A < 0).any() - is_scalar = np.prod(A.shape) == 1 - is_1d = len(A.shape) == 2 and A.shape[0] == 1 + is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) + is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": if not is_signed and (is_1d or is_scalar): @@ -66,7 +66,7 @@ class AbsorbMulIntoMultiThreshold(Transformation): assert T is not None start_name = n.input[0] # compute new thresholds and set initializer - Tnew = T / A.reshape(-1, T.shape[1]) + Tnew = T / A.reshape(-1, 1) # TODO: need to handle negative A values correctly; produce # mul sign mask and merge into preceding matmul? model.set_initializer(threshold_name, Tnew) -- GitLab