From 28721620f08ade4e5b672c3731cff55dbceeac7d Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Mon, 4 Nov 2019 17:38:33 +0000 Subject: [PATCH] [Transform] restrict absorb_mul_into_multi_threshold dim restriction: mul must be scalar or 1d val restriction: mul value must be non-negative --- src/finn/transformation/streamline.py | 40 +++++++++++++++------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 04c7f05b3..34e926e5b 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -260,30 +260,34 @@ def absorb_add_into_multi_threshold(model): def absorb_mul_into_multi_threshold(model): """Absorb preceding Mul ops into MultiThreshold by updating the threshold - values.""" + values. Only *positive* scalar/1D vectors can be absorbed.""" graph = model.graph node_ind = 0 graph_modified = False for n in graph.node: node_ind += 1 if n.op_type == "Mul": + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None + is_signed = (A < 0).any() + is_scalar = np.prod(A.shape) == 1 + is_1d = len(A.shape) == 2 and A.shape[0] == 1 consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": - mul_weight_name = n.input[1] - threshold_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - T = model.get_initializer(threshold_name) - assert A is not None - assert T is not None - start_name = n.input[0] - # compute new thresholds and set initializer - Tnew = T / A.reshape(-1, T.shape[1]) - # TODO: need to handle negative A values correctly; produce - # mul sign mask and merge into preceding matmul? - model.set_initializer(threshold_name, Tnew) - # wire add input directly to MultiThreshold - consumer.input[0] = start_name - # remove the mul node - graph.node.remove(n) - graph_modified = True + if not is_signed and (is_1d or is_scalar): + threshold_name = consumer.input[1] + T = model.get_initializer(threshold_name) + assert T is not None + start_name = n.input[0] + # compute new thresholds and set initializer + Tnew = T / A.reshape(-1, T.shape[1]) + # TODO: need to handle negative A values correctly; produce + # mul sign mask and merge into preceding matmul? + model.set_initializer(threshold_name, Tnew) + # wire add input directly to MultiThreshold + consumer.input[0] = start_name + # remove the mul node + graph.node.remove(n) + graph_modified = True return (model, graph_modified) -- GitLab